Commit 5b849cef authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Clean up session registry.

parent fbd0cc7f
Loading
Loading
Loading
Loading
+19 −26
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ package server

import (
	"context"
	"go.uber.org/atomic"
	"sync"

	"github.com/gofrs/uuid"
@@ -62,54 +63,46 @@ type SessionRegistry interface {
}

type LocalSessionRegistry struct {
	sync.RWMutex

	sessions map[uuid.UUID]Session
	sessions     *sync.Map
	sessionCount *atomic.Int32
}

func NewLocalSessionRegistry() SessionRegistry {
	return &LocalSessionRegistry{
		sessions: make(map[uuid.UUID]Session),
		sessions:     &sync.Map{},
		sessionCount: atomic.NewInt32(0),
	}
}

func (r *LocalSessionRegistry) Stop() {}

func (r *LocalSessionRegistry) Count() int {
	var count int
	r.RLock()
	count = len(r.sessions)
	r.RUnlock()
	return count
	return int(r.sessionCount.Load())
}

func (r *LocalSessionRegistry) Get(sessionID uuid.UUID) Session {
	var session Session
	r.RLock()
	session = r.sessions[sessionID]
	r.RUnlock()
	return session
	session, ok := r.sessions.Load(sessionID)
	if !ok {
		return nil
	}
	return session.(Session)
}

func (r *LocalSessionRegistry) Add(session Session) {
	r.Lock()
	r.sessions[session.ID()] = session
	r.Unlock()
	r.sessions.Store(session.ID(), session)
	r.sessionCount.Inc()
}

func (r *LocalSessionRegistry) Remove(sessionID uuid.UUID) {
	r.Lock()
	delete(r.sessions, sessionID)
	r.Unlock()
	r.sessions.Delete(sessionID)
	r.sessionCount.Dec()
}

func (r *LocalSessionRegistry) Disconnect(ctx context.Context, sessionID uuid.UUID, node string) error {
	var session Session
	r.RLock()
	session = r.sessions[sessionID]
	r.RUnlock()
	if session != nil {
		session.Close("server-side session disconnect")
	session, ok := r.sessions.Load(sessionID)
	if ok {
		// No need to remove the session from the map, session.Close() will do that.
		session.(Session).Close("server-side session disconnect")
	}
	return nil
}