diff --git a/CREATING_TOXICS.md b/CREATING_TOXICS.md index ab39bc83..46672088 100644 --- a/CREATING_TOXICS.md +++ b/CREATING_TOXICS.md @@ -145,19 +145,16 @@ An implementation of the noop toxic above using the stream package would look so ```go func (t *NoopToxic) Pipe(stub *toxics.ToxicStub) { buf := make([]byte, 32*1024) - writer := stream.NewChanWriter(stub.Output) - reader := stream.NewChanReader(stub.Input) - reader.SetInterrupt(stub.Interrupt) for { - n, err := reader.Read(buf) + n, err := stub.Reader.Read(buf) if err == stream.ErrInterrupted { - writer.Write(buf[:n]) return } else if err == io.EOF { stub.Close() return } - writer.Write(buf[:n]) + stub.Writer.Write(buf[:n]) + stub.Reader.Checkpoint(0) } } ``` diff --git a/link.go b/link.go index 69d053cc..6bffe98f 100644 --- a/link.go +++ b/link.go @@ -97,6 +97,7 @@ func (link *ToxicLink) AddToxic(toxic *toxics.ToxicWrapper) { // Interrupt the last toxic so that we don't have a race when moving channels if link.stubs[i-1].InterruptToxic() { link.stubs[i-1].Output = newin + link.stubs[i-1].Writer.SetOutput(newin) if stateful, ok := toxic.Toxic.(toxics.StatefulToxic); ok { link.stubs[i].State = stateful.NewState() @@ -129,8 +130,11 @@ func (link *ToxicLink) RemoveToxic(toxic *toxics.ToxicWrapper) { stop <- link.stubs[i-1].InterruptToxic() }() - // Unblock the previous toxic if it is trying to flush - // If the previous toxic is closed, continue flusing until we reach the end. + // Flush toxic's internal buffer + link.stubs[i].Reader.FlushTo(link.stubs[i].Writer) + + // Unblock the previous toxic if it is trying to flush. + // If the previous toxic is closed, continue flushing until we reach the end. interrupted := false stopped := false for !interrupted { @@ -149,7 +153,7 @@ func (link *ToxicLink) RemoveToxic(toxic *toxics.ToxicWrapper) { } } - // Empty the toxic's buffer if necessary + // Empty the toxic's input buffer if necessary for len(link.stubs[i].Input) > 0 { tmp := <-link.stubs[i].Input if tmp == nil { @@ -160,6 +164,7 @@ func (link *ToxicLink) RemoveToxic(toxic *toxics.ToxicWrapper) { } link.stubs[i-1].Output = link.stubs[i].Output + link.stubs[i-1].Writer.SetOutput(link.stubs[i].Output) link.stubs = append(link.stubs[:i], link.stubs[i+1:]...) go link.stubs[i-1].Run(link.toxics.chain[link.direction][i-1]) diff --git a/stream/io_chan.go b/stream/io_chan.go index 4038c32e..c4c25fee 100644 --- a/stream/io_chan.go +++ b/stream/io_chan.go @@ -1,6 +1,7 @@ package stream import ( + "bytes" "fmt" "io" "time" @@ -32,12 +33,20 @@ func NewChanWriter(output chan<- *StreamChunk) *ChanWriter { // Write `buf` as a StreamChunk to the channel. The full buffer is always written, and error // will always be nil. Calling `Write()` after closing the channel will panic. func (c *ChanWriter) Write(buf []byte) (int, error) { + if len(buf) == 0 { + return 0, nil + } + packet := &StreamChunk{make([]byte, len(buf)), time.Now()} copy(packet.Data, buf) // Make a copy before sending it to the channel c.output <- packet return len(buf), nil } +func (c *ChanWriter) SetOutput(output chan<- *StreamChunk) { + c.output = output +} + // Close the output channel func (c *ChanWriter) Close() error { close(c.output) @@ -71,7 +80,7 @@ func (c *ChanReader) Read(out []byte) (int, error) { } n := copy(out, c.buffer) c.buffer = c.buffer[n:] - if len(out) <= len(c.buffer) { + if len(out) == n { return n, nil } else if n > 0 { // We have some data to return, so make the channel read optional @@ -106,3 +115,102 @@ func (c *ChanReader) Read(out []byte) (int, error) { c.buffer = p.Data[n2:] return n + n2, nil } + +// TransactionalReader is a ChanReader that can rollback its progress to checkpoints. +// This is useful when using other buffered readers, since they may read past the end of a message. +// The buffered reader can later be removed by rolling back any buffered bytes. +// +// chan []byte -> ChanReader -> TeeReader -> Read() -> output +// V ^ +// bytes.Buffer -> bytes.Reader +type TransactionalReader struct { + buffer *bytes.Buffer + bufReader *bytes.Reader + reader *ChanReader + tee io.Reader +} + +func NewTransactionalReader(input <-chan *StreamChunk) *TransactionalReader { + t := &TransactionalReader{ + buffer: bytes.NewBuffer(make([]byte, 0, 32*1024)), + reader: NewChanReader(input), + } + t.tee = io.TeeReader(t.reader, t.buffer) + return t +} + +// Reads from the input channel either directly, or from a buffer if Rollback() has been called. +// If the reader returns `ErrInterrupted`, it will automatically call Rollback() +func (t *TransactionalReader) Read(out []byte) (n int, err error) { + defer func() { + if err == ErrInterrupted { + t.Rollback() + } + }() + + if t.bufReader != nil { + n, err := t.bufReader.Read(out) + if err == io.EOF { + t.bufReader = nil + if n > 0 { + return n, nil + } else { + return t.tee.Read(out) + } + } + return n, err + } else { + return t.tee.Read(out) + } +} + +// Flushes all buffers past the current position in the reader to the specified writer. +func (t *TransactionalReader) FlushTo(writer io.Writer) { + n := 0 + if t.bufReader != nil { + n = t.bufReader.Len() + } + buf := make([]byte, n+len(t.reader.buffer)) + if n > 0 { + t.bufReader.Read(buf[:n]) + } + if len(buf[n:]) > 0 { + t.reader.Read(buf[n:]) + } + writer.Write(buf) + t.bufReader = nil + t.buffer.Reset() +} + +// Sets a checkpoint in the reader. A call to Rollback() will begin reading from this point. +// If offset is negative, the checkpoint will be set N bytes before the current position. +// If the offset is positive, the checkpoint will be set N bytes after the previous checkpoint. +// An offset of 0 will set the checkpoint to the current position. +func (t *TransactionalReader) Checkpoint(offset int) { + current := t.buffer.Len() + if t.bufReader != nil { + current = int(t.bufReader.Size()) - t.bufReader.Len() + } + + n := offset + if offset <= 0 { + n = current + offset + } + + if n >= t.buffer.Len() { + t.buffer.Reset() + } else { + t.buffer.Next(n) + } +} + +// Rolls back the reader to start from the last checkpoint. +func (t *TransactionalReader) Rollback() { + if t.buffer.Len() > 0 { + t.bufReader = bytes.NewReader(t.buffer.Bytes()) + } +} + +func (t *TransactionalReader) SetInterrupt(interrupt <-chan struct{}) { + t.reader.SetInterrupt(interrupt) +} diff --git a/stream/io_chan_test.go b/stream/io_chan_test.go index 9e140ef1..38933a7c 100644 --- a/stream/io_chan_test.go +++ b/stream/io_chan_test.go @@ -1,6 +1,7 @@ package stream import ( + "bufio" "bytes" "io" "testing" @@ -153,6 +154,55 @@ func TestMultiWriteWithCopy(t *testing.T) { } } +func TestMultiRead(t *testing.T) { + send := []byte("hello world") + c := make(chan *StreamChunk) + writer := NewChanWriter(c) + reader := NewChanReader(c) + passed := make(chan bool) + go func() { + writer.Write(send) + select { + case c <- &StreamChunk{[]byte("garbage"), time.Now()}: + case <-passed: + } + writer.Close() + }() + buf := make([]byte, len(send)) + + n, err := reader.Read(buf[:8]) + if n != 8 { + t.Fatalf("Read wrong number of bytes: %d expected 8", n) + } + if err != nil { + t.Fatal("Couldn't read from stream", err) + } + if !bytes.Equal(buf[:8], send[:8]) { + t.Fatal("Got wrong message from stream", string(buf[:8])) + } + time.Sleep(10 * time.Millisecond) + + n, err = reader.Read(buf[8:]) + if n != len(buf[8:]) { + t.Fatalf("Read wrong number of bytes: %d expected %d", n, len(buf[8:])) + } + if err != nil { + t.Fatal("Couldn't read from stream", err) + } + if !bytes.Equal(buf, send) { + t.Fatal("Got wrong message from stream", string(buf)) + } + + passed <- true + + n, err = reader.Read(buf) + if n != 0 { + t.Fatalf("Read from channel occured when it shouldn't have: %s", string(buf[:n])) + } else if err != io.EOF { + t.Fatal("Read returned wrong error after close:", err) + } +} + func TestReadInterrupt(t *testing.T) { send := []byte("hello world") c := make(chan *StreamChunk) @@ -199,3 +249,282 @@ func TestReadInterrupt(t *testing.T) { t.Fatal("Got wrong message from stream", string(buf)) } } + +func TestBlankWrite(t *testing.T) { + c := make(chan *StreamChunk, 2) + writer := NewChanWriter(c) + writer.Write([]byte{}) + writer.Write(nil) + writer.Close() + + for v := range c { + t.Fatalf("Unexpected write to channel: %+v", v) + } +} + +type TestReadWriter struct { + *TransactionalReader + *ChanWriter + input chan *StreamChunk + output chan *StreamChunk + closer chan struct{} +} + +func (c *TestReadWriter) Close() { + close(c.closer) +} + +func NewTestReadWriter() *TestReadWriter { + rw := &TestReadWriter{ + input: make(chan *StreamChunk, 2), + output: make(chan *StreamChunk, 1), + closer: make(chan struct{}), + } + rw.TransactionalReader = NewTransactionalReader(rw.input) + rw.ChanWriter = NewChanWriter(rw.output) + rw.input <- &StreamChunk{[]byte("hello world"), time.Now()} + rw.input <- &StreamChunk{[]byte("foobar"), time.Now()} + close(rw.input) + return rw +} + +func HandleError(t *testing.T, rw *TestReadWriter, err error, expectedErr error) { + if err != expectedErr { + t.Fatalf("Unexpected error during read: %v != %v", err, expectedErr) + } + if err == io.EOF { + rw.Rollback() + rw.FlushTo(rw) + rw.Close() + } +} + +func AssertRead(t *testing.T, rw *TestReadWriter, buf []byte, msg string, expectedErr error) { + n, err := rw.Read(buf) + + if n != len(msg) { + t.Fatalf("Read wrong number of bytes: %d expected %d (%s expected %s)", n, len(msg), string(buf[:n]), msg) + } + if !bytes.Equal(buf[:n], []byte(msg)) { + t.Fatalf("Got wrong message from stream: %s expected %s", string(buf[:n]), msg) + } + + HandleError(t, rw, err, expectedErr) +} + +func AssertClosed(t *testing.T, rw *TestReadWriter, expectedOutput []byte) { + select { + case msg := <-rw.output: + if expectedOutput == nil { + t.Fatal("Unexpected message written to output channel:", string(msg.Data)) + } else if !bytes.Equal(msg.Data, expectedOutput) { + t.Fatal("Wrong message written to output channel:", string(msg.Data), "expected", string(expectedOutput)) + } + default: + if expectedOutput != nil { + t.Fatal("Expected message to be written to output channel:", string(expectedOutput)) + } + } + + select { + case <-rw.closer: + default: + t.Fatal("Closer was not closed at end of stream") + } +} + +func TestReadWriterBasicFull(t *testing.T) { + rw := NewTestReadWriter() + buf := make([]byte, 32) + + AssertRead(t, rw, buf, "hello world", nil) + rw.Checkpoint(0) + AssertRead(t, rw, buf, "foobar", nil) + rw.Checkpoint(0) + AssertRead(t, rw, buf, "", io.EOF) + + AssertClosed(t, rw, nil) +} + +func TestReadWriterNoopRollback(t *testing.T) { + rw := NewTestReadWriter() + buf := make([]byte, 32) + + AssertRead(t, rw, buf, "hello world", nil) + AssertRead(t, rw, buf, "foobar", nil) + rw.Checkpoint(0) + rw.Rollback() + if rw.bufReader != nil { + t.Fatal("bufReader was set when it shouldn't have been") + } + AssertRead(t, rw, buf, "", io.EOF) + + AssertClosed(t, rw, nil) +} + +func TestReadWriterCheckpointRollback(t *testing.T) { + rw := NewTestReadWriter() + buf := make([]byte, 8) + + AssertRead(t, rw, buf, "hello wo", nil) + rw.Checkpoint(0) + AssertRead(t, rw, buf, "rldfooba", nil) + rw.Rollback() + AssertRead(t, rw, buf, "rldfooba", nil) + rw.Checkpoint(-3) + AssertRead(t, rw, buf, "r", nil) + rw.Rollback() + AssertRead(t, rw, buf, "obar", nil) + rw.Checkpoint(0) + AssertRead(t, rw, buf, "", io.EOF) + + AssertClosed(t, rw, nil) +} + +func TestReadWriterCheckpointMidBufReader(t *testing.T) { + rw := NewTestReadWriter() + buf := make([]byte, 8) + + AssertRead(t, rw, buf, "hello wo", nil) + AssertRead(t, rw, buf, "rldfooba", nil) + rw.Rollback() + AssertRead(t, rw, buf, "hello wo", nil) + rw.Checkpoint(0) + AssertRead(t, rw, buf, "rldfooba", nil) + rw.Rollback() + AssertRead(t, rw, buf, "rldfooba", nil) + AssertRead(t, rw, buf, "r", nil) + rw.Checkpoint(0) + AssertRead(t, rw, buf, "", io.EOF) + + AssertClosed(t, rw, nil) +} + +func TestReadWriterFlush(t *testing.T) { + rw := NewTestReadWriter() + buf := make([]byte, 8) + + AssertRead(t, rw, buf, "hello wo", nil) + rw.FlushTo(rw) + AssertRead(t, rw, buf, "foobar", nil) + rw.Checkpoint(0) + AssertRead(t, rw, buf, "", io.EOF) + + AssertClosed(t, rw, []byte("rld")) +} + +func TestReadWriterDoubleFlush(t *testing.T) { + rw := NewTestReadWriter() + buf := make([]byte, 8) + + AssertRead(t, rw, buf, "hello wo", nil) + rw.FlushTo(rw) + rw.FlushTo(rw) + AssertRead(t, rw, buf, "foobar", nil) + rw.FlushTo(rw) + AssertRead(t, rw, buf, "", io.EOF) + + AssertClosed(t, rw, []byte("rld")) +} + +func TestReadWriterNoCheckpoint(t *testing.T) { + rw := NewTestReadWriter() + buf := make([]byte, 32) + + AssertRead(t, rw, buf, "hello world", nil) + AssertRead(t, rw, buf, "foobar", nil) + AssertRead(t, rw, buf, "", io.EOF) + + AssertClosed(t, rw, []byte("hello worldfoobar")) +} + +func TestReadWriterInterrupt(t *testing.T) { + rw := &TestReadWriter{ + input: make(chan *StreamChunk, 1), + output: make(chan *StreamChunk, 1), + closer: make(chan struct{}), + } + rw.TransactionalReader = NewTransactionalReader(rw.input) + rw.ChanWriter = NewChanWriter(rw.output) + rw.input <- &StreamChunk{[]byte("hello world"), time.Now()} + + interrupt := make(chan struct{}) + rw.SetInterrupt(interrupt) + buf := make([]byte, 32) + sync := make(chan struct{}) + + AssertRead(t, rw, buf, "hello world", nil) + rw.Checkpoint(0) + go func() { + interrupt <- struct{}{} + rw.input <- &StreamChunk{[]byte("foobar"), time.Now()} + + <-sync + interrupt <- struct{}{} + close(rw.input) + }() + AssertRead(t, rw, buf, "", ErrInterrupted) + AssertRead(t, rw, buf, "foobar", nil) + + close(sync) + // Interrupt should rollback automatically + AssertRead(t, rw, buf, "", ErrInterrupted) + AssertRead(t, rw, buf, "foobar", nil) + rw.Checkpoint(0) + + AssertRead(t, rw, buf, "", io.EOF) + + AssertClosed(t, rw, nil) +} + +func TestReadWriterBufferedReader(t *testing.T) { + rw := NewTestReadWriter() + buf := make([]byte, 32) + reader := bufio.NewReader(rw) + msg, err := reader.ReadString(' ') + HandleError(t, rw, err, nil) + + if msg != "hello " { + t.Fatal("Buffered reader read wrong message:", msg) + } + if reader.Buffered() != 5 { + t.Fatal("Unexpected number of buffered bytes in reader:", reader.Buffered()) + } + + rw.Checkpoint(-reader.Buffered()) + rw.Rollback() + + AssertRead(t, rw, buf, "world", nil) + rw.Checkpoint(0) + AssertRead(t, rw, buf, "foobar", nil) + rw.Checkpoint(0) + AssertRead(t, rw, buf, "", io.EOF) + + AssertClosed(t, rw, nil) +} + +func TestReadWriterBufferedReaderAlternate(t *testing.T) { + rw := NewTestReadWriter() + buf := make([]byte, 32) + reader := bufio.NewReader(rw) + msg, err := reader.ReadString(' ') + HandleError(t, rw, err, nil) + + if msg != "hello " { + t.Fatal("Buffered reader read wrong message:", msg) + } + if reader.Buffered() != 5 { + t.Fatal("Unexpected number of buffered bytes in reader:", reader.Buffered()) + } + + rw.Checkpoint(len(msg)) + rw.Rollback() + + AssertRead(t, rw, buf, "world", nil) + rw.Checkpoint(0) + AssertRead(t, rw, buf, "foobar", nil) + rw.Checkpoint(0) + AssertRead(t, rw, buf, "", io.EOF) + + AssertClosed(t, rw, nil) +} diff --git a/toxics/toxic.go b/toxics/toxic.go index 342b7e07..fe004e8f 100644 --- a/toxics/toxic.go +++ b/toxics/toxic.go @@ -54,18 +54,24 @@ type ToxicStub struct { Input <-chan *stream.StreamChunk Output chan<- *stream.StreamChunk State interface{} + Reader *stream.TransactionalReader + Writer *stream.ChanWriter Interrupt chan struct{} running chan struct{} closed chan struct{} } func NewToxicStub(input <-chan *stream.StreamChunk, output chan<- *stream.StreamChunk) *ToxicStub { - return &ToxicStub{ + stub := &ToxicStub{ Interrupt: make(chan struct{}), closed: make(chan struct{}), Input: input, Output: output, + Reader: stream.NewTransactionalReader(input), + Writer: stream.NewChanWriter(output), } + stub.Reader.SetInterrupt(stub.Interrupt) + return stub } // Begin running a toxic on this stub, can be interrupted. @@ -93,6 +99,9 @@ func (s *ToxicStub) InterruptToxic() bool { } func (s *ToxicStub) Close() { + s.Reader.Rollback() + s.Reader.FlushTo(s.Writer) + close(s.closed) close(s.Output) }