Unverified Commit 3f1fc4f2 authored by Simon Esposito's avatar Simon Esposito Committed by GitHub
Browse files

Improve google refund polling handling. (#982)

General improvements around IAP validation.
parent ada6f942
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -92,10 +92,11 @@ func init() {

type ValidateReceiptAppleResponseReceiptInApp struct {
	OriginalTransactionID string `json:"original_transaction_id"`
	TransactionId         string `json:"transaction_id"` // Different than OriginalTransactionId if the user Auto-renews subscription or restores a purchase.
	TransactionId         string `json:"transaction_id"` // Different from OriginalTransactionId if the user Auto-renews subscription or restores a purchase.
	ProductID             string `json:"product_id"`
	ExpiresDateMs         string `json:"expires_date_ms"` // Subscription expiration or renewal date.
	PurchaseDateMs        string `json:"purchase_date_ms"`
	CancellationDateMs    string `json:"cancellation_date_ms"`
}

type ValidateReceiptAppleResponseReceipt struct {
+26 −41
Original line number Diff line number Diff line
@@ -172,6 +172,7 @@ func ValidatePurchaseGoogle(ctx context.Context, logger *zap.Logger, db *sql.DB,
	if !persist {
		validatedPurchases := []*api.ValidatedPurchase{
			{
				UserId:           userID.String(),
				ProductId:        sPurchase.productId,
				TransactionId:    sPurchase.transactionId,
				Store:            sPurchase.store,
@@ -446,6 +447,7 @@ func ListPurchases(ctx context.Context, logger *zap.Logger, db *sql.DB, userID s
			purchase_time,
			create_time,
			update_time,
			refund_time,
			environment
	FROM
			purchase
@@ -472,9 +474,10 @@ func ListPurchases(ctx context.Context, logger *zap.Logger, db *sql.DB, userID s
		var purchaseTime pgtype.Timestamptz
		var createTime pgtype.Timestamptz
		var updateTime pgtype.Timestamptz
		var refundTime pgtype.Timestamptz
		var environment api.StoreEnvironment

		if err = rows.Scan(&dbUserID, &transactionId, &productId, &store, &rawResponse, &purchaseTime, &createTime, &updateTime, &environment); err != nil {
		if err = rows.Scan(&dbUserID, &transactionId, &productId, &store, &rawResponse, &purchaseTime, &createTime, &updateTime, &refundTime, &environment); err != nil {
			logger.Error("Error retrieving purchases.", zap.Error(err))
			return nil, err
		}
@@ -500,6 +503,9 @@ func ListPurchases(ctx context.Context, logger *zap.Logger, db *sql.DB, userID s
			ProviderResponse: rawResponse,
			Environment:      environment,
		}
		if refundTime.Time.Unix() != 0 {
			purchase.RefundTime = timestamppb.New(purchase.RefundTime.AsTime())
		}

		purchases = append(purchases, purchase)

@@ -575,6 +581,8 @@ func upsertPurchases(ctx context.Context, db *sql.DB, purchases []*storagePurcha
		return nil, errors.New("expects at least one receipt")
	}

	userIDIn := purchases[0].userID

	statements := make([]string, 0, len(purchases))
	params := make([]interface{}, 0, len(purchases)*8)
	transactionIDsToPurchase := make(map[string]*storagePurchase)
@@ -613,72 +621,49 @@ VALUES
ON CONFLICT
    (transaction_id)
DO UPDATE SET
    refund_time = $8, update_time = now()
    refund_time = $8,
    update_time = now()
RETURNING
    transaction_id, create_time, update_time, refund_time
		user_id,
    transaction_id,
    create_time,
    update_time,
    refund_time
`
	insertedTransactionIDs := make(map[string]struct{})
	rows, err := db.QueryContext(ctx, query, params...)
	if err != nil {
		return nil, err
	}
	for rows.Next() {
		// Newly inserted purchases
		var dbUserID uuid.UUID
		var transactionId string
		var createTime pgtype.Timestamptz
		var updateTime pgtype.Timestamptz
		var refundTime pgtype.Timestamptz
		if err = rows.Scan(&transactionId, &createTime, &updateTime, &refundTime); err != nil {
		if err = rows.Scan(&dbUserID, &transactionId, &createTime, &updateTime, &refundTime); err != nil {
			rows.Close()
			return nil, err
		}
		storedPurchase, _ := transactionIDsToPurchase[transactionId]
		storedPurchase.createTime = createTime.Time
		storedPurchase.updateTime = updateTime.Time
		storedPurchase.seenBefore = updateTime.Time.After(createTime.Time)
		if refundTime.Time.Unix() != 0 {
			storedPurchase.refundTime = refundTime.Time
		storedPurchase.seenBefore = false
		insertedTransactionIDs[storedPurchase.transactionId] = struct{}{}
	}
	rows.Close()
	if err := rows.Err(); err != nil {
		return nil, err
	}

	// Go over purchases that have not been inserted (already exist in the DB) and fetch createTime and updateTime
	if len(transactionIDsToPurchase) > len(insertedTransactionIDs) {
		seenIDs := make([]string, 0, len(transactionIDsToPurchase))
		for tID, _ := range transactionIDsToPurchase {
			if _, ok := insertedTransactionIDs[tID]; !ok {
				seenIDs = append(seenIDs, tID)
		}
	}

		rows, err = db.QueryContext(ctx, "SELECT transaction_id, create_time, update_time FROM purchase WHERE transaction_id IN ($1)", strings.Join(seenIDs, ", "))
		if err != nil {
			return nil, err
		}
		for rows.Next() {
			// Already seen purchases
			var transactionId string
			var createTime pgtype.Timestamptz
			var updateTime pgtype.Timestamptz
			if err = rows.Scan(&transactionId, &createTime, &updateTime); err != nil {
				rows.Close()
				return nil, err
			}
			storedPurchase, _ := transactionIDsToPurchase[transactionId]
			storedPurchase.createTime = createTime.Time
			storedPurchase.updateTime = updateTime.Time
			storedPurchase.seenBefore = true
		}
	rows.Close()
	if err := rows.Err(); err != nil {
		return nil, err
	}
	}

	storedPurchases := make([]*storagePurchase, 0, len(transactionIDsToPurchase))
	for _, purchase := range transactionIDsToPurchase {
		if purchase.seenBefore && purchase.userID != userIDIn {
			// Mismatch between userID requesting validation and existing receipt userID, return error.
			return nil, status.Error(codes.FailedPrecondition, "Invalid receipt for userID.")
		}
		storedPurchases = append(storedPurchases, purchase)
	}

+8 −2
Original line number Diff line number Diff line
@@ -581,9 +581,10 @@ DO
		raw_notification = coalesce(to_jsonb(nullif($9, '')), subscription.raw_notification::jsonb),
		refund_time = coalesce($10, subscription.refund_time)
RETURNING
    create_time, update_time, expire_time, refund_time, raw_response, raw_notification
    user_id, create_time, update_time, expire_time, refund_time, raw_response, raw_notification
`
	var (
		userID          uuid.UUID
		createTime      pgtype.Timestamptz
		updateTime      pgtype.Timestamptz
		expireTime      pgtype.Timestamptz
@@ -591,10 +592,15 @@ RETURNING
		rawResponse     string
		rawNotification string
	)
	if err := db.QueryRowContext(ctx, query, sub.userID, sub.store, sub.originalTransactionId, sub.productId, sub.purchaseTime, sub.environment, sub.expireTime, sub.rawResponse, sub.rawNotification, sub.refundTime).Scan(&createTime, &updateTime, &expireTime, &refundTime, &rawResponse, &rawNotification); err != nil {
	if err := db.QueryRowContext(ctx, query, sub.userID, sub.store, sub.originalTransactionId, sub.productId, sub.purchaseTime, sub.environment, sub.expireTime, sub.rawResponse, sub.rawNotification, sub.refundTime).Scan(&userID, &createTime, &updateTime, &expireTime, &refundTime, &rawResponse, &rawNotification); err != nil {
		return err
	}

	if sub.userID != userID {
		// Subscription receipt has been seen before for a different user.
		return status.Error(codes.FailedPrecondition, "Invalid receipt for userID")
	}

	sub.createTime = createTime.Time
	sub.updateTime = updateTime.Time
	sub.expireTime = expireTime.Time
+18 −16
Original line number Diff line number Diff line
@@ -96,14 +96,14 @@ func (g *LocalGoogleRefundScheduler) Start(runtime *Runtime) {
					}

					for _, vr := range voidedReceipts {
						switch vr.Kind {
						case "androidpublisher#productPurchase":
						purchase, err := getPurchaseByTransactionId(g.ctx, g.db, vr.PurchaseToken)
						if err != nil && err != sql.ErrNoRows {
								g.logger.Warn("Failed to find purchase for Google refund callback", zap.Error(err), zap.String("purchase_token", vr.PurchaseToken))
							g.logger.Error("Failed to get purchase by transaction_id", zap.Error(err), zap.String("purchase_token", vr.PurchaseToken))
							continue
						}

						if purchase != nil {
							// Refunded purchase.
							if purchase.RefundTime.Seconds != 0 {
								// Purchase refund already handled, skip it.
								continue
@@ -144,8 +144,9 @@ func (g *LocalGoogleRefundScheduler) Start(runtime *Runtime) {
								PurchaseTime:  timestamppb.New(dbPurchase.purchaseTime),
								CreateTime:    timestamppb.New(dbPurchase.createTime),
								UpdateTime:    timestamppb.New(dbPurchase.updateTime),
								RefundTime:    timestamppb.New(refundTime),
								RefundTime:    timestamppb.New(dbPurchase.refundTime),
								Environment:   purchase.Environment,
								SeenBefore:    dbPurchase.seenBefore,
							}

							json, err := json.Marshal(vr)
@@ -159,16 +160,19 @@ func (g *LocalGoogleRefundScheduler) Start(runtime *Runtime) {
									g.logger.Warn("Failed to invoke Google purchase refund hook", zap.Error(err))
								}
							}

						case "androidpublisher#subscriptionPurchase":
						} else {
							subscription, err := getSubscriptionByOriginalTransactionId(g.ctx, g.db, vr.PurchaseToken)
							if err != nil {
								if err != sql.ErrNoRows {
									g.logger.Error("Failed to find subscription for Google refund callback", zap.Error(err), zap.String("transaction_id", vr.PurchaseToken))
							if err != nil && err != sql.ErrNoRows {
								g.logger.Error("Failed to get subscription by original_transaction_id", zap.Error(err), zap.String("original_transaction_id", vr.PurchaseToken))
								continue
							}

							if subscription == nil {
								// No subscription was found.
								continue
							}

							// Refunded subscription.
							if subscription.RefundTime.Seconds != 0 {
								// Subscription refund already handled, skip it.
								continue
@@ -231,8 +235,6 @@ func (g *LocalGoogleRefundScheduler) Start(runtime *Runtime) {
									g.logger.Warn("Failed to invoke Google subscription refund hook", zap.Error(err))
								}
							}
						default:
							g.logger.Warn("Unhandled IAP Google voided receipt kind", zap.String("kind", vr.Kind))
						}
					}
				}