Commit 5ee5f485 authored by Mo Firouz's avatar Mo Firouz Committed by Andrei Mihu
Browse files

Add Before/After hooks. (#179)

Merge config file runtime env with values passed via command line.
parent 3afb6cd2
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@
--]]

local nk = require("nakama")
local du = require("debug_utils")

--[[
  Test RPC function calls from client libraries.
@@ -69,3 +70,12 @@ local function create_authoritative_match(_context, _payload)
  return nk.json_encode({ match_id = match_id })
end
nk.register_rpc(create_authoritative_match, "clientrpc.create_authoritative_match")

local function print_env(context, _)
  print("env:\n" .. du.print_r(context.env))
  local response = {
    message = context.env
  }
  return nk.json_encode(response)
end
nk.register_rpc(print_env, "clientrpc.print_env")
+3 −3
Original line number Diff line number Diff line
@@ -102,12 +102,12 @@ func main() {
	tracker.SetMatchLeaveListener(matchRegistry.Leave)
	// Separate module evaluation/validation from module loading.
	// We need the match registry to be available to wire all functions exposed to the runtime, which in turn needs the modules at least cached first.
	regRPC, err := server.ValidateRuntimeModules(jsonLogger, multiLogger, db, config, socialClient, sessionRegistry, matchRegistry, tracker, router, stdLibs, modules, once)
	regCallbacks, err := server.ValidateRuntimeModules(jsonLogger, multiLogger, db, config, socialClient, sessionRegistry, matchRegistry, tracker, router, stdLibs, modules, once)
	if err != nil {
		multiLogger.Fatal("Failed initializing runtime modules", zap.Error(err))
	}
	runtimePool := server.NewRuntimePool(jsonLogger, multiLogger, db, config, socialClient, sessionRegistry, matchRegistry, tracker, router, stdLibs, modules, regRPC, once)
	pipeline := server.NewPipeline(config, db, sessionRegistry, matchRegistry, tracker, router, runtimePool)
	runtimePool := server.NewRuntimePool(jsonLogger, multiLogger, db, config, socialClient, sessionRegistry, matchRegistry, tracker, router, stdLibs, modules, regCallbacks, once)
	pipeline := server.NewPipeline(config, db, jsonpbMarshaler, jsonpbUnmarshaler, sessionRegistry, matchRegistry, tracker, router, runtimePool)
	metrics := server.NewMetrics(multiLogger, config)
	apiServer := server.StartApiServer(jsonLogger, multiLogger, db, jsonpbMarshaler, jsonpbUnmarshaler, config, socialClient, sessionRegistry, matchRegistry, tracker, router, pipeline, runtimePool)

+150 −109
Original line number Diff line number Diff line
@@ -61,11 +61,11 @@ type ApiServer struct {
	grpcGatewayServer *http.Server
}

func StartApiServer(logger *zap.Logger, multiLogger *zap.Logger, db *sql.DB, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, config Config, socialClient *social.Client, sessionRegistry *SessionRegistry, matchRegistry MatchRegistry, tracker Tracker, router MessageRouter, pipeline *pipeline, runtimePool *RuntimePool) *ApiServer {
func StartApiServer(logger *zap.Logger, multiLogger *zap.Logger, db *sql.DB, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, config Config, socialClient *social.Client, sessionRegistry *SessionRegistry, matchRegistry MatchRegistry, tracker Tracker, router MessageRouter, pipeline *Pipeline, runtimePool *RuntimePool) *ApiServer {
	grpcServer := grpc.NewServer(
		grpc.StatsHandler(&ocgrpc.ServerHandler{IsPublicEndpoint: true}),
		grpc.MaxRecvMsgSize(int(config.GetSocket().MaxMessageSizeBytes)),
		grpc.UnaryInterceptor(SecurityInterceptorFunc(logger, config)),
		grpc.UnaryInterceptor(interceptorFunc(logger, config, runtimePool, jsonpbMarshaler, jsonpbUnmarshaler)),
	)

	s := &ApiServer{
@@ -132,7 +132,7 @@ func StartApiServer(logger *zap.Logger, multiLogger *zap.Logger, db *sql.DB, jso
		Addr:         fmt.Sprintf(":%d", config.GetSocket().Port-1),
		ReadTimeout:  time.Millisecond * time.Duration(int64(config.GetSocket().ReadTimeoutMs)),
		WriteTimeout: time.Millisecond * time.Duration(int64(config.GetSocket().WriteTimeoutMs)),
		IdleTimeout:  time.Millisecond * time.Duration(int64(config.GetSocket().IdeaTimeoutMs)),
		IdleTimeout:  time.Millisecond * time.Duration(int64(config.GetSocket().IdleTimeoutMs)),
		Handler:      handlerWithCORS,
	}

@@ -159,13 +159,55 @@ func (s *ApiServer) Healthcheck(ctx context.Context, in *empty.Empty) (*empty.Em
	return &empty.Empty{}, nil
}

func SecurityInterceptorFunc(logger *zap.Logger, config Config) func(context.Context, interface{}, *grpc.UnaryServerInfo, grpc.UnaryHandler) (interface{}, error) {
func interceptorFunc(logger *zap.Logger, config Config, runtimePool *RuntimePool, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler) func(context.Context, interface{}, *grpc.UnaryServerInfo, grpc.UnaryHandler) (interface{}, error) {
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
		//logger.Debug("Security interceptor fired", zap.Any("ctx", ctx), zap.Any("req", req), zap.Any("info", info))
		ctx, err := securityInterceptorFunc(logger, config, ctx, req, info)
		if err != nil {
			return nil, err
		}

		switch info.FullMethod {
		case "/nakama.api.Nakama/Healthcheck":
			// Healthcheck has no security.
			fallthrough
		case "/nakama.api.Nakama/RpcFunc":
			return handler(ctx, req)
		}

		uid := uuid.Nil
		username := ""
		expiry := int64(0)
		if ctx.Value(ctxUserIDKey{}) != nil {
			// incase of authentication methods, uid is nil
			uid = ctx.Value(ctxUserIDKey{}).(uuid.UUID)
			username = ctx.Value(ctxUsernameKey{}).(string)
			expiry = ctx.Value(ctxExpiryKey{}).(int64)
		}

		beforeHookResult, hookErr := invokeReqBeforeHook(logger, config, runtimePool, jsonpbMarshaler, jsonpbUnmarshaler, "", uid, username, expiry, info.FullMethod, req)
		if hookErr != nil {
			return nil, hookErr
		} else if beforeHookResult == nil {
			// if result is nil, requested resource is disabled.
			logger.Warn("Intercepted a disabled resource.",
				zap.String("resource", info.FullMethod),
				zap.String("uid", uid.String()),
				zap.String("username", username))
			return nil, status.Error(codes.NotFound, "Requested resource was not found.")
		}

		handlerResult, handlerErr := handler(ctx, beforeHookResult)
		if handlerErr == nil {
			invokeReqAfterHook(logger, config, runtimePool, jsonpbMarshaler, "", uid, username, expiry, info.FullMethod, handlerResult)
		}
		return handlerResult, handlerErr
	}
}

func securityInterceptorFunc(logger *zap.Logger, config Config, ctx context.Context, req interface{}, info *grpc.UnaryServerInfo) (context.Context, error) {
	switch info.FullMethod {
	case "/nakama.api.Nakama/Healthcheck":
		// Healthcheck has no security.
		return nil, nil
	case "/nakama.api.Nakama/AuthenticateCustom":
		fallthrough
	case "/nakama.api.Nakama/AuthenticateDevice":
@@ -197,7 +239,7 @@ func SecurityInterceptorFunc(logger *zap.Logger, config Config) func(context.Con
			// Value of "authorization" or "grpc-authorization" was empty or repeated.
			return nil, status.Error(codes.Unauthenticated, "Server key required")
		}
			username, _, ok := ParseBasicAuth(auth[0])
		username, _, ok := parseBasicAuth(auth[0])
		if !ok {
			// Value of "authorization" or "grpc-authorization" was malformed.
			return nil, status.Error(codes.Unauthenticated, "Server key invalid")
@@ -232,13 +274,13 @@ func SecurityInterceptorFunc(logger *zap.Logger, config Config) func(context.Con
				// Value of HTTP key username component did not match.
				return nil, status.Error(codes.Unauthenticated, "HTTP key invalid")
			}
				return handler(ctx, req)
			return ctx, nil
		}
		if len(auth) != 1 {
			// Value of "authorization" or "grpc-authorization" was empty or repeated.
			return nil, status.Error(codes.Unauthenticated, "Auth token invalid")
		}
			userID, username, exp, ok := ParseBearerAuth([]byte(config.GetSession().EncryptionKey), auth[0])
		userID, username, exp, ok := parseBearerAuth([]byte(config.GetSession().EncryptionKey), auth[0])
		if !ok {
			// Value of "authorization" or "grpc-authorization" was malformed or expired.
			return nil, status.Error(codes.Unauthenticated, "Auth token invalid")
@@ -263,18 +305,17 @@ func SecurityInterceptorFunc(logger *zap.Logger, config Config) func(context.Con
			// Value of "authorization" or "grpc-authorization" was empty or repeated.
			return nil, status.Error(codes.Unauthenticated, "Auth token invalid")
		}
			userID, username, exp, ok := ParseBearerAuth([]byte(config.GetSession().EncryptionKey), auth[0])
		userID, username, exp, ok := parseBearerAuth([]byte(config.GetSession().EncryptionKey), auth[0])
		if !ok {
			// Value of "authorization" or "grpc-authorization" was malformed or expired.
			return nil, status.Error(codes.Unauthenticated, "Auth token invalid")
		}
		ctx = context.WithValue(context.WithValue(context.WithValue(ctx, ctxUserIDKey{}, userID), ctxUsernameKey{}, username), ctxExpiryKey{}, exp)
	}
		return handler(ctx, req)
	}
	return ctx, nil
}

func ParseBasicAuth(auth string) (username, password string, ok bool) {
func parseBasicAuth(auth string) (username, password string, ok bool) {
	if auth == "" {
		return
	}
@@ -294,7 +335,7 @@ func ParseBasicAuth(auth string) (username, password string, ok bool) {
	return cs[:s], cs[s+1:], true
}

func ParseBearerAuth(hmacSecretByte []byte, auth string) (userID uuid.UUID, username string, exp int64, ok bool) {
func parseBearerAuth(hmacSecretByte []byte, auth string) (userID uuid.UUID, username string, exp int64, ok bool) {
	if auth == "" {
		return
	}
@@ -302,10 +343,10 @@ func ParseBearerAuth(hmacSecretByte []byte, auth string) (userID uuid.UUID, user
	if !strings.HasPrefix(auth, prefix) {
		return
	}
	return ParseToken(hmacSecretByte, string(auth[len(prefix):]))
	return parseToken(hmacSecretByte, string(auth[len(prefix):]))
}

func ParseToken(hmacSecretByte []byte, tokenString string) (userID uuid.UUID, username string, exp int64, ok bool) {
func parseToken(hmacSecretByte []byte, tokenString string) (userID uuid.UUID, username string, exp int64, ok bool) {
	token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
		if s, ok := token.Method.(*jwt.SigningMethodHMAC); !ok || s.Hash != crypto.SHA256 {
			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
+14 −4
Original line number Diff line number Diff line
@@ -33,7 +33,7 @@ func (s *ApiServer) RpcFunc(ctx context.Context, in *api.Rpc) (*api.Rpc, error)

	id := strings.ToLower(in.Id)

	if !s.runtimePool.HasRPC(id) {
	if !s.runtimePool.HasCallback(RPC, id) {
		return nil, status.Error(codes.NotFound, "RPC function not found")
	}

@@ -51,14 +51,15 @@ func (s *ApiServer) RpcFunc(ctx context.Context, in *api.Rpc) (*api.Rpc, error)
	}

	runtime := s.runtimePool.Get()
	lf := runtime.GetRuntimeCallback(RPC, id)
	lf := runtime.GetCallback(RPC, id)
	if lf == nil {
		s.runtimePool.Put(runtime)
		return nil, status.Error(codes.NotFound, "RPC function not found")
	}

	result, fnErr, code := runtime.InvokeFunctionRPC(lf, uid, username, expiry, "", in.Payload)
	result, fnErr, code := runtime.InvokeFunction(RPC, lf, uid, username, expiry, "", in.Payload)
	s.runtimePool.Put(runtime)

	if fnErr != nil {
		s.logger.Error("Runtime RPC function caused an error", zap.String("id", in.Id), zap.Error(fnErr))
		if apiErr, ok := fnErr.(*lua.ApiError); ok && !s.config.GetLog().Verbose {
@@ -78,5 +79,14 @@ func (s *ApiServer) RpcFunc(ctx context.Context, in *api.Rpc) (*api.Rpc, error)
		}
	}

	return &api.Rpc{Payload: result}, nil
	if result == nil {
		return &api.Rpc{}, nil
	}

	if payload, ok := result.(string); !ok {
		s.logger.Warn("Runtime function returned invalid data", zap.Any("result", result))
		return nil, status.Error(codes.Internal, "Runtime function returned invalid data - only allowed one return value of type String/Byte.")
	} else {
		return &api.Rpc{Payload: payload}, nil
	}
}
+35 −9
Original line number Diff line number Diff line
@@ -72,6 +72,7 @@ func ParseArgs(logger *zap.Logger, args []string) Config {
			}
		}
	}
	runtimeEnvironment := convertRuntimeEnv(logger, mainConfig.GetRuntime().Environment, mainConfig.GetRuntime().Env)

	// Override config with those passed from command-line.
	mainFlagSet := flag.NewFlagSet("nakama", flag.ExitOnError)
@@ -99,6 +100,8 @@ func ParseArgs(logger *zap.Logger, args []string) Config {
		mainConfig.GetRuntime().Path = filepath.Join(mainConfig.GetDataDir(), "modules")
	}

	mainConfig.GetRuntime().Environment = convertRuntimeEnv(logger, runtimeEnvironment, mainConfig.GetRuntime().Env)

	// Log warnings for insecure default parameter values.
	if mainConfig.GetSocket().ServerKey == "defaultkey" {
		logger.Warn("WARNING: insecure default parameter value, change this for production!", zap.String("param", "socket.server_key"))
@@ -113,6 +116,27 @@ func ParseArgs(logger *zap.Logger, args []string) Config {
	return mainConfig
}

func convertRuntimeEnv(logger *zap.Logger, existingEnv map[string]interface{}, mergeEnv []string) map[string]interface{} {
	envMap := make(map[string]interface{}, len(existingEnv))
	for k, v := range existingEnv {
		envMap[k] = v
	}

	for _, e := range mergeEnv {
		if !strings.Contains(e, "=") {
			logger.Fatal("Invalid runtime environment value.", zap.String("value", e))
		}

		kv := strings.SplitN(e, "=", 2) // the value can contain the character "=" many times over.
		if len(kv) == 1 {
			envMap[kv[0]] = ""
		} else if len(kv) == 2 {
			envMap[kv[0]] = kv[1]
		}
	}
	return envMap
}

type config struct {
	Name     string          `yaml:"name" json:"name" usage:"Nakama server’s node name - must be unique."`
	Config   string          `yaml:"config" json:"config" usage:"The absolute file path to configuration YAML file."`
@@ -201,8 +225,8 @@ func NewLogConfig() *LogConfig {

// MetricsConfig is configuration relevant to metrics capturing and output.
type MetricsConfig struct {
	ReportingFreqSec     int    `yaml:"reporting_freq_sec" json:"reporting_freq_sec" usage:"Frequency of metrics exports. Default is 1 second."`
	StackdriverProjectID string `yaml:"stackdriver_projectid" json:"stackdriver_projectid" usage:"This is the identifier of the Stackdriver project the server is uploading the stats data to. Setting this enabled metrics to be exported to Stackdriver."`
	ReportingFreqSec     int    `yaml:"reporting_freq_sec" json:"reporting_freq_sec" usage:"Frequency of metrics exports. Default is 10 seconds."`
	StackdriverProjectID string `yaml:"stackdriver_projectid" json:"stackdriver_projectid" usage:"This is the identifier of the Stackdriver project the server is uploading the stats data to. Setting this enables metrics to be exported to Stackdriver."`
	Namespace            string `yaml:"namespace" json:"namespace" usage:"Namespace for Prometheus or prefix for Stackdriver metrics. It will always prepend node name."`
	PrometheusPort       int    `yaml:"prometheus_port" json:"prometheus_port" usage:"Port to expose Prometheus. If '0' Prometheus exports are disabled."`
}
@@ -210,7 +234,7 @@ type MetricsConfig struct {
// NewMetricsConfig creates a new MatricsConfig struct.
func NewMetricsConfig() *MetricsConfig {
	return &MetricsConfig{
		ReportingFreqSec:     1,
		ReportingFreqSec:     10,
		StackdriverProjectID: "",
		Namespace:            "",
		PrometheusPort:       0,
@@ -238,7 +262,7 @@ type SocketConfig struct {
	MaxMessageSizeBytes int64  `yaml:"max_message_size_bytes" json:"max_message_size_bytes" usage:"Maximum amount of data in bytes allowed to be read from the client socket per message. Used for real-time, gRPC and HTTP connections."`
	ReadTimeoutMs       int    `yaml:"read_timeout_ms" json:"read_timeout_ms" usage:"Maximum duration in milliseconds for reading the entire request. Used for HTTP connections."`
	WriteTimeoutMs      int    `yaml:"write_timeout_ms" json:"write_timeout_ms" usage:"Maximum duration in milliseconds before timing out writes of the response. Used for HTTP connections."`
	IdeaTimeoutMs       int    `yaml:"idle_timeout_ms" json:"idle_timeout_ms" usage:"Maximum amount of time in milliseconds to wait for the next request when keep-alives are enabled. Used for HTTP connections."`
	IdleTimeoutMs       int    `yaml:"idle_timeout_ms" json:"idle_timeout_ms" usage:"Maximum amount of time in milliseconds to wait for the next request when keep-alives are enabled. Used for HTTP connections."`
	WriteWaitMs         int    `yaml:"write_wait_ms" json:"write_wait_ms" usage:"Time in milliseconds to wait for an ack from the client when writing data. Used for real-time connections."`
	PongWaitMs          int    `yaml:"pong_wait_ms" json:"pong_wait_ms" usage:"Time in milliseconds to wait between pong messages received from the client. Used for real-time connections."`
	PingPeriodMs        int    `yaml:"ping_period_ms" json:"ping_period_ms" usage:"Time in milliseconds to wait between sending ping messages to the client. This value must be less than the pong_wait_ms. Used for real-time connections."`
@@ -255,7 +279,7 @@ func NewSocketConfig() *SocketConfig {
		MaxMessageSizeBytes: 2048,
		ReadTimeoutMs:       10 * 1000,
		WriteTimeoutMs:      10 * 1000,
		IdeaTimeoutMs:       60 * 1000,
		IdleTimeoutMs:       60 * 1000,
		WriteWaitMs:         5000,
		PongWaitMs:          10000,
		PingPeriodMs:        8000,
@@ -306,7 +330,8 @@ func NewSocialConfig() *SocialConfig {

// RuntimeConfig is configuration relevant to the Runtime Lua VM.
type RuntimeConfig struct {
	Environment map[string]interface{} `yaml:"env" json:"env"` // Not supported in FlagOverrides.
	Environment map[string]interface{}
	Env         []string `yaml:"env" json:"env"`
	Path        string   `yaml:"path" json:"path" usage:"Path for the server to scan for *.lua files."`
	HTTPKey     string   `yaml:"http_key" json:"http_key" usage:"Runtime HTTP Invocation key."`
}
@@ -314,7 +339,8 @@ type RuntimeConfig struct {
// NewRuntimeConfig creates a new RuntimeConfig struct.
func NewRuntimeConfig() *RuntimeConfig {
	return &RuntimeConfig{
		Environment: make(map[string]interface{}),
		Environment: make(map[string]interface{}, 0),
		Env:         make([]string, 0),
		Path:        "",
		HTTPKey:     "defaultkey",
	}
Loading