Commit 84a1f990 authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Periodically check database hostname for underlying address changes.

parent e9501524
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -4,6 +4,9 @@ All notable changes to this project are documented below.
The format is based on [keep a changelog](http://keepachangelog.com) and this project uses [semantic versioning](http://semver.org).

## [Unreleased]
### Added
- Periodically check database hostname for underlying address changes.

### Fixed
- Fix optimistic email imports when linking social profiles.

+3 −3
Original line number Diff line number Diff line
@@ -114,12 +114,12 @@ func main() {
	}
	startupLogger.Info("Database connections", zap.Strings("dsns", redactedAddresses))

	db, dbVersion := server.DbConnect(startupLogger, config)
	startupLogger.Info("Database information", zap.String("version", dbVersion))

	// Global server context.
	ctx, ctxCancelFn := context.WithCancel(context.Background())

	db, dbVersion := server.DbConnect(ctx, startupLogger, config)
	startupLogger.Info("Database information", zap.String("version", dbVersion))

	// Check migration status and fail fast if the schema has diverged.
	migrate.StartupCheck(startupLogger, db)

+117 −12
Original line number Diff line number Diff line
@@ -19,23 +19,27 @@ import (
	"database/sql"
	"errors"
	"fmt"
	"net"
	"net/url"
	"strings"
	"time"

	"github.com/jackc/pgconn"
	"github.com/jackc/pgerrcode"
	"github.com/jackc/pgx/v4/stdlib"
	"go.uber.org/zap"
)

func DbConnect(multiLogger *zap.Logger, config Config) (*sql.DB, string) {
var ErrDatabaseDriverMismatch = errors.New("database driver mismatch")

func DbConnect(ctx context.Context, logger *zap.Logger, config Config) (*sql.DB, string) {
	rawURL := config.GetDatabase().Addresses[0]
	if !(strings.HasPrefix(rawURL, "postgresql://") || strings.HasPrefix(rawURL, "postgres://")) {
		rawURL = fmt.Sprintf("postgres://%s", rawURL)
	}
	parsedURL, err := url.Parse(rawURL)
	if err != nil {
		multiLogger.Fatal("Bad database connection URL", zap.Error(err))
		logger.Fatal("Bad database connection URL", zap.Error(err))
	}
	query := parsedURL.Query()
	if len(query.Get("sslmode")) == 0 {
@@ -50,19 +54,22 @@ func DbConnect(multiLogger *zap.Logger, config Config) (*sql.DB, string) {
		parsedURL.Path = "/nakama"
	}

	multiLogger.Debug("Complete database connection URL", zap.String("raw_url", parsedURL.String()))
	// Resolve initial database address based on host before connecting.
	resolvedAddr, resolvedAddrMap := dbResolveAddress(ctx, logger, parsedURL.Host)

	logger.Debug("Complete database connection URL", zap.String("raw_url", parsedURL.String()))
	db, err := sql.Open("pgx", parsedURL.String())
	if err != nil {
		multiLogger.Fatal("Error connecting to database", zap.Error(err))
		logger.Fatal("Error connecting to database", zap.Error(err))
	}
	// Limit the time allowed to ping database and get version to 15 seconds total.
	ctx, ctxCancelFn := context.WithTimeout(context.Background(), 15*time.Second)
	defer ctxCancelFn()
	if err = db.PingContext(ctx); err != nil {
	// Limit max time allowed across database ping and version fetch to 15 seconds total.
	pingCtx, pingCtxCancelFn := context.WithTimeout(ctx, 15*time.Second)
	defer pingCtxCancelFn()
	if err = db.PingContext(pingCtx); err != nil {
		if strings.HasSuffix(err.Error(), "does not exist (SQLSTATE 3D000)") {
			multiLogger.Fatal("Database schema not found, run `nakama migrate up`", zap.Error(err))
			logger.Fatal("Database schema not found, run `nakama migrate up`", zap.Error(err))
		}
		multiLogger.Fatal("Error pinging database", zap.Error(err))
		logger.Fatal("Error pinging database", zap.Error(err))
	}

	db.SetConnMaxLifetime(time.Millisecond * time.Duration(config.GetDatabase().ConnMaxLifetimeMs))
@@ -70,13 +77,111 @@ func DbConnect(multiLogger *zap.Logger, config Config) (*sql.DB, string) {
	db.SetMaxIdleConns(config.GetDatabase().MaxIdleConns)

	var dbVersion string
	if err = db.QueryRowContext(ctx, "SELECT version()").Scan(&dbVersion); err != nil {
		multiLogger.Fatal("Error querying database version", zap.Error(err))
	if err = db.QueryRowContext(pingCtx, "SELECT version()").Scan(&dbVersion); err != nil {
		logger.Fatal("Error querying database version", zap.Error(err))
	}

	// Periodically check database hostname for underlying address changes.
	go func() {
		ticker := time.NewTicker(1 * time.Minute)
		for {
			select {
			case <-ctx.Done():
				return
			case <-ticker.C:
				newResolvedAddr, newResolvedAddrMap := dbResolveAddress(ctx, logger, parsedURL.Host)
				if len(resolvedAddr) == 0 {
					// Could only happen when initial resolve above failed, and all resolves since have also failed.
					// Trust the database driver in this case.
					resolvedAddr = newResolvedAddr
					resolvedAddrMap = newResolvedAddrMap
					break
				}
				if len(newResolvedAddr) == 0 {
					// New addresses failed to resolve, but had previous ones. Trust the database driver in this case.
					break
				}

				// Check for any changes in the resolved addresses.
				drain := len(resolvedAddrMap) != len(newResolvedAddrMap)
				if !drain {
					for addr := range newResolvedAddrMap {
						if _, found := resolvedAddrMap[addr]; !found {
							drain = true
							break
						}
					}
				}
				if !drain {
					// No changes.
					break
				}

				// Changes found. Drain the pool and allow the database driver to open fresh connections.
				// Rely on the database driver to re-do its own hostname to address resolution.
				var acquired int
				conns := make([]*sql.Conn, 0, config.GetDatabase().MaxOpenConns)
				for acquired < config.GetDatabase().MaxOpenConns {
					acquired++
					conn, err := db.Conn(ctx)
					if err != nil {
						if err == context.Canceled {
							// Server shutting down.
							return
						}
						// Log errors acquiring connections, but proceed without the failed connection anyway.
						logger.Error("Error acquiring database connection", zap.Error(err))
						continue
					}
					conns = append(conns, conn)
				}

				logger.Warn("Database rotating all open connections due to address change",
					zap.Int("count", len(conns)),
					zap.Strings("previous", resolvedAddr),
					zap.Strings("updated", newResolvedAddr))

				resolvedAddr = newResolvedAddr
				resolvedAddrMap = newResolvedAddrMap
				for _, conn := range conns {
					if err := conn.Raw(func(driverConn interface{}) error {
						pgc, ok := driverConn.(*stdlib.Conn)
						if !ok {
							return ErrDatabaseDriverMismatch
						}
						if err := pgc.Close(); err != nil {
							return err
						}
						return nil
					}); err != nil {
						logger.Error("Error closing database connection", zap.Error(err))
					}
					if err := conn.Close(); err != nil {
						logger.Error("Error releasing database connection", zap.Error(err))
					}
				}
			}
		}
	}()

	return db, dbVersion
}

func dbResolveAddress(ctx context.Context, logger *zap.Logger, host string) ([]string, map[string]struct{}) {
	resolveCtx, resolveCtxCancelFn := context.WithTimeout(ctx, 15*time.Second)
	defer resolveCtxCancelFn()
	addr, err := net.DefaultResolver.LookupHost(resolveCtx, host)
	if err != nil {
		logger.Debug("Error resolving database address", zap.String("host", host), zap.Error(err))
		return nil, nil
	}
	addrMap := make(map[string]struct{}, len(addr))
	for _, a := range addr {
		addrMap[a] = struct{}{}
	}
	return addr, addrMap
}

// Tx is used to permit clients to implement custom transaction logic.
type Tx interface {
	ExecContext(context.Context, string, ...interface{}) (sql.Result, error)