Commit fc970cb8 authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Expose auth functions in runtime. Merged #148

parent b33a885d
Loading
Loading
Loading
Loading
+45 −243
Original line number Diff line number Diff line
@@ -22,13 +22,8 @@ import (
	"google.golang.org/grpc/status"
	"math/rand"
	"time"
	"github.com/satori/go.uuid"
	"go.uber.org/zap"
	"database/sql"
	"github.com/dgrijalva/jwt-go"
	"strings"
	"github.com/lib/pq"
	"golang.org/x/crypto/bcrypt"
)

var (
@@ -45,8 +40,6 @@ func (s *ApiServer) AuthenticateCustomFunc(ctx context.Context, in *api.Authenti
		return nil, status.Error(codes.InvalidArgument, "Custom ID invalid, must be 10-128 bytes.")
	}

	if in.Create == nil || in.Create.Value {
		// Use existing user account if found, otherwise create a new user account.
	username := in.Username
	if username == "" {
		username = generateUsername(s.random)
@@ -56,70 +49,16 @@ func (s *ApiServer) AuthenticateCustomFunc(ctx context.Context, in *api.Authenti
		return nil, status.Error(codes.InvalidArgument, "Username invalid, must be 1-128 bytes.")
	}

		userID := uuid.NewV4().String()
		ts := time.Now().UTC().Unix()
		// NOTE: This query relies on the `custom_id` conflict triggering before the `users_username_key`
		// constraint violation to ensure we fall to the RETURNING case and ignore the new username for
		// existing user accounts. The DO UPDATE SET is to trick the DB into having the data we need to return.
		query := `
INSERT INTO users (id, username, custom_id, created_at, updated_at)
VALUES ($1, $2, $3, $4, $4)
ON CONFLICT (custom_id) DO UPDATE SET custom_id = $3
RETURNING id, username, custom_id, disabled_at`
		params := []interface{}{userID, username, in.Account.Id, ts}
	create := in.Create == nil || in.Create.Value

		var dbUserID string
		var dbUsername string
		var dbCustomId sql.NullString
		var dbDisabledAt int64
		err := s.db.QueryRow(query, params...).Scan(&dbUserID, &dbUsername, &dbCustomId, &dbDisabledAt)
	dbUserID, dbUsername, err := AuthenticateCustom(s.logger, s.db, in.Account.Id, username, create)
	if err != nil {
			if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") {
				// Username is already in use by a different account.
				return nil, status.Error(codes.AlreadyExists, "Username is already in use.")
			}
			s.logger.Error("Cannot find or create user with custom ID.", zap.Error(err), zap.Any("input", in))
			return nil, status.Error(codes.Internal, "Error finding or creating user account.")
		}

		if dbDisabledAt != 0 {
			s.logger.Debug("User account is disabled.", zap.Any("input", in))
			return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}

		token := generateToken(s.config, dbUserID, dbUsername)
		return &api.Session{Token: token}, nil
	} else {
		// Do not create a new user account.
		query := `
SELECT id, username, disabled_at
FROM users
WHERE custom_id = $1`
		params := []interface{}{in.Account.Id}

		var dbUserID string
		var dbUsername string
		var dbDisabledAt int64
		err := s.db.QueryRow(query, params...).Scan(&dbUserID, &dbUsername, &dbDisabledAt)
		if err != nil {
			if err == sql.ErrNoRows {
				// No user account found.
				return nil, status.Error(codes.NotFound, "User account not found.")
			} else {
				s.logger.Error("Cannot find user with custom ID.", zap.Error(err), zap.Any("input", in))
				return nil, status.Error(codes.Internal, "Error finding user account.")
			}
		}

		if dbDisabledAt != 0 {
			s.logger.Debug("User account is disabled.", zap.Any("input", in))
			return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		return nil, err
	}

	token := generateToken(s.config, dbUserID, dbUsername)
	return &api.Session{Token: token}, nil
}
}

func (s *ApiServer) AuthenticateDeviceFunc(ctx context.Context, in *api.AuthenticateDevice) (*api.Session, error) {
	if in.Account == nil || in.Account.Id == "" {
@@ -130,8 +69,6 @@ func (s *ApiServer) AuthenticateDeviceFunc(ctx context.Context, in *api.Authenti
		return nil, status.Error(codes.InvalidArgument, "Device ID invalid, must be 10-128 bytes.")
	}

	if in.Create == nil || in.Create.Value {
		// Use existing user account if found, otherwise create a new user account.
	username := in.Username
	if username == "" {
		username = generateUsername(s.random)
@@ -141,91 +78,15 @@ func (s *ApiServer) AuthenticateDeviceFunc(ctx context.Context, in *api.Authenti
		return nil, status.Error(codes.InvalidArgument, "Username invalid, must be 1-128 bytes.")
	}

		var dbUserID string
		var dbUsername string
		fnErr := Transact(s.logger, s.db, func (tx *sql.Tx) error {
			userID := uuid.NewV4().String()
			ts := time.Now().UTC().Unix()
			query := `
INSERT INTO users (id, username, created_at, updated_at)
SELECT $1 AS id,
			 $2 AS username,
			 $4 AS created_at,
			 $4 AS updated_at
WHERE NOT EXISTS
    (SELECT id
     FROM user_device
     WHERE id = $3::VARCHAR)
RETURNING id, username, disabled_at`
			params := []interface{}{userID, username, in.Account.Id, ts}
	create := in.Create == nil || in.Create.Value

			var dbDisabledAt int64
			err := tx.QueryRow(query, params...).Scan(&dbUserID, &dbUsername, &dbDisabledAt)
	dbUserID, dbUsername, err := AuthenticateDevice(s.logger, s.db, in.Account.Id, username, create)
	if err != nil {
				if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") {
					return status.Error(codes.AlreadyExists, "Username is already in use.")
				}
				s.logger.Error("Cannot find or create user with device ID.", zap.Error(err), zap.Any("input", in))
				return status.Error(codes.Internal, "Error finding or creating user account.")
			}

			if dbDisabledAt != 0 {
				s.logger.Debug("User account is disabled.", zap.Any("input", in))
				return status.Error(codes.Unauthenticated, "Error finding or creating user account.")
			}

			query = "INSERT INTO user_device (id, user_id) VALUES ($1, $2)"
			params = []interface{}{in.Account.Id, userID}
			_, err = tx.Exec(query, params...)
			if err != nil {
				s.logger.Error("Cannot add device ID.", zap.Error(err), zap.Any("input", in))
				return status.Error(codes.Internal, "Error finding or creating user account.")
			}

			return nil
		})

		if fnErr != nil {
			return nil, fnErr
		return nil, err
	}

	token := generateToken(s.config, dbUserID, dbUsername)
	return &api.Session{Token: token}, nil
	} else {
		query := "SELECT user_id FROM user_device WHERE id = $1"
		params := []interface{}{in.Account.Id}

		var dbUserID string
		err := s.db.QueryRow(query, params...).Scan(&dbUserID)
		if err != nil {
			if err == sql.ErrNoRows {
				// No user account found.
				return nil, status.Error(codes.NotFound, "Device ID not found.")
			} else {
				s.logger.Error("Cannot find user with device ID.", zap.Error(err), zap.Any("input", in))
				return nil, status.Error(codes.Internal, "Error finding user account.")
			}
		}

		query = "SELECT username, disabled_at FROM users WHERE id = $1"
		params = []interface{}{dbUserID}
		var dbUsername string
		var dbDisabledAt int64

		err = s.db.QueryRow(query, params...).Scan(&dbUsername, &dbDisabledAt)
		if err != nil {
			s.logger.Error("Cannot find user with device ID.", zap.Error(err), zap.Any("input", in))
			return nil, status.Error(codes.Internal, "Error finding user account.")
		}

		if dbDisabledAt != 0 {
			s.logger.Debug("User account is disabled.", zap.Any("input", in))
			return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}

		token := generateToken(s.config, dbUserID, dbUsername)
		return &api.Session{Token: token}, nil
	}
}

func (s *ApiServer) AuthenticateEmailFunc(ctx context.Context, in *api.AuthenticateEmail) (*api.Session, error) {
@@ -243,10 +104,7 @@ func (s *ApiServer) AuthenticateEmailFunc(ctx context.Context, in *api.Authentic
	}

	cleanEmail := strings.ToLower(email.Email)
	hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(email.Password), bcrypt.DefaultCost)

	if in.Create == nil || in.Create.Value {
		// Use existing user account if found, otherwise create a new user account.
	username := in.Username
	if username == "" {
		username = generateUsername(s.random)
@@ -256,72 +114,16 @@ func (s *ApiServer) AuthenticateEmailFunc(ctx context.Context, in *api.Authentic
		return nil, status.Error(codes.InvalidArgument, "Username invalid, must be 1-128 bytes.")
	}

		userID := uuid.NewV4().String()
		ts := time.Now().UTC().Unix()
		query := `
INSERT INTO users (id, username, email, password, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $5)
ON CONFLICT (email) DO UPDATE SET email = $3, password = $4
RETURNING id, username, disabled_at`
		params := []interface{}{userID, username, cleanEmail, hashedPassword, ts}

		var dbUserID string
		var dbUsername string
		var dbDisabledAt int64
		err := s.db.QueryRow(query, params...).Scan(&dbUserID, &dbUsername, &dbDisabledAt)
		if err != nil {
			if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") {
				// Username is already in use by a different account.
				return nil, status.Error(codes.AlreadyExists, "Username is already in use.")
			}
			s.logger.Error("Cannot find or create user with email.", zap.Error(err), zap.Any("input", in))
			return nil, status.Error(codes.Internal, "Error finding or creating user account.")
		}

		if dbDisabledAt != 0 {
			s.logger.Debug("User account is disabled.", zap.Any("input", in))
			return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}
	create := in.Create == nil || in.Create.Value

		token := generateToken(s.config, dbUserID, dbUsername)
		return &api.Session{Token: token}, nil
	} else {
		// Do not create a new user account.
		query := `
SELECT id, username, password, disabled_at
FROM users
WHERE email = $1`
		params := []interface{}{cleanEmail}

		var dbUserID string
		var dbUsername string
		var dbPassword string
		var dbDisabledAt int64
		err := s.db.QueryRow(query, params...).Scan(&dbUserID, &dbUsername, &dbPassword, &dbDisabledAt)
	dbUserID, dbUsername, err := AuthenticateEmail(s.logger, s.db, cleanEmail, email.Password, username, create)
	if err != nil {
			if err == sql.ErrNoRows {
				// No user account found.
				return nil, status.Error(codes.NotFound, "User account not found.")
			} else {
				s.logger.Error("Cannot find user with email.", zap.Error(err), zap.Any("input", in))
				return nil, status.Error(codes.Internal, "Error finding user account.")
			}
		}

		if dbDisabledAt != 0 {
			s.logger.Debug("User account is disabled.", zap.Any("input", in))
			return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}

		err = bcrypt.CompareHashAndPassword(hashedPassword, []byte(email.Password))
		if err != nil {
			return nil, status.Error(codes.Unauthenticated, "Invalid credentials.")
		return nil, err
	}

	token := generateToken(s.config, dbUserID, dbUsername)
	return &api.Session{Token: token}, nil
}
}

func (s *ApiServer) AuthenticateFacebookFunc(ctx context.Context, in *api.AuthenticateFacebook) (*api.Session, error) {
	return nil, nil
+242 −0
Original line number Diff line number Diff line
// Copyright 2018 The Nakama Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package server

import (
	"google.golang.org/grpc/codes"
	"github.com/lib/pq"
	"go.uber.org/zap"
	"database/sql"
	"github.com/satori/go.uuid"
	"time"
	"google.golang.org/grpc/status"
	"strings"
	"golang.org/x/crypto/bcrypt"
)

func AuthenticateCustom(logger *zap.Logger, db *sql.DB, customID, username string, create bool) (string, string, error) {
	if create {
		// Use existing user account if found, otherwise create a new user account.
		userID := uuid.NewV4().String()
		ts := time.Now().UTC().Unix()
		// NOTE: This query relies on the `custom_id` conflict triggering before the `users_username_key`
		// constraint violation to ensure we fall to the RETURNING case and ignore the new username for
		// existing user accounts. The DO UPDATE SET is to trick the DB into having the data we need to return.
		query := `
INSERT INTO users (id, username, custom_id, created_at, updated_at)
VALUES ($1, $2, $3, $4, $4)
ON CONFLICT (custom_id) DO UPDATE SET custom_id = $3
RETURNING id, username, disabled_at`

		var dbUserID string
		var dbUsername string
		var dbDisabledAt int64
		err := db.QueryRow(query, userID, username, customID, ts).Scan(&dbUserID, &dbUsername, &dbDisabledAt)
		if err != nil {
			if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") {
				// Username is already in use by a different account.
				return "", "", status.Error(codes.AlreadyExists, "Username is already in use.")
			}
			logger.Error("Cannot find or create user with custom ID.", zap.Error(err), zap.String("customID", customID), zap.String("username", username), zap.Bool("create", create))
			return "", "", status.Error(codes.Internal, "Error finding or creating user account.")
		}

		if dbDisabledAt != 0 {
			logger.Debug("User account is disabled.", zap.String("customID", customID), zap.String("username", username), zap.Bool("create", create))
			return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}

		return dbUserID, dbUsername, nil
	} else {
		// Do not create a new user account.
		query := `
SELECT id, username, disabled_at
FROM users
WHERE custom_id = $1`

		var dbUserID string
		var dbUsername string
		var dbDisabledAt int64
		err := db.QueryRow(query, customID).Scan(&dbUserID, &dbUsername, &dbDisabledAt)
		if err != nil {
			if err == sql.ErrNoRows {
				// No user account found.
				return "", "", status.Error(codes.NotFound, "User account not found.")
			} else {
				logger.Error("Cannot find user with custom ID.", zap.Error(err), zap.String("customID", customID), zap.String("username", username), zap.Bool("create", create))
				return "", "", status.Error(codes.Internal, "Error finding user account.")
			}
		}

		if dbDisabledAt != 0 {
			logger.Debug("User account is disabled.", zap.String("customID", customID), zap.String("username", username), zap.Bool("create", create))
			return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}

		return dbUserID, dbUsername, nil
	}
}

func AuthenticateDevice(logger *zap.Logger, db *sql.DB, deviceID, username string, create bool) (string, string, error) {
	if create {
		// Use existing user account if found, otherwise create a new user account.
		var dbUserID string
		var dbUsername string
		fnErr := Transact(logger, db, func (tx *sql.Tx) error {
			userID := uuid.NewV4().String()
			ts := time.Now().UTC().Unix()
			query := `
INSERT INTO users (id, username, created_at, updated_at)
SELECT $1 AS id,
			 $2 AS username,
			 $4 AS created_at,
			 $4 AS updated_at
WHERE NOT EXISTS
    (SELECT id
     FROM user_device
     WHERE id = $3::VARCHAR)
RETURNING id, username, disabled_at`

			var dbDisabledAt int64
			err := tx.QueryRow(query, userID, username, deviceID, ts).Scan(&dbUserID, &dbUsername, &dbDisabledAt)
			if err != nil {
				if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") {
					return status.Error(codes.AlreadyExists, "Username is already in use.")
				}
				logger.Error("Cannot find or create user with device ID.", zap.Error(err), zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create))
				return status.Error(codes.Internal, "Error finding or creating user account.")
			}

			if dbDisabledAt != 0 {
				logger.Debug("User account is disabled.", zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create))
				return status.Error(codes.Unauthenticated, "Error finding or creating user account.")
			}

			query = "INSERT INTO user_device (id, user_id) VALUES ($1, $2)"
			_, err = tx.Exec(query, deviceID, userID)
			if err != nil {
				logger.Error("Cannot add device ID.", zap.Error(err), zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create))
				return status.Error(codes.Internal, "Error finding or creating user account.")
			}

			return nil
		})

		if fnErr != nil {
			return dbUserID, dbUsername, fnErr
		}

		return dbUserID, dbUsername, nil
	} else {
		query := "SELECT user_id FROM user_device WHERE id = $1"

		var dbUserID string
		err := db.QueryRow(query, deviceID).Scan(&dbUserID)
		if err != nil {
			if err == sql.ErrNoRows {
				// No user account found.
				return "", "", status.Error(codes.NotFound, "Device ID not found.")
			} else {
				logger.Error("Cannot find user with device ID.", zap.Error(err), zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create))
				return "", "", status.Error(codes.Internal, "Error finding user account.")
			}
		}

		query = "SELECT username, disabled_at FROM users WHERE id = $1"
		var dbUsername string
		var dbDisabledAt int64

		err = db.QueryRow(query, dbUserID).Scan(&dbUsername, &dbDisabledAt)
		if err != nil {
			logger.Error("Cannot find user with device ID.", zap.Error(err), zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create))
			return "", "", status.Error(codes.Internal, "Error finding user account.")
		}

		if dbDisabledAt != 0 {
			logger.Debug("User account is disabled.", zap.String("deviceID", deviceID), zap.String("username", username), zap.Bool("create", create))
			return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}

		return dbUserID, dbUsername, nil
	}
}

func AuthenticateEmail(logger *zap.Logger, db *sql.DB, email, password, username string, create bool) (string, string, error) {
	hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)

	if create {
		// Use existing user account if found, otherwise create a new user account.
		userID := uuid.NewV4().String()
		ts := time.Now().UTC().Unix()
		query := `
INSERT INTO users (id, username, email, password, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $5)
ON CONFLICT (email) DO UPDATE SET email = $3, password = $4
RETURNING id, username, disabled_at`

		var dbUserID string
		var dbUsername string
		var dbDisabledAt int64
		err := db.QueryRow(query, userID, username, email, hashedPassword, ts).Scan(&dbUserID, &dbUsername, &dbDisabledAt)
		if err != nil {
			if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") {
				// Username is already in use by a different account.
				return "", "", status.Error(codes.AlreadyExists, "Username is already in use.")
			}
			logger.Error("Cannot find or create user with email.", zap.Error(err), zap.String("email", email), zap.String("username", username), zap.Bool("create", create))
			return "", "", status.Error(codes.Internal, "Error finding or creating user account.")
		}

		if dbDisabledAt != 0 {
			logger.Debug("User account is disabled.", zap.String("email", email), zap.String("username", username), zap.Bool("create", create))
			return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}

		return dbUserID, dbUsername, nil
	} else {
		// Do not create a new user account.
		query := `
SELECT id, username, password, disabled_at
FROM users
WHERE email = $1`

		var dbUserID string
		var dbUsername string
		var dbPassword string
		var dbDisabledAt int64
		err := db.QueryRow(query, email).Scan(&dbUserID, &dbUsername, &dbPassword, &dbDisabledAt)
		if err != nil {
			if err == sql.ErrNoRows {
				// No user account found.
				return "", "", status.Error(codes.NotFound, "User account not found.")
			} else {
				logger.Error("Cannot find user with email.", zap.Error(err), zap.String("email", email), zap.String("username", username), zap.Bool("create", create))
				return "", "", status.Error(codes.Internal, "Error finding user account.")
			}
		}

		if dbDisabledAt != 0 {
			logger.Debug("User account is disabled.", zap.String("email", email), zap.String("username", username), zap.Bool("create", create))
			return "", "", status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}

		err = bcrypt.CompareHashAndPassword(hashedPassword, []byte(password))
		if err != nil {
			return "", "", status.Error(codes.Unauthenticated, "Invalid credentials.")
		}

		return dbUserID, dbUsername, nil
	}
}
+6 −2
Original line number Diff line number Diff line
@@ -30,6 +30,8 @@ import (
	"golang.org/x/net/context"
	"io/ioutil"
	"sync"
	"math/rand"
	"time"
)

const (
@@ -113,10 +115,12 @@ func NewRuntimePool(logger *zap.Logger, multiLogger *zap.Logger, db *sql.DB, con
		vm.Call(1, 0)
	}

	// Used to generate usernames in auth functions.
	random := rand.New(rand.NewSource(time.Now().UnixNano()))
	// Used to govern once-per-server-start executions.
	once := &sync.Once{}

	nakamaModule := NewNakamaModule(logger, db, config, vm, registry, tracker, router, once,
	nakamaModule := NewNakamaModule(logger, db, config, vm, registry, tracker, router, random, once,
		func(id string) {
			regRPC[id] = struct{}{}
			logger.Info("Registered RPC function invocation", zap.String("id", id))
@@ -155,7 +159,7 @@ func NewRuntimePool(logger *zap.Logger, multiLogger *zap.Logger, db *sql.DB, con
					vm.Call(1, 0)
				}

				nakamaModule := NewNakamaModule(logger, db, config, vm, registry, tracker, router, once, nil)
				nakamaModule := NewNakamaModule(logger, db, config, vm, registry, tracker, router, random, once, nil)
				vm.PreloadModule("nakama", nakamaModule.Loader)

				r := &Runtime{
+191 −48

File changed.

Preview size limit exceeded, changes collapsed.