Unverified Commit e4d71980 authored by Peyman Narimani's avatar Peyman Narimani Committed by GitHub
Browse files

Moved const/var errors from nakama to nakama-common

Resolves #691
parent a35db395
Loading
Loading
Loading
Loading
+30 −38
Original line number Diff line number Diff line
@@ -23,6 +23,7 @@ import (
	"errors"
	"fmt"
	"github.com/heroiclabs/nakama-common/rtapi"
	"github.com/heroiclabs/nakama-common/runtime"
	"strings"
	"time"
	"unicode/utf8"
@@ -35,12 +36,6 @@ import (
	"google.golang.org/protobuf/types/known/wrapperspb"
)

var (
	ErrChannelIDInvalid     = errors.New("invalid channel id")
	ErrChannelCursorInvalid = errors.New("invalid channel cursor")
	ErrChannelGroupNotFound = errors.New("group not found")
)

// Wrapper type to avoid allocating a stream struct when the input is invalid.
type ChannelIdToStreamResult struct {
	Stream PresenceStream
@@ -62,28 +57,28 @@ func ChannelMessagesList(ctx context.Context, logger *zap.Logger, db *sql.DB, ca
	if cursor != "" {
		cb, err := base64.StdEncoding.DecodeString(cursor)
		if err != nil {
			return nil, ErrChannelCursorInvalid
			return nil, runtime.ErrChannelCursorInvalid
		}
		incomingCursor = &channelMessageListCursor{}
		if err := gob.NewDecoder(bytes.NewReader(cb)).Decode(incomingCursor); err != nil {
			return nil, ErrChannelCursorInvalid
			return nil, runtime.ErrChannelCursorInvalid
		}

		if forward != incomingCursor.Forward {
			// Cursor is for a different channel message list direction.
			return nil, ErrChannelCursorInvalid
			return nil, runtime.ErrChannelCursorInvalid
		} else if stream.Mode != incomingCursor.StreamMode {
			// Stream mode does not match.
			return nil, ErrChannelCursorInvalid
			return nil, runtime.ErrChannelCursorInvalid
		} else if stream.Subject.String() != incomingCursor.StreamSubject {
			// Stream subject does not match.
			return nil, ErrChannelCursorInvalid
			return nil, runtime.ErrChannelCursorInvalid
		} else if stream.Subcontext.String() != incomingCursor.StreamSubcontext {
			// Stream subcontext does not match.
			return nil, ErrChannelCursorInvalid
			return nil, runtime.ErrChannelCursorInvalid
		} else if stream.Label != incomingCursor.StreamLabel {
			// Stream label does not match.
			return nil, ErrChannelCursorInvalid
			return nil, runtime.ErrChannelCursorInvalid
		}
	}

@@ -94,7 +89,7 @@ func ChannelMessagesList(ctx context.Context, logger *zap.Logger, db *sql.DB, ca
			return nil, err
		}
		if !allowed {
			return nil, ErrChannelGroupNotFound
			return nil, runtime.ErrChannelGroupNotFound
		}
	}

@@ -331,12 +326,12 @@ func GetChannelMessages(ctx context.Context, logger *zap.Logger, db *sql.DB, use

func ChannelIdToStream(channelID string) (*ChannelIdToStreamResult, error) {
	if channelID == "" {
		return nil, ErrChannelIDInvalid
		return nil, runtime.ErrChannelIDInvalid
	}

	components := strings.SplitN(channelID, ".", 4)
	if len(components) != 4 {
		return nil, ErrChannelIDInvalid
		return nil, runtime.ErrChannelIDInvalid
	}

	stream := PresenceStream{
@@ -349,23 +344,23 @@ func ChannelIdToStream(channelID string) (*ChannelIdToStreamResult, error) {
		// StreamModeChannel.
		// Expect no subject or subcontext.
		if components[1] != "" || components[2] != "" {
			return nil, ErrChannelIDInvalid
			return nil, runtime.ErrChannelIDInvalid
		}
		// Label.
		if l := len(components[3]); l < 1 || l > 64 {
			return nil, ErrChannelIDInvalid
			return nil, runtime.ErrChannelIDInvalid
		}
		stream.Label = components[3]
	case "3":
		// Expect no subcontext or label.
		if components[2] != "" || components[3] != "" {
			return nil, ErrChannelIDInvalid
			return nil, runtime.ErrChannelIDInvalid
		}
		// Subject.
		var err error
		if components[1] != "" {
			if stream.Subject, err = uuid.FromString(components[1]); err != nil {
				return nil, ErrChannelIDInvalid
				return nil, runtime.ErrChannelIDInvalid
			}
		}
		// Mode.
@@ -373,25 +368,25 @@ func ChannelIdToStream(channelID string) (*ChannelIdToStreamResult, error) {
	case "4":
		// Expect lo label.
		if components[3] != "" {
			return nil, ErrChannelIDInvalid
			return nil, runtime.ErrChannelIDInvalid
		}
		// Subject.
		var err error
		if components[1] != "" {
			if stream.Subject, err = uuid.FromString(components[1]); err != nil {
				return nil, ErrChannelIDInvalid
				return nil, runtime.ErrChannelIDInvalid
			}
		}
		// Subcontext.
		if components[2] != "" {
			if stream.Subcontext, err = uuid.FromString(components[2]); err != nil {
				return nil, ErrChannelIDInvalid
				return nil, runtime.ErrChannelIDInvalid
			}
		}
		// Mode.
		stream.Mode = StreamModeDM
	default:
		return nil, ErrChannelIDInvalid
		return nil, runtime.ErrChannelIDInvalid
	}

	return &ChannelIdToStreamResult{Stream: stream}, nil
@@ -399,7 +394,7 @@ func ChannelIdToStream(channelID string) (*ChannelIdToStreamResult, error) {

func StreamToChannelId(stream PresenceStream) (string, error) {
	if stream.Mode != StreamModeChannel && stream.Mode != StreamModeGroup && stream.Mode != StreamModeDM {
		return "", ErrChannelIDInvalid
		return "", runtime.ErrChannelIDInvalid
	}

	subject := ""
@@ -414,12 +409,9 @@ func StreamToChannelId(stream PresenceStream) (string, error) {
	return fmt.Sprintf("%v.%v.%v.%v", stream.Mode, subject, subcontext, stream.Label), nil
}

var errInvalidChannelTarget = errors.New("Invalid channel target")
var errInvalidChannelType = errors.New("Invalid channel type")

func BuildChannelId(ctx context.Context, logger *zap.Logger, db *sql.DB, userID uuid.UUID, target string, chanType rtapi.ChannelJoin_Type) (string, PresenceStream, error) {
	if target == "" {
		return "", PresenceStream{}, errInvalidChannelTarget
		return "", PresenceStream{}, runtime.ErrInvalidChannelTarget
	}

	stream := PresenceStream{
@@ -432,13 +424,13 @@ func BuildChannelId(ctx context.Context, logger *zap.Logger, db *sql.DB, userID
		fallthrough
	case rtapi.ChannelJoin_ROOM:
		if len(target) < 1 || len(target) > 64 {
			return "", PresenceStream{}, fmt.Errorf("Channel name is required and must be 1-64 chars: %w", errInvalidChannelTarget)
			return "", PresenceStream{}, fmt.Errorf("Channel name is required and must be 1-64 chars: %w", runtime.ErrInvalidChannelTarget)
		}
		if controlCharsRegex.MatchString(target) {
			return "", PresenceStream{}, fmt.Errorf("Channel name must not contain control chars: %w", errInvalidChannelTarget)
			return "", PresenceStream{}, fmt.Errorf("Channel name must not contain control chars: %w", runtime.ErrInvalidChannelTarget)
		}
		if !utf8.ValidString(target) {
			return "", PresenceStream{}, fmt.Errorf("Channel name must only contain valid UTF-8 bytes: %w", errInvalidChannelTarget)
			return "", PresenceStream{}, fmt.Errorf("Channel name must only contain valid UTF-8 bytes: %w", runtime.ErrInvalidChannelTarget)
		}
		stream.Label = target
		// Channel mode is already set by default above.
@@ -446,11 +438,11 @@ func BuildChannelId(ctx context.Context, logger *zap.Logger, db *sql.DB, userID
		// Check if user ID is valid.
		uid, err := uuid.FromString(target)
		if err != nil {
			return "", PresenceStream{}, fmt.Errorf("Invalid user ID in direct message join: %w", errInvalidChannelTarget)
			return "", PresenceStream{}, fmt.Errorf("Invalid user ID in direct message join: %w", runtime.ErrInvalidChannelTarget)
		}
		// Not allowed to chat to the nil uuid.
		if uid == uuid.Nil {
			return "", PresenceStream{}, fmt.Errorf("Invalid user ID in direct message join: %w", errInvalidChannelTarget)
			return "", PresenceStream{}, fmt.Errorf("Invalid user ID in direct message join: %w", runtime.ErrInvalidChannelTarget)
		}
		// If userID is the system user, skip these checks
		if userID != uuid.Nil {
@@ -460,7 +452,7 @@ func BuildChannelId(ctx context.Context, logger *zap.Logger, db *sql.DB, userID
				return "", PresenceStream{}, errors.New("Failed to look up user ID")
			}
			if !allowed {
				return "", PresenceStream{}, fmt.Errorf("User ID not found: %w", errInvalidChannelTarget)
				return "", PresenceStream{}, fmt.Errorf("User ID not found: %w", runtime.ErrInvalidChannelTarget)
			}
			// Assign the ID pair in a consistent order.
			if uid.String() > userID.String() {
@@ -476,7 +468,7 @@ func BuildChannelId(ctx context.Context, logger *zap.Logger, db *sql.DB, userID
		// Check if group ID is valid.
		gid, err := uuid.FromString(target)
		if err != nil {
			return "", PresenceStream{}, fmt.Errorf("Invalid group ID in group channel join: %w", errInvalidChannelTarget)
			return "", PresenceStream{}, fmt.Errorf("Invalid group ID in group channel join: %w", runtime.ErrInvalidChannelTarget)
		}
		if userID != uuid.Nil {
			allowed, err := groupCheckUserPermission(ctx, logger, db, gid, userID, 2)
@@ -484,14 +476,14 @@ func BuildChannelId(ctx context.Context, logger *zap.Logger, db *sql.DB, userID
				return "", PresenceStream{}, errors.New("Failed to look up group membership")
			}
			if !allowed {
				return "", PresenceStream{}, fmt.Errorf("Group not found: %w", errInvalidChannelTarget)
				return "", PresenceStream{}, fmt.Errorf("Group not found: %w", runtime.ErrInvalidChannelTarget)
			}
		}

		stream.Subject = gid
		stream.Mode = StreamModeGroup
	default:
		return "", PresenceStream{}, errInvalidChannelType
		return "", PresenceStream{}, runtime.ErrInvalidChannelType
	}

	channelID, err := StreamToChannelId(stream)
+4 −5
Original line number Diff line number Diff line
@@ -23,6 +23,7 @@ import (
	"encoding/json"
	"errors"
	"fmt"
	"github.com/heroiclabs/nakama-common/runtime"
	"strconv"
	"time"

@@ -34,8 +35,6 @@ import (
	"google.golang.org/protobuf/types/known/wrapperspb"
)

var ErrFriendInvalidCursor = errors.New("friend cursor invalid")

type edgeListCursor struct {
	// ID fields.
	State    int64
@@ -91,16 +90,16 @@ func ListFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker Tr
	if cursor != "" {
		cb, err := base64.StdEncoding.DecodeString(cursor)
		if err != nil {
			return nil, ErrFriendInvalidCursor
			return nil, runtime.ErrFriendInvalidCursor
		}
		incomingCursor = &edgeListCursor{}
		if err := gob.NewDecoder(bytes.NewReader(cb)).Decode(incomingCursor); err != nil {
			return nil, ErrFriendInvalidCursor
			return nil, runtime.ErrFriendInvalidCursor
		}

		// Cursor and filter mismatch. Perhaps the caller has sent an old cursor with a changed filter.
		if state != nil && int64(state.Value) != incomingCursor.State {
			return nil, ErrFriendInvalidCursor
			return nil, runtime.ErrFriendInvalidCursor
		}
	}

+44 −57
Original line number Diff line number Diff line
@@ -23,6 +23,7 @@ import (
	"encoding/json"
	"errors"
	"fmt"
	"github.com/heroiclabs/nakama-common/runtime"
	"math"
	"strconv"
	"strings"
@@ -40,20 +41,6 @@ import (
	"google.golang.org/protobuf/types/known/wrapperspb"
)

var (
	ErrGroupNameInUse         = errors.New("group name in use")
	ErrGroupPermissionDenied  = errors.New("group permission denied")
	ErrGroupNoUpdateOps       = errors.New("no group updates")
	ErrGroupNotUpdated        = errors.New("group not updated")
	ErrGroupNotFound          = errors.New("group not found")
	ErrGroupFull              = errors.New("group is full")
	ErrGroupUserNotFound      = errors.New("user not found")
	ErrGroupLastSuperadmin    = errors.New("user is last group superadmin")
	ErrGroupUserInvalidCursor = errors.New("group user cursor invalid")
	ErrUserGroupInvalidCursor = errors.New("user group cursor invalid")
	ErrGroupCreatorInvalid    = errors.New("group creator user ID not valid")
)

type groupListCursor struct {
	Lang       string
	EdgeCount  int32
@@ -75,7 +62,7 @@ func (c *groupListCursor) GetUpdateTime() time.Time {

func CreateGroup(ctx context.Context, logger *zap.Logger, db *sql.DB, userID uuid.UUID, creatorID uuid.UUID, name, lang, desc, avatarURL, metadata string, open bool, maxCount int) (*api.Group, error) {
	if userID == uuid.Nil {
		return nil, ErrGroupCreatorInvalid
		return nil, runtime.ErrGroupCreatorInvalid
	}

	state := 1
@@ -124,7 +111,7 @@ RETURNING id, creator_id, name, description, avatar_url, state, edge_count, lang
			var pgErr *pgconn.PgError
			if errors.As(err, &pgErr) && pgErr.Code == dbErrorUniqueViolation {
				logger.Info("Could not create group as it already exists.", zap.String("name", name))
				return ErrGroupNameInUse
				return runtime.ErrGroupNameInUse
			}
			logger.Debug("Could not create group.", zap.Error(err))
			return err
@@ -136,7 +123,7 @@ RETURNING id, creator_id, name, description, avatar_url, state, edge_count, lang
			var pgErr *pgconn.PgError
			if errors.As(err, &pgErr) && pgErr.Code == dbErrorUniqueViolation {
				logger.Info("Could not create group as it already exists.", zap.String("name", name))
				return ErrGroupNameInUse
				return runtime.ErrGroupNameInUse
			}
			logger.Debug("Could not parse rows.", zap.Error(err))
			return err
@@ -151,8 +138,8 @@ RETURNING id, creator_id, name, description, avatar_url, state, edge_count, lang

		return nil
	}); err != nil {
		if err == ErrGroupNameInUse {
			return nil, ErrGroupNameInUse
		if err == runtime.ErrGroupNameInUse {
			return nil, runtime.ErrGroupNameInUse
		}
		logger.Error("Error creating group.", zap.Error(err))
		return nil, err
@@ -172,7 +159,7 @@ func UpdateGroup(ctx context.Context, logger *zap.Logger, db *sql.DB, groupID uu

		if !allowedUser {
			logger.Info("User does not have permission to update group.", zap.String("group", groupID.String()), zap.String("user", userID.String()))
			return ErrGroupPermissionDenied
			return runtime.ErrGroupPermissionDenied
		}
	}

@@ -241,7 +228,7 @@ func UpdateGroup(ctx context.Context, logger *zap.Logger, db *sql.DB, groupID uu

	if len(statements) == 0 {
		logger.Info("Did not update group as no fields were changed.")
		return ErrGroupNoUpdateOps
		return runtime.ErrGroupNoUpdateOps
	}

	query := "UPDATE groups SET update_time = now(), " + strings.Join(statements, ", ") + " WHERE (id = $1) AND (disable_time = '1970-01-01 00:00:00 UTC')"
@@ -250,7 +237,7 @@ func UpdateGroup(ctx context.Context, logger *zap.Logger, db *sql.DB, groupID uu
		var pgErr *pgconn.PgError
		if errors.As(err, &pgErr) && pgErr.Code == dbErrorUniqueViolation {
			logger.Info("Could not update group as it already exists.", zap.String("group_id", groupID.String()))
			return ErrGroupNameInUse
			return runtime.ErrGroupNameInUse
		}
		logger.Error("Could not update group.", zap.Error(err))
		return err
@@ -262,7 +249,7 @@ func UpdateGroup(ctx context.Context, logger *zap.Logger, db *sql.DB, groupID uu
		return err
	}
	if rowsAffected == 0 {
		return ErrGroupNotUpdated
		return runtime.ErrGroupNotUpdated
	}

	logger.Info("Group updated.", zap.String("group_id", groupID.String()), zap.String("user_id", userID.String()))
@@ -280,7 +267,7 @@ func DeleteGroup(ctx context.Context, logger *zap.Logger, db *sql.DB, groupID uu

		if !allowedUser {
			logger.Info("User does not have permission to delete group.", zap.String("group_id", groupID.String()), zap.String("user_id", userID.String()))
			return ErrGroupPermissionDenied
			return runtime.ErrGroupPermissionDenied
		}
	}

@@ -322,13 +309,13 @@ WHERE (id = $1) AND (disable_time = '1970-01-01 00:00:00 UTC')`

	if len(groups) == 0 {
		logger.Info("Group does not exist.", zap.Error(err), zap.String("group_id", groupID.String()))
		return ErrGroupNotFound
		return runtime.ErrGroupNotFound
	}

	group := groups[0]
	if group.EdgeCount >= group.MaxCount {
		logger.Info("Group maximum count has reached.", zap.Error(err), zap.String("group_id", groupID.String()))
		return ErrGroupFull
		return runtime.ErrGroupFull
	}

	state := 2
@@ -492,7 +479,7 @@ func LeaveGroup(ctx context.Context, logger *zap.Logger, db *sql.DB, router Mess

		if otherSuperadminCount.Int64 == 0 {
			logger.Info("Cannot leave group as user is last superadmin.", zap.String("group_id", groupID.String()), zap.String("user_id", userID.String()))
			return ErrGroupLastSuperadmin
			return runtime.ErrGroupLastSuperadmin
		}
	}

@@ -551,7 +538,7 @@ func LeaveGroup(ctx context.Context, logger *zap.Logger, db *sql.DB, router Mess

			if rowsAffected == 0 {
				logger.Debug("Did not update group edge_count as group is disabled.")
				return ErrGroupNotFound
				return runtime.ErrGroupNotFound
			}
		}

@@ -581,7 +568,7 @@ func AddGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, router M
		if err := db.QueryRowContext(ctx, query, groupID, caller).Scan(&dbState); err != nil {
			if err == sql.ErrNoRows {
				logger.Info("Could not retrieve state as no group relationship exists.", zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()))
				return ErrGroupPermissionDenied
				return runtime.ErrGroupPermissionDenied
			}
			logger.Error("Could not retrieve state from group_edge.", zap.Error(err), zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()))
			return err
@@ -589,7 +576,7 @@ func AddGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, router M

		if dbState.Int64 > 1 {
			logger.Info("Cannot add users as user does not have correct permissions.", zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()), zap.Int64("state", dbState.Int64))
			return ErrGroupPermissionDenied
			return runtime.ErrGroupPermissionDenied
		}
	}

@@ -598,7 +585,7 @@ func AddGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, router M
	if err := db.QueryRowContext(ctx, query, groupID).Scan(&groupName); err != nil {
		if err == sql.ErrNoRows {
			logger.Info("Cannot add users to disabled group.", zap.String("group_id", groupID.String()))
			return ErrGroupNotFound
			return runtime.ErrGroupNotFound
		}
		logger.Error("Could not look up group when adding users.", zap.Error(err), zap.String("group_id", groupID.String()))
		return err
@@ -648,7 +635,7 @@ func AddGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, router M
			query := "SELECT username FROM users WHERE id = $1::UUID"
			if err := tx.QueryRowContext(ctx, query, uid).Scan(&username); err != nil {
				if err == sql.ErrNoRows {
					return ErrGroupUserNotFound
					return runtime.ErrGroupUserNotFound
				}
				logger.Debug("Could not retrieve username to add user to group.", zap.Error(err), zap.String("group_id", groupID.String()), zap.String("user_id", uid.String()))
				return err
@@ -692,7 +679,7 @@ func AddGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, router M
					return err
				} else if rowsAffected == 0 {
					logger.Info("Could not add users as group maximum count was reached.", zap.String("group_id", groupID.String()), zap.String("user_id", uid.String()))
					return ErrGroupFull
					return runtime.ErrGroupFull
				}
			} else {
				// If we reach here then this was a repeated (or failed, if the user was banned) operation.
@@ -759,7 +746,7 @@ func BanGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, router M
		if err := db.QueryRowContext(ctx, query, groupID, caller).Scan(&dbState); err != nil {
			if err == sql.ErrNoRows {
				logger.Info("Could not retrieve state as no group relationship exists.", zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()))
				return ErrGroupPermissionDenied
				return runtime.ErrGroupPermissionDenied
			}
			logger.Error("Could not retrieve state from group_edge.", zap.Error(err), zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()))
			return err
@@ -768,7 +755,7 @@ func BanGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, router M
		myState = int(dbState.Int64)
		if myState > 1 {
			logger.Info("Cannot ban users as user does not have correct permissions.", zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()), zap.Int("state", myState))
			return ErrGroupPermissionDenied
			return runtime.ErrGroupPermissionDenied
		}
	}

@@ -875,7 +862,7 @@ UPDATE SET state = $2, update_time = now()`
				query = "SELECT username FROM users WHERE id = $1::UUID"
				if err := tx.QueryRowContext(ctx, query, uid).Scan(&username); err != nil {
					if err == sql.ErrNoRows {
						return ErrGroupUserNotFound
						return runtime.ErrGroupUserNotFound
					}
					logger.Debug("Could not retrieve username to ban user from group.", zap.Error(err), zap.String("group_id", groupID.String()), zap.String("user_id", uid.String()))
					return err
@@ -925,7 +912,7 @@ func KickGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, router
		if err := db.QueryRowContext(ctx, query, groupID, caller).Scan(&dbState); err != nil {
			if err == sql.ErrNoRows {
				logger.Info("Could not retrieve state as no group relationship exists.", zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()))
				return ErrGroupPermissionDenied
				return runtime.ErrGroupPermissionDenied
			}
			logger.Error("Could not retrieve state from group_edge.", zap.Error(err), zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()))
			return err
@@ -934,7 +921,7 @@ func KickGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, router
		myState = int(dbState.Int64)
		if myState > 1 {
			logger.Info("Cannot kick users as user does not have correct permissions.", zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()), zap.Int("state", myState))
			return ErrGroupPermissionDenied
			return runtime.ErrGroupPermissionDenied
		}
	}

@@ -1030,7 +1017,7 @@ RETURNING state`
				query = "SELECT username FROM users WHERE id = $1::UUID"
				if err := tx.QueryRowContext(ctx, query, uid).Scan(&username); err != nil {
					if err == sql.ErrNoRows {
						return ErrGroupUserNotFound
						return runtime.ErrGroupUserNotFound
					}
					logger.Debug("Could not retrieve username to kick user from group.", zap.Error(err), zap.String("group_id", groupID.String()), zap.String("user_id", uid.String()))
					return err
@@ -1080,7 +1067,7 @@ func PromoteGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, rout
		if err := db.QueryRowContext(ctx, query, groupID, caller).Scan(&dbState); err != nil {
			if err == sql.ErrNoRows {
				logger.Info("Could not retrieve state as no group relationship exists.", zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()))
				return ErrGroupPermissionDenied
				return runtime.ErrGroupPermissionDenied
			}
			logger.Error("Could not retrieve state from group_edge.", zap.Error(err), zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()))
			return err
@@ -1089,7 +1076,7 @@ func PromoteGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, rout
		myState = int(dbState.Int64)
		if myState > 1 {
			logger.Info("Cannot promote users as user does not have correct permissions.", zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()), zap.Int("state", myState))
			return ErrGroupPermissionDenied
			return runtime.ErrGroupPermissionDenied
		}
	}

@@ -1102,7 +1089,7 @@ func PromoteGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, rout
	}
	if !groupExists.Bool {
		logger.Info("Cannot promote users to disabled group.", zap.String("group_id", groupID.String()))
		return ErrGroupNotFound
		return runtime.ErrGroupNotFound
	}

	// Prepare the messages we'll need to send to the group channel.
@@ -1163,7 +1150,7 @@ RETURNING state`
					return err
				} else if rowsAffected == 0 {
					logger.Debug("Did not update group edge count - check edge count has not reached max count.", zap.String("group_id", groupID.String()), zap.String("user_id", uid.String()))
					return ErrGroupFull
					return runtime.ErrGroupFull
				}
			}

@@ -1172,7 +1159,7 @@ RETURNING state`
			query = "SELECT username FROM users WHERE id = $1::UUID"
			if err := tx.QueryRowContext(ctx, query, uid).Scan(&username); err != nil {
				if err == sql.ErrNoRows {
					return ErrGroupUserNotFound
					return runtime.ErrGroupUserNotFound
				}
				logger.Debug("Could not retrieve username to promote user in group.", zap.Error(err), zap.String("group_id", groupID.String()), zap.String("user_id", uid.String()))
				return err
@@ -1221,7 +1208,7 @@ func DemoteGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, route
		if err := db.QueryRowContext(ctx, query, groupID, caller).Scan(&dbState); err != nil {
			if err == sql.ErrNoRows {
				logger.Info("Could not retrieve state as no group relationship exists.", zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()))
				return ErrGroupPermissionDenied
				return runtime.ErrGroupPermissionDenied
			}
			logger.Error("Could not retrieve state from group_edge.", zap.Error(err), zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()))
			return err
@@ -1230,7 +1217,7 @@ func DemoteGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, route
		myState = int(dbState.Int64)
		if myState > 1 {
			logger.Info("Cannot demote users as user does not have correct permissions.", zap.String("group_id", groupID.String()), zap.String("user_id", caller.String()), zap.Int("state", myState))
			return ErrGroupPermissionDenied
			return runtime.ErrGroupPermissionDenied
		}
	}

@@ -1243,7 +1230,7 @@ func DemoteGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, route
	}
	if !groupExists.Bool {
		logger.Info("Cannot demote users in a disabled group.", zap.String("group_id", groupID.String()))
		return ErrGroupNotFound
		return runtime.ErrGroupNotFound
	}

	// Prepare the messages we'll need to send to the group channel.
@@ -1318,7 +1305,7 @@ RETURNING state`
			query = "SELECT username FROM users WHERE id = $1::UUID"
			if err := tx.QueryRowContext(ctx, query, uid).Scan(&username); err != nil {
				if err == sql.ErrNoRows {
					return ErrGroupUserNotFound
					return runtime.ErrGroupUserNotFound
				}
				logger.Debug("Could not retrieve username to demote user in group.", zap.Error(err), zap.String("group_id", groupID.String()), zap.String("user_id", uid.String()))
				return err
@@ -1364,16 +1351,16 @@ func ListGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker
	if cursor != "" {
		cb, err := base64.StdEncoding.DecodeString(cursor)
		if err != nil {
			return nil, ErrGroupUserInvalidCursor
			return nil, runtime.ErrGroupUserInvalidCursor
		}
		incomingCursor = &edgeListCursor{}
		if err := gob.NewDecoder(bytes.NewReader(cb)).Decode(incomingCursor); err != nil {
			return nil, ErrGroupUserInvalidCursor
			return nil, runtime.ErrGroupUserInvalidCursor
		}

		// Cursor and filter mismatch. Perhaps the caller has sent an old cursor with a changed filter.
		if state != nil && int64(state.Value) != incomingCursor.State {
			return nil, ErrGroupUserInvalidCursor
			return nil, runtime.ErrGroupUserInvalidCursor
		}
	}

@@ -1411,7 +1398,7 @@ WHERE u.id = ge.destination_id AND ge.source_id = $1`
	rows, err := db.QueryContext(ctx, query, params...)
	if err != nil {
		if err == sql.ErrNoRows {
			return nil, ErrGroupNotFound
			return nil, runtime.ErrGroupNotFound
		}

		logger.Debug("Could not list users in group.", zap.Error(err), zap.String("group_id", groupID.String()))
@@ -1446,7 +1433,7 @@ WHERE u.id = ge.destination_id AND ge.source_id = $1`
		if err := rows.Scan(&id, &username, &displayName, &avatarURL, &langTag, &location, &timezone, &metadata,
			&apple, &facebook, &facebookInstantGame, &google, &gamecenter, &steam, &edgeCount, &createTime, &updateTime, &state, &position); err != nil {
			if err == sql.ErrNoRows {
				return nil, ErrGroupNotFound
				return nil, runtime.ErrGroupNotFound
			}
			logger.Error("Could not parse rows when listing users in a group.", zap.Error(err), zap.String("group_id", groupID.String()))
			return nil, err
@@ -1502,16 +1489,16 @@ func ListUserGroups(ctx context.Context, logger *zap.Logger, db *sql.DB, userID
	if cursor != "" {
		cb, err := base64.StdEncoding.DecodeString(cursor)
		if err != nil {
			return nil, ErrUserGroupInvalidCursor
			return nil, runtime.ErrUserGroupInvalidCursor
		}
		incomingCursor = &edgeListCursor{}
		if err := gob.NewDecoder(bytes.NewReader(cb)).Decode(incomingCursor); err != nil {
			return nil, ErrUserGroupInvalidCursor
			return nil, runtime.ErrUserGroupInvalidCursor
		}

		// Cursor and filter mismatch. Perhaps the caller has sent an old cursor with a changed filter.
		if state != nil && int64(state.Value) != incomingCursor.State {
			return nil, ErrUserGroupInvalidCursor
			return nil, runtime.ErrUserGroupInvalidCursor
		}
	}

@@ -1548,7 +1535,7 @@ WHERE g.id = ge.destination_id AND ge.source_id = $1`
	rows, err := db.QueryContext(ctx, query, params...)
	if err != nil {
		if err == sql.ErrNoRows {
			return nil, ErrGroupNotFound
			return nil, runtime.ErrGroupNotFound
		}
		logger.Debug("Could not list groups for a user.", zap.Error(err), zap.String("user_id", userID.String()))
		return nil, err
+7 −11

File changed.

Preview size limit exceeded, changes collapsed.

+19 −28

File changed.

Preview size limit exceeded, changes collapsed.

Loading