diff --git a/server/core_account.go b/server/core_account.go index 8aa21fbc597b61fcb0dad5f779319572e4c2a5b8..9d52894ba9bf1958d0a3731b9d3a97e13c5276f0 100644 --- a/server/core_account.go +++ b/server/core_account.go @@ -28,6 +28,7 @@ import ( "github.com/heroiclabs/nakama/v3/console" "github.com/jackc/pgconn" "github.com/jackc/pgtype" + pgx "github.com/jackc/pgx/v4" "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -243,7 +244,7 @@ WHERE u.id IN (` + strings.Join(statements, ",") + `)` } func UpdateAccounts(ctx context.Context, logger *zap.Logger, db *sql.DB, updates []*accountUpdate) error { - if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { + if err := ExecuteInTxPgx(ctx, db, func(tx pgx.Tx) error { updateErr := updateAccounts(ctx, logger, tx, updates) if updateErr != nil { return updateErr @@ -260,7 +261,7 @@ func UpdateAccounts(ctx context.Context, logger *zap.Logger, db *sql.DB, updates return nil } -func updateAccounts(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates []*accountUpdate) error { +func updateAccounts(ctx context.Context, logger *zap.Logger, tx pgx.Tx, updates []*accountUpdate) error { for _, update := range updates { updateStatements := make([]string, 0, 7) distinctStatements := make([]string, 0, 7) @@ -346,7 +347,7 @@ func updateAccounts(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates query := "UPDATE users SET update_time = now(), " + strings.Join(updateStatements, ", ") + " WHERE id = $1 AND (" + strings.Join(distinctStatements, " OR ") + ")" - if _, err := tx.ExecContext(ctx, query, params...); err != nil { + if _, err := tx.Exec(ctx, query, params...); err != nil { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && pgErr.Code == dbErrorUniqueViolation && strings.Contains(pgErr.Message, "users_username_key") { return errors.New("Username is already in use.") diff --git a/server/core_multi.go b/server/core_multi.go index 8b29800de532f094f3f9778576044f8028b53f4d..568b813a5b8a45cecd1f43f5f56f8519230f1493 100644 --- a/server/core_multi.go +++ b/server/core_multi.go @@ -20,6 +20,7 @@ import ( "github.com/heroiclabs/nakama-common/api" "github.com/heroiclabs/nakama-common/runtime" + pgx "github.com/jackc/pgx/v4" "go.uber.org/zap" ) @@ -31,7 +32,7 @@ func MultiUpdate(ctx context.Context, logger *zap.Logger, db *sql.DB, metrics Me var storageWriteAcks []*api.StorageObjectAck var walletUpdateResults []*runtime.WalletUpdateResult - if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { + if err := ExecuteInTxPgx(ctx, db, func(tx pgx.Tx) error { storageWriteAcks = nil walletUpdateResults = nil diff --git a/server/core_storage.go b/server/core_storage.go index 0b8fd7867441362b78894449924e374ec56b130f..a25b213eb5be50cc8877ce9751636891ba02eb28 100644 --- a/server/core_storage.go +++ b/server/core_storage.go @@ -21,6 +21,7 @@ import ( "database/sql" "encoding/base64" "encoding/gob" + "encoding/hex" "errors" "fmt" "sort" @@ -30,6 +31,7 @@ import ( "github.com/heroiclabs/nakama-common/runtime" "github.com/jackc/pgconn" "github.com/jackc/pgtype" + pgx "github.com/jackc/pgx/v4" "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/timestamppb" @@ -49,6 +51,28 @@ type StorageOpWrite struct { Object *api.WriteStorageObject } +// Desired `read` persmission after this Op completes +func (op *StorageOpWrite) permissionRead() int32 { + if op.Object.PermissionRead != nil { + return op.Object.PermissionRead.Value + } + return 1 +} + +// Desired `write` persmission after this Op completes +func (op *StorageOpWrite) permissionWrite() int32 { + if op.Object.PermissionWrite != nil { + return op.Object.PermissionWrite.Value + } + return 1 +} + +// Expected object version after this Op completes +func (op *StorageOpWrite) expectedVersion() string { + hash := md5.Sum([]byte(op.Object.Value)) + return hex.EncodeToString(hash[:]) +} + func (s StorageOpWrites) Len() int { return len(s) } @@ -466,11 +490,17 @@ WHERE func StorageWriteObjects(ctx context.Context, logger *zap.Logger, db *sql.DB, metrics Metrics, storageIndex StorageIndex, authoritativeWrite bool, ops StorageOpWrites) (*api.StorageObjectAcks, codes.Code, error) { var acks []*api.StorageObjectAck - if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { + if err := ExecuteInTxPgx(ctx, db, func(tx pgx.Tx) error { // If the transaction is retried ensure we wipe any acks that may have been prepared by previous attempts. var writeErr error acks, writeErr = storageWriteObjects(ctx, logger, metrics, tx, authoritativeWrite, ops) if writeErr != nil { + if writeErr == runtime.ErrStorageRejectedVersion || writeErr == runtime.ErrStorageRejectedPermission { + logger.Debug("Error writing storage objects.", zap.Error(writeErr)) + return StatusError(codes.InvalidArgument, "Storage write rejected.", writeErr) + } else { + logger.Error("Error writing storage objects.", zap.Error(writeErr)) + } return writeErr } return nil @@ -487,7 +517,7 @@ func StorageWriteObjects(ctx context.Context, logger *zap.Logger, db *sql.DB, me return &api.StorageObjectAcks{Acks: acks}, codes.OK, nil } -func storageWriteObjects(ctx context.Context, logger *zap.Logger, metrics Metrics, tx *sql.Tx, authoritativeWrite bool, ops StorageOpWrites) ([]*api.StorageObjectAck, error) { +func storageWriteObjects(ctx context.Context, logger *zap.Logger, metrics Metrics, tx pgx.Tx, authoritativeWrite bool, ops StorageOpWrites) ([]*api.StorageObjectAck, error) { // Ensure writes are processed in a consistent order to avoid deadlocks from concurrent operations. // Sorting done on a copy to ensure we don't modify the input, which may be re-used on transaction retries. sortedOps := make(StorageOpWrites, 0, len(ops)) @@ -497,143 +527,151 @@ func storageWriteObjects(ctx context.Context, logger *zap.Logger, metrics Metric indexedOps[op] = i } sort.Sort(sortedOps) - // Run operations in the sorted order. acks := make([]*api.StorageObjectAck, ops.Len()) - for _, op := range sortedOps { - ack, writeErr := storageWriteObject(ctx, logger, metrics, tx, authoritativeWrite, op.OwnerID, op.Object) - if writeErr != nil { - if writeErr == runtime.ErrStorageRejectedVersion || writeErr == runtime.ErrStorageRejectedPermission { - return nil, StatusError(codes.InvalidArgument, "Storage write rejected.", writeErr) - } - logger.Debug("Error writing storage objects.", zap.Error(writeErr)) - return nil, writeErr - } - // Acks are returned in the original order. - acks[indexedOps[op]] = ack + batch := &pgx.Batch{} + for _, op := range sortedOps { + storagePrepBatch(batch, authoritativeWrite, op) } - return acks, nil -} -func storageWriteObject(ctx context.Context, logger *zap.Logger, metrics Metrics, tx *sql.Tx, authoritativeWrite bool, ownerID string, object *api.WriteStorageObject) (*api.StorageObjectAck, error) { - var dbVersion sql.NullString - var dbPermissionWrite sql.NullInt64 - var dbPermissionRead sql.NullInt64 - err := tx.QueryRowContext(ctx, "SELECT version, read, write FROM storage WHERE collection = $1 AND key = $2 AND user_id = $3", object.Collection, object.Key, ownerID).Scan(&dbVersion, &dbPermissionRead, &dbPermissionWrite) - if err != nil { - if err == sql.ErrNoRows { - if object.Version != "" && object.Version != "*" { - // Conditional write with a specific version but the object did not exist at all. - metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection}, 1) + br := tx.SendBatch(ctx, batch) + defer br.Close() // TODO: need to "drain" batch, otherwise it logs all unprocessed queries + for _, op := range sortedOps { + object := op.Object + var resultRead int32 + var resultWrite int32 + var resultVersion string + err := br.QueryRow().Scan(&resultRead, &resultWrite, &resultVersion) + var pgErr *pgconn.PgError + if err != nil && errors.As(err, &pgErr) { + if pgErr.Code == dbErrorUniqueViolation { + metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection, "reason": "version"}, 1) return nil, runtime.ErrStorageRejectedVersion } - } else { - logger.Debug("Error in write storage object pre-flight.", zap.Any("object", object), zap.Error(err)) + return nil, err + } else if err == pgx.ErrNoRows { + // Not every case from storagePrepWriteObject can return NoRows, but those + // which do it is always ErrStorageRejectedVersion + metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection, "reason": "version"}, 1) + return nil, runtime.ErrStorageRejectedVersion + } else if err != nil { return nil, err } - } - - if dbVersion.Valid && (object.Version == "*" || (object.Version != "" && object.Version != dbVersion.String)) { - // An object existed, and it's a conditional write that either: - // - Expects no object. - // - Or expects a given version, but it does not match. - metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection}, 1) - return nil, runtime.ErrStorageRejectedVersion - } - - if dbPermissionWrite.Valid && dbPermissionWrite.Int64 == 0 && !authoritativeWrite { - // Non-authoritative write to an existing storage object with permission 0. - return nil, runtime.ErrStorageRejectedPermission - } - - newVersion := fmt.Sprintf("%x", md5.Sum([]byte(object.Value))) - newPermissionRead := int32(1) - if object.PermissionRead != nil { - newPermissionRead = object.PermissionRead.Value - } - newPermissionWrite := int32(1) - if object.PermissionWrite != nil { - newPermissionWrite = object.PermissionWrite.Value - } - if dbVersion.Valid && dbVersion.String == newVersion && dbPermissionRead.Int64 == int64(newPermissionRead) && dbPermissionWrite.Int64 == int64(newPermissionWrite) { - // Stored object existed, and exactly matches the new object's version and read/write permissions. + if !(op.permissionRead() == resultRead && + op.permissionWrite() == resultWrite && + op.expectedVersion() == resultVersion) { + // Write failed, it can happen for 3 reasons: + // - constraint violation on insert (handles elsewhere) + // - permission: non authoritative write & original row write != 1 + // - version mismatch + if !authoritativeWrite && resultWrite != 1 { + metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection, "reason": "permission"}, 1) + return nil, runtime.ErrStorageRejectedPermission + } else { + // version check failed + metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection, "reason": "version"}, 1) + return nil, runtime.ErrStorageRejectedVersion + } + } ack := &api.StorageObjectAck{ Collection: object.Collection, Key: object.Key, - Version: newVersion, + Version: resultVersion, + UserId: op.OwnerID, } - if ownerID != uuid.Nil.String() { - ack.UserId = ownerID - } - return ack, nil + acks[indexedOps[op]] = ack } + return acks, nil +} + +func storagePrepBatch(batch *pgx.Batch, authoritativeWrite bool, op *StorageOpWrite) { + object := op.Object + ownerID := op.OwnerID + + newVersion := op.expectedVersion() + newPermissionRead := op.permissionRead() + newPermissionWrite := op.permissionWrite() + params := []interface{}{object.Collection, object.Key, ownerID, object.Value, newVersion, newPermissionRead, newPermissionWrite} var query string + + writeCheck := "" + // Respect permissions in non-authoritative writes. + if !authoritativeWrite { + writeCheck = " AND storage.write = 1" + } + switch { case object.Version != "" && object.Version != "*": // OCC if-match. - query = "UPDATE storage SET value = $4, version = $5, read = $6, write = $7, update_time = now() WHERE collection = $1 AND key = $2 AND user_id = $3::UUID AND version = $8" + + // Query pattern + // (UPDATE t ... RETURNING) UNION ALL (SELECT FROM t) LIMIT 1 + // allows us to fetch row state after update even if update itself fails WHERE + // condition. + // That is returned values are final state of the row regardless of UPDATE success + query = ` + WITH upd AS ( + UPDATE storage SET value = $4, version = $5, read = $6, write = $7, update_time = now() + WHERE collection = $1 AND key = $2 AND user_id = $3::UUID AND version = $8 + ` + writeCheck + ` + AND NOT (storage.version = $5 AND storage.read = $6 AND storage.write = $7) -- micro optimization: don't update row unnecessary + RETURNING read, write, version + ) + (SELECT read, write, version from upd) + UNION ALL + (SELECT read, write, version FROM storage WHERE collection = $1 and key = $2 and user_id = $3) + LIMIT 1` + params = append(params, object.Version) - // Respect permissions in non-authoritative writes. - if !authoritativeWrite { - query += " AND write = 1" - } - case dbVersion.Valid && object.Version == "": - // An existing storage object was present, but no OCC of any kind is specified. - query = "UPDATE storage SET value = $4, version = $5, read = $6, write = $7, update_time = now() WHERE collection = $1 AND key = $2 AND user_id = $3::UUID" - // Respect permissions in non-authoritative writes. - if !authoritativeWrite { - query += " AND write = 1" - } - case !dbVersion.Valid && object.Version == "": - // An existing storage object was not present, and no OCC of any kind is specified. - // Separate to the case above to handle concurrent non-OCC object creations, where all but the first must become updates. - query = "INSERT INTO storage (collection, key, user_id, value, version, read, write, create_time, update_time) VALUES ($1, $2, $3::UUID, $4, $5, $6, $7, now(), now()) ON CONFLICT (collection, read, key, user_id) DO UPDATE SET value = $4, version = $5, read = $6, write = $7, update_time = now()" - // Respect permissions in non-authoritative writes, where this operation also loses the race to insert the object. - if !authoritativeWrite { - query += " WHERE storage.write = 1" - } - case dbVersion.Valid && object.Version != "*": - // An existing storage object was present, but no OCC if-not-exists required. - query = "UPDATE storage SET value = $4, version = $5, read = $6, write = $7, update_time = now() WHERE collection = $1 AND key = $2 AND user_id = $3::UUID AND version = $8" - params = append(params, dbVersion.String) - // Respect permissions in non-authoritative writes. - if !authoritativeWrite { - query += " AND write = 1" - } - default: + + // Outcomes: + // - No rows: if no rows returned, then object was not found in DB and can't be updated + // - We have row returned, but now we need to know if update happened, that is its WHERE matched + // * write != 1 means no permission to write + // * dbVersion != original version means OCC failure + + case object.Version == "": + // non-OCC write, "last write wins" kind of write + + // Similar pattern as in case above, but supports case when row + // didn't exist in the database. Another difference is that there is no version + // check for existing row. + query = ` + WITH upd AS ( + INSERT INTO storage (collection, key, user_id, value, version, read, write, create_time, update_time) + VALUES ($1, $2, $3::UUID, $4, $5, $6, $7, now(), now()) + ON CONFLICT (collection, key, user_id) DO + UPDATE SET value = $4, version = $5, read = $6, write = $7, update_time = now() + WHERE TRUE` + writeCheck + ` + AND NOT (storage.version = $5 AND storage.read = $6 AND storage.write = $7) -- micro optimization: don't update row unnecessary + RETURNING read, write, version + ) + (SELECT read, write, version from upd) + UNION ALL + (SELECT read, write, version FROM storage WHERE collection = $1 and key = $2 and user_id = $3) + LIMIT 1` + + // Outcomes: + // - Row is always returned, need to know if update happened, that is its WHERE matches + // - write != 1 means no permission to write + + case object.Version == "*": // OCC if-not-exists, and all other non-OCC cases. - query = "INSERT INTO storage (collection, key, user_id, value, version, read, write, create_time, update_time) VALUES ($1, $2, $3::UUID, $4, $5, $6, $7, now(), now())" // Existing permission checks are not applicable for new storage objects. - } - - res, err := tx.ExecContext(ctx, query, params...) - if err != nil { - logger.Debug("Could not write storage object, exec error.", zap.Any("object", object), zap.String("query", query), zap.Error(err)) - var pgErr *pgconn.PgError - if errors.As(err, &pgErr) && pgErr.Code == dbErrorUniqueViolation { - metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection}, 1) - return nil, runtime.ErrStorageRejectedVersion - } - return nil, err - } - if rowsAffected, err := res.RowsAffected(); rowsAffected != 1 { - logger.Debug("Could not write storage object, rowsAffected error.", zap.Any("object", object), zap.String("query", query), zap.Error(err)) - metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection}, 1) - return nil, runtime.ErrStorageRejectedVersion - } + query = ` + INSERT INTO storage (collection, key, user_id, value, version, read, write, create_time, update_time) + VALUES ($1, $2, $3::UUID, $4, $5, $6, $7, now(), now()) + RETURNING read, write, version` - ack := &api.StorageObjectAck{ - Collection: object.Collection, - Key: object.Key, - Version: newVersion, - UserId: ownerID, + // Outcomes: + // - NoRows - insert failed due to constraint violation (concurrent insert) } - return ack, nil + batch.Queue(query, params...) } func StorageDeleteObjects(ctx context.Context, logger *zap.Logger, db *sql.DB, storageIndex StorageIndex, authoritativeDelete bool, ops StorageOpDeletes) (codes.Code, error) { diff --git a/server/core_storage_test.go b/server/core_storage_test.go index ada4a6dda67d706c7aa6078f6dd5ce630d3f3132..7ea3d5b2f786dfdd2c190dd682024a7c264bee7a 100644 --- a/server/core_storage_test.go +++ b/server/core_storage_test.go @@ -17,11 +17,14 @@ package server import ( "context" "crypto/md5" + "database/sql" + "encoding/hex" "fmt" "testing" "github.com/gofrs/uuid/v5" "github.com/heroiclabs/nakama-common/api" + "github.com/heroiclabs/nakama-common/runtime" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/wrapperspb" @@ -2078,3 +2081,146 @@ func TestStorageListNoRepeats(t *testing.T) { assert.Len(t, values.Objects, 7, "values length was not 7") assert.Equal(t, "", values.Cursor, "cursor was not nil") } + +// DB State and expected outcome when performing write op +type writeTestDBState struct { + write int + v string + expectedCode codes.Code + expectedError error + descr string +} + +// Test no OCC, last write wins ("") +func TestNonOCCNonAuthoritative(t *testing.T) { + v := "{}" + + statesOutcomes := []writeTestDBState{ + {0, "", codes.OK, nil, "did not exists"}, + {1, v, codes.OK, nil, "existed and permission allows write, version match"}, + {1, `{"a":1}`, codes.OK, nil, "existed and permission allows write, version does not match"}, + {0, v, codes.InvalidArgument, runtime.ErrStorageRejectedPermission, "existed and permission reject, version match"}, + {0, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedPermission, "existed and permission reject, version does not match"}, + } + testWrite(t, `{"newV": true}`, "", 1, false, statesOutcomes) +} + +// Test no OCC, last write wins ("") +func TestNonOCCAuthoritative(t *testing.T) { + v := "{}" + + statesOutcomes := []writeTestDBState{ + {0, "", codes.OK, nil, "did not exists"}, + {1, v, codes.OK, nil, "existed and permission allows write, version match"}, + {1, `{"a":1}`, codes.OK, nil, "existed and permission allows write, version does not match"}, + {0, v, codes.OK, nil, "existed and permission reject, version match"}, + {0, `{"a":1}`, codes.OK, nil, "existed and permission reject, version does not match"}, + } + testWrite(t, `{"newV": true}`, "", 1, true, statesOutcomes) +} + +// Test when OCC requires non-existing object ("*") +func TestOCCNotExistsAuthoritative(t *testing.T) { + v := "{}" + + statesOutcomes := []writeTestDBState{ + {0, "", codes.OK, nil, "did not exists"}, + {1, v, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission allows write, version match"}, + {1, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission allows write, version does not match"}, + {0, v, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission reject, version match"}, + {0, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission reject, version does not match"}, + } + testWrite(t, `{"newV": true}`, "*", 1, true, statesOutcomes) +} + +// Test when OCC requires non-existing object ("*") +func TestOCCNotExistsNonAuthoritative(t *testing.T) { + v := "{}" + + statesOutcomes := []writeTestDBState{ + {0, "", codes.OK, nil, "did not exists"}, + {1, v, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission allows write, version match"}, + {1, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission allows write, version does not match"}, + {0, v, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission reject, version match"}, + {0, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission reject, version does not match"}, + } + testWrite(t, `{"newV": true}`, "*", 1, false, statesOutcomes) +} + +// Test when OCC requires existing object with known version ('#hash#') +func TestOCCWriteNonAuthoritative(t *testing.T) { + v := "{}" + + statesOutcomes := []writeTestDBState{ + {0, "", codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "did not exists"}, + {1, v, codes.OK, nil, "existed and permission allows write, version match"}, + {1, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission allows write, version does not match"}, + {0, v, codes.InvalidArgument, runtime.ErrStorageRejectedPermission, "existed and permission reject, version match"}, + {0, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedPermission, "existed and permission reject, version does not match"}, + } + testWrite(t, `{"newV": true}`, v, 1, false, statesOutcomes) +} + +// Test when OCC requires existing object with known version ('#hash#') +func TestOCCWriteAuthoritative(t *testing.T) { + v := "{}" + + statesOutcomes := []writeTestDBState{ + {0, "", codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "did not exists"}, + {1, v, codes.OK, nil, "existed and permission allows write, version match"}, + {1, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission allows write, version does not match"}, + {0, v, codes.OK, nil, "existed and permission reject, version match"}, + {0, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission reject, version does not match"}, + } + testWrite(t, `{"newV": true}`, v, 1, true, statesOutcomes) +} + +func testWrite(t *testing.T, newVal, prevVal string, permWrite int, authoritative bool, states []writeTestDBState) { + db := NewDB(t) + defer db.Close() + + collection, userId := "testcollection", uuid.Must(uuid.NewV4()) + InsertUser(t, db, userId) + + for _, w := range states { + t.Run(w.descr, func(t *testing.T) { + key := GenerateString() + // Prepare DB with expected state + if w.v != "" { + if _, _, err := writeObject(t, db, collection, key, userId, w.write, w.v, "", true); err != nil { + t.Fatal(err) + } + } + + var version string + if prevVal != "" && prevVal != "*" { + hash := md5.Sum([]byte(prevVal)) + version = hex.EncodeToString(hash[:]) + } else { + version = prevVal + } + + // Test writing object and assert expected result + _, code, err := writeObject(t, db, collection, key, userId, permWrite, newVal, version, authoritative) + if code != w.expectedCode || err != w.expectedError { + t.Errorf("Failed: code=%d (expected=%d) err=%v", code, w.expectedCode, err) + } + }) + } +} + +func writeObject(t *testing.T, db *sql.DB, collection, key string, owner uuid.UUID, writePerm int, newV, version string, authoritative bool) (*api.StorageObjectAcks, codes.Code, error) { + t.Helper() + ops := StorageOpWrites{&StorageOpWrite{ + OwnerID: owner.String(), + Object: &api.WriteStorageObject{ + Collection: collection, + Key: key, + Value: newV, + Version: version, + PermissionRead: &wrapperspb.Int32Value{Value: 2}, + PermissionWrite: &wrapperspb.Int32Value{Value: int32(writePerm)}, + }, + }} + return StorageWriteObjects(context.Background(), logger, db, metrics, storageIdx, authoritative, ops) +} diff --git a/server/core_wallet.go b/server/core_wallet.go index 291f20cbcbbd888bb8c6cc6252682d012aef2a01..ed3550f96399b349b89bb47987a80b741bb399ae 100644 --- a/server/core_wallet.go +++ b/server/core_wallet.go @@ -30,6 +30,7 @@ import ( "github.com/gofrs/uuid/v5" "github.com/heroiclabs/nakama-common/runtime" "github.com/jackc/pgtype" + pgx "github.com/jackc/pgx/v4" "go.uber.org/zap" ) @@ -89,7 +90,7 @@ func UpdateWallets(ctx context.Context, logger *zap.Logger, db *sql.DB, updates var results []*runtime.WalletUpdateResult - if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { + if err := ExecuteInTxPgx(ctx, db, func(tx pgx.Tx) error { var updateErr error results, updateErr = updateWallets(ctx, logger, tx, updates, updateLedger) if updateErr != nil { @@ -110,7 +111,7 @@ func UpdateWallets(ctx context.Context, logger *zap.Logger, db *sql.DB, updates return results, nil } -func updateWallets(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates []*walletUpdate, updateLedger bool) ([]*runtime.WalletUpdateResult, error) { +func updateWallets(ctx context.Context, logger *zap.Logger, tx pgx.Tx, updates []*walletUpdate, updateLedger bool) ([]*runtime.WalletUpdateResult, error) { if len(updates) == 0 { return nil, nil } @@ -126,7 +127,7 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates // Select the wallets from the DB and decode them. wallets := make(map[string]map[string]int64, len(updates)) - rows, err := tx.QueryContext(ctx, initialQuery, initialParams...) + rows, err := tx.Query(ctx, initialQuery, initialParams...) if err != nil { logger.Debug("Error retrieving user wallets.", zap.Error(err)) return nil, err @@ -136,7 +137,7 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates var wallet sql.NullString err = rows.Scan(&id, &wallet) if err != nil { - _ = rows.Close() + rows.Close() logger.Debug("Error reading user wallets.", zap.Error(err)) return nil, err } @@ -144,14 +145,14 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates var walletMap map[string]int64 err = json.Unmarshal([]byte(wallet.String), &walletMap) if err != nil { - _ = rows.Close() + rows.Close() logger.Debug("Error converting user wallet.", zap.String("user_id", id), zap.Error(err)) return nil, err } wallets[id] = walletMap } - _ = rows.Close() + rows.Close() results := make([]*runtime.WalletUpdateResult, 0, len(updates)) @@ -232,7 +233,7 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates logger.Warn("Missing wallet update for user.", zap.String("user_id", userID)) continue } - _, err = tx.ExecContext(ctx, "UPDATE users SET update_time = now(), wallet = $2 WHERE id = $1", userID, updatedWallet) + _, err = tx.Exec(ctx, "UPDATE users SET update_time = now(), wallet = $2 WHERE id = $1", userID, updatedWallet) if err != nil { logger.Debug("Error writing user wallet.", zap.String("user_id", userID), zap.Error(err)) return nil, err @@ -241,7 +242,7 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates // Write the ledger updates, if any. if updateLedger && (len(statements) > 0) { - _, err = tx.ExecContext(ctx, "INSERT INTO wallet_ledger (id, user_id, changeset, metadata) VALUES "+strings.Join(statements, ", "), params...) + _, err = tx.Exec(ctx, "INSERT INTO wallet_ledger (id, user_id, changeset, metadata) VALUES "+strings.Join(statements, ", "), params...) if err != nil { logger.Debug("Error writing user wallet ledgers.", zap.Error(err)) return nil, err diff --git a/server/db.go b/server/db.go index d148170f4a509703a39f9114d3faf2cd96bdd6f4..6515331921df054e1bca6ee59be0c666cb660a78 100644 --- a/server/db.go +++ b/server/db.go @@ -28,6 +28,7 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgerrcode" + pgx "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/stdlib" "go.uber.org/zap" ) @@ -236,16 +237,16 @@ func ExecuteRetryable(fn func() error) error { // fn is subject to the same restrictions as the fn passed to ExecuteTx. func ExecuteInTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) error { if isCockroach { - return ExecuteInTxCockroach(ctx, db, fn) + return executeInTxCockroach(ctx, db, fn) } else { - return ExecuteInTxPostgres(ctx, db, fn) + return executeInTxPostgres(ctx, db, fn) } } // Retries fn() if transaction commit returned retryable error code // Every call to fn() happens in its own transaction. On retry previous transaction // is ROLLBACK'ed and new transaction is opened. -func ExecuteInTxPostgres(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (err error) { +func executeInTxPostgres(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (err error) { var tx *sql.Tx defer func() { if tx != nil { @@ -283,7 +284,7 @@ func ExecuteInTxPostgres(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error // It has special optimization for `SAVEPOINT cockroach_restart`, called "retry savepoint", // which increases transaction priority every time it has to ROLLBACK due to serialization conflicts. // See: https://www.cockroachlabs.com/docs/stable/advanced-client-side-transaction-retries.html -func ExecuteInTxCockroach(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) error { +func executeInTxCockroach(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) error { tx, err := db.BeginTx(ctx, nil) if err != nil { // Can fail only if undernath connection is broken return err @@ -334,6 +335,125 @@ func ExecuteInTxCockroach(ctx context.Context, db *sql.DB, fn func(*sql.Tx) erro return err } +// Same as ExecuteInTx, but passes pgx.Tx to callback +func ExecuteInTxPgx(ctx context.Context, db *sql.DB, fn func(pgx.Tx) error) error { + if isCockroach { + return executeInTxCockroachPgx(ctx, db, fn) + } else { + return executeInTxPostgresPgx(ctx, db, fn) + } +} + +// Retries fn() if transaction commit returned retryable error code +// Every call to fn() happens in its own transaction. On retry previous transaction +// is ROLLBACK'ed and new transaction is opened. +func executeInTxPostgresPgx(ctx context.Context, db *sql.DB, fn func(pgx.Tx) error) (err error) { + conn, err := db.Conn(ctx) + if err != nil { + return err + } + defer conn.Close() + return conn.Raw(func(driverConn any) error { + conn := driverConn.(*stdlib.Conn).Conn() + + var tx pgx.Tx + defer func() { + if tx != nil { + _ = tx.Rollback(ctx) + } + }() + + // Prevent infinite loop (unlikely, but possible) + for i := 0; i < 5; i++ { + if tx, err = conn.BeginTx(ctx, pgx.TxOptions{}); err != nil { // Can fail only if undernath connection is broken + tx = nil + return err + } + if err = fn(tx); err == nil { + err = tx.Commit(ctx) + } + var pgErr *pgconn.PgError + if errors.As(errorCause(err), &pgErr) && pgErr.Code[:2] == "40" { + // 40XXXX codes are retriable errors + if err = tx.Rollback(ctx); err != nil && err != sql.ErrTxDone { + tx = nil + return err + } + continue + } else { + // Exit on successfull Commit or non retriable error + return err + } + } + // Stop trying after 5 attempts and return last op error + return err + }) +} + +// CockroachDB has it's own way to resolve serialization conflicts. +// It has special optimization for `SAVEPOINT cockroach_restart`, called "retry savepoint", +// which increases transaction priority every time it has to ROLLBACK due to serialization conflicts. +// See: https://www.cockroachlabs.com/docs/stable/advanced-client-side-transaction-retries.html +func executeInTxCockroachPgx(ctx context.Context, db *sql.DB, fn func(pgx.Tx) error) error { + conn, err := db.Conn(ctx) + if err != nil { + return err + } + defer conn.Close() + + return conn.Raw(func(driverConn any) error { + conn := driverConn.(*stdlib.Conn).Conn() + tx, err := conn.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { // Can fail only if undernath connection is broken + return err + } + defer func() { + if err == nil { + // Ignore commit errors. The tx has already been committed by RELEASE. + _ = tx.Commit(ctx) + } else { + // We always need to execute a Rollback() so sql.DB releases the + // connection. + _ = tx.Rollback(ctx) + } + }() + // Specify that we intend to retry this txn in case of database retryable errors. + if _, err = tx.Exec(ctx, "SAVEPOINT cockroach_restart"); err != nil { + return err + } + + // Prevent infinite loop (unlikely, but possible) + for i := 0; i < 5; i++ { + released := false + err = fn(tx) + if err == nil { + // RELEASE acts like COMMIT in CockroachDB. We use it since it gives us an + // opportunity to react to retryable errors, whereas tx.Commit() doesn't. + released = true + if _, err = tx.Exec(ctx, "RELEASE SAVEPOINT cockroach_restart"); err == nil { + return nil + } + } + // We got an error; let's see if it's a retryable one and, if so, restart. We look + // for either the standard PG errcode SerializationFailureError:40001 or the Cockroach extension + // errcode RetriableError:CR000. The Cockroach extension has been removed server-side, but support + // for it has been left here for now to maintain backwards compatibility. + var pgErr *pgconn.PgError + if retryable := errors.As(errorCause(err), &pgErr) && (pgErr.Code == "CR000" || pgErr.Code == pgerrcode.SerializationFailure); !retryable { + if released { + err = newAmbiguousCommitError(err) + } + return err + } + if _, retryErr := tx.Exec(ctx, "ROLLBACK TO SAVEPOINT cockroach_restart"); retryErr != nil { + return newTxnRestartError(retryErr, err) + } + } + // Stop trying after 5 attempts and return last op error + return err + }) +} + type int64Tuple struct { Tuple []int64 Valid bool // Valid is true if Tuple is not NULL diff --git a/server/leaderboard_rank_cache_test.go b/server/leaderboard_rank_cache_test.go index c8dc1c2e46e0c0132c12d8104e8806644a8768d0..6c8c8adedbd96584b98d3e8212b33e27cc348d5c 100644 --- a/server/leaderboard_rank_cache_test.go +++ b/server/leaderboard_rank_cache_test.go @@ -15,10 +15,11 @@ package server import ( + "testing" + "github.com/gofrs/uuid/v5" "github.com/heroiclabs/nakama-common/api" "github.com/stretchr/testify/assert" - "testing" ) func TestLocalLeaderboardRankCache_Insert_Ascending(t *testing.T) { diff --git a/server/match_common_test.go b/server/match_common_test.go index af1eb7ff14ecf96fa99986a34e910853b90b8a52..a222237fa613e1aac7ec378e1e71f0f7b58d27cf 100644 --- a/server/match_common_test.go +++ b/server/match_common_test.go @@ -17,16 +17,17 @@ package server import ( "context" "database/sql" + "os" + "strconv" + "testing" + "time" + "github.com/gofrs/uuid/v5" "github.com/heroiclabs/nakama-common/rtapi" "github.com/heroiclabs/nakama-common/runtime" "go.uber.org/atomic" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "os" - "strconv" - "testing" - "time" ) // loggerForTest allows for easily adjusting log output produced by tests in one place diff --git a/server/match_presence_test.go b/server/match_presence_test.go index 0b62a7638a9f11d989db4384c169c1684b193a92..6f6aa54f31690b9a378fcca97651e3d3e9370f9b 100644 --- a/server/match_presence_test.go +++ b/server/match_presence_test.go @@ -15,8 +15,9 @@ package server import ( - "github.com/gofrs/uuid/v5" "testing" + + "github.com/gofrs/uuid/v5" ) func TestMatchPresenceList(t *testing.T) {