Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove mutex from sessionWS #1141

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 62 additions & 162 deletions server/session_ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"errors"
"fmt"
"net"
"sync"
"time"

"github.com/gofrs/uuid/v5"
Expand All @@ -35,7 +34,6 @@ import (
var ErrSessionQueueFull = errors.New("session outgoing queue full")

type sessionWS struct {
sync.Mutex
logger *zap.Logger
config Config
id uuid.UUID
Expand Down Expand Up @@ -66,11 +64,9 @@ type sessionWS struct {
pipeline *Pipeline
runtime *Runtime

stopped bool
stopped *atomic.Bool
conn *websocket.Conn
receivedMessageCounter int
pingTimer *time.Timer
pingTimerCAS *atomic.Uint32
receivedMessageCounter *atomic.Int32
outgoingCh chan []byte
}

Expand Down Expand Up @@ -117,11 +113,9 @@ func NewSessionWS(logger *zap.Logger, config Config, format SessionFormat, sessi
pipeline: pipeline,
runtime: runtime,

stopped: false,
stopped: atomic.NewBool(false),
conn: conn,
receivedMessageCounter: config.GetSocket().PingBackoffThreshold,
pingTimer: time.NewTimer(time.Duration(config.GetSocket().PingPeriodMs) * time.Millisecond),
pingTimerCAS: atomic.NewUint32(1),
receivedMessageCounter: atomic.NewInt32(0),
outgoingCh: make(chan []byte, config.GetSocket().OutgoingQueueSize),
}
}
Expand Down Expand Up @@ -183,8 +177,7 @@ func (s *sessionWS) Consume() {
return
}
s.conn.SetPongHandler(func(string) error {
s.maybeResetPingTimer()
return nil
return s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration))
})

// Start a routine to process outbound messages.
Expand Down Expand Up @@ -215,15 +208,7 @@ IncomingLoop:
break
}

s.receivedMessageCounter--
if s.receivedMessageCounter <= 0 {
s.receivedMessageCounter = s.config.GetSocket().PingBackoffThreshold
if !s.maybeResetPingTimer() {
// Problems resetting the ping timer indicate an error so we need to close the loop.
reason = "error updating ping timer"
break
}
}
s.receivedMessageCounter.Add(1)

request := &rtapi.Envelope{}
switch s.format {
Expand Down Expand Up @@ -267,102 +252,77 @@ IncomingLoop:
s.Close(reason, runtime.PresenceReasonDisconnect)
}

func (s *sessionWS) maybeResetPingTimer() bool {
// If there's already a reset in progress there's no need to wait.
if !s.pingTimerCAS.CompareAndSwap(1, 0) {
return true
}
defer s.pingTimerCAS.CompareAndSwap(0, 1)

s.Lock()
if s.stopped {
s.Unlock()
return false
}
// CAS ensures concurrency is not a problem here.
if !s.pingTimer.Stop() {
select {
case <-s.pingTimer.C:
default:
}
}
s.pingTimer.Reset(s.pingPeriodDuration)
err := s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration))
s.Unlock()
if err != nil {
s.logger.Warn("Failed to set read deadline", zap.Error(err))
s.Close("failed to set read deadline", runtime.PresenceReasonDisconnect)
return false
}
return true
}

func (s *sessionWS) processOutgoing() {
var reason string
ticker := time.NewTicker(s.pingPeriodDuration)

defer func() {
ticker.Stop()
s.Close(reason, runtime.PresenceReasonDisconnect)
}()

OutgoingLoop:
for {
select {
case <-s.ctx.Done():
// Session is closing, close the outgoing process routine.
break OutgoingLoop
case <-s.pingTimer.C:
// Periodically send pings.
if msg, ok := s.pingNow(); !ok {
// If ping fails the session will be stopped, clean up the loop.
reason = msg
break OutgoingLoop
}
case payload := <-s.outgoingCh:
s.Lock()
if s.stopped {
// The connection may have stopped between the payload being queued on the outgoing channel and reaching here.
// If that's the case then abort outgoing processing at this point and exit.
s.Unlock()
break OutgoingLoop
}
// Process the outgoing message queue.
if err := s.conn.SetWriteDeadline(time.Now().Add(s.writeWaitDuration)); err != nil {
s.Unlock()
s.logger.Warn("Failed to set write deadline", zap.Error(err))
reason = err.Error()
break OutgoingLoop
return
}
if err := s.conn.WriteMessage(s.wsMessageType, payload); err != nil {
s.Unlock()
s.logger.Warn("Could not write message", zap.Error(err))
reason = err.Error()
break OutgoingLoop
return
}
s.Unlock()

// Update outgoing message metrics.
s.metrics.MessageBytesSent(int64(len(payload)))
}
}

s.Close(reason, runtime.PresenceReasonDisconnect)
}
case <-s.ctx.Done():
// Session is closing, close the outgoing process routine.
if err := s.conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(s.writeWaitDuration)); err != nil {
// This may not be possible if the socket was already fully closed by an error.
s.logger.Debug("Could not send close message", zap.Error(err))
}
// Close WebSocket.
if err := s.conn.Close(); err != nil {
s.logger.Debug("Could not close", zap.Error(err))
}

func (s *sessionWS) pingNow() (string, bool) {
s.Lock()
if s.stopped {
s.Unlock()
return "", false
}
if err := s.conn.SetWriteDeadline(time.Now().Add(s.writeWaitDuration)); err != nil {
s.Unlock()
s.logger.Warn("Could not set write deadline to ping", zap.Error(err))
return err.Error(), false
}
err := s.conn.WriteMessage(websocket.PingMessage, []byte{})
s.Unlock()
if err != nil {
s.logger.Warn("Could not send ping", zap.Error(err))
return err.Error(), false
}
s.logger.Info("Closed client connection")

return

case <-ticker.C:
messagesReceived := s.receivedMessageCounter.Swap(0)
if int(messagesReceived) >= s.config.GetSocket().PingBackoffThreshold {
// Received enough messages to skip sending a ping

// Update read deadline, since we aren't sending a ping, which means the pong handler won't be triggered
if err := s.conn.SetReadDeadline(time.Now().Add(s.pongWaitDuration)); err != nil {
s.logger.Warn("Failed to set read deadline", zap.Error(err))
reason = err.Error()
return
}

return "", true
continue
}

err := s.conn.SetWriteDeadline(time.Now().Add(s.writeWaitDuration))
if err != nil {
s.logger.Warn("Failed to set write deadline", zap.Error(err))
reason = err.Error()
return
}
err = s.conn.WriteMessage(websocket.PingMessage, nil)
if err != nil {
s.logger.Warn("Could not write message", zap.Error(err))
reason = err.Error()
return
}
}
}
}

func (s *sessionWS) Format() SessionFormat {
Expand Down Expand Up @@ -400,39 +360,25 @@ func (s *sessionWS) Send(envelope *rtapi.Envelope, reliable bool) error {
}

func (s *sessionWS) SendBytes(payload []byte, reliable bool) error {
s.Lock()
if s.stopped {
s.Unlock()
return nil
}

// Attempt to queue messages and observe failures.
select {
case s.outgoingCh <- payload:
s.Unlock()
return nil
default:
// The outgoing queue is full, likely because the remote client can't keep up.
// Terminate the connection immediately because the only alternative that doesn't block the server is
// to start dropping messages, which might cause unexpected behaviour.
s.Unlock()
s.logger.Warn("Could not write message, session outgoing queue full")
s.Close(ErrSessionQueueFull.Error(), runtime.PresenceReasonDisconnect)
return ErrSessionQueueFull
}
}

func (s *sessionWS) Close(msg string, reason runtime.PresenceReason, envelopes ...*rtapi.Envelope) {
s.Lock()
if s.stopped {
s.Unlock()
if !s.stopped.CompareAndSwap(false, true) {
// connection already closed
return
}
s.stopped = true
s.Unlock()

// Cancel any ongoing operations tied to this session.
s.ctxCancelFn()

if s.logger.Core().Enabled(zap.DebugLevel) {
s.logger.Info("Cleaning up closed client connection")
Expand All @@ -458,63 +404,17 @@ func (s *sessionWS) Close(msg string, reason runtime.PresenceReason, envelopes .
s.logger.Info("Cleaned up closed connection session registry")
}

// Clean up internals.
s.pingTimer.Stop()
close(s.outgoingCh)

// Send final messages, if any are specified.
for _, envelope := range envelopes {
var payload []byte
var err error
switch s.format {
case SessionFormatProtobuf:
payload, err = proto.Marshal(envelope)
case SessionFormatJson:
fallthrough
default:
if buf, err := s.protojsonMarshaler.Marshal(envelope); err == nil {
payload = buf
}
}
err := s.Send(envelope, false)
if err != nil {
s.logger.Warn("Could not marshal envelope", zap.Error(err))
s.logger.Warn("Failed to send envelope", zap.Error(err))
continue
}

if s.logger.Core().Enabled(zap.DebugLevel) {
switch envelope.Message.(type) {
case *rtapi.Envelope_Error:
s.logger.Debug("Sending error message", zap.Binary("payload", payload))
default:
s.logger.Debug(fmt.Sprintf("Sending %T message", envelope.Message), zap.Any("envelope", envelope))
}
}

s.Lock()
if err := s.conn.SetWriteDeadline(time.Now().Add(s.writeWaitDuration)); err != nil {
s.Unlock()
s.logger.Warn("Failed to set write deadline", zap.Error(err))
continue
}
if err := s.conn.WriteMessage(s.wsMessageType, payload); err != nil {
s.Unlock()
s.logger.Warn("Could not write message", zap.Error(err))
continue
}
s.Unlock()
}

// Send close message.
if err := s.conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(s.writeWaitDuration)); err != nil {
// This may not be possible if the socket was already fully closed by an error.
s.logger.Debug("Could not send close message", zap.Error(err))
}
// Close WebSocket.
if err := s.conn.Close(); err != nil {
s.logger.Debug("Could not close", zap.Error(err))
}

s.logger.Info("Closed client connection")
// Cancel any ongoing operations tied to this session. This will trigger a close message.
s.ctxCancelFn()

// Fire an event for session end.
if fn := s.runtime.EventSessionEnd(); fn != nil {
Expand Down