Commit 177b8ea6 authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Add runtime multi-update functions. (#455)

parent 6531ac1c
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -7,11 +7,12 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr
### Added
- Support for Apple Sign In authentication, linking, and unlinking.
- Wallet operations now return the updated and previous state of the wallet.
- New runtime multi-update function.

### Changed
- Sanitize metric names and properties fields.
- Wallet operations now use int64 values for all numeric operations.
- Update to nakama-common 1.6.0 release.
- Update to nakama-common 1.7.3 release.

### Fixed
- Prevent bad presence list input to dispatcher message broadcasts from causing unexpected errors.
+1 −1
Original line number Diff line number Diff line
@@ -29,7 +29,7 @@ require (
	github.com/gorilla/mux v1.7.4
	github.com/gorilla/websocket v1.4.2
	github.com/grpc-ecosystem/grpc-gateway v1.13.0
	github.com/heroiclabs/nakama-common v1.7.2
	github.com/heroiclabs/nakama-common v1.7.3
	github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 // indirect
	github.com/jackc/pgx v3.5.0+incompatible
	github.com/jmhodges/levigo v1.0.0 // indirect
+2 −0
+10 −1
Original line number Diff line number Diff line
@@ -103,7 +103,16 @@ func (s *ApiServer) UpdateAccount(ctx context.Context, in *api.UpdateAccountRequ
		}
	}

	err := UpdateAccount(ctx, s.logger, s.db, userID, username, in.GetDisplayName(), in.GetTimezone(), in.GetLocation(), in.GetLangTag(), in.GetAvatarUrl(), nil)
	err := UpdateAccounts(ctx, s.logger, s.db, []*accountUpdate{{
		userID:      userID,
		username:    username,
		displayName: in.GetDisplayName(),
		timezone:    in.GetTimezone(),
		location:    in.GetLocation(),
		langTag:     in.GetLangTag(),
		avatarURL:   in.GetAvatarUrl(),
		metadata:    nil,
	}})
	if err != nil {
		if _, ok := err.(pgx.PgError); ok {
			return nil, status.Error(codes.Internal, "Error while trying to update account.")
+118 −80
Original line number Diff line number Diff line
@@ -35,6 +35,18 @@ import (

var ErrAccountNotFound = errors.New("account not found")

// Not an API entity, only used to receive data from runtime environment.
type accountUpdate struct {
	userID      uuid.UUID
	username    string
	displayName *wrappers.StringValue
	timezone    *wrappers.StringValue
	location    *wrappers.StringValue
	langTag     *wrappers.StringValue
	avatarURL   *wrappers.StringValue
	metadata    *wrappers.StringValue
}

func GetAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker Tracker, userID uuid.UUID) (*api.Account, error) {
	var displayName sql.NullString
	var username sql.NullString
@@ -228,25 +240,50 @@ WHERE u.id IN (` + strings.Join(statements, ",") + `)`
	return accounts, nil
}

func UpdateAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, userID uuid.UUID, username string, displayName, timezone, location, langTag, avatarURL, metadata *wrappers.StringValue) error {
func UpdateAccounts(ctx context.Context, logger *zap.Logger, db *sql.DB, updates []*accountUpdate) error {
	tx, err := db.BeginTx(ctx, nil)
	if err != nil {
		logger.Error("Could not begin database transaction.", zap.Error(err))
		return err
	}

	if err = ExecuteInTx(ctx, tx, func() error {
		updateErr := updateAccounts(ctx, logger, tx, updates)
		if updateErr != nil {
			return updateErr
		}
		return nil
	}); err != nil {
		if e, ok := err.(*statusError); ok {
			return e.Cause()
		}
		logger.Error("Error updating user accounts.", zap.Error(err))
		return err
	}

	return nil
}

func updateAccounts(ctx context.Context, logger *zap.Logger, tx *sql.Tx, updates []*accountUpdate) error {
	for _, update := range updates {
		updateStatements := make([]string, 0, 7)
		distinctStatements := make([]string, 0, 7)
		params := make([]interface{}, 0, 8)

		// Ensure user ID is always present.
	params = append(params, userID)
		params = append(params, update.userID)

	if username != "" {
		if invalidCharsRegex.MatchString(username) {
		if update.username != "" {
			if invalidCharsRegex.MatchString(update.username) {
				return errors.New("Username invalid, no spaces or control characters allowed.")
			}
		params = append(params, username)
			params = append(params, update.username)
			updateStatements = append(updateStatements, "username = $"+strconv.Itoa(len(params)))
			distinctStatements = append(distinctStatements, "username IS DISTINCT FROM $"+strconv.Itoa(len(params)))
		}

	if displayName != nil {
		if d := displayName.GetValue(); d == "" {
		if update.displayName != nil {
			if d := update.displayName.GetValue(); d == "" {
				updateStatements = append(updateStatements, "display_name = NULL")
				distinctStatements = append(distinctStatements, "display_name IS NOT NULL")
			} else {
@@ -256,8 +293,8 @@ func UpdateAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, userID u
			}
		}

	if timezone != nil {
		if t := timezone.GetValue(); t == "" {
		if update.timezone != nil {
			if t := update.timezone.GetValue(); t == "" {
				updateStatements = append(updateStatements, "timezone = NULL")
				distinctStatements = append(distinctStatements, "timezone IS NOT NULL")
			} else {
@@ -267,8 +304,8 @@ func UpdateAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, userID u
			}
		}

	if location != nil {
		if l := location.GetValue(); l == "" {
		if update.location != nil {
			if l := update.location.GetValue(); l == "" {
				updateStatements = append(updateStatements, "location = NULL")
				distinctStatements = append(distinctStatements, "location IS NOT NULL")
			} else {
@@ -278,8 +315,8 @@ func UpdateAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, userID u
			}
		}

	if langTag != nil {
		if l := langTag.GetValue(); l == "" {
		if update.langTag != nil {
			if l := update.langTag.GetValue(); l == "" {
				updateStatements = append(updateStatements, "lang_tag = NULL")
				distinctStatements = append(distinctStatements, "lang_tag IS NOT NULL")
			} else {
@@ -289,8 +326,8 @@ func UpdateAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, userID u
			}
		}

	if avatarURL != nil {
		if a := avatarURL.GetValue(); a == "" {
		if update.avatarURL != nil {
			if a := update.avatarURL.GetValue(); a == "" {
				updateStatements = append(updateStatements, "avatar_url = NULL")
				distinctStatements = append(distinctStatements, "avatar_url IS NOT NULL")
			} else {
@@ -300,8 +337,8 @@ func UpdateAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, userID u
			}
		}

	if metadata != nil {
		params = append(params, metadata.GetValue())
		if update.metadata != nil {
			params = append(params, update.metadata.GetValue())
			updateStatements = append(updateStatements, "metadata = $"+strconv.Itoa(len(params)))
			distinctStatements = append(distinctStatements, "metadata IS DISTINCT FROM $"+strconv.Itoa(len(params)))
		}
@@ -313,20 +350,21 @@ func UpdateAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, userID u
		query := "UPDATE users SET update_time = now(), " + strings.Join(updateStatements, ", ") +
			" WHERE id = $1 AND (" + strings.Join(distinctStatements, " OR ") + ")"

	if _, err := db.ExecContext(ctx, query, params...); err != nil {
		if _, err := tx.ExecContext(ctx, query, params...); err != nil {
			if e, ok := err.(pgx.PgError); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") {
				return errors.New("Username is already in use.")
			}

			logger.Error("Could not update user account.", zap.Error(err),
			zap.String("username", username),
			zap.Any("display_name", displayName),
			zap.Any("timezone", timezone),
			zap.Any("location", location),
			zap.Any("lang_tag", langTag),
			zap.Any("avatar_url", avatarURL))
				zap.String("username", update.username),
				zap.Any("display_name", update.displayName),
				zap.Any("timezone", update.timezone),
				zap.Any("location", update.location),
				zap.Any("lang_tag", update.langTag),
				zap.Any("avatar_url", update.avatarURL))
			return err
		}
	}

	return nil
}
Loading