Commit f3d38570 authored by Mo Firouz's avatar Mo Firouz
Browse files

Thread through client IP and port to Lua context. #226

parent 066cd7e0
Loading
Loading
Loading
Loading
+26 −2
Original line number Diff line number Diff line
@@ -28,6 +28,7 @@ import (

	"compress/flate"
	"compress/gzip"

	"github.com/dgrijalva/jwt-go"
	"github.com/gofrs/uuid"
	"github.com/golang/protobuf/jsonpb"
@@ -49,6 +50,7 @@ import (
	"google.golang.org/grpc/credentials"
	_ "google.golang.org/grpc/encoding/gzip" // enable gzip compression on server for grpc
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/peer"
	"google.golang.org/grpc/status"
)

@@ -256,8 +258,30 @@ func apiInterceptorFunc(logger *zap.Logger, config Config, runtimePool *RuntimeP
		startNanos := time.Now().UTC().UnixNano()
		span := trace.NewSpan(name, nil, trace.StartOptions{})

		clientAddr := ""
		clientIP := ""
		clientPort := ""
		md, _ := metadata.FromIncomingContext(ctx)
		if ips := md.Get("x-forwarded-for"); len(ips) > 0 {
			// look for gRPC-Gateway / LB header
			clientAddr = strings.Split(ips[0], ",")[0]
		} else if peerInfo, ok := peer.FromContext(ctx); ok {
			// if missing, try to look up gRPC peer info
			clientAddr = peerInfo.Addr.String()
		}

		clientAddr = strings.TrimSpace(clientAddr)
		if host, port, err := net.SplitHostPort(clientAddr); err == nil {
			clientIP = host
			clientPort = port
		} else if addrErr, ok := err.(*net.AddrError); ok && addrErr.Err == "missing port in address" {
			clientIP = clientAddr
		} else {
			logger.Debug("Could not extract client address from request.", zap.Error(err))
		}

		// Actual before hook function execution.
		beforeHookResult, hookErr := invokeReqBeforeHook(logger, config, runtimePool, jsonpbMarshaler, jsonpbUnmarshaler, "", uid, username, expiry, info.FullMethod, req)
		beforeHookResult, hookErr := invokeReqBeforeHook(logger, config, runtimePool, jsonpbMarshaler, jsonpbUnmarshaler, "", uid, username, expiry, clientIP, clientPort, info.FullMethod, req)

		// Stats measurement end boundary.
		span.End()
@@ -283,7 +307,7 @@ func apiInterceptorFunc(logger *zap.Logger, config Config, runtimePool *RuntimeP
			span := trace.NewSpan(name, nil, trace.StartOptions{})

			// Actual after hook function execution.
			invokeReqAfterHook(logger, config, runtimePool, jsonpbMarshaler, "", uid, username, expiry, info.FullMethod, handlerResult)
			invokeReqAfterHook(logger, config, runtimePool, jsonpbMarshaler, "", uid, username, expiry, info.FullMethod, clientIP, clientPort, handlerResult)

			// Stats measurement end boundary.
			span.End()
+25 −1
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@
package server

import (
	"net"
	"strings"

	"github.com/gofrs/uuid"
@@ -25,6 +26,7 @@ import (
	"golang.org/x/net/context"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/peer"
	"google.golang.org/grpc/status"
)

@@ -71,7 +73,29 @@ func (s *ApiServer) RpcFunc(ctx context.Context, in *api.Rpc) (*api.Rpc, error)
		return nil, status.Error(codes.NotFound, "RPC function not found")
	}

	result, fnErr, code := runtime.InvokeFunction(ExecutionModeRPC, lf, queryParams, uid, username, expiry, "", in.Payload)
	clientAddr := ""
	clientIP := ""
	clientPort := ""
	md, _ := metadata.FromIncomingContext(ctx)
	if ips := md.Get("x-forwarded-for"); len(ips) > 0 {
		// look for gRPC-Gateway / LB header
		clientAddr = strings.Split(ips[0], ",")[0]
	} else if peerInfo, ok := peer.FromContext(ctx); ok {
		// if missing, try to look up gRPC peer info
		clientAddr = peerInfo.Addr.String()
	}

	clientAddr = strings.TrimSpace(clientAddr)
	if host, port, err := net.SplitHostPort(clientAddr); err == nil {
		clientIP = host
		clientPort = port
	} else if addrErr, ok := err.(*net.AddrError); ok && addrErr.Err == "missing port in address" {
		clientIP = clientAddr
	} else {
		s.logger.Debug("Could not extract client address from request.", zap.Error(err))
	}

	result, fnErr, code := runtime.InvokeFunction(ExecutionModeRPC, lf, queryParams, uid, username, expiry, "", clientIP, clientPort, in.Payload)
	s.runtimePool.Put(runtime)

	if fnErr != nil {
+5 −4
Original line number Diff line number Diff line
@@ -19,14 +19,15 @@ import (
	"fmt"

	"context"
	"strings"
	"time"

	"github.com/golang/protobuf/jsonpb"
	"github.com/heroiclabs/nakama/rtapi"
	"go.opencensus.io/stats"
	"go.opencensus.io/tag"
	"go.opencensus.io/trace"
	"go.uber.org/zap"
	"strings"
	"time"
)

type Pipeline struct {
@@ -147,7 +148,7 @@ func (p *Pipeline) ProcessRequest(logger *zap.Logger, session Session, envelope
		span := trace.NewSpan(name, nil, trace.StartOptions{})

		// Actual before hook function execution.
		hookResult, hookErr := invokeReqBeforeHook(logger, p.config, p.runtimePool, p.jsonpbMarshaler, p.jsonpbUnmarshaler, session.ID().String(), session.UserID(), session.Username(), session.Expiry(), messageName, envelope)
		hookResult, hookErr := invokeReqBeforeHook(logger, p.config, p.runtimePool, p.jsonpbMarshaler, p.jsonpbUnmarshaler, session.ID().String(), session.UserID(), session.Username(), session.Expiry(), session.ClientIP(), session.ClientPort(), messageName, envelope)

		// Stats measurement end boundary.
		span.End()
@@ -202,7 +203,7 @@ func (p *Pipeline) ProcessRequest(logger *zap.Logger, session Session, envelope
		span := trace.NewSpan(name, nil, trace.StartOptions{})

		// Actual after hook function execution.
		invokeReqAfterHook(logger, p.config, p.runtimePool, p.jsonpbMarshaler, session.ID().String(), session.UserID(), session.Username(), session.Expiry(), messageName, envelope)
		invokeReqAfterHook(logger, p.config, p.runtimePool, p.jsonpbMarshaler, session.ID().String(), session.UserID(), session.Username(), session.Expiry(), session.ClientIP(), session.ClientPort(), messageName, envelope)

		// Stats measurement end boundary.
		span.End()
+1 −1
Original line number Diff line number Diff line
@@ -55,7 +55,7 @@ func (p *Pipeline) rpc(logger *zap.Logger, session Session, envelope *rtapi.Enve
		return
	}

	result, fnErr, _ := runtime.InvokeFunction(ExecutionModeRPC, lf, nil, session.UserID().String(), session.Username(), session.Expiry(), session.ID().String(), rpcMessage.Payload)
	result, fnErr, _ := runtime.InvokeFunction(ExecutionModeRPC, lf, nil, session.UserID().String(), session.Username(), session.Expiry(), session.ID().String(), session.ClientIP(), session.ClientPort(), rpcMessage.Payload)
	p.runtimePool.Put(runtime)
	if fnErr != nil {
		logger.Error("Runtime RPC function caused an error", zap.String("id", rpcMessage.Id), zap.Error(fnErr))
+2 −2
Original line number Diff line number Diff line
@@ -270,8 +270,8 @@ func (r *Runtime) GetCallback(e ExecutionMode, key string) *lua.LFunction {
	return nil
}

func (r *Runtime) InvokeFunction(execMode ExecutionMode, fn *lua.LFunction, queryParams map[string][]string, uid string, username string, sessionExpiry int64, sid string, payload interface{}) (interface{}, error, codes.Code) {
	ctx := NewLuaContext(r.vm, r.luaEnv, execMode, queryParams, uid, username, sessionExpiry, sid)
func (r *Runtime) InvokeFunction(execMode ExecutionMode, fn *lua.LFunction, queryParams map[string][]string, uid string, username string, sessionExpiry int64, sid string, clientIP string, clientPort string, payload interface{}) (interface{}, error, codes.Code) {
	ctx := NewLuaContext(r.vm, r.luaEnv, execMode, queryParams, sessionExpiry, uid, username, sid, clientIP, clientPort)
	var lv lua.LValue
	if payload != nil {
		lv = ConvertValue(r.vm, payload)
Loading