Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure verification tag is zero for all INIT packets #359

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,145 @@ func TestAssocUnreliable(t *testing.T) {
})
}

// This test ensures that verification tag is set to 0 for all INIT packets.
// A test for this PR https://github.com/pion/sctp/pull/341
// We drop the first INIT ACK, and we expect the verification tag to be 0 on
// retransmission.
func TestInitVerificationTagIsZero(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

const si uint16 = 1
const msg = "ABC"
br := test.NewBridge()
ackCount := 0
recvBufSize := uint32(0)

var a0, a1 *Association
var err0, err1 error
loggerFactory := logging.NewDefaultLoggerFactory()

handshake0Ch := make(chan bool)
handshake1Ch := make(chan bool)
fatalChannel := make(chan error)

fitlerFunc := func(pkt []byte) bool {
t.Helper()

packetData := packet{}

assert.NoError(t, packetData.unmarshal(true, pkt))

// Init chunk and Init Ack chunk are never bundled.
if len(packetData.chunks) != 1 {
return true
}

switch packetData.chunks[0].(type) {
case *chunkInit:
if packetData.verificationTag != 0 {
// Even without this we will get WARNING:
// failed validating packet init chunk expects a verification tag of 0 on the packet when out-of-the-blue
// And the connection will fail silently.
go func() {
fatalChannel <- errors.New("verification tag should be 0 for Init chunk") //nolint:err113
}()

return false
}
// Drop the first two Init Ack chunk.
case *chunkInitAck:
ackCount++
return ackCount > 2
}

return true
}

br.Filter(0, fitlerFunc)

br.Filter(1, fitlerFunc)

go func() {
a0, err0 = Client(Config{
Name: "a0",
NetConn: br.GetConn0(),
MaxReceiveBufferSize: recvBufSize,
LoggerFactory: loggerFactory,
})

handshake0Ch <- true
}()
go func() {
a1, err1 = Client(Config{
Name: "a1",
NetConn: br.GetConn1(),
MaxReceiveBufferSize: recvBufSize,
LoggerFactory: loggerFactory,
})
handshake1Ch <- true
}()

a0handshakeDone := false
a1handshakeDone := false

loop1:
for i := 0; i < 1e3; i++ {
time.Sleep(10 * time.Millisecond)
br.Tick()

select {
case a0handshakeDone = <-handshake0Ch:
if a1handshakeDone {
break loop1
}
case a1handshakeDone = <-handshake1Ch:
if a0handshakeDone {
break loop1
}
case err := <-fatalChannel:
t.Fatal(err)
default:
}
}

assert.Equal(t, a0handshakeDone, true, "handshake failed e0")
assert.Equal(t, a1handshakeDone, true, "handshake failed e1")

assert.NoError(t, err0, "failed to create association a0")
assert.NoError(t, err1, "failed to create association a1")

a0.ackMode = ackModeNoDelay
a1.ackMode = ackModeNoDelay

s0, s1, err := establishSessionPair(br, a0, a1, si)
assert.Nil(t, err, "failed to establish session pair")

assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount")

n, err := s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary)
if err != nil {
assert.FailNow(t, "failed due to earlier error")
}
assert.Equal(t, len(msg), n, "unexpected length of received data")
assert.Equal(t, len(msg), a0.bufferedAmount(), "incorrect bufferedAmount")

flushBuffers(br, a0, a1)

buf := make([]byte, 32)
n, ppi, err := s1.ReadSCTP(buf)
if !assert.Nil(t, err, "ReadSCTP failed") {
assert.FailNow(t, "failed due to earlier error")
}
assert.Equal(t, n, len(msg), "unexpected length of received data")
assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi")

assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable")
assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount")

closeAssociationPair(br, a0, a1)
}

func TestCreateForwardTSN(t *testing.T) {
loggerFactory := logging.NewDefaultLoggerFactory()

Expand Down
Loading