Commit 85467ff8 authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Enable pure Protobuf binary messaging over WebSocket. (#279)

parent d37914a7
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr

## [Unreleased]
### Added
- WebSocket connections can now use pure Protobuf binary messaging.
- Lua runtime tournament listings now return duration, end active, and end time fields.
- Lua runtime tournament end hooks now contain duration, end active, and end time fields.
- Lua runtime tournament reset hooks now contain duration, end active, and end time fields.
+35 −7
Original line number Diff line number Diff line
@@ -15,7 +15,9 @@
package server

import (
	"bytes"
	"github.com/golang/protobuf/jsonpb"
	"github.com/golang/protobuf/proto"
	"github.com/heroiclabs/nakama/rtapi"
	"go.uber.org/zap"
)
@@ -45,19 +47,45 @@ func (r *LocalMessageRouter) SendToPresenceIDs(logger *zap.Logger, presenceIDs [
		return
	}

	payload, err := r.jsonpbMarshaler.MarshalToString(envelope)
	if err != nil {
		logger.Error("Could not marshall message to json", zap.Error(err))
		return
	}
	payloadBytes := []byte(payload)
	// Prepare payload variables but do not initialize until we hit a session that needs them to avoid unnecessary work.
	var payloadProtobuf []byte
	var payloadJson []byte

	for _, presenceID := range presenceIDs {
		session := r.sessionRegistry.Get(presenceID.SessionID)
		if session == nil {
			logger.Debug("No session to route to", zap.String("sid", presenceID.SessionID.String()))
			continue
		}
		if err := session.SendBytes(isStream, mode, payloadBytes); err != nil {

		var err error
		switch session.Format() {
		case SessionFormatProtobuf:
			if payloadProtobuf == nil {
				// Marshal the payload now that we know this format is needed.
				payloadProtobuf, err = proto.Marshal(envelope)
				if err != nil {
					logger.Error("Could not marshal message", zap.Error(err))
					return
				}
			}
			err = session.SendBytes(isStream, mode, payloadProtobuf)
		case SessionFormatJson:
			fallthrough
		default:
			if payloadJson == nil {
				// Marshal the payload now that we know this format is needed.
				var buf bytes.Buffer
				if err = r.jsonpbMarshaler.Marshal(&buf, envelope); err == nil {
					payloadJson = buf.Bytes()
				} else {
					logger.Error("Could not marshal message", zap.Error(err))
					return
				}
			}
			err = session.SendBytes(isStream, mode, payloadJson)
		}
		if err != nil {
			logger.Error("Failed to route to", zap.String("sid", presenceID.SessionID.String()), zap.Error(err))
		}
	}
+58 −22
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ import (
	"context"
	"errors"
	"fmt"
	"github.com/golang/protobuf/proto"
	"sync"
	"time"

@@ -39,6 +40,7 @@ type sessionWS struct {
	logger     *zap.Logger
	config     Config
	id         uuid.UUID
	format     SessionFormat
	userID     uuid.UUID
	username   *atomic.String
	expiry     int64
@@ -50,6 +52,7 @@ type sessionWS struct {

	jsonpbMarshaler        *jsonpb.Marshaler
	jsonpbUnmarshaler      *jsonpb.Unmarshaler
	wsMessageType          int
	queuePriorityThreshold int
	pingPeriodDuration     time.Duration
	pongWaitDuration       time.Duration
@@ -68,18 +71,24 @@ type sessionWS struct {
	outgoingStopCh         chan struct{}
}

func NewSessionWS(logger *zap.Logger, config Config, userID uuid.UUID, username string, expiry int64, clientIP string, clientPort string, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, conn *websocket.Conn, sessionRegistry *SessionRegistry, matchmaker Matchmaker, tracker Tracker) Session {
func NewSessionWS(logger *zap.Logger, config Config, format SessionFormat, userID uuid.UUID, username string, expiry int64, clientIP string, clientPort string, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, conn *websocket.Conn, sessionRegistry *SessionRegistry, matchmaker Matchmaker, tracker Tracker) Session {
	sessionID := uuid.Must(uuid.NewV4())
	sessionLogger := logger.With(zap.String("uid", userID.String()), zap.String("sid", sessionID.String()))

	sessionLogger.Info("New WebSocket session connected")
	sessionLogger.Info("New WebSocket session connected", zap.Uint8("format", uint8(format)))

	ctx, ctxCancelFn := context.WithCancel(context.Background())

	wsMessageType := websocket.TextMessage
	if format == SessionFormatProtobuf {
		wsMessageType = websocket.BinaryMessage
	}

	return &sessionWS{
		logger:     sessionLogger,
		config:     config,
		id:         sessionID,
		format:     format,
		userID:     userID,
		username:   atomic.NewString(username),
		expiry:     expiry,
@@ -91,6 +100,7 @@ func NewSessionWS(logger *zap.Logger, config Config, userID uuid.UUID, username

		jsonpbMarshaler:        jsonpbMarshaler,
		jsonpbUnmarshaler:      jsonpbUnmarshaler,
		wsMessageType:          wsMessageType,
		queuePriorityThreshold: (config.GetSocket().OutgoingQueueSize / 3) * 2,
		pingPeriodDuration:     time.Duration(config.GetSocket().PingPeriodMs) * time.Millisecond,
		pongWaitDuration:       time.Duration(config.GetSocket().PongWaitMs) * time.Millisecond,
@@ -159,7 +169,7 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess
	go s.processOutgoing()

	for {
		_, data, err := s.conn.ReadMessage()
		messageType, data, err := s.conn.ReadMessage()
		if err != nil {
			// Ignore "normal" WebSocket errors.
			if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
@@ -170,6 +180,12 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess
			}
			break
		}
		if messageType != s.wsMessageType {
			// Expected text but received binary, or expected binary but received text.
			// Disconnect client if it attempts to use this kind of mixed protocol mode.
			s.logger.Debug("Received unexpected WebSocket message type", zap.Int("expected", s.wsMessageType), zap.Int("actual", messageType))
			break
		}

		s.receivedMessageCounter--
		if s.receivedMessageCounter <= 0 {
@@ -178,11 +194,20 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess
		}

		request := &rtapi.Envelope{}
		if err = s.jsonpbUnmarshaler.Unmarshal(bytes.NewReader(data), request); err != nil {
		switch s.format {
		case SessionFormatProtobuf:
			err = proto.Unmarshal(data, request)
		case SessionFormatJson:
			fallthrough
		default:
			err = s.jsonpbUnmarshaler.Unmarshal(bytes.NewReader(data), request)
		}
		if err != nil {
			// If the payload is malformed the client is incompatible or misbehaving, either way disconnect it now.
			s.logger.Warn("Received malformed payload", zap.String("data", string(data)))
			s.logger.Warn("Received malformed payload", zap.Binary("data", data))
			break
		} else {
		}

		switch request.Cid {
		case "":
			if !processRequest(s.logger, s, request) {
@@ -196,7 +221,6 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess
		}
	}
}
}

func (s *sessionWS) maybeResetPingTimer() {
	// If there's already a reset in progress there's no need to wait.
@@ -238,7 +262,7 @@ func (s *sessionWS) processOutgoing() {
			}
			// Process the outgoing message queue.
			s.conn.SetWriteDeadline(time.Now().Add(s.writeWaitDuration))
			if err := s.conn.WriteMessage(websocket.TextMessage, payload); err != nil {
			if err := s.conn.WriteMessage(s.wsMessageType, payload); err != nil {
				s.Unlock()
				s.logger.Warn("Could not write message", zap.Error(err))
				return
@@ -268,22 +292,34 @@ func (s *sessionWS) pingNow() bool {
}

func (s *sessionWS) Format() SessionFormat {
	return SessionFormatJson
	return s.format
}

func (s *sessionWS) Send(isStream bool, mode uint8, envelope *rtapi.Envelope) error {
	payload, err := s.jsonpbMarshaler.MarshalToString(envelope)
	var payload []byte
	var err error
	switch s.format {
	case SessionFormatProtobuf:
		payload, err = proto.Marshal(envelope)
	case SessionFormatJson:
		fallthrough
	default:
		var buf bytes.Buffer
		if err = s.jsonpbMarshaler.Marshal(&buf, envelope); err == nil {
			payload = buf.Bytes()
		}
	}
	if err != nil {
		s.logger.Warn("Could not marshal to json", zap.Error(err))
		s.logger.Warn("Could not marshal envelope", zap.Error(err))
		return err
	}

	if s.logger.Core().Enabled(zap.DebugLevel) {
		switch envelope.Message.(type) {
		case *rtapi.Envelope_Error:
			s.logger.Debug("Sending error message", zap.String("payload", payload))
			s.logger.Debug("Sending error message", zap.Binary("payload", payload))
		default:
			s.logger.Debug(fmt.Sprintf("Sending %T message", envelope.Message), zap.String("payload", payload))
			s.logger.Debug(fmt.Sprintf("Sending %T message", envelope.Message), zap.Any("envelope", envelope))
		}
	}

+16 −1
Original line number Diff line number Diff line
@@ -40,6 +40,21 @@ func NewSocketWsAcceptor(logger *zap.Logger, config Config, sessionRegistry *Ses

	// This handler will be attached to the API Gateway server.
	return func(w http.ResponseWriter, r *http.Request) {
		// Check format.
		var format SessionFormat
		switch r.URL.Query().Get("format") {
		case "protobuf":
			format = SessionFormatProtobuf
		case "json":
			fallthrough
		case "":
			format = SessionFormatJson
		default:
			// Invalid values are rejected.
			http.Error(w, "Invalid format parameter", 400)
			return
		}

		// Check authentication.
		token := r.URL.Query().Get("token")
		if token == "" {
@@ -90,7 +105,7 @@ func NewSocketWsAcceptor(logger *zap.Logger, config Config, sessionRegistry *Ses
		span := trace.NewSpan("nakama.session.ws", nil, trace.StartOptions{})

		// Wrap the connection for application handling.
		s := NewSessionWS(logger, config, userID, username, expiry, clientIP, clientPort, jsonpbMarshaler, jsonpbUnmarshaler, conn, sessionRegistry, matchmaker, tracker)
		s := NewSessionWS(logger, config, format, userID, username, expiry, clientIP, clientPort, jsonpbMarshaler, jsonpbUnmarshaler, conn, sessionRegistry, matchmaker, tracker)

		// Add to the session registry.
		sessionRegistry.add(s)