Unverified Commit cc786c59 authored by Simon Esposito's avatar Simon Esposito Committed by GitHub
Browse files

Improve db migration setup (#606)

Add support for accepting db url including protocol but keeping backwards compatibility.
Default db ssl mode to 'prefer' if no option is specified.
Resolves #602.
parent 682fb038
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -230,14 +230,17 @@ func main() {
}

func dbConnect(multiLogger *zap.Logger, config server.Config) (*sql.DB, string) {
	rawURL := fmt.Sprintf("postgresql://%s", config.GetDatabase().Addresses[0])
	rawURL := config.GetDatabase().Addresses[0]
	if !(strings.HasPrefix(rawURL, "postgresql://") || strings.HasPrefix(rawURL, "postgres://")) {
		rawURL = fmt.Sprintf("postgres://%s", rawURL)
	}
	parsedURL, err := url.Parse(rawURL)
	if err != nil {
		multiLogger.Fatal("Bad database connection URL", zap.Error(err))
	}
	query := parsedURL.Query()
	if len(query.Get("sslmode")) == 0 {
		query.Set("sslmode", "disable")
		query.Set("sslmode", "prefer")
		parsedURL.RawQuery = query.Encode()
	}

+45 −27
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ import (
)

const (
	dbErrorDatabaseDoesNotExist = "3D000"
	dbErrorDuplicateDatabase    = "42P04"
	migrationTable              = "migration_info"
	dialect                     = "postgres"
@@ -152,14 +153,17 @@ func Parse(args []string, tmpLogger *zap.Logger) {
	ms.parseSubcommand(args[1:], tmpLogger)
	logger := server.NewJSONLogger(os.Stdout, zapcore.InfoLevel, ms.loggerFormat)

	rawURL := fmt.Sprintf("postgresql://%s", ms.dbAddress)
	rawURL := ms.dbAddress
	if !(strings.HasPrefix(rawURL, "postgresql://") || strings.HasPrefix(rawURL, "postgres://")) {
		rawURL = fmt.Sprintf("postgres://%s", rawURL)
	}
	parsedURL, err := url.Parse(rawURL)
	if err != nil {
		logger.Fatal("Bad connection URL", zap.Error(err))
	}
	query := parsedURL.Query()
	if len(query.Get("sslmode")) == 0 {
		query.Set("sslmode", "disable")
		query.Set("sslmode", "prefer")
		parsedURL.RawQuery = query.Encode()
	}

@@ -169,48 +173,62 @@ func Parse(args []string, tmpLogger *zap.Logger) {
	dbname := "nakama"
	if len(parsedURL.Path) > 1 {
		dbname = parsedURL.Path[1:]
	} else {
		// Default dbname to 'nakama'
		parsedURL.Path = "/nakama"
	}

	logger.Info("Database connection", zap.String("dsn", parsedURL.Redacted()))

	parsedURL.Path = ""
	db, err := sql.Open("pgx", parsedURL.String())
	if err != nil {
		logger.Fatal("Failed to open database", zap.Error(err))
	}
	if err = db.Ping(); err != nil {
		logger.Fatal("Error pinging database", zap.Error(err))
	}

	var dbVersion string
	if err = db.QueryRow("SELECT version()").Scan(&dbVersion); err != nil {
		logger.Fatal("Error querying database version", zap.Error(err))
	}
	logger.Info("Database information", zap.String("version", dbVersion))

	if _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %q", dbname)); err != nil {
		if e, ok := err.(pgx.PgError); ok && e.Code == dbErrorDuplicateDatabase {
			logger.Info("Using existing database", zap.String("name", dbname))
		} else {
			logger.Fatal("Database query failed", zap.Error(err))
		}
	} else {
		if e, ok := err.(pgx.PgError); ok && e.Code == dbErrorDatabaseDoesNotExist {
			// Database does not exist, try to create a new one
			logger.Info("Creating new database", zap.String("name", dbname))
			db.Close()
			// Connect to anonymous db
			parsedURL.Path = ""
			db, err = sql.Open("pgx", parsedURL.String())
			if err != nil {
				logger.Fatal("Failed to open database", zap.Error(err))
			}
	_ = db.Close()

	// Append dbname to data source name.
			if _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname)); err != nil {
				db.Close()
				logger.Fatal("Failed to create database", zap.Error(err))
			}
			db.Close()
			parsedURL.Path = fmt.Sprintf("/%s", dbname)
			db, err = sql.Open("pgx", parsedURL.String())
			if err != nil {
				db.Close()
				logger.Fatal("Failed to open database", zap.Error(err))
			}
			// Reattempt to get database version
			if err = db.QueryRow("SELECT version()").Scan(&dbVersion); err != nil {
				db.Close()
				logger.Fatal("Error querying database version", zap.Error(err))
			}
		} else {
			db.Close()
			logger.Fatal("Error querying database version", zap.Error(err))
		}
	}
	logger.Info("Database information", zap.String("version", dbVersion))

	if err = db.Ping(); err != nil {
		db.Close()
		logger.Fatal("Error pinging database", zap.Error(err))
	}

	ms.db = db

	exec(logger)
	db.Close()
	os.Exit(0)
}