Commit 425c599a authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Update username handling and hook registration validation.

parent 67292354
Loading
Loading
Loading
Loading
+0 −7
Original line number Diff line number Diff line
@@ -43,7 +43,6 @@ func (s *ApiServer) AuthenticateCustom(ctx context.Context, in *api.Authenticate
	}

	username := in.Username
	username = strings.ToLower(username)
	if username == "" {
		username = generateUsername()
	} else if invalidCharsRegex.MatchString(username) {
@@ -73,7 +72,6 @@ func (s *ApiServer) AuthenticateDevice(ctx context.Context, in *api.Authenticate
	}

	username := in.Username
	username = strings.ToLower(username)
	if username == "" {
		username = generateUsername()
	} else if invalidCharsRegex.MatchString(username) {
@@ -110,7 +108,6 @@ func (s *ApiServer) AuthenticateEmail(ctx context.Context, in *api.AuthenticateE
	cleanEmail := strings.ToLower(email.Email)

	username := in.Username
	username = strings.ToLower(username)
	if username == "" {
		username = generateUsername()
	} else if invalidCharsRegex.MatchString(username) {
@@ -136,7 +133,6 @@ func (s *ApiServer) AuthenticateFacebook(ctx context.Context, in *api.Authentica
	}

	username := in.Username
	username = strings.ToLower(username)
	if username == "" {
		username = generateUsername()
	} else if invalidCharsRegex.MatchString(username) {
@@ -179,7 +175,6 @@ func (s *ApiServer) AuthenticateGameCenter(ctx context.Context, in *api.Authenti
	}

	username := in.Username
	username = strings.ToLower(username)
	if username == "" {
		username = generateUsername()
	} else if invalidCharsRegex.MatchString(username) {
@@ -205,7 +200,6 @@ func (s *ApiServer) AuthenticateGoogle(ctx context.Context, in *api.Authenticate
	}

	username := in.Username
	username = strings.ToLower(username)
	if username == "" {
		username = generateUsername()
	} else if invalidCharsRegex.MatchString(username) {
@@ -235,7 +229,6 @@ func (s *ApiServer) AuthenticateSteam(ctx context.Context, in *api.AuthenticateS
	}

	username := in.Username
	username = strings.ToLower(username)
	if username == "" {
		username = generateUsername()
	} else if invalidCharsRegex.MatchString(username) {
+14 −11
Original line number Diff line number Diff line
@@ -130,57 +130,60 @@ func UpdateAccount(db *sql.DB, logger *zap.Logger, userID uuid.UUID, username st
	params := make([]interface{}, 0)

	if username != "" {
		if invalidCharsRegex.MatchString(username) {
			return errors.New("Username invalid, no spaces or control characters allowed.")
		}
		statements = append(statements, "username = $"+strconv.Itoa(index))
		params = append(params, strings.ToLower(username))
		params = append(params, username)
		index++
	}

	if displayName != nil {
		if displayName.GetValue() == "" {
		if d := displayName.GetValue(); d == "" {
			statements = append(statements, "display_name = NULL")
		} else {
			statements = append(statements, "display_name = $"+strconv.Itoa(index))
			params = append(params, displayName.GetValue())
			params = append(params, d)
			index++
		}
	}

	if timezone != nil {
		if timezone.GetValue() == "" {
		if t := timezone.GetValue(); t == "" {
			statements = append(statements, "timezone = NULL")
		} else {
			statements = append(statements, "timezone = $"+strconv.Itoa(index))
			params = append(params, timezone.GetValue())
			params = append(params, t)
			index++
		}
	}

	if location != nil {
		if location.GetValue() == "" {
		if l := location.GetValue(); l == "" {
			statements = append(statements, "location = NULL")
		} else {
			statements = append(statements, "location = $"+strconv.Itoa(index))
			params = append(params, location.GetValue())
			params = append(params, l)
			index++
		}
	}

	if langTag != nil {
		if langTag.GetValue() == "" {
		if l := langTag.GetValue(); l == "" {
			statements = append(statements, "lang_tag = NULL")
		} else {
			statements = append(statements, "lang_tag = $"+strconv.Itoa(index))
			params = append(params, langTag.GetValue())
			params = append(params, l)
			index++
		}
	}

	if avatarURL != nil {
		if avatarURL.GetValue() == "" {
		if a := avatarURL.GetValue(); a == "" {
			statements = append(statements, "avatar_url = NULL")
		} else {
			statements = append(statements, "avatar_url = $"+strconv.Itoa(index))
			params = append(params, avatarURL.GetValue())
			params = append(params, a)
			index++
		}
	}
+1 −1
Original line number Diff line number Diff line
@@ -106,7 +106,7 @@ func UpdateWalletLedger(logger *zap.Logger, db *sql.DB, id uuid.UUID, metadata s
	var changeset sql.NullString
	var createTime pq.NullTime
	var updateTime pq.NullTime
	query := "UPDATE wallet_ledger SET update_time = now(), metadata = $2 WHERE id = $1::UUID RETURNING user_id, changeset, create_time, update_time"
	query := "UPDATE wallet_ledger SET update_time = now(), metadata = metadata || $2 WHERE id = $1::UUID RETURNING user_id, changeset, create_time, update_time"
	err := db.QueryRow(query, id, metadata).Scan(&userId, &changeset, &createTime, &updateTime)
	if err != nil {
		logger.Error("Error updating user wallet ledger.", zap.String("id", id.String()), zap.Error(err))
+20 −0
Original line number Diff line number Diff line
@@ -186,6 +186,10 @@ func (n *NakamaModule) registerRPC(l *lua.LState) int {
	id = strings.ToLower(id)

	rc := l.Context().Value(CALLBACKS).(*Callbacks)
	if _, ok := rc.RPC[id]; ok {
		l.RaiseError("rpc id already registered")
		return 0
	}
	rc.RPC[id] = fn
	if n.announceCallback != nil {
		n.announceCallback(RPC, id)
@@ -205,6 +209,10 @@ func (n *NakamaModule) registerReqBefore(l *lua.LState) int {
	id = strings.ToLower(API_PREFIX + id)

	rc := l.Context().Value(CALLBACKS).(*Callbacks)
	if _, ok := rc.Before[id]; ok {
		l.RaiseError("before id already registered")
		return 0
	}
	rc.Before[id] = fn
	if n.announceCallback != nil {
		n.announceCallback(BEFORE, id)
@@ -224,6 +232,10 @@ func (n *NakamaModule) registerReqAfter(l *lua.LState) int {
	id = strings.ToLower(API_PREFIX + id)

	rc := l.Context().Value(CALLBACKS).(*Callbacks)
	if _, ok := rc.After[id]; ok {
		l.RaiseError("after id already registered")
		return 0
	}
	rc.After[id] = fn
	if n.announceCallback != nil {
		n.announceCallback(AFTER, id)
@@ -243,6 +255,10 @@ func (n *NakamaModule) registerRTBefore(l *lua.LState) int {
	id = strings.ToLower(RTAPI_PREFIX + id)

	rc := l.Context().Value(CALLBACKS).(*Callbacks)
	if _, ok := rc.Before[id]; ok {
		l.RaiseError("before id already registered")
		return 0
	}
	rc.Before[id] = fn
	if n.announceCallback != nil {
		n.announceCallback(BEFORE, id)
@@ -262,6 +278,10 @@ func (n *NakamaModule) registerRTAfter(l *lua.LState) int {
	id = strings.ToLower(RTAPI_PREFIX + id)

	rc := l.Context().Value(CALLBACKS).(*Callbacks)
	if _, ok := rc.After[id]; ok {
		l.RaiseError("before id already registered")
		return 0
	}
	rc.After[id] = fn
	if n.announceCallback != nil {
		n.announceCallback(AFTER, id)