From f3af2bbe289a790acbc300e96d21193e520d160d Mon Sep 17 00:00:00 2001
From: Andrei Mihu <andrei@heroiclabs.com>
Date: Wed, 18 Aug 2021 14:18:09 +0100
Subject: [PATCH] Correctly register purchase validation before/after hooks in
 JavaScript/Lua runtimes.

---
 CHANGELOG.md                 |  1 +
 server/runtime.go            | 36 ++++++++++++++++++------------------
 server/runtime_javascript.go | 36 ++++++++++++++++++++++++++++++++++++
 server/runtime_lua.go        | 36 ++++++++++++++++++++++++++++++++++++
 4 files changed, 91 insertions(+), 18 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5eae569a8..f0c11e0aa 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,7 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr
 
 ### Fixed
 - Fix log level in Lua runtime log calls which use logger fields.
+- Correctly register purchase validation before/after hooks in JavaScript/Lua runtimes.
 
 ## [3.5.0] - 2021-08-10
 ### Added
diff --git a/server/runtime.go b/server/runtime.go
index e12652786..0f8d5a40b 100644
--- a/server/runtime.go
+++ b/server/runtime.go
@@ -877,13 +877,13 @@ func NewRuntime(ctx context.Context, logger, startupLogger *zap.Logger, db *sql.
 		startupLogger.Info("Registered JavaScript runtime Before function invocation", zap.String("id", "getusers"))
 	}
 	if allBeforeReqFunctions.beforeValidatePurchaseAppleFunction != nil {
-		startupLogger.Info("Registered JavaScript runtime Before function invocation", zap.String("id", "receiptvalidateapple"))
+		startupLogger.Info("Registered JavaScript runtime Before function invocation", zap.String("id", "validatepurchaseapple"))
 	}
 	if allBeforeReqFunctions.beforeValidatePurchaseGoogleFunction != nil {
-		startupLogger.Info("Registered JavaScript runtime Before function invocation", zap.String("id", "receiptvalidategoogle"))
+		startupLogger.Info("Registered JavaScript runtime Before function invocation", zap.String("id", "validatepurchasegoogle"))
 	}
 	if allBeforeReqFunctions.beforeValidatePurchaseHuaweiFunction != nil {
-		startupLogger.Info("Registered JavaScript runtime Before function invocation", zap.String("id", "receiptvalidatehuawei"))
+		startupLogger.Info("Registered JavaScript runtime Before function invocation", zap.String("id", "validatepurchasehuawei"))
 	}
 	if allBeforeReqFunctions.beforeEventFunction != nil {
 		startupLogger.Info("Registered JavaScript runtime Before custom events function invocation")
@@ -1160,15 +1160,15 @@ func NewRuntime(ctx context.Context, logger, startupLogger *zap.Logger, db *sql.
 	}
 	if luaBeforeReqFunctions.beforeValidatePurchaseAppleFunction != nil {
 		allBeforeReqFunctions.beforeValidatePurchaseAppleFunction = luaBeforeReqFunctions.beforeValidatePurchaseAppleFunction
-		startupLogger.Info("Registered Lua runtime Before function invocation", zap.String("id", "receiptvalidateapple"))
+		startupLogger.Info("Registered Lua runtime Before function invocation", zap.String("id", "validatepurchaseapple"))
 	}
 	if luaBeforeReqFunctions.beforeValidatePurchaseGoogleFunction != nil {
 		allBeforeReqFunctions.beforeValidatePurchaseGoogleFunction = luaBeforeReqFunctions.beforeValidatePurchaseGoogleFunction
-		startupLogger.Info("Registered Lua runtime Before function invocation", zap.String("id", "receiptvalidategoogle"))
+		startupLogger.Info("Registered Lua runtime Before function invocation", zap.String("id", "validatepurchasegoogle"))
 	}
 	if luaBeforeReqFunctions.beforeValidatePurchaseHuaweiFunction != nil {
 		allBeforeReqFunctions.beforeValidatePurchaseHuaweiFunction = luaBeforeReqFunctions.beforeValidatePurchaseHuaweiFunction
-		startupLogger.Info("Registered Lua runtime Before function invocation", zap.String("id", "receiptvalidatehuawei"))
+		startupLogger.Info("Registered Lua runtime Before function invocation", zap.String("id", "validatepurchasehuawei"))
 	}
 	if luaBeforeReqFunctions.beforeEventFunction != nil {
 		allBeforeReqFunctions.beforeEventFunction = luaBeforeReqFunctions.beforeEventFunction
@@ -1450,15 +1450,15 @@ func NewRuntime(ctx context.Context, logger, startupLogger *zap.Logger, db *sql.
 	}
 	if goBeforeReqFunctions.beforeValidatePurchaseAppleFunction != nil {
 		allBeforeReqFunctions.beforeValidatePurchaseAppleFunction = goBeforeReqFunctions.beforeValidatePurchaseAppleFunction
-		startupLogger.Info("Registered Go runtime Before function invocation", zap.String("id", "receiptvalidateapple"))
+		startupLogger.Info("Registered Go runtime Before function invocation", zap.String("id", "validateapple"))
 	}
 	if goBeforeReqFunctions.beforeValidatePurchaseGoogleFunction != nil {
 		allBeforeReqFunctions.beforeValidatePurchaseGoogleFunction = goBeforeReqFunctions.beforeValidatePurchaseGoogleFunction
-		startupLogger.Info("Registered Go runtime Before function invocation", zap.String("id", "receiptvalidategoogle"))
+		startupLogger.Info("Registered Go runtime Before function invocation", zap.String("id", "validatepurchasegoogle"))
 	}
 	if goBeforeReqFunctions.beforeValidatePurchaseHuaweiFunction != nil {
 		allBeforeReqFunctions.beforeValidatePurchaseHuaweiFunction = goBeforeReqFunctions.beforeValidatePurchaseHuaweiFunction
-		startupLogger.Info("Registered Go runtime Before function invocation", zap.String("id", "receiptvalidatehuawei"))
+		startupLogger.Info("Registered Go runtime Before function invocation", zap.String("id", "validatepurchasehuawei"))
 	}
 	if goBeforeReqFunctions.beforeEventFunction != nil {
 		allBeforeReqFunctions.beforeEventFunction = goBeforeReqFunctions.beforeEventFunction
@@ -1672,13 +1672,13 @@ func NewRuntime(ctx context.Context, logger, startupLogger *zap.Logger, db *sql.
 		startupLogger.Info("Registered JavaScript runtime After function invocation", zap.String("id", "getusers"))
 	}
 	if allAfterReqFunctions.afterValidatePurchaseAppleFunction != nil {
-		startupLogger.Info("Registered JavaScript runtime Before function invocation", zap.String("id", "receiptvalidateapple"))
+		startupLogger.Info("Registered JavaScript runtime Before function invocation", zap.String("id", "validatepurchaseapple"))
 	}
 	if allAfterReqFunctions.afterValidatePurchaseGoogleFunction != nil {
-		startupLogger.Info("Registered JavaScript runtime Before function invocation", zap.String("id", "receiptvalidategoogle"))
+		startupLogger.Info("Registered JavaScript runtime Before function invocation", zap.String("id", "validatepurchasegoogle"))
 	}
 	if allAfterReqFunctions.afterValidatePurchaseHuaweiFunction != nil {
-		startupLogger.Info("Registered JavaScript runtime Before function invocation", zap.String("id", "receiptvalidatehuawei"))
+		startupLogger.Info("Registered JavaScript runtime Before function invocation", zap.String("id", "validatepurchasehuawei"))
 	}
 	if allAfterReqFunctions.afterEventFunction != nil {
 		startupLogger.Info("Registered JavaScript runtime After custom events function invocation")
@@ -1955,15 +1955,15 @@ func NewRuntime(ctx context.Context, logger, startupLogger *zap.Logger, db *sql.
 	}
 	if luaAfterReqFunctions.afterValidatePurchaseAppleFunction != nil {
 		allAfterReqFunctions.afterValidatePurchaseAppleFunction = luaAfterReqFunctions.afterValidatePurchaseAppleFunction
-		startupLogger.Info("Registered Lua runtime Before function invocation", zap.String("id", "receiptvalidateapple"))
+		startupLogger.Info("Registered Lua runtime Before function invocation", zap.String("id", "validatepurchaseapple"))
 	}
 	if luaAfterReqFunctions.afterValidatePurchaseGoogleFunction != nil {
 		allAfterReqFunctions.afterValidatePurchaseGoogleFunction = luaAfterReqFunctions.afterValidatePurchaseGoogleFunction
-		startupLogger.Info("Registered Lua runtime Before function invocation", zap.String("id", "receiptvalidategoogle"))
+		startupLogger.Info("Registered Lua runtime Before function invocation", zap.String("id", "validatepurchasegoogle"))
 	}
 	if luaAfterReqFunctions.afterValidatePurchaseHuaweiFunction != nil {
 		allAfterReqFunctions.afterValidatePurchaseHuaweiFunction = luaAfterReqFunctions.afterValidatePurchaseHuaweiFunction
-		startupLogger.Info("Registered Lua runtime Before function invocation", zap.String("id", "receiptvalidatehuawei"))
+		startupLogger.Info("Registered Lua runtime Before function invocation", zap.String("id", "validatepurchasehuawei"))
 	}
 	if luaAfterReqFunctions.afterEventFunction != nil {
 		allAfterReqFunctions.afterEventFunction = luaAfterReqFunctions.afterEventFunction
@@ -2245,15 +2245,15 @@ func NewRuntime(ctx context.Context, logger, startupLogger *zap.Logger, db *sql.
 	}
 	if goAfterReqFunctions.afterValidatePurchaseAppleFunction != nil {
 		allAfterReqFunctions.afterValidatePurchaseAppleFunction = goAfterReqFunctions.afterValidatePurchaseAppleFunction
-		startupLogger.Info("Registered Go runtime Before function invocation", zap.String("id", "receiptvalidateapple"))
+		startupLogger.Info("Registered Go runtime Before function invocation", zap.String("id", "validatepurchaseapple"))
 	}
 	if goAfterReqFunctions.afterValidatePurchaseGoogleFunction != nil {
 		allAfterReqFunctions.afterValidatePurchaseGoogleFunction = goAfterReqFunctions.afterValidatePurchaseGoogleFunction
-		startupLogger.Info("Registered Go runtime Before function invocation", zap.String("id", "receiptvalidategoogle"))
+		startupLogger.Info("Registered Go runtime Before function invocation", zap.String("id", "validatepurchasegoogle"))
 	}
 	if goAfterReqFunctions.afterValidatePurchaseHuaweiFunction != nil {
 		allAfterReqFunctions.afterValidatePurchaseHuaweiFunction = goAfterReqFunctions.afterValidatePurchaseHuaweiFunction
-		startupLogger.Info("Registered Go runtime Before function invocation", zap.String("id", "receiptvalidatehuawei"))
+		startupLogger.Info("Registered Go runtime Before function invocation", zap.String("id", "validatepurchasehuawei"))
 	}
 	if goAfterReqFunctions.afterEventFunction != nil {
 		allAfterReqFunctions.afterEventFunction = goAfterReqFunctions.afterEventFunction
diff --git a/server/runtime_javascript.go b/server/runtime_javascript.go
index b9f001faa..355e16e57 100644
--- a/server/runtime_javascript.go
+++ b/server/runtime_javascript.go
@@ -1170,6 +1170,30 @@ func NewRuntimeProviderJS(logger, startupLogger *zap.Logger, db *sql.DB, protojs
 						}
 						return result.(*api.GetUsersRequest), nil, 0
 					}
+				case "validatepurchaseapple":
+					beforeReqFunctions.beforeValidatePurchaseAppleFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, in *api.ValidatePurchaseAppleRequest) (*api.ValidatePurchaseAppleRequest, error, codes.Code) {
+						result, err, code := runtimeProviderJS.BeforeReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, in)
+						if result == nil || err != nil {
+							return nil, err, code
+						}
+						return result.(*api.ValidatePurchaseAppleRequest), nil, 0
+					}
+				case "validatepurchasegoogle":
+					beforeReqFunctions.beforeValidatePurchaseGoogleFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, in *api.ValidatePurchaseGoogleRequest) (*api.ValidatePurchaseGoogleRequest, error, codes.Code) {
+						result, err, code := runtimeProviderJS.BeforeReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, in)
+						if result == nil || err != nil {
+							return nil, err, code
+						}
+						return result.(*api.ValidatePurchaseGoogleRequest), nil, 0
+					}
+				case "validatepurchasehuawei":
+					beforeReqFunctions.beforeValidatePurchaseHuaweiFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, in *api.ValidatePurchaseHuaweiRequest) (*api.ValidatePurchaseHuaweiRequest, error, codes.Code) {
+						result, err, code := runtimeProviderJS.BeforeReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, in)
+						if result == nil || err != nil {
+							return nil, err, code
+						}
+						return result.(*api.ValidatePurchaseHuaweiRequest), nil, 0
+					}
 				case "event":
 					beforeReqFunctions.beforeEventFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, in *api.Event) (*api.Event, error, codes.Code) {
 						result, err, code := runtimeProviderJS.BeforeReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, in)
@@ -1456,6 +1480,18 @@ func NewRuntimeProviderJS(logger, startupLogger *zap.Logger, db *sql.DB, protojs
 					afterReqFunctions.afterGetUsersFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, out *api.Users, in *api.GetUsersRequest) error {
 						return runtimeProviderJS.AfterReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, out, in)
 					}
+				case "validatepurchaseapple":
+					afterReqFunctions.afterValidatePurchaseAppleFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, out *api.ValidatePurchaseResponse, in *api.ValidatePurchaseAppleRequest) error {
+						return runtimeProviderJS.AfterReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, out, in)
+					}
+				case "validatepurchasegoogle":
+					afterReqFunctions.afterValidatePurchaseGoogleFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, out *api.ValidatePurchaseResponse, in *api.ValidatePurchaseGoogleRequest) error {
+						return runtimeProviderJS.AfterReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, out, in)
+					}
+				case "validatepurchasehuawei":
+					afterReqFunctions.afterValidatePurchaseHuaweiFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, out *api.ValidatePurchaseResponse, in *api.ValidatePurchaseHuaweiRequest) error {
+						return runtimeProviderJS.AfterReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, out, in)
+					}
 				case "event":
 					afterReqFunctions.afterEventFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, in *api.Event) error {
 						return runtimeProviderJS.AfterReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, nil, in)
diff --git a/server/runtime_lua.go b/server/runtime_lua.go
index dd5e2a5fc..b0e9d5bc6 100644
--- a/server/runtime_lua.go
+++ b/server/runtime_lua.go
@@ -713,6 +713,30 @@ func NewRuntimeProviderLua(logger, startupLogger *zap.Logger, db *sql.DB, protoj
 						}
 						return result.(*api.GetUsersRequest), nil, 0
 					}
+				case "validatepurchaseapple":
+					beforeReqFunctions.beforeValidatePurchaseAppleFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, in *api.ValidatePurchaseAppleRequest) (*api.ValidatePurchaseAppleRequest, error, codes.Code) {
+						result, err, code := runtimeProviderLua.BeforeReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, in)
+						if result == nil || err != nil {
+							return nil, err, code
+						}
+						return result.(*api.ValidatePurchaseAppleRequest), nil, 0
+					}
+				case "validatepurchasegoogle":
+					beforeReqFunctions.beforeValidatePurchaseGoogleFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, in *api.ValidatePurchaseGoogleRequest) (*api.ValidatePurchaseGoogleRequest, error, codes.Code) {
+						result, err, code := runtimeProviderLua.BeforeReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, in)
+						if result == nil || err != nil {
+							return nil, err, code
+						}
+						return result.(*api.ValidatePurchaseGoogleRequest), nil, 0
+					}
+				case "validatepurchasehuawei":
+					beforeReqFunctions.beforeValidatePurchaseHuaweiFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, in *api.ValidatePurchaseHuaweiRequest) (*api.ValidatePurchaseHuaweiRequest, error, codes.Code) {
+						result, err, code := runtimeProviderLua.BeforeReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, in)
+						if result == nil || err != nil {
+							return nil, err, code
+						}
+						return result.(*api.ValidatePurchaseHuaweiRequest), nil, 0
+					}
 				case "event":
 					beforeReqFunctions.beforeEventFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, in *api.Event) (*api.Event, error, codes.Code) {
 						result, err, code := runtimeProviderLua.BeforeReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, in)
@@ -999,6 +1023,18 @@ func NewRuntimeProviderLua(logger, startupLogger *zap.Logger, db *sql.DB, protoj
 					afterReqFunctions.afterGetUsersFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, out *api.Users, in *api.GetUsersRequest) error {
 						return runtimeProviderLua.AfterReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, out, in)
 					}
+				case "validatepurchaseapple":
+					afterReqFunctions.afterValidatePurchaseAppleFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, out *api.ValidatePurchaseResponse, in *api.ValidatePurchaseAppleRequest) error {
+						return runtimeProviderLua.AfterReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, out, in)
+					}
+				case "validatepurchasegoogle":
+					afterReqFunctions.afterValidatePurchaseGoogleFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, out *api.ValidatePurchaseResponse, in *api.ValidatePurchaseGoogleRequest) error {
+						return runtimeProviderLua.AfterReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, out, in)
+					}
+				case "validatepurchasehuawei":
+					afterReqFunctions.afterValidatePurchaseHuaweiFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, out *api.ValidatePurchaseResponse, in *api.ValidatePurchaseHuaweiRequest) error {
+						return runtimeProviderLua.AfterReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, out, in)
+					}
 				case "event":
 					afterReqFunctions.afterEventFunction = func(ctx context.Context, logger *zap.Logger, userID, username string, vars map[string]string, expiry int64, clientIP, clientPort string, in *api.Event) error {
 						return runtimeProviderLua.AfterReq(ctx, id, logger, userID, username, vars, expiry, clientIP, clientPort, nil, in)
-- 
GitLab