Commit 33eb0ee3 authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Add runtime batch account get function.

parent b0bcf00a
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr
- Allow client email authentication requests to optionally authenticate with username/password instead of email/password.
- Allow runtime email authentication calls to authenticate with username/password instead of email/password.
- New authoritative match dispatcher function to defer message broadcasts until the end of the tick.
- New runtime function to retrieve multiple user accounts by user ID.

### Changed
- Replace standard logger supplied to the Go runtime with a more powerful interface.
+1 −0
Original line number Diff line number Diff line
@@ -744,6 +744,7 @@ type NakamaModule interface {
	AuthenticateTokenGenerate(userID, username string, exp int64) (string, int64, error)

	AccountGetId(ctx context.Context, userID string) (*api.Account, error)
	AccountsGetId(ctx context.Context, userIDs []string) ([]*api.Account, error)
	AccountUpdateId(ctx context.Context, userID, username string, metadata map[string]interface{}, displayName, timezone, location, langTag, avatarUrl string) error

	UsersGetId(ctx context.Context, userIDs []string) ([]*api.User, error)
+110 −33
Original line number Diff line number Diff line
@@ -36,7 +36,7 @@ func GetAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker Tra
	var username sql.NullString
	var avatarURL sql.NullString
	var langTag sql.NullString
	var locat sql.NullString
	var location sql.NullString
	var timezone sql.NullString
	var metadata sql.NullString
	var wallet sql.NullString
@@ -46,19 +46,20 @@ func GetAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker Tra
	var gamecenter sql.NullString
	var steam sql.NullString
	var customID sql.NullString
	var edge_count int
	var edgeCount int
	var createTime pq.NullTime
	var updateTime pq.NullTime
	var verifyTime pq.NullTime
	var deviceIDs pq.StringArray

	query := `
SELECT username, display_name, avatar_url, lang_tag, location, timezone, metadata, wallet,
	email, facebook_id, google_id, gamecenter_id, steam_id, custom_id, edge_count,
	create_time, update_time, verify_time
FROM users
WHERE id = $1`
SELECT u.username, u.display_name, u.avatar_url, u.lang_tag, u.location, u.timezone, u.metadata, u.wallet,
	u.email, u.facebook_id, u.google_id, u.gamecenter_id, u.steam_id, u.custom_id, u.edge_count,
	u.create_time, u.update_time, u.verify_time, array(select ud.id from user_device ud where u.id = ud.user_id)
FROM users u
WHERE u.id = $1`

	if err := db.QueryRowContext(ctx, query, userID).Scan(&username, &displayName, &avatarURL, &langTag, &locat, &timezone, &metadata, &wallet, &email, &facebook, &google, &gamecenter, &steam, &customID, &edge_count, &createTime, &updateTime, &verifyTime); err != nil {
	if err := db.QueryRowContext(ctx, query, userID).Scan(&username, &displayName, &avatarURL, &langTag, &location, &timezone, &metadata, &wallet, &email, &facebook, &google, &gamecenter, &steam, &customID, &edgeCount, &createTime, &updateTime, &verifyTime, &deviceIDs); err != nil {
		if err == sql.ErrNoRows {
			return nil, ErrAccountNotFound
		}
@@ -66,65 +67,141 @@ WHERE id = $1`
		return nil, err
	}

	rows, err := db.QueryContext(ctx, "SELECT id FROM user_device WHERE user_id = $1", userID)
	devices := make([]*api.AccountDevice, 0, len(deviceIDs))
	for _, deviceID := range deviceIDs {
		devices = append(devices, &api.AccountDevice{Id: deviceID})
	}

	var verifyTimestamp *timestamp.Timestamp = nil
	if verifyTime.Valid && verifyTime.Time.Unix() != 0 {
		verifyTimestamp = &timestamp.Timestamp{Seconds: verifyTime.Time.Unix()}
	}

	online := false
	if tracker != nil {
		online = tracker.StreamExists(PresenceStream{Mode: StreamModeNotifications, Subject: userID})
	}

	return &api.Account{
		User: &api.User{
			Id:           userID.String(),
			Username:     username.String,
			DisplayName:  displayName.String,
			AvatarUrl:    avatarURL.String,
			LangTag:      langTag.String,
			Location:     location.String,
			Timezone:     timezone.String,
			Metadata:     metadata.String,
			FacebookId:   facebook.String,
			GoogleId:     google.String,
			GamecenterId: gamecenter.String,
			SteamId:      steam.String,
			EdgeCount:    int32(edgeCount),
			CreateTime:   &timestamp.Timestamp{Seconds: createTime.Time.Unix()},
			UpdateTime:   &timestamp.Timestamp{Seconds: updateTime.Time.Unix()},
			Online:       online,
		},
		Wallet:     wallet.String,
		Email:      email.String,
		Devices:    devices,
		CustomId:   customID.String,
		VerifyTime: verifyTimestamp,
	}, nil
}

func GetAccounts(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker Tracker, userIDs []string) ([]*api.Account, error) {
	statements := make([]string, 0, len(userIDs))
	parameters := make([]interface{}, 0, len(userIDs))
	for _, userID := range userIDs {
		parameters = append(parameters, userID)
		statements = append(statements, "$"+strconv.Itoa(len(parameters)))
	}

	query := `
SELECT u.id, u.username, u.display_name, u.avatar_url, u.lang_tag, u.location, u.timezone, u.metadata, u.wallet,
	u.email, u.facebook_id, u.google_id, u.gamecenter_id, u.steam_id, u.custom_id, u.edge_count,
	u.create_time, u.update_time, u.verify_time, array(select ud.id from user_device ud where u.id = ud.user_id)
FROM users u
WHERE u.id IN (` + strings.Join(statements, ",") + `)`
	rows, err := db.QueryContext(ctx, query, parameters...)
	if err != nil {
		logger.Error("Error retrieving user account.", zap.Error(err))
		logger.Error("Error retrieving user accounts.", zap.Error(err))
		return nil, err
	}
	defer rows.Close()

	deviceIDs := make([]*api.AccountDevice, 0)
	accounts := make([]*api.Account, 0, len(userIDs))
	for rows.Next() {
		var deviceID sql.NullString
		err = rows.Scan(&deviceID)
		var userID string
		var username sql.NullString
		var displayName sql.NullString
		var avatarURL sql.NullString
		var langTag sql.NullString
		var location sql.NullString
		var timezone sql.NullString
		var metadata sql.NullString
		var wallet sql.NullString
		var email sql.NullString
		var facebook sql.NullString
		var google sql.NullString
		var gamecenter sql.NullString
		var steam sql.NullString
		var customID sql.NullString
		var edgeCount int
		var createTime pq.NullTime
		var updateTime pq.NullTime
		var verifyTime pq.NullTime
		var deviceIDs pq.StringArray

		err = rows.Scan(&userID, &username, &displayName, &avatarURL, &langTag, &location, &timezone, &metadata, &wallet, &email, &facebook, &google, &gamecenter, &steam, &customID, &edgeCount, &createTime, &updateTime, &verifyTime, &deviceIDs)
		if err != nil {
			logger.Error("Error retrieving user account.", zap.Error(err))
			logger.Error("Error retrieving user accounts.", zap.Error(err))
			return nil, err
		}
		if deviceID.Valid {
			deviceIDs = append(deviceIDs, &api.AccountDevice{Id: deviceID.String})
		}
	}
	if err = rows.Err(); err != nil {
		logger.Error("Error retrieving user account.", zap.Error(err))
		return nil, err

		devices := make([]*api.AccountDevice, 0, len(deviceIDs))
		for _, deviceID := range deviceIDs {
			devices = append(devices, &api.AccountDevice{Id: deviceID})
		}

	var verifyTimestamp *timestamp.Timestamp = nil
		var verifyTimestamp *timestamp.Timestamp
		if verifyTime.Valid && verifyTime.Time.Unix() != 0 {
			verifyTimestamp = &timestamp.Timestamp{Seconds: verifyTime.Time.Unix()}
		}

		online := false
		if tracker != nil {
		online = tracker.StreamExists(PresenceStream{Mode: StreamModeNotifications, Subject: userID})
			online = tracker.StreamExists(PresenceStream{Mode: StreamModeNotifications, Subject: uuid.FromStringOrNil(userID)})
		}

	return &api.Account{
		accounts = append(accounts, &api.Account{
			User: &api.User{
			Id:           userID.String(),
				Id:           userID,
				Username:     username.String,
				DisplayName:  displayName.String,
				AvatarUrl:    avatarURL.String,
				LangTag:      langTag.String,
			Location:     locat.String,
				Location:     location.String,
				Timezone:     timezone.String,
				Metadata:     metadata.String,
				FacebookId:   facebook.String,
				GoogleId:     google.String,
				GamecenterId: gamecenter.String,
				SteamId:      steam.String,
			EdgeCount:    int32(edge_count),
				EdgeCount:    int32(edgeCount),
				CreateTime:   &timestamp.Timestamp{Seconds: createTime.Time.Unix()},
				UpdateTime:   &timestamp.Timestamp{Seconds: updateTime.Time.Unix()},
				Online:       online,
			},
			Wallet:     wallet.String,
			Email:      email.String,
		Devices:    deviceIDs,
			Devices:    devices,
			CustomId:   customID.String,
			VerifyTime: verifyTimestamp,
	}, nil
		})
	}

	return accounts, nil
}

func UpdateAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, userID uuid.UUID, username string, displayName, timezone, location, langTag, avatarURL, metadata *wrappers.StringValue) error {
+14 −0
Original line number Diff line number Diff line
@@ -273,6 +273,20 @@ func (n *RuntimeGoNakamaModule) AccountGetId(ctx context.Context, userID string)
	return GetAccount(ctx, n.logger, n.db, n.tracker, u)
}

func (n *RuntimeGoNakamaModule) AccountsGetId(ctx context.Context, userIDs []string) ([]*api.Account, error) {
	if len(userIDs) == 0 {
		return make([]*api.Account, 0), nil
	}

	for _, id := range userIDs {
		if _, err := uuid.FromString(id); err != nil {
			return nil, errors.New("each user id must be a valid id string")
		}
	}

	return GetAccounts(ctx, n.logger, n.db, n.tracker, userIDs)
}

func (n *RuntimeGoNakamaModule) AccountUpdateId(ctx context.Context, userID, username string, metadata map[string]interface{}, displayName, timezone, location, langTag, avatarUrl string) error {
	u, err := uuid.FromString(userID)
	if err != nil {
+113 −0
Original line number Diff line number Diff line
@@ -151,6 +151,7 @@ func (n *RuntimeLuaNakamaModule) Loader(l *lua.LState) int {
		"logger_warn":                 n.loggerWarn,
		"logger_error":                n.loggerError,
		"account_get_id":              n.accountGetId,
		"accounts_get_id":             n.accountsGetId,
		"account_update_id":           n.accountUpdateId,
		"users_get_id":                n.usersGetId,
		"users_get_username":          n.usersGetUsername,
@@ -1495,6 +1496,118 @@ func (n *RuntimeLuaNakamaModule) accountGetId(l *lua.LState) int {
	return 1
}

func (n *RuntimeLuaNakamaModule) accountsGetId(l *lua.LState) int {
	// Input table validation.
	input := l.OptTable(1, nil)
	if input == nil {
		l.ArgError(1, "invalid user id list")
		return 0
	}
	if input.Len() == 0 {
		l.Push(l.CreateTable(0, 0))
		return 1
	}

	userIDs := make([]string, 0, input.Len())
	var conversionError bool
	input.ForEach(func(k lua.LValue, v lua.LValue) {
		if conversionError {
			return
		}
		if v.Type() != lua.LTString {
			l.ArgError(1, "user id must be a string")
			conversionError = true
			return
		}
		vs := v.String()
		if _, err := uuid.FromString(vs); err != nil {
			l.ArgError(1, "user id must be a valid identifier string")
			conversionError = true
			return
		}
		userIDs = append(userIDs, vs)
	})
	if conversionError {
		return 0
	}

	accounts, err := GetAccounts(l.Context(), n.logger, n.db, n.tracker, userIDs)
	if err != nil {
		l.RaiseError("failed to get accounts: %s", err.Error())
		return 0
	}

	accountsTable := l.CreateTable(len(accounts), 0)
	for i, account := range accounts {
		accountTable := l.CreateTable(0, 21)
		accountTable.RawSetString("user_id", lua.LString(account.User.Id))
		accountTable.RawSetString("username", lua.LString(account.User.Username))
		accountTable.RawSetString("display_name", lua.LString(account.User.DisplayName))
		accountTable.RawSetString("avatar_url", lua.LString(account.User.AvatarUrl))
		accountTable.RawSetString("lang_tag", lua.LString(account.User.LangTag))
		accountTable.RawSetString("location", lua.LString(account.User.Location))
		accountTable.RawSetString("timezone", lua.LString(account.User.Timezone))
		if account.User.FacebookId != "" {
			accountTable.RawSetString("facebook_id", lua.LString(account.User.FacebookId))
		}
		if account.User.GoogleId != "" {
			accountTable.RawSetString("google_id", lua.LString(account.User.GoogleId))
		}
		if account.User.GamecenterId != "" {
			accountTable.RawSetString("gamecenter_id", lua.LString(account.User.GamecenterId))
		}
		if account.User.SteamId != "" {
			accountTable.RawSetString("steam_id", lua.LString(account.User.SteamId))
		}
		accountTable.RawSetString("online", lua.LBool(account.User.Online))
		accountTable.RawSetString("edge_count", lua.LNumber(account.User.EdgeCount))
		accountTable.RawSetString("create_time", lua.LNumber(account.User.CreateTime.Seconds))
		accountTable.RawSetString("update_time", lua.LNumber(account.User.UpdateTime.Seconds))

		metadataMap := make(map[string]interface{})
		err = json.Unmarshal([]byte(account.User.Metadata), &metadataMap)
		if err != nil {
			l.RaiseError(fmt.Sprintf("failed to convert metadata to json: %s", err.Error()))
			return 0
		}
		metadataTable := RuntimeLuaConvertMap(l, metadataMap)
		accountTable.RawSetString("metadata", metadataTable)

		walletMap := make(map[string]interface{})
		err = json.Unmarshal([]byte(account.Wallet), &walletMap)
		if err != nil {
			l.RaiseError(fmt.Sprintf("failed to convert wallet to json: %s", err.Error()))
			return 0
		}
		walletTable := RuntimeLuaConvertMap(l, walletMap)
		accountTable.RawSetString("wallet", walletTable)

		if account.Email != "" {
			accountTable.RawSetString("email", lua.LString(account.Email))
		}
		if len(account.Devices) != 0 {
			devicesTable := l.CreateTable(len(account.Devices), 0)
			for i, device := range account.Devices {
				deviceTable := l.CreateTable(0, 1)
				deviceTable.RawSetString("id", lua.LString(device.Id))
				devicesTable.RawSetInt(i+1, deviceTable)
			}
			accountTable.RawSetString("devices", devicesTable)
		}
		if account.CustomId != "" {
			accountTable.RawSetString("custom_id", lua.LString(account.CustomId))
		}
		if account.VerifyTime != nil {
			accountTable.RawSetString("verify_time", lua.LNumber(account.VerifyTime.Seconds))
		}

		accountsTable.RawSetInt(i+1, accountTable)
	}

	l.Push(accountsTable)
	return 1
}

func (n *RuntimeLuaNakamaModule) usersGetId(l *lua.LState) int {
	// Input table validation.
	input := l.OptTable(1, nil)