Commit b33a885d authored by Mo Firouz's avatar Mo Firouz Committed by Andrei Mihu
Browse files

Device and email authentication. (#147)

parent ccdb0fd4
Loading
Loading
Loading
Loading
+215 −16
Original line number Diff line number Diff line
@@ -27,19 +27,22 @@ import (
	"database/sql"
	"github.com/dgrijalva/jwt-go"
	"strings"
	"github.com/lib/pq"
	"golang.org/x/crypto/bcrypt"
)

var (
	invalidCharsRegex = regexp.MustCompilePOSIX("([[:cntrl:]]|[[:space:]])+")
	emailRegex        = regexp.MustCompile("^.+@.+\\..+$")
)

func (s *ApiServer) AuthenticateCustomFunc(ctx context.Context, in *api.AuthenticateCustom) (*api.Session, error) {
	if in.Account == nil || in.Account.Id == "" {
		return nil, status.Error(codes.InvalidArgument, "Custom ID is required")
		return nil, status.Error(codes.InvalidArgument, "Custom ID is required.")
	} else if invalidCharsRegex.MatchString(in.Account.Id) {
		return nil, status.Error(codes.InvalidArgument, "Custom ID invalid, no spaces or control characters allowed")
		return nil, status.Error(codes.InvalidArgument, "Custom ID invalid, no spaces or control characters allowed.")
	} else if len(in.Account.Id) < 10 || len(in.Account.Id) > 128 {
		return nil, status.Error(codes.InvalidArgument, "Custom ID invalid, must be 10-128 bytes")
		return nil, status.Error(codes.InvalidArgument, "Custom ID invalid, must be 10-128 bytes.")
	}

	if in.Create == nil || in.Create.Value {
@@ -48,9 +51,9 @@ func (s *ApiServer) AuthenticateCustomFunc(ctx context.Context, in *api.Authenti
		if username == "" {
			username = generateUsername(s.random)
		} else if invalidCharsRegex.MatchString(username) {
			return nil, status.Error(codes.InvalidArgument, "Username invalid, no spaces or control characters allowed")
			return nil, status.Error(codes.InvalidArgument, "Username invalid, no spaces or control characters allowed.")
		} else if len(username) > 128 {
			return nil, status.Error(codes.InvalidArgument, "Username invalid, must be 1-128 bytes")
			return nil, status.Error(codes.InvalidArgument, "Username invalid, must be 1-128 bytes.")
		}

		userID := uuid.NewV4().String()
@@ -71,16 +74,17 @@ RETURNING id, username, custom_id, disabled_at`
		var dbDisabledAt int64
		err := s.db.QueryRow(query, params...).Scan(&dbUserID, &dbUsername, &dbCustomId, &dbDisabledAt)
		if err != nil {
			if strings.HasSuffix(err.Error(), "violates unique constraint \"users_username_key\"") {
			if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") {
				// Username is already in use by a different account.
				return nil, status.Error(codes.AlreadyExists, "Username is already in use")
				return nil, status.Error(codes.AlreadyExists, "Username is already in use.")
			}
			s.logger.Error("Cannot find or create user with custom ID, query error", zap.Error(err))
			return nil, status.Error(codes.Internal, "Error finding or creating user account")
			s.logger.Error("Cannot find or create user with custom ID.", zap.Error(err), zap.Any("input", in))
			return nil, status.Error(codes.Internal, "Error finding or creating user account.")
		}

		if dbDisabledAt != 0 {
			return nil, status.Error(codes.PermissionDenied, "User account is disabled")
			s.logger.Debug("User account is disabled.", zap.Any("input", in))
			return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}

		token := generateToken(s.config, dbUserID, dbUsername)
@@ -100,15 +104,16 @@ WHERE custom_id = $1`
		if err != nil {
			if err == sql.ErrNoRows {
				// No user account found.
				return nil, status.Error(codes.NotFound, "User account not found")
				return nil, status.Error(codes.NotFound, "User account not found.")
			} else {
				s.logger.Error("Cannot find user with custom ID, query error", zap.Error(err))
				return nil, status.Error(codes.Internal, "Error finding user user account")
				s.logger.Error("Cannot find user with custom ID.", zap.Error(err), zap.Any("input", in))
				return nil, status.Error(codes.Internal, "Error finding user account.")
			}
		}

		if dbDisabledAt != 0 {
			return nil, status.Error(codes.PermissionDenied, "User account is disabled")
			s.logger.Debug("User account is disabled.", zap.Any("input", in))
			return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}

		token := generateToken(s.config, dbUserID, dbUsername)
@@ -117,11 +122,205 @@ WHERE custom_id = $1`
}

func (s *ApiServer) AuthenticateDeviceFunc(ctx context.Context, in *api.AuthenticateDevice) (*api.Session, error) {
	return nil, nil
	if in.Account == nil || in.Account.Id == "" {
		return nil, status.Error(codes.InvalidArgument, "Device ID is required.")
	} else if invalidCharsRegex.MatchString(in.Account.Id) {
		return nil, status.Error(codes.InvalidArgument, "Device ID invalid, no spaces or control characters allowed.")
	} else if len(in.Account.Id) < 10 || len(in.Account.Id) > 128 {
		return nil, status.Error(codes.InvalidArgument, "Device ID invalid, must be 10-128 bytes.")
	}

	if in.Create == nil || in.Create.Value {
		// Use existing user account if found, otherwise create a new user account.
		username := in.Username
		if username == "" {
			username = generateUsername(s.random)
		} else if invalidCharsRegex.MatchString(username) {
			return nil, status.Error(codes.InvalidArgument, "Username invalid, no spaces or control characters allowed.")
		} else if len(username) > 128 {
			return nil, status.Error(codes.InvalidArgument, "Username invalid, must be 1-128 bytes.")
		}

		var dbUserID string
		var dbUsername string
		fnErr := Transact(s.logger, s.db, func (tx *sql.Tx) error {
			userID := uuid.NewV4().String()
			ts := time.Now().UTC().Unix()
			query := `
INSERT INTO users (id, username, created_at, updated_at)
SELECT $1 AS id,
			 $2 AS username,
			 $4 AS created_at,
			 $4 AS updated_at
WHERE NOT EXISTS
    (SELECT id
     FROM user_device
     WHERE id = $3::VARCHAR)
RETURNING id, username, disabled_at`
			params := []interface{}{userID, username, in.Account.Id, ts}

			var dbDisabledAt int64
			err := tx.QueryRow(query, params...).Scan(&dbUserID, &dbUsername, &dbDisabledAt)
			if err != nil {
				if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") {
					return status.Error(codes.AlreadyExists, "Username is already in use.")
				}
				s.logger.Error("Cannot find or create user with device ID.", zap.Error(err), zap.Any("input", in))
				return status.Error(codes.Internal, "Error finding or creating user account.")
			}

			if dbDisabledAt != 0 {
				s.logger.Debug("User account is disabled.", zap.Any("input", in))
				return status.Error(codes.Unauthenticated, "Error finding or creating user account.")
			}

			query = "INSERT INTO user_device (id, user_id) VALUES ($1, $2)"
			params = []interface{}{in.Account.Id, userID}
			_, err = tx.Exec(query, params...)
			if err != nil {
				s.logger.Error("Cannot add device ID.", zap.Error(err), zap.Any("input", in))
				return status.Error(codes.Internal, "Error finding or creating user account.")
			}

			return nil
		})

		if fnErr != nil {
			return nil, fnErr
		}

		token := generateToken(s.config, dbUserID, dbUsername)
		return &api.Session{Token: token}, nil
	} else {
		query := "SELECT user_id FROM user_device WHERE id = $1"
		params := []interface{}{in.Account.Id}

		var dbUserID string
		err := s.db.QueryRow(query, params...).Scan(&dbUserID)
		if err != nil {
			if err == sql.ErrNoRows {
				// No user account found.
				return nil, status.Error(codes.NotFound, "Device ID not found.")
			} else {
				s.logger.Error("Cannot find user with device ID.", zap.Error(err), zap.Any("input", in))
				return nil, status.Error(codes.Internal, "Error finding user account.")
			}
		}

		query = "SELECT username, disabled_at FROM users WHERE id = $1"
		params = []interface{}{dbUserID}
		var dbUsername string
		var dbDisabledAt int64

		err = s.db.QueryRow(query, params...).Scan(&dbUsername, &dbDisabledAt)
		if err != nil {
			s.logger.Error("Cannot find user with device ID.", zap.Error(err), zap.Any("input", in))
			return nil, status.Error(codes.Internal, "Error finding user account.")
		}

		if dbDisabledAt != 0 {
			s.logger.Debug("User account is disabled.", zap.Any("input", in))
			return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}

		token := generateToken(s.config, dbUserID, dbUsername)
		return &api.Session{Token: token}, nil
	}
}

func (s *ApiServer) AuthenticateEmailFunc(ctx context.Context, in *api.AuthenticateEmail) (*api.Session, error) {
	return nil, nil
	email := in.Account
	if email == nil || email.Email == "" || email.Password == "" {
		return nil, status.Error(codes.InvalidArgument, "Email address and password is required.")
	} else if invalidCharsRegex.MatchString(email.Email) {
		return nil, status.Error(codes.InvalidArgument, "Invalid email address, no spaces or control characters allowed.")
	} else if len(email.Password) < 8 {
		return nil, status.Error(codes.InvalidArgument, "Password must be longer than 8 characters.")
	} else if !emailRegex.MatchString(email.Email) {
		return nil, status.Error(codes.InvalidArgument, "Invalid email address format.")
	} else if len(email.Email) < 10 || len(email.Email) > 255 {
		return nil, status.Error(codes.InvalidArgument, "Invalid email address, must be 10-255 bytes.")
	}

	cleanEmail := strings.ToLower(email.Email)
	hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(email.Password), bcrypt.DefaultCost)

	if in.Create == nil || in.Create.Value {
		// Use existing user account if found, otherwise create a new user account.
		username := in.Username
		if username == "" {
			username = generateUsername(s.random)
		} else if invalidCharsRegex.MatchString(username) {
			return nil, status.Error(codes.InvalidArgument, "Username invalid , no spaces or control characters allowed.")
		} else if len(username) > 128 {
			return nil, status.Error(codes.InvalidArgument, "Username invalid, must be 1-128 bytes.")
		}

		userID := uuid.NewV4().String()
		ts := time.Now().UTC().Unix()
		query := `
INSERT INTO users (id, username, email, password, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $5)
ON CONFLICT (email) DO UPDATE SET email = $3, password = $4
RETURNING id, username, disabled_at`
		params := []interface{}{userID, username, cleanEmail, hashedPassword, ts}

		var dbUserID string
		var dbUsername string
		var dbDisabledAt int64
		err := s.db.QueryRow(query, params...).Scan(&dbUserID, &dbUsername, &dbDisabledAt)
		if err != nil {
			if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") {
				// Username is already in use by a different account.
				return nil, status.Error(codes.AlreadyExists, "Username is already in use.")
			}
			s.logger.Error("Cannot find or create user with email.", zap.Error(err), zap.Any("input", in))
			return nil, status.Error(codes.Internal, "Error finding or creating user account.")
		}

		if dbDisabledAt != 0 {
			s.logger.Debug("User account is disabled.", zap.Any("input", in))
			return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}

		token := generateToken(s.config, dbUserID, dbUsername)
		return &api.Session{Token: token}, nil
	} else {
		// Do not create a new user account.
		query := `
SELECT id, username, password, disabled_at
FROM users
WHERE email = $1`
		params := []interface{}{cleanEmail}

		var dbUserID string
		var dbUsername string
		var dbPassword string
		var dbDisabledAt int64
		err := s.db.QueryRow(query, params...).Scan(&dbUserID, &dbUsername, &dbPassword, &dbDisabledAt)
		if err != nil {
			if err == sql.ErrNoRows {
				// No user account found.
				return nil, status.Error(codes.NotFound, "User account not found.")
			} else {
				s.logger.Error("Cannot find user with email.", zap.Error(err), zap.Any("input", in))
				return nil, status.Error(codes.Internal, "Error finding user account.")
			}
		}

		if dbDisabledAt != 0 {
			s.logger.Debug("User account is disabled.", zap.Any("input", in))
			return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.")
		}

		err = bcrypt.CompareHashAndPassword(hashedPassword, []byte(email.Password))
		if err != nil {
			return nil, status.Error(codes.Unauthenticated, "Invalid credentials.")
		}

		token := generateToken(s.config, dbUserID, dbUsername)
		return &api.Session{Token: token}, nil
	}
}

func (s *ApiServer) AuthenticateFacebookFunc(ctx context.Context, in *api.AuthenticateFacebook) (*api.Session, error) {
+121 −3
Original line number Diff line number Diff line
@@ -18,18 +18,136 @@ import (
	"golang.org/x/net/context"
	"github.com/heroiclabs/nakama/api"
	"github.com/golang/protobuf/ptypes/empty"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"go.uber.org/zap"
	"time"
	"database/sql"
	"github.com/lib/pq"
	"golang.org/x/crypto/bcrypt"
	"strings"
)

func (s *ApiServer) LinkCustomFunc(ctx context.Context, in *api.AccountCustom) (*empty.Empty, error) {
	return nil, nil
	customID := in.Id
	if customID == "" {
		return nil, status.Error(codes.InvalidArgument, "Custom ID is required.")
	} else if invalidCharsRegex.MatchString(customID) {
		return nil, status.Error(codes.InvalidArgument,  "Invalid custom ID, no spaces or control characters allowed.")
	} else if len(customID) < 10 || len(customID) > 128 {
		return nil, status.Error(codes.InvalidArgument,  "Invalid custom ID, must be 10-128 bytes.")
	}

	userID := ctx.Value(ctxUserIDKey{})
	ts := time.Now().UTC().Unix()
	res, err := s.db.Exec(`
UPDATE users
SET custom_id = $2, updated_at = $3
WHERE (id = $1)
AND (NOT EXISTS
    (SELECT id
     FROM users
     WHERE custom_id = $2 AND NOT id = $1))`,
		userID,
		customID,
		ts)

	if err != nil {
		s.logger.Warn("Could not link custom ID.", zap.Error(err), zap.Any("input", in))
		return nil, status.Error(codes.Internal, "Error while trying to link Custom ID.")
	} else if count, _ := res.RowsAffected(); count == 0 {
		return nil, status.Error(codes.AlreadyExists, "Custom ID is already in use.")
	}

	return &empty.Empty{}, nil
}

func (s *ApiServer) LinkDeviceFunc(ctx context.Context, in *api.AccountDevice) (*empty.Empty, error) {
	return nil, nil
	deviceID := in.Id
	if deviceID == "" {
		return nil, status.Error(codes.InvalidArgument, "Device ID is required.")
	} else if invalidCharsRegex.MatchString(deviceID) {
		return nil, status.Error(codes.InvalidArgument,  "Device ID invalid, no spaces or control characters allowed.")
	} else if len(deviceID) < 10 || len(deviceID) > 128 {
		return nil, status.Error(codes.InvalidArgument,  "Device ID invalid, must be 10-128 bytes.")
	}

	fnErr := Transact(s.logger, s.db, func (tx *sql.Tx) error {
		userID := ctx.Value(ctxUserIDKey{})
		ts := time.Now().UTC().Unix()

		var dbDeviceIdLinkedUser int64
		err := tx.QueryRow("SELECT COUNT(id) FROM user_device WHERE id = $1 AND user_id = $2 LIMIT 1", deviceID, userID).Scan(&dbDeviceIdLinkedUser)
		if err != nil {
			s.logger.Error("Cannot link device ID.", zap.Error(err), zap.Any("input", in))
			return status.Error(codes.Internal, "Error linking Device ID.")
		}

		if dbDeviceIdLinkedUser == 0 {
			_, err = tx.Exec("INSERT INTO user_device (id, user_id) VALUES ($1, $2)", deviceID, userID)
			if err != nil {
				if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation {
					return status.Error(codes.AlreadyExists, "Device ID already in use.")
				}
				s.logger.Error("Cannot link device ID.", zap.Error(err), zap.Any("input", in))
				return status.Error(codes.Internal, "Error linking Device ID.")
			}
		}

		_, err = tx.Exec("UPDATE users SET updated_at = $1 WHERE id = $2", ts, userID)
		if err != nil {
			s.logger.Error("Cannot update users table while linking.", zap.Error(err), zap.Any("input", in))
			return status.Error(codes.Internal, "Error linking Device ID.")
		}
		return nil
	})

	if fnErr != nil {
		return nil, fnErr
	}

	return &empty.Empty{}, nil
}

func (s *ApiServer) LinkEmailFunc(ctx context.Context, in *api.AccountEmail) (*empty.Empty, error) {
	return nil, nil
	if in.Email == "" || in.Password == "" {
		return nil, status.Error(codes.InvalidArgument, "Email address and password is required.")
	} else if invalidCharsRegex.MatchString(in.Email) {
		return nil, status.Error(codes.InvalidArgument, "Invalid email address, no spaces or control characters allowed.")
	} else if len(in.Password) < 8 {
		return nil, status.Error(codes.InvalidArgument, "Password must be longer than 8 characters.")
	} else if !emailRegex.MatchString(in.Email) {
		return nil, status.Error(codes.InvalidArgument, "Invalid email address format.")
	} else if len(in.Email) < 10 || len(in.Email) > 255 {
		return nil, status.Error(codes.InvalidArgument, "Invalid email address, must be 10-255 bytes.")
	}

	cleanEmail := strings.ToLower(in.Email)
	hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(in.Password), bcrypt.DefaultCost)

	userID := ctx.Value(ctxUserIDKey{})
	ts := time.Now().UTC().Unix()
	res, err := s.db.Exec(`
UPDATE users
SET email = $2, password = $3, updated_at = $4
WHERE (id = $1)
AND (NOT EXISTS
    (SELECT id
     FROM users
     WHERE email = $2 AND NOT id = $1))`,
		userID,
		cleanEmail,
		hashedPassword,
		ts)

	if err != nil {
		s.logger.Warn("Could not link email.", zap.Error(err), zap.Any("input", in))
		return nil, status.Error(codes.Internal, "Error while trying to link email.")
	} else if count, _ := res.RowsAffected(); count == 0 {
		return nil, status.Error(codes.AlreadyExists, "Email is already in use.")
	}

	return &empty.Empty{}, nil
}

func (s *ApiServer) LinkFacebookFunc(ctx context.Context, in *api.AccountFacebook) (*empty.Empty, error) {
+93 −3
Original line number Diff line number Diff line
@@ -18,18 +18,108 @@ import (
	"golang.org/x/net/context"
	"github.com/heroiclabs/nakama/api"
	"github.com/golang/protobuf/ptypes/empty"
	"go.uber.org/zap"
	"time"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"database/sql"
	"strings"
)

func (s *ApiServer) UnlinkCustomFunc(ctx context.Context, in *api.AccountCustom) (*empty.Empty, error) {
	return nil, nil
	query := `UPDATE users SET custom_id = NULL, updated_at = $3
WHERE id = $1
AND custom_id = $2
AND ((facebook_id IS NOT NULL
      OR google_id IS NOT NULL
      OR gamecenter_id IS NOT NULL
      OR steam_id IS NOT NULL
      OR email IS NOT NULL)
     OR
     EXISTS (SELECT id FROM user_device WHERE user_id = $1 LIMIT 1))`

	userID := ctx.Value(ctxUserIDKey{})
	ts := time.Now().UTC().Unix()
	res, err := s.db.Exec(query, userID, in.Id, ts)

	if err != nil {
		s.logger.Warn("Could not unlink custom ID.", zap.Error(err), zap.Any("input", in))
		return nil, status.Error(codes.Internal, "Error while trying to unlink custom ID.")
	} else if count, _ := res.RowsAffected(); count == 0 {
		return nil, status.Error(codes.PermissionDenied, "Cannot unlink last account identifier. Check profile exists and is not last link.")
	}

	return &empty.Empty{}, nil
}

func (s *ApiServer) UnlinkDeviceFunc(ctx context.Context, in *api.AccountDevice) (*empty.Empty, error) {
	return nil, nil
	fnErr := Transact(s.logger, s.db, func (tx *sql.Tx) error {
		userID := ctx.Value(ctxUserIDKey{})
		ts := time.Now().UTC().Unix()

		query := `DELETE FROM user_device WHERE id = $2 AND user_id = $1
AND (EXISTS (SELECT id FROM users WHERE id = $1 AND
    (facebook_id IS NOT NULL
     OR google_id IS NOT NULL
     OR gamecenter_id IS NOT NULL
     OR steam_id IS NOT NULL
     OR email IS NOT NULL
     OR custom_id IS NOT NULL))
   OR EXISTS (SELECT id FROM user_device WHERE user_id = $1 AND id <> $2))`

    res, err := tx.Exec(query, userID, in.Id)
		if err != nil {
			s.logger.Warn("Could not unlink device ID.", zap.Error(err), zap.Any("input", in))
			return status.Error(codes.Internal, "Could not unlink Device ID.")
		}
		if count, _ := res.RowsAffected(); count == 0 {
			return status.Error(codes.PermissionDenied, "Cannot unlink last account identifier. Check profile exists and is not last link.")
		}

		res, err = tx.Exec("UPDATE users SET updated_at = $2 WHERE id = $1", userID, ts)
		if err != nil {
			s.logger.Warn("Could not unlink device ID.", zap.Error(err), zap.Any("input", in))
			return status.Error(codes.Internal, "Could not unlink Device ID.")
		}
		if count, _ := res.RowsAffected(); count == 0 {
			return status.Error(codes.PermissionDenied, "Cannot unlink last account identifier. Check profile exists and is not last link.")
		}

		return nil
	})

	if fnErr != nil {
		return nil, fnErr
	}

	return &empty.Empty{}, nil
}

func (s *ApiServer) UnlinkEmailFunc(ctx context.Context, in *api.AccountEmail) (*empty.Empty, error) {
	return nil, nil
	query := `UPDATE users SET email = NULL, password = NULL, updated_at = $3
WHERE id = $1
AND email = $2
AND ((facebook_id IS NOT NULL
      OR google_id IS NOT NULL
      OR gamecenter_id IS NOT NULL
      OR steam_id IS NOT NULL
      OR custom_id IS NOT NULL)
     OR
     EXISTS (SELECT id FROM user_device WHERE user_id = $1 LIMIT 1))`

	userID := ctx.Value(ctxUserIDKey{})
	ts := time.Now().UTC().Unix()
	cleanEmail := strings.ToLower(in.Email)
	res, err := s.db.Exec(query, userID, cleanEmail, ts)

	if err != nil {
		s.logger.Warn("Could not unlink email.", zap.Error(err), zap.Any("input", in))
		return nil, status.Error(codes.Internal, "Error while trying to unlink email.")
	} else if count, _ := res.RowsAffected(); count == 0 {
		return nil, status.Error(codes.PermissionDenied, "Cannot unlink last account identifier. Check profile exists and is not last link.")
	}

	return &empty.Empty{}, nil
}

func (s *ApiServer) UnlinkFacebookFunc(ctx context.Context, in *api.AccountFacebook) (*empty.Empty, error) {

server/db.go

0 → 100644
+52 −0
Original line number Diff line number Diff line
// Copyright 2018 The Nakama Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package server

import (
	"database/sql"
	"go.uber.org/zap"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

const (
	dbErrorUniqueViolation = "23505"
)

func Transact(logger *zap.Logger, db *sql.DB, txFunc func(*sql.Tx) error) (err error) {
	tx, err := db.Begin()
	if err != nil {
		logger.Error("Could not begin database transaction.", zap.Error(err))
		return
	}

	fnErr := txFunc(tx)

	if p := recover(); p != nil {
		if err = tx.Rollback(); err != nil {
			logger.Error("Could not rollback database transaction.", zap.Error(err))
		}
	} else if fnErr != nil {
		if err = tx.Rollback(); err != nil {
			logger.Error("Could not rollback database transaction.", zap.Error(err))
		}
	} else {
		if err = tx.Commit(); err != nil {
			logger.Error("Could not commit database transaction.", zap.Error(err))
			return status.Error(codes.Internal, "Could not complete operation.")
		}
	}
	return fnErr
}