Commit e5da3c84 authored by Mo Firouz's avatar Mo Firouz
Browse files

Better strategy for database checking. Update Google Login authentication flow.

parent 7277e37f
Loading
Loading
Loading
Loading
+24 −20
Original line number Diff line number Diff line
@@ -27,9 +27,11 @@ import (

	"github.com/rubenv/sql-migrate"
	"go.uber.org/zap"
	"github.com/lib/pq"
)

const (
	dbErrorDuplicateDatabase = "42P04"
	migrationTable = "migration_info"
	dialect        = "postgres"
	defaultLimit   = -1
@@ -104,21 +106,27 @@ func MigrateParse(args []string, logger *zap.Logger) {

	ms.parseSubcommand(args[1:])

	rawurl := fmt.Sprintf("postgresql://%s?sslmode=disable", ms.dbAddress)
	url, err := url.Parse(rawurl)
	rawUrl := fmt.Sprintf("postgresql://%s", ms.dbAddress)
	parsedUrl, err := url.Parse(rawUrl)
	if err != nil {
		logger.Fatal("Bad connection URL", zap.Error(err))
	}

	query := parsedUrl.Query()
	if len(query.Get("sslmode")) == 0 {
		query.Set("sslmode", "disable")
		parsedUrl.RawQuery = query.Encode()
	}

	dbname := "nakama"
	if len(url.Path) > 1 {
		dbname = url.Path[1:]
	if len(parsedUrl.Path) > 1 {
		dbname = parsedUrl.Path[1:]
	}

	logger.Info("Database connection", zap.String("db", ms.dbAddress))

	url.Path = ""
	db, err := sql.Open(dialect, url.String())
	parsedUrl.Path = ""
	db, err := sql.Open(dialect, parsedUrl.String())
	if err != nil {
		logger.Fatal("Failed to open database", zap.Error(err))
	}
@@ -132,24 +140,20 @@ func MigrateParse(args []string, logger *zap.Logger) {
	}
	logger.Info("Database information", zap.String("version", dbVersion))

	var exists bool
	err = db.QueryRow("SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)", dbname).Scan(&exists)
start:
	switch {
	case err != nil:
		logger.Fatal("Database query failed", zap.Error(err))
	case !exists:
		_, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname))
		exists = err == nil
		goto start
	case exists:
	if _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname)); err != nil {
		if e, ok := err.(*pq.Error); ok && e.Code == dbErrorDuplicateDatabase {
			logger.Info("Using existing database", zap.String("name", dbname))
		} else {
			logger.Fatal("Database query failed", zap.Error(err))
		}
	} else {
		logger.Info("Creating new database", zap.String("name", dbname))
	}
	db.Close()

	// Append dbname to data source name.
	url.Path = fmt.Sprintf("/%s", dbname)
	db, err = sql.Open(dialect, url.String())
	parsedUrl.Path = fmt.Sprintf("/%s", dbname)
	db, err = sql.Open(dialect, parsedUrl.String())
	if err != nil {
		logger.Fatal("Failed to open database", zap.Error(err))
	}
+7 −3
Original line number Diff line number Diff line
@@ -81,7 +81,7 @@ func main() {
	multiLogger.Info("Data directory", zap.String("path", config.GetDataDir()))
	multiLogger.Info("Database connections", zap.Strings("dsns", config.GetDatabase().Addresses))

	db, dbVersion := dbConnect(multiLogger, config.GetDatabase().Addresses)
	db, dbVersion := dbConnect(multiLogger, config)
	multiLogger.Info("Database information", zap.String("version", dbVersion))

	// Check migration status and log if the schema has diverged.
@@ -149,9 +149,9 @@ func main() {
	select {}
}

func dbConnect(multiLogger *zap.Logger, dsns []string) (*sql.DB, string) {
func dbConnect(multiLogger *zap.Logger, config server.Config) (*sql.DB, string) {
	// TODO config database pooling
	rawurl := fmt.Sprintf("postgresql://%s", dsns[0])
	rawurl := fmt.Sprintf("postgresql://%s", config.GetDatabase().Addresses[0])
	url, err := url.Parse(rawurl)
	if err != nil {
		multiLogger.Fatal("Bad connection URL", zap.Error(err))
@@ -175,6 +175,10 @@ func dbConnect(multiLogger *zap.Logger, dsns []string) (*sql.DB, string) {
		multiLogger.Fatal("Error pinging database", zap.Error(err))
	}

	db.SetConnMaxLifetime(time.Millisecond * time.Duration(config.GetDatabase().ConnMaxLifetimeMs))
	db.SetMaxOpenConns(config.GetDatabase().MaxOpenConns)
	db.SetMaxIdleConns(config.GetDatabase().MaxIdleConns)

	var dbVersion string
	if err := db.QueryRow("SELECT version()").Scan(&dbVersion); err != nil {
		multiLogger.Fatal("Error querying database version", zap.Error(err))
+132 −13
Original line number Diff line number Diff line
@@ -29,10 +29,17 @@ import (
	"strconv"
	"strings"
	"time"
	"github.com/dgrijalva/jwt-go"
	"crypto/rsa"
	"sync"
	"crypto"
)

// Client is responsible for making calls to different providers
type Client struct {
	sync.RWMutex
	googleCerts          []*rsa.PublicKey
	googleCertsRefreshAt int64
	client           *http.Client
	gamecenterCaCert *x509.Certificate
}
@@ -63,12 +70,22 @@ type facebookFriends struct {
	Paging facebookPaging    `json:"paging"`
}

// GoogleProfile is an abbreviated version of a Google profile.
// GoogleProfile is an abbreviated version of a Google profile extracted from in a verified ID token.
type GoogleProfile struct {
	ID     string `json:"id"`
	Name   string `json:"name"`
	// Fields available in all tokens.
	Iss string `json:"iss"`
	Sub string `json:"sub"`
	Azp string `json:"azp"`
	Aud string `json:"aud"`
	Iat string `json:"iat"`
	Exp string `json:"exp"`
	// Fields available only if the user granted the "profile" and "email" OAuth scopes.
	Email         string `json:"email"`
	Gender string `json:"gender"`
	EmailVerified string `json:"email_verified"`
	Name          string `json:"name"`
	Picture       string `json:"picture"`
	GivenName     string `json:"given_name"`
	FamilyName    string `json:"family_name"`
	Locale        string `json:"locale"`
}

@@ -160,14 +177,116 @@ func (c *Client) GetFacebookFriends(accessToken string) ([]FacebookProfile, erro
	}
}

// GetGoogleProfile retrieves the user's Google Profile given the accessToken
func (c *Client) GetGoogleProfile(accessToken string) (*GoogleProfile, error) {
	path := "https://www.googleapis.com/oauth2/v2/userinfo?alt=json"
	var profile GoogleProfile
	err := c.request("google profile", path, map[string]string{"Authorization": "Bearer " + accessToken}, &profile)
// CheckGoogleToken extracts the user's Google Profile from a given ID token.
func (c *Client) CheckGoogleToken(idToken string) (*GoogleProfile, error) {
	c.RLock()
	if c.googleCertsRefreshAt < time.Now().UTC().Unix() {
		// Release the read lock and perform a certificate refresh.
		c.RUnlock()
		c.Lock()
		if c.googleCertsRefreshAt < time.Now().UTC().Unix() {
			certs := make(map[string]string, 2)
			err := c.request("google cert", "https://www.googleapis.com/oauth2/v1/certs", nil, &certs)
			if err != nil {
				c.Unlock()
				return nil, err
			}
			newCerts := make([]*rsa.PublicKey, 0, 2)
			var newRefreshAt int64
			for _, data := range certs {
				currentBlock, _ := pem.Decode([]byte(data))
				if currentBlock == nil {
					// Block was invalid, ignore it and try the next.
					continue
				}
				currentCert, err := x509.ParseCertificate(currentBlock.Bytes)
				if err != nil {
					// Certificate was invalid, ignore it and try the next.
					continue
				}
				t := time.Now()
				if currentCert.NotBefore.After(t) || currentCert.NotAfter.After(t) {
					// Certificate not yet valid or has already expired, skip it.
					continue
				}
				pub, ok := currentCert.PublicKey.(*rsa.PublicKey)
				if !ok {

				}
				newCerts = append(newCerts, pub)
				if newRefreshAt == 0 || newRefreshAt > currentCert.NotAfter.UTC().Unix() {
					// Refresh all certs 1 hour before the soonest expiry is due.
					newRefreshAt = currentCert.NotAfter.UTC().Unix() - 3600
				}
			}
			if len(newCerts) == 0 {
				c.Unlock()
				return nil, errors.New("error finding valid google cert")
			}
			c.googleCerts = newCerts
			c.googleCertsRefreshAt = newRefreshAt
		}
		c.Unlock()
		c.RLock()
	}

	var err error
	var token *jwt.Token
	for _, cert := range c.googleCerts {
		// Try to parse and verify the token with each of the currently available certificates.
		token, err = jwt.Parse(idToken, func(token *jwt.Token) (interface{}, error) {
			if s, ok := token.Method.(*jwt.SigningMethodRSA); !ok || s.Hash != crypto.SHA256 {
				return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
			}
			claims := token.Claims.(jwt.MapClaims)
			if !claims.VerifyIssuer("accounts.google.com", true) && !claims.VerifyIssuer("https://accounts.google.com", true) {
				return nil, fmt.Errorf("unexpected issuer: %v", claims["iss"])
			}
			return cert, nil
		})
		if err == nil {
			// If any certificate worked, the token is valid.
			break
		}
	}

	// All verification attempts failed.
	if token == nil {
		c.RUnlock()
		return nil, errors.New("google id token invalid")
	}

	claims := token.Claims.(jwt.MapClaims)
	profile := GoogleProfile{
		Iss: claims["iss"].(string),
		Sub: claims["sub"].(string),
		Azp: claims["azp"].(string),
		Aud: claims["aud"].(string),
		Iat: claims["iat"].(string),
		Exp: claims["exp"].(string),
	}
	if v, ok := claims["email"]; ok {
		profile.Email = v.(string)
	}
	if v, ok := claims["email_verified"]; ok {
		profile.EmailVerified = v.(string)
	}
	if v, ok := claims["name"]; ok {
		profile.Name = v.(string)
	}
	if v, ok := claims["picture"]; ok {
		profile.Picture = v.(string)
	}
	if v, ok := claims["given_name"]; ok {
		profile.GivenName = v.(string)
	}
	if v, ok := claims["family_name"]; ok {
		profile.FamilyName = v.(string)
	}
	if v, ok := claims["locale"]; ok {
		profile.Locale = v.(string)
	}

	return &profile, nil
}

+2 −2
Original line number Diff line number Diff line
@@ -170,7 +170,7 @@ func (p *pipeline) linkGoogle(logger *zap.Logger, session session, envelope *Env
		return
	}

	googleProfile, err := p.socialClient.GetGoogleProfile(accessToken)
	googleProfile, err := p.socialClient.CheckGoogleToken(accessToken)
	if err != nil {
		logger.Warn("Could not get Google profile", zap.Error(err))
		session.Send(ErrorMessage(envelope.CollationId, USER_LINK_PROVIDER_UNAVAILABLE, "Could not get Google profile"), true)
@@ -186,7 +186,7 @@ AND NOT EXISTS
     FROM users
     WHERE google_id = $2)`,
		session.UserID(),
		googleProfile.ID,
		googleProfile.Sub,
		nowMs())

	if err != nil {
+5 −5
Original line number Diff line number Diff line
@@ -631,7 +631,7 @@ func (a *authenticationService) loginGoogle(authReq *AuthenticateRequest) (strin
		return "", "", 0, "Invalid Google access token, no spaces or control characters allowed", BAD_INPUT
	}

	googleProfile, err := a.socialClient.GetGoogleProfile(accessToken)
	googleProfile, err := a.socialClient.CheckGoogleToken(accessToken)
	if err != nil {
		a.logger.Warn("Could not get Google profile", zap.Error(err))
		return "", "", 0, errorCouldNotLogin, AUTH_ERROR
@@ -641,7 +641,7 @@ func (a *authenticationService) loginGoogle(authReq *AuthenticateRequest) (strin
	var handle string
	var disabledAt int64
	err = a.db.QueryRow("SELECT id, handle, disabled_at FROM users WHERE google_id = $1",
		googleProfile.ID).
		googleProfile.Sub).
		Scan(&userID, &handle, &disabledAt)
	if err != nil {
		if err == sql.ErrNoRows {
@@ -960,7 +960,7 @@ func (a *authenticationService) registerGoogle(tx *sql.Tx, authReq *Authenticate
		return "", "", "", "Invalid Google access token, no spaces or control characters allowed", BAD_INPUT
	}

	googleProfile, err := a.socialClient.GetGoogleProfile(accessToken)
	googleProfile, err := a.socialClient.CheckGoogleToken(accessToken)
	if err != nil {
		a.logger.Warn("Could not get Google profile", zap.Error(err))
		return "", "", "", errorCouldNotRegister, AUTH_ERROR
@@ -982,7 +982,7 @@ WHERE NOT EXISTS
 WHERE google_id = $3::VARCHAR)`,
		userID,
		handle,
		googleProfile.ID,
		googleProfile.Sub,
		updatedAt)

	if err != nil {
@@ -998,7 +998,7 @@ WHERE NOT EXISTS
		return "", "", "", errorCouldNotRegister, RUNTIME_EXCEPTION
	}

	return userID, handle, googleProfile.ID, "", 0
	return userID, handle, googleProfile.Sub, "", 0
}

func (a *authenticationService) registerGameCenter(tx *sql.Tx, authReq *AuthenticateRequest) (string, string, string, string, Error_Code) {