Unverified Commit 07e0165d authored by Simon Esposito's avatar Simon Esposito Committed by GitHub
Browse files

Improve group list cursor precision. (#906)

parent e4cd2cee
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -113,7 +113,7 @@ FROM groups ORDER BY id ASC LIMIT $1`
	var previousGroup *api.Group

	for rows.Next() {
		group, err := convertToGroup(rows)
		group, _, err := convertToGroup(rows)
		if err != nil {
			_ = rows.Close()
			s.logger.Error("Error scanning groups.", zap.Any("in", in), zap.Error(err))
+51 −36
Original line number Diff line number Diff line
@@ -123,7 +123,7 @@ RETURNING id, creator_id, name, description, avatar_url, state, edge_count, lang
		}
		// Rows closed in groupConvertRows()

		groups, err := groupConvertRows(rows, 1)
		groups, _, err := groupConvertRows(rows, 1)
		if err != nil {
			logger.Debug("Could not parse rows.", zap.Error(err))
			return err
@@ -301,7 +301,7 @@ WHERE (id = $1) AND (disable_time = '1970-01-01 00:00:00 UTC')`
	}
	// Rows closed in groupConvertRows()

	groups, err := groupConvertRows(rows, 1)
	groups, _, err := groupConvertRows(rows, 1)
	if err != nil {
		logger.Error("Could not parse groups.", zap.Error(err))
		return err
@@ -1679,7 +1679,7 @@ AND id IN (` + strings.Join(statements, ",") + `)`
	}
	// Rows closed in groupConvertRows()

	groups, err := groupConvertRows(rows, len(ids))
	groups, _, err := groupConvertRows(rows, len(ids))
	if err != nil {
		if err == sql.ErrNoRows {
			return make([]*api.Group, 0), nil
@@ -1862,7 +1862,7 @@ WHERE disable_time = '1970-01-01 00:00:00 UTC'`
	}

	// Rows closed in groupConvertRows()
	groups, err := groupConvertRows(rows, limit+1)
	groups, newCursorStr, err := groupConvertRows(rows, limit)
	if err != nil {
		if err == sql.ErrNoRows {
			return groupList, nil
@@ -1871,26 +1871,8 @@ WHERE disable_time = '1970-01-01 00:00:00 UTC'`
		return nil, err
	}

	groupList.Groups = groups[:int(math.Min(float64(len(groups)), float64(limit)))]

	cursorBuf := new(bytes.Buffer)
	if len(groups) > limit {
		lastGroup := groupList.Groups[len(groupList.Groups)-1]
		newCursor := &groupListCursor{
			ID:         uuid.Must(uuid.FromString(lastGroup.Id)),
			EdgeCount:  lastGroup.EdgeCount,
			Lang:       lastGroup.LangTag,
			Name:       lastGroup.Name,
			Open:       lastGroup.Open.Value,
			UpdateTime: lastGroup.UpdateTime.AsTime().UnixNano(),
		}

		if err := gob.NewEncoder(cursorBuf).Encode(newCursor); err != nil {
			logger.Error("Could not create new cursor.", zap.Error(err))
			return nil, err
		}
		groupList.Cursor = base64.RawURLEncoding.EncodeToString(cursorBuf.Bytes())
	}
	groupList.Groups = groups
	groupList.Cursor = newCursorStr

	return groupList, nil
}
@@ -1910,7 +1892,7 @@ type groupSqlStruct struct {
	updateTime  pgtype.Timestamptz
}

func sqlMapper(row *groupSqlStruct) *api.Group {
func sqlMapper(row *groupSqlStruct) (*api.Group, *time.Time) {
	open := true
	if row.state.Int64 == 1 {
		open = false
@@ -1928,34 +1910,64 @@ func sqlMapper(row *groupSqlStruct) *api.Group {
		MaxCount:    int32(row.maxCount.Int64),
		CreateTime:  &timestamppb.Timestamp{Seconds: row.createTime.Time.Unix()},
		UpdateTime:  &timestamppb.Timestamp{Seconds: row.updateTime.Time.Unix()},
	}
	}, &row.updateTime.Time
}

func convertToGroup(rows *sql.Rows) (*api.Group, error) {
func convertToGroup(rows *sql.Rows) (*api.Group, *time.Time, error) {
	s := groupSqlStruct{}
	groupStruct, fields := &s, []interface{}{&s.id, &s.creatorID, &s.name, &s.description, &s.avatarURL, &s.state, &s.edgeCount, &s.lang,
		&s.maxCount, &s.metadata, &s.createTime, &s.updateTime}
	if err := rows.Scan(fields...); err != nil {
		return nil, err
		return nil, nil, err
	}
	return sqlMapper(groupStruct), nil
	group, updateTime := sqlMapper(groupStruct)

	return group, updateTime, nil
}

func groupConvertRows(rows *sql.Rows, limit int) ([]*api.Group, error) {
func groupConvertRows(rows *sql.Rows, limit int) ([]*api.Group, string, error) {
	defer rows.Close()
	groups := make([]*api.Group, 0, limit)

	groups := make([]*api.Group, 0, limit+1)
	updateTimes := make([]*time.Time, 0, limit+1)
	var updateTime *time.Time
	for rows.Next() {
		if group, err := convertToGroup(rows); err != nil {
			return nil, err
		var group *api.Group
		var err error
		if group, updateTime, err = convertToGroup(rows); err != nil {
			return nil, "", err
		} else {
			groups = append(groups, group)
			updateTimes = append(updateTimes, updateTime)
		}
	}
	if err := rows.Err(); err != nil {
		return nil, err
		return nil, "", err
	}

	return groups, nil
	outGroups := groups[:int(math.Min(float64(len(groups)), float64(limit)))]

	var cursor string
	cursorBuf := new(bytes.Buffer)
	if len(groups) > limit {
		lastGroup := outGroups[len(outGroups)-1]
		newCursor := &groupListCursor{
			ID:         uuid.Must(uuid.FromString(lastGroup.Id)),
			EdgeCount:  lastGroup.EdgeCount,
			Lang:       lastGroup.LangTag,
			Name:       lastGroup.Name,
			Open:       lastGroup.Open.Value,
			UpdateTime: updateTimes[len(outGroups)-1].UnixNano(),
		}

		if err := gob.NewEncoder(cursorBuf).Encode(newCursor); err != nil {
			return nil, "", err
		}

		cursor = base64.RawURLEncoding.EncodeToString(cursorBuf.Bytes())
	}

	return outGroups, cursor, nil
}

func groupAddUser(ctx context.Context, db *sql.DB, tx *sql.Tx, groupID uuid.UUID, userID uuid.UUID, state int) (int64, error) {
@@ -2172,7 +2184,10 @@ func getGroup(ctx context.Context, logger *zap.Logger, db *sql.DB, groupID uuid.
		logger.Error("Error retrieving group.", zap.Error(err))
		return nil, err
	}
	return sqlMapper(groupStruct), nil

	group, _ := sqlMapper(groupStruct)

	return group, nil
}

func incrementGroupEdge(ctx context.Context, logger *zap.Logger, tx *sql.Tx, uid uuid.UUID, groupID uuid.UUID) error {