From 1174be795ccdf39343ac136d256a93dfb1c326b5 Mon Sep 17 00:00:00 2001 From: Andrei Mihu Date: Fri, 28 Dec 2018 20:56:18 +0000 Subject: [PATCH] Improve Lua runtime context cancellation when waiting for available runtimes. --- CHANGELOG.md | 1 + server/leaderboard_scheduler.go | 6 +- server/pipeline_matchmaker.go | 3 +- server/runtime.go | 8 +- server/runtime_go.go | 16 +- server/runtime_lua.go | 288 ++++++++++++++++++-------------- 6 files changed, 179 insertions(+), 143 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c0e007e55..efedcf264 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr - Check group max allowed user when promoting a user. - Correct Lua runtime decoding of stream identifying parameters. - Correctly use optional parameters when they are passed to group creation operations. +- Lua runtime operations now observe context cancellation while waiting for an available Lua instance. ## [2.2.1] - 2018-11-20 ### Added diff --git a/server/leaderboard_scheduler.go b/server/leaderboard_scheduler.go index 875e5ff52..1be261af9 100644 --- a/server/leaderboard_scheduler.go +++ b/server/leaderboard_scheduler.go @@ -277,7 +277,7 @@ WHERE id = $1` // Trigger callback on a goroutine so any extended processing does not block future scheduling. go func() { - if err := fn(tournament, int64(tournament.EndActive), int64(tournament.NextReset)); err != nil { + if err := fn(ls.ctx, tournament, int64(tournament.EndActive), int64(tournament.NextReset)); err != nil { ls.logger.Warn("Failed to invoke tournament end callback", zap.Error(err)) } }() @@ -327,7 +327,7 @@ WHERE id = $1` if fnTournamentReset != nil { // Trigger callback on a goroutine so any extended processing does not block future scheduling. go func() { - if err := fnTournamentReset(tournament, int64(tournament.EndActive), int64(tournament.NextReset)); err != nil { + if err := fnTournamentReset(ls.ctx, tournament, int64(tournament.EndActive), int64(tournament.NextReset)); err != nil { ls.logger.Warn("Failed to invoke tournament reset callback", zap.Error(err)) } }() @@ -342,7 +342,7 @@ WHERE id = $1` // Trigger callback on a goroutine so any extended processing does not block future scheduling. go func() { - if err := fnLeaderboardReset(leaderboardOrTournament, nextReset); err != nil { + if err := fnLeaderboardReset(ls.ctx, leaderboardOrTournament, nextReset); err != nil { ls.logger.Warn("Failed to invoke leaderboard reset callback", zap.Error(err)) } }() diff --git a/server/pipeline_matchmaker.go b/server/pipeline_matchmaker.go index d40a0c7f3..9aa92f051 100644 --- a/server/pipeline_matchmaker.go +++ b/server/pipeline_matchmaker.go @@ -15,6 +15,7 @@ package server import ( + "context" "fmt" "time" @@ -79,7 +80,7 @@ func (p *Pipeline) matchmakerAdd(logger *zap.Logger, session Session, envelope * // Check if there's a matchmaker matched runtime callback, call it, and see if it returns a match ID. fn := p.runtime.MatchmakerMatched() if fn != nil { - tokenOrMatchID, isMatchID, err = fn(entries) + tokenOrMatchID, isMatchID, err = fn(context.Background(), entries) if err != nil { p.logger.Error("Error running Matchmaker Matched hook.", zap.Error(err)) } diff --git a/server/runtime.go b/server/runtime.go index b91c3b9f5..dc6088b41 100644 --- a/server/runtime.go +++ b/server/runtime.go @@ -161,15 +161,15 @@ type ( RuntimeBeforeGetUsersFunction func(ctx context.Context, logger *zap.Logger, userID, username string, expiry int64, clientIP, clientPort string, in *api.GetUsersRequest) (*api.GetUsersRequest, error, codes.Code) RuntimeAfterGetUsersFunction func(ctx context.Context, logger *zap.Logger, userID, username string, expiry int64, clientIP, clientPort string, out *api.Users, in *api.GetUsersRequest) error - RuntimeMatchmakerMatchedFunction func(entries []*MatchmakerEntry) (string, bool, error) + RuntimeMatchmakerMatchedFunction func(ctx context.Context, entries []*MatchmakerEntry) (string, bool, error) RuntimeMatchCreateFunction func(ctx context.Context, logger *zap.Logger, id uuid.UUID, node string, name string) (RuntimeMatchCore, error) RuntimeMatchDeferMessageFunction func(msg *DeferredMessage) error - RuntimeTournamentEndFunction func(tournament *api.Tournament, end, reset int64) error - RuntimeTournamentResetFunction func(tournament *api.Tournament, end, reset int64) error + RuntimeTournamentEndFunction func(ctx context.Context, tournament *api.Tournament, end, reset int64) error + RuntimeTournamentResetFunction func(ctx context.Context, tournament *api.Tournament, end, reset int64) error - RuntimeLeaderboardResetFunction func(leaderboard runtime.Leaderboard, reset int64) error + RuntimeLeaderboardResetFunction func(ctx context.Context, leaderboard runtime.Leaderboard, reset int64) error ) type RuntimeExecutionMode int diff --git a/server/runtime_go.go b/server/runtime_go.go index 21465b695..e709ad1d6 100644 --- a/server/runtime_go.go +++ b/server/runtime_go.go @@ -1690,8 +1690,8 @@ func (ri *RuntimeGoInitializer) RegisterAfterGetUsers(fn func(ctx context.Contex } func (ri *RuntimeGoInitializer) RegisterMatchmakerMatched(fn func(ctx context.Context, logger runtime.Logger, db *sql.DB, nk runtime.NakamaModule, entries []runtime.MatchmakerEntry) (string, error)) error { - ri.matchmakerMatched = func(entries []*MatchmakerEntry) (string, bool, error) { - ctx := NewRuntimeGoContext(context.Background(), ri.env, RuntimeExecutionModeMatchmaker, nil, 0, "", "", "", "", "") + ri.matchmakerMatched = func(ctx context.Context, entries []*MatchmakerEntry) (string, bool, error) { + ctx = NewRuntimeGoContext(ctx, ri.env, RuntimeExecutionModeMatchmaker, nil, 0, "", "", "", "", "") runtimeEntries := make([]runtime.MatchmakerEntry, len(entries)) for i, entry := range entries { runtimeEntries[i] = runtime.MatchmakerEntry(entry) @@ -1706,24 +1706,24 @@ func (ri *RuntimeGoInitializer) RegisterMatchmakerMatched(fn func(ctx context.Co } func (ri *RuntimeGoInitializer) RegisterTournamentEnd(fn func(ctx context.Context, logger runtime.Logger, db *sql.DB, nk runtime.NakamaModule, tournament *api.Tournament, end, reset int64) error) error { - ri.tournamentEnd = func(tournament *api.Tournament, end, reset int64) error { - ctx := NewRuntimeGoContext(context.Background(), ri.env, RuntimeExecutionModeTournamentEnd, nil, 0, "", "", "", "", "") + ri.tournamentEnd = func(ctx context.Context, tournament *api.Tournament, end, reset int64) error { + ctx = NewRuntimeGoContext(ctx, ri.env, RuntimeExecutionModeTournamentEnd, nil, 0, "", "", "", "", "") return fn(ctx, ri.logger, ri.db, ri.nk, tournament, end, reset) } return nil } func (ri *RuntimeGoInitializer) RegisterTournamentReset(fn func(ctx context.Context, logger runtime.Logger, db *sql.DB, nk runtime.NakamaModule, tournament *api.Tournament, end, reset int64) error) error { - ri.tournamentReset = func(tournament *api.Tournament, end, reset int64) error { - ctx := NewRuntimeGoContext(context.Background(), ri.env, RuntimeExecutionModeTournamentReset, nil, 0, "", "", "", "", "") + ri.tournamentReset = func(ctx context.Context, tournament *api.Tournament, end, reset int64) error { + ctx = NewRuntimeGoContext(ctx, ri.env, RuntimeExecutionModeTournamentReset, nil, 0, "", "", "", "", "") return fn(ctx, ri.logger, ri.db, ri.nk, tournament, end, reset) } return nil } func (ri *RuntimeGoInitializer) RegisterLeaderboardReset(fn func(ctx context.Context, logger runtime.Logger, db *sql.DB, nk runtime.NakamaModule, leaderboard runtime.Leaderboard, reset int64) error) error { - ri.leaderboardReset = func(leaderboard runtime.Leaderboard, reset int64) error { - ctx := NewRuntimeGoContext(context.Background(), ri.env, RuntimeExecutionModeLeaderboardReset, nil, 0, "", "", "", "", "") + ri.leaderboardReset = func(ctx context.Context, leaderboard runtime.Leaderboard, reset int64) error { + ctx = NewRuntimeGoContext(ctx, ri.env, RuntimeExecutionModeLeaderboardReset, nil, 0, "", "", "", "", "") return fn(ctx, ri.logger, ri.db, ri.nk, leaderboard, reset) } return nil diff --git a/server/runtime_lua.go b/server/runtime_lua.go index 16a1fef82..b06c3057d 100644 --- a/server/runtime_lua.go +++ b/server/runtime_lua.go @@ -21,6 +21,7 @@ import ( "encoding/json" "errors" "fmt" + "go.uber.org/atomic" "io/ioutil" "os" "path/filepath" @@ -83,7 +84,6 @@ func (mc *RuntimeLuaModuleCache) Add(m *RuntimeLuaModule) { } type RuntimeProviderLua struct { - sync.Mutex logger *zap.Logger db *sql.DB jsonpbMarshaler *jsonpb.Marshaler @@ -100,8 +100,8 @@ type RuntimeProviderLua struct { once *sync.Once poolCh chan *RuntimeLua - maxCount int - currentCount int + maxCount uint32 + currentCount *atomic.Uint32 newFn func() *RuntimeLua statsCtx context.Context @@ -117,7 +117,10 @@ func NewRuntimeProviderLua(logger, startupLogger *zap.Logger, db *sql.DB, jsonpb // Override before Package library is invoked. lua.LuaLDir = rootPath lua.LuaPathDefault = lua.LuaLDir + string(os.PathSeparator) + "?.lua;" + lua.LuaLDir + string(os.PathSeparator) + "?" + string(os.PathSeparator) + "init.lua" - os.Setenv(lua.LuaPath, lua.LuaPathDefault) + if err := os.Setenv(lua.LuaPath, lua.LuaPathDefault); err != nil { + startupLogger.Error("Could not set Lua module path", zap.Error(err)) + return nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, err + } startupLogger.Info("Initialising Lua runtime provider", zap.String("path", lua.LuaLDir)) @@ -196,9 +199,9 @@ func NewRuntimeProviderLua(logger, startupLogger *zap.Logger, db *sql.DB, jsonpb once: once, poolCh: make(chan *RuntimeLua, config.GetRuntime().MaxCount), - maxCount: config.GetRuntime().MaxCount, + maxCount: uint32(config.GetRuntime().MaxCount), // Set the current count assuming we'll warm up the pool in a moment. - currentCount: config.GetRuntime().MinCount, + currentCount: atomic.NewUint32(uint32(config.GetRuntime().MinCount)), newFn: func() *RuntimeLua { r, err := newRuntimeLuaVM(logger, db, jsonpbUnmarshaler, config, socialClient, leaderboardCache, leaderboardRankCache, leaderboardScheduler, sessionRegistry, matchRegistry, tracker, router, stdLibs, moduleCache, once, localCache, allMatchCreateFn, nil) if err != nil { @@ -923,20 +926,20 @@ func NewRuntimeProviderLua(logger, startupLogger *zap.Logger, db *sql.DB, jsonpb } } case RuntimeExecutionModeMatchmaker: - matchmakerMatchedFunction = func(entries []*MatchmakerEntry) (string, bool, error) { - return runtimeProviderLua.MatchmakerMatched(entries) + matchmakerMatchedFunction = func(ctx context.Context, entries []*MatchmakerEntry) (string, bool, error) { + return runtimeProviderLua.MatchmakerMatched(ctx, entries) } case RuntimeExecutionModeTournamentEnd: - tournamentEndFunction = func(tournament *api.Tournament, end, reset int64) error { - return runtimeProviderLua.TournamentEnd(tournament, end, reset) + tournamentEndFunction = func(ctx context.Context, tournament *api.Tournament, end, reset int64) error { + return runtimeProviderLua.TournamentEnd(ctx, tournament, end, reset) } case RuntimeExecutionModeTournamentReset: - tournamentResetFunction = func(tournament *api.Tournament, end, reset int64) error { - return runtimeProviderLua.TournamentReset(tournament, end, reset) + tournamentResetFunction = func(ctx context.Context, tournament *api.Tournament, end, reset int64) error { + return runtimeProviderLua.TournamentReset(ctx, tournament, end, reset) } case RuntimeExecutionModeLeaderboardReset: - leaderboardResetFunction = func(leaderboard runtime.Leaderboard, reset int64) error { - return runtimeProviderLua.LeaderboardReset(leaderboard, reset) + leaderboardResetFunction = func(ctx context.Context, leaderboard runtime.Leaderboard, reset int64) error { + return runtimeProviderLua.LeaderboardReset(ctx, leaderboard, reset) } } }) @@ -948,7 +951,7 @@ func NewRuntimeProviderLua(logger, startupLogger *zap.Logger, db *sql.DB, jsonpb startupLogger.Info("Lua runtime modules loaded") // Warm up the pool. - startupLogger.Info("Allocating minimum runtime pool", zap.Int("count", runtimeProviderLua.currentCount)) + startupLogger.Info("Allocating minimum runtime pool", zap.Int("count", config.GetRuntime().MinCount)) if len(moduleCache.Names) > 0 { // Only if there are runtime modules to load. for i := 0; i < config.GetRuntime().MinCount; i++ { @@ -962,17 +965,20 @@ func NewRuntimeProviderLua(logger, startupLogger *zap.Logger, db *sql.DB, jsonpb } func (rp *RuntimeProviderLua) Rpc(ctx context.Context, id string, queryParams map[string][]string, userID, username string, expiry int64, sessionID, clientIP, clientPort, payload string) (string, error, codes.Code) { - runtime := rp.Get() - lf := runtime.GetCallback(RuntimeExecutionModeRPC, id) + r, err := rp.Get(ctx) + if err != nil { + return "", err, codes.Internal + } + lf := r.GetCallback(RuntimeExecutionModeRPC, id) if lf == nil { - rp.Put(runtime) + rp.Put(r) return "", ErrRuntimeRPCNotFound, codes.NotFound } - runtime.vm.SetContext(ctx) - result, fnErr, code := runtime.InvokeFunction(RuntimeExecutionModeRPC, lf, queryParams, userID, username, expiry, sessionID, clientIP, clientPort, payload) - runtime.vm.SetContext(context.Background()) - rp.Put(runtime) + r.vm.SetContext(ctx) + result, fnErr, code := r.InvokeFunction(RuntimeExecutionModeRPC, lf, queryParams, userID, username, expiry, sessionID, clientIP, clientPort, payload) + r.vm.SetContext(context.Background()) + rp.Put(r) if fnErr != nil { rp.logger.Error("Runtime RPC function caused an error", zap.String("id", id), zap.Error(fnErr)) @@ -1012,30 +1018,33 @@ func (rp *RuntimeProviderLua) Rpc(ctx context.Context, id string, queryParams ma } func (rp *RuntimeProviderLua) BeforeRt(ctx context.Context, id string, logger *zap.Logger, userID, username string, expiry int64, sessionID, clientIP, clientPort string, envelope *rtapi.Envelope) (*rtapi.Envelope, error) { - runtime := rp.Get() - lf := runtime.GetCallback(RuntimeExecutionModeBefore, id) + r, err := rp.Get(ctx) + if err != nil { + return nil, err + } + lf := r.GetCallback(RuntimeExecutionModeBefore, id) if lf == nil { - rp.Put(runtime) + rp.Put(r) return nil, errors.New("Runtime Before function not found.") } envelopeJSON, err := rp.jsonpbMarshaler.MarshalToString(envelope) if err != nil { - rp.Put(runtime) + rp.Put(r) logger.Error("Could not marshall envelope to JSON", zap.Any("envelope", envelope), zap.Error(err)) return nil, errors.New("Could not run runtime Before function.") } var envelopeMap map[string]interface{} if err := json.Unmarshal([]byte(envelopeJSON), &envelopeMap); err != nil { - rp.Put(runtime) + rp.Put(r) logger.Error("Could not unmarshall envelope to interface{}", zap.Any("envelope_json", envelopeJSON), zap.Error(err)) return nil, errors.New("Could not run runtime Before function.") } - runtime.vm.SetContext(ctx) - result, fnErr, _ := runtime.InvokeFunction(RuntimeExecutionModeBefore, lf, nil, userID, username, expiry, sessionID, clientIP, clientPort, envelopeMap) - runtime.vm.SetContext(context.Background()) - rp.Put(runtime) + r.vm.SetContext(ctx) + result, fnErr, _ := r.InvokeFunction(RuntimeExecutionModeBefore, lf, nil, userID, username, expiry, sessionID, clientIP, clientPort, envelopeMap) + r.vm.SetContext(context.Background()) + rp.Put(r) if fnErr != nil { logger.Error("Runtime Before function caused an error.", zap.String("id", id), zap.Error(fnErr)) @@ -1075,30 +1084,33 @@ func (rp *RuntimeProviderLua) BeforeRt(ctx context.Context, id string, logger *z } func (rp *RuntimeProviderLua) AfterRt(ctx context.Context, id string, logger *zap.Logger, userID, username string, expiry int64, sessionID, clientIP, clientPort string, envelope *rtapi.Envelope) error { - runtime := rp.Get() - lf := runtime.GetCallback(RuntimeExecutionModeAfter, id) + r, err := rp.Get(ctx) + if err != nil { + return err + } + lf := r.GetCallback(RuntimeExecutionModeAfter, id) if lf == nil { - rp.Put(runtime) + rp.Put(r) return errors.New("Runtime After function not found.") } envelopeJSON, err := rp.jsonpbMarshaler.MarshalToString(envelope) if err != nil { - rp.Put(runtime) + rp.Put(r) logger.Error("Could not marshall envelope to JSON", zap.Any("envelope", envelope), zap.Error(err)) return errors.New("Could not run runtime After function.") } var envelopeMap map[string]interface{} if err := json.Unmarshal([]byte(envelopeJSON), &envelopeMap); err != nil { - rp.Put(runtime) + rp.Put(r) logger.Error("Could not unmarshall envelope to interface{}", zap.Any("envelope_json", envelopeJSON), zap.Error(err)) return errors.New("Could not run runtime After function.") } - runtime.vm.SetContext(ctx) - _, fnErr, _ := runtime.InvokeFunction(RuntimeExecutionModeAfter, lf, nil, userID, username, expiry, sessionID, clientIP, clientPort, envelopeMap) - runtime.vm.SetContext(context.Background()) - rp.Put(runtime) + r.vm.SetContext(ctx) + _, fnErr, _ := r.InvokeFunction(RuntimeExecutionModeAfter, lf, nil, userID, username, expiry, sessionID, clientIP, clientPort, envelopeMap) + r.vm.SetContext(context.Background()) + rp.Put(r) if fnErr != nil { logger.Error("Runtime After function caused an error.", zap.String("id", id), zap.Error(fnErr)) @@ -1123,10 +1135,13 @@ func (rp *RuntimeProviderLua) AfterRt(ctx context.Context, id string, logger *za } func (rp *RuntimeProviderLua) BeforeReq(ctx context.Context, id string, logger *zap.Logger, userID, username string, expiry int64, clientIP, clientPort string, req interface{}) (interface{}, error, codes.Code) { - runtime := rp.Get() - lf := runtime.GetCallback(RuntimeExecutionModeBefore, id) + r, err := rp.Get(ctx) + if err != nil { + return nil, err, codes.Internal + } + lf := r.GetCallback(RuntimeExecutionModeBefore, id) if lf == nil { - rp.Put(runtime) + rp.Put(r) return nil, errors.New("Runtime Before function not found."), codes.NotFound } @@ -1137,27 +1152,27 @@ func (rp *RuntimeProviderLua) BeforeReq(ctx context.Context, id string, logger * var ok bool reqProto, ok = req.(proto.Message) if !ok { - rp.Put(runtime) + rp.Put(r) logger.Error("Could not cast request to message", zap.Any("request", req)) return nil, errors.New("Could not run runtime Before function."), codes.Internal } reqJSON, err := rp.jsonpbMarshaler.MarshalToString(reqProto) if err != nil { - rp.Put(runtime) + rp.Put(r) logger.Error("Could not marshall request to JSON", zap.Any("request", reqProto), zap.Error(err)) return nil, errors.New("Could not run runtime Before function."), codes.Internal } if err := json.Unmarshal([]byte(reqJSON), &reqMap); err != nil { - rp.Put(runtime) + rp.Put(r) logger.Error("Could not unmarshall request to interface{}", zap.Any("request_json", reqJSON), zap.Error(err)) return nil, errors.New("Could not run runtime Before function."), codes.Internal } } - runtime.vm.SetContext(ctx) - result, fnErr, code := runtime.InvokeFunction(RuntimeExecutionModeBefore, lf, nil, userID, username, expiry, "", clientIP, clientPort, reqMap) - runtime.vm.SetContext(context.Background()) - rp.Put(runtime) + r.vm.SetContext(ctx) + result, fnErr, code := r.InvokeFunction(RuntimeExecutionModeBefore, lf, nil, userID, username, expiry, "", clientIP, clientPort, reqMap) + r.vm.SetContext(context.Background()) + rp.Put(r) if fnErr != nil { logger.Error("Runtime Before function caused an error.", zap.String("id", id), zap.Error(fnErr)) @@ -1198,10 +1213,13 @@ func (rp *RuntimeProviderLua) BeforeReq(ctx context.Context, id string, logger * } func (rp *RuntimeProviderLua) AfterReq(ctx context.Context, id string, logger *zap.Logger, userID, username string, expiry int64, clientIP, clientPort string, res interface{}, req interface{}) error { - runtime := rp.Get() - lf := runtime.GetCallback(RuntimeExecutionModeAfter, id) + r, err := rp.Get(ctx) + if err != nil { + return err + } + lf := r.GetCallback(RuntimeExecutionModeAfter, id) if lf == nil { - rp.Put(runtime) + rp.Put(r) return errors.New("Runtime After function not found.") } @@ -1210,19 +1228,19 @@ func (rp *RuntimeProviderLua) AfterReq(ctx context.Context, id string, logger *z // Res may be nil if there is no response body. resProto, ok := res.(proto.Message) if !ok { - rp.Put(runtime) + rp.Put(r) logger.Error("Could not cast response to message", zap.Any("response", res)) return errors.New("Could not run runtime After function.") } resJSON, err := rp.jsonpbMarshaler.MarshalToString(resProto) if err != nil { - rp.Put(runtime) + rp.Put(r) logger.Error("Could not marshall response to JSON", zap.Any("response", resProto), zap.Error(err)) return errors.New("Could not run runtime After function.") } if err := json.Unmarshal([]byte(resJSON), &resMap); err != nil { - rp.Put(runtime) + rp.Put(r) logger.Error("Could not unmarshall response to interface{}", zap.Any("response_json", resJSON), zap.Error(err)) return errors.New("Could not run runtime After function.") } @@ -1233,28 +1251,28 @@ func (rp *RuntimeProviderLua) AfterReq(ctx context.Context, id string, logger *z // Req may be nil if there is no request body. reqProto, ok := req.(proto.Message) if !ok { - rp.Put(runtime) + rp.Put(r) logger.Error("Could not cast request to message", zap.Any("request", req)) return errors.New("Could not run runtime After function.") } reqJSON, err := rp.jsonpbMarshaler.MarshalToString(reqProto) if err != nil { - rp.Put(runtime) + rp.Put(r) logger.Error("Could not marshall request to JSON", zap.Any("request", reqProto), zap.Error(err)) return errors.New("Could not run runtime After function.") } if err := json.Unmarshal([]byte(reqJSON), &reqMap); err != nil { - rp.Put(runtime) + rp.Put(r) logger.Error("Could not unmarshall request to interface{}", zap.Any("request_json", reqJSON), zap.Error(err)) return errors.New("Could not run runtime After function.") } } - runtime.vm.SetContext(ctx) - _, fnErr, _ := runtime.InvokeFunction(RuntimeExecutionModeAfter, lf, nil, userID, username, expiry, "", clientIP, clientPort, resMap, reqMap) - runtime.vm.SetContext(context.Background()) - rp.Put(runtime) + r.vm.SetContext(ctx) + _, fnErr, _ := r.InvokeFunction(RuntimeExecutionModeAfter, lf, nil, userID, username, expiry, "", clientIP, clientPort, resMap, reqMap) + r.vm.SetContext(context.Background()) + rp.Put(r) if fnErr != nil { logger.Error("Runtime After function caused an error.", zap.String("id", id), zap.Error(fnErr)) @@ -1278,25 +1296,28 @@ func (rp *RuntimeProviderLua) AfterReq(ctx context.Context, id string, logger *z return nil } -func (rp *RuntimeProviderLua) MatchmakerMatched(entries []*MatchmakerEntry) (string, bool, error) { - runtime := rp.Get() - lf := runtime.GetCallback(RuntimeExecutionModeMatchmaker, "") +func (rp *RuntimeProviderLua) MatchmakerMatched(ctx context.Context, entries []*MatchmakerEntry) (string, bool, error) { + r, err := rp.Get(ctx) + if err != nil { + return "", false, err + } + lf := r.GetCallback(RuntimeExecutionModeMatchmaker, "") if lf == nil { - rp.Put(runtime) + rp.Put(r) return "", false, errors.New("Runtime Matchmaker Matched function not found.") } - ctx := NewRuntimeLuaContext(runtime.vm, runtime.luaEnv, RuntimeExecutionModeMatchmaker, nil, 0, "", "", "", "", "") + luaCtx := NewRuntimeLuaContext(r.vm, r.luaEnv, RuntimeExecutionModeMatchmaker, nil, 0, "", "", "", "", "") - entriesTable := runtime.vm.CreateTable(len(entries), 0) + entriesTable := r.vm.CreateTable(len(entries), 0) for i, entry := range entries { - presenceTable := runtime.vm.CreateTable(0, 4) + presenceTable := r.vm.CreateTable(0, 4) presenceTable.RawSetString("user_id", lua.LString(entry.Presence.UserId)) presenceTable.RawSetString("session_id", lua.LString(entry.Presence.SessionId)) presenceTable.RawSetString("username", lua.LString(entry.Presence.Username)) presenceTable.RawSetString("node", lua.LString(entry.Presence.Node)) - propertiesTable := runtime.vm.CreateTable(0, len(entry.StringProperties)+len(entry.NumericProperties)) + propertiesTable := r.vm.CreateTable(0, len(entry.StringProperties)+len(entry.NumericProperties)) for k, v := range entry.StringProperties { propertiesTable.RawSetString(k, lua.LString(v)) } @@ -1304,15 +1325,15 @@ func (rp *RuntimeProviderLua) MatchmakerMatched(entries []*MatchmakerEntry) (str propertiesTable.RawSetString(k, lua.LNumber(v)) } - entryTable := runtime.vm.CreateTable(0, 2) + entryTable := r.vm.CreateTable(0, 2) entryTable.RawSetString("presence", presenceTable) entryTable.RawSetString("properties", propertiesTable) entriesTable.RawSetInt(i+1, entryTable) } - retValue, err, _ := runtime.invokeFunction(runtime.vm, lf, ctx, entriesTable) - rp.Put(runtime) + retValue, err, _ := r.invokeFunction(r.vm, lf, luaCtx, entriesTable) + rp.Put(r) if err != nil { return "", false, fmt.Errorf("Error running runtime Matchmaker Matched hook: %v", err.Error()) } @@ -1342,17 +1363,20 @@ func (rp *RuntimeProviderLua) MatchmakerMatched(entries []*MatchmakerEntry) (str return "", false, errors.New("Unexpected return type from runtime Matchmaker Matched hook, must be string or nil.") } -func (rp *RuntimeProviderLua) TournamentEnd(tournament *api.Tournament, end, reset int64) error { - runtime := rp.Get() - lf := runtime.GetCallback(RuntimeExecutionModeTournamentEnd, "") +func (rp *RuntimeProviderLua) TournamentEnd(ctx context.Context, tournament *api.Tournament, end, reset int64) error { + r, err := rp.Get(ctx) + if err != nil { + return err + } + lf := r.GetCallback(RuntimeExecutionModeTournamentEnd, "") if lf == nil { - rp.Put(runtime) + rp.Put(r) return errors.New("Runtime Tournament End function not found.") } - ctx := NewRuntimeLuaContext(runtime.vm, runtime.luaEnv, RuntimeExecutionModeTournamentEnd, nil, 0, "", "", "", "", "") + luaCtx := NewRuntimeLuaContext(r.vm, r.luaEnv, RuntimeExecutionModeTournamentEnd, nil, 0, "", "", "", "", "") - tournamentTable := runtime.vm.CreateTable(0, 16) + tournamentTable := r.vm.CreateTable(0, 16) tournamentTable.RawSetString("id", lua.LString(tournament.Id)) tournamentTable.RawSetString("title", lua.LString(tournament.Title)) @@ -1371,12 +1395,12 @@ func (rp *RuntimeProviderLua) TournamentEnd(tournament *api.Tournament, end, res tournamentTable.RawSetString("can_enter", lua.LBool(tournament.CanEnter)) tournamentTable.RawSetString("next_reset", lua.LNumber(tournament.NextReset)) metadataMap := make(map[string]interface{}) - err := json.Unmarshal([]byte(tournament.Metadata), &metadataMap) + err = json.Unmarshal([]byte(tournament.Metadata), &metadataMap) if err != nil { - rp.Put(runtime) + rp.Put(r) return fmt.Errorf("failed to convert metadata to json: %s", err.Error()) } - metadataTable := RuntimeLuaConvertMap(runtime.vm, metadataMap) + metadataTable := RuntimeLuaConvertMap(r.vm, metadataMap) tournamentTable.RawSetString("metadata", metadataTable) tournamentTable.RawSetString("create_time", lua.LNumber(tournament.CreateTime.Seconds)) tournamentTable.RawSetString("start_time", lua.LNumber(tournament.StartTime.Seconds)) @@ -1386,8 +1410,8 @@ func (rp *RuntimeProviderLua) TournamentEnd(tournament *api.Tournament, end, res tournamentTable.RawSetString("end_time", lua.LNumber(tournament.EndTime.Seconds)) } - retValue, err, _ := runtime.invokeFunction(runtime.vm, lf, ctx, tournamentTable, lua.LNumber(end), lua.LNumber(reset)) - rp.Put(runtime) + retValue, err, _ := r.invokeFunction(r.vm, lf, luaCtx, tournamentTable, lua.LNumber(end), lua.LNumber(reset)) + rp.Put(r) if err != nil { return fmt.Errorf("Error running runtime Tournament End hook: %v", err.Error()) } @@ -1397,20 +1421,23 @@ func (rp *RuntimeProviderLua) TournamentEnd(tournament *api.Tournament, end, res return nil } - return errors.New("Unexpected return type from runtime Tournament End hook, must be string or nil.") + return errors.New("Unexpected return type from runtime Tournament End hook, must be nil.") } -func (rp *RuntimeProviderLua) TournamentReset(tournament *api.Tournament, end, reset int64) error { - runtime := rp.Get() - lf := runtime.GetCallback(RuntimeExecutionModeTournamentReset, "") +func (rp *RuntimeProviderLua) TournamentReset(ctx context.Context, tournament *api.Tournament, end, reset int64) error { + r, err := rp.Get(ctx) + if err != nil { + return err + } + lf := r.GetCallback(RuntimeExecutionModeTournamentReset, "") if lf == nil { - rp.Put(runtime) + rp.Put(r) return errors.New("Runtime Tournament Reset function not found.") } - ctx := NewRuntimeLuaContext(runtime.vm, runtime.luaEnv, RuntimeExecutionModeTournamentReset, nil, 0, "", "", "", "", "") + luaCtx := NewRuntimeLuaContext(r.vm, r.luaEnv, RuntimeExecutionModeTournamentReset, nil, 0, "", "", "", "", "") - tournamentTable := runtime.vm.CreateTable(0, 16) + tournamentTable := r.vm.CreateTable(0, 16) tournamentTable.RawSetString("id", lua.LString(tournament.Id)) tournamentTable.RawSetString("title", lua.LString(tournament.Title)) @@ -1429,12 +1456,12 @@ func (rp *RuntimeProviderLua) TournamentReset(tournament *api.Tournament, end, r tournamentTable.RawSetString("can_enter", lua.LBool(tournament.CanEnter)) tournamentTable.RawSetString("next_reset", lua.LNumber(tournament.NextReset)) metadataMap := make(map[string]interface{}) - err := json.Unmarshal([]byte(tournament.Metadata), &metadataMap) + err = json.Unmarshal([]byte(tournament.Metadata), &metadataMap) if err != nil { - rp.Put(runtime) + rp.Put(r) return fmt.Errorf("failed to convert metadata to json: %s", err.Error()) } - metadataTable := RuntimeLuaConvertMap(runtime.vm, metadataMap) + metadataTable := RuntimeLuaConvertMap(r.vm, metadataMap) tournamentTable.RawSetString("metadata", metadataTable) tournamentTable.RawSetString("create_time", lua.LNumber(tournament.CreateTime.Seconds)) tournamentTable.RawSetString("start_time", lua.LNumber(tournament.StartTime.Seconds)) @@ -1444,8 +1471,8 @@ func (rp *RuntimeProviderLua) TournamentReset(tournament *api.Tournament, end, r tournamentTable.RawSetString("end_time", lua.LNumber(tournament.EndTime.Seconds)) } - retValue, err, _ := runtime.invokeFunction(runtime.vm, lf, ctx, tournamentTable, lua.LNumber(end), lua.LNumber(reset)) - rp.Put(runtime) + retValue, err, _ := r.invokeFunction(r.vm, lf, luaCtx, tournamentTable, lua.LNumber(end), lua.LNumber(reset)) + rp.Put(r) if err != nil { return fmt.Errorf("Error running runtime Tournament Reset hook: %v", err.Error()) } @@ -1455,32 +1482,35 @@ func (rp *RuntimeProviderLua) TournamentReset(tournament *api.Tournament, end, r return nil } - return errors.New("Unexpected return type from runtime Tournament Reset hook, must be string or nil.") + return errors.New("Unexpected return type from runtime Tournament Reset hook, must be nil.") } -func (rp *RuntimeProviderLua) LeaderboardReset(leaderboard runtime.Leaderboard, reset int64) error { - runtime := rp.Get() - lf := runtime.GetCallback(RuntimeExecutionModeLeaderboardReset, "") +func (rp *RuntimeProviderLua) LeaderboardReset(ctx context.Context, leaderboard runtime.Leaderboard, reset int64) error { + r, err := rp.Get(ctx) + if err != nil { + return err + } + lf := r.GetCallback(RuntimeExecutionModeLeaderboardReset, "") if lf == nil { - rp.Put(runtime) + rp.Put(r) return errors.New("Runtime Leaderboard Reset function not found.") } - ctx := NewRuntimeLuaContext(runtime.vm, runtime.luaEnv, RuntimeExecutionModeLeaderboardReset, nil, 0, "", "", "", "", "") + luaCtx := NewRuntimeLuaContext(r.vm, r.luaEnv, RuntimeExecutionModeLeaderboardReset, nil, 0, "", "", "", "", "") - leaderboardTable := runtime.vm.CreateTable(0, 13) + leaderboardTable := r.vm.CreateTable(0, 13) leaderboardTable.RawSetString("id", lua.LString(leaderboard.GetId())) leaderboardTable.RawSetString("authoritative", lua.LBool(leaderboard.GetAuthoritative())) leaderboardTable.RawSetString("sort_order", lua.LString(leaderboard.GetSortOrder())) leaderboardTable.RawSetString("operator", lua.LString(leaderboard.GetOperator())) leaderboardTable.RawSetString("reset", lua.LString(leaderboard.GetReset())) - metadataTable := RuntimeLuaConvertMap(runtime.vm, leaderboard.GetMetadata()) + metadataTable := RuntimeLuaConvertMap(r.vm, leaderboard.GetMetadata()) leaderboardTable.RawSetString("metadata", metadataTable) leaderboardTable.RawSetString("create_time", lua.LNumber(leaderboard.GetCreateTime())) - retValue, err, _ := runtime.invokeFunction(runtime.vm, lf, ctx, leaderboardTable, lua.LNumber(reset)) - rp.Put(runtime) + retValue, err, _ := r.invokeFunction(r.vm, lf, luaCtx, leaderboardTable, lua.LNumber(reset)) + rp.Put(r) if err != nil { return fmt.Errorf("Error running runtime Leaderboard Reset hook: %v", err.Error()) } @@ -1490,35 +1520,39 @@ func (rp *RuntimeProviderLua) LeaderboardReset(leaderboard runtime.Leaderboard, return nil } - return errors.New("Unexpected return type from runtime Leaderboard Reset hook, must be string or nil.") + return errors.New("Unexpected return type from runtime Leaderboard Reset hook, must be nil.") } -func (rp *RuntimeProviderLua) Get() *RuntimeLua { +func (rp *RuntimeProviderLua) Get(ctx context.Context) (*RuntimeLua, error) { select { + case <-ctx.Done(): + return nil, ctx.Err() case r := <-rp.poolCh: // Ideally use an available idle runtime. - return r + return r, nil default: // If there was no idle runtime, see if we can allocate a new one. - rp.Lock() - if rp.currentCount >= rp.maxCount { - rp.Unlock() - // If we've reached the max allowed allocation block on an available runtime. - return <-rp.poolCh + if rp.currentCount.Load() >= rp.maxCount { + // No further runtime allocations allowed. + break } - // Inside the locked region now, last chance to use an available idle runtime. - // Note: useful in case a runtime becomes available while waiting to acquire lock. - select { - case r := <-rp.poolCh: - rp.Unlock() - return r - default: - // Allocate a new runtime. - rp.currentCount++ - rp.Unlock() - stats.Record(rp.statsCtx, MetricsRuntimeCount.M(1)) - return rp.newFn() + if rp.currentCount.Inc() > rp.maxCount { + // When we've incremented see if we can still allocate or a concurrent operation has already done so up to the limit. + // The current count value may go above max count value, but we will never over-allocate runtimes. + // This discrepancy is allowed as it avoids a full mutex locking scenario. + break } + stats.Record(rp.statsCtx, MetricsRuntimeCount.M(1)) + return rp.newFn(), nil + } + + // If we reach here then we were unable to find an available idle runtime, and allocation was not allowed. + // Wait as needed. + select { + case <-ctx.Done(): + return nil, ctx.Err() + case r := <-rp.poolCh: + return r, nil } } -- GitLab