diff --git a/frame.go b/frame.go index 36062d7..bfa6a34 100644 --- a/frame.go +++ b/frame.go @@ -13,6 +13,7 @@ const ( // cmds cmdSYN byte = iota // stream open cmdFIN // stream close, a.k.a EOF mark cmdPSH // data push + cmdACK // data ack cmdNOP // no operation ) diff --git a/mux.go b/mux.go index afcf58b..40d1135 100644 --- a/mux.go +++ b/mux.go @@ -21,18 +21,18 @@ type Config struct { // frame size to sent to the remote MaxFrameSize int - // MaxReceiveBuffer is used to control the maximum + // MaxPerStreamReceiveBuffer is used to control the maximum // number of data in the buffer pool - MaxReceiveBuffer int + MaxPerStreamReceiveBuffer int } // DefaultConfig is used to return a default configuration func DefaultConfig() *Config { return &Config{ - KeepAliveInterval: 10 * time.Second, - KeepAliveTimeout: 30 * time.Second, - MaxFrameSize: 4096, - MaxReceiveBuffer: 4194304, + KeepAliveInterval: 10 * time.Second, + KeepAliveTimeout: 30 * time.Second, + MaxFrameSize: 4096, + MaxPerStreamReceiveBuffer: 4194304, } } @@ -50,7 +50,7 @@ func VerifyConfig(config *Config) error { if config.MaxFrameSize > 65535 { return errors.New("max frame size must not be larger than 65535") } - if config.MaxReceiveBuffer <= 0 { + if config.MaxPerStreamReceiveBuffer <= 0 { return errors.New("max receive buffer must be positive") } return nil diff --git a/mux_test.go b/mux_test.go index 638e67c..f4ae72c 100644 --- a/mux_test.go +++ b/mux_test.go @@ -51,7 +51,7 @@ func TestConfig(t *testing.T) { } config = DefaultConfig() - config.MaxReceiveBuffer = 0 + config.MaxPerStreamReceiveBuffer = 0 err = VerifyConfig(config) t.Log(err) if err == nil { diff --git a/session.go b/session.go index e93317e..1491c4d 100644 --- a/session.go +++ b/session.go @@ -8,6 +8,7 @@ import ( "time" "github.com/pkg/errors" + "container/heap" ) const ( @@ -21,8 +22,10 @@ const ( ) type writeRequest struct { - frame Frame - result chan writeResult + niceness uint8 + sequence uint64 // Used to keep the heap ordered by time + frame Frame + result chan writeResult } type writeResult struct { @@ -30,6 +33,31 @@ type writeResult struct { err error } +type writeHeap []writeRequest + +func (h writeHeap) Len() int { return len(h) } +func (h writeHeap) Less(i, j int) bool { + if h[i].niceness == h[j].niceness { + return h[i].sequence < h[j].sequence + } + return h[i].niceness < h[j].niceness +} +func (h writeHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *writeHeap) Push(x interface{}) { + // Push and Pop use pointer receivers because they modify the slice's length, + // not just its contents. + *h = append(*h, x.(writeRequest)) +} + +func (h *writeHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + // Session defines a multiplexed connection for streams type Session struct { conn io.ReadWriteCloser @@ -38,9 +66,6 @@ type Session struct { nextStreamID uint32 // next stream identifier nextStreamIDLock sync.Mutex - bucket int32 // token bucket - bucketNotify chan struct{} // used for waiting for tokens - streams map[uint32]*Stream // all streams in this session streamLock sync.Mutex // locks streams @@ -54,7 +79,10 @@ type Session struct { deadline atomic.Value - writes chan writeRequest + writeTicket chan struct{} + writesLock sync.Mutex + writes writeHeap + writeSequenceNum uint64 } func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { @@ -64,9 +92,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { s.config = config s.streams = make(map[uint32]*Stream) s.chAccepts = make(chan *Stream, defaultAcceptBacklog) - s.bucket = int32(config.MaxReceiveBuffer) - s.bucketNotify = make(chan struct{}, 1) - s.writes = make(chan writeRequest) + s.writeTicket = make(chan struct{}) if client { s.nextStreamID = 1 @@ -79,8 +105,12 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { return s } -// OpenStream is used to create a new stream func (s *Session) OpenStream() (*Stream, error) { + return s.OpenStreamOpt(100) +} + +// OpenStream is used to create a new stream +func (s *Session) OpenStreamOpt(niceness uint8) (*Stream, error) { if s.IsClosed() { return nil, errors.New(errBrokenPipe) } @@ -101,9 +131,9 @@ func (s *Session) OpenStream() (*Stream, error) { } s.nextStreamIDLock.Unlock() - stream := newStream(sid, s.config.MaxFrameSize, s) + stream := newStream(sid, niceness, s.config.MaxFrameSize, int32(s.config.MaxPerStreamReceiveBuffer), s) - if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil { + if _, err := s.writeFrame(0, newFrame(cmdSYN, sid)); err != nil { return nil, errors.Wrap(err, "writeFrame") } @@ -113,9 +143,13 @@ func (s *Session) OpenStream() (*Stream, error) { return stream, nil } +func (s *Session) AcceptStream() (*Stream, error) { + return s.AcceptStreamOpt(100) +} + // AcceptStream is used to block until the next available stream // is ready to be accepted. -func (s *Session) AcceptStream() (*Stream, error) { +func (s *Session) AcceptStreamOpt(niceness uint8) (*Stream, error) { var deadline <-chan time.Time if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() { timer := time.NewTimer(time.Until(d)) @@ -124,6 +158,7 @@ func (s *Session) AcceptStream() (*Stream, error) { } select { case stream := <-s.chAccepts: + stream.niceness = niceness return stream, nil case <-deadline: return nil, errTimeout @@ -148,19 +183,10 @@ func (s *Session) Close() (err error) { s.streams[k].sessionClose() } s.streamLock.Unlock() - s.notifyBucket() return s.conn.Close() } } -// notifyBucket notifies recvLoop that bucket is available -func (s *Session) notifyBucket() { - select { - case s.bucketNotify <- struct{}{}: - default: - } -} - // IsClosed does a safe check to see if we have shutdown func (s *Session) IsClosed() bool { select { @@ -191,20 +217,15 @@ func (s *Session) SetDeadline(t time.Time) error { // notify the session that a stream has closed func (s *Session) streamClosed(sid uint32) { s.streamLock.Lock() - if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket - if atomic.AddInt32(&s.bucket, int32(n)) > 0 { - s.notifyBucket() - } - } delete(s.streams, sid) s.streamLock.Unlock() } -// returnTokens is called by stream to return token after read -func (s *Session) returnTokens(n int) { - if atomic.AddInt32(&s.bucket, int32(n)) > 0 { - s.notifyBucket() - } +func (s *Session) queueAcks(streamId uint32, n int32) { + ack := newFrame(cmdACK, streamId) + ack.data = make([]byte, 4) + binary.BigEndian.PutUint32(ack.data, uint32(n)) + s.queueFrame(0, ack) } // session read a frame from underlying connection @@ -235,10 +256,6 @@ func (s *Session) readFrame(buffer []byte) (f Frame, err error) { func (s *Session) recvLoop() { buffer := make([]byte, (1<<16)+headerSize) for { - for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() { - <-s.bucketNotify - } - if f, err := s.readFrame(buffer); err == nil { atomic.StoreInt32(&s.dataReady, 1) @@ -247,7 +264,7 @@ func (s *Session) recvLoop() { case cmdSYN: s.streamLock.Lock() if _, ok := s.streams[f.sid]; !ok { - stream := newStream(f.sid, s.config.MaxFrameSize, s) + stream := newStream(f.sid, 255, s.config.MaxFrameSize, int32(s.config.MaxPerStreamReceiveBuffer), s) s.streams[f.sid] = stream select { case s.chAccepts <- stream: @@ -265,11 +282,17 @@ func (s *Session) recvLoop() { case cmdPSH: s.streamLock.Lock() if stream, ok := s.streams[f.sid]; ok { - atomic.AddInt32(&s.bucket, -int32(len(f.data))) stream.pushBytes(f.data) stream.notifyReadEvent() } s.streamLock.Unlock() + case cmdACK: + s.streamLock.Lock() + if stream, ok := s.streams[f.sid]; ok { + tokens := binary.BigEndian.Uint32(f.data) + stream.receiveAck(int32(tokens)) + } + s.streamLock.Unlock() default: s.Close() return @@ -289,8 +312,7 @@ func (s *Session) keepalive() { for { select { case <-tickerPing.C: - s.writeFrame(newFrame(cmdNOP, 0)) - s.notifyBucket() // force a signal to the recvLoop + s.writeFrame(0, newFrame(cmdNOP, 0)) case <-tickerTimeout.C: if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { s.Close() @@ -308,7 +330,11 @@ func (s *Session) sendLoop() { select { case <-s.die: return - case request := <-s.writes: + case <-s.writeTicket: + s.writesLock.Lock() + request := heap.Pop(&s.writes).(writeRequest) + s.writesLock.Unlock() + buf[0] = request.frame.ver buf[1] = request.frame.cmd binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) @@ -332,19 +358,31 @@ func (s *Session) sendLoop() { } } -// writeFrame writes the frame to the underlying connection -// and returns the number of bytes written if successful -func (s *Session) writeFrame(f Frame) (n int, err error) { +func (s *Session) queueFrame(niceness uint8, f Frame) (writeRequest, error) { req := writeRequest{ - frame: f, - result: make(chan writeResult, 1), + niceness: niceness, + sequence: atomic.AddUint64(&s.writeSequenceNum, 1), + frame: f, + result: make(chan writeResult, 1), } + s.writesLock.Lock() + heap.Push(&s.writes, req) + s.writesLock.Unlock() select { case <-s.die: - return 0, errors.New(errBrokenPipe) - case s.writes <- req: + return req, errors.New(errBrokenPipe) + case s.writeTicket <- struct{}{}: } + return req, nil +} +// writeFrame writes the frame to the underlying connection +// and returns the number of bytes written if successful +func (s *Session) writeFrame(niceness uint8, f Frame) (n int, err error) { + req, err := s.queueFrame(niceness, f) + if err != nil { + return 0, err + } result := <-req.result return result.n, result.err } diff --git a/session_test.go b/session_test.go index 760642d..b78cba3 100644 --- a/session_test.go +++ b/session_test.go @@ -11,6 +11,9 @@ import ( "sync" "testing" "time" + "container/heap" + "github.com/stretchr/testify/assert" + "io/ioutil" ) // setupServer starts new server listening on a random localhost port and @@ -58,6 +61,19 @@ func handleConnection(conn net.Conn) { } } +func TestWriteHeap(t *testing.T) { + var reqs writeHeap + req1 := writeRequest{niceness: 1} + heap.Push(&reqs, req1) + req3 := writeRequest{niceness: 3} + heap.Push(&reqs, req3) + req2 := writeRequest{niceness: 2} + heap.Push(&reqs, req2) + assert.Equal(t, heap.Pop(&reqs), req1) + assert.Equal(t, heap.Pop(&reqs), req2) + assert.Equal(t, heap.Pop(&reqs), req3) +} + func TestEcho(t *testing.T) { _, stop, cli, err := setupServer(t) if err != nil { @@ -461,7 +477,7 @@ func TestRandomFrame(t *testing.T) { session, _ = Client(cli, nil) for i := 0; i < 100; i++ { f := newFrame(cmdSYN, 1000) - session.writeFrame(f) + session.writeFrame(0, f) } cli.Close() @@ -474,7 +490,7 @@ func TestRandomFrame(t *testing.T) { session, _ = Client(cli, nil) for i := 0; i < 100; i++ { f := newFrame(allcmds[rand.Int()%len(allcmds)], rand.Uint32()) - session.writeFrame(f) + session.writeFrame(0, f) } cli.Close() @@ -486,7 +502,7 @@ func TestRandomFrame(t *testing.T) { session, _ = Client(cli, nil) for i := 0; i < 100; i++ { f := newFrame(byte(rand.Uint32()), rand.Uint32()) - session.writeFrame(f) + session.writeFrame(0, f) } cli.Close() @@ -499,7 +515,7 @@ func TestRandomFrame(t *testing.T) { for i := 0; i < 100; i++ { f := newFrame(byte(rand.Uint32()), rand.Uint32()) f.ver = byte(rand.Uint32()) - session.writeFrame(f) + session.writeFrame(0, f) } cli.Close() @@ -578,6 +594,81 @@ func TestWriteDeadline(t *testing.T) { session.Close() } +func TestSlowReceiverDoesNotBlock(t *testing.T) { + config := &Config{ + KeepAliveInterval: 10 * time.Second, + KeepAliveTimeout: 30 * time.Second, + MaxFrameSize: 100, + MaxPerStreamReceiveBuffer: 1000, + } + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + go func() { + err := func() error { + conn, err := ln.Accept() + if err != nil { + return err + } + defer conn.Close() + session, err := Server(conn, config) + if err != nil { + return err + } + defer session.Close() + // Accept stream1 but ready nothing from it + _, err = session.AcceptStream() + if err != nil { + return err + } + stream2, err := session.AcceptStream() + io.Copy(ioutil.Discard, stream2) + return nil + }() + if err != nil { + t.Error(err) + } + }() + + cli, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer cli.Close() + session, err := Client(cli, config) + if err != nil { + t.Fatal(err) + } + + // Open the first stream and write more than the receive buffer, verifying that it finishes with an error + go func() { + buf := make([]byte, 2*config.MaxPerStreamReceiveBuffer) + if stream, err := session.OpenStream(); err == nil { + _, err := stream.Write(buf) + assert.NotNil(t, err) + } else { + t.Fatal(err) + } + }() + + // Wait until the go routine above runs to create the first stream + // TODO(jnewman): make this less flaky + time.Sleep(100 * time.Millisecond) + + // The second stream can be written to + stream, err := session.OpenStream() + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 10*config.MaxPerStreamReceiveBuffer) + n, err := stream.Write(buf) + assert.Equal(t, n, 10*config.MaxPerStreamReceiveBuffer) + assert.Nil(t, err) + session.Close() +} + func BenchmarkAcceptClose(b *testing.B) { _, stop, cli, err := setupServer(b) if err != nil { diff --git a/stream.go b/stream.go index 57a0bc6..43b1fb7 100644 --- a/stream.go +++ b/stream.go @@ -9,31 +9,38 @@ import ( "time" "github.com/pkg/errors" + "container/heap" ) // Stream implements net.Conn type Stream struct { - id uint32 - rstflag int32 - sess *Session - buffer bytes.Buffer - bufferLock sync.Mutex - frameSize int - chReadEvent chan struct{} // notify a read event - die chan struct{} // flag the stream has closed - dieLock sync.Mutex - readDeadline atomic.Value - writeDeadline atomic.Value + id uint32 + niceness uint8 + rstflag int32 + sess *Session + buffer bytes.Buffer + bufferLock sync.Mutex + frameSize int + chReadEvent chan struct{} // notify a read event + die chan struct{} // flag the stream has closed + dieLock sync.Mutex + readDeadline atomic.Value + writeDeadline atomic.Value + writeTokenBucket int32 // write tokens required for writing to the sessions + writeTokenBucketNotify chan struct{} // used for waiting for tokens } // newStream initiates a Stream struct -func newStream(id uint32, frameSize int, sess *Session) *Stream { +func newStream(id uint32, niceness uint8, frameSize int, writeTokenBucketSize int32, sess *Session) *Stream { s := new(Stream) + s.niceness = niceness s.id = id s.chReadEvent = make(chan struct{}, 1) s.frameSize = frameSize s.sess = sess s.die = make(chan struct{}) + s.writeTokenBucket = writeTokenBucketSize + s.writeTokenBucketNotify = make(chan struct{}, 1) return s } @@ -66,7 +73,7 @@ READ: s.bufferLock.Unlock() if n > 0 { - s.sess.returnTokens(n) + s.sess.queueAcks(s.id, int32(n)) return n, nil } else if atomic.LoadInt32(&s.rstflag) == 1 { _ = s.Close() @@ -101,13 +108,31 @@ func (s *Stream) Write(b []byte) (n int, err error) { frames := s.split(b, cmdPSH, s.id) sent := 0 for k := range frames { + for atomic.LoadInt32(&s.writeTokenBucket) <= 0 { + select { + case <-s.writeTokenBucketNotify: + case <-s.die: + return sent, errors.New(errBrokenPipe) + case <-deadline: + return sent, errTimeout + } + } + + atomic.AddInt32(&s.writeTokenBucket, -int32(len(frames[k].data))) + req := writeRequest{ - frame: frames[k], - result: make(chan writeResult, 1), + niceness: s.niceness, + sequence: atomic.AddUint64(&s.sess.writeSequenceNum, 1), + frame: frames[k], + result: make(chan writeResult, 1), } + // TODO(jnewman): replace with session.writeFrame(..)? + s.sess.writesLock.Lock() + heap.Push(&s.sess.writes, req) + s.sess.writesLock.Unlock() select { - case s.sess.writes <- req: + case s.sess.writeTicket <- struct{}{}: case <-s.die: return sent, errors.New(errBrokenPipe) case <-deadline: @@ -141,7 +166,7 @@ func (s *Stream) Close() error { close(s.die) s.dieLock.Unlock() s.sess.streamClosed(s.id) - _, err := s.sess.writeFrame(newFrame(cmdFIN, s.id)) + _, err := s.sess.writeFrame(0, newFrame(cmdFIN, s.id)) return err } } @@ -214,13 +239,19 @@ func (s *Stream) pushBytes(p []byte) { s.bufferLock.Unlock() } -// recycleTokens transform remaining bytes to tokens(will truncate buffer) -func (s *Stream) recycleTokens() (n int) { - s.bufferLock.Lock() - n = s.buffer.Len() - s.buffer.Reset() - s.bufferLock.Unlock() - return +// receiveAck replenishes the token writeTokenBucket so that more writes can proceed +func (s *Stream) receiveAck(numTokens int32) { + if atomic.AddInt32(&s.writeTokenBucket, numTokens) > 0 { + s.notifyBucket() + } +} + +// notifyBucket notifies waiting write loops that there are more tokens in the writeTokenBucket +func (s *Stream) notifyBucket() { + select { + case s.writeTokenBucketNotify <- struct{}{}: + default: + } } // split large byte buffer into smaller frames, reference only