diff --git a/server/party_handler.go b/server/party_handler.go index 1caf658813e3dbd30dd4089e65e4d75b9fdff727..4e26061b0aa6759ccd3b72017cb71e684d190410 100644 --- a/server/party_handler.go +++ b/server/party_handler.go @@ -53,6 +53,7 @@ type PartyHandler struct { Stream PresenceStream stopped bool + expectedInitialLeader *rtapi.UserPresence leader *PresenceID leaderUserPresence *rtapi.UserPresence members []*PresenceID @@ -62,7 +63,7 @@ type PartyHandler struct { joinRequestUserPresences []*rtapi.UserPresence } -func NewPartyHandler(logger *zap.Logger, partyRegistry PartyRegistry, matchmaker Matchmaker, tracker Tracker, streamManager StreamManager, router MessageRouter, id uuid.UUID, node string, open bool, maxSize int) *PartyHandler { +func NewPartyHandler(logger *zap.Logger, partyRegistry PartyRegistry, matchmaker Matchmaker, tracker Tracker, streamManager StreamManager, router MessageRouter, id uuid.UUID, node string, open bool, maxSize int, presence *rtapi.UserPresence) *PartyHandler { idStr := fmt.Sprintf("%v.%v", id.String(), node) return &PartyHandler{ logger: logger.With(zap.String("party_id", idStr)), @@ -80,6 +81,7 @@ func NewPartyHandler(logger *zap.Logger, partyRegistry PartyRegistry, matchmaker Stream: PresenceStream{Mode: StreamModeParty, Subject: id, Label: node}, stopped: false, + expectedInitialLeader: presence, leader: nil, leaderUserPresence: nil, members: make([]*PresenceID, 0, maxSize), @@ -158,7 +160,25 @@ func (p *PartyHandler) Join(presences []*Presence) { } // Assign the party leader if this is the first join. + var initialLeader *Presence if p.leader == nil { + if p.expectedInitialLeader != nil { + expectedInitialLeader := p.expectedInitialLeader + p.expectedInitialLeader = nil + for _, presence := range presences { + if presence.GetUserId() == expectedInitialLeader.UserId && presence.GetSessionId() == expectedInitialLeader.SessionId { + // The initial leader is joining the party at creation time. + initialLeader = presence + p.leader = &presence.ID + p.leaderUserPresence = &rtapi.UserPresence{ + UserId: presence.GetUserId(), + SessionId: presence.GetSessionId(), + Username: presence.GetUsername(), + } + break + } + } + } p.leader = &presences[0].ID p.leaderUserPresence = &rtapi.UserPresence{ UserId: presences[0].GetUserId(), @@ -178,7 +198,16 @@ func (p *PartyHandler) Join(presences []*Presence) { SessionId: presence.GetSessionId(), Username: presence.GetUsername(), } + p.members = append(p.members, ¤tPresence.ID) + p.memberUserPresences = append(p.memberUserPresences, memberUserPresence) + memberUserPresences = append(memberUserPresences, memberUserPresence) + p.joinsInProgress-- + // Prepare message to be sent to the new presences. + if initialLeader != nil && presence == initialLeader { + // The party creator has already received this message in the pipeline, do not send it to them again. + continue + } presenceIDs[¤tPresence.ID] = &rtapi.Envelope{ Message: &rtapi.Envelope_Party{ Party: &rtapi.Party{ @@ -191,10 +220,6 @@ func (p *PartyHandler) Join(presences []*Presence) { }, }, } - p.members = append(p.members, ¤tPresence.ID) - p.memberUserPresences = append(p.memberUserPresences, memberUserPresence) - memberUserPresences = append(memberUserPresences, memberUserPresence) - p.joinsInProgress-- } p.Unlock() diff --git a/server/party_registry.go b/server/party_registry.go index 658a4298b2132607190a4da5cda3b75600209a38..627871f3e9c2343ccd54591daa96be0b6e8f7580 100644 --- a/server/party_registry.go +++ b/server/party_registry.go @@ -27,7 +27,7 @@ import ( var ErrPartyNotFound = errors.New("party not found") type PartyRegistry interface { - Create(open bool, maxSize int) *PartyHandler + Create(open bool, maxSize int, leader *rtapi.UserPresence) *PartyHandler Delete(id uuid.UUID) Join(id uuid.UUID, presences []*Presence) @@ -68,9 +68,9 @@ func NewLocalPartyRegistry(logger *zap.Logger, matchmaker Matchmaker, tracker Tr } } -func (p *LocalPartyRegistry) Create(open bool, maxSize int) *PartyHandler { +func (p *LocalPartyRegistry) Create(open bool, maxSize int, presence *rtapi.UserPresence) *PartyHandler { id := uuid.Must(uuid.NewV4()) - partyHandler := NewPartyHandler(p.logger, p, p.matchmaker, p.tracker, p.streamManager, p.router, id, p.node, open, maxSize) + partyHandler := NewPartyHandler(p.logger, p, p.matchmaker, p.tracker, p.streamManager, p.router, id, p.node, open, maxSize, presence) p.parties.Store(id, partyHandler) diff --git a/server/pipeline_party.go b/server/pipeline_party.go index f3091befbbb1354d2f1f14a52d85779e127618d6..43845cb5ec9b264412d2d9ac5258f6b2dd63f592 100644 --- a/server/pipeline_party.go +++ b/server/pipeline_party.go @@ -35,8 +35,14 @@ func (p *Pipeline) partyCreate(logger *zap.Logger, session Session, envelope *rt return } + presence := &rtapi.UserPresence{ + UserId: session.UserID().String(), + SessionId: session.ID().String(), + Username: session.Username(), + } + // Handle through the party registry. - ph := p.partyRegistry.Create(incoming.Open, int(incoming.MaxSize)) + ph := p.partyRegistry.Create(incoming.Open, int(incoming.MaxSize), presence) if ph == nil { session.Send(&rtapi.Envelope{Cid: envelope.Cid, Message: &rtapi.Envelope_Error{Error: &rtapi.Error{ Code: int32(rtapi.Error_RUNTIME_EXCEPTION), @@ -59,7 +65,14 @@ func (p *Pipeline) partyCreate(logger *zap.Logger, session Session, envelope *rt return } - session.Send(&rtapi.Envelope{Cid: envelope.Cid}, true) + session.Send(&rtapi.Envelope{Cid: envelope.Cid, Message: &rtapi.Envelope_Party{Party: &rtapi.Party{ + PartyId: ph.IDStr, + Open: incoming.Open, + MaxSize: incoming.MaxSize, + Self: presence, + Leader: presence, + Presences: []*rtapi.UserPresence{presence}, + }}}, true) } func (p *Pipeline) partyJoin(logger *zap.Logger, session Session, envelope *rtapi.Envelope) {