diff --git a/data/modules/clientrpc.lua b/data/modules/clientrpc.lua index e1c6c37f5f1749f90050edef989334bda60ab7ed..5b4ab28a8a6933f771c37e88f6da6a782bd9abf6 100644 --- a/data/modules/clientrpc.lua +++ b/data/modules/clientrpc.lua @@ -42,12 +42,12 @@ local function send_notification(context, payload) local decoded = nk.json_decode(payload) local new_notifications = { { - Code = 1, - Content = { reward_coins = 1000 }, - Persistent = true, - SenderId = context.UserId, - Subject = "You've unlocked level 100!", - UserId = decoded.user_id + code = 1, + content = { reward_coins = 1000 }, + persistent = true, + sender_id = context.user_id, + subject = "You've unlocked level 100!", + user_id = decoded.user_id } } nk.notifications_send(new_notifications) @@ -56,10 +56,10 @@ nk.register_rpc(send_notification, "clientrpc.send_notification") local function send_stream_data(context, payload) local stream = { - Mode = 20, - Label = "Stream Data Test", + mode = 20, + label = "Stream Data Test", } - nk.stream_user_join(context.UserId, context.SessionId, stream, false, false) + nk.stream_user_join(context.user_id, context.session_id, stream, false, false) nk.stream_send(stream, tostring(payload)) end nk.register_rpc(send_stream_data, "clientrpc.send_stream_data") diff --git a/data/modules/match.lua b/data/modules/match.lua index b02832834a16ad77a4352eaed8107eb19df3fb36..9e0ea7851882f6c031873a1cf4ba26c2a717eac6 100644 --- a/data/modules/match.lua +++ b/data/modules/match.lua @@ -21,10 +21,10 @@ Called when a match is created as a result of nk.match_create(). Context represents information about the match and server, for information purposes. Format: { - Env = {}, -- key-value data set in the runtime.env server configuration. - ExecutionMode = "Match", - MatchId = "client-friendly match ID, can be shared with clients and used in match join operations", - MatchNode = "name of the Nakama node hosting this match" + env = {}, -- key-value data set in the runtime.env server configuration. + execution_mode = "Match", + match_id = "client-friendly match ID, can be shared with clients and used in match join operations", + match_node = "name of the Nakama node hosting this match" } Params is the optional arbitrary second argument passed to `nk.match_create()`, or `nil` if none was used. @@ -52,12 +52,12 @@ Called when a user attempts to join the match using the client's match join oper Context represents information about the match and server, for information purposes. Format: { - Env = {}, -- key-value data set in the runtime.env server configuration. - ExecutionMode = "Match", - MatchId = "client-friendly match ID, can be shared with clients and used in match join operations", - MatchNode = "name of the Nakama node hosting this match", - MatchLabel = "the label string returned from match_init", - MatchTickrate = 1 -- the tick rate returned by match_init + env = {}, -- key-value data set in the runtime.env server configuration. + execution_mode = "Match", + match_id = "client-friendly match ID, can be shared with clients and used in match join operations", + match_node = "name of the Nakama node hosting this match", + match_label = "the label string returned from match_init", + match_tick_rate = 1 -- the tick rate returned by match_init } Dispatcher exposes useful functions to the match. Format: @@ -78,10 +78,10 @@ State is the current in-memory match state, may be any Lua term except nil. Presence is the user attempting to join the match. Format: { - UserId: "user unique ID", - SessionId: "session ID of the user's current connection", - Username: "user's unique username", - Node: "name of the Nakama node the user is connected to" + user_id: "user unique ID", + session_id: "session ID of the user's current connection", + username: "user's unique username", + node: "name of the Nakama node the user is connected to" } Expected return these values (all required) in order: @@ -100,12 +100,12 @@ Called when one or more users have left the match for any reason, including conn Context represents information about the match and server, for information purposes. Format: { - Env = {}, -- key-value data set in the runtime.env server configuration. - ExecutionMode = "Match", - MatchId = "client-friendly match ID, can be shared with clients and used in match join operations", - MatchNode = "name of the Nakama node hosting this match", - MatchLabel = "the label string returned from match_init", - MatchTickrate = 1 -- the tick rate returned by match_init + env = {}, -- key-value data set in the runtime.env server configuration. + execution_mode = "Match", + match_id = "client-friendly match ID, can be shared with clients and used in match join operations", + match_node = "name of the Nakama node hosting this match", + match_label = "the label string returned from match_init", + match_tick_rate = 1 -- the tick rate returned by match_init } Dispatcher exposes useful functions to the match. Format: @@ -127,10 +127,10 @@ State is the current in-memory match state, may be any Lua term except nil. Presences is a list of users that have left the match. Format: { { - UserId: "user unique ID", - SessionId: "session ID of the user's current connection", - Username: "user's unique username", - Node: "name of the Nakama node the user is connected to" + user_id: "user unique ID", + session_id: "session ID of the user's current connection", + username: "user's unique username", + node: "name of the Nakama node the user is connected to" }, ... } @@ -150,12 +150,12 @@ Called on an interval based on the tick rate returned by match_init. Context represents information about the match and server, for information purposes. Format: { - Env = {}, -- key-value data set in the runtime.env server configuration. - ExecutionMode = "Match", - MatchId = "client-friendly match ID, can be shared with clients and used in match join operations", - MatchNode = "name of the Nakama node hosting this match", - MatchLabel = "the label string returned from match_init", - MatchTickrate = 1 -- the tick rate returned by match_init + env = {}, -- key-value data set in the runtime.env server configuration. + executionMode = "Match", + match_id = "client-friendly match ID, can be shared with clients and used in match join operations", + match_node = "name of the Nakama node hosting this match", + match_label = "the label string returned from match_init", + match_tick_rate = 1 -- the tick rate returned by match_init } Dispatcher exposes useful functions to the match. Format: @@ -177,14 +177,14 @@ State is the current in-memory match state, may be any Lua term except nil. Messages is a list of data messages received from users between the previous and current ticks. Format: { { - Sender = { - UserId: "user unique ID", - SessionId: "session ID of the user's current connection", - Username: "user's unique username", - Node: "name of the Nakama node the user is connected to" + sender = { + user_id: "user unique ID", + session_id: "session ID of the user's current connection", + username: "user's unique username", + node: "name of the Nakama node the user is connected to" }, - OpCode = 1, -- numeric op code set by the sender. - Data = "any string data set by the sender" -- may be nil. + op_code = 1, -- numeric op code set by the sender. + data = "any string data set by the sender" -- may be nil. }, ... } @@ -194,8 +194,8 @@ Expected return these values (all required) in order: --]] local function match_loop(context, dispatcher, tick, state, messages) if state.debug then - print("match " .. context.MatchId .. " tick " .. tick) - print("match " .. context.MatchId .. " messages:\n" .. du.print_r(messages)) + print("match " .. context.match_id .. " tick " .. tick) + print("match " .. context.match_id .. " messages:\n" .. du.print_r(messages)) end if tick < 180 then return state diff --git a/data/modules/match_init.lua b/data/modules/match_init.lua index 400486798a565ba40fee74049fd2681e3c12609a..08298c5f444ad543c1df5f8200ed9d37cf976b16 100644 --- a/data/modules/match_init.lua +++ b/data/modules/match_init.lua @@ -14,4 +14,4 @@ limitations under the License. --]] ---require("nakama").match_create("match", {debug = true}) +require("nakama").match_create("match", {debug = true}) diff --git a/server/api_authenticate.go b/server/api_authenticate.go index acb55d75d286f7e8c850933874dde1c8a1f01881..8c0a83dc4fe4f179549e966f1de3a2e8f42cab11 100644 --- a/server/api_authenticate.go +++ b/server/api_authenticate.go @@ -154,7 +154,7 @@ func (s *ApiServer) AuthenticateFacebook(ctx context.Context, in *api.Authentica // Import friends if requested. if in.Import == nil || in.Import.Value { - importFacebookFriends(s.logger, s.db, s.socialClient, uuid.FromStringOrNil(dbUserID), dbUsername, in.Account.Token) + importFacebookFriends(s.logger, s.db, s.socialClient, uuid.FromStringOrNil(dbUserID), dbUsername, in.Account.Token, false) } token := generateToken(s.config, dbUserID, dbUsername) diff --git a/server/api_friend.go b/server/api_friend.go index 5b534eddcdb50b325bcd524527267e5bda1b8d8d..7c64c737edb137259160b1065c541a3bcc14c7b5 100644 --- a/server/api_friend.go +++ b/server/api_friend.go @@ -166,5 +166,15 @@ func (s *ApiServer) BlockFriends(ctx context.Context, in *api.BlockFriendsReques } func (s *ApiServer) ImportFacebookFriends(ctx context.Context, in *api.ImportFacebookFriendsRequest) (*empty.Empty, error) { - return nil, nil + if in.Account == nil || in.Account.Token == "" { + return nil, status.Error(codes.InvalidArgument, "Facebook token is required.") + } + + err := importFacebookFriends(s.logger, s.db, s.socialClient, ctx.Value(ctxUserIDKey{}).(uuid.UUID), ctx.Value(ctxUsernameKey{}).(string), in.Account.Token, in.Reset_ != nil && in.Reset_.Value) + if err != nil { + // Already logged inside the core importFacebookFriends function. + return nil, err + } + + return &empty.Empty{}, nil } diff --git a/server/api_link.go b/server/api_link.go index a3dd947270b7affc99135f45634f2ea51f78cd08..ecff03b4c1057a7ac68deb7291e07dbeb82a5b3f 100644 --- a/server/api_link.go +++ b/server/api_link.go @@ -16,8 +16,8 @@ package server import ( "database/sql" - "strings" "strconv" + "strings" "time" "github.com/golang/protobuf/ptypes/empty" @@ -187,7 +187,7 @@ AND (NOT EXISTS // Import friends if requested. if in.Import == nil || in.Import.Value { - importFacebookFriends(s.logger, s.db, s.socialClient, userID.(uuid.UUID), ctx.Value(ctxUsernameKey{}).(string), in.Account.Token) + importFacebookFriends(s.logger, s.db, s.socialClient, userID.(uuid.UUID), ctx.Value(ctxUsernameKey{}).(string), in.Account.Token, false) } return &empty.Empty{}, nil diff --git a/server/core_authenticate.go b/server/core_authenticate.go index 813819b48af36b95ce3bff111286a83c83ef8bf1..cafec7e09ed0c71ed0217d88b6e8301d67b00c28 100644 --- a/server/core_authenticate.go +++ b/server/core_authenticate.go @@ -22,6 +22,7 @@ import ( "encoding/json" "strconv" + "errors" "github.com/golang/protobuf/ptypes/timestamp" "github.com/heroiclabs/nakama/api" "github.com/heroiclabs/nakama/social" @@ -95,16 +96,53 @@ func AuthenticateCustom(logger *zap.Logger, db *sql.DB, customID, username strin } func AuthenticateDevice(logger *zap.Logger, db *sql.DB, deviceID, username string, create bool) (string, string, error) { + found := true + + // Look for an existing account. + query := "SELECT user_id FROM user_device WHERE id = $1" + var dbUserID string + err := db.QueryRow(query, deviceID).Scan(&dbUserID) + if err != nil { + if err == sql.ErrNoRows { + found = false + // No user account found. + //return "", "", status.Error(codes.NotFound, "Device ID not found.") + } else { + logger.Error("Cannot find user with device ID.", zap.Error(err), zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding user account.") + } + } + + // Existing account found. + if found { + // Load its details. + query = "SELECT username, disable_time FROM users WHERE id = $1" + var dbUsername string + var dbDisableTime int64 + err = db.QueryRow(query, dbUserID).Scan(&dbUsername, &dbDisableTime) + if err != nil { + logger.Error("Cannot find user with device ID.", zap.Error(err), zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding user account.") + } + + // Check if it's disabled. + if dbDisableTime != 0 { + logger.Debug("User account is disabled.", zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.") + } + + return dbUserID, dbUsername, nil + } + if !create { - return LoginDevice(logger, db, deviceID, username, create) + // No user account found, and creation is not allowed. + return "", "", status.Error(codes.NotFound, "User account not found.") } - // Use existing user account if found, otherwise create a new user account. - var dbUserID string - var dbUsername string + // Create a new account. + userID := uuid.NewV4().String() fnErr := Transact(logger, db, func(tx *sql.Tx) error { - userID := uuid.NewV4().String() - ts := time.Now().UTC().Unix() + //ts := time.Now().UTC().Unix() query := ` INSERT INTO users (id, username, create_time, update_time) SELECT $1 AS id, @@ -114,137 +152,70 @@ SELECT $1 AS id, WHERE NOT EXISTS (SELECT id FROM user_device - WHERE id = $3::VARCHAR) -ON CONFLICT(id) DO NOTHING -RETURNING id, username, disable_time` + WHERE id = $3::VARCHAR)` - var dbDisableTime int64 - err := tx.QueryRow(query, userID, username, deviceID, ts).Scan(&dbUserID, &dbUsername, &dbDisableTime) + result, err := tx.Exec(query, userID, username, deviceID, time.Now().UTC().Unix()) if err != nil { - if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") { - return status.Error(codes.AlreadyExists, "Username is already in use.") - } - if err == sql.ErrNoRows { - // let's catch this case as it could be there could be a device ID already - // linked to a ID so let's attempt a vanilla login - dbUserID, dbUsername, err = LoginDevice(logger, db, deviceID, username, create) - return err - } else { - logger.Error("Cannot find or create user with device ID.", zap.Error(err), zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create)) + // A concurrent write has inserted this device ID. + logger.Debug("Did not insert new user as device ID already exists.", zap.Error(err), zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create)) return status.Error(codes.Internal, "Error finding or creating user account.") + } else if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") { + return status.Error(codes.AlreadyExists, "Username is already in use.") } + logger.Error("Cannot find or create user with device ID.", zap.Error(err), zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create)) + return status.Error(codes.Internal, "Error finding or creating user account.") } - if dbDisableTime != 0 { - logger.Debug("User account is disabled.", zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create)) - return status.Error(codes.Unauthenticated, "Error finding or creating user account.") + if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != 1 { + logger.Error("Did not insert new user.", zap.Int64("rows_affected", rowsAffectedCount)) + return status.Error(codes.Internal, "Error finding or creating user account.") } - query = "INSERT INTO user_device (id, user_id) VALUES ($1, $2) ON CONFLICT(id) DO NOTHING" - _, err = tx.Exec(query, deviceID, userID) + query = "INSERT INTO user_device (id, user_id) VALUES ($1, $2)" + result, err = tx.Exec(query, deviceID, userID) if err != nil { logger.Error("Cannot add device ID.", zap.Error(err), zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create)) return status.Error(codes.Internal, "Error finding or creating user account.") } + if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != 1 { + logger.Error("Did not insert new user.", zap.Int64("rows_affected", rowsAffectedCount)) + return status.Error(codes.Internal, "Error finding or creating user account.") + } + return nil }) - if fnErr != nil { - return dbUserID, dbUsername, fnErr + return "", "", fnErr } - return dbUserID, dbUsername, nil + return userID, username, nil } -func LoginDevice(logger *zap.Logger, db *sql.DB, deviceID, username string, create bool) (string, string, error) { - query := "SELECT user_id FROM user_device WHERE id = $1" +func AuthenticateEmail(logger *zap.Logger, db *sql.DB, email, password, username string, create bool) (string, string, error) { + found := true + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + // Look for an existing account. + query := "SELECT id, username, password, disable_time FROM users WHERE email = $1" var dbUserID string - err := db.QueryRow(query, deviceID).Scan(&dbUserID) + var dbUsername string + var dbPassword string + var dbDisableTime int64 + err := db.QueryRow(query, email).Scan(&dbUserID, &dbUsername, &dbPassword, &dbDisableTime) if err != nil { if err == sql.ErrNoRows { - // No user account found. - return "", "", status.Error(codes.NotFound, "Device ID not found.") + found = false } else { - logger.Error("Cannot find user with device ID.", zap.Error(err), zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create)) + logger.Error("Cannot find user with email.", zap.Error(err), zap.String("email", email), zap.String("username", username), zap.Bool("create", create)) return "", "", status.Error(codes.Internal, "Error finding user account.") } } - query = "SELECT username, disable_time FROM users WHERE id = $1" - var dbUsername string - var dbDisableTime int64 - - err = db.QueryRow(query, dbUserID).Scan(&dbUsername, &dbDisableTime) - if err != nil { - logger.Error("Cannot find user with device ID.", zap.Error(err), zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Internal, "Error finding user account.") - } - - if dbDisableTime != 0 { - logger.Debug("User account is disabled.", zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.") - } - - return dbUserID, dbUsername, nil -} - -func AuthenticateEmail(logger *zap.Logger, db *sql.DB, email, password, username string, create bool) (string, string, error) { - hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - - if create { - // Use existing user account if found, otherwise create a new user account. - userID := uuid.NewV4().String() - ts := time.Now().UTC().Unix() - query := ` -INSERT INTO users (id, username, email, password, create_time, update_time) -VALUES ($1, $2, $3, $4, $5, $5) -ON CONFLICT (email) DO UPDATE SET email = $3, password = $4 -RETURNING id, username, disable_time` - - var dbUserID string - var dbUsername string - var dbDisableTime int64 - err := db.QueryRow(query, userID, username, email, hashedPassword, ts).Scan(&dbUserID, &dbUsername, &dbDisableTime) - if err != nil { - if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") { - // Username is already in use by a different account. - return "", "", status.Error(codes.AlreadyExists, "Username is already in use.") - } - logger.Error("Cannot find or create user with email.", zap.Error(err), zap.String("email", email), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Internal, "Error finding or creating user account.") - } - - if dbDisableTime != 0 { - logger.Debug("User account is disabled.", zap.String("email", email), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.") - } - - return dbUserID, dbUsername, nil - } else { - // Do not create a new user account. - query := ` -SELECT id, username, password, disable_time -FROM users -WHERE email = $1` - - var dbUserID string - var dbUsername string - var dbPassword string - var dbDisableTime int64 - err := db.QueryRow(query, email).Scan(&dbUserID, &dbUsername, &dbPassword, &dbDisableTime) - if err != nil { - if err == sql.ErrNoRows { - // No user account found. - return "", "", status.Error(codes.NotFound, "User account not found.") - } else { - logger.Error("Cannot find user with email.", zap.Error(err), zap.String("email", email), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Internal, "Error finding user account.") - } - } - + // Existing account found. + if found { + // Check if it's disabled. if dbDisableTime != 0 { logger.Debug("User account is disabled.", zap.String("email", email), zap.String("username", username), zap.Bool("create", create)) return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.") @@ -257,6 +228,37 @@ WHERE email = $1` return dbUserID, dbUsername, nil } + + if !create { + // No user account found, and creation is not allowed. + return "", "", status.Error(codes.NotFound, "User account not found.") + } + + // Create a new account. + userID := uuid.NewV4().String() + query = "INSERT INTO users (id, username, email, password, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $5)" + result, err := db.Exec(query, userID, username, email, hashedPassword, time.Now().UTC().Unix()) + if err != nil { + if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation { + if strings.Contains(e.Message, "users_username_key") { + // Username is already in use by a different account. + return "", "", status.Error(codes.AlreadyExists, "Username is already in use.") + } else if strings.Contains(e.Message, "users_email_key") { + // A concurrent write has inserted this email. + logger.Debug("Did not insert new user as email already exists.", zap.Error(err), zap.String("email", email), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + } + } + logger.Error("Cannot find or create user with email.", zap.Error(err), zap.String("email", email), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + } + + if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != 1 { + logger.Error("Did not insert new user.", zap.Int64("rows_affected", rowsAffectedCount)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + } + + return userID, username, nil } func AuthenticateFacebook(logger *zap.Logger, db *sql.DB, client *social.Client, accessToken, username string, create bool) (string, string, error) { @@ -265,64 +267,64 @@ func AuthenticateFacebook(logger *zap.Logger, db *sql.DB, client *social.Client, logger.Debug("Could not authenticate Facebook profile.", zap.Error(err)) return "", "", status.Error(codes.Unauthenticated, "Could not authenticate Facebook profile.") } + found := true - if create { - // Use existing user account if found, otherwise create a new user account. - userID := uuid.NewV4().String() - ts := time.Now().UTC().Unix() - query := ` -INSERT INTO users (id, username, facebook_id, create_time, update_time) -VALUES ($1, $2, $3, $4, $4) -ON CONFLICT (facebook_id) DO UPDATE SET facebook_id = $3 -RETURNING id, username, disable_time` - - var dbUserID string - var dbUsername string - var dbDisableTime int64 - err := db.QueryRow(query, userID, username, facebookProfile.ID, ts).Scan(&dbUserID, &dbUsername, &dbDisableTime) - if err != nil { - if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") { - // Username is already in use by a different account. - return "", "", status.Error(codes.AlreadyExists, "Username is already in use.") - } - logger.Error("Cannot find or create user with Facebook ID.", zap.Error(err), zap.String("facebookID", facebookProfile.ID), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + // Look for an existing account. + query := "SELECT id, username, disable_time FROM users WHERE facebook_id = $1" + var dbUserID string + var dbUsername string + var dbDisableTime int64 + err = db.QueryRow(query, facebookProfile.ID).Scan(&dbUserID, &dbUsername, &dbDisableTime) + if err != nil { + if err == sql.ErrNoRows { + found = false + } else { + logger.Error("Cannot find user with Facebook ID.", zap.Error(err), zap.String("facebookID", facebookProfile.ID), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding user account.") } + } + // Existing account found. + if found { + // Check if it's disabled. if dbDisableTime != 0 { logger.Debug("User account is disabled.", zap.String("facebookID", facebookProfile.ID), zap.String("username", username), zap.Bool("create", create)) return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.") } return dbUserID, dbUsername, nil - } else { - // Do not create a new user account. - query := ` -SELECT id, username, disable_time -FROM users -WHERE facebook_id = $1` + } - var dbUserID string - var dbUsername string - var dbDisableTime int64 - err := db.QueryRow(query, facebookProfile.ID).Scan(&dbUserID, &dbUsername, &dbDisableTime) - if err != nil { - if err == sql.ErrNoRows { - // No user account found. - return "", "", status.Error(codes.NotFound, "User account not found.") - } else { - logger.Error("Cannot find user with Facebook ID.", zap.Error(err), zap.String("facebookID", facebookProfile.ID), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Internal, "Error finding user account.") - } - } + if !create { + // No user account found, and creation is not allowed. + return "", "", status.Error(codes.NotFound, "User account not found.") + } - if dbDisableTime != 0 { - logger.Debug("User account is disabled.", zap.String("facebookID", facebookProfile.ID), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.") + // Create a new account. + userID := uuid.NewV4().String() + query = "INSERT INTO users (id, username, facebook_id, create_time, update_time) VALUES ($1, $2, $3, $4, $4)" + result, err := db.Exec(query, userID, username, facebookProfile.ID, time.Now().UTC().Unix()) + if err != nil { + if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation { + if strings.Contains(e.Message, "users_username_key") { + // Username is already in use by a different account. + return "", "", status.Error(codes.AlreadyExists, "Username is already in use.") + } else if strings.Contains(e.Message, "users_facebook_id_key") { + // A concurrent write has inserted this Facebook ID. + logger.Debug("Did not insert new user as Facebook ID already exists.", zap.Error(err), zap.String("facebookID", facebookProfile.ID), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + } } + logger.Error("Cannot find or create user with Facebook ID.", zap.Error(err), zap.String("facebookID", facebookProfile.ID), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + } - return dbUserID, dbUsername, nil + if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != 1 { + logger.Error("Did not insert new user.", zap.Int64("rows_affected", rowsAffectedCount)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") } + + return userID, username, nil } func AuthenticateGameCenter(logger *zap.Logger, db *sql.DB, client *social.Client, playerID, bundleID string, timestamp int64, salt, signature, publicKeyUrl, username string, create bool) (string, string, error) { @@ -331,64 +333,64 @@ func AuthenticateGameCenter(logger *zap.Logger, db *sql.DB, client *social.Clien logger.Debug("Could not authenticate GameCenter profile.", zap.Error(err), zap.Bool("valid", valid)) return "", "", status.Error(codes.Unauthenticated, "Could not authenticate GameCenter profile.") } + found := true - if create { - // Use existing user account if found, otherwise create a new user account. - userID := uuid.NewV4().String() - ts := time.Now().UTC().Unix() - query := ` -INSERT INTO users (id, username, gamecenter_id, create_time, update_time) -VALUES ($1, $2, $3, $4, $4) -ON CONFLICT (gamecenter_id) DO UPDATE SET gamecenter_id = $3 -RETURNING id, username, disable_time` - - var dbUserID string - var dbUsername string - var dbDisableTime int64 - err := db.QueryRow(query, userID, username, playerID, ts).Scan(&dbUserID, &dbUsername, &dbDisableTime) - if err != nil { - if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") { - // Username is already in use by a different account. - return "", "", status.Error(codes.AlreadyExists, "Username is already in use.") - } - logger.Error("Cannot find or create user with GameCenter ID.", zap.Error(err), zap.String("gameCenterID", playerID), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + // Look for an existing account. + query := "SELECT id, username, disable_time FROM users WHERE gamecenter_id = $1" + var dbUserID string + var dbUsername string + var dbDisableTime int64 + err = db.QueryRow(query, playerID).Scan(&dbUserID, &dbUsername, &dbDisableTime) + if err != nil { + if err == sql.ErrNoRows { + found = false + } else { + logger.Error("Cannot find user with GameCenter ID.", zap.Error(err), zap.String("gameCenterID", playerID), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding user account.") } + } + // Existing account found. + if found { + // Check if it's disabled. if dbDisableTime != 0 { logger.Debug("User account is disabled.", zap.String("gameCenterID", playerID), zap.String("username", username), zap.Bool("create", create)) return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.") } return dbUserID, dbUsername, nil - } else { - // Do not create a new user account. - query := ` -SELECT id, username, disable_time -FROM users -WHERE gamecenter_id = $1` + } - var dbUserID string - var dbUsername string - var dbDisableTime int64 - err := db.QueryRow(query, playerID).Scan(&dbUserID, &dbUsername, &dbDisableTime) - if err != nil { - if err == sql.ErrNoRows { - // No user account found. - return "", "", status.Error(codes.NotFound, "User account not found.") - } else { - logger.Error("Cannot find user with GameCenter ID.", zap.Error(err), zap.String("gameCenterID", playerID), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Internal, "Error finding user account.") - } - } + if !create { + // No user account found, and creation is not allowed. + return "", "", status.Error(codes.NotFound, "User account not found.") + } - if dbDisableTime != 0 { - logger.Debug("User account is disabled.", zap.String("gameCenterID", playerID), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.") + // Create a new account. + userID := uuid.NewV4().String() + query = "INSERT INTO users (id, username, gamecenter_id, create_time, update_time) VALUES ($1, $2, $3, $4, $4)" + result, err := db.Exec(query, userID, username, playerID, time.Now().UTC().Unix()) + if err != nil { + if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation { + if strings.Contains(e.Message, "users_username_key") { + // Username is already in use by a different account. + return "", "", status.Error(codes.AlreadyExists, "Username is already in use.") + } else if strings.Contains(e.Message, "users_gamecenter_id_key") { + // A concurrent write has inserted this GameCenter ID. + logger.Debug("Did not insert new user as GameCenter ID already exists.", zap.Error(err), zap.String("gameCenterID", playerID), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + } } + logger.Error("Cannot find or create user with GameCenter ID.", zap.Error(err), zap.String("gameCenterID", playerID), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + } - return dbUserID, dbUsername, nil + if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != 1 { + logger.Error("Did not insert new user.", zap.Int64("rows_affected", rowsAffectedCount)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") } + + return userID, username, nil } func AuthenticateGoogle(logger *zap.Logger, db *sql.DB, client *social.Client, idToken, username string, create bool) (string, string, error) { @@ -397,64 +399,64 @@ func AuthenticateGoogle(logger *zap.Logger, db *sql.DB, client *social.Client, i logger.Debug("Could not authenticate Google profile.", zap.Error(err)) return "", "", status.Error(codes.Unauthenticated, "Could not authenticate Google profile.") } + found := true - if create { - // Use existing user account if found, otherwise create a new user account. - userID := uuid.NewV4().String() - ts := time.Now().UTC().Unix() - query := ` -INSERT INTO users (id, username, google_id, create_time, update_time) -VALUES ($1, $2, $3, $4, $4) -ON CONFLICT (google_id) DO UPDATE SET google_id = $3 -RETURNING id, username, disable_time` - - var dbUserID string - var dbUsername string - var dbDisableTime int64 - err := db.QueryRow(query, userID, username, googleProfile.Sub, ts).Scan(&dbUserID, &dbUsername, &dbDisableTime) - if err != nil { - if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") { - // Username is already in use by a different account. - return "", "", status.Error(codes.AlreadyExists, "Username is already in use.") - } - logger.Error("Cannot find or create user with Google ID.", zap.Error(err), zap.String("googleID", googleProfile.Sub), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + // Look for an existing account. + query := "SELECT id, username, disable_time FROM users WHERE google_id = $1" + var dbUserID string + var dbUsername string + var dbDisableTime int64 + err = db.QueryRow(query, googleProfile.Sub).Scan(&dbUserID, &dbUsername, &dbDisableTime) + if err != nil { + if err == sql.ErrNoRows { + found = false + } else { + logger.Error("Cannot find user with Google ID.", zap.Error(err), zap.String("googleID", googleProfile.Sub), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding user account.") } + } + // Existing account found. + if found { + // Check if it's disabled. if dbDisableTime != 0 { logger.Debug("User account is disabled.", zap.String("googleID", googleProfile.Sub), zap.String("username", username), zap.Bool("create", create)) return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.") } return dbUserID, dbUsername, nil - } else { - // Do not create a new user account. - query := ` -SELECT id, username, disable_time -FROM users -WHERE google_id = $1` + } - var dbUserID string - var dbUsername string - var dbDisableTime int64 - err := db.QueryRow(query, googleProfile.Sub).Scan(&dbUserID, &dbUsername, &dbDisableTime) - if err != nil { - if err == sql.ErrNoRows { - // No user account found. - return "", "", status.Error(codes.NotFound, "User account not found.") - } else { - logger.Error("Cannot find user with Google ID.", zap.Error(err), zap.String("googleID", googleProfile.Sub), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Internal, "Error finding user account.") - } - } + if !create { + // No user account found, and creation is not allowed. + return "", "", status.Error(codes.NotFound, "User account not found.") + } - if dbDisableTime != 0 { - logger.Debug("User account is disabled.", zap.String("googleID", googleProfile.Sub), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.") + // Create a new account. + userID := uuid.NewV4().String() + query = "INSERT INTO users (id, username, google_id, create_time, update_time) VALUES ($1, $2, $3, $4, $4)" + result, err := db.Exec(query, userID, username, googleProfile.Sub, time.Now().UTC().Unix()) + if err != nil { + if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation { + if strings.Contains(e.Message, "users_username_key") { + // Username is already in use by a different account. + return "", "", status.Error(codes.AlreadyExists, "Username is already in use.") + } else if strings.Contains(e.Message, "users_google_id_key") { + // A concurrent write has inserted this Google ID. + logger.Debug("Did not insert new user as Google ID already exists.", zap.Error(err), zap.String("googleID", googleProfile.Sub), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + } } + logger.Error("Cannot find or create user with Google ID.", zap.Error(err), zap.String("googleID", googleProfile.Sub), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + } - return dbUserID, dbUsername, nil + if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != 1 { + logger.Error("Did not insert new user.", zap.Int64("rows_affected", rowsAffectedCount)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") } + + return userID, username, nil } func AuthenticateSteam(logger *zap.Logger, db *sql.DB, client *social.Client, appID int, publisherKey, token, username string, create bool) (string, string, error) { @@ -463,83 +465,129 @@ func AuthenticateSteam(logger *zap.Logger, db *sql.DB, client *social.Client, ap logger.Debug("Could not authenticate Steam profile.", zap.Error(err)) return "", "", status.Error(codes.Unauthenticated, "Could not authenticate Steam profile.") } - steamID := strconv.FormatUint(steamProfile.SteamID, 10) + found := true - if create { - // Use existing user account if found, otherwise create a new user account. - userID := uuid.NewV4().String() - ts := time.Now().UTC().Unix() - query := ` -INSERT INTO users (id, username, steam_id, create_time, update_time) -VALUES ($1, $2, $3, $4, $4) -ON CONFLICT (steam_id) DO UPDATE SET steam_id = $3 -RETURNING id, username, disable_time` - - var dbUserID string - var dbUsername string - var dbDisableTime int64 - err := db.QueryRow(query, userID, username, steamID, ts).Scan(&dbUserID, &dbUsername, &dbDisableTime) - if err != nil { - if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") { - // Username is already in use by a different account. - return "", "", status.Error(codes.AlreadyExists, "Username is already in use.") - } - logger.Error("Cannot find or create user with Steam ID.", zap.Error(err), zap.String("steamID", steamID), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + // Look for an existing account. + query := "SELECT id, username, disable_time FROM users WHERE steam_id = $1" + var dbUserID string + var dbUsername string + var dbDisableTime int64 + err = db.QueryRow(query, steamID).Scan(&dbUserID, &dbUsername, &dbDisableTime) + if err != nil { + if err == sql.ErrNoRows { + found = false + } else { + logger.Error("Cannot find user with Steam ID.", zap.Error(err), zap.String("steamID", steamID), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding user account.") } + } + // Existing account found. + if found { + // Check if it's disabled. if dbDisableTime != 0 { - logger.Debug("User account is disabled.", zap.String("steamID", steamID), zap.String("username", username), zap.Bool("create", create)) + logger.Debug("User account is disabled.", zap.Error(err), zap.String("steamID", steamID), zap.String("username", username), zap.Bool("create", create)) return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.") } return dbUserID, dbUsername, nil - } else { - // Do not create a new user account. - query := ` -SELECT id, username, disable_time -FROM users -WHERE steam_id = $1` + } - var dbUserID string - var dbUsername string - var dbDisableTime int64 - err := db.QueryRow(query, steamID).Scan(&dbUserID, &dbUsername, &dbDisableTime) - if err != nil { - if err == sql.ErrNoRows { - // No user account found. - return "", "", status.Error(codes.NotFound, "User account not found.") - } else { - logger.Error("Cannot find user with Steam ID.", zap.Error(err), zap.String("steamID", steamID), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Internal, "Error finding user account.") - } - } + if !create { + // No user account found, and creation is not allowed. + return "", "", status.Error(codes.NotFound, "User account not found.") + } - if dbDisableTime != 0 { - logger.Debug("User account is disabled.", zap.String("steamID", steamID), zap.String("username", username), zap.Bool("create", create)) - return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.") + // Create a new account. + userID := uuid.NewV4().String() + query = "INSERT INTO users (id, username, steam_id, create_time, update_time) VALUES ($1, $2, $3, $4, $4)" + result, err := db.Exec(query, userID, username, steamID, time.Now().UTC().Unix()) + if err != nil { + if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation { + if strings.Contains(e.Message, "users_username_key") { + // Username is already in use by a different account. + return "", "", status.Error(codes.AlreadyExists, "Username is already in use.") + } else if strings.Contains(e.Message, "users_steam_id_key") { + // A concurrent write has inserted this Steam ID. + logger.Debug("Did not insert new user as Steam ID already exists.", zap.Error(err), zap.String("steamID", steamID), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + } } + logger.Error("Cannot find or create user with Steam ID.", zap.Error(err), zap.String("steamID", steamID), zap.String("username", username), zap.Bool("create", create)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") + } - return dbUserID, dbUsername, nil + if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != 1 { + logger.Error("Did not insert new user.", zap.Int64("rows_affected", rowsAffectedCount)) + return "", "", status.Error(codes.Internal, "Error finding or creating user account.") } + + return userID, username, nil } -func importFacebookFriends(logger *zap.Logger, db *sql.DB, client *social.Client, userID uuid.UUID, username, token string) { +func importFacebookFriends(logger *zap.Logger, db *sql.DB, client *social.Client, userID uuid.UUID, username, token string, reset bool) error { facebookProfiles, err := client.GetFacebookFriends(token) if err != nil { logger.Debug("Could not import Facebook friends.", zap.Error(err)) - return + return status.Error(codes.Unauthenticated, "Could not authenticate Facebook profile.") } - if len(facebookProfiles) == 0 { - return + if len(facebookProfiles) == 0 && !reset { + // No Facebook friends to import, and friend reset not requested - no work to do. + return nil } ts := time.Now().UTC().Unix() friendUserIDs := make([]uuid.UUID, 0) err = Transact(logger, db, func(tx *sql.Tx) error { + if reset { + // Reset all friends for the current user, replacing them entirely with their Facebook friends. + // Note: will NOT remove blocked users. + query := "DELETE FROM user_edge WHERE source_id = $1 AND state != 3" + result, err := tx.Exec(query, userID) + if err != nil { + return err + } + if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != 0 { + // Update edge count to reflect removed friends. + query = "UPDATE user SET edge_count = edge_count - $2 WHERE id = $1" + result, err := tx.Exec(query, userID, rowsAffectedCount) + if err != nil { + return err + } + if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != 1 { + return errors.New("error updating edge count after friends reset") + } + } + + // Remove links to the current user. + // Note: will NOT remove blocks. + query = "DELETE FROM user_edge WHERE destination_id = $1 AND state != 3 RETURNING source_id" + rows, err := tx.Query(query, userID) + if err != nil { + return err + } + defer rows.Close() + var id string + query = "UPDATE user SET edge_count = edge_count - 1 WHERE id = $1" + for rows.Next() { + // Update edge count to reflect each removed friend. + err = rows.Scan(&id) + if err != nil { + return err + } + result, err := tx.Exec(query, id) + if err != nil { + return err + } + if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != 1 { + return errors.New("error updating edge count after friend reset") + } + } + } + statements := make([]string, 0, len(facebookProfiles)) params := make([]interface{}, 0, len(facebookProfiles)) count := 1 @@ -637,6 +685,7 @@ AND EXISTS }) if err != nil { logger.Error("Error importing Facebook friends.", zap.Error(err)) + return status.Error(codes.Internal, "Error importing Facebook friends.") } if len(friendUserIDs) != 0 { @@ -655,4 +704,6 @@ AND EXISTS }} } } + + return nil } diff --git a/server/match_handler.go b/server/match_handler.go index df97d77c2bb9490b15d53f4a4a0d14db5483e703..427b9d827f9c8b93d037141a0f994d50f78be971 100644 --- a/server/match_handler.go +++ b/server/match_handler.go @@ -320,18 +320,18 @@ func loop(mh *MatchHandler) { msg := <-mh.inputCh presence := mh.vm.CreateTable(4, 4) - presence.RawSetString("UserId", lua.LString(msg.UserID.String())) - presence.RawSetString("SessionId", lua.LString(msg.SessionID.String())) - presence.RawSetString("Username", lua.LString(msg.Username)) - presence.RawSetString("Node", lua.LString(msg.Node)) + presence.RawSetString("user_id", lua.LString(msg.UserID.String())) + presence.RawSetString("session_id", lua.LString(msg.SessionID.String())) + presence.RawSetString("username", lua.LString(msg.Username)) + presence.RawSetString("node", lua.LString(msg.Node)) in := mh.vm.CreateTable(3, 3) - in.RawSetString("Sender", presence) - in.RawSetString("OpCode", lua.LNumber(msg.OpCode)) + in.RawSetString("sender", presence) + in.RawSetString("op_code", lua.LNumber(msg.OpCode)) if msg.Data != nil { - in.RawSetString("Data", lua.LString(msg.Data)) + in.RawSetString("data", lua.LString(msg.Data)) } else { - in.RawSetString("Data", lua.LNil) + in.RawSetString("data", lua.LNil) } input.RawSetInt(i, in) @@ -384,10 +384,10 @@ func JoinAttempt(resultCh chan *MatchJoinResult, userID, sessionID uuid.UUID, us mh.Unlock() presence := mh.vm.CreateTable(4, 4) - presence.RawSetString("UserId", lua.LString(userID.String())) - presence.RawSetString("SessionId", lua.LString(sessionID.String())) - presence.RawSetString("Username", lua.LString(username)) - presence.RawSetString("Node", lua.LString(node)) + presence.RawSetString("user_id", lua.LString(userID.String())) + presence.RawSetString("session_id", lua.LString(sessionID.String())) + presence.RawSetString("username", lua.LString(username)) + presence.RawSetString("node", lua.LString(node)) // Execute the match_join_attempt call. mh.vm.Push(LSentinel) @@ -456,10 +456,10 @@ func Leave(leaves []*MatchPresence) func(mh *MatchHandler) { presences := mh.vm.CreateTable(size, size) for i, p := range leaves { presence := mh.vm.CreateTable(4, 4) - presence.RawSetString("UserId", lua.LString(p.UserID.String())) - presence.RawSetString("SessionId", lua.LString(p.SessionID.String())) - presence.RawSetString("Username", lua.LString(p.Username)) - presence.RawSetString("Node", lua.LString(p.Node)) + presence.RawSetString("user_id", lua.LString(p.UserID.String())) + presence.RawSetString("session_id", lua.LString(p.SessionID.String())) + presence.RawSetString("username", lua.LString(p.Username)) + presence.RawSetString("node", lua.LString(p.Node)) presences.RawSetInt(i+1, presence) } @@ -532,18 +532,18 @@ func (mh *MatchHandler) broadcastMessage(l *lua.LState) int { presenceID := &PresenceID{} pt.ForEach(func(k, v lua.LValue) { switch k.String() { - case "SessionId": + case "session_id": sid, err := uuid.FromString(v.String()) if err != nil { conversionError = true - l.ArgError(1, "expects each presence to have a valid SessionId") + l.ArgError(1, "expects each presence to have a valid session_id") return } presenceID.SessionID = sid - case "Node": + case "node": if v.Type() != lua.LTString { conversionError = true - l.ArgError(1, "expects Node to be string") + l.ArgError(1, "expects node to be string") return } presenceID.Node = v.String() @@ -551,7 +551,7 @@ func (mh *MatchHandler) broadcastMessage(l *lua.LState) int { }) if presenceID.SessionID == uuid.Nil || presenceID.Node == "" { conversionError = true - l.ArgError(1, "expects each presence to have a valid UserId, SessionId, and Node") + l.ArgError(1, "expects each presence to have a valid session_id and node") return } if conversionError { @@ -576,35 +576,35 @@ func (mh *MatchHandler) broadcastMessage(l *lua.LState) int { conversionError := false sender.ForEach(func(k, v lua.LValue) { switch k.String() { - case "UserId": + case "user_id": s := v.String() _, err := uuid.FromString(s) if err != nil { conversionError = true - l.ArgError(4, "expects presence to have a valid UserId") + l.ArgError(4, "expects presence to have a valid user_id") return } presence.UserId = s - case "SessionId": + case "session_id": s := v.String() _, err := uuid.FromString(s) if err != nil { conversionError = true - l.ArgError(4, "expects presence to have a valid SessionId") + l.ArgError(4, "expects presence to have a valid session_id") return } presence.SessionId = s - case "Username": + case "username": if v.Type() != lua.LTString { conversionError = true - l.ArgError(4, "expects Username to be string") + l.ArgError(4, "expects username to be string") return } presence.Username = v.String() } }) if presence.UserId == "" || presence.SessionId == "" || presence.Username == "" { - l.ArgError(4, "expects presence to have a valid UserId, SessionId, and Username") + l.ArgError(4, "expects presence to have a valid user_id, session_id, and username") return 0 } if conversionError { @@ -679,26 +679,26 @@ func (mh *MatchHandler) matchKick(l *lua.LState) int { presence := &MatchPresence{} pt.ForEach(func(k, v lua.LValue) { switch k.String() { - case "UserId": + case "user_id": uid, err := uuid.FromString(v.String()) if err != nil { conversionError = true - l.ArgError(1, "expects each presence to have a valid UserId") + l.ArgError(1, "expects each presence to have a valid user_id") return } presence.UserID = uid - case "SessionId": + case "session_id": sid, err := uuid.FromString(v.String()) if err != nil { conversionError = true - l.ArgError(1, "expects each presence to have a valid SessionId") + l.ArgError(1, "expects each presence to have a valid session_id") return } presence.SessionID = sid - case "Node": + case "node": if v.Type() != lua.LTString { conversionError = true - l.ArgError(1, "expects Node to be string") + l.ArgError(1, "expects node to be string") return } presence.Node = v.String() @@ -706,7 +706,7 @@ func (mh *MatchHandler) matchKick(l *lua.LState) int { }) if presence.UserID == uuid.Nil || presence.SessionID == uuid.Nil || presence.Node == "" { conversionError = true - l.ArgError(1, "expects each presence to have a valid UserId, SessionId, and Node") + l.ArgError(1, "expects each presence to have a valid user_id, session_id, and node") return } if conversionError { diff --git a/server/runtime.go b/server/runtime.go index eb1afcbc3cb129219ad6043a7acfec271cc2e055..dcb0af99fe1849a481965b8e17dfa83d6c6683f0 100644 --- a/server/runtime.go +++ b/server/runtime.go @@ -42,9 +42,9 @@ func (s *LSentinelType) Type() lua.LValueType { return LTSentinel } var LSentinel = lua.LValue(&LSentinelType{}) type RuntimeModule struct { - name string - path string - content []byte + Name string + Path string + Content []byte } type RuntimePool struct { @@ -154,13 +154,13 @@ func (r *Runtime) loadModules(modules []*RuntimeModule) error { preload := r.vm.GetField(r.vm.GetField(r.vm.Get(lua.EnvironIndex), "package"), "preload") fns := make(map[string]*lua.LFunction) for _, module := range modules { - f, err := r.vm.Load(bytes.NewReader(module.content), module.path) + f, err := r.vm.Load(bytes.NewReader(module.Content), module.Path) if err != nil { - r.logger.Error("Could not load module", zap.String("name", module.path), zap.Error(err)) + r.logger.Error("Could not load module", zap.String("name", module.Path), zap.Error(err)) return err } else { - r.vm.SetField(preload, module.name, f) - fns[module.name] = f + r.vm.SetField(preload, module.Name, f) + fns[module.Name] = f } } diff --git a/server/runtime_loadlib.go b/server/runtime_loadlib.go index d45c89995cdf463668d26b4d844913ee7daf17ea..3403d6d8c9ae447dcb79fbe4f526628e2cd49deb 100644 --- a/server/runtime_loadlib.go +++ b/server/runtime_loadlib.go @@ -64,7 +64,7 @@ func OpenPackage(modules *sync.Map) func(L *lua.LState) int { L.Push(lua.LString(fmt.Sprintf("invalid cached module '%s'", name))) return 1 } - fn, err := L.Load(bytes.NewReader(module.content), module.path) + fn, err := L.Load(bytes.NewReader(module.Content), module.Path) if err != nil { L.RaiseError(err.Error()) } diff --git a/server/runtime_lua_context.go b/server/runtime_lua_context.go index dc5ae2eea2e15762da89ce561aa81e496bdd88f3..8b14fd3c566e25db8a29487b1e3a07b4afb9161a 100644 --- a/server/runtime_lua_context.go +++ b/server/runtime_lua_context.go @@ -42,16 +42,16 @@ func (e ExecutionMode) String() string { } const ( - __CTX_ENV = "Env" - __CTX_MODE = "ExecutionMode" - __CTX_USER_ID = "UserId" - __CTX_USERNAME = "Username" - __CTX_USER_SESSION_EXP = "UserSessionExp" - __CTX_SESSION_ID = "SessionId" - __CTX_MATCH_ID = "MatchId" - __CTX_MATCH_NODE = "MatchNode" - __CTX_MATCH_LABEL = "MatchLabel" - __CTX_MATCH_TICK_RATE = "MatchTickRate" + __CTX_ENV = "env" + __CTX_MODE = "execution_mode" + __CTX_USER_ID = "user_id" + __CTX_USERNAME = "username" + __CTX_USER_SESSION_EXP = "user_session_exp" + __CTX_SESSION_ID = "session_id" + __CTX_MATCH_ID = "match_id" + __CTX_MATCH_NODE = "match_node" + __CTX_MATCH_LABEL = "match_label" + __CTX_MATCH_TICK_RATE = "match_tick_rate" ) func NewLuaContext(l *lua.LState, env *lua.LTable, mode ExecutionMode, uid string, username string, sessionExpiry int64, sid string) *lua.LTable { diff --git a/server/runtime_module_cache.go b/server/runtime_module_cache.go index 75c3b43f7107d2bdd2e459bb6ec430937b517b7d..568a94d5d9c4539e55c33dfa9999db4e79b28cf0 100644 --- a/server/runtime_module_cache.go +++ b/server/runtime_module_cache.go @@ -57,9 +57,9 @@ func LoadRuntimeModules(logger, multiLogger *zap.Logger, config Config) (map[str // Make paths Lua friendly. name = strings.Replace(name, "/", ".", -1) modules.Store(name, &RuntimeModule{ - name: name, - path: path, - content: content, + Name: name, + Path: path, + Content: content, }) modulePaths = append(modulePaths, relPath) } diff --git a/server/runtime_nakama_module.go b/server/runtime_nakama_module.go index d5f323f86f17d739edcbeefc9189b3ba6f419621..45e8764478f531412b4535f776af724562dea3e1 100644 --- a/server/runtime_nakama_module.go +++ b/server/runtime_nakama_module.go @@ -793,7 +793,7 @@ func (n *NakamaModule) authenticateFacebook(l *lua.LState) int { // Import friends if requested. if importFriends { - importFacebookFriends(n.logger, n.db, n.socialClient, uuid.FromStringOrNil(dbUserID), dbUsername, token) + importFacebookFriends(n.logger, n.db, n.socialClient, uuid.FromStringOrNil(dbUserID), dbUsername, token, false) } l.Push(lua.LString(dbUserID)) @@ -1057,37 +1057,37 @@ func (n *NakamaModule) streamUserGet(l *lua.LState) int { conversionError := "" streamTable.ForEach(func(k lua.LValue, v lua.LValue) { switch k.String() { - case "Mode": + case "mode": if v.Type() != lua.LTNumber { - conversionError = "stream Mode must be a number" + conversionError = "stream mode must be a number" return } stream.Mode = uint8(lua.LVAsNumber(v)) - case "Subject": + case "subject": if v.Type() != lua.LTString { - conversionError = "stream Subject must be a string" + conversionError = "stream subject must be a string" return } sid, err := uuid.FromString(v.String()) if err != nil { - conversionError = "stream Subject must be a valid identifier" + conversionError = "stream subject must be a valid identifier" return } stream.Subject = sid - case "Descriptor": + case "descriptor": if v.Type() != lua.LTString { - conversionError = "stream Descriptor must be a string" + conversionError = "stream descriptor must be a string" return } did, err := uuid.FromString(v.String()) if err != nil { - conversionError = "stream Descriptor must be a valid identifier" + conversionError = "stream descriptor must be a valid identifier" return } stream.Subject = did - case "Label": + case "label": if v.Type() != lua.LTString { - conversionError = "stream Label must be a string" + conversionError = "stream label must be a string" return } stream.Label = v.String() @@ -1103,10 +1103,10 @@ func (n *NakamaModule) streamUserGet(l *lua.LState) int { l.Push(lua.LNil) } else { metaTable := l.CreateTable(4, 4) - metaTable.RawSetString("Hidden", lua.LBool(meta.Hidden)) - metaTable.RawSetString("Persistence", lua.LBool(meta.Persistence)) - metaTable.RawSetString("Username", lua.LString(meta.Username)) - metaTable.RawSetString("Status", lua.LString(meta.Status)) + metaTable.RawSetString("hidden", lua.LBool(meta.Hidden)) + metaTable.RawSetString("persistence", lua.LBool(meta.Persistence)) + metaTable.RawSetString("username", lua.LString(meta.Username)) + metaTable.RawSetString("status", lua.LString(meta.Status)) l.Push(metaTable) } return 1 @@ -1147,37 +1147,37 @@ func (n *NakamaModule) streamUserJoin(l *lua.LState) int { conversionError := "" streamTable.ForEach(func(k lua.LValue, v lua.LValue) { switch k.String() { - case "Mode": + case "mode": if v.Type() != lua.LTNumber { - conversionError = "stream Mode must be a number" + conversionError = "stream mode must be a number" return } stream.Mode = uint8(lua.LVAsNumber(v)) - case "Subject": + case "subject": if v.Type() != lua.LTString { - conversionError = "stream Subject must be a string" + conversionError = "stream subject must be a string" return } sid, err := uuid.FromString(v.String()) if err != nil { - conversionError = "stream Subject must be a valid identifier" + conversionError = "stream subject must be a valid identifier" return } stream.Subject = sid - case "Descriptor": + case "descriptor": if v.Type() != lua.LTString { - conversionError = "stream Descriptor must be a string" + conversionError = "stream descriptor must be a string" return } did, err := uuid.FromString(v.String()) if err != nil { - conversionError = "stream Descriptor must be a valid identifier" + conversionError = "stream descriptor must be a valid identifier" return } stream.Subject = did - case "Label": + case "label": if v.Type() != lua.LTString { - conversionError = "stream Label must be a string" + conversionError = "stream label must be a string" return } stream.Label = v.String() @@ -1250,37 +1250,37 @@ func (n *NakamaModule) streamUserLeave(l *lua.LState) int { conversionError := "" streamTable.ForEach(func(k lua.LValue, v lua.LValue) { switch k.String() { - case "Mode": + case "mode": if v.Type() != lua.LTNumber { - conversionError = "stream Mode must be a number" + conversionError = "stream mode must be a number" return } stream.Mode = uint8(lua.LVAsNumber(v)) - case "Subject": + case "subject": if v.Type() != lua.LTString { - conversionError = "stream Subject must be a string" + conversionError = "stream subject must be a string" return } sid, err := uuid.FromString(v.String()) if err != nil { - conversionError = "stream Subject must be a valid identifier" + conversionError = "stream subject must be a valid identifier" return } stream.Subject = sid - case "Descriptor": + case "descriptor": if v.Type() != lua.LTString { - conversionError = "stream Descriptor must be a string" + conversionError = "stream descriptor must be a string" return } did, err := uuid.FromString(v.String()) if err != nil { - conversionError = "stream Descriptor must be a valid identifier" + conversionError = "stream descriptor must be a valid identifier" return } stream.Subject = did - case "Label": + case "label": if v.Type() != lua.LTString { - conversionError = "stream Label must be a string" + conversionError = "stream label must be a string" return } stream.Label = v.String() @@ -1307,37 +1307,37 @@ func (n *NakamaModule) streamCount(l *lua.LState) int { conversionError := "" streamTable.ForEach(func(k lua.LValue, v lua.LValue) { switch k.String() { - case "Mode": + case "mode": if v.Type() != lua.LTNumber { - conversionError = "stream Mode must be a number" + conversionError = "stream mode must be a number" return } stream.Mode = uint8(lua.LVAsNumber(v)) - case "Subject": + case "subject": if v.Type() != lua.LTString { - conversionError = "stream Subject must be a string" + conversionError = "stream subject must be a string" return } sid, err := uuid.FromString(v.String()) if err != nil { - conversionError = "stream Subject must be a valid identifier" + conversionError = "stream subject must be a valid identifier" return } stream.Subject = sid - case "Descriptor": + case "descriptor": if v.Type() != lua.LTString { - conversionError = "stream Descriptor must be a string" + conversionError = "stream descriptor must be a string" return } did, err := uuid.FromString(v.String()) if err != nil { - conversionError = "stream Descriptor must be a valid identifier" + conversionError = "stream descriptor must be a valid identifier" return } stream.Subject = did - case "Label": + case "label": if v.Type() != lua.LTString { - conversionError = "stream Label must be a string" + conversionError = "stream label must be a string" return } stream.Label = v.String() @@ -1365,37 +1365,37 @@ func (n *NakamaModule) streamClose(l *lua.LState) int { conversionError := "" streamTable.ForEach(func(k lua.LValue, v lua.LValue) { switch k.String() { - case "Mode": + case "mode": if v.Type() != lua.LTNumber { - conversionError = "stream Mode must be a number" + conversionError = "stream mode must be a number" return } stream.Mode = uint8(lua.LVAsNumber(v)) - case "Subject": + case "subject": if v.Type() != lua.LTString { - conversionError = "stream Subject must be a string" + conversionError = "stream subject must be a string" return } sid, err := uuid.FromString(v.String()) if err != nil { - conversionError = "stream Subject must be a valid identifier" + conversionError = "stream subject must be a valid identifier" return } stream.Subject = sid - case "Descriptor": + case "descriptor": if v.Type() != lua.LTString { - conversionError = "stream Descriptor must be a string" + conversionError = "stream descriptor must be a string" return } did, err := uuid.FromString(v.String()) if err != nil { - conversionError = "stream Descriptor must be a valid identifier" + conversionError = "stream descriptor must be a valid identifier" return } stream.Subject = did - case "Label": + case "label": if v.Type() != lua.LTString { - conversionError = "stream Label must be a string" + conversionError = "stream label must be a string" return } stream.Label = v.String() @@ -1422,37 +1422,37 @@ func (n *NakamaModule) streamSend(l *lua.LState) int { conversionError := "" streamTable.ForEach(func(k lua.LValue, v lua.LValue) { switch k.String() { - case "Mode": + case "mode": if v.Type() != lua.LTNumber { - conversionError = "stream Mode must be a number" + conversionError = "stream mode must be a number" return } stream.Mode = uint8(lua.LVAsNumber(v)) - case "Subject": + case "subject": if v.Type() != lua.LTString { - conversionError = "stream Subject must be a string" + conversionError = "stream subject must be a string" return } sid, err := uuid.FromString(v.String()) if err != nil { - conversionError = "stream Subject must be a valid identifier" + conversionError = "stream subject must be a valid identifier" return } stream.Subject = sid - case "Descriptor": + case "descriptor": if v.Type() != lua.LTString { - conversionError = "stream Descriptor must be a string" + conversionError = "stream descriptor must be a string" return } did, err := uuid.FromString(v.String()) if err != nil { - conversionError = "stream Descriptor must be a valid identifier" + conversionError = "stream descriptor must be a valid identifier" return } stream.Subject = did - case "Label": + case "label": if v.Type() != lua.LTString { - conversionError = "stream Label must be a string" + conversionError = "stream label must be a string" return } stream.Label = v.String() @@ -1558,14 +1558,14 @@ func (n *NakamaModule) matchList(l *lua.LState) int { matches := l.CreateTable(s, s) for i, result := range results { match := l.CreateTable(4, 4) - match.RawSetString("MatchId", lua.LString(result.MatchId)) - match.RawSetString("Authoritative", lua.LBool(result.Authoritative)) + match.RawSetString("match_id", lua.LString(result.MatchId)) + match.RawSetString("authoritative", lua.LBool(result.Authoritative)) if result.Label == nil { - match.RawSetString("Label", lua.LNil) + match.RawSetString("label", lua.LNil) } else { - match.RawSetString("Label", lua.LString(result.Label.Value)) + match.RawSetString("label", lua.LString(result.Label.Value)) } - match.RawSetString("Size", lua.LNumber(result.Size)) + match.RawSetString("size", lua.LNumber(result.Size)) matches.RawSetInt(i+1, match) } l.Push(matches) @@ -1595,13 +1595,13 @@ func (n *NakamaModule) notificationSend(l *lua.LState) int { u := l.CheckString(1) userID, err := uuid.FromString(u) if err != nil { - l.ArgError(1, "expects UserID to be a valid UUID") + l.ArgError(1, "expects user_id to be a valid UUID") return 0 } subject := l.CheckString(2) if subject == "" { - l.ArgError(2, "expects Subject to be non-empty") + l.ArgError(2, "expects subject to be a non-empty string") return 0 } @@ -1615,7 +1615,7 @@ func (n *NakamaModule) notificationSend(l *lua.LState) int { code := l.CheckInt(4) if code <= 0 { - l.ArgError(4, "expects Code to number above 0") + l.ArgError(4, "expects code to number above 0") return 0 } @@ -1624,7 +1624,7 @@ func (n *NakamaModule) notificationSend(l *lua.LState) int { if s != "" { suid, err := uuid.FromString(s) if err != nil { - l.ArgError(5, "expects senderID to either be not set, empty string or a valid UUID") + l.ArgError(5, "expects sender)id to either be not set, empty string or a valid UUID") return 0 } senderID = suid.String() @@ -1673,24 +1673,24 @@ func (n *NakamaModule) notificationsSend(l *lua.LState) int { userID := uuid.Nil notificationTable.ForEach(func(k lua.LValue, v lua.LValue) { switch k.String() { - case "Persistent": + case "persistent": if v.Type() != lua.LTBool { conversionError = true - l.ArgError(1, "expects Persistent to be boolean") + l.ArgError(1, "expects persistent to be boolean") return } notification.Persistent = lua.LVAsBool(v) - case "Subject": + case "subject": if v.Type() != lua.LTString { conversionError = true - l.ArgError(1, "expects Subject to be string") + l.ArgError(1, "expects subject to be string") return } notification.Subject = v.String() - case "Content": + case "content": if v.Type() != lua.LTTable { conversionError = true - l.ArgError(1, "expects Content to be a table") + l.ArgError(1, "expects content to be a table") return } @@ -1703,47 +1703,47 @@ func (n *NakamaModule) notificationsSend(l *lua.LState) int { } notification.Content = string(contentBytes) - case "Code": + case "code": if v.Type() != lua.LTNumber { conversionError = true - l.ArgError(1, "expects Code to be number") + l.ArgError(1, "expects code to be number") return } number := int(lua.LVAsNumber(v)) if number <= 0 { - l.ArgError(1, "expects Code to number above 0") + l.ArgError(1, "expects code to number above 0") return } notification.Code = int32(number) - case "UserId": + case "user_id": if v.Type() != lua.LTString { conversionError = true - l.ArgError(1, "expects UserId to be string") + l.ArgError(1, "expects user_id to be string") return } u := v.String() if u == "" { - l.ArgError(1, "expects UserId to be a valid UUID") + l.ArgError(1, "expects user_id to be a valid UUID") return } uid, err := uuid.FromString(u) if err != nil { - l.ArgError(1, "expects UserId to be a valid UUID") + l.ArgError(1, "expects user_id to be a valid UUID") return } userID = uid - case "SenderId": + case "sender_id": if v.Type() == lua.LTNil { return } if v.Type() != lua.LTString { conversionError = true - l.ArgError(1, "expects SenderId to be string") + l.ArgError(1, "expects sender_id to be string") return } u := v.String() if u == "" { - l.ArgError(1, "expects SenderId to be a valid UUID") + l.ArgError(1, "expects sender_id to be a valid UUID") return } notification.SenderId = u @@ -1751,16 +1751,16 @@ func (n *NakamaModule) notificationsSend(l *lua.LState) int { }) if notification.Subject == "" { - l.ArgError(1, "expects Subject to be non-empty") + l.ArgError(1, "expects subject to be non-empty") return } else if len(notification.Content) == 0 { - l.ArgError(1, "expects Content to be a valid JSON") + l.ArgError(1, "expects content to be a valid JSON") return } else if uuid.Equal(uuid.Nil, userID) { - l.ArgError(1, "expects UserId to be a valid UUID") + l.ArgError(1, "expects user_id to be a valid UUID") return } else if notification.Code == 0 { - l.ArgError(1, "expects Code to number above 0") + l.ArgError(1, "expects code to number above 0") return } diff --git a/tests/runtime_test.go b/tests/runtime_test.go index 15540143102ab062985b3dd0708213a6a326c4db..0f61685b416e44fc3875fb4c65d0405f3df50990 100644 --- a/tests/runtime_test.go +++ b/tests/runtime_test.go @@ -6,20 +6,26 @@ import ( "io/ioutil" "net/http" "os" - "path/filepath" "strings" "testing" + "fmt" + "github.com/heroiclabs/nakama/rtapi" "github.com/heroiclabs/nakama/server" + "github.com/yuin/gopher-lua" "go.uber.org/zap" "golang.org/x/crypto/bcrypt" + "sync" ) +type DummyMessageRouter struct{} + +func (d *DummyMessageRouter) SendToPresenceIDs(*zap.Logger, []*server.PresenceID, *rtapi.Envelope) {} +func (d *DummyMessageRouter) SendToStream(*zap.Logger, server.PresenceStream, *rtapi.Envelope) {} + var ( - tempDir, _ = ioutil.TempDir("", "nakama") - luaPath = filepath.Join(tempDir, "modules") - config = server.NewConfig() - logger = server.NewConsoleLogger(os.Stdout, true) + config = server.NewConfig() + logger = server.NewConsoleLogger(os.Stdout, true) ) func db(t *testing.T) *sql.DB { @@ -34,23 +40,30 @@ func db(t *testing.T) *sql.DB { return db } -func vm(t *testing.T) *server.RuntimePool { - config.Runtime.Path = luaPath - runtimePool, err := server.NewRuntimePool(logger, logger, db(t), config, nil, nil, nil) - if err != nil { - t.Error("Failed initializing runtime modules", zap.Error(err)) +func vm(t *testing.T, modules *sync.Map, regRPC map[string]struct{}) *server.RuntimePool { + stdLibs := map[string]lua.LGFunction{ + lua.LoadLibName: server.OpenPackage(modules), + lua.BaseLibName: lua.OpenBase, + lua.TabLibName: lua.OpenTable, + lua.OsLibName: server.OpenOs, + lua.StringLibName: lua.OpenString, + lua.MathLibName: lua.OpenMath, } + runtimePool := server.NewRuntimePool(logger, logger, db(t), config, nil, nil, nil, nil, &DummyMessageRouter{}, stdLibs, modules, regRPC, &sync.Once{}) return runtimePool } -func writeLuaModule(name, content string) { - os.MkdirAll(luaPath, os.ModePerm) - ioutil.WriteFile(filepath.Join(luaPath, name), []byte(content), 0644) +func writeLuaModule(modules *sync.Map, name, content string) { + modules.Store(name, &server.RuntimeModule{ + Name: name, + Path: fmt.Sprintf("%v.lua", name), + Content: []byte(content), + }) } -func writeStatsModule() { - writeLuaModule("stats.lua", ` +func writeStatsModule(modules *sync.Map) { + writeLuaModule(modules, "stats", ` stats={} -- Get the mean value of a table function stats.mean( t ) @@ -68,8 +81,8 @@ print("Stats Module Loaded") return stats`) } -func writeTestModule() { - writeLuaModule("test.lua", ` +func writeTestModule(modules *sync.Map) { + writeLuaModule(modules, "test", ` test={} -- Get the mean value of a table function test.printWorld() @@ -82,7 +95,7 @@ return test } func TestRuntimeSampleScript(t *testing.T) { - rp := vm(t) + rp := vm(t, new(sync.Map), make(map[string]struct{}, 0)) r := rp.Get() defer r.Stop() @@ -100,7 +113,7 @@ end`) } func TestRuntimeDisallowStandardLibs(t *testing.T) { - rp := vm(t) + rp := vm(t, new(sync.Map), make(map[string]struct{}, 0)) r := rp.Get() defer r.Stop() @@ -124,44 +137,44 @@ file_exists "./"`) // Have a look at the stdout messages to see if the module was loaded multiple times // You should only see "Test Module Loaded" once func TestRuntimeRequireEval(t *testing.T) { - defer os.RemoveAll(luaPath) - writeTestModule() - writeLuaModule("test-invoke.lua", ` + modules := new(sync.Map) + writeTestModule(modules) + writeLuaModule(modules, "test-invoke", ` local nakama = require("nakama") local test = require("test") test.printWorld() `) - vm(t) + vm(t, modules, make(map[string]struct{}, 0)) } func TestRuntimeRequireFile(t *testing.T) { - defer os.RemoveAll(luaPath) - writeStatsModule() - writeLuaModule("local_test.lua", ` + modules := new(sync.Map) + writeStatsModule(modules) + writeLuaModule(modules, "local_test", ` local stats = require("stats") t = {[1]=5, [2]=7, [3]=8, [4]='Something else.'} assert(stats.mean(t) > 0) `) - vm(t) + vm(t, modules, make(map[string]struct{}, 0)) } func TestRuntimeRequirePreload(t *testing.T) { - defer os.RemoveAll(luaPath) - writeStatsModule() - writeLuaModule("states-invoke.lua", ` + modules := new(sync.Map) + writeStatsModule(modules) + writeLuaModule(modules, "states-invoke", ` local stats = require("stats") t = {[1]=5, [2]=7, [3]=8, [4]='Something else.'} print(stats.mean(t)) `) - vm(t) + vm(t, modules, make(map[string]struct{}, 0)) } func TestRuntimeRegisterRPCWithPayload(t *testing.T) { - defer os.RemoveAll(luaPath) - writeLuaModule("test.lua", ` + modules := new(sync.Map) + writeLuaModule(modules, "test", ` test={} -- Get the mean value of a table function test.printWorld(ctx, payload) @@ -172,20 +185,20 @@ end print("Test Module Loaded") return test `) - writeLuaModule("http-invoke.lua", ` + writeLuaModule(modules, "http-invoke", ` local nakama = require("nakama") local test = require("test") nakama.register_rpc(test.printWorld, "helloworld") `) - rp := vm(t) + rp := vm(t, modules, map[string]struct{}{"helloworld": struct{}{}}) r := rp.Get() defer r.Stop() fn := r.GetRuntimeCallback(server.RPC, "helloworld") payload := "Hello World" - m, err := r.InvokeFunctionRPC(fn, "", "", 0, "", payload) + m, err, _ := r.InvokeFunctionRPC(fn, "", "", 0, "", payload) if err != nil { t.Error(err) } @@ -196,8 +209,8 @@ nakama.register_rpc(test.printWorld, "helloworld") } func TestRuntimeRegisterRPCWithPayloadEndToEnd(t *testing.T) { - defer os.RemoveAll(luaPath) - writeLuaModule("test.lua", ` + modules := new(sync.Map) + writeLuaModule(modules, "test", ` test={} -- Get the mean value of a table function test.printWorld(ctx, payload) @@ -208,23 +221,23 @@ end print("Test Module Loaded") return test `) - writeLuaModule("http-invoke.lua", ` + writeLuaModule(modules, "http-invoke", ` local nakama = require("nakama") local test = require("test") nakama.register_rpc(test.printWorld, "helloworld") `) - rp := vm(t) + rp := vm(t, modules, map[string]struct{}{"helloworld": struct{}{}}) r := rp.Get() defer r.Stop() - pipeline := server.NewPipeline(config, nil, nil, nil, nil, rp) - apiServer := server.StartApiServer(logger, nil, nil, nil, config, nil, nil, nil, pipeline, rp) + pipeline := server.NewPipeline(config, nil, nil, nil, nil, nil, rp) + apiServer := server.StartApiServer(logger, nil, nil, nil, config, nil, nil, nil, nil, nil, pipeline, rp) defer apiServer.Stop() payload := "\"Hello World\"" client := &http.Client{} - request, _ := http.NewRequest("POST", "http://localhost:7351/v2/rpc/helloworld?http_key=defaultkey", strings.NewReader(payload)) + request, _ := http.NewRequest("POST", "http://localhost:7349/v2/rpc/helloworld?http_key=defaultkey", strings.NewReader(payload)) request.Header.Add("Content-Type", "Application/JSON") res, err := client.Do(request) if err != nil { @@ -243,8 +256,8 @@ nakama.register_rpc(test.printWorld, "helloworld") } func TestRuntimeHTTPRequest(t *testing.T) { - defer os.RemoveAll(luaPath) - writeLuaModule("test.lua", ` + modules := new(sync.Map) + writeLuaModule(modules, "test", ` local nakama = require("nakama") function test(ctx, payload) local success, code, headers, body = pcall(nakama.http_request, "http://httpbin.org/status/200", "GET", {}) @@ -253,12 +266,12 @@ end nakama.register_rpc(test, "test") `) - rp := vm(t) + rp := vm(t, modules, map[string]struct{}{"test": struct{}{}}) r := rp.Get() defer r.Stop() fn := r.GetRuntimeCallback(server.RPC, "test") - m, err := r.InvokeFunctionRPC(fn, "", "", 0, "", "") + m, err, _ := r.InvokeFunctionRPC(fn, "", "", 0, "", "") if err != nil { t.Error(err) } @@ -269,8 +282,8 @@ nakama.register_rpc(test, "test") } func TestRuntimeJson(t *testing.T) { - defer os.RemoveAll(luaPath) - writeLuaModule("test.lua", ` + modules := new(sync.Map) + writeLuaModule(modules, "test", ` local nakama = require("nakama") function test(ctx, payload) return nakama.json_encode(nakama.json_decode(payload)) @@ -278,13 +291,13 @@ end nakama.register_rpc(test, "test") `) - rp := vm(t) + rp := vm(t, modules, map[string]struct{}{"test": struct{}{}}) r := rp.Get() defer r.Stop() payload := "{\"key\":\"value\"}" fn := r.GetRuntimeCallback(server.RPC, "test") - m, err := r.InvokeFunctionRPC(fn, "", "", 0, "", payload) + m, err, _ := r.InvokeFunctionRPC(fn, "", "", 0, "", payload) if err != nil { t.Error(err) } @@ -295,8 +308,8 @@ nakama.register_rpc(test, "test") } func TestRuntimeBase64(t *testing.T) { - defer os.RemoveAll(luaPath) - writeLuaModule("test.lua", ` + modules := new(sync.Map) + writeLuaModule(modules, "test", ` local nakama = require("nakama") function test(ctx, payload) return nakama.base64_decode(nakama.base64_encode(payload)) @@ -304,13 +317,13 @@ end nakama.register_rpc(test, "test") `) - rp := vm(t) + rp := vm(t, modules, map[string]struct{}{"test": struct{}{}}) r := rp.Get() defer r.Stop() payload := "{\"key\":\"value\"}" fn := r.GetRuntimeCallback(server.RPC, "test") - m, err := r.InvokeFunctionRPC(fn, "", "", 0, "", payload) + m, err, _ := r.InvokeFunctionRPC(fn, "", "", 0, "", payload) if err != nil { t.Error(err) } @@ -321,8 +334,8 @@ nakama.register_rpc(test, "test") } func TestRuntimeBase16(t *testing.T) { - defer os.RemoveAll(luaPath) - writeLuaModule("test.lua", ` + modules := new(sync.Map) + writeLuaModule(modules, "test", ` local nakama = require("nakama") function test(ctx, payload) return nakama.base16_decode(nakama.base16_encode(payload)) @@ -330,13 +343,13 @@ end nakama.register_rpc(test, "test") `) - rp := vm(t) + rp := vm(t, modules, map[string]struct{}{"test": struct{}{}}) r := rp.Get() defer r.Stop() payload := "{\"key\":\"value\"}" fn := r.GetRuntimeCallback(server.RPC, "test") - m, err := r.InvokeFunctionRPC(fn, "", "", 0, "", payload) + m, err, _ := r.InvokeFunctionRPC(fn, "", "", 0, "", payload) if err != nil { t.Error(err) } @@ -347,8 +360,8 @@ nakama.register_rpc(test, "test") } func TestRuntimeAes128(t *testing.T) { - defer os.RemoveAll(luaPath) - writeLuaModule("test.lua", ` + modules := new(sync.Map) + writeLuaModule(modules, "test", ` local nakama = require("nakama") function test(ctx, payload) return nakama.aes128_decrypt(nakama.aes128_encrypt(payload, "goldenbridge_key"), "goldenbridge_key") @@ -356,13 +369,13 @@ end nakama.register_rpc(test, "test") `) - rp := vm(t) + rp := vm(t, modules, map[string]struct{}{"test": struct{}{}}) r := rp.Get() defer r.Stop() payload := "{\"key\":\"value\"}" fn := r.GetRuntimeCallback(server.RPC, "test") - m, err := r.InvokeFunctionRPC(fn, "", "", 0, "", payload) + m, err, _ := r.InvokeFunctionRPC(fn, "", "", 0, "", payload) if err != nil { t.Error(err) } @@ -373,8 +386,8 @@ nakama.register_rpc(test, "test") } func TestRuntimeBcryptHash(t *testing.T) { - defer os.RemoveAll(luaPath) - writeLuaModule("test.lua", ` + modules := new(sync.Map) + writeLuaModule(modules, "test", ` local nakama = require("nakama") function test(ctx, payload) return nakama.bcrypt_hash(payload) @@ -382,13 +395,13 @@ end nakama.register_rpc(test, "test") `) - rp := vm(t) + rp := vm(t, modules, map[string]struct{}{"test": struct{}{}}) r := rp.Get() defer r.Stop() payload := "{\"key\":\"value\"}" fn := r.GetRuntimeCallback(server.RPC, "test") - m, err := r.InvokeFunctionRPC(fn, "", "", 0, "", payload) + m, err, _ := r.InvokeFunctionRPC(fn, "", "", 0, "", payload) if err != nil { t.Error(err) } @@ -400,8 +413,8 @@ nakama.register_rpc(test, "test") } func TestRuntimeBcryptCompare(t *testing.T) { - defer os.RemoveAll(luaPath) - writeLuaModule("test.lua", ` + modules := new(sync.Map) + writeLuaModule(modules, "test", ` local nakama = require("nakama") function test(ctx, payload) return tostring(nakama.bcrypt_compare(payload, "something_to_encrypt")) @@ -409,14 +422,14 @@ end nakama.register_rpc(test, "test") `) - rp := vm(t) + rp := vm(t, modules, map[string]struct{}{"test": struct{}{}}) r := rp.Get() defer r.Stop() payload := "something_to_encrypt" hash, _ := bcrypt.GenerateFromPassword([]byte(payload), bcrypt.DefaultCost) fn := r.GetRuntimeCallback(server.RPC, "test") - m, err := r.InvokeFunctionRPC(fn, "", "", 0, "", string(hash)) + m, err, _ := r.InvokeFunctionRPC(fn, "", "", 0, "", string(hash)) if err != nil { t.Error(err) } @@ -427,8 +440,8 @@ nakama.register_rpc(test, "test") } func TestRuntimeNotificationsSend(t *testing.T) { - defer os.RemoveAll(luaPath) - writeLuaModule("test.lua", ` + modules := new(sync.Map) + writeLuaModule(modules, "test", ` local nk = require("nakama") local subject = "You've unlocked level 100!" @@ -439,19 +452,19 @@ local user_id = "4c2ae592-b2a7-445e-98ec-697694478b1c" -- who to send local code = 1 local new_notifications = { - { Subject = subject, Content = content, UserId = user_id, Code = code, Persistent = false} + { subject = subject, content = content, user_id = user_id, code = code, persistent = false} } nk.notifications_send(new_notifications) `) - rp := vm(t) + rp := vm(t, modules, make(map[string]struct{}, 0)) r := rp.Get() defer r.Stop() } func TestRuntimeNotificationSend(t *testing.T) { - defer os.RemoveAll(luaPath) - writeLuaModule("test.lua", ` + modules := new(sync.Map) + writeLuaModule(modules, "test", ` local nk = require("nakama") local subject = "You've unlocked level 100!" @@ -464,14 +477,14 @@ local code = 1 nk.notification_send(user_id, subject, content, code, "", false) `) - rp := vm(t) + rp := vm(t, modules, make(map[string]struct{}, 0)) r := rp.Get() defer r.Stop() } func TestRuntimeWalletWrite(t *testing.T) { - defer os.RemoveAll(luaPath) - writeLuaModule("test.lua", ` + modules := new(sync.Map) + writeLuaModule(modules, "test", ` local nk = require("nakama") local content = { @@ -482,7 +495,7 @@ local user_id = "95f05d94-cc66-445a-b4d1-9e262662cf79" -- who to send nk.wallet_write(user_id, content) `) - rp := vm(t) + rp := vm(t, modules, make(map[string]struct{}, 0)) r := rp.Get() defer r.Stop() }