Commit f32695bb authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Improved validation of dispatcher broadcast message filters.

parent bb2c43b3
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr
## [Unreleased]
### Changed
- Improved cancellation of ongoing work when clients disconnect.
- Improved validation of dispatcher broadcast message filters.

### Fixed
- Use leaderboard expires rather than end active IDs with leaderboard resets.
+13 −6
Original line number Diff line number Diff line
@@ -179,17 +179,24 @@ func (p *Pipeline) ProcessRequest(logger *zap.Logger, session Session, envelope
	}

	// Stats measurement start boundary.
	var span *trace.Span
	var startNanos int64
	var statsCtx context.Context
	if pipelineName != "matchDataSend" {
		name := fmt.Sprintf("nakama.rtapi.%v", pipelineName)
	statsCtx, _ := tag.New(context.Background(), tag.Upsert(MetricsFunction, name))
	startNanos := time.Now().UTC().UnixNano()
	span := trace.NewSpan(name, nil, trace.StartOptions{})
		statsCtx, _ = tag.New(context.Background(), tag.Upsert(MetricsFunction, name))
		startNanos = time.Now().UTC().UnixNano()
		span = trace.NewSpan(name, nil, trace.StartOptions{})
	}

	// Actual function execution.
	pipelineFn(logger, session, envelope)

	// Stats measurement end boundary.
	if span != nil {
		span.End()
		stats.Record(statsCtx, MetricsRtapiTimeSpentMsec.M(float64(time.Now().UTC().UnixNano()-startNanos)/1000), MetricsRtapiCount.M(1))
	}

	if messageName != "" {
		if fn := p.runtime.AfterRt(messageNameID); fn != nil {
+33 −20
Original line number Diff line number Diff line
@@ -198,6 +198,18 @@ func (r *RuntimeGoMatchCore) BroadcastMessage(opCode int64, data []byte, presenc

	if presenceIDs != nil {
		// Ensure specific presences actually exist to prevent sending bogus messages to arbitrary users.
		if len(presenceIDs) == 1 {
			// Shorter validation cycle if there is only one intended recipient.
			userID, err := uuid.FromString(presences[0].GetUserId())
			if err != nil {
				return errors.New("Presence contains an invalid User ID")
			}
			if r.tracker.GetBySessionIDStreamUserID(presenceIDs[0].Node, presenceIDs[0].SessionID, r.stream, userID) == nil {
				// The one intended recipient is not a match member.
				return nil
			}
		} else {
			// Validate multiple filtered recipients.
			actualPresenceIDs := r.tracker.ListPresenceIDByStream(r.stream)
			for i := 0; i < len(presenceIDs); i++ {
				found := false
@@ -223,6 +235,7 @@ func (r *RuntimeGoMatchCore) BroadcastMessage(opCode int64, data []byte, presenc
				return nil
			}
		}
	}

	msg := &rtapi.Envelope{Message: &rtapi.Envelope_MatchData{MatchData: &rtapi.MatchData{
		MatchId:  r.idStr,
+56 −24
Original line number Diff line number Diff line
@@ -501,7 +501,7 @@ func (r *RuntimeLuaMatchCore) broadcastMessage(l *lua.LState) int {
			pt, ok := p.(*lua.LTable)
			if !ok {
				conversionError = true
				l.ArgError(1, "expects a valid set of presences")
				l.ArgError(3, "expects a valid set of presences")
				return
			}

@@ -512,14 +512,14 @@ func (r *RuntimeLuaMatchCore) broadcastMessage(l *lua.LState) int {
					sid, err := uuid.FromString(v.String())
					if err != nil {
						conversionError = true
						l.ArgError(1, "expects each presence to have a valid session_id")
						l.ArgError(3, "expects each presence to have a valid session_id")
						return
					}
					presenceID.SessionID = sid
				case "node":
					if v.Type() != lua.LTString {
						conversionError = true
						l.ArgError(1, "expects node to be string")
						l.ArgError(3, "expects node to be string")
						return
					}
					presenceID.Node = v.String()
@@ -527,7 +527,7 @@ func (r *RuntimeLuaMatchCore) broadcastMessage(l *lua.LState) int {
			})
			if presenceID.SessionID == uuid.Nil || presenceID.Node == "" {
				conversionError = true
				l.ArgError(1, "expects each presence to have a valid session_id and node")
				l.ArgError(3, "expects each presence to have a valid session_id and node")
				return
			}
			if conversionError {
@@ -590,6 +590,37 @@ func (r *RuntimeLuaMatchCore) broadcastMessage(l *lua.LState) int {

	if presenceIDs != nil {
		// Ensure specific presences actually exist to prevent sending bogus messages to arbitrary users.
		if len(presenceIDs) == 1 {
			// Shorter validation cycle if there is only one intended recipient.
			presenceValue := filter.RawGetInt(1)
			if presenceValue == lua.LNil {
				l.ArgError(3, "expects each presence to be non-nil")
				return 0
			}
			presenceTable, ok := presenceValue.(*lua.LTable)
			if !ok {
				l.ArgError(3, "expects each presence to be a table")
				return 0
			}
			userIDValue := presenceTable.RawGetString("user_id")
			if userIDValue == nil {
				l.ArgError(3, "expects each presence to have a valid user_id")
				return 0
			}
			if userIDValue.Type() != lua.LTString {
				l.ArgError(3, "expects each presence to have a valid user_id")
				return 0
			}
			userID, err := uuid.FromString(userIDValue.String())
			if err != nil {
				l.ArgError(3, "expects each presence to have a valid user_id")
				return 0
			}
			if r.tracker.GetBySessionIDStreamUserID(presenceIDs[0].Node, presenceIDs[0].SessionID, r.stream, userID) == nil {
				// The one intended recipient is not a match member.
				return 0
			}
		} else {
			actualPresenceIDs := r.tracker.ListPresenceIDByStream(r.stream)
			for i := 0; i < len(presenceIDs); i++ {
				found := false
@@ -615,6 +646,7 @@ func (r *RuntimeLuaMatchCore) broadcastMessage(l *lua.LState) int {
				return 0
			}
		}
	}

	msg := &rtapi.Envelope{Message: &rtapi.Envelope_MatchData{MatchData: &rtapi.MatchData{
		MatchId:  r.idStr,