Commit 6e07e5a0 authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Improve authoritative match handler tracker usage.

parent f53bf2d3
Loading
Loading
Loading
Loading
+66 −3
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ package server
import (
	"context"
	"fmt"
	"sync"
	"time"

	"github.com/gofrs/uuid"
@@ -25,6 +26,58 @@ import (
	"go.uber.org/zap"
)

type MatchPresenceList struct {
	sync.RWMutex
	presences []*PresenceID
}

func (m *MatchPresenceList) Join(joins []*MatchPresence) {
	m.Lock()
	for _, join := range joins {
		m.presences = append(m.presences, &PresenceID{
			Node:      join.Node,
			SessionID: join.SessionID,
		})
	}
	m.Unlock()
}

func (m *MatchPresenceList) Leave(leaves []*MatchPresence) {
	m.Lock()
	for _, leave := range leaves {
		for i, presenceID := range m.presences {
			if presenceID.SessionID == leave.SessionID && presenceID.Node == leave.Node {
				m.presences = append(m.presences[:i], m.presences[i+1:]...)
				break
			}
		}
	}
	m.Unlock()
}

func (m *MatchPresenceList) Contains(presence *PresenceID) bool {
	var found bool
	m.RLock()
	for _, p := range m.presences {
		if p.SessionID == presence.SessionID && p.Node == p.Node {
			found = true
			break
		}
	}
	m.RUnlock()
	return found
}

func (m *MatchPresenceList) List() []*PresenceID {
	m.RLock()
	list := make([]*PresenceID, 0, len(m.presences))
	for _, presence := range m.presences {
		list = append(list, presence)
	}
	m.RUnlock()
	return list
}

type MatchDataMessage struct {
	UserID      uuid.UUID
	SessionID   uuid.UUID
@@ -72,6 +125,7 @@ type MatchHandler struct {
	tracker       Tracker
	router        MessageRouter

	presenceList *MatchPresenceList
	core         RuntimeMatchCore

	// Identification not (directly) controlled by match init.
@@ -100,7 +154,11 @@ type MatchHandler struct {
}

func NewMatchHandler(logger *zap.Logger, config Config, matchRegistry MatchRegistry, core RuntimeMatchCore, label *atomic.String, id uuid.UUID, node string, params map[string]interface{}) (*MatchHandler, error) {
	state, rateInt, labelStr, err := core.MatchInit(params)
	presenceList := &MatchPresenceList{
		presences: make([]*PresenceID, 0, 10),
	}

	state, rateInt, labelStr, err := core.MatchInit(presenceList, params)
	if err != nil {
		core.Cancel()
		return nil, err
@@ -120,6 +178,7 @@ func NewMatchHandler(logger *zap.Logger, config Config, matchRegistry MatchRegis
		logger:        logger,
		matchRegistry: matchRegistry,

		presenceList: presenceList,
		core:         core,

		ID:    id,
@@ -303,6 +362,8 @@ func (mh *MatchHandler) QueueJoin(joins []*MatchPresence) bool {
			return
		}

		mh.presenceList.Join(joins)

		state, err := mh.core.MatchJoin(mh.tick, mh.state, joins)
		if err != nil {
			mh.Stop()
@@ -331,6 +392,8 @@ func (mh *MatchHandler) QueueLeave(leaves []*MatchPresence) bool {
			return
		}

		mh.presenceList.Leave(leaves)

		state, err := mh.core.MatchLeave(mh.tick, mh.state, leaves)
		if err != nil {
			mh.Stop()
+1 −1
Original line number Diff line number Diff line
@@ -215,7 +215,7 @@ func (e RuntimeExecutionMode) String() string {
}

type RuntimeMatchCore interface {
	MatchInit(params map[string]interface{}) (interface{}, int, string, error)
	MatchInit(presenceList *MatchPresenceList, params map[string]interface{}) (interface{}, int, string, error)
	MatchJoinAttempt(tick int64, state interface{}, userID, sessionID uuid.UUID, username, node string, metadata map[string]string) (interface{}, bool, string, error)
	MatchJoin(tick int64, state interface{}, joins []*MatchPresence) (interface{}, error)
	MatchLeave(tick int64, state interface{}, leaves []*MatchPresence) (interface{}, error)
+2 −1
Original line number Diff line number Diff line
@@ -1758,7 +1758,8 @@ func NewRuntimeProviderGo(logger, startupLogger *zap.Logger, db *sql.DB, config
			return nil, err
		}

		return NewRuntimeGoMatchCore(logger, matchRegistry, tracker, router, id, node, labelUpdateFn, db, env, nk, match)
		//return NewRuntimeGoMatchCore(logger, matchRegistry, tracker, router, id, node, labelUpdateFn, db, env, nk, match)
		return NewRuntimeGoMatchCore(logger, matchRegistry, router, id, node, labelUpdateFn, db, env, nk, match)
	}
	nk.SetMatchCreateFn(matchCreateFn)
	matchNamesListFn := func() []string {
+13 −10
Original line number Diff line number Diff line
@@ -28,10 +28,10 @@ import (
type RuntimeGoMatchCore struct {
	logger        *zap.Logger
	matchRegistry MatchRegistry
	tracker       Tracker
	router        MessageRouter

	labelUpdateFn RuntimeMatchLabelUpdateFunction
	presenceList  *MatchPresenceList

	match runtime.Match

@@ -48,7 +48,8 @@ type RuntimeGoMatchCore struct {
	ctxCancelFn context.CancelFunc
}

func NewRuntimeGoMatchCore(logger *zap.Logger, matchRegistry MatchRegistry, tracker Tracker, router MessageRouter, id uuid.UUID, node string, labelUpdateFn RuntimeMatchLabelUpdateFunction, db *sql.DB, env map[string]string, nk runtime.NakamaModule, match runtime.Match) (RuntimeMatchCore, error) {
//func NewRuntimeGoMatchCore(logger *zap.Logger, matchRegistry MatchRegistry, tracker Tracker, router MessageRouter, id uuid.UUID, node string, labelUpdateFn RuntimeMatchLabelUpdateFunction, db *sql.DB, env map[string]string, nk runtime.NakamaModule, match runtime.Match) (RuntimeMatchCore, error) {
func NewRuntimeGoMatchCore(logger *zap.Logger, matchRegistry MatchRegistry, router MessageRouter, id uuid.UUID, node string, labelUpdateFn RuntimeMatchLabelUpdateFunction, db *sql.DB, env map[string]string, nk runtime.NakamaModule, match runtime.Match) (RuntimeMatchCore, error) {
	ctx, ctxCancelFn := context.WithCancel(context.Background())
	ctx = NewRuntimeGoContext(ctx, env, RuntimeExecutionModeMatch, nil, 0, "", "", "", "", "")
	ctx = context.WithValue(ctx, runtime.RUNTIME_CTX_MATCH_ID, fmt.Sprintf("%v.%v", id.String(), node))
@@ -57,10 +58,10 @@ func NewRuntimeGoMatchCore(logger *zap.Logger, matchRegistry MatchRegistry, trac
	return &RuntimeGoMatchCore{
		logger:        logger,
		matchRegistry: matchRegistry,
		tracker:       tracker,
		router:        router,

		labelUpdateFn: labelUpdateFn,
		// presenceList set in MatchInit.

		match: match,

@@ -82,7 +83,7 @@ func NewRuntimeGoMatchCore(logger *zap.Logger, matchRegistry MatchRegistry, trac
	}, nil
}

func (r *RuntimeGoMatchCore) MatchInit(params map[string]interface{}) (interface{}, int, string, error) {
func (r *RuntimeGoMatchCore) MatchInit(presenceList *MatchPresenceList, params map[string]interface{}) (interface{}, int, string, error) {
	state, tickRate, label := r.match.MatchInit(r.ctx, r.runtimeLogger, r.db, r.nk, params)

	if len(label) > 256 {
@@ -95,6 +96,8 @@ func (r *RuntimeGoMatchCore) MatchInit(params map[string]interface{}) (interface
	r.ctx = context.WithValue(r.ctx, runtime.RUNTIME_CTX_MATCH_TICK_RATE, tickRate)
	r.ctx = context.WithValue(r.ctx, runtime.RUNTIME_CTX_MATCH_LABEL, label)

	r.presenceList = presenceList

	return state, tickRate, label, nil
}

@@ -199,17 +202,17 @@ func (r *RuntimeGoMatchCore) BroadcastMessage(opCode int64, data []byte, presenc
		// Ensure specific presences actually exist to prevent sending bogus messages to arbitrary users.
		if len(presenceIDs) == 1 {
			// Shorter validation cycle if there is only one intended recipient.
			userID, err := uuid.FromString(presences[0].GetUserId())
			_, err := uuid.FromString(presences[0].GetUserId())
			if err != nil {
				return errors.New("Presence contains an invalid User ID")
			}
			if r.tracker.GetBySessionIDStreamUserID(presenceIDs[0].Node, presenceIDs[0].SessionID, r.stream, userID) == nil {
			if !r.presenceList.Contains(presenceIDs[0]) {
				// The one intended recipient is not a match member.
				return nil
			}
		} else {
			// Validate multiple filtered recipients.
			actualPresenceIDs := r.tracker.ListPresenceIDByStream(r.stream)
			actualPresenceIDs := r.presenceList.List()
			for i := 0; i < len(presenceIDs); i++ {
				found := false
				presenceID := presenceIDs[i]
@@ -244,11 +247,11 @@ func (r *RuntimeGoMatchCore) BroadcastMessage(opCode int64, data []byte, presenc
	}}}

	if presenceIDs == nil {
		r.router.SendToStream(r.logger, r.stream, msg)
	} else {
		r.router.SendToPresenceIDs(r.logger, presenceIDs, true, StreamModeMatchAuthoritative, msg)
		presenceIDs = r.presenceList.List()
	}

	r.router.SendToPresenceIDs(r.logger, presenceIDs, true, StreamModeMatchAuthoritative, msg)

	return nil
}

+11 −10
Original line number Diff line number Diff line
@@ -32,10 +32,10 @@ import (
type RuntimeLuaMatchCore struct {
	logger        *zap.Logger
	matchRegistry MatchRegistry
	tracker       Tracker
	router        MessageRouter

	labelUpdateFn RuntimeMatchLabelUpdateFunction
	presenceList  *MatchPresenceList

	id     uuid.UUID
	node   string
@@ -134,10 +134,10 @@ func NewRuntimeLuaMatchCore(logger *zap.Logger, db *sql.DB, jsonpbUnmarshaler *j
	core := &RuntimeLuaMatchCore{
		logger:        logger,
		matchRegistry: matchRegistry,
		tracker:       tracker,
		router:        router,

		labelUpdateFn: labelUpdateFn,
		// presenceList set in MatchInit.

		id:    id,
		node:  node,
@@ -170,7 +170,7 @@ func NewRuntimeLuaMatchCore(logger *zap.Logger, db *sql.DB, jsonpbUnmarshaler *j
	return core, nil
}

func (r *RuntimeLuaMatchCore) MatchInit(params map[string]interface{}) (interface{}, int, string, error) {
func (r *RuntimeLuaMatchCore) MatchInit(presenceList *MatchPresenceList, params map[string]interface{}) (interface{}, int, string, error) {
	// Run the match_init sequence.
	r.vm.Push(LSentinel)
	r.vm.Push(r.initFn)
@@ -231,6 +231,8 @@ func (r *RuntimeLuaMatchCore) MatchInit(params map[string]interface{}) (interfac
	r.ctx.RawSetString(__RUNTIME_LUA_CTX_MATCH_LABEL, label)
	r.ctx.RawSetString(__RUNTIME_LUA_CTX_MATCH_TICK_RATE, rate)

	r.presenceList = presenceList

	return state, rateInt, labelStr, nil
}

@@ -618,17 +620,16 @@ func (r *RuntimeLuaMatchCore) broadcastMessage(l *lua.LState) int {
				l.ArgError(3, "expects each presence to have a valid user_id")
				return 0
			}
			userID, err := uuid.FromString(userIDValue.String())
			_, err := uuid.FromString(userIDValue.String())
			if err != nil {
				l.ArgError(3, "expects each presence to have a valid user_id")
				return 0
			}
			if r.tracker.GetBySessionIDStreamUserID(presenceIDs[0].Node, presenceIDs[0].SessionID, r.stream, userID) == nil {
				// The one intended recipient is not a match member.
			if !r.presenceList.Contains(presenceIDs[0]) {
				return 0
			}
		} else {
			actualPresenceIDs := r.tracker.ListPresenceIDByStream(r.stream)
			actualPresenceIDs := r.presenceList.List()
			for i := 0; i < len(presenceIDs); i++ {
				found := false
				presenceID := presenceIDs[i]
@@ -663,11 +664,11 @@ func (r *RuntimeLuaMatchCore) broadcastMessage(l *lua.LState) int {
	}}}

	if presenceIDs == nil {
		r.router.SendToStream(r.logger, r.stream, msg)
	} else {
		r.router.SendToPresenceIDs(r.logger, presenceIDs, true, StreamModeMatchAuthoritative, msg)
		presenceIDs = r.presenceList.List()
	}

	r.router.SendToPresenceIDs(r.logger, presenceIDs, true, StreamModeMatchAuthoritative, msg)

	return 0
}