Commit b44817ae authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Improve parsing of client IP and port.

parent 5b1adece
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr
## [Unreleased]
### Changed
- Default maximum database connection lifetime is now 1 hour.
- Improved parsing of client address and port for incoming requests and socket connections.

### Fixed
- CRON expressions for leaderboard and tournament resets now allow concurrent processing.
+38 −14
Original line number Diff line number Diff line
@@ -432,10 +432,8 @@ func decompressHandler(logger *zap.Logger, h http.Handler) http.HandlerFunc {
	})
}

func extractClientAddress(logger *zap.Logger, ctx context.Context) (string, string) {
	clientAddr := ""
	clientIP := ""
	clientPort := ""
func extractClientAddressFromContext(logger *zap.Logger, ctx context.Context) (string, string) {
	var clientAddr string
	md, _ := metadata.FromIncomingContext(ctx)
	if ips := md.Get("x-forwarded-for"); len(ips) > 0 {
		// Look for gRPC-Gateway / LB header.
@@ -445,14 +443,41 @@ func extractClientAddress(logger *zap.Logger, ctx context.Context) (string, stri
		clientAddr = peerInfo.Addr.String()
	}

	return extractClientAddress(logger, clientAddr)
}

func extractClientAddressFromRequest(logger *zap.Logger, r *http.Request) (string, string) {
	var clientAddr string
	if ips := r.Header.Get("x-forwarded-for"); len(ips) > 0 {
		clientAddr = strings.Split(ips, ",")[0]
	} else {
		clientAddr = r.RemoteAddr
	}

	return extractClientAddress(logger, clientAddr)
}

func extractClientAddress(logger *zap.Logger, clientAddr string) (string, string) {
	var clientIP, clientPort string

	if clientAddr != "" {
		// It's possible the request metadata had no client address string.

		clientAddr = strings.TrimSpace(clientAddr)
		if host, port, err := net.SplitHostPort(clientAddr); err == nil {
			clientIP = host
			clientPort = port
	} else if addrErr, ok := err.(*net.AddrError); ok && addrErr.Err == "missing port in address" {
		} else if addrErr, ok := err.(*net.AddrError); ok {
			switch addrErr.Err {
			case "missing port in address":
				fallthrough
			case "too many colons in address":
				clientIP = clientAddr
	} else {
		logger.Debug("Could not extract client address from request.", zap.Error(err))
			default:
				// Unknown address error, ignore the address.
			}
		}
		// At this point err may still be a non-nil value that's not a *net.AddrError, ignore the address.
	}

	return clientIP, clientPort
@@ -460,12 +485,11 @@ func extractClientAddress(logger *zap.Logger, ctx context.Context) (string, stri

func traceApiBefore(ctx context.Context, logger *zap.Logger, fullMethodName string, fn func(clientIP, clientPort string) error) error {
	name := fmt.Sprintf("%v-before", fullMethodName)
	clientIP, clientPort := extractClientAddress(logger, ctx)
	clientIP, clientPort := extractClientAddressFromContext(logger, ctx)
	statsCtx, err := tag.New(ctx, tag.Upsert(MetricsFunction, name))
	if err != nil {
		// If there was an error processing the stats, just execute the function.
		logger.Warn("Error tagging API before stats", zap.String("full_method_name", fullMethodName), zap.Error(err))
		clientIP, clientPort := extractClientAddress(logger, ctx)
		return fn(clientIP, clientPort)
	}
	startNanos := time.Now().UTC().UnixNano()
@@ -481,7 +505,7 @@ func traceApiBefore(ctx context.Context, logger *zap.Logger, fullMethodName stri

func traceApiAfter(ctx context.Context, logger *zap.Logger, fullMethodName string, fn func(clientIP, clientPort string)) {
	name := fmt.Sprintf("%v-after", logger)
	clientIP, clientPort := extractClientAddress(logger, ctx)
	clientIP, clientPort := extractClientAddressFromContext(logger, ctx)
	statsCtx, err := tag.New(ctx, tag.Upsert(MetricsFunction, name))
	if err != nil {
		// If there was an error processing the stats, just execute the function.
+2 −23
Original line number Diff line number Diff line
@@ -15,16 +15,13 @@
package server

import (
	"net"
	"strings"

	"github.com/gofrs/uuid"
	"github.com/heroiclabs/nakama/api"
	"go.uber.org/zap"
	"golang.org/x/net/context"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/peer"
	"google.golang.org/grpc/status"
)

@@ -64,26 +61,8 @@ func (s *ApiServer) RpcFunc(ctx context.Context, in *api.Rpc) (*api.Rpc, error)
	if e := ctx.Value(ctxExpiryKey{}); e != nil {
		expiry = e.(int64)
	}
	clientAddr := ""
	clientIP := ""
	clientPort := ""
	md, _ := metadata.FromIncomingContext(ctx)
	if ips := md.Get("x-forwarded-for"); len(ips) > 0 {
		// look for gRPC-Gateway / LB header
		clientAddr = strings.Split(ips[0], ",")[0]
	} else if peerInfo, ok := peer.FromContext(ctx); ok {
		// if missing, try to look up gRPC peer info
		clientAddr = peerInfo.Addr.String()
	}
	clientAddr = strings.TrimSpace(clientAddr)
	if host, port, err := net.SplitHostPort(clientAddr); err == nil {
		clientIP = host
		clientPort = port
	} else if addrErr, ok := err.(*net.AddrError); ok && addrErr.Err == "missing port in address" {
		clientIP = clientAddr
	} else {
		s.logger.Debug("Could not extract client address from request.", zap.Error(err))
	}

	clientIP, clientPort := extractClientAddressFromContext(s.logger, ctx)

	result, fnErr, code := fn(ctx, queryParams, uid, username, expiry, "", clientIP, clientPort, in.Payload)
	if fnErr != nil {
+2 −22
Original line number Diff line number Diff line
@@ -15,11 +15,8 @@
package server

import (
	"net"
	"net/http"
	"strings"

	"context"
	"net/http"
	"time"

	"github.com/golang/protobuf/jsonpb"
@@ -67,24 +64,7 @@ func NewSocketWsAcceptor(logger *zap.Logger, config Config, sessionRegistry Sess
			return
		}

		clientAddr := ""
		clientIP := ""
		clientPort := ""
		if ips := r.Header.Get("x-forwarded-for"); len(ips) > 0 {
			clientAddr = strings.Split(ips, ",")[0]
		} else {
			clientAddr = r.RemoteAddr
		}

		clientAddr = strings.TrimSpace(clientAddr)
		if host, port, err := net.SplitHostPort(clientAddr); err == nil {
			clientIP = host
			clientPort = port
		} else if addrErr, ok := err.(*net.AddrError); ok && addrErr.Err == "missing port in address" {
			clientIP = clientAddr
		} else {
			logger.Debug("Could not extract client address from request.", zap.Error(err))
		}
		clientIP, clientPort := extractClientAddressFromRequest(logger, r)

		status := false
		if r.URL.Query().Get("status") == "true" {