diff --git a/server/session_ws.go b/server/session_ws.go index 58ca96eb301639d92d0fa0605472e50eb9f72cd2..060109471ffae2fb6190b1344cad99da017f0c85 100644 --- a/server/session_ws.go +++ b/server/session_ws.go @@ -63,6 +63,7 @@ type sessionWS struct { conn *websocket.Conn receivedMessageCounter int pingTimer *time.Timer + pingTimerCAS *atomic.Uint32 outgoingCh chan []byte outgoingStopCh chan struct{} } @@ -103,6 +104,7 @@ func NewSessionWS(logger *zap.Logger, config Config, userID uuid.UUID, username conn: conn, receivedMessageCounter: config.GetSocket().PingBackoffThreshold, pingTimer: time.NewTimer(time.Duration(config.GetSocket().PingPeriodMs) * time.Millisecond), + pingTimerCAS: atomic.NewUint32(1), outgoingCh: make(chan []byte, config.GetSocket().OutgoingQueueSize), outgoingStopCh: make(chan struct{}), } @@ -149,8 +151,7 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess s.conn.SetReadLimit(s.config.GetSocket().MaxMessageSizeBytes) s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration)) s.conn.SetPongHandler(func(string) error { - s.pingTimer.Reset(s.pingPeriodDuration) - s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration)) + s.maybeResetPingTimer() return nil }) @@ -173,8 +174,7 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess s.receivedMessageCounter-- if s.receivedMessageCounter <= 0 { s.receivedMessageCounter = s.config.GetSocket().PingBackoffThreshold - s.pingTimer.Reset(s.pingPeriodDuration) - s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration)) + s.maybeResetPingTimer() } request := &rtapi.Envelope{} @@ -183,7 +183,6 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess s.logger.Warn("Received malformed payload", zap.String("data", string(data))) break } else { - // TODO Add session-global context here to cancel in-progress operations when the session is closed. switch request.Cid { case "": if !processRequest(s.logger, s, request) { @@ -199,6 +198,24 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess } } +func (s *sessionWS) maybeResetPingTimer() { + // If there's already a reset in progress there's no need to wait. + if !s.pingTimerCAS.CAS(1, 0) { + return + } + + // CAS ensures concurrency is not a problem here. + if !s.pingTimer.Stop() { + select { + case <-s.pingTimer.C: + default: + } + } + s.pingTimer.Reset(s.pingPeriodDuration) + s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration)) + s.pingTimerCAS.CAS(0, 1) +} + func (s *sessionWS) processOutgoing() { for { select { @@ -246,7 +263,7 @@ func (s *sessionWS) pingNow() bool { s.cleanupClosedConnection() return false } - s.pingTimer.Reset(s.pingPeriodDuration) + return true }