diff --git a/server/session_ws.go b/server/session_ws.go index 3719b1dbeb720e502e70af1eb9c6fc708a1bf78c..3a2696c12e3fd08aabb913daf16a27003356a74e 100644 --- a/server/session_ws.go +++ b/server/session_ws.go @@ -18,6 +18,8 @@ import ( "context" "errors" "fmt" + "io" + "io/ioutil" "net" "sync" "time" @@ -170,6 +172,33 @@ func (s *sessionWS) Expiry() int64 { return s.expiry } +type ReaderFunc func(p []byte) (n int, err error) + +func (f ReaderFunc) Read(p []byte) (n int, err error) { + return f(p) +} + +func (s *sessionWS) ReadMessageWithLivenessUpdate() (messageType int, p []byte, err error) { + var r io.Reader + messageType, r, err = s.conn.NextReader() + if err != nil { + return messageType, nil, err + } + + // Wrap original Reader into ours, which resets ping timer on every successful + // read. This allows us to transmit large messages over slow connection and don't + // disconnect clients due to ping timeouts + LivenessUpdatingReader := func(p []byte) (n int, err error) { + n, err = r.Read(p) + if err != nil && n > 0 { + s.maybeResetPingTimer() + } + return + } + p, err = ioutil.ReadAll(ReaderFunc(LivenessUpdatingReader)) + return +} + func (s *sessionWS) Consume() { // Fire an event for session start. if fn := s.runtime.EventSessionStart(); fn != nil { @@ -195,7 +224,7 @@ func (s *sessionWS) Consume() { IncomingLoop: for { - messageType, data, err := s.conn.ReadMessage() + messageType, data, err := s.ReadMessageWithLivenessUpdate() if err != nil { // Ignore "normal" WebSocket errors. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {