Unverified Commit fd381860 authored by Simon Esposito's avatar Simon Esposito Committed by GitHub
Browse files

Add refresh token rotation (#1091)

When a session is refreshed a new refresh token is now provided.
parent 89c58491
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr
- Add storage index create flag to read only from the index.
- Add caller id param to storage listing and storage index listing runtime APIs.
- Update Facebook Graph API usage from v11 to v18.
- Add support to refresh token rotation.

### Fixed
- Fixed multiple issues found by linter.
+7 −7
Original line number Diff line number Diff line
@@ -403,12 +403,12 @@ func securityInterceptorFunc(logger *zap.Logger, config Config, sessionCache Ses
			// Value of "authorization" or "grpc-authorization" was empty or repeated.
			return nil, status.Error(codes.Unauthenticated, "Auth token invalid")
		}
		userID, username, vars, exp, token, ok := parseBearerAuth([]byte(config.GetSession().EncryptionKey), auth[0])
		userID, username, vars, exp, tokenId, ok := parseBearerAuth([]byte(config.GetSession().EncryptionKey), auth[0])
		if !ok {
			// Value of "authorization" or "grpc-authorization" was malformed or expired.
			return nil, status.Error(codes.Unauthenticated, "Auth token invalid")
		}
		if !sessionCache.IsValidSession(userID, exp, token) {
		if !sessionCache.IsValidSession(userID, exp, tokenId) {
			return nil, status.Error(codes.Unauthenticated, "Auth token invalid")
		}
		ctx = context.WithValue(context.WithValue(context.WithValue(context.WithValue(ctx, ctxUserIDKey{}, userID), ctxUsernameKey{}, username), ctxVarsKey{}, vars), ctxExpiryKey{}, exp)
@@ -431,12 +431,12 @@ func securityInterceptorFunc(logger *zap.Logger, config Config, sessionCache Ses
			// Value of "authorization" or "grpc-authorization" was empty or repeated.
			return nil, status.Error(codes.Unauthenticated, "Auth token invalid")
		}
		userID, username, vars, exp, token, ok := parseBearerAuth([]byte(config.GetSession().EncryptionKey), auth[0])
		userID, username, vars, exp, tokenId, ok := parseBearerAuth([]byte(config.GetSession().EncryptionKey), auth[0])
		if !ok {
			// Value of "authorization" or "grpc-authorization" was malformed or expired.
			return nil, status.Error(codes.Unauthenticated, "Auth token invalid")
		}
		if !sessionCache.IsValidSession(userID, exp, token) {
		if !sessionCache.IsValidSession(userID, exp, tokenId) {
			return nil, status.Error(codes.Unauthenticated, "Auth token invalid")
		}
		ctx = context.WithValue(context.WithValue(context.WithValue(context.WithValue(ctx, ctxUserIDKey{}, userID), ctxUsernameKey{}, username), ctxVarsKey{}, vars), ctxExpiryKey{}, exp)
@@ -464,7 +464,7 @@ func parseBasicAuth(auth string) (username, password string, ok bool) {
	return cs[:s], cs[s+1:], true
}

func parseBearerAuth(hmacSecretByte []byte, auth string) (userID uuid.UUID, username string, vars map[string]string, exp int64, token string, ok bool) {
func parseBearerAuth(hmacSecretByte []byte, auth string) (userID uuid.UUID, username string, vars map[string]string, exp int64, tokenId string, ok bool) {
	if auth == "" {
		return
	}
@@ -475,7 +475,7 @@ func parseBearerAuth(hmacSecretByte []byte, auth string) (userID uuid.UUID, user
	return parseToken(hmacSecretByte, auth[len(prefix):])
}

func parseToken(hmacSecretByte []byte, tokenString string) (userID uuid.UUID, username string, vars map[string]string, exp int64, token string, ok bool) {
func parseToken(hmacSecretByte []byte, tokenString string) (userID uuid.UUID, username string, vars map[string]string, exp int64, tokenId string, ok bool) {
	jwtToken, err := jwt.ParseWithClaims(tokenString, &SessionTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
		if s, ok := token.Method.(*jwt.SigningMethodHMAC); !ok || s.Hash != crypto.SHA256 {
			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
@@ -493,7 +493,7 @@ func parseToken(hmacSecretByte []byte, tokenString string) (userID uuid.UUID, us
	if err != nil {
		return
	}
	return userID, claims.Username, claims.Vars, claims.ExpiresAt, tokenString, true
	return userID, claims.Username, claims.Vars, claims.ExpiresAt, claims.TokenId, true
}

func decompressHandler(logger *zap.Logger, h http.Handler) http.HandlerFunc {
+43 −32
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ var (
)

type SessionTokenClaims struct {
	TokenId   string            `json:"tid,omitempty"`
	UserId    string            `json:"uid,omitempty"`
	Username  string            `json:"usn,omitempty"`
	Vars      map[string]string `json:"vrs,omitempty"`
@@ -102,9 +103,10 @@ func (s *ApiServer) AuthenticateApple(ctx context.Context, in *api.AuthenticateA
		return nil, err
	}

	token, exp := generateToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, token, refreshExp, refreshToken)
	tokenID := uuid.Must(uuid.NewV4()).String()
	token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID)
	session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken}

	// After hook.
@@ -168,9 +170,10 @@ func (s *ApiServer) AuthenticateCustom(ctx context.Context, in *api.Authenticate
		return nil, err
	}

	token, exp := generateToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, token, refreshExp, refreshToken)
	tokenID := uuid.Must(uuid.NewV4()).String()
	token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID)
	session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken}

	// After hook.
@@ -234,9 +237,10 @@ func (s *ApiServer) AuthenticateDevice(ctx context.Context, in *api.Authenticate
		return nil, err
	}

	token, exp := generateToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, token, refreshExp, refreshToken)
	tokenID := uuid.Must(uuid.NewV4()).String()
	token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID)
	session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken}

	// After hook.
@@ -330,9 +334,10 @@ func (s *ApiServer) AuthenticateEmail(ctx context.Context, in *api.AuthenticateE
		return nil, err
	}

	token, exp := generateToken(s.config, dbUserID, username, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, dbUserID, username, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, token, refreshExp, refreshToken)
	tokenID := uuid.Must(uuid.NewV4()).String()
	token, exp := generateToken(s.config, tokenID, dbUserID, username, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, username, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID)
	session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken}

	// After hook.
@@ -397,9 +402,10 @@ func (s *ApiServer) AuthenticateFacebook(ctx context.Context, in *api.Authentica
		_ = importFacebookFriends(ctx, s.logger, s.db, s.router, s.socialClient, uuid.FromStringOrNil(dbUserID), dbUsername, in.Account.Token, false)
	}

	token, exp := generateToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, token, refreshExp, refreshToken)
	tokenID := uuid.Must(uuid.NewV4()).String()
	token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID)
	session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken}

	// After hook.
@@ -458,9 +464,10 @@ func (s *ApiServer) AuthenticateFacebookInstantGame(ctx context.Context, in *api
	if err != nil {
		return nil, err
	}
	token, exp := generateToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, token, refreshExp, refreshToken)
	tokenID := uuid.Must(uuid.NewV4()).String()
	token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID)
	session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken}

	// After hook.
@@ -532,9 +539,10 @@ func (s *ApiServer) AuthenticateGameCenter(ctx context.Context, in *api.Authenti
		return nil, err
	}

	token, exp := generateToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, token, refreshExp, refreshToken)
	tokenID := uuid.Must(uuid.NewV4()).String()
	token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID)
	session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken}

	// After hook.
@@ -594,9 +602,10 @@ func (s *ApiServer) AuthenticateGoogle(ctx context.Context, in *api.Authenticate
		return nil, err
	}

	token, exp := generateToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, token, refreshExp, refreshToken)
	tokenID := uuid.Must(uuid.NewV4()).String()
	token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID)
	session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken}

	// After hook.
@@ -665,9 +674,10 @@ func (s *ApiServer) AuthenticateSteam(ctx context.Context, in *api.AuthenticateS
		_ = importSteamFriends(ctx, s.logger, s.db, s.router, s.socialClient, uuid.FromStringOrNil(dbUserID), dbUsername, s.config.GetSocial().Steam.PublisherKey, steamID, false)
	}

	token, exp := generateToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, token, refreshExp, refreshToken)
	tokenID := uuid.Must(uuid.NewV4()).String()
	token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars)
	s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID)
	session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken}

	// After hook.
@@ -683,18 +693,19 @@ func (s *ApiServer) AuthenticateSteam(ctx context.Context, in *api.AuthenticateS
	return session, nil
}

func generateToken(config Config, userID, username string, vars map[string]string) (string, int64) {
func generateToken(config Config, tokenID, userID, username string, vars map[string]string) (string, int64) {
	exp := time.Now().UTC().Add(time.Duration(config.GetSession().TokenExpirySec) * time.Second).Unix()
	return generateTokenWithExpiry(config.GetSession().EncryptionKey, userID, username, vars, exp)
	return generateTokenWithExpiry(config.GetSession().EncryptionKey, tokenID, userID, username, vars, exp)
}

func generateRefreshToken(config Config, userID string, username string, vars map[string]string) (string, int64) {
func generateRefreshToken(config Config, tokenID, userID string, username string, vars map[string]string) (string, int64) {
	exp := time.Now().UTC().Add(time.Duration(config.GetSession().RefreshTokenExpirySec) * time.Second).Unix()
	return generateTokenWithExpiry(config.GetSession().RefreshEncryptionKey, userID, username, vars, exp)
	return generateTokenWithExpiry(config.GetSession().RefreshEncryptionKey, tokenID, userID, username, vars, exp)
}

func generateTokenWithExpiry(signingKey, userID, username string, vars map[string]string, exp int64) (string, int64) {
func generateTokenWithExpiry(signingKey, tokenID, userID, username string, vars map[string]string, exp int64) (string, int64) {
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, &SessionTokenClaims{
		TokenId:   tokenID,
		UserId:    userID,
		Username:  username,
		Vars:      vars,
+13 −5
Original line number Diff line number Diff line
@@ -53,7 +53,7 @@ func (s *ApiServer) SessionRefresh(ctx context.Context, in *api.SessionRefreshRe
		return nil, status.Error(codes.InvalidArgument, "Refresh token is required.")
	}

	userID, username, vars, err := SessionRefresh(ctx, s.logger, s.db, s.config, s.sessionCache, in.Token)
	userID, username, vars, tokenId, err := SessionRefresh(ctx, s.logger, s.db, s.config, s.sessionCache, in.Token)
	if err != nil {
		return nil, err
	}
@@ -65,14 +65,22 @@ func (s *ApiServer) SessionRefresh(ctx context.Context, in *api.SessionRefreshRe
	}
	userIDStr := userID.String()

	token, exp := generateToken(s.config, userIDStr, username, useVars)
	s.sessionCache.Add(userID, exp, token, 0, "")
	session := &api.Session{Created: false, Token: token, RefreshToken: in.Token}
	//newTokenId := uuid.Must(uuid.NewV4()).String()
	//token, tokenExp := generateToken(s.config, newTokenId, userIDStr, username, useVars)
	//refreshToken, refreshTokenExp := generateRefreshToken(s.config, newTokenId, userIDStr, username, useVars)
	//s.sessionCache.Remove(userID, tokenExp, "", refreshTokenExp, tokenId)
	//s.sessionCache.Add(userID, tokenExp, newTokenId, refreshTokenExp, newTokenId)
	//session := &api.Session{Created: false, Token: token, RefreshToken: refreshToken}

	token, tokenExp := generateToken(s.config, tokenId, userIDStr, username, useVars)
	refreshToken, refreshTokenExp := generateRefreshToken(s.config, tokenId, userIDStr, username, useVars)
	s.sessionCache.Add(userID, tokenExp, tokenId, refreshTokenExp, tokenId)
	session := &api.Session{Created: false, Token: token, RefreshToken: refreshToken}

	// After hook.
	if fn := s.runtime.AfterSessionRefresh(); fn != nil {
		afterFn := func(clientIP, clientPort string) error {
			return fn(ctx, s.logger, userIDStr, username, useVars, exp, clientIP, clientPort, session, in)
			return fn(ctx, s.logger, userIDStr, username, useVars, tokenExp, clientIP, clientPort, session, in)
		}

		// Execute the after function lambda wrapped in a trace for stats measurement.
+15 −15
Original line number Diff line number Diff line
@@ -31,13 +31,13 @@ var (
	ErrRefreshTokenInvalid = errors.New("refresh token invalid")
)

func SessionRefresh(ctx context.Context, logger *zap.Logger, db *sql.DB, config Config, sessionCache SessionCache, token string) (uuid.UUID, string, map[string]string, error) {
	userID, _, vars, exp, _, ok := parseToken([]byte(config.GetSession().RefreshEncryptionKey), token)
func SessionRefresh(ctx context.Context, logger *zap.Logger, db *sql.DB, config Config, sessionCache SessionCache, token string) (uuid.UUID, string, map[string]string, string, error) {
	userID, _, vars, exp, tokenId, ok := parseToken([]byte(config.GetSession().RefreshEncryptionKey), token)
	if !ok {
		return uuid.Nil, "", nil, status.Error(codes.Unauthenticated, "Refresh token invalid or expired.")
		return uuid.Nil, "", nil, "", status.Error(codes.Unauthenticated, "Refresh token invalid or expired.")
	}
	if !sessionCache.IsValidRefresh(userID, exp, token) {
		return uuid.Nil, "", nil, status.Error(codes.Unauthenticated, "Refresh token invalid or expired.")
	if !sessionCache.IsValidRefresh(userID, exp, tokenId) {
		return uuid.Nil, "", nil, "", status.Error(codes.Unauthenticated, "Refresh token invalid or expired.")
	}

	// Look for an existing account.
@@ -48,49 +48,49 @@ func SessionRefresh(ctx context.Context, logger *zap.Logger, db *sql.DB, config
	if err != nil {
		if err == sql.ErrNoRows {
			// Account not found and creation is never allowed for this type.
			return uuid.Nil, "", nil, status.Error(codes.NotFound, "User account not found.")
			return uuid.Nil, "", nil, "", status.Error(codes.NotFound, "User account not found.")
		}
		logger.Error("Error looking up user by ID.", zap.Error(err), zap.String("id", userID.String()))
		return uuid.Nil, "", nil, status.Error(codes.Internal, "Error finding user account.")
		return uuid.Nil, "", nil, "", status.Error(codes.Internal, "Error finding user account.")
	}

	// Check if it's disabled.
	if dbDisableTime.Status == pgtype.Present && dbDisableTime.Time.Unix() != 0 {
		logger.Info("User account is disabled.", zap.String("id", userID.String()))
		return uuid.Nil, "", nil, status.Error(codes.PermissionDenied, "User account banned.")
		return uuid.Nil, "", nil, "", status.Error(codes.PermissionDenied, "User account banned.")
	}

	return userID, dbUsername, vars, nil
	return userID, dbUsername, vars, tokenId, nil
}

func SessionLogout(config Config, sessionCache SessionCache, userID uuid.UUID, token, refreshToken string) error {
	var maybeSessionExp int64
	var maybeSessionToken string
	var maybeSessionTokenId string
	if token != "" {
		var sessionUserID uuid.UUID
		var ok bool
		sessionUserID, _, _, maybeSessionExp, maybeSessionToken, ok = parseToken([]byte(config.GetSession().EncryptionKey), token)
		sessionUserID, _, _, maybeSessionExp, maybeSessionTokenId, ok = parseToken([]byte(config.GetSession().EncryptionKey), token)
		if !ok || sessionUserID != userID {
			return ErrSessionTokenInvalid
		}
	}

	var maybeRefreshExp int64
	var maybeRefreshToken string
	var maybeRefreshTokenId string
	if refreshToken != "" {
		var refreshUserID uuid.UUID
		var ok bool
		refreshUserID, _, _, maybeRefreshExp, maybeRefreshToken, ok = parseToken([]byte(config.GetSession().RefreshEncryptionKey), refreshToken)
		refreshUserID, _, _, maybeRefreshExp, maybeRefreshTokenId, ok = parseToken([]byte(config.GetSession().RefreshEncryptionKey), refreshToken)
		if !ok || refreshUserID != userID {
			return ErrRefreshTokenInvalid
		}
	}

	if maybeSessionToken == "" && maybeRefreshToken == "" {
	if maybeSessionTokenId == "" && maybeRefreshTokenId == "" {
		sessionCache.RemoveAll(userID)
		return nil
	}

	sessionCache.Remove(userID, maybeSessionExp, maybeSessionToken, maybeRefreshExp, maybeRefreshToken)
	sessionCache.Remove(userID, maybeSessionExp, maybeSessionTokenId, maybeRefreshExp, maybeRefreshTokenId)
	return nil
}
Loading