diff --git a/wsflate/cbuf.go b/wsflate/cbuf.go index 9a3e9d2..2c10332 100644 --- a/wsflate/cbuf.go +++ b/wsflate/cbuf.go @@ -61,6 +61,26 @@ type suffixedReader struct { r io.Reader pos int // position in the suffix. suffix [9]byte + + rx struct{ io.Reader } +} + +func (r *suffixedReader) iface() io.Reader { + if _, ok := r.r.(io.ByteReader); ok { + // If source io.Reader implements io.ByteReader, return full set of + // methods from suffixedReader struct (Read() and ReadByte()). + // This actually is an optimization needed for those Decompressor + // implementations (such as default flate.Reader) which do check if + // given source is already "buffered" by checking if source implements + // io.ByteReader. So without this checks we will always result in + // double-buffering for default decompressors. + return r + } + // Source io.Reader doesn't support io.ByteReader, so we should cut off the + // ReadByte() method from suffixedReader struct. We use r.srx field to + // avoid allocations. + r.rx.Reader = r + return &r.rx } func (r *suffixedReader) Read(p []byte) (n int, err error) { @@ -80,6 +100,27 @@ func (r *suffixedReader) Read(p []byte) (n int, err error) { return n, nil } +func (r *suffixedReader) ReadByte() (b byte, err error) { + if r.r != nil { + br, ok := r.r.(io.ByteReader) + if !ok { + panic("wsflate: internal error: incorrect use of suffixedReader") + } + b, err = br.ReadByte() + if err == io.EOF { + err = nil + r.r = nil + } + return b, err + } + if r.pos >= len(r.suffix) { + return 0, io.EOF + } + b = r.suffix[r.pos] + r.pos += 1 + return b, nil +} + func (r *suffixedReader) reset(src io.Reader) { r.r = src r.pos = 0 diff --git a/wsflate/reader.go b/wsflate/reader.go index afba9d5..8f0f660 100644 --- a/wsflate/reader.go +++ b/wsflate/reader.go @@ -50,10 +50,11 @@ func (r *Reader) Reset(src io.Reader) { r.err = nil r.src = src r.sr.reset(src) + if x, ok := r.d.(ReadResetter); ok { - x.Reset(&r.sr) + x.Reset(r.sr.iface()) } else { - r.d = r.ctor(&r.sr) + r.d = r.ctor(r.sr.iface()) } } diff --git a/wsflate/reader_test.go b/wsflate/reader_test.go index 55c7a35..a283f7f 100644 --- a/wsflate/reader_test.go +++ b/wsflate/reader_test.go @@ -1 +1,37 @@ package wsflate + +import ( + "bytes" + "fmt" + "io" + "testing" +) + +func TestSuffixedReaderIface(t *testing.T) { + for _, test := range []struct { + src io.Reader + exp bool + }{ + { + src: bytes.NewReader(nil), + exp: true, + }, + { + src: io.TeeReader(nil, nil), + exp: false, + }, + } { + t.Run(fmt.Sprintf("%T", test.src), func(t *testing.T) { + isByteReader := func(r io.Reader) bool { + _, ok := r.(io.ByteReader) + return ok + } + s := &suffixedReader{ + r: test.src, + } + if act, exp := isByteReader(s.iface()), test.exp; act != exp { + t.Fatalf("unexpected io.ByteReader: %t; want %t", act, exp) + } + }) + } +}