Commit 158c6500 authored by Andrei Mihu's avatar Andrei Mihu
Browse files

Realtime chat improvements. Merge #24

parent f3ea246d
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -6,6 +6,15 @@ The format is based on [keep a changelog](http://keepachangelog.com/) and this p
## [Unreleased]
### Added
- Include Dockerfile and Docker instructions.
- Use a default limit in topic message listings if one is not provided.

### Fixed
- Enforce concurrency control on outgoing socket messages.
- Correct session lookup for realtime message routing.
- Fix input validation when sending topic messages.
- Correct handling of IDs in various login options.
- Fix presence service shutdown sequence.
- More graceful handling of session operations while connection is closing.

## [0.11.1] - 2017-02-12
### Changed
+1 −1
Original line number Diff line number Diff line
@@ -47,7 +47,7 @@ func (m *messageRouterService) Send(logger zap.Logger, ps []Presence, msg proto.

	for _, p := range ps {
		session := m.registry.Get(p.ID.SessionID)
		if session == nil {
		if session != nil {
			err := session.SendBytes(payload)
			if err != nil {
				logger.Error("Failed to route to", zap.Object("p", p), zap.Error(err))
+33 −18
Original line number Diff line number Diff line
@@ -610,21 +610,26 @@ func (p *pipeline) groupJoin(l zap.Logger, session *session, envelope *Envelope)
	}
	defer func() {
		if err != nil {
			logger.Error("Could not add user to group", zap.Error(err))
			logger.Error("Could not join group", zap.Error(err))
			err = tx.Rollback()
			if err != nil {
				logger.Error("Could not rollback transaction", zap.Error(err))
			}

			session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Could not add user to group"}}})
			session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Could not join group"}}})
		} else {
			err = tx.Commit()
			if err != nil {
				logger.Error("Could not commit transaction", zap.Error(err))
				session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Could not add user to group"}}})
				session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Could not join group"}}})
			} else {
				logger.Info("Added user to the group")
				logger.Info("User joined group")
				session.Send(&Envelope{CollationId: envelope.CollationId})

				err = p.storeAndDeliverMessage(logger, session, &TopicId{Id: &TopicId_GroupId{GroupId: groupID.Bytes()}}, 1, []byte("{}"))
				if err != nil {
					logger.Error("Error handling group user join notification topic message", zap.Error(err))
				}
			}
		}
	}()
@@ -662,8 +667,6 @@ VALUES ($1, $2, $2, $3, $4), ($3, $2, $2, $1, $4)`,
	if err != nil {
		return
	}

	p.storeAndDeliverMessage(logger, session, &TopicId{Id: &TopicId_GroupId{GroupId: groupID.Bytes()}}, 1, []byte("{}"))
}

func (p *pipeline) groupLeave(l zap.Logger, session *session, envelope *Envelope) {
@@ -699,8 +702,13 @@ func (p *pipeline) groupLeave(l zap.Logger, session *session, envelope *Envelope
				logger.Error("Could not commit transaction", zap.Error(err))
				session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: failureReason}}})
			} else {
				logger.Info("Left group")
				logger.Info("User left group")
				session.Send(&Envelope{CollationId: envelope.CollationId})

				err = p.storeAndDeliverMessage(logger, session, &TopicId{Id: &TopicId_GroupId{GroupId: groupID.Bytes()}}, 3, []byte("{}"))
				if err != nil {
					logger.Error("Error handling group user leave notification topic message", zap.Error(err))
				}
			}
		}
	}()
@@ -769,8 +777,6 @@ OR
	if err != nil {
		return
	}

	p.storeAndDeliverMessage(logger, session, &TopicId{Id: &TopicId_GroupId{GroupId: groupID.Bytes()}}, 3, []byte("{}"))
}

func (p *pipeline) groupUserAdd(l zap.Logger, session *session, envelope *Envelope) {
@@ -789,6 +795,7 @@ func (p *pipeline) groupUserAdd(l zap.Logger, session *session, envelope *Envelo
	}

	logger := l.With(zap.String("group_id", groupID.String()), zap.String("user_id", userID.String()))
	var handle string

	tx, err := p.db.Begin()
	if err != nil {
@@ -817,12 +824,17 @@ func (p *pipeline) groupUserAdd(l zap.Logger, session *session, envelope *Envelo
			} else {
				logger.Info("Added user to the group")
				session.Send(&Envelope{CollationId: envelope.CollationId})

				data, _ := json.Marshal(map[string]string{"user_id": userID.String(), "handle": handle})
				err = p.storeAndDeliverMessage(logger, session, &TopicId{Id: &TopicId_GroupId{GroupId: groupID.Bytes()}}, 2, data)
				if err != nil {
					logger.Error("Error handling group user added notification topic message", zap.Error(err))
				}
			}
		}
	}()

	// Look up the user being added.
	var handle string
	err = tx.QueryRow("SELECT handle FROM users WHERE id = $1 AND disabled_at = 0", userID.Bytes()).Scan(&handle)
	if err != nil {
		return
@@ -857,9 +869,6 @@ DO UPDATE SET state = 1, updated_at = $2::INT`,
	if err != nil {
		return
	}

	data, err := json.Marshal(map[string]string{"user_id": userID.String(), "handle": handle})
	p.storeAndDeliverMessage(logger, session, &TopicId{Id: &TopicId_GroupId{GroupId: groupID.Bytes()}}, 2, data)
}

func (p *pipeline) groupUserKick(l zap.Logger, session *session, envelope *Envelope) {
@@ -884,6 +893,7 @@ func (p *pipeline) groupUserKick(l zap.Logger, session *session, envelope *Envel
	}

	logger := l.With(zap.String("group_id", groupID.String()), zap.String("user_id", userID.String()))
	var handle string

	failureReason := "Could not kick user from group"
	tx, err := p.db.Begin()
@@ -913,6 +923,12 @@ func (p *pipeline) groupUserKick(l zap.Logger, session *session, envelope *Envel
			} else {
				logger.Info("Kicked user from group")
				session.Send(&Envelope{CollationId: envelope.CollationId})

				data, _ := json.Marshal(map[string]string{"user_id": userID.String(), "handle": handle})
				err = p.storeAndDeliverMessage(logger, session, &TopicId{Id: &TopicId_GroupId{GroupId: groupID.Bytes()}}, 4, data)
				if err != nil {
					logger.Error("Error handling group user kicked notification topic message", zap.Error(err))
				}
			}
		}
	}()
@@ -946,14 +962,10 @@ AND
	}

	// Look up the user being kicked. Allow kicking disabled users.
	var handle string
	err = tx.QueryRow("SELECT handle FROM users WHERE id = $1", userID.Bytes()).Scan(&handle)
	if err != nil {
		return
	}

	data, err := json.Marshal(map[string]string{"user_id": userID.String(), "handle": handle})
	p.storeAndDeliverMessage(logger, session, &TopicId{Id: &TopicId_GroupId{GroupId: groupID.Bytes()}}, 4, data)
}

func (p *pipeline) groupUserPromote(l zap.Logger, session *session, envelope *Envelope) {
@@ -1011,7 +1023,10 @@ AND
	}

	data, _ := json.Marshal(map[string]string{"user_id": userID.String(), "handle": handle})
	p.storeAndDeliverMessage(logger, session, &TopicId{Id: &TopicId_GroupId{GroupId: groupID.Bytes()}}, 5, data)
	err = p.storeAndDeliverMessage(logger, session, &TopicId{Id: &TopicId_GroupId{GroupId: groupID.Bytes()}}, 5, data)
	if err != nil {
		logger.Error("Error handling group user promoted notification topic message", zap.Error(err))
	}

	session.Send(&Envelope{CollationId: envelope.CollationId})
}
+43 −17
Original line number Diff line number Diff line
@@ -233,7 +233,7 @@ func (p *pipeline) topicMessageSend(logger zap.Logger, session *session, envelop
		session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Topic ID is required"}}})
		return
	}
	data := envelope.GetTopicMessage().Data
	data := envelope.GetTopicMessageSend().Data
	if data == nil || len(data) == 0 || len(data) > 1000 {
		session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Data is required and must be 1-1000 JSON bytes"}}})
		return
@@ -330,12 +330,14 @@ func (p *pipeline) topicMessageSend(logger zap.Logger, session *session, envelop
		return
	}

	messageID, handle, createdAt, expiresAt, err := p.storeAndDeliverMessage(logger, session, topic, 0, data)
	// Store message to history.
	messageID, handle, createdAt, expiresAt, err := p.storeMessage(logger, session, topic, 0, data)
	if err != nil {
		session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error storing message"}}})
		return
	}

	// Return receipt to sender.
	ack := &TTopicMessageAck{
		MessageId: messageID,
		CreatedAt: createdAt,
@@ -343,6 +345,9 @@ func (p *pipeline) topicMessageSend(logger zap.Logger, session *session, envelop
		Handle:    handle,
	}
	session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_TopicMessageAck{TopicMessageAck: ack}})

	// Deliver message to topic.
	p.deliverMessage(logger, session, topic, 0, data, messageID, handle, createdAt, expiresAt)
}

func (p *pipeline) topicMessagesList(logger zap.Logger, session *session, envelope *Envelope) {
@@ -351,7 +356,11 @@ func (p *pipeline) topicMessagesList(logger zap.Logger, session *session, envelo
		session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Topic ID is required"}}})
		return
	}
	if input.Limit < 10 || input.Limit > 100 {
	limit := input.Limit
	if limit == 0 {
		limit = 10
	}
	if limit < 10 || limit > 100 {
		session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Limit must be 10-100"}}})
		return
	}
@@ -435,7 +444,7 @@ func (p *pipeline) topicMessagesList(logger zap.Logger, session *session, envelo
	}

	query := "SELECT message_id, user_id, created_at, expires_at, handle, type, data FROM message WHERE topic = $2 AND topic_type = $3"
	params := []interface{}{input.Limit + 1, topicBytes, topicType}
	params := []interface{}{limit + 1, topicBytes, topicType}

	// Only paginate if all cursor components are available.
	if input.Cursor != nil {
@@ -477,7 +486,7 @@ func (p *pipeline) topicMessagesList(logger zap.Logger, session *session, envelo
	var msgType int64
	var data []byte
	for rows.Next() {
		if int64(len(messages)) >= input.Limit {
		if int64(len(messages)) >= limit {
			cursorBuf := new(bytes.Buffer)
			if gob.NewEncoder(cursorBuf).Encode(&messageCursor{MessageID: messageID, UserID: userID, CreatedAt: createdAt}); err != nil {
				logger.Error("Error creating topic messages list cursor", zap.Error(err))
@@ -551,25 +560,17 @@ AND ue.source_id = $2`, checkUserID, blocksUserID).Scan(&uid, &state)
}

// Assumes `topic` has already been validated, or was constructed internally.
func (p *pipeline) storeAndDeliverMessage(logger zap.Logger, session *session, topic *TopicId, msgType int64, data []byte) ([]byte, string, int64, int64, error) {
	var trackerTopic string
func (p *pipeline) storeMessage(logger zap.Logger, session *session, topic *TopicId, msgType int64, data []byte) ([]byte, string, int64, int64, error) {
	var topicBytes []byte
	var topicType int64
	switch topic.Id.(type) {
	case *TopicId_Dm:
		bothUserIDBytes := topic.GetDm()
		userID1 := uuid.FromBytesOrNil(bothUserIDBytes[:16])
		userID2 := uuid.FromBytesOrNil(bothUserIDBytes[16:])

		trackerTopic = "dm:" + userID1.String() + ":" + userID2.String()
		topicBytes = bothUserIDBytes
		topicBytes = topic.GetDm()
		topicType = 0
	case *TopicId_Room:
		trackerTopic = "room:" + string(topic.GetRoom())
		topicBytes = []byte(topic.GetRoom())
		topicBytes = topic.GetRoom()
		topicType = 1
	case *TopicId_GroupId:
		trackerTopic = "group:" + uuid.FromBytesOrNil(topic.GetGroupId()).String()
		topicBytes = topic.GetGroupId()
		topicType = 2
	}
@@ -588,6 +589,24 @@ RETURNING handle`, topicBytes, topicType, messageID, session.userID.Bytes(), cre
		return nil, "", 0, 0, err
	}

	return messageID, handle, createdAt, expiresAt, nil
}

func (p *pipeline) deliverMessage(logger zap.Logger, session *session, topic *TopicId, msgType int64, data []byte, messageID []byte, handle string, createdAt int64, expiresAt int64) {
	var trackerTopic string
	switch topic.Id.(type) {
	case *TopicId_Dm:
		bothUserIDBytes := topic.GetDm()
		userID1 := uuid.FromBytesOrNil(bothUserIDBytes[:16])
		userID2 := uuid.FromBytesOrNil(bothUserIDBytes[16:])

		trackerTopic = "dm:" + userID1.String() + ":" + userID2.String()
	case *TopicId_Room:
		trackerTopic = "room:" + string(topic.GetRoom())
	case *TopicId_GroupId:
		trackerTopic = "group:" + uuid.FromBytesOrNil(topic.GetGroupId()).String()
	}

	outgoing := &Envelope{
		Payload: &Envelope_TopicMessage{
			TopicMessage: &TopicMessage{
@@ -605,6 +624,13 @@ RETURNING handle`, topicBytes, topicType, messageID, session.userID.Bytes(), cre

	presences := p.tracker.ListByTopic(trackerTopic)
	p.messageRouter.Send(logger, presences, outgoing)
}

	return messageID, handle, createdAt, expiresAt, nil
func (p *pipeline) storeAndDeliverMessage(logger zap.Logger, session *session, topic *TopicId, msgType int64, data []byte) error {
	messageID, handle, createdAt, expiresAt, err := p.storeMessage(logger, session, topic, msgType, data)
	if err != nil {
		return err
	}
	p.deliverMessage(logger, session, topic, msgType, data, messageID, handle, createdAt, expiresAt)
	return nil
}
+18 −7
Original line number Diff line number Diff line
@@ -85,8 +85,7 @@ func (s *session) Consume(processRequest func(logger zap.Logger, session *sessio
			s.logger.Warn("Received malformed payload", zap.Object("data", data))
			s.Send(&Envelope{CollationId: request.CollationId, Payload: &Envelope_Error{&Error{Reason: "Unrecognized message"}}})
		} else {
			//TODO(mofirouz, zyro) Add session-global context here
			//to cancel in-progress operations when the session is closed
			// TODO Add session-global context here to cancel in-progress operations when the session is closed.
			requestLogger := s.logger.With(zap.String("cid", request.CollationId))
			processRequest(requestLogger, s, request)
		}
@@ -104,8 +103,14 @@ func (s *session) pingPeriodically() {

func (s *session) pingNow() bool {
	// Websocket ping.
	s.Lock()
	if s.stopped {
		s.Unlock()
		return false
	}
	s.conn.SetWriteDeadline(time.Now().Add(time.Duration(s.config.GetTransport().WriteWaitMs) * time.Millisecond))
	err := s.conn.WriteMessage(websocket.PingMessage, []byte{})
	s.Unlock()
	if err != nil {
		s.logger.Warn("Could not send ping. Closing channel", zap.String("remoteAddress", s.conn.RemoteAddr().String()), zap.Error(err))
		s.Close()
@@ -135,6 +140,13 @@ func (s *session) Send(envelope *Envelope) error {
}

func (s *session) SendBytes(payload []byte) error {
	// TODO Improve on mutex usage here.
	s.Lock()
	defer s.Unlock()
	if s.stopped {
		return nil
	}

	s.conn.SetWriteDeadline(time.Now().Add(time.Duration(s.config.GetTransport().WriteWaitMs) * time.Millisecond))
	return s.conn.WriteMessage(websocket.BinaryMessage, payload)
}
@@ -155,17 +167,16 @@ func (s *session) cleanupClosedConnection() {
}

func (s *session) Close() {
	s.unregister(s)
	s.pingTicker.Stop()

	s.Lock()
	defer s.Unlock()
	if s.stopped {
		return
	}
	s.stopped = true
	s.Unlock()

	s.logger.Info("Closing client connection.", zap.String("remoteAddress", s.conn.RemoteAddr().String()))

	s.unregister(s)
	s.pingTicker.Stop()
	err := s.conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Duration(s.config.GetTransport().WriteWaitMs)*time.Millisecond))
	if err != nil {
		s.logger.Warn("Could not send close message. Closing prematurely.", zap.String("remoteAddress", s.conn.RemoteAddr().String()), zap.Error(err))
Loading