Commit c47ce13e authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Improve socket session close semantics.

parent 33eb0ee3
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -28,6 +28,8 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr
- Authoritative matches now complete their stop phase faster to avoid unnecessary processing.
- Authoritative match join attempts now have their own bounded queue and no longer count towards the match call queue limit.
- Lua runtime group create function now sets the correct default max size if one is not specified.
- Improve socket session close semantics.
- Session logging now prints correct remote address if available when the connection is through a proxy.

### Fixed
- Correctly report execution mode in Lua runtime after hooks.
+1 −10
Original line number Diff line number Diff line
@@ -64,16 +64,7 @@ func NewSessionRegistry() *SessionRegistry {
	}
}

func (r *SessionRegistry) Stop() {
	r.Lock()
	for sessionID, session := range r.sessions {
		delete(r.sessions, sessionID)
		// Send graceful close messages to client connections.
		// No need to clean up presences or matchmaker entries because we only expect to be here on server shutdown.
		session.Close()
	}
	r.Unlock()
}
func (r *SessionRegistry) Stop() {}

func (r *SessionRegistry) Get(sessionID uuid.UUID) Session {
	var s Session
+54 −48
Original line number Diff line number Diff line
@@ -68,7 +68,6 @@ type sessionWS struct {
	pingTimer              *time.Timer
	pingTimerCAS           *atomic.Uint32
	outgoingCh             chan []byte
	outgoingStopCh         chan struct{}
}

func NewSessionWS(logger *zap.Logger, config Config, format SessionFormat, userID uuid.UUID, username string, expiry int64, clientIP string, clientPort string, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, conn *websocket.Conn, sessionRegistry *SessionRegistry, matchmaker Matchmaker, tracker Tracker) Session {
@@ -116,7 +115,6 @@ func NewSessionWS(logger *zap.Logger, config Config, format SessionFormat, userI
		pingTimer:              time.NewTimer(time.Duration(config.GetSocket().PingPeriodMs) * time.Millisecond),
		pingTimerCAS:           atomic.NewUint32(1),
		outgoingCh:             make(chan []byte, config.GetSocket().OutgoingQueueSize),
		outgoingStopCh:         make(chan struct{}),
	}
}

@@ -157,9 +155,12 @@ func (s *sessionWS) Expiry() int64 {
}

func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Session, envelope *rtapi.Envelope) bool) {
	defer s.cleanupClosedConnection()
	defer s.Close()
	s.conn.SetReadLimit(s.config.GetSocket().MaxMessageSizeBytes)
	s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration))
	if err := s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration)); err != nil {
		s.logger.Warn("Failed to set initial read deadline", zap.Error(err))
		return
	}
	s.conn.SetPongHandler(func(string) error {
		s.maybeResetPingTimer()
		return nil
@@ -190,7 +191,10 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess
		s.receivedMessageCounter--
		if s.receivedMessageCounter <= 0 {
			s.receivedMessageCounter = s.config.GetSocket().PingBackoffThreshold
			s.maybeResetPingTimer()
			if !s.maybeResetPingTimer() {
				// Problems resetting the ping timer indicate an error so we need to close the loop.
				break
			}
		}

		request := &rtapi.Envelope{}
@@ -222,12 +226,18 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess
	}
}

func (s *sessionWS) maybeResetPingTimer() {
func (s *sessionWS) maybeResetPingTimer() bool {
	// If there's already a reset in progress there's no need to wait.
	if !s.pingTimerCAS.CAS(1, 0) {
		return
		return true
	}
	defer s.pingTimerCAS.CAS(0, 1)

	s.Lock()
	if s.stopped {
		s.Unlock()
		return false
	}
	// CAS ensures concurrency is not a problem here.
	if !s.pingTimer.Stop() {
		select {
@@ -236,14 +246,21 @@ func (s *sessionWS) maybeResetPingTimer() {
		}
	}
	s.pingTimer.Reset(s.pingPeriodDuration)
	s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration))
	s.pingTimerCAS.CAS(0, 1)
	err := s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration))
	s.Unlock()
	if err != nil {
		s.logger.Warn("Failed to set read deadline", zap.Error(err))
		s.Close()
		return false
	}
	return true
}

func (s *sessionWS) processOutgoing() {
	defer s.Close()
	for {
		select {
		case <-s.outgoingStopCh:
		case <-s.ctx.Done():
			// Session is closing, close the outgoing process routine.
			return
		case <-s.pingTimer.C:
@@ -278,13 +295,15 @@ func (s *sessionWS) pingNow() bool {
		s.Unlock()
		return false
	}
	s.conn.SetWriteDeadline(time.Now().Add(s.writeWaitDuration))
	if err := s.conn.SetWriteDeadline(time.Now().Add(s.writeWaitDuration)); err != nil {
		s.Unlock()
		s.logger.Warn("Could not set write deadline to ping", zap.Error(err))
		return false
	}
	err := s.conn.WriteMessage(websocket.PingMessage, []byte{})
	s.Unlock()
	if err != nil {
		s.logger.Warn("Could not send ping, closing channel", zap.String("remoteAddress", s.conn.RemoteAddr().String()), zap.Error(err))
		// The connection has already failed.
		s.cleanupClosedConnection()
		s.logger.Warn("Could not send ping", zap.String("remoteAddress", fmt.Sprintf("%v:%v", s.clientIP, s.clientPort)), zap.Error(err))
		return false
	}

@@ -361,12 +380,12 @@ func (s *sessionWS) SendBytes(isStream bool, mode uint8, payload []byte) error {
		// to start dropping messages, which might cause unexpected behaviour.
		s.Unlock()
		s.logger.Warn("Could not write message, session outgoing queue full")
		s.cleanupClosedConnection()
		s.Close()
		return ErrSessionQueueFull
	}
}

func (s *sessionWS) cleanupClosedConnection() {
func (s *sessionWS) Close() {
	s.Lock()
	if s.stopped {
		s.Unlock()
@@ -379,50 +398,37 @@ func (s *sessionWS) cleanupClosedConnection() {
	s.ctxCancelFn()

	if s.logger.Core().Enabled(zap.DebugLevel) {
		s.logger.Info("Cleaning up closed client connection", zap.String("remoteAddress", s.conn.RemoteAddr().String()))
		s.logger.Info("Cleaning up closed client connection", zap.String("remoteAddress", fmt.Sprintf("%v:%v", s.clientIP, s.clientPort)))
	}

	// When connection close originates internally in the session, ensure cleanup of external resources and references.
	s.sessionRegistry.remove(s.id)
	s.matchmaker.RemoveAll(s.id)
	if err := s.matchmaker.RemoveAll(s.id); err != nil {
		s.logger.Warn("Failed to remove all matchmaking tickets", zap.Error(err))
	}
	if s.logger.Core().Enabled(zap.DebugLevel) {
		s.logger.Info("Cleaned up closed connection matchmaker", zap.String("remoteAddress", fmt.Sprintf("%v:%v", s.clientIP, s.clientPort)))
	}
	s.tracker.UntrackAll(s.id)

	// Clean up internals.
	s.pingTimer.Stop()
	close(s.outgoingStopCh)
	close(s.outgoingCh)

	// Close WebSocket.
	s.conn.Close()
	s.logger.Info("Closed client connection")
	if s.logger.Core().Enabled(zap.DebugLevel) {
		s.logger.Info("Cleaned up closed connection tracker", zap.String("remoteAddress", fmt.Sprintf("%v:%v", s.clientIP, s.clientPort)))
	}

func (s *sessionWS) Close() {
	s.Lock()
	if s.stopped {
		s.Unlock()
		return
	s.sessionRegistry.remove(s.id)
	if s.logger.Core().Enabled(zap.DebugLevel) {
		s.logger.Info("Cleaned up closed connection session registry", zap.String("remoteAddress", fmt.Sprintf("%v:%v", s.clientIP, s.clientPort)))
	}
	s.stopped = true
	s.Unlock()

	// Cancel any ongoing operations tied to this session.
	s.ctxCancelFn()

	// Expect the caller of this session.Close() to clean up external resources (like presences) separately.

	// Clean up internals.
	s.pingTimer.Stop()
	close(s.outgoingStopCh)
	close(s.outgoingCh)

	// Send close message.
	err := s.conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(s.writeWaitDuration))
	if err != nil {
		s.logger.Warn("Could not send close message, closing prematurely", zap.String("remoteAddress", s.conn.RemoteAddr().String()), zap.Error(err))
	if err := s.conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(s.writeWaitDuration)); err != nil {
		s.logger.Debug("Could not send close message", zap.String("remoteAddress", fmt.Sprintf("%v:%v", s.clientIP, s.clientPort)), zap.Error(err))
	}

	// Close WebSocket.
	s.conn.Close()
	if err := s.conn.Close(); err != nil {
		s.logger.Debug("Could not close", zap.String("remoteAddress", fmt.Sprintf("%v:%v", s.clientIP, s.clientPort)), zap.Error(err))
	}

	s.logger.Info("Closed client connection")
}