Commit f64d8e27 authored by Simon Esposito's avatar Simon Esposito
Browse files

Upgrade pgx to v5

Use trimmed sql-migrate
Config and db creation refactor.
parent 8dc44aa3
Loading
Loading
Loading
Loading
+10 −3
Original line number Diff line number Diff line
@@ -18,11 +18,11 @@ require (
	github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0
	github.com/heroiclabs/nakama-common v1.30.0
	github.com/jackc/pgconn v1.14.0
	github.com/heroiclabs/sql-migrate v0.0.0-20230615133120-fb3ad977aaaf
	github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa
	github.com/jackc/pgtype v1.14.0
	github.com/jackc/pgx/v4 v4.18.1
	github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59
	github.com/jackc/pgx/v5 v5.4.1
	github.com/prometheus/client_golang v1.16.0
	github.com/rubenv/sql-migrate v1.2.0
	github.com/stretchr/testify v1.8.4
	github.com/twmb/murmur3 v1.1.8
	github.com/uber-go/tally/v4 v4.1.7
@@ -61,13 +61,18 @@ require (
	github.com/gofrs/uuid v4.3.0+incompatible // indirect
	github.com/golang/glog v1.1.0 // indirect
	github.com/golang/protobuf v1.5.3 // indirect
<<<<<<< HEAD
	github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect
	github.com/jackc/chunkreader/v2 v2.0.1 // indirect
=======
	github.com/jackc/pgconn v1.14.0 // indirect
>>>>>>> bf0c798f (Upgrade pgx to v5)
	github.com/jackc/pgio v1.0.0 // indirect
	github.com/jackc/pgpassfile v1.0.0 // indirect
	github.com/jackc/pgproto3/v2 v2.3.2 // indirect
	github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
	github.com/klauspost/compress v1.15.2 // indirect
	github.com/lib/pq v1.10.0 // indirect
	github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
	github.com/mschoch/smat v0.2.0 // indirect
	github.com/pkg/errors v0.9.1 // indirect
@@ -75,10 +80,12 @@ require (
	github.com/prometheus/client_model v0.3.0 // indirect
	github.com/prometheus/common v0.42.0 // indirect
	github.com/prometheus/procfs v0.10.1 // indirect
	github.com/rogpeppe/go-internal v1.8.0 // indirect
	go.uber.org/multierr v1.6.0 // indirect
	golang.org/x/net v0.12.0 // indirect
	golang.org/x/sys v0.10.0 // indirect
	golang.org/x/text v0.11.0 // indirect
	golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
	google.golang.org/appengine v1.6.7 // indirect
	google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc // indirect
	google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect
+49 −473

File changed.

Preview size limit exceeded, changes collapsed.

+37 −9
Original line number Diff line number Diff line
@@ -34,7 +34,8 @@ import (
	"github.com/heroiclabs/nakama/v3/migrate"
	"github.com/heroiclabs/nakama/v3/server"
	"github.com/heroiclabs/nakama/v3/social"
	_ "github.com/jackc/pgx/v4/stdlib"
	"github.com/jackc/pgx/v5/stdlib"
	_ "github.com/jackc/pgx/v5/stdlib" // Blank import to register SQL driver
	"go.uber.org/zap"
	"go.uber.org/zap/zapcore"
	"google.golang.org/protobuf/encoding/protojson"
@@ -71,13 +72,33 @@ func main() {

	tmpLogger := server.NewJSONLogger(os.Stdout, zapcore.InfoLevel, server.JSONFormat)

	ctx, ctxCancelFn := context.WithCancel(context.Background())

	if len(os.Args) > 1 {
		switch os.Args[1] {
		case "--version":
			fmt.Println(semver)
			return
		case "migrate":
			migrate.Parse(os.Args[2:], tmpLogger)
			config := server.ParseArgs(tmpLogger, os.Args[2:])
			config.ValidateDatabase(tmpLogger)
			db := server.DbConnect(ctx, tmpLogger, config, true)
			defer db.Close()

			conn, err := db.Conn(ctx)
			if err != nil {
				tmpLogger.Fatal("Failed to acquire db conn for migration", zap.Error(err))
			}

			if err = conn.Raw(func(driverConn any) error {
				pgxConn := driverConn.(*stdlib.Conn).Conn()
				migrate.RunCmd(ctx, tmpLogger, pgxConn, os.Args[2], config.GetLimit(), config.GetLogger().Format)

				return nil
			}); err != nil {
				tmpLogger.Fatal("Failed to acquire pgx conn for migration", zap.Error(err))
			}

			return
		case "check":
			// Parse any command line args to look up runtime path.
@@ -108,7 +129,7 @@ func main() {

	config := server.ParseArgs(tmpLogger, os.Args)
	logger, startupLogger := server.SetupLogging(tmpLogger, config)
	configWarnings := server.CheckConfig(logger, config)
	configWarnings := config.Validate(logger)

	startupLogger.Info("Nakama starting")
	startupLogger.Info("Node", zap.String("name", config.GetName()), zap.String("version", semver), zap.String("runtime", runtime.Version()), zap.Int("cpu", runtime.NumCPU()), zap.Int("proc", runtime.GOMAXPROCS(0)))
@@ -125,14 +146,21 @@ func main() {
	}
	startupLogger.Info("Database connections", zap.Strings("dsns", redactedAddresses))

	// Global server context.
	ctx, ctxCancelFn := context.WithCancel(context.Background())

	db, dbVersion := server.DbConnect(ctx, startupLogger, config)
	startupLogger.Info("Database information", zap.String("version", dbVersion))
	db := server.DbConnect(ctx, startupLogger, config, false)

	// Check migration status and fail fast if the schema has diverged.
	migrate.StartupCheck(startupLogger, db)
	conn, err := db.Conn(context.Background())
	if err != nil {
		logger.Fatal("Failed to acquire db conn for migration check", zap.Error(err))
	}

	if err = conn.Raw(func(driverConn any) error {
		pgxConn := driverConn.(*stdlib.Conn).Conn()
		migrate.Check(ctx, startupLogger, pgxConn)
		return nil
	}); err != nil {
		logger.Fatal("Failed to acquire pgx conn for migration check", zap.Error(err))
	}

	// Access to social provider integrations.
	socialClient := social.NewClient(logger, 5*time.Second, config.GetGoogleAuth().OAuthConfig)
+61 −168
Original line number Diff line number Diff line
// Copyright 2018 The Nakama Authors
// Copyright 2021 Heroic Labs.
// All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// NOTICE: All information contained herein is, and remains the property of Heroic
// Labs. and its suppliers, if any. The intellectual and technical concepts
// contained herein are proprietary to Heroic Labs. and its suppliers and may be
// covered by U.S. and Foreign Patents, patents in process, and are protected by
// trade secret or copyright law. Dissemination of this information or reproduction
// of this material is strictly forbidden unless prior written permission is
// obtained from Heroic Labs.

package migrate

import (
	"database/sql"
	"context"
	"embed"
	"errors"
	"flag"
	"fmt"
	"math"
	"net/url"
	"os"
	"strings"
	"time"

	"github.com/heroiclabs/nakama/v3/server"
	"github.com/jackc/pgconn"
	"github.com/jackc/pgerrcode"
	_ "github.com/jackc/pgx/v4/stdlib"
	migrate "github.com/rubenv/sql-migrate"
	sqlmigrate "github.com/heroiclabs/sql-migrate"
	"github.com/jackc/pgx/v5"
	_ "github.com/jackc/pgx/v5/stdlib" // Blank import to register SQL driver
	"go.uber.org/zap"
	"go.uber.org/zap/zapcore"
)

const (
	dbErrorDatabaseDoesNotExist = pgerrcode.InvalidCatalogName
	migrationTable = "migration_info"
	dialect                     = "postgres"
	defaultLimit   = -1
)

@@ -53,18 +43,17 @@ type statusRow struct {
}

type migrationService struct {
	dbAddress    string
	limit        int
	loggerFormat server.LoggingFormat
	migrations   *migrate.EmbedFileSystemMigrationSource
	db           *sql.DB
	migrations   *sqlmigrate.EmbedFileSystemMigrationSource
	execFn       func(ctx context.Context, logger *zap.Logger, db *pgx.Conn)
}

func StartupCheck(logger *zap.Logger, db *sql.DB) {
	migrate.SetTable(migrationTable)
	migrate.SetIgnoreUnknown(true)
func Check(ctx context.Context, logger *zap.Logger, db *pgx.Conn) {
	sqlmigrate.SetTable(migrationTable)
	sqlmigrate.SetIgnoreUnknown(true)

	ms := &migrate.EmbedFileSystemMigrationSource{
	ms := &sqlmigrate.EmbedFileSystemMigrationSource{
		FileSystem: sqlMigrateFS,
		Root:       "sql",
	}
@@ -73,7 +62,7 @@ func StartupCheck(logger *zap.Logger, db *sql.DB) {
	if err != nil {
		logger.Fatal("Could not find migrations", zap.Error(err))
	}
	records, err := migrate.GetMigrationRecords(db, dialect)
	records, err := sqlmigrate.GetMigrationRecords(ctx, db)
	if err != nil {
		logger.Fatal("Could not get migration records, run `nakama migrate up`", zap.Error(err))
	}
@@ -85,140 +74,37 @@ func StartupCheck(logger *zap.Logger, db *sql.DB) {
	if diff < 0 {
		logger.Warn("DB schema newer, update Nakama", zap.Int64("migrations", int64(math.Abs(float64(diff)))))
	}
	db.Close(ctx)
}

func Parse(args []string, tmpLogger *zap.Logger) {
	if len(args) == 0 {
func RunCmd(ctx context.Context, tmpLogger *zap.Logger, db *pgx.Conn, cmd string, limit int, loggerFormat string) {
	if cmd == "" {
		tmpLogger.Fatal("Migrate requires a subcommand. Available commands are: 'up', 'down', 'redo', 'status'.")
	}

	migrate.SetTable(migrationTable)
	migrate.SetIgnoreUnknown(true)
	sqlmigrate.SetTable(migrationTable)
	sqlmigrate.SetIgnoreUnknown(true)
	ms := &migrationService{
		migrations: &migrate.EmbedFileSystemMigrationSource{
		migrations: &sqlmigrate.EmbedFileSystemMigrationSource{
			FileSystem: sqlMigrateFS,
			Root:       "sql",
		},
		limit: limit,
	}

	var exec func(logger *zap.Logger)
	switch args[0] {
	case "up":
		exec = ms.up
	case "down":
		exec = ms.down
	case "redo":
		exec = ms.redo
	case "status":
		exec = ms.status
	default:
		tmpLogger.Fatal("Unrecognized migrate subcommand. Available commands are: 'up', 'down', 'redo', 'status'.")
		return
	}

	ms.parseSubcommand(args[1:], tmpLogger)
	ms.parseParams(tmpLogger, cmd, loggerFormat)
	logger := server.NewJSONLogger(os.Stdout, zapcore.InfoLevel, ms.loggerFormat)

	rawURL := ms.dbAddress
	if !(strings.HasPrefix(rawURL, "postgresql://") || strings.HasPrefix(rawURL, "postgres://")) {
		rawURL = fmt.Sprintf("postgres://%s", rawURL)
	}
	parsedURL, err := url.Parse(rawURL)
	if err != nil {
		logger.Fatal("Bad connection URL", zap.Error(err))
	}
	query := parsedURL.Query()
	var queryUpdated bool
	if len(query.Get("sslmode")) == 0 {
		query.Set("sslmode", "prefer")
		queryUpdated = true
	}
	//if len(query.Get("statement_cache_mode")) == 0 {
	//	query.Set("statement_cache_mode", "describe")
	//	queryUpdated = true
	//}
	if queryUpdated {
		parsedURL.RawQuery = query.Encode()
	}

	if len(parsedURL.User.Username()) < 1 {
		parsedURL.User = url.User("root")
	}
	dbname := "nakama"
	if len(parsedURL.Path) > 1 {
		dbname = parsedURL.Path[1:]
	} else {
		// Default dbname to 'nakama'
		parsedURL.Path = "/nakama"
	}

	logger.Info("Database connection", zap.String("dsn", parsedURL.Redacted()))

	db, err := sql.Open("pgx", parsedURL.String())
	if err != nil {
		logger.Fatal("Failed to open database", zap.Error(err))
	}

	var nakamaDBExists bool
	if err = db.QueryRow("SELECT EXISTS (SELECT 1 from pg_database WHERE datname = $1)", dbname).Scan(&nakamaDBExists); err != nil {
		var pgErr *pgconn.PgError
		if errors.As(err, &pgErr) && pgErr.Code == dbErrorDatabaseDoesNotExist {
			nakamaDBExists = false
		} else {
			db.Close()
			logger.Fatal("Failed to check if db exists", zap.String("db", dbname), zap.Error(err))
		}
	}

	if !nakamaDBExists {
		// Database does not exist, create it
		logger.Info("Creating new database", zap.String("name", dbname))
		db.Close()
		// Connect to anonymous db
		parsedURL.Path = ""
		db, err = sql.Open("pgx", parsedURL.String())
		if err != nil {
			logger.Fatal("Failed to open database", zap.Error(err))
		}
		if _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %q", dbname)); err != nil {
			db.Close()
			logger.Fatal("Failed to create database", zap.Error(err))
		}
		db.Close()
		parsedURL.Path = fmt.Sprintf("/%s", dbname)
		db, err = sql.Open("pgx", parsedURL.String())
		if err != nil {
			db.Close()
			logger.Fatal("Failed to open database", zap.Error(err))
		}
	}

	// Get database version
	var dbVersion string
	if err = db.QueryRow("SELECT version()").Scan(&dbVersion); err != nil {
		db.Close()
		logger.Fatal("Error querying database version", zap.Error(err))
	}

	logger.Info("Database information", zap.String("version", dbVersion))

	if err = db.Ping(); err != nil {
		db.Close()
		logger.Fatal("Error pinging database", zap.Error(err))
	}

	ms.db = db

	exec(logger)
	db.Close()
	ms.runMigration(ctx, logger, db)
	db.Close(ctx)
}

func (ms *migrationService) up(logger *zap.Logger) {
func (ms *migrationService) up(ctx context.Context, logger *zap.Logger, db *pgx.Conn) {
	if ms.limit < defaultLimit {
		ms.limit = 0
	}

	appliedMigrations, err := migrate.ExecMax(ms.db, dialect, ms.migrations, migrate.Up, ms.limit)
	appliedMigrations, err := sqlmigrate.ExecMax(ctx, db, ms.migrations, sqlmigrate.Up, ms.limit)
	if err != nil {
		logger.Fatal("Failed to apply migrations", zap.Int("count", appliedMigrations), zap.Error(err))
	}
@@ -226,12 +112,12 @@ func (ms *migrationService) up(logger *zap.Logger) {
	logger.Info("Successfully applied migration", zap.Int("count", appliedMigrations))
}

func (ms *migrationService) down(logger *zap.Logger) {
func (ms *migrationService) down(ctx context.Context, logger *zap.Logger, db *pgx.Conn) {
	if ms.limit < defaultLimit {
		ms.limit = 1
	}

	appliedMigrations, err := migrate.ExecMax(ms.db, dialect, ms.migrations, migrate.Down, ms.limit)
	appliedMigrations, err := sqlmigrate.ExecMax(ctx, db, ms.migrations, sqlmigrate.Down, ms.limit)
	if err != nil {
		logger.Fatal("Failed to migrate back", zap.Int("count", appliedMigrations), zap.Error(err))
	}
@@ -239,25 +125,25 @@ func (ms *migrationService) down(logger *zap.Logger) {
	logger.Info("Successfully migrated back", zap.Int("count", appliedMigrations))
}

func (ms *migrationService) redo(logger *zap.Logger) {
func (ms *migrationService) redo(ctx context.Context, logger *zap.Logger, db *pgx.Conn) {
	if ms.limit > defaultLimit {
		logger.Warn("Limit is ignored when redo is invoked")
	}

	appliedMigrations, err := migrate.ExecMax(ms.db, dialect, ms.migrations, migrate.Down, 1)
	appliedMigrations, err := sqlmigrate.ExecMax(ctx, db, ms.migrations, sqlmigrate.Down, 1)
	if err != nil {
		logger.Fatal("Failed to migrate back", zap.Int("count", appliedMigrations), zap.Error(err))
	}
	logger.Info("Successfully migrated back", zap.Int("count", appliedMigrations))

	appliedMigrations, err = migrate.ExecMax(ms.db, dialect, ms.migrations, migrate.Up, 1)
	appliedMigrations, err = sqlmigrate.ExecMax(ctx, db, ms.migrations, sqlmigrate.Up, 1)
	if err != nil {
		logger.Fatal("Failed to apply migrations", zap.Int("count", appliedMigrations), zap.Error(err))
	}
	logger.Info("Successfully applied migration", zap.Int("count", appliedMigrations))
}

func (ms *migrationService) status(logger *zap.Logger) {
func (ms *migrationService) status(ctx context.Context, logger *zap.Logger, db *pgx.Conn) {
	if ms.limit > defaultLimit {
		logger.Warn("Limit is ignored when status is invoked")
	}
@@ -267,7 +153,7 @@ func (ms *migrationService) status(logger *zap.Logger) {
		logger.Fatal("Could not find migrations", zap.Error(err))
	}

	records, err := migrate.GetMigrationRecords(ms.db, dialect)
	records, err := sqlmigrate.GetMigrationRecords(ctx, db)
	if err != nil {
		logger.Fatal("Could not get migration records", zap.Error(err))
	}
@@ -305,19 +191,18 @@ func (ms *migrationService) status(logger *zap.Logger) {
	}
}

func (ms *migrationService) parseSubcommand(args []string, logger *zap.Logger) {
	var loggerFormat string
	flags := flag.NewFlagSet("migrate", flag.ExitOnError)
	flags.StringVar(&ms.dbAddress, "database.address", "root@localhost:26257", "Address of CockroachDB server (username:password@address:port/dbname)")
	flags.IntVar(&ms.limit, "limit", defaultLimit, "Number of migrations to apply forwards or backwards.")
	flags.StringVar(&loggerFormat, "logger.format", "json", "Number of migrations to apply forwards or backwards.")

	if err := flags.Parse(args); err != nil {
		logger.Fatal("Could not parse migration flags.")
	}

	if ms.dbAddress == "" {
		logger.Fatal("Database connection details are required.")
func (ms *migrationService) parseParams(logger *zap.Logger, cmd, loggerFormat string) {
	switch cmd {
	case "up":
		ms.execFn = ms.up
	case "down":
		ms.execFn = ms.down
	case "redo":
		ms.execFn = ms.redo
	case "status":
		ms.execFn = ms.status
	default:
		logger.Fatal("Unrecognized migrate subcommand. Available commands are: 'up', 'down', 'redo', 'status'.")
	}

	ms.loggerFormat = server.JSONFormat
@@ -332,3 +217,11 @@ func (ms *migrationService) parseSubcommand(args []string, logger *zap.Logger) {
		logger.Fatal("Logger mode invalid, must be one of: '', 'json', or 'stackdriver")
	}
}

func (ms *migrationService) runMigration(ctx context.Context, logger *zap.Logger, db *pgx.Conn) {
	if ms.execFn == nil {
		logger.Fatal("Cannot run migration without a set command")
	}

	ms.execFn(ctx, logger, db)
}
+1 −1
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ import (

	"github.com/gofrs/uuid/v5"
	"github.com/heroiclabs/nakama-common/api"
	"github.com/jackc/pgconn"
	"github.com/jackc/pgx/v5/pgconn"
	"go.uber.org/zap"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
Loading