Unverified Commit 211f6d15 authored by Fernando Takagi's avatar Fernando Takagi Committed by GitHub
Browse files

Move queries with variable number of args to a fixed number of args syntax (#1114)

WHERE x in ($1,$2,$3, ...) is replaced with WHERE x = ANY($1::type_of_x[]), with all values passed as a single array argument.

INSERT INTO t (a,b,c) VALUES ($1,$2,$3), ($4,$5,$6), ... is replaced with INSERT INTO .. SELECT unnest()
parent 2ee6c926
Loading
Loading
Loading
Loading
+4 −10
Original line number Diff line number Diff line
@@ -5,8 +5,6 @@ import (
	"database/sql"
	"net/url"
	"sort"
	"strconv"
	"strings"
	"time"

	"github.com/gofrs/uuid/v5"
@@ -66,20 +64,16 @@ func (s *ConsoleServer) DeleteChannelMessages(ctx context.Context, in *console.D
		s.logger.Info("Messages deleted.", zap.Int64("affected", affected), zap.String("timestamp", deleteBefore.String()))
	}
	if len(in.Ids) > 0 {
		params := make([]interface{}, 0, len(in.Ids))
		statements := make([]string, len(in.Ids))
		for i, id := range in.Ids {
			idStr, err := uuid.FromString(id)
		for _, id := range in.Ids {
			_, err := uuid.FromString(id)
			if err != nil {
				return nil, status.Error(codes.InvalidArgument, "Requires a valid message ID.")
			}
			params = append(params, idStr)
			statements[i] = "$" + strconv.Itoa(i+1)
		}
		query := "DELETE FROM message WHERE id IN (" + strings.Join(statements, ",") + ")"
		query := "DELETE FROM message WHERE id = ANY($1)"
		var res sql.Result
		var err error
		if res, err = s.db.ExecContext(ctx, query, params...); err != nil {
		if res, err = s.db.ExecContext(ctx, query, in.Ids); err != nil {
			s.logger.Error("Could not delete messages.", zap.Error(err))
			return nil, status.Error(codes.Internal, "An error occurred while trying to delete messages.")
		}
+2 −9
Original line number Diff line number Diff line
@@ -139,20 +139,13 @@ WHERE u.id = $1`
}

func GetAccounts(ctx context.Context, logger *zap.Logger, db *sql.DB, statusRegistry *StatusRegistry, userIDs []string) ([]*api.Account, error) {
	statements := make([]string, 0, len(userIDs))
	parameters := make([]interface{}, 0, len(userIDs))
	for _, userID := range userIDs {
		parameters = append(parameters, userID)
		statements = append(statements, "$"+strconv.Itoa(len(parameters)))
	}

	query := `
SELECT u.id, u.username, u.display_name, u.avatar_url, u.lang_tag, u.location, u.timezone, u.metadata, u.wallet,
	u.email, u.apple_id, u.facebook_id, u.facebook_instant_game_id, u.google_id, u.gamecenter_id, u.steam_id, u.custom_id, u.edge_count,
	u.create_time, u.update_time, u.verify_time, u.disable_time, array(select ud.id from user_device ud where u.id = ud.user_id)
FROM users u
WHERE u.id IN (` + strings.Join(statements, ",") + `)`
	rows, err := db.QueryContext(ctx, query, parameters...)
WHERE u.id = ANY($1)`
	rows, err := db.QueryContext(ctx, query, userIDs)
	if err != nil {
		logger.Error("Error retrieving user accounts.", zap.Error(err))
		return nil, err
+16 −22
Original line number Diff line number Diff line
@@ -854,15 +854,13 @@ func importSteamFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, mes
			return nil
		}

		statements := make([]string, 0, len(steamProfiles))
		params := make([]interface{}, 0, len(steamProfiles))
		for i, steamProfile := range steamProfiles {
			statements = append(statements, "$"+strconv.Itoa(i+1))
			params = append(params, strconv.FormatUint(steamProfile.SteamID, 10))
		steamIDs := make([]string, 0, len(steamProfiles))
		for _, steamProfile := range steamProfiles {
			steamIDs = append(steamIDs, strconv.FormatUint(steamProfile.SteamID, 10))
		}

		query := "SELECT id FROM users WHERE steam_id IN (" + strings.Join(statements, ", ") + ")"
		rows, err := tx.QueryContext(ctx, query, params...)
		query := "SELECT id FROM users WHERE steam_id = ANY($1::text[])"
		rows, err := tx.QueryContext(ctx, query, steamIDs)
		if err != nil {
			if err == sql.ErrNoRows {
				// None of the friend profiles exist.
@@ -872,7 +870,7 @@ func importSteamFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, mes
		}

		var id string
		possibleFriendIDs := make([]uuid.UUID, 0, len(statements))
		possibleFriendIDs := make([]uuid.UUID, 0, len(steamIDs))
		for rows.Next() {
			err = rows.Scan(&id)
			if err != nil {
@@ -930,15 +928,13 @@ func importFacebookFriends(ctx context.Context, logger *zap.Logger, db *sql.DB,
			return nil
		}

		statements := make([]string, 0, len(facebookProfiles))
		params := make([]interface{}, 0, len(facebookProfiles))
		for i, facebookProfile := range facebookProfiles {
			statements = append(statements, "$"+strconv.Itoa(i+1))
		params := make([]string, 0, len(facebookProfiles))
		for _, facebookProfile := range facebookProfiles {
			params = append(params, facebookProfile.ID)
		}

		query := "SELECT id FROM users WHERE facebook_id IN (" + strings.Join(statements, ", ") + ")"
		rows, err := tx.QueryContext(ctx, query, params...)
		query := "SELECT id FROM users WHERE facebook_id = ANY($1::text[])"
		rows, err := tx.QueryContext(ctx, query, params)
		if err != nil {
			if err == sql.ErrNoRows {
				// None of the friend profiles exist.
@@ -948,7 +944,7 @@ func importFacebookFriends(ctx context.Context, logger *zap.Logger, db *sql.DB,
		}

		var id string
		possibleFriendIDs := make([]uuid.UUID, 0, len(statements))
		possibleFriendIDs := make([]uuid.UUID, 0, len(params))
		for rows.Next() {
			err = rows.Scan(&id)
			if err != nil {
@@ -1005,8 +1001,7 @@ func resetUserFriends(ctx context.Context, tx *sql.Tx, userID uuid.UUID) error {
	if err != nil {
		return err
	}
	statements := make([]string, 0, 10)
	params := make([]interface{}, 0, 10)
	params := make([]string, 0, 10)
	for rows.Next() {
		var id string
		err = rows.Scan(&id)
@@ -1015,17 +1010,16 @@ func resetUserFriends(ctx context.Context, tx *sql.Tx, userID uuid.UUID) error {
			return err
		}
		params = append(params, id)
		statements = append(statements, "$"+strconv.Itoa(len(params)))
	}
	_ = rows.Close()

	if len(statements) > 0 {
		query = "UPDATE users SET edge_count = edge_count - 1 WHERE id IN (" + strings.Join(statements, ",") + ")"
		result, err := tx.ExecContext(ctx, query, params...)
	if len(params) > 0 {
		query = "UPDATE users SET edge_count = edge_count - 1 WHERE id = ANY($1)"
		result, err := tx.ExecContext(ctx, query, params)
		if err != nil {
			return err
		}
		if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != int64(len(statements)) {
		if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != int64(len(params)) {
			return errors.New("error updating edge count after friend reset")
		}
	}
+2 −10
Original line number Diff line number Diff line
@@ -1624,18 +1624,10 @@ func GetGroups(ctx context.Context, logger *zap.Logger, db *sql.DB, ids []string
		return make([]*api.Group, 0), nil
	}

	statements := make([]string, 0, len(ids))
	params := make([]interface{}, 0, len(ids))
	for i, id := range ids {
		statements = append(statements, "$"+strconv.Itoa(i+1))
		params = append(params, id)
	}

	query := `SELECT id, creator_id, name, description, avatar_url, state, edge_count, lang_tag, max_count, metadata, create_time, update_time
FROM groups
WHERE disable_time = '1970-01-01 00:00:00 UTC'
AND id IN (` + strings.Join(statements, ",") + `)`
	rows, err := db.QueryContext(ctx, query, params...)
WHERE disable_time = '1970-01-01 00:00:00 UTC' AND id = ANY($1)`
	rows, err := db.QueryContext(ctx, query, ids)
	if err != nil {
		if err == sql.ErrNoRows {
			return make([]*api.Group, 0), nil
+4 −9
Original line number Diff line number Diff line
@@ -22,7 +22,6 @@ import (
	"encoding/gob"
	"errors"
	"sort"
	"strconv"
	"strings"
	"time"

@@ -286,15 +285,11 @@ func LeaderboardRecordsList(ctx context.Context, logger *zap.Logger, db *sql.DB,
	}

	if len(ownerIds) != 0 {
		params := make([]interface{}, 0, len(ownerIds)+2)
		params = append(params, leaderboardId, time.Unix(expiryTime, 0).UTC())
		statements := make([]string, len(ownerIds))
		for i, ownerID := range ownerIds {
			params = append(params, ownerID)
			statements[i] = "$" + strconv.Itoa(i+3)
		}
		params := []any{leaderboardId, time.Unix(expiryTime, 0).UTC(), ownerIds}
		query := `SELECT owner_id, username, score, subscore, num_score, max_num_score, metadata, create_time, update_time
FROM leaderboard_record
WHERE leaderboard_id = $1 AND expiry_time = $2 AND owner_id = ANY($3)`

		query := "SELECT owner_id, username, score, subscore, num_score, max_num_score, metadata, create_time, update_time FROM leaderboard_record WHERE leaderboard_id = $1 AND expiry_time = $2 AND owner_id IN (" + strings.Join(statements, ", ") + ")"
		rows, err := db.QueryContext(ctx, query, params...)
		if err != nil {
			logger.Error("Error reading leaderboard records", zap.Error(err))
Loading