Skip to content

Commit

Permalink
Simplify error handling in Stream() to ensure immediate disconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
tomekjarosik committed Nov 28, 2024
1 parent 9dce3b5 commit d462238
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 23 deletions.
21 changes: 3 additions & 18 deletions pkg/streams/streams.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,6 @@ func (f *FilesBasedStreams) Close() error {
}

func (f *FilesBasedStreams) Stream(ctx context.Context, attach api.AttachIO, loggerPrintf func(format string, v ...any)) error {
allowableError := func(err error) bool {
return err == nil || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF)
}

// Create an errgroup with the provided context
eg, ctx := errgroup.WithContext(ctx)

Expand All @@ -122,12 +118,8 @@ func (f *FilesBasedStreams) Stream(ctx context.Context, attach api.AttachIO, log
eg.Go(func() error {
defer f.cleanupWG.Done()
err := followFileStream(ctx, attach.Stdout(), f.stdoutFile.Name(), loggerPrintf)
if !allowableError(err) {
loggerPrintf("Error streaming stdout: %v", err)
return err
}
loggerPrintf("stdout copy completed")
return nil
return err
})
}

Expand All @@ -137,12 +129,8 @@ func (f *FilesBasedStreams) Stream(ctx context.Context, attach api.AttachIO, log
eg.Go(func() error {
defer f.cleanupWG.Done()
err := followFileStream(ctx, attach.Stderr(), f.stderrFile.Name(), loggerPrintf)
if !allowableError(err) {
loggerPrintf("Error streaming stderr: %v", err)
return err
}
loggerPrintf("stderr copy completed")
return nil
return err
})
}

Expand All @@ -153,10 +141,7 @@ func (f *FilesBasedStreams) Stream(ctx context.Context, attach api.AttachIO, log
defer f.cleanupWG.Done()
// TODO: This blocks until if stdin has no data, even if context is cancelled
_, err := io.Copy(f.stdinWriter, attach.Stdin())
if !allowableError(err) {
loggerPrintf("Error streaming stdin: %v", err)
}
loggerPrintf("stdin copy completed")
loggerPrintf("stdin copy completed: %v", err)
}()
}

Expand Down
11 changes: 6 additions & 5 deletions pkg/streams/streams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package streams
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -74,7 +75,7 @@ func TestStreamStdout(t *testing.T) {
go func() {
defer wg.Done()
err := fbs.Stream(ctx, attachIO, t.Logf)
if err != nil {
if !errors.Is(err, context.Canceled) {
t.Errorf("Stream returned error: %v", err)
}
}()
Expand Down Expand Up @@ -118,7 +119,7 @@ func TestStreamStderr(t *testing.T) {
go func() {
defer wg.Done()
err := fbs.Stream(ctx, attachIO, t.Logf)
if err != nil {
if !errors.Is(err, context.Canceled) {
t.Errorf("Stream returned error: %v", err)
}
}()
Expand Down Expand Up @@ -158,7 +159,7 @@ func TestStreamStdin(t *testing.T) {
go func() {
defer wg.Done()
err := fbs.Stream(ctx, attachIO, t.Logf)
if err != nil {
if !errors.Is(err, context.Canceled) {
t.Errorf("Stream returned error: %v", err)
}
}()
Expand Down Expand Up @@ -198,7 +199,7 @@ func TestContextCancellation(t *testing.T) {
go func() {
defer close(doneCh)
err := fbs.Stream(ctx, attachIO, t.Logf)
if err != nil {
if !errors.Is(err, context.Canceled) {
t.Errorf("Stream returned error: %v", err)
}
}()
Expand Down Expand Up @@ -261,7 +262,7 @@ func TestTTYResizeEvents(t *testing.T) {
go func() {
defer wg.Done()
err := fbs.Stream(ctx, attachIO, t.Logf)
if err != nil {
if !errors.Is(err, context.Canceled) {
t.Errorf("Stream returned error: %v", err)
}
}()
Expand Down

0 comments on commit d462238

Please sign in to comment.