diff --git a/packetio/buffer.go b/packetio/buffer.go index 719b425..7c4c319 100644 --- a/packetio/buffer.go +++ b/packetio/buffer.go @@ -287,10 +287,7 @@ func (b *Buffer) Close() (err error) { b.mutex.Unlock() if waiting { - select { - case b.notify <- struct{}{}: - default: - } + close(b.notify) } return nil diff --git a/packetio/buffer_test.go b/packetio/buffer_test.go index 1711d35..2b8cb03 100644 --- a/packetio/buffer_test.go +++ b/packetio/buffer_test.go @@ -582,3 +582,47 @@ func BenchmarkBuffer140(b *testing.B) { func BenchmarkBuffer1400(b *testing.B) { benchmarkBuffer(b, 1400) } + +func TestBufferConcurrentRead(t *testing.T) { + assert := assert.New(t) + + buffer := NewBuffer() + packet := make([]byte, 4) + + // Write twice + n, err := buffer.Write([]byte{2, 3, 4}) + assert.NoError(err) + assert.Equal(3, n) + + n, err = buffer.Write([]byte{5, 6, 7}) + assert.NoError(err) + assert.Equal(3, n) + + // Read twice + n, err = buffer.Read(packet) + assert.NoError(err) + assert.Equal(3, n) + assert.Equal([]byte{2, 3, 4}, packet[:n]) + + n, err = buffer.Read(packet) + assert.NoError(err) + assert.Equal(3, n) + assert.Equal([]byte{5, 6, 7}, packet[:n]) + + errCh := make(chan error, 2) + readIntoErr := func() { + _, readErr := buffer.Read(packet) + errCh <- readErr + } + go readIntoErr() + go readIntoErr() + + // Close + err = buffer.Close() + assert.NoError(err) + + err = <-errCh + assert.Equal(io.EOF, err) + err = <-errCh + assert.Equal(io.EOF, err) +}