Loading CHANGELOG.md +3 −0 Original line number Diff line number Diff line Loading @@ -4,6 +4,9 @@ All notable changes to this project are documented below. The format is based on [keep a changelog](http://keepachangelog.com) and this project uses [semantic versioning](http://semver.org). ## [Unreleased] ### Added - Periodically check database hostname for underlying address changes. ### Fixed - Fix optimistic email imports when linking social profiles. Loading main.go +3 −3 Original line number Diff line number Diff line Loading @@ -114,12 +114,12 @@ func main() { } startupLogger.Info("Database connections", zap.Strings("dsns", redactedAddresses)) db, dbVersion := server.DbConnect(startupLogger, config) startupLogger.Info("Database information", zap.String("version", dbVersion)) // Global server context. ctx, ctxCancelFn := context.WithCancel(context.Background()) db, dbVersion := server.DbConnect(ctx, startupLogger, config) startupLogger.Info("Database information", zap.String("version", dbVersion)) // Check migration status and fail fast if the schema has diverged. migrate.StartupCheck(startupLogger, db) Loading server/db.go +117 −12 Original line number Diff line number Diff line Loading @@ -19,23 +19,27 @@ import ( "database/sql" "errors" "fmt" "net" "net/url" "strings" "time" "github.com/jackc/pgconn" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v4/stdlib" "go.uber.org/zap" ) func DbConnect(multiLogger *zap.Logger, config Config) (*sql.DB, string) { var ErrDatabaseDriverMismatch = errors.New("database driver mismatch") func DbConnect(ctx context.Context, logger *zap.Logger, config Config) (*sql.DB, string) { rawURL := config.GetDatabase().Addresses[0] if !(strings.HasPrefix(rawURL, "postgresql://") || strings.HasPrefix(rawURL, "postgres://")) { rawURL = fmt.Sprintf("postgres://%s", rawURL) } parsedURL, err := url.Parse(rawURL) if err != nil { multiLogger.Fatal("Bad database connection URL", zap.Error(err)) logger.Fatal("Bad database connection URL", zap.Error(err)) } query := parsedURL.Query() if len(query.Get("sslmode")) == 0 { Loading @@ -50,19 +54,22 @@ func DbConnect(multiLogger *zap.Logger, config Config) (*sql.DB, string) { parsedURL.Path = "/nakama" } multiLogger.Debug("Complete database connection URL", zap.String("raw_url", parsedURL.String())) // Resolve initial database address based on host before connecting. resolvedAddr, resolvedAddrMap := dbResolveAddress(ctx, logger, parsedURL.Host) logger.Debug("Complete database connection URL", zap.String("raw_url", parsedURL.String())) db, err := sql.Open("pgx", parsedURL.String()) if err != nil { multiLogger.Fatal("Error connecting to database", zap.Error(err)) logger.Fatal("Error connecting to database", zap.Error(err)) } // Limit the time allowed to ping database and get version to 15 seconds total. ctx, ctxCancelFn := context.WithTimeout(context.Background(), 15*time.Second) defer ctxCancelFn() if err = db.PingContext(ctx); err != nil { // Limit max time allowed across database ping and version fetch to 15 seconds total. pingCtx, pingCtxCancelFn := context.WithTimeout(ctx, 15*time.Second) defer pingCtxCancelFn() if err = db.PingContext(pingCtx); err != nil { if strings.HasSuffix(err.Error(), "does not exist (SQLSTATE 3D000)") { multiLogger.Fatal("Database schema not found, run `nakama migrate up`", zap.Error(err)) logger.Fatal("Database schema not found, run `nakama migrate up`", zap.Error(err)) } multiLogger.Fatal("Error pinging database", zap.Error(err)) logger.Fatal("Error pinging database", zap.Error(err)) } db.SetConnMaxLifetime(time.Millisecond * time.Duration(config.GetDatabase().ConnMaxLifetimeMs)) Loading @@ -70,13 +77,111 @@ func DbConnect(multiLogger *zap.Logger, config Config) (*sql.DB, string) { db.SetMaxIdleConns(config.GetDatabase().MaxIdleConns) var dbVersion string if err = db.QueryRowContext(ctx, "SELECT version()").Scan(&dbVersion); err != nil { multiLogger.Fatal("Error querying database version", zap.Error(err)) if err = db.QueryRowContext(pingCtx, "SELECT version()").Scan(&dbVersion); err != nil { logger.Fatal("Error querying database version", zap.Error(err)) } // Periodically check database hostname for underlying address changes. go func() { ticker := time.NewTicker(1 * time.Minute) for { select { case <-ctx.Done(): return case <-ticker.C: newResolvedAddr, newResolvedAddrMap := dbResolveAddress(ctx, logger, parsedURL.Host) if len(resolvedAddr) == 0 { // Could only happen when initial resolve above failed, and all resolves since have also failed. // Trust the database driver in this case. resolvedAddr = newResolvedAddr resolvedAddrMap = newResolvedAddrMap break } if len(newResolvedAddr) == 0 { // New addresses failed to resolve, but had previous ones. Trust the database driver in this case. break } // Check for any changes in the resolved addresses. drain := len(resolvedAddrMap) != len(newResolvedAddrMap) if !drain { for addr := range newResolvedAddrMap { if _, found := resolvedAddrMap[addr]; !found { drain = true break } } } if !drain { // No changes. break } // Changes found. Drain the pool and allow the database driver to open fresh connections. // Rely on the database driver to re-do its own hostname to address resolution. var acquired int conns := make([]*sql.Conn, 0, config.GetDatabase().MaxOpenConns) for acquired < config.GetDatabase().MaxOpenConns { acquired++ conn, err := db.Conn(ctx) if err != nil { if err == context.Canceled { // Server shutting down. return } // Log errors acquiring connections, but proceed without the failed connection anyway. logger.Error("Error acquiring database connection", zap.Error(err)) continue } conns = append(conns, conn) } logger.Warn("Database rotating all open connections due to address change", zap.Int("count", len(conns)), zap.Strings("previous", resolvedAddr), zap.Strings("updated", newResolvedAddr)) resolvedAddr = newResolvedAddr resolvedAddrMap = newResolvedAddrMap for _, conn := range conns { if err := conn.Raw(func(driverConn interface{}) error { pgc, ok := driverConn.(*stdlib.Conn) if !ok { return ErrDatabaseDriverMismatch } if err := pgc.Close(); err != nil { return err } return nil }); err != nil { logger.Error("Error closing database connection", zap.Error(err)) } if err := conn.Close(); err != nil { logger.Error("Error releasing database connection", zap.Error(err)) } } } } }() return db, dbVersion } func dbResolveAddress(ctx context.Context, logger *zap.Logger, host string) ([]string, map[string]struct{}) { resolveCtx, resolveCtxCancelFn := context.WithTimeout(ctx, 15*time.Second) defer resolveCtxCancelFn() addr, err := net.DefaultResolver.LookupHost(resolveCtx, host) if err != nil { logger.Debug("Error resolving database address", zap.String("host", host), zap.Error(err)) return nil, nil } addrMap := make(map[string]struct{}, len(addr)) for _, a := range addr { addrMap[a] = struct{}{} } return addr, addrMap } // Tx is used to permit clients to implement custom transaction logic. type Tx interface { ExecContext(context.Context, string, ...interface{}) (sql.Result, error) Loading Loading
CHANGELOG.md +3 −0 Original line number Diff line number Diff line Loading @@ -4,6 +4,9 @@ All notable changes to this project are documented below. The format is based on [keep a changelog](http://keepachangelog.com) and this project uses [semantic versioning](http://semver.org). ## [Unreleased] ### Added - Periodically check database hostname for underlying address changes. ### Fixed - Fix optimistic email imports when linking social profiles. Loading
main.go +3 −3 Original line number Diff line number Diff line Loading @@ -114,12 +114,12 @@ func main() { } startupLogger.Info("Database connections", zap.Strings("dsns", redactedAddresses)) db, dbVersion := server.DbConnect(startupLogger, config) startupLogger.Info("Database information", zap.String("version", dbVersion)) // Global server context. ctx, ctxCancelFn := context.WithCancel(context.Background()) db, dbVersion := server.DbConnect(ctx, startupLogger, config) startupLogger.Info("Database information", zap.String("version", dbVersion)) // Check migration status and fail fast if the schema has diverged. migrate.StartupCheck(startupLogger, db) Loading
server/db.go +117 −12 Original line number Diff line number Diff line Loading @@ -19,23 +19,27 @@ import ( "database/sql" "errors" "fmt" "net" "net/url" "strings" "time" "github.com/jackc/pgconn" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v4/stdlib" "go.uber.org/zap" ) func DbConnect(multiLogger *zap.Logger, config Config) (*sql.DB, string) { var ErrDatabaseDriverMismatch = errors.New("database driver mismatch") func DbConnect(ctx context.Context, logger *zap.Logger, config Config) (*sql.DB, string) { rawURL := config.GetDatabase().Addresses[0] if !(strings.HasPrefix(rawURL, "postgresql://") || strings.HasPrefix(rawURL, "postgres://")) { rawURL = fmt.Sprintf("postgres://%s", rawURL) } parsedURL, err := url.Parse(rawURL) if err != nil { multiLogger.Fatal("Bad database connection URL", zap.Error(err)) logger.Fatal("Bad database connection URL", zap.Error(err)) } query := parsedURL.Query() if len(query.Get("sslmode")) == 0 { Loading @@ -50,19 +54,22 @@ func DbConnect(multiLogger *zap.Logger, config Config) (*sql.DB, string) { parsedURL.Path = "/nakama" } multiLogger.Debug("Complete database connection URL", zap.String("raw_url", parsedURL.String())) // Resolve initial database address based on host before connecting. resolvedAddr, resolvedAddrMap := dbResolveAddress(ctx, logger, parsedURL.Host) logger.Debug("Complete database connection URL", zap.String("raw_url", parsedURL.String())) db, err := sql.Open("pgx", parsedURL.String()) if err != nil { multiLogger.Fatal("Error connecting to database", zap.Error(err)) logger.Fatal("Error connecting to database", zap.Error(err)) } // Limit the time allowed to ping database and get version to 15 seconds total. ctx, ctxCancelFn := context.WithTimeout(context.Background(), 15*time.Second) defer ctxCancelFn() if err = db.PingContext(ctx); err != nil { // Limit max time allowed across database ping and version fetch to 15 seconds total. pingCtx, pingCtxCancelFn := context.WithTimeout(ctx, 15*time.Second) defer pingCtxCancelFn() if err = db.PingContext(pingCtx); err != nil { if strings.HasSuffix(err.Error(), "does not exist (SQLSTATE 3D000)") { multiLogger.Fatal("Database schema not found, run `nakama migrate up`", zap.Error(err)) logger.Fatal("Database schema not found, run `nakama migrate up`", zap.Error(err)) } multiLogger.Fatal("Error pinging database", zap.Error(err)) logger.Fatal("Error pinging database", zap.Error(err)) } db.SetConnMaxLifetime(time.Millisecond * time.Duration(config.GetDatabase().ConnMaxLifetimeMs)) Loading @@ -70,13 +77,111 @@ func DbConnect(multiLogger *zap.Logger, config Config) (*sql.DB, string) { db.SetMaxIdleConns(config.GetDatabase().MaxIdleConns) var dbVersion string if err = db.QueryRowContext(ctx, "SELECT version()").Scan(&dbVersion); err != nil { multiLogger.Fatal("Error querying database version", zap.Error(err)) if err = db.QueryRowContext(pingCtx, "SELECT version()").Scan(&dbVersion); err != nil { logger.Fatal("Error querying database version", zap.Error(err)) } // Periodically check database hostname for underlying address changes. go func() { ticker := time.NewTicker(1 * time.Minute) for { select { case <-ctx.Done(): return case <-ticker.C: newResolvedAddr, newResolvedAddrMap := dbResolveAddress(ctx, logger, parsedURL.Host) if len(resolvedAddr) == 0 { // Could only happen when initial resolve above failed, and all resolves since have also failed. // Trust the database driver in this case. resolvedAddr = newResolvedAddr resolvedAddrMap = newResolvedAddrMap break } if len(newResolvedAddr) == 0 { // New addresses failed to resolve, but had previous ones. Trust the database driver in this case. break } // Check for any changes in the resolved addresses. drain := len(resolvedAddrMap) != len(newResolvedAddrMap) if !drain { for addr := range newResolvedAddrMap { if _, found := resolvedAddrMap[addr]; !found { drain = true break } } } if !drain { // No changes. break } // Changes found. Drain the pool and allow the database driver to open fresh connections. // Rely on the database driver to re-do its own hostname to address resolution. var acquired int conns := make([]*sql.Conn, 0, config.GetDatabase().MaxOpenConns) for acquired < config.GetDatabase().MaxOpenConns { acquired++ conn, err := db.Conn(ctx) if err != nil { if err == context.Canceled { // Server shutting down. return } // Log errors acquiring connections, but proceed without the failed connection anyway. logger.Error("Error acquiring database connection", zap.Error(err)) continue } conns = append(conns, conn) } logger.Warn("Database rotating all open connections due to address change", zap.Int("count", len(conns)), zap.Strings("previous", resolvedAddr), zap.Strings("updated", newResolvedAddr)) resolvedAddr = newResolvedAddr resolvedAddrMap = newResolvedAddrMap for _, conn := range conns { if err := conn.Raw(func(driverConn interface{}) error { pgc, ok := driverConn.(*stdlib.Conn) if !ok { return ErrDatabaseDriverMismatch } if err := pgc.Close(); err != nil { return err } return nil }); err != nil { logger.Error("Error closing database connection", zap.Error(err)) } if err := conn.Close(); err != nil { logger.Error("Error releasing database connection", zap.Error(err)) } } } } }() return db, dbVersion } func dbResolveAddress(ctx context.Context, logger *zap.Logger, host string) ([]string, map[string]struct{}) { resolveCtx, resolveCtxCancelFn := context.WithTimeout(ctx, 15*time.Second) defer resolveCtxCancelFn() addr, err := net.DefaultResolver.LookupHost(resolveCtx, host) if err != nil { logger.Debug("Error resolving database address", zap.String("host", host), zap.Error(err)) return nil, nil } addrMap := make(map[string]struct{}, len(addr)) for _, a := range addr { addrMap[a] = struct{}{} } return addr, addrMap } // Tx is used to permit clients to implement custom transaction logic. type Tx interface { ExecContext(context.Context, string, ...interface{}) (sql.Result, error) Loading