Skip to content

Commit

Permalink
Merge pull request #2793 from OffchainLabs/feed-handshake
Browse files Browse the repository at this point in the history
remove gobwas/ws handshake extensions race condition workaround
  • Loading branch information
joshuacolvin0 authored Nov 19, 2024
2 parents bbc4120 + 7f70cc0 commit 438a992
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 39 deletions.
30 changes: 25 additions & 5 deletions broadcastclient/broadcastclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,10 @@ type BroadcastClient struct {

chainId uint64

// Protects conn and shuttingDown
connMutex sync.Mutex
conn net.Conn
// Protects conn, shuttingDown and compression
connMutex sync.Mutex
conn net.Conn
compression bool

retryCount atomic.Int64

Expand Down Expand Up @@ -299,7 +300,7 @@ func (bc *BroadcastClient) connect(ctx context.Context, nextSeqNum arbutil.Messa
return nil, nil
}

conn, br, _, err := timeoutDialer.Dial(ctx, bc.websocketUrl)
conn, br, hs, err := timeoutDialer.Dial(ctx, bc.websocketUrl)
if errors.Is(err, ErrIncorrectFeedServerVersion) || errors.Is(err, ErrIncorrectChainId) {
return nil, err
}
Expand All @@ -325,6 +326,24 @@ func (bc *BroadcastClient) connect(ctx context.Context, nextSeqNum arbutil.Messa
return nil, ErrMissingFeedServerVersion
}

compressionNegotiated := false
for _, ext := range hs.Extensions {
if ext.Equal(deflateExt) {
compressionNegotiated = true
break
}
}
if !compressionNegotiated && config.EnableCompression {
log.Warn("Compression was not negotiated when connecting to feed server.")
}
if compressionNegotiated && !config.EnableCompression {
err := conn.Close()
if err != nil {
return nil, fmt.Errorf("error closing connection when negotiated disabled extension: %w", err)
}
return nil, errors.New("error dialing feed server: negotiated compression ws extension, but it is disabled")
}

var earlyFrameData io.Reader
if br != nil {
// Depending on how long the client takes to read the response, there may be
Expand All @@ -339,6 +358,7 @@ func (bc *BroadcastClient) connect(ctx context.Context, nextSeqNum arbutil.Messa

bc.connMutex.Lock()
bc.conn = conn
bc.compression = compressionNegotiated
bc.connMutex.Unlock()
log.Info("Feed connected", "feedServerVersion", feedServerVersion, "chainId", chainId, "requestedSeqNum", nextSeqNum)

Expand All @@ -362,7 +382,7 @@ func (bc *BroadcastClient) startBackgroundReader(earlyFrameData io.Reader) {
var op ws.OpCode
var err error
config := bc.config()
msg, op, err = wsbroadcastserver.ReadData(ctx, bc.conn, earlyFrameData, config.Timeout, ws.StateClientSide, config.EnableCompression, flateReader)
msg, op, err = wsbroadcastserver.ReadData(ctx, bc.conn, earlyFrameData, config.Timeout, ws.StateClientSide, bc.compression, flateReader)
if err != nil {
if bc.isShuttingDown() {
return
Expand Down
53 changes: 20 additions & 33 deletions broadcastclient/broadcastclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,43 +30,30 @@ import (
"github.com/offchainlabs/nitro/wsbroadcastserver"
)

func TestReceiveMessagesWithoutCompression(t *testing.T) {
func TestReceiveMessages(t *testing.T) {
t.Parallel()
testReceiveMessages(t, false, false, false, false)
}

func TestReceiveMessagesWithCompression(t *testing.T) {
t.Parallel()
testReceiveMessages(t, true, true, false, false)
}

func TestReceiveMessagesWithServerOptionalCompression(t *testing.T) {
t.Parallel()
testReceiveMessages(t, true, true, false, false)
}

func TestReceiveMessagesWithServerOnlyCompression(t *testing.T) {
t.Parallel()
testReceiveMessages(t, false, true, false, false)
}

func TestReceiveMessagesWithClientOnlyCompression(t *testing.T) {
t.Parallel()
testReceiveMessages(t, true, false, false, false)
}

func TestReceiveMessagesWithRequiredCompression(t *testing.T) {
t.Parallel()
testReceiveMessages(t, true, true, true, false)
}

func TestReceiveMessagesWithRequiredCompressionButClientDisabled(t *testing.T) {
t.Parallel()
testReceiveMessages(t, false, true, true, true)
t.Run("withoutCompression", func(t *testing.T) {
testReceiveMessages(t, false, false, false, false)
})
t.Run("withServerOptionalCompression", func(t *testing.T) {
testReceiveMessages(t, true, true, false, false)
})
t.Run("withServerOnlyCompression", func(t *testing.T) {
testReceiveMessages(t, false, true, false, false)
})
t.Run("withClientOnlyCompression", func(t *testing.T) {
testReceiveMessages(t, true, false, false, false)
})
t.Run("withRequiredCompression", func(t *testing.T) {
testReceiveMessages(t, true, true, true, false)
})
t.Run("withRequiredCompressionButClientDisabled", func(t *testing.T) {
testReceiveMessages(t, false, true, true, true)
})
}

func testReceiveMessages(t *testing.T, clientCompression bool, serverCompression bool, serverRequire bool, expectNoMessagesReceived bool) {
t.Helper()
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand Down
2 changes: 1 addition & 1 deletion wsbroadcastserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func ReadData(ctx context.Context, conn net.Conn, earlyFrameData io.Reader, time
var data []byte
if msg.IsCompressed() {
if !compression {
return nil, 0, errors.New("Received compressed frame even though compression is disabled")
return nil, 0, errors.New("Received compressed frame even though compression extension wasn't negotiated")
}
flateReader.Reset(&reader)
data, err = io.ReadAll(flateReader)
Expand Down

0 comments on commit 438a992

Please sign in to comment.