Commit c66891df authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Include groups in GDPR handling.

parent 4ab4a0a1
Loading
Loading
Loading
Loading
+5 −1
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ package server

import (
	"github.com/heroiclabs/nakama/api"
	"github.com/satori/go.uuid"
	"golang.org/x/net/context"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
@@ -44,9 +45,12 @@ func (s *ApiServer) ListChannelMessages(ctx context.Context, in *api.ListChannel
		return nil, status.Error(codes.InvalidArgument, "Invalid channel ID.")
	}

	messageList, err := ChannelMessagesList(s.logger, s.db, streamConversionResult.Stream, in.ChannelId, limit, forward, in.Cursor)
	userID := ctx.Value(ctxUserIDKey{}).(uuid.UUID)
	messageList, err := ChannelMessagesList(s.logger, s.db, userID, streamConversionResult.Stream, in.ChannelId, limit, forward, in.Cursor)
	if err == ErrChannelCursorInvalid {
		return nil, status.Error(codes.InvalidArgument, "Cursor is invalid or expired.")
	} else if err == ErrChannelGroupNotFound {
		return nil, status.Error(codes.InvalidArgument, "Group not found.")
	} else if err != nil {
		return nil, status.Error(codes.Internal, "Error listing messages from channel.")
	}
+38 −37
Original line number Diff line number Diff line
@@ -32,11 +32,10 @@ func (s *ApiServer) CreateGroup(ctx context.Context, in *api.CreateGroupRequest)

	group, err := CreateGroup(s.logger, s.db, userID, userID, in.GetName(), in.GetLangTag(), in.GetDescription(), in.GetAvatarUrl(), "", in.GetOpen(), -1)
	if err != nil {
		return nil, status.Error(codes.Internal, "Error while trying to create group.")
		if err == ErrGroupNameInUse {
			return nil, status.Error(codes.InvalidArgument, "Group name is in use.")
		}

	if group == nil {
		return nil, status.Error(codes.InvalidArgument, "Did not create group as a group already exists with the same name.")
		return nil, status.Error(codes.Internal, "Error while trying to create group.")
	}

	return group, nil
@@ -65,13 +64,16 @@ func (s *ApiServer) UpdateGroup(ctx context.Context, in *api.UpdateGroupRequest)
	}

	userID := ctx.Value(ctxUserIDKey{}).(uuid.UUID)
	updated, err := UpdateGroup(s.logger, s.db, groupID, userID, nil, in.GetName(), in.GetLangTag(), in.GetDescription(), in.GetAvatarUrl(), nil, in.GetOpen(), -1)
	err = UpdateGroup(s.logger, s.db, groupID, userID, nil, in.GetName(), in.GetLangTag(), in.GetDescription(), in.GetAvatarUrl(), nil, in.GetOpen(), -1)
	if err != nil {
		return nil, status.Error(codes.Internal, "Error while trying to update group.")
		if err == ErrGroupPermissionDenied {
			return nil, status.Error(codes.NotFound, "Group not found or you're not allowed to update.")
		} else if err == ErrGroupNoUpdateOps {
			return nil, status.Error(codes.InvalidArgument, "Specify at least one field to update.")
		} else if err == ErrGroupNotUpdated {
			return nil, status.Error(codes.InvalidArgument, "No new fields in group update.")
		}

	if !updated {
		return nil, status.Error(codes.InvalidArgument, "Did not update group - Make sure that group exists, group name is unique and you have the correct permissions.")
		return nil, status.Error(codes.Internal, "Error while trying to update group.")
	}

	return &empty.Empty{}, nil
@@ -88,13 +90,12 @@ func (s *ApiServer) DeleteGroup(ctx context.Context, in *api.DeleteGroupRequest)
	}

	userID := ctx.Value(ctxUserIDKey{}).(uuid.UUID)
	deleted, err := DeleteGroup(s.logger, s.db, groupID, userID)
	err = DeleteGroup(s.logger, s.db, groupID, userID)
	if err != nil {
		return nil, status.Error(codes.Internal, "Error while trying to delete group.")
		if err == ErrGroupPermissionDenied {
			return nil, status.Error(codes.InvalidArgument, "Group not found or you're not allowed to delete.")
		}

	if !deleted {
		return nil, status.Error(codes.InvalidArgument, "Did not delete group - Make sure that group exists and you have the correct permissions.")
		return nil, status.Error(codes.Internal, "Error while trying to delete group.")
	}

	return &empty.Empty{}, nil
@@ -112,14 +113,14 @@ func (s *ApiServer) JoinGroup(ctx context.Context, in *api.JoinGroupRequest) (*e

	userID := ctx.Value(ctxUserIDKey{}).(uuid.UUID)

	joined, err := JoinGroup(s.logger, s.db, groupID, userID)

	err = JoinGroup(s.logger, s.db, groupID, userID)
	if err != nil {
		return nil, status.Error(codes.Internal, "Error while trying to join group.")
		if err == ErrGroupNotFound {
			return nil, status.Error(codes.NotFound, "Group not found.")
		} else if err == ErrGroupFull {
			return nil, status.Error(codes.InvalidArgument, "Group is full.")
		}

	if !joined {
		return nil, status.Error(codes.InvalidArgument, "Did not join group - Make sure that group exists and maximum count has not been reached.")
		return nil, status.Error(codes.Internal, "Error while trying to join group.")
	}

	return &empty.Empty{}, nil
@@ -136,14 +137,12 @@ func (s *ApiServer) LeaveGroup(ctx context.Context, in *api.LeaveGroupRequest) (
	}

	userID := ctx.Value(ctxUserIDKey{}).(uuid.UUID)
	left, err := LeaveGroup(s.logger, s.db, groupID, userID)

	err = LeaveGroup(s.logger, s.db, groupID, userID)
	if err != nil {
		return nil, status.Error(codes.Internal, "Error while trying to leave group.")
		if err == ErrGroupLastSuperadmin {
			return nil, status.Error(codes.InvalidArgument, "Cannot leave group when you are the last superadmin.")
		}

	if !left {
		return nil, status.Error(codes.InvalidArgument, "Did not leave group - Make sure that group exists and you have the correct permissions.")
		return nil, status.Error(codes.Internal, "Error while trying to leave group.")
	}

	return &empty.Empty{}, nil
@@ -173,14 +172,14 @@ func (s *ApiServer) AddGroupUsers(ctx context.Context, in *api.AddGroupUsersRequ
	}

	userID := ctx.Value(ctxUserIDKey{}).(uuid.UUID)
	done, err := AddGroupUsers(s.logger, s.db, userID, groupID, userIDs)

	err = AddGroupUsers(s.logger, s.db, userID, groupID, userIDs)
	if err != nil {
		return nil, status.Error(codes.Internal, "Error while trying to add users to a group.")
		if err == ErrGroupPermissionDenied {
			return nil, status.Error(codes.NotFound, "Group not found or permission denied.")
		} else if err == ErrGroupFull {
			return nil, status.Error(codes.InvalidArgument, "Group is full.")
		}

	if !done {
		return nil, status.Error(codes.InvalidArgument, "Did not add users to group - Make sure that group exists, you have correct permissions, and maximum member count is not reached.")
		return nil, status.Error(codes.Internal, "Error while trying to add users to a group.")
	}

	return &empty.Empty{}, nil
@@ -211,6 +210,9 @@ func (s *ApiServer) KickGroupUsers(ctx context.Context, in *api.KickGroupUsersRe

	userID := ctx.Value(ctxUserIDKey{}).(uuid.UUID)
	if err = KickGroupUsers(s.logger, s.db, userID, groupID, userIDs); err != nil {
		if err == ErrGroupPermissionDenied {
			return nil, status.Error(codes.NotFound, "Group not found or permission denied.")
		}
		return nil, status.Error(codes.Internal, "Error while trying to kick users from a group.")
	}

@@ -241,13 +243,12 @@ func (s *ApiServer) PromoteGroupUsers(ctx context.Context, in *api.PromoteGroupU
	}

	userID := ctx.Value(ctxUserIDKey{}).(uuid.UUID)
	promoted, err := PromoteGroupUsers(s.logger, s.db, userID, groupID, userIDs)
	err = PromoteGroupUsers(s.logger, s.db, userID, groupID, userIDs)
	if err != nil {
		return nil, status.Error(codes.Internal, "Error while trying to promote users in a group.")
		if err == ErrGroupPermissionDenied {
			return nil, status.Error(codes.NotFound, "Group not found or permission denied.")
		}

	if !promoted {
		return nil, status.Error(codes.InvalidArgument, "Did not promote users to group - Make sure that group exists and you have correct permissions.")
		return nil, status.Error(codes.Internal, "Error while trying to promote users in a group.")
	}

	return &empty.Empty{}, nil
+44 −12
Original line number Diff line number Diff line
@@ -18,8 +18,11 @@ import (
	"context"

	"encoding/json"

	"github.com/cockroachdb/cockroach-go/crdb"
	"github.com/golang/protobuf/ptypes/empty"
	"github.com/golang/protobuf/ptypes/timestamp"
	"github.com/heroiclabs/nakama/api"
	"github.com/heroiclabs/nakama/console"
	"github.com/satori/go.uuid"
	"go.uber.org/zap"
@@ -33,22 +36,41 @@ func (s *ConsoleServer) DeleteAccount(ctx context.Context, in *console.AccountId
		return nil, status.Error(codes.InvalidArgument, "Invalid user ID was provided.")
	}

	count, err := DeleteUser(s.db, userID)
	tx, err := s.db.Begin()
	if err != nil {
		s.logger.Error("Could not delete user", zap.Error(err), zap.String("user_id", in.Id))
		s.logger.Error("Could not begin database transaction.", zap.Error(err))
		return nil, status.Error(codes.Internal, "An error occurred while trying to delete the user.")
	}

	if err := crdb.ExecuteInTx(context.Background(), tx, func() error {
		count, err := DeleteUser(tx, userID)
		if err != nil {
			s.logger.Debug("Could not delete user", zap.Error(err), zap.String("user_id", in.Id))
			return err
		} else if count == 0 {
			s.logger.Info("No user was found to delete. Skipping blacklist.", zap.String("user_id", in.Id))
		return &empty.Empty{}, nil
			return nil
		}

	err = LeaderboardRecordsDeleteAll(s.logger, s.db, userID)
		err = LeaderboardRecordsDeleteAll(s.logger, tx, userID)
		if err != nil {
		return nil, status.Error(codes.Internal, "An error occurred while trying to delete the user.")
			s.logger.Debug("Could not delete leaderboard records.", zap.Error(err), zap.String("user_id", in.Id))
			return err
		}

	if _, err = s.db.Exec(`INSERT INTO user_tombstone (user_id) VALUES ($1) ON CONFLICT(user_id) DO NOTHING`, userID); err != nil {
		s.logger.Error("Could not insert user ID into tombstone", zap.Error(err), zap.String("user_id", in.Id))
		err = GroupDeleteAll(s.logger, tx, userID)
		if err != nil {
			s.logger.Debug("Could not delete groups and relationships.", zap.Error(err), zap.String("user_id", in.Id))
			return err
		}

		if _, err = tx.Exec(`INSERT INTO user_tombstone (user_id) VALUES ($1) ON CONFLICT(user_id) DO NOTHING`, userID); err != nil {
			s.logger.Debug("Could not insert user ID into tombstone", zap.Error(err), zap.String("user_id", in.Id))
			return err
		}
		return nil
	}); err != nil {
		s.logger.Error("Error occurred while trying to delete the user.", zap.Error(err), zap.String("user_id", in.Id))
		return nil, status.Error(codes.Internal, "An error occurred while trying to delete the user.")
	}

@@ -89,6 +111,16 @@ func (s *ConsoleServer) ExportAccount(ctx context.Context, in *console.AccountId
		return nil, status.Error(codes.Internal, "An error occurred while trying to export user data.")
	}

	groups := make([]*api.Group, 0)
	groupUsers, err := ListUserGroups(s.logger, s.db, userID)
	if err != nil {
		s.logger.Error("Could not fetch groups that belong to the user", zap.Error(err), zap.String("user_id", in.Id))
		return nil, status.Error(codes.Internal, "An error occurred while trying to export user data.")
	}
	for _, g := range groupUsers.UserGroups {
		groups = append(groups, g.Group)
	}

	// Notifications.
	notifications, err := NotificationList(s.logger, s.db, userID, 0, "", nil)
	if err != nil {
@@ -131,12 +163,12 @@ func (s *ConsoleServer) ExportAccount(ctx context.Context, in *console.AccountId
		}
	}

	// TODO(mo, zyro) add groups
	export := &console.AccountExport{
		Account:            account,
		Objects:            storageObjects,
		Friends:            friends.GetFriends(),
		Messages:           messages,
		Groups:             groups,
		LeaderboardRecords: leaderboardRecords,
		Notifications:      notifications.GetNotifications(),
		WalletLedgers:      wl,
+55 −23
Original line number Diff line number Diff line
@@ -34,6 +34,7 @@ import (
var (
	ErrChannelIdInvalid     = errors.New("invalid channel id")
	ErrChannelCursorInvalid = errors.New("invalid channel cursor")
	ErrChannelGroupNotFound = errors.New("group not found")
)

// Wrapper type to avoid allocating a stream struct when the input is invalid.
@@ -52,7 +53,7 @@ type channelMessageListCursor struct {
	IsNext           bool
}

func ChannelMessagesList(logger *zap.Logger, db *sql.DB, stream PresenceStream, channelId string, limit int, forward bool, cursor string) (*api.ChannelMessageList, error) {
func ChannelMessagesList(logger *zap.Logger, db *sql.DB, caller uuid.UUID, stream PresenceStream, channelId string, limit int, forward bool, cursor string) (*api.ChannelMessageList, error) {
	var incomingCursor *channelMessageListCursor
	if cursor != "" {
		if cb, err := base64.StdEncoding.DecodeString(cursor); err != nil {
@@ -82,6 +83,17 @@ func ChannelMessagesList(logger *zap.Logger, db *sql.DB, stream PresenceStream,
		}
	}

	// If it's a group, check membership.
	if !uuid.Equal(uuid.Nil, caller) && stream.Mode == StreamModeGroup {
		allowed, err := groupCheckUserPermission(logger, db, stream.Subject, caller, 2)
		if err != nil {
			return nil, err
		}
		if !allowed {
			return nil, ErrChannelGroupNotFound
		}
	}

	query := `SELECT id, code, sender_id, username, content, create_time, update_time FROM message
WHERE stream_mode = $1 AND stream_subject = $2::UUID AND stream_descriptor = $3::UUID AND stream_label = $4`
	if incomingCursor == nil {
@@ -263,7 +275,7 @@ func ChannelIdToStream(channelId string) (*ChannelIdToStreamResult, error) {
		return nil, ErrChannelIdInvalid
	}

	components := strings.SplitN(channelId, ":", 4)
	components := strings.SplitN(channelId, ".", 4)
	if len(components) != 4 {
		return nil, ErrChannelIdInvalid
	}
@@ -275,33 +287,53 @@ func ChannelIdToStream(channelId string) (*ChannelIdToStreamResult, error) {
	// Parse and assign mode.
	switch components[0] {
	case "2":
		// StreamModeChannel
		// StreamModeChannel.
		// Expect no subject or descriptor.
		if components[1] != "" || components[2] != "" {
			return nil, ErrChannelIdInvalid
		}
		// Label.
		if l := len(components[3]); l < 1 || l > 64 {
			return nil, ErrChannelIdInvalid
		}
		stream.Label = components[3]
	case "3":
		// Expect no descriptor or label.
		if components[2] != "" || components[3] != "" {
			return nil, ErrChannelIdInvalid
		}
		// Subject.
		var err error
		if components[1] != "" {
			if stream.Subject, err = uuid.FromString(components[1]); err != nil {
				return nil, ErrChannelIdInvalid
			}
		}
		// Mode.
		stream.Mode = StreamModeGroup
	case "4":
		stream.Mode = StreamModeDM
	default:
		// Expect lo label.
		if components[3] != "" {
			return nil, ErrChannelIdInvalid
		}

	var err error

		// Subject.
		var err error
		if components[1] != "" {
			if stream.Subject, err = uuid.FromString(components[1]); err != nil {
				return nil, ErrChannelIdInvalid
			}
		}

		// Descriptor.
		if components[2] != "" {
			if stream.Descriptor, err = uuid.FromString(components[2]); err != nil {
				return nil, ErrChannelIdInvalid
			}
		}

	// Label.
	stream.Label = components[3]
		// Mode.
		stream.Mode = StreamModeDM
	default:
		return nil, ErrChannelIdInvalid
	}

	return &ChannelIdToStreamResult{Stream: stream}, nil
}
@@ -320,5 +352,5 @@ func StreamToChannelId(stream PresenceStream) (string, error) {
		descriptor = stream.Descriptor.String()
	}

	return fmt.Sprintf("%v:%v:%v:%v", stream.Mode, subject, descriptor, stream.Label), nil
	return fmt.Sprintf("%v.%v.%v.%v", stream.Mode, subject, descriptor, stream.Label), nil
}
+200 −92

File changed.

Preview size limit exceeded, changes collapsed.

Loading