From b33a885d3676260c29119a4c988c9e6b8b6f8d44 Mon Sep 17 00:00:00 2001 From: Mo Firouz Date: Tue, 6 Feb 2018 18:33:19 -0800 Subject: [PATCH] Device and email authentication. (#147) --- server/api_authenticate.go | 231 ++++++++++++++++++++++++++++++++++--- server/api_link.go | 124 +++++++++++++++++++- server/api_unlink.go | 96 ++++++++++++++- server/db.go | 52 +++++++++ 4 files changed, 481 insertions(+), 22 deletions(-) create mode 100644 server/db.go diff --git a/server/api_authenticate.go b/server/api_authenticate.go index 272ae1788..8ce1a5af7 100644 --- a/server/api_authenticate.go +++ b/server/api_authenticate.go @@ -27,19 +27,22 @@ import ( "database/sql" "github.com/dgrijalva/jwt-go" "strings" + "github.com/lib/pq" + "golang.org/x/crypto/bcrypt" ) var ( invalidCharsRegex = regexp.MustCompilePOSIX("([[:cntrl:]]|[[:space:]])+") + emailRegex = regexp.MustCompile("^.+@.+\\..+$") ) func (s *ApiServer) AuthenticateCustomFunc(ctx context.Context, in *api.AuthenticateCustom) (*api.Session, error) { if in.Account == nil || in.Account.Id == "" { - return nil, status.Error(codes.InvalidArgument, "Custom ID is required") + return nil, status.Error(codes.InvalidArgument, "Custom ID is required.") } else if invalidCharsRegex.MatchString(in.Account.Id) { - return nil, status.Error(codes.InvalidArgument, "Custom ID invalid, no spaces or control characters allowed") + return nil, status.Error(codes.InvalidArgument, "Custom ID invalid, no spaces or control characters allowed.") } else if len(in.Account.Id) < 10 || len(in.Account.Id) > 128 { - return nil, status.Error(codes.InvalidArgument, "Custom ID invalid, must be 10-128 bytes") + return nil, status.Error(codes.InvalidArgument, "Custom ID invalid, must be 10-128 bytes.") } if in.Create == nil || in.Create.Value { @@ -48,9 +51,9 @@ func (s *ApiServer) AuthenticateCustomFunc(ctx context.Context, in *api.Authenti if username == "" { username = generateUsername(s.random) } else if invalidCharsRegex.MatchString(username) { - return nil, status.Error(codes.InvalidArgument, "Username invalid, no spaces or control characters allowed") + return nil, status.Error(codes.InvalidArgument, "Username invalid, no spaces or control characters allowed.") } else if len(username) > 128 { - return nil, status.Error(codes.InvalidArgument, "Username invalid, must be 1-128 bytes") + return nil, status.Error(codes.InvalidArgument, "Username invalid, must be 1-128 bytes.") } userID := uuid.NewV4().String() @@ -71,16 +74,17 @@ RETURNING id, username, custom_id, disabled_at` var dbDisabledAt int64 err := s.db.QueryRow(query, params...).Scan(&dbUserID, &dbUsername, &dbCustomId, &dbDisabledAt) if err != nil { - if strings.HasSuffix(err.Error(), "violates unique constraint \"users_username_key\"") { + if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") { // Username is already in use by a different account. - return nil, status.Error(codes.AlreadyExists, "Username is already in use") + return nil, status.Error(codes.AlreadyExists, "Username is already in use.") } - s.logger.Error("Cannot find or create user with custom ID, query error", zap.Error(err)) - return nil, status.Error(codes.Internal, "Error finding or creating user account") + s.logger.Error("Cannot find or create user with custom ID.", zap.Error(err), zap.Any("input", in)) + return nil, status.Error(codes.Internal, "Error finding or creating user account.") } if dbDisabledAt != 0 { - return nil, status.Error(codes.PermissionDenied, "User account is disabled") + s.logger.Debug("User account is disabled.", zap.Any("input", in)) + return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.") } token := generateToken(s.config, dbUserID, dbUsername) @@ -100,15 +104,16 @@ WHERE custom_id = $1` if err != nil { if err == sql.ErrNoRows { // No user account found. - return nil, status.Error(codes.NotFound, "User account not found") + return nil, status.Error(codes.NotFound, "User account not found.") } else { - s.logger.Error("Cannot find user with custom ID, query error", zap.Error(err)) - return nil, status.Error(codes.Internal, "Error finding user user account") + s.logger.Error("Cannot find user with custom ID.", zap.Error(err), zap.Any("input", in)) + return nil, status.Error(codes.Internal, "Error finding user account.") } } if dbDisabledAt != 0 { - return nil, status.Error(codes.PermissionDenied, "User account is disabled") + s.logger.Debug("User account is disabled.", zap.Any("input", in)) + return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.") } token := generateToken(s.config, dbUserID, dbUsername) @@ -117,11 +122,205 @@ WHERE custom_id = $1` } func (s *ApiServer) AuthenticateDeviceFunc(ctx context.Context, in *api.AuthenticateDevice) (*api.Session, error) { - return nil, nil + if in.Account == nil || in.Account.Id == "" { + return nil, status.Error(codes.InvalidArgument, "Device ID is required.") + } else if invalidCharsRegex.MatchString(in.Account.Id) { + return nil, status.Error(codes.InvalidArgument, "Device ID invalid, no spaces or control characters allowed.") + } else if len(in.Account.Id) < 10 || len(in.Account.Id) > 128 { + return nil, status.Error(codes.InvalidArgument, "Device ID invalid, must be 10-128 bytes.") + } + + if in.Create == nil || in.Create.Value { + // Use existing user account if found, otherwise create a new user account. + username := in.Username + if username == "" { + username = generateUsername(s.random) + } else if invalidCharsRegex.MatchString(username) { + return nil, status.Error(codes.InvalidArgument, "Username invalid, no spaces or control characters allowed.") + } else if len(username) > 128 { + return nil, status.Error(codes.InvalidArgument, "Username invalid, must be 1-128 bytes.") + } + + var dbUserID string + var dbUsername string + fnErr := Transact(s.logger, s.db, func (tx *sql.Tx) error { + userID := uuid.NewV4().String() + ts := time.Now().UTC().Unix() + query := ` +INSERT INTO users (id, username, created_at, updated_at) +SELECT $1 AS id, + $2 AS username, + $4 AS created_at, + $4 AS updated_at +WHERE NOT EXISTS + (SELECT id + FROM user_device + WHERE id = $3::VARCHAR) +RETURNING id, username, disabled_at` + params := []interface{}{userID, username, in.Account.Id, ts} + + var dbDisabledAt int64 + err := tx.QueryRow(query, params...).Scan(&dbUserID, &dbUsername, &dbDisabledAt) + if err != nil { + if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") { + return status.Error(codes.AlreadyExists, "Username is already in use.") + } + s.logger.Error("Cannot find or create user with device ID.", zap.Error(err), zap.Any("input", in)) + return status.Error(codes.Internal, "Error finding or creating user account.") + } + + if dbDisabledAt != 0 { + s.logger.Debug("User account is disabled.", zap.Any("input", in)) + return status.Error(codes.Unauthenticated, "Error finding or creating user account.") + } + + query = "INSERT INTO user_device (id, user_id) VALUES ($1, $2)" + params = []interface{}{in.Account.Id, userID} + _, err = tx.Exec(query, params...) + if err != nil { + s.logger.Error("Cannot add device ID.", zap.Error(err), zap.Any("input", in)) + return status.Error(codes.Internal, "Error finding or creating user account.") + } + + return nil + }) + + if fnErr != nil { + return nil, fnErr + } + + token := generateToken(s.config, dbUserID, dbUsername) + return &api.Session{Token: token}, nil + } else { + query := "SELECT user_id FROM user_device WHERE id = $1" + params := []interface{}{in.Account.Id} + + var dbUserID string + err := s.db.QueryRow(query, params...).Scan(&dbUserID) + if err != nil { + if err == sql.ErrNoRows { + // No user account found. + return nil, status.Error(codes.NotFound, "Device ID not found.") + } else { + s.logger.Error("Cannot find user with device ID.", zap.Error(err), zap.Any("input", in)) + return nil, status.Error(codes.Internal, "Error finding user account.") + } + } + + query = "SELECT username, disabled_at FROM users WHERE id = $1" + params = []interface{}{dbUserID} + var dbUsername string + var dbDisabledAt int64 + + err = s.db.QueryRow(query, params...).Scan(&dbUsername, &dbDisabledAt) + if err != nil { + s.logger.Error("Cannot find user with device ID.", zap.Error(err), zap.Any("input", in)) + return nil, status.Error(codes.Internal, "Error finding user account.") + } + + if dbDisabledAt != 0 { + s.logger.Debug("User account is disabled.", zap.Any("input", in)) + return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.") + } + + token := generateToken(s.config, dbUserID, dbUsername) + return &api.Session{Token: token}, nil + } } func (s *ApiServer) AuthenticateEmailFunc(ctx context.Context, in *api.AuthenticateEmail) (*api.Session, error) { - return nil, nil + email := in.Account + if email == nil || email.Email == "" || email.Password == "" { + return nil, status.Error(codes.InvalidArgument, "Email address and password is required.") + } else if invalidCharsRegex.MatchString(email.Email) { + return nil, status.Error(codes.InvalidArgument, "Invalid email address, no spaces or control characters allowed.") + } else if len(email.Password) < 8 { + return nil, status.Error(codes.InvalidArgument, "Password must be longer than 8 characters.") + } else if !emailRegex.MatchString(email.Email) { + return nil, status.Error(codes.InvalidArgument, "Invalid email address format.") + } else if len(email.Email) < 10 || len(email.Email) > 255 { + return nil, status.Error(codes.InvalidArgument, "Invalid email address, must be 10-255 bytes.") + } + + cleanEmail := strings.ToLower(email.Email) + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(email.Password), bcrypt.DefaultCost) + + if in.Create == nil || in.Create.Value { + // Use existing user account if found, otherwise create a new user account. + username := in.Username + if username == "" { + username = generateUsername(s.random) + } else if invalidCharsRegex.MatchString(username) { + return nil, status.Error(codes.InvalidArgument, "Username invalid , no spaces or control characters allowed.") + } else if len(username) > 128 { + return nil, status.Error(codes.InvalidArgument, "Username invalid, must be 1-128 bytes.") + } + + userID := uuid.NewV4().String() + ts := time.Now().UTC().Unix() + query := ` +INSERT INTO users (id, username, email, password, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5, $5) +ON CONFLICT (email) DO UPDATE SET email = $3, password = $4 +RETURNING id, username, disabled_at` + params := []interface{}{userID, username, cleanEmail, hashedPassword, ts} + + var dbUserID string + var dbUsername string + var dbDisabledAt int64 + err := s.db.QueryRow(query, params...).Scan(&dbUserID, &dbUsername, &dbDisabledAt) + if err != nil { + if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation && strings.Contains(e.Message, "users_username_key") { + // Username is already in use by a different account. + return nil, status.Error(codes.AlreadyExists, "Username is already in use.") + } + s.logger.Error("Cannot find or create user with email.", zap.Error(err), zap.Any("input", in)) + return nil, status.Error(codes.Internal, "Error finding or creating user account.") + } + + if dbDisabledAt != 0 { + s.logger.Debug("User account is disabled.", zap.Any("input", in)) + return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.") + } + + token := generateToken(s.config, dbUserID, dbUsername) + return &api.Session{Token: token}, nil + } else { + // Do not create a new user account. + query := ` +SELECT id, username, password, disabled_at +FROM users +WHERE email = $1` + params := []interface{}{cleanEmail} + + var dbUserID string + var dbUsername string + var dbPassword string + var dbDisabledAt int64 + err := s.db.QueryRow(query, params...).Scan(&dbUserID, &dbUsername, &dbPassword, &dbDisabledAt) + if err != nil { + if err == sql.ErrNoRows { + // No user account found. + return nil, status.Error(codes.NotFound, "User account not found.") + } else { + s.logger.Error("Cannot find user with email.", zap.Error(err), zap.Any("input", in)) + return nil, status.Error(codes.Internal, "Error finding user account.") + } + } + + if dbDisabledAt != 0 { + s.logger.Debug("User account is disabled.", zap.Any("input", in)) + return nil, status.Error(codes.Unauthenticated, "Error finding or creating user account.") + } + + err = bcrypt.CompareHashAndPassword(hashedPassword, []byte(email.Password)) + if err != nil { + return nil, status.Error(codes.Unauthenticated, "Invalid credentials.") + } + + token := generateToken(s.config, dbUserID, dbUsername) + return &api.Session{Token: token}, nil + } } func (s *ApiServer) AuthenticateFacebookFunc(ctx context.Context, in *api.AuthenticateFacebook) (*api.Session, error) { diff --git a/server/api_link.go b/server/api_link.go index f01d65217..1e0d6b50a 100644 --- a/server/api_link.go +++ b/server/api_link.go @@ -18,18 +18,136 @@ import ( "golang.org/x/net/context" "github.com/heroiclabs/nakama/api" "github.com/golang/protobuf/ptypes/empty" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "go.uber.org/zap" + "time" + "database/sql" + "github.com/lib/pq" + "golang.org/x/crypto/bcrypt" + "strings" ) func (s *ApiServer) LinkCustomFunc(ctx context.Context, in *api.AccountCustom) (*empty.Empty, error) { - return nil, nil + customID := in.Id + if customID == "" { + return nil, status.Error(codes.InvalidArgument, "Custom ID is required.") + } else if invalidCharsRegex.MatchString(customID) { + return nil, status.Error(codes.InvalidArgument, "Invalid custom ID, no spaces or control characters allowed.") + } else if len(customID) < 10 || len(customID) > 128 { + return nil, status.Error(codes.InvalidArgument, "Invalid custom ID, must be 10-128 bytes.") + } + + userID := ctx.Value(ctxUserIDKey{}) + ts := time.Now().UTC().Unix() + res, err := s.db.Exec(` +UPDATE users +SET custom_id = $2, updated_at = $3 +WHERE (id = $1) +AND (NOT EXISTS + (SELECT id + FROM users + WHERE custom_id = $2 AND NOT id = $1))`, + userID, + customID, + ts) + + if err != nil { + s.logger.Warn("Could not link custom ID.", zap.Error(err), zap.Any("input", in)) + return nil, status.Error(codes.Internal, "Error while trying to link Custom ID.") + } else if count, _ := res.RowsAffected(); count == 0 { + return nil, status.Error(codes.AlreadyExists, "Custom ID is already in use.") + } + + return &empty.Empty{}, nil } func (s *ApiServer) LinkDeviceFunc(ctx context.Context, in *api.AccountDevice) (*empty.Empty, error) { - return nil, nil + deviceID := in.Id + if deviceID == "" { + return nil, status.Error(codes.InvalidArgument, "Device ID is required.") + } else if invalidCharsRegex.MatchString(deviceID) { + return nil, status.Error(codes.InvalidArgument, "Device ID invalid, no spaces or control characters allowed.") + } else if len(deviceID) < 10 || len(deviceID) > 128 { + return nil, status.Error(codes.InvalidArgument, "Device ID invalid, must be 10-128 bytes.") + } + + fnErr := Transact(s.logger, s.db, func (tx *sql.Tx) error { + userID := ctx.Value(ctxUserIDKey{}) + ts := time.Now().UTC().Unix() + + var dbDeviceIdLinkedUser int64 + err := tx.QueryRow("SELECT COUNT(id) FROM user_device WHERE id = $1 AND user_id = $2 LIMIT 1", deviceID, userID).Scan(&dbDeviceIdLinkedUser) + if err != nil { + s.logger.Error("Cannot link device ID.", zap.Error(err), zap.Any("input", in)) + return status.Error(codes.Internal, "Error linking Device ID.") + } + + if dbDeviceIdLinkedUser == 0 { + _, err = tx.Exec("INSERT INTO user_device (id, user_id) VALUES ($1, $2)", deviceID, userID) + if err != nil { + if e, ok := err.(*pq.Error); ok && e.Code == dbErrorUniqueViolation { + return status.Error(codes.AlreadyExists, "Device ID already in use.") + } + s.logger.Error("Cannot link device ID.", zap.Error(err), zap.Any("input", in)) + return status.Error(codes.Internal, "Error linking Device ID.") + } + } + + _, err = tx.Exec("UPDATE users SET updated_at = $1 WHERE id = $2", ts, userID) + if err != nil { + s.logger.Error("Cannot update users table while linking.", zap.Error(err), zap.Any("input", in)) + return status.Error(codes.Internal, "Error linking Device ID.") + } + return nil + }) + + if fnErr != nil { + return nil, fnErr + } + + return &empty.Empty{}, nil } func (s *ApiServer) LinkEmailFunc(ctx context.Context, in *api.AccountEmail) (*empty.Empty, error) { - return nil, nil + if in.Email == "" || in.Password == "" { + return nil, status.Error(codes.InvalidArgument, "Email address and password is required.") + } else if invalidCharsRegex.MatchString(in.Email) { + return nil, status.Error(codes.InvalidArgument, "Invalid email address, no spaces or control characters allowed.") + } else if len(in.Password) < 8 { + return nil, status.Error(codes.InvalidArgument, "Password must be longer than 8 characters.") + } else if !emailRegex.MatchString(in.Email) { + return nil, status.Error(codes.InvalidArgument, "Invalid email address format.") + } else if len(in.Email) < 10 || len(in.Email) > 255 { + return nil, status.Error(codes.InvalidArgument, "Invalid email address, must be 10-255 bytes.") + } + + cleanEmail := strings.ToLower(in.Email) + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(in.Password), bcrypt.DefaultCost) + + userID := ctx.Value(ctxUserIDKey{}) + ts := time.Now().UTC().Unix() + res, err := s.db.Exec(` +UPDATE users +SET email = $2, password = $3, updated_at = $4 +WHERE (id = $1) +AND (NOT EXISTS + (SELECT id + FROM users + WHERE email = $2 AND NOT id = $1))`, + userID, + cleanEmail, + hashedPassword, + ts) + + if err != nil { + s.logger.Warn("Could not link email.", zap.Error(err), zap.Any("input", in)) + return nil, status.Error(codes.Internal, "Error while trying to link email.") + } else if count, _ := res.RowsAffected(); count == 0 { + return nil, status.Error(codes.AlreadyExists, "Email is already in use.") + } + + return &empty.Empty{}, nil } func (s *ApiServer) LinkFacebookFunc(ctx context.Context, in *api.AccountFacebook) (*empty.Empty, error) { diff --git a/server/api_unlink.go b/server/api_unlink.go index 5be7841ec..fe59262f2 100644 --- a/server/api_unlink.go +++ b/server/api_unlink.go @@ -18,18 +18,108 @@ import ( "golang.org/x/net/context" "github.com/heroiclabs/nakama/api" "github.com/golang/protobuf/ptypes/empty" + "go.uber.org/zap" + "time" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "database/sql" + "strings" ) func (s *ApiServer) UnlinkCustomFunc(ctx context.Context, in *api.AccountCustom) (*empty.Empty, error) { - return nil, nil + query := `UPDATE users SET custom_id = NULL, updated_at = $3 +WHERE id = $1 +AND custom_id = $2 +AND ((facebook_id IS NOT NULL + OR google_id IS NOT NULL + OR gamecenter_id IS NOT NULL + OR steam_id IS NOT NULL + OR email IS NOT NULL) + OR + EXISTS (SELECT id FROM user_device WHERE user_id = $1 LIMIT 1))` + + userID := ctx.Value(ctxUserIDKey{}) + ts := time.Now().UTC().Unix() + res, err := s.db.Exec(query, userID, in.Id, ts) + + if err != nil { + s.logger.Warn("Could not unlink custom ID.", zap.Error(err), zap.Any("input", in)) + return nil, status.Error(codes.Internal, "Error while trying to unlink custom ID.") + } else if count, _ := res.RowsAffected(); count == 0 { + return nil, status.Error(codes.PermissionDenied, "Cannot unlink last account identifier. Check profile exists and is not last link.") + } + + return &empty.Empty{}, nil } func (s *ApiServer) UnlinkDeviceFunc(ctx context.Context, in *api.AccountDevice) (*empty.Empty, error) { - return nil, nil + fnErr := Transact(s.logger, s.db, func (tx *sql.Tx) error { + userID := ctx.Value(ctxUserIDKey{}) + ts := time.Now().UTC().Unix() + + query := `DELETE FROM user_device WHERE id = $2 AND user_id = $1 +AND (EXISTS (SELECT id FROM users WHERE id = $1 AND + (facebook_id IS NOT NULL + OR google_id IS NOT NULL + OR gamecenter_id IS NOT NULL + OR steam_id IS NOT NULL + OR email IS NOT NULL + OR custom_id IS NOT NULL)) + OR EXISTS (SELECT id FROM user_device WHERE user_id = $1 AND id <> $2))` + + res, err := tx.Exec(query, userID, in.Id) + if err != nil { + s.logger.Warn("Could not unlink device ID.", zap.Error(err), zap.Any("input", in)) + return status.Error(codes.Internal, "Could not unlink Device ID.") + } + if count, _ := res.RowsAffected(); count == 0 { + return status.Error(codes.PermissionDenied, "Cannot unlink last account identifier. Check profile exists and is not last link.") + } + + res, err = tx.Exec("UPDATE users SET updated_at = $2 WHERE id = $1", userID, ts) + if err != nil { + s.logger.Warn("Could not unlink device ID.", zap.Error(err), zap.Any("input", in)) + return status.Error(codes.Internal, "Could not unlink Device ID.") + } + if count, _ := res.RowsAffected(); count == 0 { + return status.Error(codes.PermissionDenied, "Cannot unlink last account identifier. Check profile exists and is not last link.") + } + + return nil + }) + + if fnErr != nil { + return nil, fnErr + } + + return &empty.Empty{}, nil } func (s *ApiServer) UnlinkEmailFunc(ctx context.Context, in *api.AccountEmail) (*empty.Empty, error) { - return nil, nil + query := `UPDATE users SET email = NULL, password = NULL, updated_at = $3 +WHERE id = $1 +AND email = $2 +AND ((facebook_id IS NOT NULL + OR google_id IS NOT NULL + OR gamecenter_id IS NOT NULL + OR steam_id IS NOT NULL + OR custom_id IS NOT NULL) + OR + EXISTS (SELECT id FROM user_device WHERE user_id = $1 LIMIT 1))` + + userID := ctx.Value(ctxUserIDKey{}) + ts := time.Now().UTC().Unix() + cleanEmail := strings.ToLower(in.Email) + res, err := s.db.Exec(query, userID, cleanEmail, ts) + + if err != nil { + s.logger.Warn("Could not unlink email.", zap.Error(err), zap.Any("input", in)) + return nil, status.Error(codes.Internal, "Error while trying to unlink email.") + } else if count, _ := res.RowsAffected(); count == 0 { + return nil, status.Error(codes.PermissionDenied, "Cannot unlink last account identifier. Check profile exists and is not last link.") + } + + return &empty.Empty{}, nil } func (s *ApiServer) UnlinkFacebookFunc(ctx context.Context, in *api.AccountFacebook) (*empty.Empty, error) { diff --git a/server/db.go b/server/db.go new file mode 100644 index 000000000..62fb1c677 --- /dev/null +++ b/server/db.go @@ -0,0 +1,52 @@ +// Copyright 2018 The Nakama Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "database/sql" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + dbErrorUniqueViolation = "23505" +) + +func Transact(logger *zap.Logger, db *sql.DB, txFunc func(*sql.Tx) error) (err error) { + tx, err := db.Begin() + if err != nil { + logger.Error("Could not begin database transaction.", zap.Error(err)) + return + } + + fnErr := txFunc(tx) + + if p := recover(); p != nil { + if err = tx.Rollback(); err != nil { + logger.Error("Could not rollback database transaction.", zap.Error(err)) + } + } else if fnErr != nil { + if err = tx.Rollback(); err != nil { + logger.Error("Could not rollback database transaction.", zap.Error(err)) + } + } else { + if err = tx.Commit(); err != nil { + logger.Error("Could not commit database transaction.", zap.Error(err)) + return status.Error(codes.Internal, "Could not complete operation.") + } + } + return fnErr +} -- GitLab