Commit 03b664d9 authored by Mo Firouz's avatar Mo Firouz
Browse files

Improve concurrency for closed sockets. Merged #25

parent 158c6500
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -8,8 +8,13 @@ The format is based on [keep a changelog](http://keepachangelog.com/) and this p
- Include Dockerfile and Docker instructions.
- Use a default limit in topic message listings if one is not provided.

### Changed
- Improve warning message on migration database creation.
- Print database connections to logs on server start.

### Fixed
- Enforce concurrency control on outgoing socket messages.
- Improve concurrency for closed sockets.
- Correct session lookup for realtime message routing.
- Fix input validation when sending topic messages.
- Correct handling of IDs in various login options.
+5 −1
Original line number Diff line number Diff line
@@ -126,7 +126,11 @@ func MigrateParse(args []string, logger zap.Logger) {

	_, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname))
	if err != nil {
		logger.Info("Database could not be created", zap.Error(err))
		if err.Error() == fmt.Sprintf("pq: database \"%s\" already exists", dbname) {
			logger.Info("Using existing database", zap.String("name", dbname))
		} else {
			logger.Fatal("Database could not be created", zap.Error(err))
		}
	} else {
		logger.Info("Database created", zap.String("name", dbname))
	}
+1 −0
Original line number Diff line number Diff line
@@ -85,6 +85,7 @@ func main() {
	mlogger.Info("Nakama starting", zap.String("at", time.Now().UTC().Format("2006-01-02 15:04:05.000 -0700 MST")))
	mlogger.Info("Node", zap.String("name", config.GetName()), zap.String("version", semver))
	mlogger.Info("Data directory", zap.String("path", config.GetDataDir()))
	mlogger.Info("Database connections", zap.Object("dsns", config.GetDSNS()))

	db := dbConnect(mlogger, config.GetDSNS())

+15 −12
Original line number Diff line number Diff line
@@ -29,16 +29,18 @@ type pipeline struct {
	socialClient    *social.Client
	tracker         Tracker
	messageRouter   MessageRouter
	sessionRegistry *SessionRegistry
}

// NewPipeline creates a new Pipeline
func NewPipeline(config Config, db *sql.DB, socialClient *social.Client, tracker Tracker, messageRouter MessageRouter) *pipeline {
func NewPipeline(config Config, db *sql.DB, socialClient *social.Client, tracker Tracker, messageRouter MessageRouter, registry *SessionRegistry) *pipeline {
	return &pipeline{
		config:          config,
		db:              db,
		socialClient:    socialClient,
		tracker:         tracker,
		messageRouter:   messageRouter,
		sessionRegistry: registry,
	}
}

@@ -48,7 +50,8 @@ func (p *pipeline) processRequest(logger zap.Logger, session *session, envelope
	switch envelope.Payload.(type) {
	case *Envelope_Logout:
		// TODO Store JWT into a blacklist until remaining JWT expiry.
		session.Close()
		p.sessionRegistry.remove(session)
		session.close()

	case *Envelope_Link:
		p.linkID(logger, session, envelope)
+14 −10
Original line number Diff line number Diff line
@@ -102,7 +102,6 @@ func (s *session) pingPeriodically() {
}

func (s *session) pingNow() bool {
	// Websocket ping.
	s.Lock()
	if s.stopped {
		s.Unlock()
@@ -113,7 +112,7 @@ func (s *session) pingNow() bool {
	s.Unlock()
	if err != nil {
		s.logger.Warn("Could not send ping. Closing channel", zap.String("remoteAddress", s.conn.RemoteAddr().String()), zap.Error(err))
		s.Close()
		s.cleanupClosedConnection() // The connection has already failed
		return false
	}

@@ -148,7 +147,13 @@ func (s *session) SendBytes(payload []byte) error {
	}

	s.conn.SetWriteDeadline(time.Now().Add(time.Duration(s.config.GetTransport().WriteWaitMs) * time.Millisecond))
	return s.conn.WriteMessage(websocket.BinaryMessage, payload)
	err := s.conn.WriteMessage(websocket.BinaryMessage, payload)
	if err != nil {
		s.logger.Warn("Could not write message", zap.Error(err))
		//TODO investigate whether we need to cleanupClosedConnection if write fails
	}

	return err
}

func (s *session) cleanupClosedConnection() {
@@ -159,27 +164,26 @@ func (s *session) cleanupClosedConnection() {
	s.stopped = true
	s.Unlock()

	s.logger.Info("Clean up closed client connection.", zap.String("remoteAddress", s.conn.RemoteAddr().String()))

	s.logger.Info("Cleaning up closed client connection", zap.String("remoteAddress", s.conn.RemoteAddr().String()))
	s.unregister(s)
	s.pingTicker.Stop()
	s.conn.Close()
	s.logger.Info("Closed client connection")
}

func (s *session) Close() {
	s.unregister(s)
	s.pingTicker.Stop()

func (s *session) close() {
	s.Lock()
	defer s.Unlock()
	if s.stopped {
		return
	}
	s.stopped = true
	s.Unlock()

	s.pingTicker.Stop()
	err := s.conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Duration(s.config.GetTransport().WriteWaitMs)*time.Millisecond))
	if err != nil {
		s.logger.Warn("Could not send close message. Closing prematurely.", zap.String("remoteAddress", s.conn.RemoteAddr().String()), zap.Error(err))
	}
	s.conn.Close()
	s.logger.Info("Closed client connection")
}
Loading