diff --git a/session.go b/session.go index 8f4ad61..88015e3 100644 --- a/session.go +++ b/session.go @@ -17,6 +17,14 @@ const ( openCloseTimeout = 30 * time.Second // stream open/close timeout ) +// define frame class +type CLASSID int + +const ( + CLSCTRL CLASSID = iota + CLSDATA +) + var ( ErrInvalidProtocol = errors.New("invalid protocol") ErrConsumed = errors.New("peer consumed more than sent") @@ -26,7 +34,7 @@ var ( ) type writeRequest struct { - prio uint32 + class CLASSID frame Frame seq uint32 result chan writeResult @@ -396,7 +404,7 @@ func (s *Session) keepalive() { for { select { case <-tickerPing.C: - s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, 0) + s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, CLSCTRL) s.notifyBucket() // force a signal to the recvLoop case <-tickerTimeout.C: if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { @@ -515,13 +523,13 @@ func (s *Session) sendLoop() { // writeFrame writes the frame to the underlying connection // and returns the number of bytes written if successful func (s *Session) writeFrame(f Frame) (n int, err error) { - return s.writeFrameInternal(f, time.After(openCloseTimeout), 0) + return s.writeFrameInternal(f, time.After(openCloseTimeout), CLSCTRL) } // internal writeFrame version to support deadline used in keepalive -func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint32) (int, error) { +func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, class CLASSID) (int, error) { req := writeRequest{ - prio: prio, + class: class, frame: f, seq: atomic.AddUint32(&s.requestID, 1), result: make(chan writeResult, 1), diff --git a/session_test.go b/session_test.go index 3479570..1c49ae7 100644 --- a/session_test.go +++ b/session_test.go @@ -867,7 +867,7 @@ func TestWriteFrameInternal(t *testing.T) { session.Close() for i := 0; i < 100; i++ { f := newFrame(1, byte(rand.Uint32()), rand.Uint32()) - session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), 0) + session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), CLSDATA) } // random cmds @@ -879,14 +879,14 @@ func TestWriteFrameInternal(t *testing.T) { session, _ = Client(cli, nil) for i := 0; i < 100; i++ { f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32()) - session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), 0) + session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), CLSDATA) } //deadline occur { c := make(chan time.Time) close(c) f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32()) - _, err := session.writeFrameInternal(f, c, 0) + _, err := session.writeFrameInternal(f, c, CLSDATA) if !strings.Contains(err.Error(), "timeout") { t.Fatal("write frame with deadline failed", err) } @@ -911,7 +911,7 @@ func TestWriteFrameInternal(t *testing.T) { time.Sleep(time.Second) close(c) }() - _, err = session.writeFrameInternal(f, c, 0) + _, err = session.writeFrameInternal(f, c, CLSDATA) if !strings.Contains(err.Error(), "closed pipe") { t.Fatal("write frame with to closed conn failed", err) } diff --git a/shaper.go b/shaper.go index 35773ee..8d52ef7 100644 --- a/shaper.go +++ b/shaper.go @@ -6,8 +6,14 @@ func _itimediff(later, earlier uint32) int32 { type shaperHeap []writeRequest -func (h shaperHeap) Len() int { return len(h) } -func (h shaperHeap) Less(i, j int) bool { return _itimediff(h[j].seq, h[i].seq) > 0 } +func (h shaperHeap) Len() int { return len(h) } +func (h shaperHeap) Less(i, j int) bool { + if h[i].class != h[j].class { + return h[i].class < h[j].class + } + return _itimediff(h[j].seq, h[i].seq) > 0 +} + func (h shaperHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h *shaperHeap) Push(x interface{}) { *h = append(*h, x.(writeRequest)) } diff --git a/shaper_test.go b/shaper_test.go index 95d0079..b02e317 100644 --- a/shaper_test.go +++ b/shaper_test.go @@ -6,11 +6,11 @@ import ( ) func TestShaper(t *testing.T) { - w1 := writeRequest{prio: 10} - w2 := writeRequest{prio: 10} - w3 := writeRequest{prio: 20} - w4 := writeRequest{prio: 100} - w5 := writeRequest{prio: (1 << 32) - 1} + w1 := writeRequest{seq: 1} + w2 := writeRequest{seq: 2} + w3 := writeRequest{seq: 3} + w4 := writeRequest{seq: 4} + w5 := writeRequest{seq: 5} var reqs shaperHeap heap.Push(&reqs, w5) @@ -19,25 +19,20 @@ func TestShaper(t *testing.T) { heap.Push(&reqs, w2) heap.Push(&reqs, w1) - var lastPrio = reqs[0].prio for len(reqs) > 0 { w := heap.Pop(&reqs).(writeRequest) - if int32(w.prio-lastPrio) < 0 { - t.Fatal("incorrect shaper priority") - } - - t.Log("prio:", w.prio) - lastPrio = w.prio + t.Log("sid:", w.frame.sid, "seq:", w.seq) } } func TestShaper2(t *testing.T) { - w1 := writeRequest{prio: 10, seq: 1} // stream 0 - w2 := writeRequest{prio: 10, seq: 2} - w3 := writeRequest{prio: 20, seq: 3} - w4 := writeRequest{prio: 100, seq: 4} - w5 := writeRequest{prio: (1 << 32) - 1, seq: 5} - w6 := writeRequest{prio: 10, seq: 1, frame: Frame{sid: 10}} // stream 1 + w1 := writeRequest{class: CLSDATA, seq: 1} // stream 0 + w2 := writeRequest{class: CLSDATA, seq: 2} + w3 := writeRequest{class: CLSDATA, seq: 3} + w4 := writeRequest{class: CLSDATA, seq: 4} + w5 := writeRequest{class: CLSDATA, seq: 5} + w6 := writeRequest{class: CLSCTRL, seq: 6, frame: Frame{sid: 10}} // ctrl 1 + w7 := writeRequest{class: CLSCTRL, seq: 7, frame: Frame{sid: 11}} // ctrl 2 var reqs shaperHeap heap.Push(&reqs, w6) @@ -46,9 +41,10 @@ func TestShaper2(t *testing.T) { heap.Push(&reqs, w3) heap.Push(&reqs, w2) heap.Push(&reqs, w1) + heap.Push(&reqs, w7) for len(reqs) > 0 { w := heap.Pop(&reqs).(writeRequest) - t.Log("prio:", w.prio, "sid:", w.frame.sid, "seq:", w.seq) + t.Log("sid:", w.frame.sid, "seq:", w.seq) } } diff --git a/stream.go b/stream.go index 94e858e..0d7e045 100644 --- a/stream.go +++ b/stream.go @@ -255,7 +255,7 @@ func (s *Stream) sendWindowUpdate(consumed uint32) error { binary.LittleEndian.PutUint32(hdr[:], consumed) binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.config.MaxStreamBuffer)) frame.data = hdr[:] - _, err := s.sess.writeFrameInternal(frame, deadline, 0) + _, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA) return err } @@ -325,7 +325,7 @@ func (s *Stream) Write(b []byte) (n int, err error) { } frame.data = bts[:sz] bts = bts[sz:] - n, err := s.sess.writeFrameInternal(frame, deadline, 0) + n, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA) s.numWritten++ sent += n if err != nil { @@ -393,7 +393,7 @@ func (s *Stream) writeV2(b []byte) (n int, err error) { } frame.data = bts[:sz] bts = bts[sz:] - n, err := s.sess.writeFrameInternal(frame, deadline, 0) + n, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA) atomic.AddUint32(&s.numWritten, uint32(sz)) sent += n if err != nil {