Commit 2b88b543 authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Clean up outgoing authoritative message validation.

parent 82aad397
Loading
Loading
Loading
Loading
+21 −8
Original line number Diff line number Diff line
@@ -107,6 +107,7 @@ type MatchPresenceList struct {
	presences       []*MatchPresenceListItem
	presenceMap     map[uuid.UUID]string
	presencesRead   *atomic.Value
	presenceIDsRead *atomic.Value
}

type MatchPresenceListItem struct {
@@ -120,8 +121,10 @@ func NewMatchPresenceList() *MatchPresenceList {
		presences:       make([]*MatchPresenceListItem, 0, 10),
		presenceMap:     make(map[uuid.UUID]string, 10),
		presencesRead:   &atomic.Value{},
		presenceIDsRead: &atomic.Value{},
	}
	m.presencesRead.Store(make([]*MatchPresence, 0))
	m.presenceIDsRead.Store(make([]*PresenceID, 0))
	return m
}

@@ -144,10 +147,13 @@ func (m *MatchPresenceList) Join(joins []*MatchPresence) []*MatchPresence {
	l := len(processed)
	if l != 0 {
		presencesRead := make([]*MatchPresence, 0, len(m.presences))
		presenceIDsRead := make([]*PresenceID, 0, len(m.presences))
		for _, presence := range m.presences {
			presencesRead = append(presencesRead, presence.Presence)
			presenceIDsRead = append(presenceIDsRead, presence.PresenceID)
		}
		m.presencesRead.Store(presencesRead)
		m.presenceIDsRead.Store(presenceIDsRead)
	}
	m.Unlock()
	if l != 0 {
@@ -176,10 +182,13 @@ func (m *MatchPresenceList) Leave(leaves []*MatchPresence) []*MatchPresence {
	l := len(processed)
	if l != 0 {
		presencesRead := make([]*MatchPresence, 0, len(m.presences))
		presenceIDsRead := make([]*PresenceID, 0, len(m.presences))
		for _, presence := range m.presences {
			presencesRead = append(presencesRead, presence.Presence)
			presenceIDsRead = append(presenceIDsRead, presence.PresenceID)
		}
		m.presencesRead.Store(presencesRead)
		m.presenceIDsRead.Store(presenceIDsRead)
	}
	m.Unlock()
	if l != 0 {
@@ -216,6 +225,10 @@ func (m *MatchPresenceList) ListPresences() []*MatchPresence {
	return m.presencesRead.Load().([]*MatchPresence)
}

func (m *MatchPresenceList) ListPresenceIDs() []*PresenceID {
	return m.presenceIDsRead.Load().([]*PresenceID)
}

func (m *MatchPresenceList) Size() int {
	return int(m.size.Load())
}
+1 −6
Original line number Diff line number Diff line
@@ -25,7 +25,6 @@ import (
// Deferred message expected to be batched with other deferred messages.
// All deferred messages in a batch are expected to be for the same stream/mode and share a logger context.
type DeferredMessage struct {
	Stream      *PresenceStream
	PresenceIDs []*PresenceID
	Envelope    *rtapi.Envelope
	Reliable    bool
@@ -108,10 +107,6 @@ func (r *LocalMessageRouter) SendToStream(logger *zap.Logger, stream PresenceStr

func (r *LocalMessageRouter) SendDeferred(logger *zap.Logger, messages []*DeferredMessage) {
	for _, message := range messages {
		if message.Stream != nil {
			r.SendToStream(logger, *message.Stream, message.Envelope, message.Reliable)
		} else {
		r.SendToPresenceIDs(logger, message.PresenceIDs, message.Envelope, message.Reliable)
	}
}
}
+7 −15
Original line number Diff line number Diff line
@@ -208,15 +208,11 @@ func (r *RuntimeGoMatchCore) BroadcastMessage(opCode int64, data []byte, presenc
	if err != nil {
		return err
	}
	if msg == nil {
	if len(presenceIDs) == 0 {
		return nil
	}

	if len(presenceIDs) == 0 {
		r.router.SendToStream(r.logger, r.stream, msg, reliable)
	} else {
	r.router.SendToPresenceIDs(r.logger, presenceIDs, msg, reliable)
	}

	return nil
}
@@ -230,16 +226,8 @@ func (r *RuntimeGoMatchCore) BroadcastMessageDeferred(opCode int64, data []byte,
	if err != nil {
		return err
	}
	if msg == nil {
		return nil
	}

	if len(presenceIDs) == 0 {
		return r.deferMessageFn(&DeferredMessage{
			Stream:   &r.stream,
			Envelope: msg,
			Reliable: reliable,
		})
		return nil
	}

	return r.deferMessageFn(&DeferredMessage{
@@ -330,6 +318,10 @@ func (r *RuntimeGoMatchCore) validateBroadcast(opCode int64, data []byte, presen
		Reliable: reliable,
	}}}

	if presenceIDs == nil {
		presenceIDs = r.presenceList.ListPresenceIDs()
	}

	return presenceIDs, msg, nil
}

+6 −20
Original line number Diff line number Diff line
@@ -428,13 +428,7 @@ func (rm *RuntimeJavaScriptMatchCore) broadcastMessage(r *goja.Runtime) func(goj
		}

		presenceIDs, msg, reliable := rm.validateBroadcast(r, f)
		if msg == nil {
			return goja.Undefined()
		}

		if len(presenceIDs) == 0 {
			rm.router.SendToStream(rm.logger, rm.stream, msg, reliable)
		} else {
		if len(presenceIDs) != 0 {
			rm.router.SendToPresenceIDs(rm.logger, presenceIDs, msg, reliable)
		}

@@ -449,19 +443,7 @@ func (rm *RuntimeJavaScriptMatchCore) broadcastMessageDeferred(r *goja.Runtime)
		}

		presenceIDs, msg, reliable := rm.validateBroadcast(r, f)
		if msg == nil {
			return goja.Undefined()
		}

		if len(presenceIDs) == 0 {
			if err := rm.deferMessageFn(&DeferredMessage{
				Stream:   &rm.stream,
				Envelope: msg,
				Reliable: reliable,
			}); err != nil {
				panic(r.NewGoError(fmt.Errorf("error deferring message broadcast: %v", err)))
			}
		} else {
		if len(presenceIDs) != 0 {
			if err := rm.deferMessageFn(&DeferredMessage{
				PresenceIDs: presenceIDs,
				Envelope:    msg,
@@ -617,6 +599,10 @@ func (rm *RuntimeJavaScriptMatchCore) validateBroadcast(r *goja.Runtime, f goja.
		Reliable: reliable,
	}}}

	if presenceIDs == nil {
		presenceIDs = rm.presenceList.ListPresenceIDs()
	}

	return presenceIDs, msg, reliable
}

+6 −20
Original line number Diff line number Diff line
@@ -571,13 +571,7 @@ func (r *RuntimeLuaMatchCore) broadcastMessage(l *lua.LState) int {
	}

	presenceIDs, msg, reliable := r.validateBroadcast(l)
	if msg == nil {
		return 0
	}

	if len(presenceIDs) == 0 {
		r.router.SendToStream(r.logger, r.stream, msg, reliable)
	} else {
	if len(presenceIDs) != 0 {
		r.router.SendToPresenceIDs(r.logger, presenceIDs, msg, reliable)
	}

@@ -591,19 +585,7 @@ func (r *RuntimeLuaMatchCore) broadcastMessageDeferred(l *lua.LState) int {
	}

	presenceIDs, msg, reliable := r.validateBroadcast(l)
	if msg == nil {
		return 0
	}

	if len(presenceIDs) == 0 {
		if err := r.deferMessageFn(&DeferredMessage{
			Stream:   &r.stream,
			Envelope: msg,
			Reliable: reliable,
		}); err != nil {
			l.RaiseError("error deferring message broadcast: %v", err)
		}
	} else {
	if len(presenceIDs) != 0 {
		if err := r.deferMessageFn(&DeferredMessage{
			PresenceIDs: presenceIDs,
			Envelope:    msg,
@@ -778,6 +760,10 @@ func (r *RuntimeLuaMatchCore) validateBroadcast(l *lua.LState) ([]*PresenceID, *
		Reliable: reliable,
	}}}

	if presenceIDs == nil {
		presenceIDs = r.presenceList.ListPresenceIDs()
	}

	return presenceIDs, msg, reliable
}