Unverified Commit 06ed00f6 authored by Maxim Ivanov's avatar Maxim Ivanov Committed by GitHub
Browse files

Use fixed number of query args in StorageReadObjects. (#1044)

parent 1b24c986
Loading
Loading
Loading
Loading
+50 −36
Original line number Diff line number Diff line
@@ -416,51 +416,65 @@ func storageListObjects(rows *sql.Rows, limit int) (*api.StorageObjectList, erro
}

func StorageReadObjects(ctx context.Context, logger *zap.Logger, db *sql.DB, caller uuid.UUID, objectIDs []*api.ReadStorageObjectId) (*api.StorageObjects, error) {
	params := make([]interface{}, 0, len(objectIDs)*3)
	collectionParam := make([]string, 0, len(objectIDs))
	keyParam := make([]string, 0, len(objectIDs))
	userIdParam := make([]uuid.UUID, 0, len(objectIDs))

	whereClause := ""
	for _, id := range objectIDs {
		l := len(params)
		if whereClause != "" {
			whereClause += " OR "
	// When selecting variable number of object we'd like to keep number of
	// SQL query arguments constant, otherwise query statistics explode, because
	// from PostgreSQL perspective query with different number of arguments is a distinct query
	//
	// To keep number of arguments static instead of building
	// WHERE (a = $1 and b = $2) OR (a = $3 and b = $4) OR ...
	// we use JOIN with "virtual" table built from columns provided as arrays:
	//
	// JOIN ROWS FROM (
	//		unnest($1::type_of_a[]),
	//      unnest($2::type_of_b[])
	// ) v(a, b)
	//
	// This way regardless of how many objects we query, we pass same number of args: one per column
	query := `SELECT collection, key, user_id, value, version, read, write, create_time, update_time
		FROM storage
		NATURAL JOIN ROWS FROM (
			unnest($1::text[]),
			unnest($2::text[]),
			unnest($3::uuid[])
		) v(collection, key, user_id)
	`

	if caller != uuid.Nil {
		// Caller is not nil: either read public (read=2) object from requested user
		// or private (read=1) object owned by caller
		query += `
		WHERE (read = 2 or (read = 1 and storage.user_id = $4))
		`
	}

		if caller == uuid.Nil {
			// Disregard permissions if called authoritatively.
			whereClause += fmt.Sprintf(" (collection = $%v AND key = $%v AND user_id = $%v) ", l+1, l+2, l+3)
			if id.UserId == "" {
				params = append(params, id.Collection, id.Key, uuid.Nil)
	for _, id := range objectIDs {
		collectionParam = append(collectionParam, id.Collection)
		keyParam = append(keyParam, id.Key)
		var reqUid uuid.UUID
		if uid := id.GetUserId(); uid != "" {
			if uid, err := uuid.FromString(uid); err == nil {
				reqUid = uid
			} else {
				params = append(params, id.Collection, id.Key, id.UserId)
				logger.Error("Could not read storage objects. Unable to parse requested user_id", zap.Error(err))
				return nil, err
			}
		} else if id.GetUserId() == "" {
			whereClause += fmt.Sprintf(" (collection = $%v AND key = $%v AND user_id = $%v AND read = 2) ", l+1, l+2, l+3)
			params = append(params, id.Collection, id.Key, uuid.Nil)
		} else {
			whereClause += fmt.Sprintf(" (collection = $%v AND key = $%v AND user_id = $%v AND (read = 2 OR (read = 1 AND user_id = $%v))) ", l+1, l+2, l+3, l+4)
			params = append(params, id.Collection, id.Key, id.UserId, caller)
		}
		userIdParam = append(userIdParam, reqUid)
	}

	query := `
SELECT collection, key, user_id, value, version, read, write, create_time, update_time
FROM storage
WHERE
` + whereClause
	params := []interface{}{collectionParam, keyParam, userIdParam}
	if caller != uuid.Nil {
		params = append(params, caller)
	}

	var objects *api.StorageObjects
	err := ExecuteRetryable(func() error {
		rows, err := db.QueryContext(ctx, query, params...)
		if err != nil {
			if err == sql.ErrNoRows {
				objects = &api.StorageObjects{Objects: make([]*api.StorageObject, 0)}
				return nil
			}
			logger.Error("Could not read storage objects.", zap.Error(err))
			return err
		}
	err := ExecuteRetryablePgx(ctx, db, func(conn *pgx.Conn) error {
		rows, _ := conn.Query(ctx, query, params...)
		defer rows.Close()

		funcObjects := &api.StorageObjects{Objects: make([]*api.StorageObject, 0, len(objectIDs))}
		for rows.Next() {
			o := &api.StorageObject{CreateTime: &timestamppb.Timestamp{}, UpdateTime: &timestamppb.Timestamp{}}
@@ -476,7 +490,7 @@ WHERE

			funcObjects.Objects = append(funcObjects.Objects, o)
		}
		if err = rows.Err(); err != nil {
		if err := rows.Err(); err != nil {
			logger.Error("Could not read storage objects.", zap.Error(err))
			return err
		}
+23 −0
Original line number Diff line number Diff line
@@ -233,6 +233,29 @@ func ExecuteRetryable(fn func() error) error {
	return nil
}

// ExecuteRetryablePgx Retry functions that perform non-transactional database operations on PgConn
func ExecuteRetryablePgx(ctx context.Context, db *sql.DB, fn func(conn *pgx.Conn) error) error {
	c, err := db.Conn(ctx)
	if err != nil {
		return err
	}
	defer c.Close()
	return c.Raw(func(dc any) (err error) {
		conn := dc.(*stdlib.Conn).Conn()
		for i := 0; i < 5; i++ {
			err = fn(conn)
			var pgErr *pgconn.PgError
			if errors.As(errorCause(err), &pgErr) && pgErr.Code[:2] == "40" {
				// 40XXXX codes are retriable errors
				continue
			}
			// return on non retryable error or success
			return err
		}
		return err
	})
}

// ExecuteInTx runs fn inside tx which should already have begun.
// fn is subject to the same restrictions as the fn passed to ExecuteTx.
func ExecuteInTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) error {