From 33eb0ee368269813d949860797231823994c0805 Mon Sep 17 00:00:00 2001 From: Andrei Mihu Date: Thu, 27 Dec 2018 23:55:32 +0000 Subject: [PATCH] Add runtime batch account get function. --- CHANGELOG.md | 1 + runtime/runtime.go | 1 + server/core_account.go | 143 +++++++++++++++++++++++++++-------- server/runtime_go_nakama.go | 14 ++++ server/runtime_lua_nakama.go | 113 +++++++++++++++++++++++++++ 5 files changed, 239 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d360a4a3..0e1424fc7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr - Allow client email authentication requests to optionally authenticate with username/password instead of email/password. - Allow runtime email authentication calls to authenticate with username/password instead of email/password. - New authoritative match dispatcher function to defer message broadcasts until the end of the tick. +- New runtime function to retrieve multiple user accounts by user ID. ### Changed - Replace standard logger supplied to the Go runtime with a more powerful interface. diff --git a/runtime/runtime.go b/runtime/runtime.go index ea27cf536..2af3fcfee 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -744,6 +744,7 @@ type NakamaModule interface { AuthenticateTokenGenerate(userID, username string, exp int64) (string, int64, error) AccountGetId(ctx context.Context, userID string) (*api.Account, error) + AccountsGetId(ctx context.Context, userIDs []string) ([]*api.Account, error) AccountUpdateId(ctx context.Context, userID, username string, metadata map[string]interface{}, displayName, timezone, location, langTag, avatarUrl string) error UsersGetId(ctx context.Context, userIDs []string) ([]*api.User, error) diff --git a/server/core_account.go b/server/core_account.go index 046245369..46be2d66f 100644 --- a/server/core_account.go +++ b/server/core_account.go @@ -36,7 +36,7 @@ func GetAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker Tra var username sql.NullString var avatarURL sql.NullString var langTag sql.NullString - var locat sql.NullString + var location sql.NullString var timezone sql.NullString var metadata sql.NullString var wallet sql.NullString @@ -46,19 +46,20 @@ func GetAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker Tra var gamecenter sql.NullString var steam sql.NullString var customID sql.NullString - var edge_count int + var edgeCount int var createTime pq.NullTime var updateTime pq.NullTime var verifyTime pq.NullTime + var deviceIDs pq.StringArray query := ` -SELECT username, display_name, avatar_url, lang_tag, location, timezone, metadata, wallet, - email, facebook_id, google_id, gamecenter_id, steam_id, custom_id, edge_count, - create_time, update_time, verify_time -FROM users -WHERE id = $1` +SELECT u.username, u.display_name, u.avatar_url, u.lang_tag, u.location, u.timezone, u.metadata, u.wallet, + u.email, u.facebook_id, u.google_id, u.gamecenter_id, u.steam_id, u.custom_id, u.edge_count, + u.create_time, u.update_time, u.verify_time, array(select ud.id from user_device ud where u.id = ud.user_id) +FROM users u +WHERE u.id = $1` - if err := db.QueryRowContext(ctx, query, userID).Scan(&username, &displayName, &avatarURL, &langTag, &locat, &timezone, &metadata, &wallet, &email, &facebook, &google, &gamecenter, &steam, &customID, &edge_count, &createTime, &updateTime, &verifyTime); err != nil { + if err := db.QueryRowContext(ctx, query, userID).Scan(&username, &displayName, &avatarURL, &langTag, &location, &timezone, &metadata, &wallet, &email, &facebook, &google, &gamecenter, &steam, &customID, &edgeCount, &createTime, &updateTime, &verifyTime, &deviceIDs); err != nil { if err == sql.ErrNoRows { return nil, ErrAccountNotFound } @@ -66,28 +67,9 @@ WHERE id = $1` return nil, err } - rows, err := db.QueryContext(ctx, "SELECT id FROM user_device WHERE user_id = $1", userID) - if err != nil { - logger.Error("Error retrieving user account.", zap.Error(err)) - return nil, err - } - defer rows.Close() - - deviceIDs := make([]*api.AccountDevice, 0) - for rows.Next() { - var deviceID sql.NullString - err = rows.Scan(&deviceID) - if err != nil { - logger.Error("Error retrieving user account.", zap.Error(err)) - return nil, err - } - if deviceID.Valid { - deviceIDs = append(deviceIDs, &api.AccountDevice{Id: deviceID.String}) - } - } - if err = rows.Err(); err != nil { - logger.Error("Error retrieving user account.", zap.Error(err)) - return nil, err + devices := make([]*api.AccountDevice, 0, len(deviceIDs)) + for _, deviceID := range deviceIDs { + devices = append(devices, &api.AccountDevice{Id: deviceID}) } var verifyTimestamp *timestamp.Timestamp = nil @@ -107,26 +89,121 @@ WHERE id = $1` DisplayName: displayName.String, AvatarUrl: avatarURL.String, LangTag: langTag.String, - Location: locat.String, + Location: location.String, Timezone: timezone.String, Metadata: metadata.String, FacebookId: facebook.String, GoogleId: google.String, GamecenterId: gamecenter.String, SteamId: steam.String, - EdgeCount: int32(edge_count), + EdgeCount: int32(edgeCount), CreateTime: ×tamp.Timestamp{Seconds: createTime.Time.Unix()}, UpdateTime: ×tamp.Timestamp{Seconds: updateTime.Time.Unix()}, Online: online, }, Wallet: wallet.String, Email: email.String, - Devices: deviceIDs, + Devices: devices, CustomId: customID.String, VerifyTime: verifyTimestamp, }, nil } +func GetAccounts(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker Tracker, userIDs []string) ([]*api.Account, error) { + statements := make([]string, 0, len(userIDs)) + parameters := make([]interface{}, 0, len(userIDs)) + for _, userID := range userIDs { + parameters = append(parameters, userID) + statements = append(statements, "$"+strconv.Itoa(len(parameters))) + } + + query := ` +SELECT u.id, u.username, u.display_name, u.avatar_url, u.lang_tag, u.location, u.timezone, u.metadata, u.wallet, + u.email, u.facebook_id, u.google_id, u.gamecenter_id, u.steam_id, u.custom_id, u.edge_count, + u.create_time, u.update_time, u.verify_time, array(select ud.id from user_device ud where u.id = ud.user_id) +FROM users u +WHERE u.id IN (` + strings.Join(statements, ",") + `)` + rows, err := db.QueryContext(ctx, query, parameters...) + if err != nil { + logger.Error("Error retrieving user accounts.", zap.Error(err)) + return nil, err + } + defer rows.Close() + + accounts := make([]*api.Account, 0, len(userIDs)) + for rows.Next() { + var userID string + var username sql.NullString + var displayName sql.NullString + var avatarURL sql.NullString + var langTag sql.NullString + var location sql.NullString + var timezone sql.NullString + var metadata sql.NullString + var wallet sql.NullString + var email sql.NullString + var facebook sql.NullString + var google sql.NullString + var gamecenter sql.NullString + var steam sql.NullString + var customID sql.NullString + var edgeCount int + var createTime pq.NullTime + var updateTime pq.NullTime + var verifyTime pq.NullTime + var deviceIDs pq.StringArray + + err = rows.Scan(&userID, &username, &displayName, &avatarURL, &langTag, &location, &timezone, &metadata, &wallet, &email, &facebook, &google, &gamecenter, &steam, &customID, &edgeCount, &createTime, &updateTime, &verifyTime, &deviceIDs) + if err != nil { + logger.Error("Error retrieving user accounts.", zap.Error(err)) + return nil, err + } + + devices := make([]*api.AccountDevice, 0, len(deviceIDs)) + for _, deviceID := range deviceIDs { + devices = append(devices, &api.AccountDevice{Id: deviceID}) + } + + var verifyTimestamp *timestamp.Timestamp + if verifyTime.Valid && verifyTime.Time.Unix() != 0 { + verifyTimestamp = ×tamp.Timestamp{Seconds: verifyTime.Time.Unix()} + } + + online := false + if tracker != nil { + online = tracker.StreamExists(PresenceStream{Mode: StreamModeNotifications, Subject: uuid.FromStringOrNil(userID)}) + } + + accounts = append(accounts, &api.Account{ + User: &api.User{ + Id: userID, + Username: username.String, + DisplayName: displayName.String, + AvatarUrl: avatarURL.String, + LangTag: langTag.String, + Location: location.String, + Timezone: timezone.String, + Metadata: metadata.String, + FacebookId: facebook.String, + GoogleId: google.String, + GamecenterId: gamecenter.String, + SteamId: steam.String, + EdgeCount: int32(edgeCount), + CreateTime: ×tamp.Timestamp{Seconds: createTime.Time.Unix()}, + UpdateTime: ×tamp.Timestamp{Seconds: updateTime.Time.Unix()}, + Online: online, + }, + Wallet: wallet.String, + Email: email.String, + Devices: devices, + CustomId: customID.String, + VerifyTime: verifyTimestamp, + }) + } + + return accounts, nil +} + func UpdateAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, userID uuid.UUID, username string, displayName, timezone, location, langTag, avatarURL, metadata *wrappers.StringValue) error { index := 1 statements := make([]string, 0) diff --git a/server/runtime_go_nakama.go b/server/runtime_go_nakama.go index 50fcb9562..2b826e1b9 100644 --- a/server/runtime_go_nakama.go +++ b/server/runtime_go_nakama.go @@ -273,6 +273,20 @@ func (n *RuntimeGoNakamaModule) AccountGetId(ctx context.Context, userID string) return GetAccount(ctx, n.logger, n.db, n.tracker, u) } +func (n *RuntimeGoNakamaModule) AccountsGetId(ctx context.Context, userIDs []string) ([]*api.Account, error) { + if len(userIDs) == 0 { + return make([]*api.Account, 0), nil + } + + for _, id := range userIDs { + if _, err := uuid.FromString(id); err != nil { + return nil, errors.New("each user id must be a valid id string") + } + } + + return GetAccounts(ctx, n.logger, n.db, n.tracker, userIDs) +} + func (n *RuntimeGoNakamaModule) AccountUpdateId(ctx context.Context, userID, username string, metadata map[string]interface{}, displayName, timezone, location, langTag, avatarUrl string) error { u, err := uuid.FromString(userID) if err != nil { diff --git a/server/runtime_lua_nakama.go b/server/runtime_lua_nakama.go index 07716ad86..710abc052 100644 --- a/server/runtime_lua_nakama.go +++ b/server/runtime_lua_nakama.go @@ -151,6 +151,7 @@ func (n *RuntimeLuaNakamaModule) Loader(l *lua.LState) int { "logger_warn": n.loggerWarn, "logger_error": n.loggerError, "account_get_id": n.accountGetId, + "accounts_get_id": n.accountsGetId, "account_update_id": n.accountUpdateId, "users_get_id": n.usersGetId, "users_get_username": n.usersGetUsername, @@ -1495,6 +1496,118 @@ func (n *RuntimeLuaNakamaModule) accountGetId(l *lua.LState) int { return 1 } +func (n *RuntimeLuaNakamaModule) accountsGetId(l *lua.LState) int { + // Input table validation. + input := l.OptTable(1, nil) + if input == nil { + l.ArgError(1, "invalid user id list") + return 0 + } + if input.Len() == 0 { + l.Push(l.CreateTable(0, 0)) + return 1 + } + + userIDs := make([]string, 0, input.Len()) + var conversionError bool + input.ForEach(func(k lua.LValue, v lua.LValue) { + if conversionError { + return + } + if v.Type() != lua.LTString { + l.ArgError(1, "user id must be a string") + conversionError = true + return + } + vs := v.String() + if _, err := uuid.FromString(vs); err != nil { + l.ArgError(1, "user id must be a valid identifier string") + conversionError = true + return + } + userIDs = append(userIDs, vs) + }) + if conversionError { + return 0 + } + + accounts, err := GetAccounts(l.Context(), n.logger, n.db, n.tracker, userIDs) + if err != nil { + l.RaiseError("failed to get accounts: %s", err.Error()) + return 0 + } + + accountsTable := l.CreateTable(len(accounts), 0) + for i, account := range accounts { + accountTable := l.CreateTable(0, 21) + accountTable.RawSetString("user_id", lua.LString(account.User.Id)) + accountTable.RawSetString("username", lua.LString(account.User.Username)) + accountTable.RawSetString("display_name", lua.LString(account.User.DisplayName)) + accountTable.RawSetString("avatar_url", lua.LString(account.User.AvatarUrl)) + accountTable.RawSetString("lang_tag", lua.LString(account.User.LangTag)) + accountTable.RawSetString("location", lua.LString(account.User.Location)) + accountTable.RawSetString("timezone", lua.LString(account.User.Timezone)) + if account.User.FacebookId != "" { + accountTable.RawSetString("facebook_id", lua.LString(account.User.FacebookId)) + } + if account.User.GoogleId != "" { + accountTable.RawSetString("google_id", lua.LString(account.User.GoogleId)) + } + if account.User.GamecenterId != "" { + accountTable.RawSetString("gamecenter_id", lua.LString(account.User.GamecenterId)) + } + if account.User.SteamId != "" { + accountTable.RawSetString("steam_id", lua.LString(account.User.SteamId)) + } + accountTable.RawSetString("online", lua.LBool(account.User.Online)) + accountTable.RawSetString("edge_count", lua.LNumber(account.User.EdgeCount)) + accountTable.RawSetString("create_time", lua.LNumber(account.User.CreateTime.Seconds)) + accountTable.RawSetString("update_time", lua.LNumber(account.User.UpdateTime.Seconds)) + + metadataMap := make(map[string]interface{}) + err = json.Unmarshal([]byte(account.User.Metadata), &metadataMap) + if err != nil { + l.RaiseError(fmt.Sprintf("failed to convert metadata to json: %s", err.Error())) + return 0 + } + metadataTable := RuntimeLuaConvertMap(l, metadataMap) + accountTable.RawSetString("metadata", metadataTable) + + walletMap := make(map[string]interface{}) + err = json.Unmarshal([]byte(account.Wallet), &walletMap) + if err != nil { + l.RaiseError(fmt.Sprintf("failed to convert wallet to json: %s", err.Error())) + return 0 + } + walletTable := RuntimeLuaConvertMap(l, walletMap) + accountTable.RawSetString("wallet", walletTable) + + if account.Email != "" { + accountTable.RawSetString("email", lua.LString(account.Email)) + } + if len(account.Devices) != 0 { + devicesTable := l.CreateTable(len(account.Devices), 0) + for i, device := range account.Devices { + deviceTable := l.CreateTable(0, 1) + deviceTable.RawSetString("id", lua.LString(device.Id)) + devicesTable.RawSetInt(i+1, deviceTable) + } + accountTable.RawSetString("devices", devicesTable) + } + if account.CustomId != "" { + accountTable.RawSetString("custom_id", lua.LString(account.CustomId)) + } + if account.VerifyTime != nil { + accountTable.RawSetString("verify_time", lua.LNumber(account.VerifyTime.Seconds)) + } + + accountsTable.RawSetInt(i+1, accountTable) + } + + l.Push(accountsTable) + return 1 +} + func (n *RuntimeLuaNakamaModule) usersGetId(l *lua.LState) int { // Input table validation. input := l.OptTable(1, nil) -- GitLab