From ca07adf6dd50184d20c0aac50deca024e1262790 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Tue, 18 Jul 2023 16:26:52 +0000 Subject: [PATCH] storage: optimize multiple objects write (#1059) * Fix tests * storage: additional test cases * db: add Pgx version of ExecuteInTx helper * storage: optimize multiple objects write Previously multiple objects write required 2 round trips to DB per object: one to fetch object state from DB and second to update object. Each query is indexed and fast, but with latency to DB comparable to query execution time it adds significant overhead. This change optimizes multiple objects write in 2 steps: 1. Instead of reading DB state for each object and then deciding on possible write operation, perform write operation unconditionally with correct predicates (version and permissions) defined where applicable. That said write operation might not succeed if row doesn't match predicate. Write query is structured in such way, that final state of the row in the database is returned, regardless whether writeop successed or not. By inspecting returned row we can infer whether it was success, version conflict or permission error. 2. Now that each object is written to DB in a single query, there is no dependencies between queries and all of them can be blasted to DB in a batch without waiting for result of each. Whole batch continues to be executed in a single transaction, so outcome is the same, but batching negates latency penalty. --- server/core_account.go | 7 +- server/core_multi.go | 3 +- server/core_storage.go | 262 +++++++++++++++----------- server/core_storage_test.go | 146 ++++++++++++++ server/core_wallet.go | 17 +- server/db.go | 128 ++++++++++++- server/leaderboard_rank_cache_test.go | 3 +- server/match_common_test.go | 9 +- server/match_presence_test.go | 3 +- 9 files changed, 444 insertions(+), 134 deletions(-) diff --git a/server/core_account.go b/server/core_account.go index 8aa21fbc5..9d52894ba 100644 --- a/server/core_account.go +++ b/server/core_account.go @@ -28,6 +28,7 @@ import ( "github.com/heroiclabs/nakama/v3/console" "github.com/jackc/pgconn" "github.com/jackc/pgtype" + pgx "github.com/jackc/pgx/v4" "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -243,7 +244,7 @@ WHERE u.id IN (` + strings.Join(statements, ",") + `)` } func UpdateAccounts(ctx context.Context, logger *zap.Logger, db *sql.DB, updates []*accountUpdate) error { - if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { + if err := ExecuteInTxPgx(ctx, db, func(tx pgx.Tx) error { updateErr := updateAccounts(ctx, logger, tx, updates) if updateErr != nil { return updateErr @@ -260,7 +261,7 @@ func UpdateAccounts(ctx context.Context, logger *zap.Logger, db *sql.DB, updates return nil } -func updateAccounts(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates []*accountUpdate) error { +func updateAccounts(ctx context.Context, logger *zap.Logger, tx pgx.Tx, updates []*accountUpdate) error { for _, update := range updates { updateStatements := make([]string, 0, 7) distinctStatements := make([]string, 0, 7) @@ -346,7 +347,7 @@ func updateAccounts(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates query := "UPDATE users SET update_time = now(), " + strings.Join(updateStatements, ", ") + " WHERE id = $1 AND (" + strings.Join(distinctStatements, " OR ") + ")" - if _, err := tx.ExecContext(ctx, query, params...); err != nil { + if _, err := tx.Exec(ctx, query, params...); err != nil { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && pgErr.Code == dbErrorUniqueViolation && strings.Contains(pgErr.Message, "users_username_key") { return errors.New("Username is already in use.") diff --git a/server/core_multi.go b/server/core_multi.go index 8b29800de..568b813a5 100644 --- a/server/core_multi.go +++ b/server/core_multi.go @@ -20,6 +20,7 @@ import ( "github.com/heroiclabs/nakama-common/api" "github.com/heroiclabs/nakama-common/runtime" + pgx "github.com/jackc/pgx/v4" "go.uber.org/zap" ) @@ -31,7 +32,7 @@ func MultiUpdate(ctx context.Context, logger *zap.Logger, db *sql.DB, metrics Me var storageWriteAcks []*api.StorageObjectAck var walletUpdateResults []*runtime.WalletUpdateResult - if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { + if err := ExecuteInTxPgx(ctx, db, func(tx pgx.Tx) error { storageWriteAcks = nil walletUpdateResults = nil diff --git a/server/core_storage.go b/server/core_storage.go index 0b8fd7867..a25b213eb 100644 --- a/server/core_storage.go +++ b/server/core_storage.go @@ -21,6 +21,7 @@ import ( "database/sql" "encoding/base64" "encoding/gob" + "encoding/hex" "errors" "fmt" "sort" @@ -30,6 +31,7 @@ import ( "github.com/heroiclabs/nakama-common/runtime" "github.com/jackc/pgconn" "github.com/jackc/pgtype" + pgx "github.com/jackc/pgx/v4" "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/timestamppb" @@ -49,6 +51,28 @@ type StorageOpWrite struct { Object *api.WriteStorageObject } +// Desired `read` persmission after this Op completes +func (op *StorageOpWrite) permissionRead() int32 { + if op.Object.PermissionRead != nil { + return op.Object.PermissionRead.Value + } + return 1 +} + +// Desired `write` persmission after this Op completes +func (op *StorageOpWrite) permissionWrite() int32 { + if op.Object.PermissionWrite != nil { + return op.Object.PermissionWrite.Value + } + return 1 +} + +// Expected object version after this Op completes +func (op *StorageOpWrite) expectedVersion() string { + hash := md5.Sum([]byte(op.Object.Value)) + return hex.EncodeToString(hash[:]) +} + func (s StorageOpWrites) Len() int { return len(s) } @@ -466,11 +490,17 @@ WHERE func StorageWriteObjects(ctx context.Context, logger *zap.Logger, db *sql.DB, metrics Metrics, storageIndex StorageIndex, authoritativeWrite bool, ops StorageOpWrites) (*api.StorageObjectAcks, codes.Code, error) { var acks []*api.StorageObjectAck - if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { + if err := ExecuteInTxPgx(ctx, db, func(tx pgx.Tx) error { // If the transaction is retried ensure we wipe any acks that may have been prepared by previous attempts. var writeErr error acks, writeErr = storageWriteObjects(ctx, logger, metrics, tx, authoritativeWrite, ops) if writeErr != nil { + if writeErr == runtime.ErrStorageRejectedVersion || writeErr == runtime.ErrStorageRejectedPermission { + logger.Debug("Error writing storage objects.", zap.Error(writeErr)) + return StatusError(codes.InvalidArgument, "Storage write rejected.", writeErr) + } else { + logger.Error("Error writing storage objects.", zap.Error(writeErr)) + } return writeErr } return nil @@ -487,7 +517,7 @@ func StorageWriteObjects(ctx context.Context, logger *zap.Logger, db *sql.DB, me return &api.StorageObjectAcks{Acks: acks}, codes.OK, nil } -func storageWriteObjects(ctx context.Context, logger *zap.Logger, metrics Metrics, tx *sql.Tx, authoritativeWrite bool, ops StorageOpWrites) ([]*api.StorageObjectAck, error) { +func storageWriteObjects(ctx context.Context, logger *zap.Logger, metrics Metrics, tx pgx.Tx, authoritativeWrite bool, ops StorageOpWrites) ([]*api.StorageObjectAck, error) { // Ensure writes are processed in a consistent order to avoid deadlocks from concurrent operations. // Sorting done on a copy to ensure we don't modify the input, which may be re-used on transaction retries. sortedOps := make(StorageOpWrites, 0, len(ops)) @@ -497,143 +527,151 @@ func storageWriteObjects(ctx context.Context, logger *zap.Logger, metrics Metric indexedOps[op] = i } sort.Sort(sortedOps) - // Run operations in the sorted order. acks := make([]*api.StorageObjectAck, ops.Len()) - for _, op := range sortedOps { - ack, writeErr := storageWriteObject(ctx, logger, metrics, tx, authoritativeWrite, op.OwnerID, op.Object) - if writeErr != nil { - if writeErr == runtime.ErrStorageRejectedVersion || writeErr == runtime.ErrStorageRejectedPermission { - return nil, StatusError(codes.InvalidArgument, "Storage write rejected.", writeErr) - } - logger.Debug("Error writing storage objects.", zap.Error(writeErr)) - return nil, writeErr - } - // Acks are returned in the original order. - acks[indexedOps[op]] = ack + batch := &pgx.Batch{} + for _, op := range sortedOps { + storagePrepBatch(batch, authoritativeWrite, op) } - return acks, nil -} -func storageWriteObject(ctx context.Context, logger *zap.Logger, metrics Metrics, tx *sql.Tx, authoritativeWrite bool, ownerID string, object *api.WriteStorageObject) (*api.StorageObjectAck, error) { - var dbVersion sql.NullString - var dbPermissionWrite sql.NullInt64 - var dbPermissionRead sql.NullInt64 - err := tx.QueryRowContext(ctx, "SELECT version, read, write FROM storage WHERE collection = $1 AND key = $2 AND user_id = $3", object.Collection, object.Key, ownerID).Scan(&dbVersion, &dbPermissionRead, &dbPermissionWrite) - if err != nil { - if err == sql.ErrNoRows { - if object.Version != "" && object.Version != "*" { - // Conditional write with a specific version but the object did not exist at all. - metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection}, 1) + br := tx.SendBatch(ctx, batch) + defer br.Close() // TODO: need to "drain" batch, otherwise it logs all unprocessed queries + for _, op := range sortedOps { + object := op.Object + var resultRead int32 + var resultWrite int32 + var resultVersion string + err := br.QueryRow().Scan(&resultRead, &resultWrite, &resultVersion) + var pgErr *pgconn.PgError + if err != nil && errors.As(err, &pgErr) { + if pgErr.Code == dbErrorUniqueViolation { + metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection, "reason": "version"}, 1) return nil, runtime.ErrStorageRejectedVersion } - } else { - logger.Debug("Error in write storage object pre-flight.", zap.Any("object", object), zap.Error(err)) + return nil, err + } else if err == pgx.ErrNoRows { + // Not every case from storagePrepWriteObject can return NoRows, but those + // which do it is always ErrStorageRejectedVersion + metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection, "reason": "version"}, 1) + return nil, runtime.ErrStorageRejectedVersion + } else if err != nil { return nil, err } - } - - if dbVersion.Valid && (object.Version == "*" || (object.Version != "" && object.Version != dbVersion.String)) { - // An object existed, and it's a conditional write that either: - // - Expects no object. - // - Or expects a given version, but it does not match. - metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection}, 1) - return nil, runtime.ErrStorageRejectedVersion - } - - if dbPermissionWrite.Valid && dbPermissionWrite.Int64 == 0 && !authoritativeWrite { - // Non-authoritative write to an existing storage object with permission 0. - return nil, runtime.ErrStorageRejectedPermission - } - - newVersion := fmt.Sprintf("%x", md5.Sum([]byte(object.Value))) - newPermissionRead := int32(1) - if object.PermissionRead != nil { - newPermissionRead = object.PermissionRead.Value - } - newPermissionWrite := int32(1) - if object.PermissionWrite != nil { - newPermissionWrite = object.PermissionWrite.Value - } - if dbVersion.Valid && dbVersion.String == newVersion && dbPermissionRead.Int64 == int64(newPermissionRead) && dbPermissionWrite.Int64 == int64(newPermissionWrite) { - // Stored object existed, and exactly matches the new object's version and read/write permissions. + if !(op.permissionRead() == resultRead && + op.permissionWrite() == resultWrite && + op.expectedVersion() == resultVersion) { + // Write failed, it can happen for 3 reasons: + // - constraint violation on insert (handles elsewhere) + // - permission: non authoritative write & original row write != 1 + // - version mismatch + if !authoritativeWrite && resultWrite != 1 { + metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection, "reason": "permission"}, 1) + return nil, runtime.ErrStorageRejectedPermission + } else { + // version check failed + metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection, "reason": "version"}, 1) + return nil, runtime.ErrStorageRejectedVersion + } + } ack := &api.StorageObjectAck{ Collection: object.Collection, Key: object.Key, - Version: newVersion, + Version: resultVersion, + UserId: op.OwnerID, } - if ownerID != uuid.Nil.String() { - ack.UserId = ownerID - } - return ack, nil + acks[indexedOps[op]] = ack } + return acks, nil +} + +func storagePrepBatch(batch *pgx.Batch, authoritativeWrite bool, op *StorageOpWrite) { + object := op.Object + ownerID := op.OwnerID + + newVersion := op.expectedVersion() + newPermissionRead := op.permissionRead() + newPermissionWrite := op.permissionWrite() + params := []interface{}{object.Collection, object.Key, ownerID, object.Value, newVersion, newPermissionRead, newPermissionWrite} var query string + + writeCheck := "" + // Respect permissions in non-authoritative writes. + if !authoritativeWrite { + writeCheck = " AND storage.write = 1" + } + switch { case object.Version != "" && object.Version != "*": // OCC if-match. - query = "UPDATE storage SET value = $4, version = $5, read = $6, write = $7, update_time = now() WHERE collection = $1 AND key = $2 AND user_id = $3::UUID AND version = $8" + + // Query pattern + // (UPDATE t ... RETURNING) UNION ALL (SELECT FROM t) LIMIT 1 + // allows us to fetch row state after update even if update itself fails WHERE + // condition. + // That is returned values are final state of the row regardless of UPDATE success + query = ` + WITH upd AS ( + UPDATE storage SET value = $4, version = $5, read = $6, write = $7, update_time = now() + WHERE collection = $1 AND key = $2 AND user_id = $3::UUID AND version = $8 + ` + writeCheck + ` + AND NOT (storage.version = $5 AND storage.read = $6 AND storage.write = $7) -- micro optimization: don't update row unnecessary + RETURNING read, write, version + ) + (SELECT read, write, version from upd) + UNION ALL + (SELECT read, write, version FROM storage WHERE collection = $1 and key = $2 and user_id = $3) + LIMIT 1` + params = append(params, object.Version) - // Respect permissions in non-authoritative writes. - if !authoritativeWrite { - query += " AND write = 1" - } - case dbVersion.Valid && object.Version == "": - // An existing storage object was present, but no OCC of any kind is specified. - query = "UPDATE storage SET value = $4, version = $5, read = $6, write = $7, update_time = now() WHERE collection = $1 AND key = $2 AND user_id = $3::UUID" - // Respect permissions in non-authoritative writes. - if !authoritativeWrite { - query += " AND write = 1" - } - case !dbVersion.Valid && object.Version == "": - // An existing storage object was not present, and no OCC of any kind is specified. - // Separate to the case above to handle concurrent non-OCC object creations, where all but the first must become updates. - query = "INSERT INTO storage (collection, key, user_id, value, version, read, write, create_time, update_time) VALUES ($1, $2, $3::UUID, $4, $5, $6, $7, now(), now()) ON CONFLICT (collection, read, key, user_id) DO UPDATE SET value = $4, version = $5, read = $6, write = $7, update_time = now()" - // Respect permissions in non-authoritative writes, where this operation also loses the race to insert the object. - if !authoritativeWrite { - query += " WHERE storage.write = 1" - } - case dbVersion.Valid && object.Version != "*": - // An existing storage object was present, but no OCC if-not-exists required. - query = "UPDATE storage SET value = $4, version = $5, read = $6, write = $7, update_time = now() WHERE collection = $1 AND key = $2 AND user_id = $3::UUID AND version = $8" - params = append(params, dbVersion.String) - // Respect permissions in non-authoritative writes. - if !authoritativeWrite { - query += " AND write = 1" - } - default: + + // Outcomes: + // - No rows: if no rows returned, then object was not found in DB and can't be updated + // - We have row returned, but now we need to know if update happened, that is its WHERE matched + // * write != 1 means no permission to write + // * dbVersion != original version means OCC failure + + case object.Version == "": + // non-OCC write, "last write wins" kind of write + + // Similar pattern as in case above, but supports case when row + // didn't exist in the database. Another difference is that there is no version + // check for existing row. + query = ` + WITH upd AS ( + INSERT INTO storage (collection, key, user_id, value, version, read, write, create_time, update_time) + VALUES ($1, $2, $3::UUID, $4, $5, $6, $7, now(), now()) + ON CONFLICT (collection, key, user_id) DO + UPDATE SET value = $4, version = $5, read = $6, write = $7, update_time = now() + WHERE TRUE` + writeCheck + ` + AND NOT (storage.version = $5 AND storage.read = $6 AND storage.write = $7) -- micro optimization: don't update row unnecessary + RETURNING read, write, version + ) + (SELECT read, write, version from upd) + UNION ALL + (SELECT read, write, version FROM storage WHERE collection = $1 and key = $2 and user_id = $3) + LIMIT 1` + + // Outcomes: + // - Row is always returned, need to know if update happened, that is its WHERE matches + // - write != 1 means no permission to write + + case object.Version == "*": // OCC if-not-exists, and all other non-OCC cases. - query = "INSERT INTO storage (collection, key, user_id, value, version, read, write, create_time, update_time) VALUES ($1, $2, $3::UUID, $4, $5, $6, $7, now(), now())" // Existing permission checks are not applicable for new storage objects. - } - - res, err := tx.ExecContext(ctx, query, params...) - if err != nil { - logger.Debug("Could not write storage object, exec error.", zap.Any("object", object), zap.String("query", query), zap.Error(err)) - var pgErr *pgconn.PgError - if errors.As(err, &pgErr) && pgErr.Code == dbErrorUniqueViolation { - metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection}, 1) - return nil, runtime.ErrStorageRejectedVersion - } - return nil, err - } - if rowsAffected, err := res.RowsAffected(); rowsAffected != 1 { - logger.Debug("Could not write storage object, rowsAffected error.", zap.Any("object", object), zap.String("query", query), zap.Error(err)) - metrics.StorageWriteRejectCount(map[string]string{"collection": object.Collection}, 1) - return nil, runtime.ErrStorageRejectedVersion - } + query = ` + INSERT INTO storage (collection, key, user_id, value, version, read, write, create_time, update_time) + VALUES ($1, $2, $3::UUID, $4, $5, $6, $7, now(), now()) + RETURNING read, write, version` - ack := &api.StorageObjectAck{ - Collection: object.Collection, - Key: object.Key, - Version: newVersion, - UserId: ownerID, + // Outcomes: + // - NoRows - insert failed due to constraint violation (concurrent insert) } - return ack, nil + batch.Queue(query, params...) } func StorageDeleteObjects(ctx context.Context, logger *zap.Logger, db *sql.DB, storageIndex StorageIndex, authoritativeDelete bool, ops StorageOpDeletes) (codes.Code, error) { diff --git a/server/core_storage_test.go b/server/core_storage_test.go index ada4a6dda..7ea3d5b2f 100644 --- a/server/core_storage_test.go +++ b/server/core_storage_test.go @@ -17,11 +17,14 @@ package server import ( "context" "crypto/md5" + "database/sql" + "encoding/hex" "fmt" "testing" "github.com/gofrs/uuid/v5" "github.com/heroiclabs/nakama-common/api" + "github.com/heroiclabs/nakama-common/runtime" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/wrapperspb" @@ -2078,3 +2081,146 @@ func TestStorageListNoRepeats(t *testing.T) { assert.Len(t, values.Objects, 7, "values length was not 7") assert.Equal(t, "", values.Cursor, "cursor was not nil") } + +// DB State and expected outcome when performing write op +type writeTestDBState struct { + write int + v string + expectedCode codes.Code + expectedError error + descr string +} + +// Test no OCC, last write wins ("") +func TestNonOCCNonAuthoritative(t *testing.T) { + v := "{}" + + statesOutcomes := []writeTestDBState{ + {0, "", codes.OK, nil, "did not exists"}, + {1, v, codes.OK, nil, "existed and permission allows write, version match"}, + {1, `{"a":1}`, codes.OK, nil, "existed and permission allows write, version does not match"}, + {0, v, codes.InvalidArgument, runtime.ErrStorageRejectedPermission, "existed and permission reject, version match"}, + {0, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedPermission, "existed and permission reject, version does not match"}, + } + testWrite(t, `{"newV": true}`, "", 1, false, statesOutcomes) +} + +// Test no OCC, last write wins ("") +func TestNonOCCAuthoritative(t *testing.T) { + v := "{}" + + statesOutcomes := []writeTestDBState{ + {0, "", codes.OK, nil, "did not exists"}, + {1, v, codes.OK, nil, "existed and permission allows write, version match"}, + {1, `{"a":1}`, codes.OK, nil, "existed and permission allows write, version does not match"}, + {0, v, codes.OK, nil, "existed and permission reject, version match"}, + {0, `{"a":1}`, codes.OK, nil, "existed and permission reject, version does not match"}, + } + testWrite(t, `{"newV": true}`, "", 1, true, statesOutcomes) +} + +// Test when OCC requires non-existing object ("*") +func TestOCCNotExistsAuthoritative(t *testing.T) { + v := "{}" + + statesOutcomes := []writeTestDBState{ + {0, "", codes.OK, nil, "did not exists"}, + {1, v, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission allows write, version match"}, + {1, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission allows write, version does not match"}, + {0, v, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission reject, version match"}, + {0, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission reject, version does not match"}, + } + testWrite(t, `{"newV": true}`, "*", 1, true, statesOutcomes) +} + +// Test when OCC requires non-existing object ("*") +func TestOCCNotExistsNonAuthoritative(t *testing.T) { + v := "{}" + + statesOutcomes := []writeTestDBState{ + {0, "", codes.OK, nil, "did not exists"}, + {1, v, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission allows write, version match"}, + {1, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission allows write, version does not match"}, + {0, v, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission reject, version match"}, + {0, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission reject, version does not match"}, + } + testWrite(t, `{"newV": true}`, "*", 1, false, statesOutcomes) +} + +// Test when OCC requires existing object with known version ('#hash#') +func TestOCCWriteNonAuthoritative(t *testing.T) { + v := "{}" + + statesOutcomes := []writeTestDBState{ + {0, "", codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "did not exists"}, + {1, v, codes.OK, nil, "existed and permission allows write, version match"}, + {1, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission allows write, version does not match"}, + {0, v, codes.InvalidArgument, runtime.ErrStorageRejectedPermission, "existed and permission reject, version match"}, + {0, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedPermission, "existed and permission reject, version does not match"}, + } + testWrite(t, `{"newV": true}`, v, 1, false, statesOutcomes) +} + +// Test when OCC requires existing object with known version ('#hash#') +func TestOCCWriteAuthoritative(t *testing.T) { + v := "{}" + + statesOutcomes := []writeTestDBState{ + {0, "", codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "did not exists"}, + {1, v, codes.OK, nil, "existed and permission allows write, version match"}, + {1, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission allows write, version does not match"}, + {0, v, codes.OK, nil, "existed and permission reject, version match"}, + {0, `{"a":1}`, codes.InvalidArgument, runtime.ErrStorageRejectedVersion, "existed and permission reject, version does not match"}, + } + testWrite(t, `{"newV": true}`, v, 1, true, statesOutcomes) +} + +func testWrite(t *testing.T, newVal, prevVal string, permWrite int, authoritative bool, states []writeTestDBState) { + db := NewDB(t) + defer db.Close() + + collection, userId := "testcollection", uuid.Must(uuid.NewV4()) + InsertUser(t, db, userId) + + for _, w := range states { + t.Run(w.descr, func(t *testing.T) { + key := GenerateString() + // Prepare DB with expected state + if w.v != "" { + if _, _, err := writeObject(t, db, collection, key, userId, w.write, w.v, "", true); err != nil { + t.Fatal(err) + } + } + + var version string + if prevVal != "" && prevVal != "*" { + hash := md5.Sum([]byte(prevVal)) + version = hex.EncodeToString(hash[:]) + } else { + version = prevVal + } + + // Test writing object and assert expected result + _, code, err := writeObject(t, db, collection, key, userId, permWrite, newVal, version, authoritative) + if code != w.expectedCode || err != w.expectedError { + t.Errorf("Failed: code=%d (expected=%d) err=%v", code, w.expectedCode, err) + } + }) + } +} + +func writeObject(t *testing.T, db *sql.DB, collection, key string, owner uuid.UUID, writePerm int, newV, version string, authoritative bool) (*api.StorageObjectAcks, codes.Code, error) { + t.Helper() + ops := StorageOpWrites{&StorageOpWrite{ + OwnerID: owner.String(), + Object: &api.WriteStorageObject{ + Collection: collection, + Key: key, + Value: newV, + Version: version, + PermissionRead: &wrapperspb.Int32Value{Value: 2}, + PermissionWrite: &wrapperspb.Int32Value{Value: int32(writePerm)}, + }, + }} + return StorageWriteObjects(context.Background(), logger, db, metrics, storageIdx, authoritative, ops) +} diff --git a/server/core_wallet.go b/server/core_wallet.go index 291f20cbc..ed3550f96 100644 --- a/server/core_wallet.go +++ b/server/core_wallet.go @@ -30,6 +30,7 @@ import ( "github.com/gofrs/uuid/v5" "github.com/heroiclabs/nakama-common/runtime" "github.com/jackc/pgtype" + pgx "github.com/jackc/pgx/v4" "go.uber.org/zap" ) @@ -89,7 +90,7 @@ func UpdateWallets(ctx context.Context, logger *zap.Logger, db *sql.DB, updates var results []*runtime.WalletUpdateResult - if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { + if err := ExecuteInTxPgx(ctx, db, func(tx pgx.Tx) error { var updateErr error results, updateErr = updateWallets(ctx, logger, tx, updates, updateLedger) if updateErr != nil { @@ -110,7 +111,7 @@ func UpdateWallets(ctx context.Context, logger *zap.Logger, db *sql.DB, updates return results, nil } -func updateWallets(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates []*walletUpdate, updateLedger bool) ([]*runtime.WalletUpdateResult, error) { +func updateWallets(ctx context.Context, logger *zap.Logger, tx pgx.Tx, updates []*walletUpdate, updateLedger bool) ([]*runtime.WalletUpdateResult, error) { if len(updates) == 0 { return nil, nil } @@ -126,7 +127,7 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates // Select the wallets from the DB and decode them. wallets := make(map[string]map[string]int64, len(updates)) - rows, err := tx.QueryContext(ctx, initialQuery, initialParams...) + rows, err := tx.Query(ctx, initialQuery, initialParams...) if err != nil { logger.Debug("Error retrieving user wallets.", zap.Error(err)) return nil, err @@ -136,7 +137,7 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates var wallet sql.NullString err = rows.Scan(&id, &wallet) if err != nil { - _ = rows.Close() + rows.Close() logger.Debug("Error reading user wallets.", zap.Error(err)) return nil, err } @@ -144,14 +145,14 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates var walletMap map[string]int64 err = json.Unmarshal([]byte(wallet.String), &walletMap) if err != nil { - _ = rows.Close() + rows.Close() logger.Debug("Error converting user wallet.", zap.String("user_id", id), zap.Error(err)) return nil, err } wallets[id] = walletMap } - _ = rows.Close() + rows.Close() results := make([]*runtime.WalletUpdateResult, 0, len(updates)) @@ -232,7 +233,7 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates logger.Warn("Missing wallet update for user.", zap.String("user_id", userID)) continue } - _, err = tx.ExecContext(ctx, "UPDATE users SET update_time = now(), wallet = $2 WHERE id = $1", userID, updatedWallet) + _, err = tx.Exec(ctx, "UPDATE users SET update_time = now(), wallet = $2 WHERE id = $1", userID, updatedWallet) if err != nil { logger.Debug("Error writing user wallet.", zap.String("user_id", userID), zap.Error(err)) return nil, err @@ -241,7 +242,7 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates // Write the ledger updates, if any. if updateLedger && (len(statements) > 0) { - _, err = tx.ExecContext(ctx, "INSERT INTO wallet_ledger (id, user_id, changeset, metadata) VALUES "+strings.Join(statements, ", "), params...) + _, err = tx.Exec(ctx, "INSERT INTO wallet_ledger (id, user_id, changeset, metadata) VALUES "+strings.Join(statements, ", "), params...) if err != nil { logger.Debug("Error writing user wallet ledgers.", zap.Error(err)) return nil, err diff --git a/server/db.go b/server/db.go index d148170f4..651533192 100644 --- a/server/db.go +++ b/server/db.go @@ -28,6 +28,7 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgerrcode" + pgx "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/stdlib" "go.uber.org/zap" ) @@ -236,16 +237,16 @@ func ExecuteRetryable(fn func() error) error { // fn is subject to the same restrictions as the fn passed to ExecuteTx. func ExecuteInTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) error { if isCockroach { - return ExecuteInTxCockroach(ctx, db, fn) + return executeInTxCockroach(ctx, db, fn) } else { - return ExecuteInTxPostgres(ctx, db, fn) + return executeInTxPostgres(ctx, db, fn) } } // Retries fn() if transaction commit returned retryable error code // Every call to fn() happens in its own transaction. On retry previous transaction // is ROLLBACK'ed and new transaction is opened. -func ExecuteInTxPostgres(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (err error) { +func executeInTxPostgres(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (err error) { var tx *sql.Tx defer func() { if tx != nil { @@ -283,7 +284,7 @@ func ExecuteInTxPostgres(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error // It has special optimization for `SAVEPOINT cockroach_restart`, called "retry savepoint", // which increases transaction priority every time it has to ROLLBACK due to serialization conflicts. // See: https://www.cockroachlabs.com/docs/stable/advanced-client-side-transaction-retries.html -func ExecuteInTxCockroach(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) error { +func executeInTxCockroach(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) error { tx, err := db.BeginTx(ctx, nil) if err != nil { // Can fail only if undernath connection is broken return err @@ -334,6 +335,125 @@ func ExecuteInTxCockroach(ctx context.Context, db *sql.DB, fn func(*sql.Tx) erro return err } +// Same as ExecuteInTx, but passes pgx.Tx to callback +func ExecuteInTxPgx(ctx context.Context, db *sql.DB, fn func(pgx.Tx) error) error { + if isCockroach { + return executeInTxCockroachPgx(ctx, db, fn) + } else { + return executeInTxPostgresPgx(ctx, db, fn) + } +} + +// Retries fn() if transaction commit returned retryable error code +// Every call to fn() happens in its own transaction. On retry previous transaction +// is ROLLBACK'ed and new transaction is opened. +func executeInTxPostgresPgx(ctx context.Context, db *sql.DB, fn func(pgx.Tx) error) (err error) { + conn, err := db.Conn(ctx) + if err != nil { + return err + } + defer conn.Close() + return conn.Raw(func(driverConn any) error { + conn := driverConn.(*stdlib.Conn).Conn() + + var tx pgx.Tx + defer func() { + if tx != nil { + _ = tx.Rollback(ctx) + } + }() + + // Prevent infinite loop (unlikely, but possible) + for i := 0; i < 5; i++ { + if tx, err = conn.BeginTx(ctx, pgx.TxOptions{}); err != nil { // Can fail only if undernath connection is broken + tx = nil + return err + } + if err = fn(tx); err == nil { + err = tx.Commit(ctx) + } + var pgErr *pgconn.PgError + if errors.As(errorCause(err), &pgErr) && pgErr.Code[:2] == "40" { + // 40XXXX codes are retriable errors + if err = tx.Rollback(ctx); err != nil && err != sql.ErrTxDone { + tx = nil + return err + } + continue + } else { + // Exit on successfull Commit or non retriable error + return err + } + } + // Stop trying after 5 attempts and return last op error + return err + }) +} + +// CockroachDB has it's own way to resolve serialization conflicts. +// It has special optimization for `SAVEPOINT cockroach_restart`, called "retry savepoint", +// which increases transaction priority every time it has to ROLLBACK due to serialization conflicts. +// See: https://www.cockroachlabs.com/docs/stable/advanced-client-side-transaction-retries.html +func executeInTxCockroachPgx(ctx context.Context, db *sql.DB, fn func(pgx.Tx) error) error { + conn, err := db.Conn(ctx) + if err != nil { + return err + } + defer conn.Close() + + return conn.Raw(func(driverConn any) error { + conn := driverConn.(*stdlib.Conn).Conn() + tx, err := conn.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { // Can fail only if undernath connection is broken + return err + } + defer func() { + if err == nil { + // Ignore commit errors. The tx has already been committed by RELEASE. + _ = tx.Commit(ctx) + } else { + // We always need to execute a Rollback() so sql.DB releases the + // connection. + _ = tx.Rollback(ctx) + } + }() + // Specify that we intend to retry this txn in case of database retryable errors. + if _, err = tx.Exec(ctx, "SAVEPOINT cockroach_restart"); err != nil { + return err + } + + // Prevent infinite loop (unlikely, but possible) + for i := 0; i < 5; i++ { + released := false + err = fn(tx) + if err == nil { + // RELEASE acts like COMMIT in CockroachDB. We use it since it gives us an + // opportunity to react to retryable errors, whereas tx.Commit() doesn't. + released = true + if _, err = tx.Exec(ctx, "RELEASE SAVEPOINT cockroach_restart"); err == nil { + return nil + } + } + // We got an error; let's see if it's a retryable one and, if so, restart. We look + // for either the standard PG errcode SerializationFailureError:40001 or the Cockroach extension + // errcode RetriableError:CR000. The Cockroach extension has been removed server-side, but support + // for it has been left here for now to maintain backwards compatibility. + var pgErr *pgconn.PgError + if retryable := errors.As(errorCause(err), &pgErr) && (pgErr.Code == "CR000" || pgErr.Code == pgerrcode.SerializationFailure); !retryable { + if released { + err = newAmbiguousCommitError(err) + } + return err + } + if _, retryErr := tx.Exec(ctx, "ROLLBACK TO SAVEPOINT cockroach_restart"); retryErr != nil { + return newTxnRestartError(retryErr, err) + } + } + // Stop trying after 5 attempts and return last op error + return err + }) +} + type int64Tuple struct { Tuple []int64 Valid bool // Valid is true if Tuple is not NULL diff --git a/server/leaderboard_rank_cache_test.go b/server/leaderboard_rank_cache_test.go index c8dc1c2e4..6c8c8aded 100644 --- a/server/leaderboard_rank_cache_test.go +++ b/server/leaderboard_rank_cache_test.go @@ -15,10 +15,11 @@ package server import ( + "testing" + "github.com/gofrs/uuid/v5" "github.com/heroiclabs/nakama-common/api" "github.com/stretchr/testify/assert" - "testing" ) func TestLocalLeaderboardRankCache_Insert_Ascending(t *testing.T) { diff --git a/server/match_common_test.go b/server/match_common_test.go index af1eb7ff1..a222237fa 100644 --- a/server/match_common_test.go +++ b/server/match_common_test.go @@ -17,16 +17,17 @@ package server import ( "context" "database/sql" + "os" + "strconv" + "testing" + "time" + "github.com/gofrs/uuid/v5" "github.com/heroiclabs/nakama-common/rtapi" "github.com/heroiclabs/nakama-common/runtime" "go.uber.org/atomic" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "os" - "strconv" - "testing" - "time" ) // loggerForTest allows for easily adjusting log output produced by tests in one place diff --git a/server/match_presence_test.go b/server/match_presence_test.go index 0b62a7638..6f6aa54f3 100644 --- a/server/match_presence_test.go +++ b/server/match_presence_test.go @@ -15,8 +15,9 @@ package server import ( - "github.com/gofrs/uuid/v5" "testing" + + "github.com/gofrs/uuid/v5" ) func TestMatchPresenceList(t *testing.T) { -- GitLab