Commit f4009a45 authored by Fernando Takagi's avatar Fernando Takagi
Browse files

fixes

parent b2bdc11a
Loading
Loading
Loading
Loading
+43 −56
Original line number Diff line number Diff line
@@ -427,39 +427,27 @@ func StorageReadObjects(ctx context.Context, logger *zap.Logger, db *sql.DB, cal
	keyParam := make([]string, 0, len(objectIDs))
	userIdParam := make([]uuid.UUID, 0, len(objectIDs))

	collectionSet := make(map[string]struct{})
	keySet := make(map[string]struct{})
	userIdSet := make(map[uuid.UUID]struct{})
	isCollectionSetUnique := true
	isKeySetUnique := true
	isUserIdSetUnique := true

	isSingleCollection := true
	isSingleKey := true
	isSingleUserId := true

	multipleArgs := make([]storageQueryArg, 0, 3)
	singleArgs := make([]storageQueryArg, 0, 3)
	distinctArgs := make([]storageQueryArg, 0, 3)
	uniqueArgs := make([]storageQueryArg, 0, 3)

	for _, id := range objectIDs {
		collectionParam = append(collectionParam, id.Collection)
		if isSingleCollection {
			_, ok := collectionSet[id.Collection]
			if !ok {
				collectionSet[id.Collection] = struct{}{}
			} else {
				isSingleCollection = false
				collectionSet = make(map[string]struct{})
				multipleArgs = append(multipleArgs, storageQueryArg{name: "collection", dbType: "text[]", param: collectionParam})
		if isCollectionSetUnique && len(collectionParam) > 1 {
			if id.Collection != collectionParam[0] {
				isCollectionSetUnique = false
				distinctArgs = append(distinctArgs, storageQueryArg{name: "collection", dbType: "text[]", param: collectionParam})
			}
		}

		keyParam = append(keyParam, id.Key)
		if isSingleKey {
			_, ok := keySet[id.Key]
			if !ok {
				keySet[id.Key] = struct{}{}
			} else {
				isSingleKey = false
				keySet = make(map[string]struct{})
				multipleArgs = append(multipleArgs, storageQueryArg{name: "key", dbType: "text[]", param: keyParam})
		if isKeySetUnique && len(keyParam) > 1 {
			if id.Key != keyParam[0] {
				isKeySetUnique = false
				distinctArgs = append(distinctArgs, storageQueryArg{name: "key", dbType: "text[]", param: keyParam})
			}
		}

@@ -473,50 +461,46 @@ func StorageReadObjects(ctx context.Context, logger *zap.Logger, db *sql.DB, cal
			}
		}
		userIdParam = append(userIdParam, reqUid)
		if isSingleUserId {
			_, ok := userIdSet[reqUid]
			if !ok {
				userIdSet[reqUid] = struct{}{}
			} else {
				isSingleUserId = false
				userIdSet = make(map[uuid.UUID]struct{})
				multipleArgs = append(multipleArgs, storageQueryArg{name: "user_id", dbType: "uuid[]", param: userIdParam})
		if isUserIdSetUnique && len(userIdParam) > 1 {
			if reqUid != userIdParam[0] {
				isUserIdSetUnique = false
				distinctArgs = append(distinctArgs, storageQueryArg{name: "user_id", dbType: "uuid[]", param: userIdParam})
			}
		}
	}

	if isSingleCollection {
		singleArgs = append(singleArgs, storageQueryArg{name: "collection", param: collectionParam[0]})
	if isCollectionSetUnique {
		uniqueArgs = append(uniqueArgs, storageQueryArg{name: "collection", param: collectionParam[0]})
	}
	if isSingleKey {
		singleArgs = append(singleArgs, storageQueryArg{name: "key", param: keyParam[0]})
	if isKeySetUnique {
		uniqueArgs = append(uniqueArgs, storageQueryArg{name: "key", param: keyParam[0]})
	}
	if isSingleUserId {
		singleArgs = append(singleArgs, storageQueryArg{name: "user_id", param: userIdParam[0]})
	if isUserIdSetUnique {
		uniqueArgs = append(uniqueArgs, storageQueryArg{name: "user_id", param: userIdParam[0]})
	}

	var query string
	var params []any
	switch len(multipleArgs) {
	switch len(distinctArgs) {
	case 0:
		query = fmt.Sprintf(`
SELECT collection, key, user_id, value, version, read, write, create_time, update_time FROM storage WHERE %s = $1 AND %s = $2 and %s = $3
`, singleArgs[0].name, singleArgs[1].name, singleArgs[2].name)
		params = []any{singleArgs[0].param, singleArgs[1].param, singleArgs[2].param}
SELECT collection, key, user_id, value, version, read, write, create_time, update_time FROM storage WHERE %s = $1 AND %s = $2 AND %s = $3`,
			uniqueArgs[0].name, uniqueArgs[1].name, uniqueArgs[2].name)
		params = []any{uniqueArgs[0].param, uniqueArgs[1].param, uniqueArgs[2].param}
	case 1:
		query = fmt.Sprintf(`
SELECT collection, key, user_id, value, version, read, write, create_time, update_time FROM storage WHERE %s = $1 AND %s = $2 and %s = ANY($3::%s)
`, singleArgs[0].name, singleArgs[1].name, multipleArgs[0].name, multipleArgs[0].dbType)
		params = []any{singleArgs[0].param, singleArgs[1].param, multipleArgs[0].param}
SELECT collection, key, user_id, value, version, read, write, create_time, update_time FROM storage WHERE %s = $1 AND %s = $2 AND %s = ANY($3::%s)`,
			uniqueArgs[0].name, uniqueArgs[1].name, distinctArgs[0].name, distinctArgs[0].dbType)
		params = []any{uniqueArgs[0].param, uniqueArgs[1].param, distinctArgs[0].param}
	case 2:
		query = fmt.Sprintf(`
SELECT collection, key, user_id, value, version, read, write, create_time, update_time FROM storage NATURAL JOIN ROWS FROM (
  unnest($1::%s),
  unnest($2::%s)
) t(%s, %s)
WHERE %s = $3
`, multipleArgs[0].dbType, multipleArgs[1].dbType, multipleArgs[0].name, multipleArgs[1].name, singleArgs[0].name)
		params = []any{multipleArgs[0].param, multipleArgs[1].param, singleArgs[0].param}
WHERE %s = $3`,
			distinctArgs[0].dbType, distinctArgs[1].dbType, distinctArgs[0].name, distinctArgs[1].name, uniqueArgs[0].name)
		params = []any{distinctArgs[0].param, distinctArgs[1].param, uniqueArgs[0].param}
	case 3:
		// When selecting a variable number of objects we'd like to keep number of
		// SQL query arguments constant, otherwise query statistics explode, because
@@ -537,20 +521,23 @@ SELECT collection, key, user_id, value, version, read, write, create_time, updat
  unnest($1::%s),
  unnest($2::%s),
  unnest($3::%s)
) t(%s, %s, %s)
`, multipleArgs[0].dbType, multipleArgs[1].dbType, multipleArgs[2].dbType, multipleArgs[0].name, multipleArgs[1].name, multipleArgs[2].name)
		params = []any{multipleArgs[0].param, multipleArgs[1].param, multipleArgs[2].param}
) t(%s, %s, %s)`,
			distinctArgs[0].dbType, distinctArgs[1].dbType, distinctArgs[2].dbType, distinctArgs[0].name, distinctArgs[1].name, distinctArgs[2].name)
		params = []any{distinctArgs[0].param, distinctArgs[1].param, distinctArgs[2].param}
	default:
		logger.Error("Unexpected code path.", zap.Int("multipleArgs", len(multipleArgs)))
		logger.Error("Unexpected code path.", zap.Int("multipleArgs", len(distinctArgs)))
		return nil, errors.New("unexpected code path")
	}

	if caller != uuid.Nil {
		if len(distinctArgs) == 3 {
			query += ` WHERE `
		} else {
			query += ` AND `
		}
		// Caller is not nil: either read public (read=2) object from requested user
		// or private (read=1) object owned by caller
		query += `
		WHERE (read = 2 or (read = 1 and storage.user_id = $4))
		`
		query += `(read = 2 or (read = 1 and storage.user_id = $4))`
		params = append(params, caller)
	}