Commit 20d6ab3c authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Allow RPC functions to receive and return raw JSON data.

parent 16bf2a02
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -4,6 +4,9 @@ All notable changes to this project are documented below.
The format is based on [keep a changelog](http://keepachangelog.com) and this project uses [semantic versioning](http://semver.org).

## [Unreleased]
### Added
- Allow RPC functions to receive and return raw JSON data.

### Changed
- Update devconsole lodash (4.17.13) and lodash.template (4.5.0) dependencies.

+6 −1
Original line number Diff line number Diff line
@@ -179,11 +179,16 @@ func StartApiServer(logger *zap.Logger, startupLogger *zap.Logger, db *sql.DB, j
	grpcGatewayRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) }).Methods("GET")
	grpcGatewayRouter.HandleFunc("/ws", NewSocketWsAcceptor(logger, config, sessionRegistry, matchmaker, tracker, runtime, jsonpbMarshaler, jsonpbUnmarshaler, pipeline)).Methods("GET")

	// Another nested router to hijack RPC requests bound for GRPC Gateway.
	grpcGatewayMux := mux.NewRouter()
	grpcGatewayMux.HandleFunc("/v2/rpc/{id:.*}", s.RpcFuncHttp).Methods("GET", "POST")
	grpcGatewayMux.NewRoute().Handler(grpcGateway)

	// Enable stats recording on all request paths except:
	// "/" is not tracked at all.
	// "/ws" implements its own separate tracking.
	handlerWithStats := &ochttp.Handler{
		Handler:          grpcGateway,
		Handler:          grpcGatewayMux,
		IsPublicEndpoint: true,
	}

+171 −0
Original line number Diff line number Diff line
@@ -15,16 +15,187 @@
package server

import (
	"encoding/json"
	"io/ioutil"
	"net/http"
	"strings"

	"github.com/gofrs/uuid"
	"github.com/gorilla/mux"
	"github.com/grpc-ecosystem/grpc-gateway/runtime"
	"github.com/heroiclabs/nakama/api"
	"go.uber.org/zap"
	"golang.org/x/net/context"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
)

var (
	authTokenInvalidBytes    = []byte(`{"error":"Auth token invalid","message":"Auth token invalid","code":16}`)
	httpKeyInvalidBytes      = []byte(`{"error":"HTTP key invalid","message":"HTTP key invalid","code":16}`)
	noAuthBytes              = []byte(`{"error":"Auth token or HTTP key required","message":"Auth token or HTTP key required","code":16}`)
	rpcIdMustBeSetBytes      = []byte(`{"error":"RPC ID must be set","message":"RPC ID must be set","code":3}`)
	rpcFunctionNotFoundBytes = []byte(`{"error":"RPC function not found","message":"RPC function not found","code":5}`)
	internalServerErrorBytes = []byte(`{"error":"Internal Server Error","message":"Internal Server Error","code":13}`)
	badJsonBytes             = []byte(`{"error":"json: cannot unmarshal object into Go value of type string","message":"json: cannot unmarshal object into Go value of type string","code":3}`)
)

func (s *ApiServer) RpcFuncHttp(w http.ResponseWriter, r *http.Request) {
	// Check first token then HTTP key for authentication, and add user info to the context.
	queryParams := r.URL.Query()
	var tokenAuth bool
	var userID uuid.UUID
	var username string
	var expiry int64
	if auth := r.Header["Authorization"]; len(auth) >= 1 {
		userID, username, expiry, tokenAuth = parseBearerAuth([]byte(s.config.GetSession().EncryptionKey), auth[0])
		if !tokenAuth {
			// Auth token not valid or expired.
			w.WriteHeader(http.StatusUnauthorized)
			w.Header().Set("content-type", "application/json")
			_, err := w.Write(authTokenInvalidBytes)
			if err != nil {
				s.logger.Debug("Error writing response to client", zap.Error(err))
			}
			return
		}
	} else if httpKey := queryParams.Get("http_key"); httpKey != "" {
		if httpKey != s.config.GetRuntime().HTTPKey {
			// HTTP key did not match.
			w.WriteHeader(http.StatusUnauthorized)
			w.Header().Set("content-type", "application/json")
			_, err := w.Write(httpKeyInvalidBytes)
			if err != nil {
				s.logger.Debug("Error writing response to client", zap.Error(err))
			}
			return
		}
	} else {
		// No authentication present.
		w.WriteHeader(http.StatusUnauthorized)
		w.Header().Set("content-type", "application/json")
		_, err := w.Write(noAuthBytes)
		if err != nil {
			s.logger.Debug("Error writing response to client", zap.Error(err))
		}
		return
	}

	// Check the RPC function ID.
	maybeId, ok := mux.Vars(r)["id"]
	if !ok || maybeId == "" {
		// Missing RPC function ID.
		w.WriteHeader(http.StatusBadRequest)
		w.Header().Set("content-type", "application/json")
		_, err := w.Write(rpcIdMustBeSetBytes)
		if err != nil {
			s.logger.Debug("Error writing response to client", zap.Error(err))
		}
		return
	}
	id := strings.ToLower(maybeId)

	// Find the correct RPC function.
	fn := s.runtime.Rpc(id)
	if fn == nil {
		// No function registered for this ID.
		w.WriteHeader(http.StatusNotFound)
		w.Header().Set("content-type", "application/json")
		_, err := w.Write(rpcFunctionNotFoundBytes)
		if err != nil {
			s.logger.Debug("Error writing response to client", zap.Error(err))
		}
		return
	}

	// Check if we need to mimic existing GRPC Gateway behaviour or expect to receive/send unwrapped data.
	// Any value for this query parameter, including the parameter existing with an empty value, will
	// indicate that raw behaviour is expected.
	_, unwrap := queryParams["unwrap"]

	// Prepare input to function.
	var payload string
	if r.Method == "POST" {
		b, err := ioutil.ReadAll(r.Body)
		if err != nil {
			// Error reading request body.
			w.WriteHeader(http.StatusInternalServerError)
			w.Header().Set("content-type", "application/json")
			_, err := w.Write(internalServerErrorBytes)
			if err != nil {
				s.logger.Debug("Error writing response to client", zap.Error(err))
			}
			return
		}

		// Maybe attempt to decode to a JSON string to mimic existing GRPC Gateway behaviour.
		if !unwrap {
			err = json.Unmarshal(b, &payload)
			if err != nil {
				w.WriteHeader(http.StatusBadRequest)
				w.Header().Set("content-type", "application/json")
				_, err := w.Write(badJsonBytes)
				if err != nil {
					s.logger.Debug("Error writing response to client", zap.Error(err))
				}
				return
			}
		} else {
			payload = string(b)
		}
	}

	queryParams.Del("http_key")

	uid := ""
	if tokenAuth {
		uid = userID.String()
	}

	clientIP, clientPort := extractClientAddressFromRequest(s.logger, r)

	// Execute the function.
	result, fnErr, code := fn(r.Context(), queryParams, uid, username, expiry, "", clientIP, clientPort, payload)
	if fnErr != nil {
		response, _ := json.Marshal(map[string]interface{}{"error": fnErr, "message": fnErr.Error(), "code": code})
		w.WriteHeader(runtime.HTTPStatusFromCode(code))
		w.Header().Set("content-type", "application/json")
		_, err := w.Write(response)
		if err != nil {
			s.logger.Debug("Error writing response to client", zap.Error(err))
		}
		return
	}

	// Return the successful result.
	var response []byte
	if !unwrap {
		// GRPC Gateway equivalent behaviour.
		var err error
		response, err = json.Marshal(map[string]interface{}{"payload": result})
		if err != nil {
			// Failed to encode the wrapped response.
			s.logger.Error("Error marshaling wrapped response to client", zap.Error(err))
			w.WriteHeader(http.StatusInternalServerError)
			w.Header().Set("content-type", "application/json")
			_, err := w.Write(internalServerErrorBytes)
			if err != nil {
				s.logger.Debug("Error writing response to client", zap.Error(err))
			}
			return
		}
	} else {
		// "Unwrapped" response.
		response = []byte(result)
	}
	w.WriteHeader(http.StatusOK)
	_, err := w.Write(response)
	if err != nil {
		s.logger.Debug("Error writing response to client", zap.Error(err))
	}
}

func (s *ApiServer) RpcFunc(ctx context.Context, in *api.Rpc) (*api.Rpc, error) {
	if in.Id == "" {
		return nil, status.Error(codes.InvalidArgument, "RPC ID must be set")