From 81f573a5160f82d57593b0fe104172ce520db384 Mon Sep 17 00:00:00 2001 From: Joe Turki Date: Tue, 21 Jan 2025 15:47:22 -0600 Subject: [PATCH] Ensure verification tag is zero for INIT packets This test verifies that the verification tag is correctly set to 0 for all INIT packets, including retransmissions. --- association_test.go | 139 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) diff --git a/association_test.go b/association_test.go index 44a71acf..5f1c2ed0 100644 --- a/association_test.go +++ b/association_test.go @@ -1181,6 +1181,145 @@ func TestAssocUnreliable(t *testing.T) { }) } +// This test ensures that verification tag is set to 0 for all INIT packets. +// A test for this PR https://github.com/pion/sctp/pull/341 +// We drop the first INIT ACK, and we expect the verification tag to be 0 on +// retransmission. +func TestInitVerificationTagIsZero(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + const si uint16 = 1 + const msg = "ABC" + br := test.NewBridge() + ackCount := 0 + recvBufSize := uint32(0) + + var a0, a1 *Association + var err0, err1 error + loggerFactory := logging.NewDefaultLoggerFactory() + + handshake0Ch := make(chan bool) + handshake1Ch := make(chan bool) + fatalChannel := make(chan error) + + fitlerFunc := func(pkt []byte) bool { + t.Helper() + + packetData := packet{} + + assert.NoError(t, packetData.unmarshal(true, pkt)) + + // Init chunk and Init Ack chunk are never bundled. + if len(packetData.chunks) != 1 { + return true + } + + switch packetData.chunks[0].(type) { + case *chunkInit: + if packetData.verificationTag != 0 { + // Even without this we will get WARNING: + // failed validating packet init chunk expects a verification tag of 0 on the packet when out-of-the-blue + // And the connection will fail silently. + go func() { + fatalChannel <- errors.New("verification tag should be 0 for Init chunk") //nolint:err113 + }() + + return false + } + // Drop the first two Init Ack chunk. + case *chunkInitAck: + ackCount++ + return ackCount > 2 + } + + return true + } + + br.Filter(0, fitlerFunc) + + br.Filter(1, fitlerFunc) + + go func() { + a0, err0 = Client(Config{ + Name: "a0", + NetConn: br.GetConn0(), + MaxReceiveBufferSize: recvBufSize, + LoggerFactory: loggerFactory, + }) + + handshake0Ch <- true + }() + go func() { + a1, err1 = Client(Config{ + Name: "a1", + NetConn: br.GetConn1(), + MaxReceiveBufferSize: recvBufSize, + LoggerFactory: loggerFactory, + }) + handshake1Ch <- true + }() + + a0handshakeDone := false + a1handshakeDone := false + +loop1: + for i := 0; i < 1e3; i++ { + time.Sleep(10 * time.Millisecond) + br.Tick() + + select { + case a0handshakeDone = <-handshake0Ch: + if a1handshakeDone { + break loop1 + } + case a1handshakeDone = <-handshake1Ch: + if a0handshakeDone { + break loop1 + } + case err := <-fatalChannel: + t.Fatal(err) + default: + } + } + + assert.Equal(t, a0handshakeDone, true, "handshake failed e0") + assert.Equal(t, a1handshakeDone, true, "handshake failed e1") + + assert.NoError(t, err0, "failed to create association a0") + assert.NoError(t, err1, "failed to create association a1") + + a0.ackMode = ackModeNoDelay + a1.ackMode = ackModeNoDelay + + s0, s1, err := establishSessionPair(br, a0, a1, si) + assert.Nil(t, err, "failed to establish session pair") + + assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount") + + n, err := s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary) + if err != nil { + assert.FailNow(t, "failed due to earlier error") + } + assert.Equal(t, len(msg), n, "unexpected length of received data") + assert.Equal(t, len(msg), a0.bufferedAmount(), "incorrect bufferedAmount") + + flushBuffers(br, a0, a1) + + buf := make([]byte, 32) + n, ppi, err := s1.ReadSCTP(buf) + if !assert.Nil(t, err, "ReadSCTP failed") { + assert.FailNow(t, "failed due to earlier error") + } + assert.Equal(t, n, len(msg), "unexpected length of received data") + assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") + + assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") + assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount") + + closeAssociationPair(br, a0, a1) +} + func TestCreateForwardTSN(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory()