Commit c2d92e4a authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Optionally allow JSON encoding in user register/login. Merge #52

parent e6a9dd74
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -4,6 +4,8 @@ All notable changes to this project are documented below.
The format is based on [keep a changelog](http://keepachangelog.com/) and this project uses [semantic versioning](http://semver.org/).

## [Unreleased]
### Added
- Optionally allow JSON encoding in user login/register operations and responses.

## [0.12.0] - 2017-03-19
### Added
+76 −26
Original line number Diff line number Diff line
@@ -26,7 +26,9 @@ import (

	"nakama/pkg/social"

	"bytes"
	"github.com/dgrijalva/jwt-go"
	"github.com/gogo/protobuf/jsonpb"
	"github.com/gogo/protobuf/proto"
	"github.com/gorilla/handlers"
	"github.com/gorilla/mux"
@@ -34,6 +36,7 @@ import (
	"github.com/satori/go.uuid"
	"github.com/uber-go/zap"
	"golang.org/x/crypto/bcrypt"
	"mime"
)

const (
@@ -62,6 +65,8 @@ type authenticationService struct {
	upgrader          *websocket.Upgrader
	socialClient      *social.Client
	random            *rand.Rand
	jsonpbMarshaler   *jsonpb.Marshaler
	jsonpbUnmarshaler *jsonpb.Unmarshaler
}

// NewAuthenticationService creates a new AuthenticationService
@@ -75,13 +80,22 @@ func NewAuthenticationService(logger zap.Logger, config Config, db *sql.DB, regi
		registry:       registry,
		pipeline:       p,
		hmacSecretByte: []byte(config.GetSession().EncryptionKey),
		socialClient:   s,
		random:         rand.New(rand.NewSource(time.Now().UnixNano())),
		upgrader: &websocket.Upgrader{
			ReadBufferSize:  1024,
			WriteBufferSize: 1024,
			CheckOrigin:     func(r *http.Request) bool { return true },
		},
		socialClient: s,
		random:       rand.New(rand.NewSource(time.Now().UnixNano())),
		jsonpbMarshaler: &jsonpb.Marshaler{
			EnumsAsInts:  true,
			EmitDefaults: false,
			Indent:       "",
			OrigName:     false,
		},
		jsonpbUnmarshaler: &jsonpb.Unmarshaler{
			AllowUnknownFields: false,
		},
	}

	a.configure()
@@ -155,32 +169,48 @@ func (a *authenticationService) handleAuth(w http.ResponseWriter, r *http.Reques

	username, _, ok := r.BasicAuth()
	if !ok {
		a.sendAuthError(w, "Missing or invalid authentication header", 400, nil)
		a.sendAuthError(w, r, "Missing or invalid authentication header", 400, nil)
		return
	} else if username != a.config.GetTransport().ServerKey {
		a.sendAuthError(w, "Invalid server key", 401, nil)
		a.sendAuthError(w, r, "Invalid server key", 401, nil)
		return
	}

	data, err := ioutil.ReadAll(http.MaxBytesReader(w, r.Body, a.config.GetTransport().MaxMessageSizeBytes))
	if err != nil {
		a.logger.Warn("Could not read body", zap.Error(err))
		a.sendAuthError(w, "Could not read request body", 400, nil)
		a.sendAuthError(w, r, "Could not read request body", 400, nil)
		return
	}

	contentType := r.Header.Get("content-type")
	if contentType == "" {
		contentType = "application/octet-stream"
	}
	mediaType, _, err := mime.ParseMediaType(contentType)
	if err != nil {
		a.logger.Warn("Could not decode content type header", zap.Error(err))
		a.sendAuthError(w, r, "Could not decode content type header", 400, nil)
		return
	}

	authReq := &AuthenticateRequest{}
	switch mediaType {
	case "application/json":
		err = a.jsonpbUnmarshaler.Unmarshal(bytes.NewReader(data), authReq)
	default:
		err = proto.Unmarshal(data, authReq)
	}
	if err != nil {
		a.logger.Warn("Could not decode body", zap.Error(err))
		a.sendAuthError(w, "Could not decode body", 400, nil)
		a.sendAuthError(w, r, "Could not decode body", 400, nil)
		return
	}

	userID, handle, errString, errCode := retrieveUserID(authReq)
	if errString != "" {
		a.logger.Debug("Could not retrieve user ID", zap.String("error", errString), zap.Int("code", errCode))
		a.sendAuthError(w, errString, errCode, authReq)
		a.sendAuthError(w, r, errString, errCode, authReq)
		return
	}

@@ -194,31 +224,51 @@ func (a *authenticationService) handleAuth(w http.ResponseWriter, r *http.Reques
	signedToken, _ := token.SignedString(a.hmacSecretByte)

	authResponse := &AuthenticateResponse{CollationId: authReq.CollationId, Payload: &AuthenticateResponse_Session_{&AuthenticateResponse_Session{Token: signedToken}}}
	a.sendAuthResponse(w, authResponse)
	a.sendAuthResponse(w, r, 200, authResponse)
}

func (a *authenticationService) sendAuthError(w http.ResponseWriter, error string, errorCode int, authRequest *AuthenticateRequest) {
func (a *authenticationService) sendAuthError(w http.ResponseWriter, r *http.Request, error string, errorCode int, authRequest *AuthenticateRequest) {
	var collationID string
	if authRequest != nil {
		collationID = authRequest.CollationId
	}
	w.Header().Set("X-Content-Type-Options", "nosniff")
	w.WriteHeader(errorCode)
	authResponse := &AuthenticateResponse{CollationId: collationID, Payload: &AuthenticateResponse_Error_{&AuthenticateResponse_Error{
		Code:    int32(AUTH_ERROR),
		Message: error,
		Request: authRequest,
	}}}
	a.sendAuthResponse(w, authResponse)
	a.sendAuthResponse(w, r, errorCode, authResponse)
}

func (a *authenticationService) sendAuthResponse(w http.ResponseWriter, response *AuthenticateResponse) {
	payload, err := proto.Marshal(response)
func (a *authenticationService) sendAuthResponse(w http.ResponseWriter, r *http.Request, code int, response *AuthenticateResponse) {
	accept := r.Header.Get("accept")
	if accept == "" {
		accept = "application/octet-stream"
	}
	mediaType, _, err := mime.ParseMediaType(accept)
	if err != nil {
		a.logger.Error("Could not marshall Response to byte[]", zap.Error(err))
		a.logger.Warn("Could not decode accept header, defaulting to Protobuf output", zap.Error(err))
		err = nil
	}

	var payload []byte
	switch mediaType {
	case "application/json":
		payloadString, err := a.jsonpbMarshaler.MarshalToString(response)
		if err == nil {
			payload = []byte(payloadString)
			w.Header().Set("Content-Type", "application/json")
		}
	default:
		payload, err = proto.Marshal(response)
	}
	if err != nil {
		a.logger.Error("Could not marshal AuthenticateResponse", zap.Error(err))
		return
	}

	w.Header().Set("X-Content-Type-Options", "nosniff")
	w.WriteHeader(code)
	w.Write(payload)
}