From 9b2b03b0f310335a9dd7c26e6a1c6928826d90d9 Mon Sep 17 00:00:00 2001 From: jared2501 Date: Mon, 23 Apr 2018 00:02:51 -0700 Subject: [PATCH 1/2] Implement simple stream write priorities (#1) --- session.go | 80 +++++++++++++++++++++++++++++++++++++++---------- session_test.go | 23 +++++++++++--- stream.go | 19 ++++++++---- 3 files changed, 98 insertions(+), 24 deletions(-) diff --git a/session.go b/session.go index e93317e..29d7e20 100644 --- a/session.go +++ b/session.go @@ -8,6 +8,7 @@ import ( "time" "github.com/pkg/errors" + "container/heap" ) const ( @@ -21,8 +22,10 @@ const ( ) type writeRequest struct { - frame Frame - result chan writeResult + niceness uint8 + sequence uint64 // Used to keep the heap ordered by time + frame Frame + result chan writeResult } type writeResult struct { @@ -30,6 +33,31 @@ type writeResult struct { err error } +type writeHeap []writeRequest + +func (h writeHeap) Len() int { return len(h) } +func (h writeHeap) Less(i, j int) bool { + if h[i].niceness == h[j].niceness { + return h[i].sequence < h[j].sequence + } + return h[i].niceness < h[j].niceness +} +func (h writeHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *writeHeap) Push(x interface{}) { + // Push and Pop use pointer receivers because they modify the slice's length, + // not just its contents. + *h = append(*h, x.(writeRequest)) +} + +func (h *writeHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + // Session defines a multiplexed connection for streams type Session struct { conn io.ReadWriteCloser @@ -54,7 +82,10 @@ type Session struct { deadline atomic.Value - writes chan writeRequest + writeTicket chan struct{} + writesLock sync.Mutex + writes writeHeap + writeSequenceNum uint64 } func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { @@ -66,7 +97,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { s.chAccepts = make(chan *Stream, defaultAcceptBacklog) s.bucket = int32(config.MaxReceiveBuffer) s.bucketNotify = make(chan struct{}, 1) - s.writes = make(chan writeRequest) + s.writeTicket = make(chan struct{}) if client { s.nextStreamID = 1 @@ -79,8 +110,12 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { return s } -// OpenStream is used to create a new stream func (s *Session) OpenStream() (*Stream, error) { + return s.OpenStreamOpt(100) +} + +// OpenStream is used to create a new stream +func (s *Session) OpenStreamOpt(niceness uint8) (*Stream, error) { if s.IsClosed() { return nil, errors.New(errBrokenPipe) } @@ -101,9 +136,9 @@ func (s *Session) OpenStream() (*Stream, error) { } s.nextStreamIDLock.Unlock() - stream := newStream(sid, s.config.MaxFrameSize, s) + stream := newStream(sid, niceness, s.config.MaxFrameSize, s) - if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil { + if _, err := s.writeFrame(0, newFrame(cmdSYN, sid)); err != nil { return nil, errors.Wrap(err, "writeFrame") } @@ -113,9 +148,13 @@ func (s *Session) OpenStream() (*Stream, error) { return stream, nil } +func (s *Session) AcceptStream() (*Stream, error) { + return s.AcceptStreamOpt(100) +} + // AcceptStream is used to block until the next available stream // is ready to be accepted. -func (s *Session) AcceptStream() (*Stream, error) { +func (s *Session) AcceptStreamOpt(niceness uint8) (*Stream, error) { var deadline <-chan time.Time if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() { timer := time.NewTimer(time.Until(d)) @@ -124,6 +163,7 @@ func (s *Session) AcceptStream() (*Stream, error) { } select { case stream := <-s.chAccepts: + stream.niceness = niceness return stream, nil case <-deadline: return nil, errTimeout @@ -247,7 +287,7 @@ func (s *Session) recvLoop() { case cmdSYN: s.streamLock.Lock() if _, ok := s.streams[f.sid]; !ok { - stream := newStream(f.sid, s.config.MaxFrameSize, s) + stream := newStream(f.sid, 255, s.config.MaxFrameSize, s) s.streams[f.sid] = stream select { case s.chAccepts <- stream: @@ -289,7 +329,7 @@ func (s *Session) keepalive() { for { select { case <-tickerPing.C: - s.writeFrame(newFrame(cmdNOP, 0)) + s.writeFrame(0, newFrame(cmdNOP, 0)) s.notifyBucket() // force a signal to the recvLoop case <-tickerTimeout.C: if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { @@ -308,7 +348,11 @@ func (s *Session) sendLoop() { select { case <-s.die: return - case request := <-s.writes: + case <-s.writeTicket: + s.writesLock.Lock() + request := heap.Pop(&s.writes).(writeRequest) + s.writesLock.Unlock() + buf[0] = request.frame.ver buf[1] = request.frame.cmd binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) @@ -334,15 +378,21 @@ func (s *Session) sendLoop() { // writeFrame writes the frame to the underlying connection // and returns the number of bytes written if successful -func (s *Session) writeFrame(f Frame) (n int, err error) { +func (s *Session) writeFrame(niceness uint8, f Frame) (n int, err error) { req := writeRequest{ - frame: f, - result: make(chan writeResult, 1), + niceness: niceness, + sequence: atomic.AddUint64(&s.writeSequenceNum, 1), + frame: f, + result: make(chan writeResult, 1), } + + s.writesLock.Lock() + heap.Push(&s.writes, req) + s.writesLock.Unlock() select { case <-s.die: return 0, errors.New(errBrokenPipe) - case s.writes <- req: + case s.writeTicket <- struct{}{}: } result := <-req.result diff --git a/session_test.go b/session_test.go index 760642d..03e28db 100644 --- a/session_test.go +++ b/session_test.go @@ -11,6 +11,8 @@ import ( "sync" "testing" "time" + "container/heap" + "github.com/stretchr/testify/assert" ) // setupServer starts new server listening on a random localhost port and @@ -58,6 +60,19 @@ func handleConnection(conn net.Conn) { } } +func TestWriteHeap(t *testing.T) { + var reqs writeHeap + req1 := writeRequest{niceness: 1} + heap.Push(&reqs, req1) + req3 := writeRequest{niceness: 3} + heap.Push(&reqs, req3) + req2 := writeRequest{niceness: 2} + heap.Push(&reqs, req2) + assert.Equal(t, heap.Pop(&reqs), req1) + assert.Equal(t, heap.Pop(&reqs), req2) + assert.Equal(t, heap.Pop(&reqs), req3) +} + func TestEcho(t *testing.T) { _, stop, cli, err := setupServer(t) if err != nil { @@ -461,7 +476,7 @@ func TestRandomFrame(t *testing.T) { session, _ = Client(cli, nil) for i := 0; i < 100; i++ { f := newFrame(cmdSYN, 1000) - session.writeFrame(f) + session.writeFrame(0, f) } cli.Close() @@ -474,7 +489,7 @@ func TestRandomFrame(t *testing.T) { session, _ = Client(cli, nil) for i := 0; i < 100; i++ { f := newFrame(allcmds[rand.Int()%len(allcmds)], rand.Uint32()) - session.writeFrame(f) + session.writeFrame(0, f) } cli.Close() @@ -486,7 +501,7 @@ func TestRandomFrame(t *testing.T) { session, _ = Client(cli, nil) for i := 0; i < 100; i++ { f := newFrame(byte(rand.Uint32()), rand.Uint32()) - session.writeFrame(f) + session.writeFrame(0, f) } cli.Close() @@ -499,7 +514,7 @@ func TestRandomFrame(t *testing.T) { for i := 0; i < 100; i++ { f := newFrame(byte(rand.Uint32()), rand.Uint32()) f.ver = byte(rand.Uint32()) - session.writeFrame(f) + session.writeFrame(0, f) } cli.Close() diff --git a/stream.go b/stream.go index 57a0bc6..234291a 100644 --- a/stream.go +++ b/stream.go @@ -9,11 +9,13 @@ import ( "time" "github.com/pkg/errors" + "container/heap" ) // Stream implements net.Conn type Stream struct { id uint32 + niceness uint8 rstflag int32 sess *Session buffer bytes.Buffer @@ -27,8 +29,9 @@ type Stream struct { } // newStream initiates a Stream struct -func newStream(id uint32, frameSize int, sess *Session) *Stream { +func newStream(id uint32, niceness uint8, frameSize int, sess *Session) *Stream { s := new(Stream) + s.niceness = niceness s.id = id s.chReadEvent = make(chan struct{}, 1) s.frameSize = frameSize @@ -102,12 +105,18 @@ func (s *Stream) Write(b []byte) (n int, err error) { sent := 0 for k := range frames { req := writeRequest{ - frame: frames[k], - result: make(chan writeResult, 1), + niceness: s.niceness, + sequence: atomic.AddUint64(&s.sess.writeSequenceNum, 1), + frame: frames[k], + result: make(chan writeResult, 1), } + // TODO(jnewman): replace with session.writeFrame(..)? + s.sess.writesLock.Lock() + heap.Push(&s.sess.writes, req) + s.sess.writesLock.Unlock() select { - case s.sess.writes <- req: + case s.sess.writeTicket <- struct{}{}: case <-s.die: return sent, errors.New(errBrokenPipe) case <-deadline: @@ -141,7 +150,7 @@ func (s *Stream) Close() error { close(s.die) s.dieLock.Unlock() s.sess.streamClosed(s.id) - _, err := s.sess.writeFrame(newFrame(cmdFIN, s.id)) + _, err := s.sess.writeFrame(0, newFrame(cmdFIN, s.id)) return err } } From a468d322c7440fccec2b2ffb2149396602cb5eac Mon Sep 17 00:00:00 2001 From: Jared Newman Date: Mon, 23 Apr 2018 00:04:00 -0700 Subject: [PATCH 2/2] Implement a bucket per stream to prevent HOLB --- frame.go | 1 + mux.go | 14 ++++----- mux_test.go | 2 +- session.go | 62 ++++++++++++++++------------------------ session_test.go | 76 +++++++++++++++++++++++++++++++++++++++++++++++++ stream.go | 64 +++++++++++++++++++++++++++-------------- 6 files changed, 153 insertions(+), 66 deletions(-) diff --git a/frame.go b/frame.go index 36062d7..bfa6a34 100644 --- a/frame.go +++ b/frame.go @@ -13,6 +13,7 @@ const ( // cmds cmdSYN byte = iota // stream open cmdFIN // stream close, a.k.a EOF mark cmdPSH // data push + cmdACK // data ack cmdNOP // no operation ) diff --git a/mux.go b/mux.go index afcf58b..40d1135 100644 --- a/mux.go +++ b/mux.go @@ -21,18 +21,18 @@ type Config struct { // frame size to sent to the remote MaxFrameSize int - // MaxReceiveBuffer is used to control the maximum + // MaxPerStreamReceiveBuffer is used to control the maximum // number of data in the buffer pool - MaxReceiveBuffer int + MaxPerStreamReceiveBuffer int } // DefaultConfig is used to return a default configuration func DefaultConfig() *Config { return &Config{ - KeepAliveInterval: 10 * time.Second, - KeepAliveTimeout: 30 * time.Second, - MaxFrameSize: 4096, - MaxReceiveBuffer: 4194304, + KeepAliveInterval: 10 * time.Second, + KeepAliveTimeout: 30 * time.Second, + MaxFrameSize: 4096, + MaxPerStreamReceiveBuffer: 4194304, } } @@ -50,7 +50,7 @@ func VerifyConfig(config *Config) error { if config.MaxFrameSize > 65535 { return errors.New("max frame size must not be larger than 65535") } - if config.MaxReceiveBuffer <= 0 { + if config.MaxPerStreamReceiveBuffer <= 0 { return errors.New("max receive buffer must be positive") } return nil diff --git a/mux_test.go b/mux_test.go index 638e67c..f4ae72c 100644 --- a/mux_test.go +++ b/mux_test.go @@ -51,7 +51,7 @@ func TestConfig(t *testing.T) { } config = DefaultConfig() - config.MaxReceiveBuffer = 0 + config.MaxPerStreamReceiveBuffer = 0 err = VerifyConfig(config) t.Log(err) if err == nil { diff --git a/session.go b/session.go index 29d7e20..1491c4d 100644 --- a/session.go +++ b/session.go @@ -66,9 +66,6 @@ type Session struct { nextStreamID uint32 // next stream identifier nextStreamIDLock sync.Mutex - bucket int32 // token bucket - bucketNotify chan struct{} // used for waiting for tokens - streams map[uint32]*Stream // all streams in this session streamLock sync.Mutex // locks streams @@ -95,8 +92,6 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { s.config = config s.streams = make(map[uint32]*Stream) s.chAccepts = make(chan *Stream, defaultAcceptBacklog) - s.bucket = int32(config.MaxReceiveBuffer) - s.bucketNotify = make(chan struct{}, 1) s.writeTicket = make(chan struct{}) if client { @@ -136,7 +131,7 @@ func (s *Session) OpenStreamOpt(niceness uint8) (*Stream, error) { } s.nextStreamIDLock.Unlock() - stream := newStream(sid, niceness, s.config.MaxFrameSize, s) + stream := newStream(sid, niceness, s.config.MaxFrameSize, int32(s.config.MaxPerStreamReceiveBuffer), s) if _, err := s.writeFrame(0, newFrame(cmdSYN, sid)); err != nil { return nil, errors.Wrap(err, "writeFrame") @@ -188,19 +183,10 @@ func (s *Session) Close() (err error) { s.streams[k].sessionClose() } s.streamLock.Unlock() - s.notifyBucket() return s.conn.Close() } } -// notifyBucket notifies recvLoop that bucket is available -func (s *Session) notifyBucket() { - select { - case s.bucketNotify <- struct{}{}: - default: - } -} - // IsClosed does a safe check to see if we have shutdown func (s *Session) IsClosed() bool { select { @@ -231,20 +217,15 @@ func (s *Session) SetDeadline(t time.Time) error { // notify the session that a stream has closed func (s *Session) streamClosed(sid uint32) { s.streamLock.Lock() - if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket - if atomic.AddInt32(&s.bucket, int32(n)) > 0 { - s.notifyBucket() - } - } delete(s.streams, sid) s.streamLock.Unlock() } -// returnTokens is called by stream to return token after read -func (s *Session) returnTokens(n int) { - if atomic.AddInt32(&s.bucket, int32(n)) > 0 { - s.notifyBucket() - } +func (s *Session) queueAcks(streamId uint32, n int32) { + ack := newFrame(cmdACK, streamId) + ack.data = make([]byte, 4) + binary.BigEndian.PutUint32(ack.data, uint32(n)) + s.queueFrame(0, ack) } // session read a frame from underlying connection @@ -275,10 +256,6 @@ func (s *Session) readFrame(buffer []byte) (f Frame, err error) { func (s *Session) recvLoop() { buffer := make([]byte, (1<<16)+headerSize) for { - for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() { - <-s.bucketNotify - } - if f, err := s.readFrame(buffer); err == nil { atomic.StoreInt32(&s.dataReady, 1) @@ -287,7 +264,7 @@ func (s *Session) recvLoop() { case cmdSYN: s.streamLock.Lock() if _, ok := s.streams[f.sid]; !ok { - stream := newStream(f.sid, 255, s.config.MaxFrameSize, s) + stream := newStream(f.sid, 255, s.config.MaxFrameSize, int32(s.config.MaxPerStreamReceiveBuffer), s) s.streams[f.sid] = stream select { case s.chAccepts <- stream: @@ -305,11 +282,17 @@ func (s *Session) recvLoop() { case cmdPSH: s.streamLock.Lock() if stream, ok := s.streams[f.sid]; ok { - atomic.AddInt32(&s.bucket, -int32(len(f.data))) stream.pushBytes(f.data) stream.notifyReadEvent() } s.streamLock.Unlock() + case cmdACK: + s.streamLock.Lock() + if stream, ok := s.streams[f.sid]; ok { + tokens := binary.BigEndian.Uint32(f.data) + stream.receiveAck(int32(tokens)) + } + s.streamLock.Unlock() default: s.Close() return @@ -330,7 +313,6 @@ func (s *Session) keepalive() { select { case <-tickerPing.C: s.writeFrame(0, newFrame(cmdNOP, 0)) - s.notifyBucket() // force a signal to the recvLoop case <-tickerTimeout.C: if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { s.Close() @@ -376,25 +358,31 @@ func (s *Session) sendLoop() { } } -// writeFrame writes the frame to the underlying connection -// and returns the number of bytes written if successful -func (s *Session) writeFrame(niceness uint8, f Frame) (n int, err error) { +func (s *Session) queueFrame(niceness uint8, f Frame) (writeRequest, error) { req := writeRequest{ niceness: niceness, sequence: atomic.AddUint64(&s.writeSequenceNum, 1), frame: f, result: make(chan writeResult, 1), } - s.writesLock.Lock() heap.Push(&s.writes, req) s.writesLock.Unlock() select { case <-s.die: - return 0, errors.New(errBrokenPipe) + return req, errors.New(errBrokenPipe) case s.writeTicket <- struct{}{}: } + return req, nil +} +// writeFrame writes the frame to the underlying connection +// and returns the number of bytes written if successful +func (s *Session) writeFrame(niceness uint8, f Frame) (n int, err error) { + req, err := s.queueFrame(niceness, f) + if err != nil { + return 0, err + } result := <-req.result return result.n, result.err } diff --git a/session_test.go b/session_test.go index 03e28db..b78cba3 100644 --- a/session_test.go +++ b/session_test.go @@ -13,6 +13,7 @@ import ( "time" "container/heap" "github.com/stretchr/testify/assert" + "io/ioutil" ) // setupServer starts new server listening on a random localhost port and @@ -593,6 +594,81 @@ func TestWriteDeadline(t *testing.T) { session.Close() } +func TestSlowReceiverDoesNotBlock(t *testing.T) { + config := &Config{ + KeepAliveInterval: 10 * time.Second, + KeepAliveTimeout: 30 * time.Second, + MaxFrameSize: 100, + MaxPerStreamReceiveBuffer: 1000, + } + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + go func() { + err := func() error { + conn, err := ln.Accept() + if err != nil { + return err + } + defer conn.Close() + session, err := Server(conn, config) + if err != nil { + return err + } + defer session.Close() + // Accept stream1 but ready nothing from it + _, err = session.AcceptStream() + if err != nil { + return err + } + stream2, err := session.AcceptStream() + io.Copy(ioutil.Discard, stream2) + return nil + }() + if err != nil { + t.Error(err) + } + }() + + cli, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer cli.Close() + session, err := Client(cli, config) + if err != nil { + t.Fatal(err) + } + + // Open the first stream and write more than the receive buffer, verifying that it finishes with an error + go func() { + buf := make([]byte, 2*config.MaxPerStreamReceiveBuffer) + if stream, err := session.OpenStream(); err == nil { + _, err := stream.Write(buf) + assert.NotNil(t, err) + } else { + t.Fatal(err) + } + }() + + // Wait until the go routine above runs to create the first stream + // TODO(jnewman): make this less flaky + time.Sleep(100 * time.Millisecond) + + // The second stream can be written to + stream, err := session.OpenStream() + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 10*config.MaxPerStreamReceiveBuffer) + n, err := stream.Write(buf) + assert.Equal(t, n, 10*config.MaxPerStreamReceiveBuffer) + assert.Nil(t, err) + session.Close() +} + func BenchmarkAcceptClose(b *testing.B) { _, stop, cli, err := setupServer(b) if err != nil { diff --git a/stream.go b/stream.go index 234291a..43b1fb7 100644 --- a/stream.go +++ b/stream.go @@ -14,22 +14,24 @@ import ( // Stream implements net.Conn type Stream struct { - id uint32 - niceness uint8 - rstflag int32 - sess *Session - buffer bytes.Buffer - bufferLock sync.Mutex - frameSize int - chReadEvent chan struct{} // notify a read event - die chan struct{} // flag the stream has closed - dieLock sync.Mutex - readDeadline atomic.Value - writeDeadline atomic.Value + id uint32 + niceness uint8 + rstflag int32 + sess *Session + buffer bytes.Buffer + bufferLock sync.Mutex + frameSize int + chReadEvent chan struct{} // notify a read event + die chan struct{} // flag the stream has closed + dieLock sync.Mutex + readDeadline atomic.Value + writeDeadline atomic.Value + writeTokenBucket int32 // write tokens required for writing to the sessions + writeTokenBucketNotify chan struct{} // used for waiting for tokens } // newStream initiates a Stream struct -func newStream(id uint32, niceness uint8, frameSize int, sess *Session) *Stream { +func newStream(id uint32, niceness uint8, frameSize int, writeTokenBucketSize int32, sess *Session) *Stream { s := new(Stream) s.niceness = niceness s.id = id @@ -37,6 +39,8 @@ func newStream(id uint32, niceness uint8, frameSize int, sess *Session) *Stream s.frameSize = frameSize s.sess = sess s.die = make(chan struct{}) + s.writeTokenBucket = writeTokenBucketSize + s.writeTokenBucketNotify = make(chan struct{}, 1) return s } @@ -69,7 +73,7 @@ READ: s.bufferLock.Unlock() if n > 0 { - s.sess.returnTokens(n) + s.sess.queueAcks(s.id, int32(n)) return n, nil } else if atomic.LoadInt32(&s.rstflag) == 1 { _ = s.Close() @@ -104,6 +108,18 @@ func (s *Stream) Write(b []byte) (n int, err error) { frames := s.split(b, cmdPSH, s.id) sent := 0 for k := range frames { + for atomic.LoadInt32(&s.writeTokenBucket) <= 0 { + select { + case <-s.writeTokenBucketNotify: + case <-s.die: + return sent, errors.New(errBrokenPipe) + case <-deadline: + return sent, errTimeout + } + } + + atomic.AddInt32(&s.writeTokenBucket, -int32(len(frames[k].data))) + req := writeRequest{ niceness: s.niceness, sequence: atomic.AddUint64(&s.sess.writeSequenceNum, 1), @@ -223,13 +239,19 @@ func (s *Stream) pushBytes(p []byte) { s.bufferLock.Unlock() } -// recycleTokens transform remaining bytes to tokens(will truncate buffer) -func (s *Stream) recycleTokens() (n int) { - s.bufferLock.Lock() - n = s.buffer.Len() - s.buffer.Reset() - s.bufferLock.Unlock() - return +// receiveAck replenishes the token writeTokenBucket so that more writes can proceed +func (s *Stream) receiveAck(numTokens int32) { + if atomic.AddInt32(&s.writeTokenBucket, numTokens) > 0 { + s.notifyBucket() + } +} + +// notifyBucket notifies waiting write loops that there are more tokens in the writeTokenBucket +func (s *Stream) notifyBucket() { + select { + case s.writeTokenBucketNotify <- struct{}{}: + default: + } } // split large byte buffer into smaller frames, reference only