Commit 09ade675 authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Improved cancellation of ongoing work when clients disconnect. (#249)

parent 5ee05f05
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -4,6 +4,9 @@ All notable changes to this project are documented below.
The format is based on [keep a changelog](http://keepachangelog.com) and this project uses [semantic versioning](http://semver.org).

## [Unreleased]
### Changed
- Improved cancellation of ongoing work when clients disconnect.

### Fixed
- Use leaderboard expires rather than end active IDs with leaderboard resets.
- Better validation of tournament duration when a reset schedule is set.
+1 −1
Original line number Diff line number Diff line
@@ -310,7 +310,7 @@ type NakamaModule interface {
	StreamClose(mode uint8, subject, descriptor, label string) error
	StreamSend(mode uint8, subject, descriptor, label, data string) error

	MatchCreate(module string, params map[string]interface{}) (string, error)
	MatchCreate(ctx context.Context, module string, params map[string]interface{}) (string, error)
	MatchList(ctx context.Context, limit int, authoritative bool, label string, minSize, maxSize int) []*api.Match

	NotificationSend(ctx context.Context, userID, subject string, content map[string]interface{}, code int, sender string, persistent bool) error
+6 −6
Original line number Diff line number Diff line
@@ -44,7 +44,7 @@ func (s *ApiServer) GetAccount(ctx context.Context, in *empty.Empty) (*api.Accou

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		err, code := fn(s.logger, userID.String(), ctx.Value(ctxUsernameKey{}).(string), ctx.Value(ctxExpiryKey{}).(int64), clientIP, clientPort)
		err, code := fn(ctx, s.logger, userID.String(), ctx.Value(ctxUsernameKey{}).(string), ctx.Value(ctxExpiryKey{}).(int64), clientIP, clientPort)
		if err != nil {
			return nil, status.Error(code, err.Error())
		}
@@ -55,7 +55,7 @@ func (s *ApiServer) GetAccount(ctx context.Context, in *empty.Empty) (*api.Accou
		stats.Record(statsCtx, MetricsApiTimeSpentMsec.M(float64(time.Now().UTC().UnixNano()-startNanos)/1000), MetricsApiCount.M(1))
	}

	user, err := GetAccount(s.logger, s.db, s.tracker, userID)
	user, err := GetAccount(ctx, s.logger, s.db, s.tracker, userID)
	if err != nil {
		if err == ErrAccountNotFound {
			return nil, status.Error(codes.NotFound, "Account not found.")
@@ -73,7 +73,7 @@ func (s *ApiServer) GetAccount(ctx context.Context, in *empty.Empty) (*api.Accou

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		fn(s.logger, userID.String(), ctx.Value(ctxUsernameKey{}).(string), ctx.Value(ctxExpiryKey{}).(int64), clientIP, clientPort, user)
		fn(ctx, s.logger, userID.String(), ctx.Value(ctxUsernameKey{}).(string), ctx.Value(ctxExpiryKey{}).(int64), clientIP, clientPort, user)

		// Stats measurement end boundary.
		span.End()
@@ -97,7 +97,7 @@ func (s *ApiServer) UpdateAccount(ctx context.Context, in *api.UpdateAccountRequ

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		result, err, code := fn(s.logger, userID.String(), ctx.Value(ctxUsernameKey{}).(string), ctx.Value(ctxExpiryKey{}).(int64), clientIP, clientPort, in)
		result, err, code := fn(ctx, s.logger, userID.String(), ctx.Value(ctxUsernameKey{}).(string), ctx.Value(ctxExpiryKey{}).(int64), clientIP, clientPort, in)
		if err != nil {
			return nil, status.Error(code, err.Error())
		}
@@ -120,7 +120,7 @@ func (s *ApiServer) UpdateAccount(ctx context.Context, in *api.UpdateAccountRequ
		}
	}

	err := UpdateAccount(s.db, s.logger, userID, username, in.GetDisplayName(), in.GetTimezone(), in.GetLocation(), in.GetLangTag(), in.GetAvatarUrl(), nil)
	err := UpdateAccount(ctx, s.logger, s.db, userID, username, in.GetDisplayName(), in.GetTimezone(), in.GetLocation(), in.GetLangTag(), in.GetAvatarUrl(), nil)
	if err != nil {
		if _, ok := err.(*pq.Error); ok {
			return nil, status.Error(codes.Internal, "Error while trying to update account.")
@@ -138,7 +138,7 @@ func (s *ApiServer) UpdateAccount(ctx context.Context, in *api.UpdateAccountRequ

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		fn(s.logger, userID.String(), ctx.Value(ctxUsernameKey{}).(string), ctx.Value(ctxExpiryKey{}).(int64), clientIP, clientPort, in)
		fn(ctx, s.logger, userID.String(), ctx.Value(ctxUsernameKey{}).(string), ctx.Value(ctxExpiryKey{}).(int64), clientIP, clientPort, in)

		// Stats measurement end boundary.
		span.End()
+22 −22
Original line number Diff line number Diff line
@@ -50,7 +50,7 @@ func (s *ApiServer) AuthenticateCustom(ctx context.Context, in *api.Authenticate

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		result, err, code := fn(s.logger, "", "", 0, clientIP, clientPort, in)
		result, err, code := fn(ctx, s.logger, "", "", 0, clientIP, clientPort, in)
		if err != nil {
			return nil, status.Error(code, err.Error())
		}
@@ -85,7 +85,7 @@ func (s *ApiServer) AuthenticateCustom(ctx context.Context, in *api.Authenticate

	create := in.Create == nil || in.Create.Value

	dbUserID, dbUsername, created, err := AuthenticateCustom(s.logger, s.db, in.Account.Id, username, create)
	dbUserID, dbUsername, created, err := AuthenticateCustom(ctx, s.logger, s.db, in.Account.Id, username, create)
	if err != nil {
		return nil, err
	}
@@ -103,7 +103,7 @@ func (s *ApiServer) AuthenticateCustom(ctx context.Context, in *api.Authenticate

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		fn(s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)
		fn(ctx, s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)

		// Stats measurement end boundary.
		span.End()
@@ -125,7 +125,7 @@ func (s *ApiServer) AuthenticateDevice(ctx context.Context, in *api.Authenticate

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		result, err, code := fn(s.logger, "", "", 0, clientIP, clientPort, in)
		result, err, code := fn(ctx, s.logger, "", "", 0, clientIP, clientPort, in)
		if err != nil {
			return nil, status.Error(code, err.Error())
		}
@@ -160,7 +160,7 @@ func (s *ApiServer) AuthenticateDevice(ctx context.Context, in *api.Authenticate

	create := in.Create == nil || in.Create.Value

	dbUserID, dbUsername, created, err := AuthenticateDevice(s.logger, s.db, in.Account.Id, username, create)
	dbUserID, dbUsername, created, err := AuthenticateDevice(ctx, s.logger, s.db, in.Account.Id, username, create)
	if err != nil {
		return nil, err
	}
@@ -178,7 +178,7 @@ func (s *ApiServer) AuthenticateDevice(ctx context.Context, in *api.Authenticate

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		fn(s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)
		fn(ctx, s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)

		// Stats measurement end boundary.
		span.End()
@@ -200,7 +200,7 @@ func (s *ApiServer) AuthenticateEmail(ctx context.Context, in *api.AuthenticateE

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		result, err, code := fn(s.logger, "", "", 0, clientIP, clientPort, in)
		result, err, code := fn(ctx, s.logger, "", "", 0, clientIP, clientPort, in)
		if err != nil {
			return nil, status.Error(code, err.Error())
		}
@@ -242,7 +242,7 @@ func (s *ApiServer) AuthenticateEmail(ctx context.Context, in *api.AuthenticateE

	create := in.Create == nil || in.Create.Value

	dbUserID, dbUsername, created, err := AuthenticateEmail(s.logger, s.db, cleanEmail, email.Password, username, create)
	dbUserID, dbUsername, created, err := AuthenticateEmail(ctx, s.logger, s.db, cleanEmail, email.Password, username, create)
	if err != nil {
		return nil, err
	}
@@ -260,7 +260,7 @@ func (s *ApiServer) AuthenticateEmail(ctx context.Context, in *api.AuthenticateE

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		fn(s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)
		fn(ctx, s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)

		// Stats measurement end boundary.
		span.End()
@@ -282,7 +282,7 @@ func (s *ApiServer) AuthenticateFacebook(ctx context.Context, in *api.Authentica

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		result, err, code := fn(s.logger, "", "", 0, clientIP, clientPort, in)
		result, err, code := fn(ctx, s.logger, "", "", 0, clientIP, clientPort, in)
		if err != nil {
			return nil, status.Error(code, err.Error())
		}
@@ -313,14 +313,14 @@ func (s *ApiServer) AuthenticateFacebook(ctx context.Context, in *api.Authentica

	create := in.Create == nil || in.Create.Value

	dbUserID, dbUsername, created, err := AuthenticateFacebook(s.logger, s.db, s.socialClient, in.Account.Token, username, create)
	dbUserID, dbUsername, created, err := AuthenticateFacebook(ctx, s.logger, s.db, s.socialClient, in.Account.Token, username, create)
	if err != nil {
		return nil, err
	}

	// Import friends if requested.
	if in.Import == nil || in.Import.Value {
		importFacebookFriends(s.logger, s.db, s.router, s.socialClient, uuid.FromStringOrNil(dbUserID), dbUsername, in.Account.Token, false)
		importFacebookFriends(ctx, s.logger, s.db, s.router, s.socialClient, uuid.FromStringOrNil(dbUserID), dbUsername, in.Account.Token, false)
	}

	token, exp := generateToken(s.config, dbUserID, dbUsername)
@@ -336,7 +336,7 @@ func (s *ApiServer) AuthenticateFacebook(ctx context.Context, in *api.Authentica

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		fn(s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)
		fn(ctx, s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)

		// Stats measurement end boundary.
		span.End()
@@ -358,7 +358,7 @@ func (s *ApiServer) AuthenticateGameCenter(ctx context.Context, in *api.Authenti

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		result, err, code := fn(s.logger, "", "", 0, clientIP, clientPort, in)
		result, err, code := fn(ctx, s.logger, "", "", 0, clientIP, clientPort, in)
		if err != nil {
			return nil, status.Error(code, err.Error())
		}
@@ -401,7 +401,7 @@ func (s *ApiServer) AuthenticateGameCenter(ctx context.Context, in *api.Authenti

	create := in.Create == nil || in.Create.Value

	dbUserID, dbUsername, created, err := AuthenticateGameCenter(s.logger, s.db, s.socialClient, in.Account.PlayerId, in.Account.BundleId, in.Account.TimestampSeconds, in.Account.Salt, in.Account.Signature, in.Account.PublicKeyUrl, username, create)
	dbUserID, dbUsername, created, err := AuthenticateGameCenter(ctx, s.logger, s.db, s.socialClient, in.Account.PlayerId, in.Account.BundleId, in.Account.TimestampSeconds, in.Account.Salt, in.Account.Signature, in.Account.PublicKeyUrl, username, create)
	if err != nil {
		return nil, err
	}
@@ -419,7 +419,7 @@ func (s *ApiServer) AuthenticateGameCenter(ctx context.Context, in *api.Authenti

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		fn(s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)
		fn(ctx, s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)

		// Stats measurement end boundary.
		span.End()
@@ -441,7 +441,7 @@ func (s *ApiServer) AuthenticateGoogle(ctx context.Context, in *api.Authenticate

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		result, err, code := fn(s.logger, "", "", 0, clientIP, clientPort, in)
		result, err, code := fn(ctx, s.logger, "", "", 0, clientIP, clientPort, in)
		if err != nil {
			return nil, status.Error(code, err.Error())
		}
@@ -472,7 +472,7 @@ func (s *ApiServer) AuthenticateGoogle(ctx context.Context, in *api.Authenticate

	create := in.Create == nil || in.Create.Value

	dbUserID, dbUsername, created, err := AuthenticateGoogle(s.logger, s.db, s.socialClient, in.Account.Token, username, create)
	dbUserID, dbUsername, created, err := AuthenticateGoogle(ctx, s.logger, s.db, s.socialClient, in.Account.Token, username, create)
	if err != nil {
		return nil, err
	}
@@ -490,7 +490,7 @@ func (s *ApiServer) AuthenticateGoogle(ctx context.Context, in *api.Authenticate

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		fn(s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)
		fn(ctx, s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)

		// Stats measurement end boundary.
		span.End()
@@ -512,7 +512,7 @@ func (s *ApiServer) AuthenticateSteam(ctx context.Context, in *api.AuthenticateS

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		result, err, code := fn(s.logger, "", "", 0, clientIP, clientPort, in)
		result, err, code := fn(ctx, s.logger, "", "", 0, clientIP, clientPort, in)
		if err != nil {
			return nil, status.Error(code, err.Error())
		}
@@ -547,7 +547,7 @@ func (s *ApiServer) AuthenticateSteam(ctx context.Context, in *api.AuthenticateS

	create := in.Create == nil || in.Create.Value

	dbUserID, dbUsername, created, err := AuthenticateSteam(s.logger, s.db, s.socialClient, s.config.GetSocial().Steam.AppID, s.config.GetSocial().Steam.PublisherKey, in.Account.Token, username, create)
	dbUserID, dbUsername, created, err := AuthenticateSteam(ctx, s.logger, s.db, s.socialClient, s.config.GetSocial().Steam.AppID, s.config.GetSocial().Steam.PublisherKey, in.Account.Token, username, create)
	if err != nil {
		return nil, err
	}
@@ -565,7 +565,7 @@ func (s *ApiServer) AuthenticateSteam(ctx context.Context, in *api.AuthenticateS

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		fn(s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)
		fn(ctx, s.logger, dbUserID, dbUsername, exp, clientIP, clientPort, session, in)

		// Stats measurement end boundary.
		span.End()
+3 −3
Original line number Diff line number Diff line
@@ -42,7 +42,7 @@ func (s *ApiServer) ListChannelMessages(ctx context.Context, in *api.ListChannel

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		result, err, code := fn(s.logger, userID.String(), ctx.Value(ctxUsernameKey{}).(string), ctx.Value(ctxExpiryKey{}).(int64), clientIP, clientPort, in)
		result, err, code := fn(ctx, s.logger, userID.String(), ctx.Value(ctxUsernameKey{}).(string), ctx.Value(ctxExpiryKey{}).(int64), clientIP, clientPort, in)
		if err != nil {
			return nil, status.Error(code, err.Error())
		}
@@ -80,7 +80,7 @@ func (s *ApiServer) ListChannelMessages(ctx context.Context, in *api.ListChannel
		return nil, status.Error(codes.InvalidArgument, "Invalid channel ID.")
	}

	messageList, err := ChannelMessagesList(s.logger, s.db, userID, streamConversionResult.Stream, in.ChannelId, limit, forward, in.Cursor)
	messageList, err := ChannelMessagesList(ctx, s.logger, s.db, userID, streamConversionResult.Stream, in.ChannelId, limit, forward, in.Cursor)
	if err == ErrChannelCursorInvalid {
		return nil, status.Error(codes.InvalidArgument, "Cursor is invalid or expired.")
	} else if err == ErrChannelGroupNotFound {
@@ -99,7 +99,7 @@ func (s *ApiServer) ListChannelMessages(ctx context.Context, in *api.ListChannel

		// Extract request information and execute the hook.
		clientIP, clientPort := extractClientAddress(s.logger, ctx)
		fn(s.logger, userID.String(), ctx.Value(ctxUsernameKey{}).(string), ctx.Value(ctxExpiryKey{}).(int64), clientIP, clientPort, messageList, in)
		fn(ctx, s.logger, userID.String(), ctx.Value(ctxUsernameKey{}).(string), ctx.Value(ctxExpiryKey{}).(int64), clientIP, clientPort, messageList, in)

		// Stats measurement end boundary.
		span.End()
Loading