From e54b956e3f2aa1eddb3b784ab26a6bd57d8e4d63 Mon Sep 17 00:00:00 2001 From: Andrei Mihu Date: Mon, 23 Jul 2018 17:08:11 +0100 Subject: [PATCH] Extend available runtime hash functions. (#221) --- CHANGELOG.md | 6 + main.go | 20 +-- server/config.go | 60 ++++----- server/runtime_bit32.go | 209 ++++++++++++++++++++++++++++++++ server/runtime_module_cache.go | 1 + server/runtime_nakama_module.go | 29 +++++ server/session_ws.go | 77 +++++++----- tests/runtime_test.go | 169 +++++++++++++++++++++++++- 8 files changed, 493 insertions(+), 78 deletions(-) create mode 100644 server/runtime_bit32.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 776092f03..b18848dae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,13 @@ All notable changes to this project are documented below. The format is based on [keep a changelog](http://keepachangelog.com) and this project uses [semantic versioning](http://semver.org). ## [Unreleased] +### Added +- New `bit32` module available in the code runtime. +- New code runtime function to create MD5 hashes. +- New code runtime function to create SHA256 hashes. +### Changed +- Reduce the frequency of socket checks on known active connections. ## [2.0.2] - 2018-07-09 ### Added diff --git a/main.go b/main.go index 9c63017d4..e141f74c5 100644 --- a/main.go +++ b/main.go @@ -128,7 +128,7 @@ func main() { cookie := newOrLoadCookie(config) gacode := "UA-89792135-1" if gaenabled { - runTelemetry(startupLogger, http.DefaultClient, gacode, cookie) + runTelemetry(http.DefaultClient, gacode, cookie) } // Respect OS stop signals. @@ -206,24 +206,14 @@ func dbConnect(multiLogger *zap.Logger, config server.Config) (*sql.DB, string) // // This information is sent via Google Analytics which allows the Nakama team to // analyze usage patterns and errors in order to help improve the server. -func runTelemetry(startupLogger *zap.Logger, httpc *http.Client, gacode string, cookie string) { - err := ga.SendSessionStart(httpc, gacode, cookie) - if err != nil { - startupLogger.Debug("Send start session event failed.", zap.Error(err)) +func runTelemetry(httpc *http.Client, gacode string, cookie string) { + if ga.SendSessionStart(httpc, gacode, cookie) != nil { return } - - err = ga.SendEvent(httpc, gacode, cookie, &ga.Event{Ec: "version", Ea: fmt.Sprintf("%s+%s", version, commitID)}) - if err != nil { - startupLogger.Debug("Send event failed.", zap.Error(err)) - return - } - - err = ga.SendEvent(httpc, gacode, cookie, &ga.Event{Ec: "variant", Ea: "nakama"}) - if err != nil { - startupLogger.Debug("Send event failed.", zap.Error(err)) + if ga.SendEvent(httpc, gacode, cookie, &ga.Event{Ec: "version", Ea: fmt.Sprintf("%s+%s", version, commitID)}) != nil { return } + ga.SendEvent(httpc, gacode, cookie, &ga.Event{Ec: "variant", Ea: "nakama"}) } func newOrLoadCookie(config server.Config) string { diff --git a/server/config.go b/server/config.go index a71a3e71b..b05d70081 100644 --- a/server/config.go +++ b/server/config.go @@ -317,40 +317,42 @@ func NewSessionConfig() *SessionConfig { // SocketConfig is configuration relevant to the transport socket and protocol. type SocketConfig struct { - ServerKey string `yaml:"server_key" json:"server_key" usage:"Server key to use to establish a connection to the server."` - Port int `yaml:"port" json:"port" usage:"The port for accepting connections from the client for the given interface(s), address(es), and protocol(s). Default 7350."` - Address string `yaml:"address" json:"address" usage:"The IP address of the interface to listen for client traffic on. Default listen on all available addresses/interfaces."` - Protocol string `yaml:"protocol" json:"protocol" usage:"The network protocol to listen for traffic on. Possible values are 'tcp' for both IPv4 and IPv6, 'tcp4' for IPv4 only, or 'tcp6' for IPv6 only. Default 'tcp'."` - MaxMessageSizeBytes int64 `yaml:"max_message_size_bytes" json:"max_message_size_bytes" usage:"Maximum amount of data in bytes allowed to be read from the client socket per message. Used for real-time, gRPC and HTTP connections."` - ReadTimeoutMs int `yaml:"read_timeout_ms" json:"read_timeout_ms" usage:"Maximum duration in milliseconds for reading the entire request. Used for HTTP connections."` - WriteTimeoutMs int `yaml:"write_timeout_ms" json:"write_timeout_ms" usage:"Maximum duration in milliseconds before timing out writes of the response. Used for HTTP connections."` - IdleTimeoutMs int `yaml:"idle_timeout_ms" json:"idle_timeout_ms" usage:"Maximum amount of time in milliseconds to wait for the next request when keep-alives are enabled. Used for HTTP connections."` - WriteWaitMs int `yaml:"write_wait_ms" json:"write_wait_ms" usage:"Time in milliseconds to wait for an ack from the client when writing data. Used for real-time connections."` - PongWaitMs int `yaml:"pong_wait_ms" json:"pong_wait_ms" usage:"Time in milliseconds to wait between pong messages received from the client. Used for real-time connections."` - PingPeriodMs int `yaml:"ping_period_ms" json:"ping_period_ms" usage:"Time in milliseconds to wait between sending ping messages to the client. This value must be less than the pong_wait_ms. Used for real-time connections."` - OutgoingQueueSize int `yaml:"outgoing_queue_size" json:"outgoing_queue_size" usage:"The maximum number of messages waiting to be sent to the client. If this is exceeded the client is considered too slow and will disconnect. Used when processing real-time connections."` - SSLCertificate string `yaml:"ssl_certificate" json:"ssl_certificate" usage:"Path to certificate file if you want the server to use SSL directly. Must also supply ssl_private_key. NOT recommended for production use."` - SSLPrivateKey string `yaml:"ssl_private_key" json:"ssl_private_key" usage:"Path to private key file if you want the server to use SSL directly. Must also supply ssl_certificate. NOT recommended for production use."` - TLSCert []tls.Certificate // Created by processing SSLCertificate and SSLPrivateKey, not set from input args directly. + ServerKey string `yaml:"server_key" json:"server_key" usage:"Server key to use to establish a connection to the server."` + Port int `yaml:"port" json:"port" usage:"The port for accepting connections from the client for the given interface(s), address(es), and protocol(s). Default 7350."` + Address string `yaml:"address" json:"address" usage:"The IP address of the interface to listen for client traffic on. Default listen on all available addresses/interfaces."` + Protocol string `yaml:"protocol" json:"protocol" usage:"The network protocol to listen for traffic on. Possible values are 'tcp' for both IPv4 and IPv6, 'tcp4' for IPv4 only, or 'tcp6' for IPv6 only. Default 'tcp'."` + MaxMessageSizeBytes int64 `yaml:"max_message_size_bytes" json:"max_message_size_bytes" usage:"Maximum amount of data in bytes allowed to be read from the client socket per message. Used for real-time, gRPC and HTTP connections."` + ReadTimeoutMs int `yaml:"read_timeout_ms" json:"read_timeout_ms" usage:"Maximum duration in milliseconds for reading the entire request. Used for HTTP connections."` + WriteTimeoutMs int `yaml:"write_timeout_ms" json:"write_timeout_ms" usage:"Maximum duration in milliseconds before timing out writes of the response. Used for HTTP connections."` + IdleTimeoutMs int `yaml:"idle_timeout_ms" json:"idle_timeout_ms" usage:"Maximum amount of time in milliseconds to wait for the next request when keep-alives are enabled. Used for HTTP connections."` + WriteWaitMs int `yaml:"write_wait_ms" json:"write_wait_ms" usage:"Time in milliseconds to wait for an ack from the client when writing data. Used for real-time connections."` + PongWaitMs int `yaml:"pong_wait_ms" json:"pong_wait_ms" usage:"Time in milliseconds to wait between pong messages received from the client. Used for real-time connections."` + PingPeriodMs int `yaml:"ping_period_ms" json:"ping_period_ms" usage:"Time in milliseconds to wait between sending ping messages to the client. This value must be less than the pong_wait_ms. Used for real-time connections."` + PingBackoffThreshold int `yaml:"ping_backoff_threshold" json:"ping_backoff_threshold" usage:"Minimum number of messages received from the client during a single ping period that will delay the sending of a ping until the next ping period, to avoid sending unnecessary pings on regularly active connections. Default 20."` + OutgoingQueueSize int `yaml:"outgoing_queue_size" json:"outgoing_queue_size" usage:"The maximum number of messages waiting to be sent to the client. If this is exceeded the client is considered too slow and will disconnect. Used when processing real-time connections."` + SSLCertificate string `yaml:"ssl_certificate" json:"ssl_certificate" usage:"Path to certificate file if you want the server to use SSL directly. Must also supply ssl_private_key. NOT recommended for production use."` + SSLPrivateKey string `yaml:"ssl_private_key" json:"ssl_private_key" usage:"Path to private key file if you want the server to use SSL directly. Must also supply ssl_certificate. NOT recommended for production use."` + TLSCert []tls.Certificate // Created by processing SSLCertificate and SSLPrivateKey, not set from input args directly. } // NewTransportConfig creates a new TransportConfig struct. func NewSocketConfig() *SocketConfig { return &SocketConfig{ - ServerKey: "defaultkey", - Port: 7350, - Address: "", - Protocol: "tcp", - MaxMessageSizeBytes: 4096, - ReadTimeoutMs: 10 * 1000, - WriteTimeoutMs: 10 * 1000, - IdleTimeoutMs: 60 * 1000, - WriteWaitMs: 5000, - PongWaitMs: 10000, - PingPeriodMs: 8000, - OutgoingQueueSize: 64, - SSLCertificate: "", - SSLPrivateKey: "", + ServerKey: "defaultkey", + Port: 7350, + Address: "", + Protocol: "tcp", + MaxMessageSizeBytes: 4096, + ReadTimeoutMs: 10 * 1000, + WriteTimeoutMs: 10 * 1000, + IdleTimeoutMs: 60 * 1000, + WriteWaitMs: 5000, + PongWaitMs: 10000, + PingPeriodMs: 8000, + PingBackoffThreshold: 20, + OutgoingQueueSize: 64, + SSLCertificate: "", + SSLPrivateKey: "", } } diff --git a/server/runtime_bit32.go b/server/runtime_bit32.go new file mode 100644 index 000000000..940cc0a47 --- /dev/null +++ b/server/runtime_bit32.go @@ -0,0 +1,209 @@ +// Copyright 2018 The Nakama Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "github.com/yuin/gopher-lua" + "math" + "math/bits" +) + +var ( + Bit32LibName = "bit32" + Bit32Default64 = int64(math.Pow(2, 32) - 1) +) + +func OpenBit32(l *lua.LState) int { + mod := l.RegisterModule(Bit32LibName, bit32Funcs) + l.Push(mod) + return 1 +} + +var bit32Funcs = map[string]lua.LGFunction{ + "arshift": bit32arshift, + "band": bit32band, + "bnot": bit32not, + "bor": bit32or, + "btest": bit32btest, + "bxor": bit32xor, + "extract": bit32extract, + "replace": bit32replace, + "lrotate": bit32lrotate, + "lshift": bit32lshift, + "rrotate": bit32rrotate, + "rshift": bit32rshift, +} + +func bit32arshift(l *lua.LState) int { + a := uint32(l.CheckInt64(1)) + n := l.CheckInt(2) + if n < 0 { + l.Push(lua.LNumber(a << uint32(n*-1))) + } else if a>>uint32(31) != 0 { + l.Push(lua.LNumber((a >> uint32(n)) | (uint32(math.Pow(2, float64(n))-1) << uint32(32-n)))) + } else { + l.Push(lua.LNumber(a >> uint32(n))) + } + return 1 +} + +func bit32band(l *lua.LState) int { + a := uint32(l.OptInt64(1, Bit32Default64)) + next := 2 + for { + val := l.Get(next) + if val == lua.LNil { + break + } + if val.Type() != lua.LTNumber { + l.TypeError(next, lua.LTNumber) + return 0 + } + b := val.(lua.LNumber) + a = a & uint32(b) + next++ + } + l.Push(lua.LNumber(a)) + return 1 +} + +func bit32not(l *lua.LState) int { + a := uint32(l.CheckInt64(1)) + l.Push(lua.LNumber(^a)) + return 1 +} + +func bit32or(l *lua.LState) int { + a := uint32(l.OptInt64(1, 0)) + next := 2 + for { + val := l.Get(next) + if val == lua.LNil { + break + } + if val.Type() != lua.LTNumber { + l.TypeError(next, lua.LTNumber) + return 0 + } + b := val.(lua.LNumber) + a = a | uint32(b) + next++ + } + l.Push(lua.LNumber(a)) + return 1 +} + +func bit32btest(l *lua.LState) int { + a := uint32(l.OptInt64(1, Bit32Default64)) + next := 2 + for { + val := l.Get(next) + if val == lua.LNil { + break + } + if val.Type() != lua.LTNumber { + l.TypeError(next, lua.LTNumber) + return 0 + } + b := val.(lua.LNumber) + a = a & uint32(b) + next++ + } + l.Push(lua.LBool(a != 0)) + return 1 +} + +func bit32xor(l *lua.LState) int { + a := uint32(l.OptInt64(1, 0)) + next := 2 + for { + val := l.Get(next) + if val == lua.LNil { + break + } + if val.Type() != lua.LTNumber { + l.TypeError(next, lua.LTNumber) + return 0 + } + b := val.(lua.LNumber) + a = a ^ uint32(b) + next++ + } + l.Push(lua.LNumber(a)) + return 1 +} + +func bit32extract(l *lua.LState) int { + a := uint32(l.CheckInt64(1)) + offset := l.CheckInt(2) + width := l.OptInt(3, 1) + if offset < 0 || offset > 31 || width < 1 || width > 32 || (offset+width) > 32 { + l.RaiseError("trying to access non-existent bits") + return 0 + } + l.Push(lua.LNumber((a >> uint32(offset)) & (1< 31 || width < 1 || width > 32 || (offset+width) > 32 { + l.RaiseError("trying to access non-existent bits") + return 0 + } + a = a ^ (((a >> uint32(offset)) & (1<> uint32(32-width)) << uint32(offset) + l.Push(lua.LNumber(a | v)) + return 1 +} + +func bit32lrotate(l *lua.LState) int { + a := uint32(l.CheckInt64(1)) + n := l.CheckInt(2) + l.Push(lua.LNumber(bits.RotateLeft32(a, n))) + return 1 +} + +func bit32lshift(l *lua.LState) int { + a := uint32(l.CheckInt64(1)) + n := l.CheckInt(2) + if n < 0 { + l.Push(lua.LNumber(a >> uint32(n*-1))) + } else { + l.Push(lua.LNumber(a << uint32(n))) + } + return 1 +} + +func bit32rrotate(l *lua.LState) int { + a := uint32(l.CheckInt64(1)) + n := l.CheckInt(2) + l.Push(lua.LNumber(bits.RotateLeft32(a, n*-1))) + return 1 +} + +func bit32rshift(l *lua.LState) int { + a := uint32(l.CheckInt64(1)) + n := l.CheckInt(2) + if n < 0 { + l.Push(lua.LNumber(a << uint32(n*-1))) + } else { + l.Push(lua.LNumber(a >> uint32(n))) + } + return 1 +} diff --git a/server/runtime_module_cache.go b/server/runtime_module_cache.go index 0909cae6c..051a2f3f7 100644 --- a/server/runtime_module_cache.go +++ b/server/runtime_module_cache.go @@ -102,6 +102,7 @@ func LoadRuntimeModules(startupLogger *zap.Logger, config Config) (map[string]lu lua.OsLibName: OpenOs, lua.StringLibName: lua.OpenString, lua.MathLibName: lua.OpenMath, + Bit32LibName: OpenBit32, } startupLogger.Info("Found runtime modules", zap.Int("count", len(modulePaths)), zap.Strings("modules", modulePaths)) diff --git a/server/runtime_nakama_module.go b/server/runtime_nakama_module.go index 93c96afbd..14062834f 100644 --- a/server/runtime_nakama_module.go +++ b/server/runtime_nakama_module.go @@ -40,6 +40,7 @@ import ( "crypto/hmac" "crypto/sha256" + "crypto/md5" "github.com/golang/protobuf/ptypes/timestamp" "github.com/golang/protobuf/ptypes/wrappers" "github.com/gorhill/cronexpr" @@ -129,6 +130,8 @@ func (n *NakamaModule) Loader(l *lua.LState) int { "base16_decode": n.base16Decode, "aes128_encrypt": n.aes128Encrypt, "aes128_decrypt": n.aes128Decrypt, + "md5_hash": n.md5Hash, + "sha256_hash": n.sha256Hash, "hmac_sha256_hash": n.hmacSHA256Hash, "bcrypt_hash": n.bcryptHash, "bcrypt_compare": n.bcryptCompare, @@ -812,6 +815,32 @@ func (n *NakamaModule) aes128Decrypt(l *lua.LState) int { return 1 } +func (n *NakamaModule) md5Hash(l *lua.LState) int { + input := l.CheckString(1) + if input == "" { + l.ArgError(1, "expects input string") + return 0 + } + + hash := fmt.Sprintf("%x", md5.Sum([]byte(input))) + + l.Push(lua.LString(hash)) + return 1 +} + +func (n *NakamaModule) sha256Hash(l *lua.LState) int { + input := l.CheckString(1) + if input == "" { + l.ArgError(1, "expects input string") + return 0 + } + + hash := fmt.Sprintf("%x", sha256.Sum256([]byte(input))) + + l.Push(lua.LString(hash)) + return 1 +} + func (n *NakamaModule) hmacSHA256Hash(l *lua.LState) int { input := l.CheckString(1) if input == "" { diff --git a/server/session_ws.go b/server/session_ws.go index 45e54a988..acc35b4bb 100644 --- a/server/session_ws.go +++ b/server/session_ws.go @@ -45,16 +45,20 @@ type sessionWS struct { jsonpbMarshaler *jsonpb.Marshaler jsonpbUnmarshaler *jsonpb.Unmarshaler queuePriorityThreshold int + pingPeriodDuration time.Duration + pongWaitDuration time.Duration + writeWaitDuration time.Duration sessionRegistry *SessionRegistry matchmaker Matchmaker tracker Tracker - stopped bool - conn *websocket.Conn - pingTicker *time.Ticker - outgoingCh chan []byte - outgoingStopCh chan struct{} + stopped bool + conn *websocket.Conn + receivedMessageCounter int + pingTimer *time.Timer + outgoingCh chan []byte + outgoingStopCh chan struct{} } func NewSessionWS(logger *zap.Logger, config Config, userID uuid.UUID, username string, expiry int64, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, conn *websocket.Conn, sessionRegistry *SessionRegistry, matchmaker Matchmaker, tracker Tracker) Session { @@ -74,16 +78,20 @@ func NewSessionWS(logger *zap.Logger, config Config, userID uuid.UUID, username jsonpbMarshaler: jsonpbMarshaler, jsonpbUnmarshaler: jsonpbUnmarshaler, queuePriorityThreshold: (config.GetSocket().OutgoingQueueSize / 3) * 2, + pingPeriodDuration: time.Duration(config.GetSocket().PingPeriodMs) * time.Millisecond, + pongWaitDuration: time.Duration(config.GetSocket().PongWaitMs) * time.Millisecond, + writeWaitDuration: time.Duration(config.GetSocket().WriteWaitMs) * time.Millisecond, sessionRegistry: sessionRegistry, matchmaker: matchmaker, tracker: tracker, - stopped: false, - conn: conn, - pingTicker: time.NewTicker(time.Duration(config.GetSocket().PingPeriodMs) * time.Millisecond), - outgoingCh: make(chan []byte, config.GetSocket().OutgoingQueueSize), - outgoingStopCh: make(chan struct{}), + stopped: false, + conn: conn, + receivedMessageCounter: config.GetSocket().PingBackoffThreshold, + pingTimer: time.NewTimer(time.Duration(config.GetSocket().PingPeriodMs) * time.Millisecond), + outgoingCh: make(chan []byte, config.GetSocket().OutgoingQueueSize), + outgoingStopCh: make(chan struct{}), } } @@ -114,18 +122,13 @@ func (s *sessionWS) Expiry() int64 { func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Session, envelope *rtapi.Envelope) bool) { defer s.cleanupClosedConnection() s.conn.SetReadLimit(s.config.GetSocket().MaxMessageSizeBytes) - s.conn.SetReadDeadline(time.Now().Add(time.Duration(s.config.GetSocket().PongWaitMs) * time.Millisecond)) + s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration)) s.conn.SetPongHandler(func(string) error { - s.conn.SetReadDeadline(time.Now().Add(time.Duration(s.config.GetSocket().PongWaitMs) * time.Millisecond)) + s.pingTimer.Reset(s.pingPeriodDuration) + s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration)) return nil }) - // Send an initial ping immediately. - if !s.pingNow() { - // If the first ping fails abort the rest of the consume sequence immediately. - return - } - // Start a routine to process outbound messages. go s.processOutgoing() @@ -142,6 +145,13 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess break } + s.receivedMessageCounter-- + if s.receivedMessageCounter <= 0 { + s.receivedMessageCounter = s.config.GetSocket().PingBackoffThreshold + s.pingTimer.Reset(s.pingPeriodDuration) + s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration)) + } + request := &rtapi.Envelope{} if err = s.jsonpbUnmarshaler.Unmarshal(bytes.NewReader(data), request); err != nil { // If the payload is malformed the client is incompatible or misbehaving, either way disconnect it now. @@ -149,9 +159,16 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess break } else { // TODO Add session-global context here to cancel in-progress operations when the session is closed. - requestLogger := s.logger.With(zap.String("cid", request.Cid)) - if !processRequest(requestLogger, s, request) { - break + switch request.Cid { + case "": + if !processRequest(s.logger, s, request) { + break + } + default: + requestLogger := s.logger.With(zap.String("cid", request.Cid)) + if !processRequest(requestLogger, s, request) { + break + } } } } @@ -163,7 +180,7 @@ func (s *sessionWS) processOutgoing() { case <-s.outgoingStopCh: // Session is closing, close the outgoing process routine. return - case <-s.pingTicker.C: + case <-s.pingTimer.C: // Periodically send pings. if !s.pingNow() { // If ping fails the session will be stopped, clean up the loop. @@ -178,7 +195,7 @@ func (s *sessionWS) processOutgoing() { return } // Process the outgoing message queue. - s.conn.SetWriteDeadline(time.Now().Add(time.Duration(s.config.GetSocket().WriteWaitMs) * time.Millisecond)) + s.conn.SetWriteDeadline(time.Now().Add(s.writeWaitDuration)) if err := s.conn.WriteMessage(websocket.TextMessage, payload); err != nil { s.Unlock() s.logger.Warn("Could not write message", zap.Error(err)) @@ -195,8 +212,9 @@ func (s *sessionWS) pingNow() bool { s.Unlock() return false } - s.conn.SetWriteDeadline(time.Now().Add(time.Duration(s.config.GetSocket().WriteWaitMs) * time.Millisecond)) - err := s.conn.WriteMessage(websocket.PingMessage, []byte{}) + t := time.Now() + s.conn.SetWriteDeadline(t.Add(s.writeWaitDuration)) + err := s.conn.WriteMessage(websocket.BinaryMessage, []byte{}) s.Unlock() if err != nil { s.logger.Warn("Could not send ping, closing channel", zap.String("remoteAddress", s.conn.RemoteAddr().String()), zap.Error(err)) @@ -204,6 +222,9 @@ func (s *sessionWS) pingNow() bool { s.cleanupClosedConnection() return false } + s.pingTimer.Reset(s.pingPeriodDuration) + // Workaround for poor behaviour in some WebSocket clients. + s.conn.SetReadDeadline(t.Add(s.pongWaitDuration)) return true } @@ -289,7 +310,7 @@ func (s *sessionWS) cleanupClosedConnection() { s.tracker.UntrackAll(s.id) // Clean up internals. - s.pingTicker.Stop() + s.pingTimer.Stop() close(s.outgoingStopCh) close(s.outgoingCh) @@ -310,12 +331,12 @@ func (s *sessionWS) Close() { // Expect the caller of this session.Close() to clean up external resources (like presences) separately. // Clean up internals. - s.pingTicker.Stop() + s.pingTimer.Stop() close(s.outgoingStopCh) close(s.outgoingCh) // Send close message. - err := s.conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Duration(s.config.GetSocket().WriteWaitMs)*time.Millisecond)) + err := s.conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(s.writeWaitDuration)) if err != nil { s.logger.Warn("Could not send close message, closing prematurely", zap.String("remoteAddress", s.conn.RemoteAddr().String()), zap.Error(err)) } diff --git a/tests/runtime_test.go b/tests/runtime_test.go index 2c768e12f..7d4ac405d 100644 --- a/tests/runtime_test.go +++ b/tests/runtime_test.go @@ -35,12 +35,13 @@ import ( func vm(t *testing.T, moduleCache *server.ModuleCache) *server.RuntimePool { stdLibs := map[string]lua.LGFunction{ - lua.LoadLibName: server.OpenPackage(moduleCache), - lua.BaseLibName: lua.OpenBase, - lua.TabLibName: lua.OpenTable, - lua.OsLibName: server.OpenOs, - lua.StringLibName: lua.OpenString, - lua.MathLibName: lua.OpenMath, + lua.LoadLibName: server.OpenPackage(moduleCache), + lua.BaseLibName: lua.OpenBase, + lua.TabLibName: lua.OpenTable, + lua.OsLibName: server.OpenOs, + lua.StringLibName: lua.OpenString, + lua.MathLibName: lua.OpenMath, + server.Bit32LibName: server.OpenBit32, } db := NewDB(t) @@ -187,6 +188,136 @@ print(stats.mean(t)) vm(t, modules) } +func TestRuntimeBit32(t *testing.T) { + modules := &server.ModuleCache{ + Names: make([]string, 0), + Modules: make(map[string]*server.RuntimeModule, 0), + } + writeLuaModule(modules, "bit32-tests", ` +--[[ +Original under MIT license at https://github.com/Shopify/lua-tests/blob/master/bitwise.lua +--]] + +print("testing bitwise operations") + +assert(bit32.band() == bit32.bnot(0)) +assert(bit32.btest() == true) +assert(bit32.bor() == 0) +assert(bit32.bxor() == 0) + +assert(bit32.band() == bit32.band(0xffffffff)) +assert(bit32.band(1,2) == 0) + + +-- out-of-range numbers +assert(bit32.band(-1) == 0xffffffff) +assert(bit32.band(2^33 - 1) == 0xffffffff) +assert(bit32.band(-2^33 - 1) == 0xffffffff) +assert(bit32.band(2^33 + 1) == 1) +assert(bit32.band(-2^33 + 1) == 1) +assert(bit32.band(-2^40) == 0) +assert(bit32.band(2^40) == 0) +assert(bit32.band(-2^40 - 2) == 0xfffffffe) +assert(bit32.band(2^40 - 4) == 0xfffffffc) + +assert(bit32.lrotate(0, -1) == 0) +assert(bit32.lrotate(0, 7) == 0) +assert(bit32.lrotate(0x12345678, 4) == 0x23456781) +assert(bit32.rrotate(0x12345678, -4) == 0x23456781) +assert(bit32.lrotate(0x12345678, -8) == 0x78123456) +assert(bit32.rrotate(0x12345678, 8) == 0x78123456) +assert(bit32.lrotate(0xaaaaaaaa, 2) == 0xaaaaaaaa) +assert(bit32.lrotate(0xaaaaaaaa, -2) == 0xaaaaaaaa) +for i = -50, 50 do + assert(bit32.lrotate(0x89abcdef, i) == bit32.lrotate(0x89abcdef, i%32)) +end + +assert(bit32.lshift(0x12345678, 4) == 0x23456780) +assert(bit32.lshift(0x12345678, 8) == 0x34567800) +assert(bit32.lshift(0x12345678, -4) == 0x01234567) +assert(bit32.lshift(0x12345678, -8) == 0x00123456) +assert(bit32.lshift(0x12345678, 32) == 0) +assert(bit32.lshift(0x12345678, -32) == 0) +assert(bit32.rshift(0x12345678, 4) == 0x01234567) +assert(bit32.rshift(0x12345678, 8) == 0x00123456) +assert(bit32.rshift(0x12345678, 32) == 0) +assert(bit32.rshift(0x12345678, -32) == 0) +assert(bit32.arshift(0x12345678, 0) == 0x12345678) +assert(bit32.arshift(0x12345678, 1) == 0x12345678 / 2) +assert(bit32.arshift(0x12345678, -1) == 0x12345678 * 2) +assert(bit32.arshift(-1, 1) == 0xffffffff) +assert(bit32.arshift(-1, 24) == 0xffffffff) +assert(bit32.arshift(-1, 32) == 0xffffffff) +assert(bit32.arshift(-1, -1) == (-1 * 2) % 2^32) + +print("+") +-- some special cases +local c = {0, 1, 2, 3, 10, 0x80000000, 0xaaaaaaaa, 0x55555555, + 0xffffffff, 0x7fffffff} + +for _, b in pairs(c) do + assert(bit32.band(b) == b) + assert(bit32.band(b, b) == b) + assert(bit32.btest(b, b) == (b ~= 0)) + assert(bit32.band(b, b, b) == b) + assert(bit32.btest(b, b, b) == (b ~= 0)) + assert(bit32.band(b, bit32.bnot(b)) == 0) + assert(bit32.bor(b, bit32.bnot(b)) == bit32.bnot(0)) + assert(bit32.bor(b) == b) + assert(bit32.bor(b, b) == b) + assert(bit32.bor(b, b, b) == b) + assert(bit32.bxor(b) == b) + assert(bit32.bxor(b, b) == 0) + assert(bit32.bxor(b, 0) == b) + assert(bit32.bnot(b) ~= b) + assert(bit32.bnot(bit32.bnot(b)) == b) + assert(bit32.bnot(b) == 2^32 - 1 - b) + assert(bit32.lrotate(b, 32) == b) + assert(bit32.rrotate(b, 32) == b) + assert(bit32.lshift(bit32.lshift(b, -4), 4) == bit32.band(b, bit32.bnot(0xf))) + assert(bit32.rshift(bit32.rshift(b, 4), -4) == bit32.band(b, bit32.bnot(0xf))) + for i = -40, 40 do + assert(bit32.lshift(b, i) == math.floor((b * 2^i) % 2^32)) + end +end + +assert(not pcall(bit32.band, {})) +assert(not pcall(bit32.bnot, "a")) +assert(not pcall(bit32.lshift, 45)) +assert(not pcall(bit32.lshift, 45, print)) +assert(not pcall(bit32.rshift, 45, print)) + +print("+") + + +-- testing extract/replace + +assert(bit32.extract(0x12345678, 0, 4) == 8) +assert(bit32.extract(0x12345678, 4, 4) == 7) +assert(bit32.extract(0xa0001111, 28, 4) == 0xa) +assert(bit32.extract(0xa0001111, 31, 1) == 1) +assert(bit32.extract(0x50000111, 31, 1) == 0) +assert(bit32.extract(0xf2345679, 0, 32) == 0xf2345679) + +assert(not pcall(bit32.extract, 0, -1)) +assert(not pcall(bit32.extract, 0, 32)) +assert(not pcall(bit32.extract, 0, 0, 33)) +assert(not pcall(bit32.extract, 0, 31, 2)) + +assert(bit32.replace(0x12345678, 5, 28, 4) == 0x52345678) +assert(bit32.replace(0x12345678, 0x87654321, 0, 32) == 0x87654321) +assert(bit32.replace(0, 1, 2) == 2^2) +assert(bit32.replace(0, -1, 4) == 2^4) +assert(bit32.replace(-1, 0, 31) == 2^31 - 1) +assert(bit32.replace(-1, 0, 1, 2) == 2^32 - 7) + + +print'OK' +`) + + vm(t, modules) +} + func TestRuntimeRegisterRPCWithPayload(t *testing.T) { modules := &server.ModuleCache{ Names: make([]string, 0), @@ -420,6 +551,32 @@ nakama.register_rpc(test, "test") } } +func TestRuntimeMD5Hash(t *testing.T) { + modules := &server.ModuleCache{ + Names: make([]string, 0), + Modules: make(map[string]*server.RuntimeModule, 0), + } + writeLuaModule(modules, "md5hash-test", ` +local nk = require("nakama") +assert(nk.md5_hash("test") == "098f6bcd4621d373cade4e832627b4f6") +`) + + vm(t, modules) +} + +func TestRuntimeSHA256Hash(t *testing.T) { + modules := &server.ModuleCache{ + Names: make([]string, 0), + Modules: make(map[string]*server.RuntimeModule, 0), + } + writeLuaModule(modules, "sha256hash-test", ` +local nk = require("nakama") +assert(nk.sha256_hash("test") == "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08") +`) + + vm(t, modules) +} + func TestRuntimeBcryptHash(t *testing.T) { modules := &server.ModuleCache{ Names: make([]string, 0), -- GitLab