diff --git a/pkg/ctxio/multicontext.go b/pkg/ctxio/multicontext.go new file mode 100644 index 0000000..a8309f6 --- /dev/null +++ b/pkg/ctxio/multicontext.go @@ -0,0 +1,28 @@ +package ctxio + +import ( + "context" + "sync" +) + +// MultiContext returns a context that is canceled when any of the provided contexts are canceled. +func MultiContext(ctxs ...context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + + var once sync.Once + for _, c := range ctxs { + c := c // Capture range variable + go func() { + select { + case <-c.Done(): + once.Do(func() { + cancel() + }) + case <-ctx.Done(): + // The merged context was canceled elsewhere + } + }() + } + + return ctx, cancel +} diff --git a/pkg/fugaci/provider.go b/pkg/fugaci/provider.go index 1f84477..b659e55 100644 --- a/pkg/fugaci/provider.go +++ b/pkg/fugaci/provider.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/macvmio/fugaci/pkg/ctxio" "github.com/macvmio/fugaci/pkg/curie" "github.com/macvmio/fugaci/pkg/portforwarder" "github.com/macvmio/fugaci/pkg/sshrunner" @@ -226,6 +227,8 @@ func (s *Provider) PortForward(ctx context.Context, namespace, podName string, p if err != nil { return fmt.Errorf("failed to find VM for pod %s/%s: %w", namespace, podName, err) } + ctx, cancel := ctxio.MultiContext(ctx, vm.cmdLifetimeCtx) + defer cancel() return vm.PortForward(ctx, port, stream) } @@ -234,6 +237,8 @@ func (s *Provider) AttachToContainer(ctx context.Context, namespace, podName, co if err != nil { return fmt.Errorf("failed to find VM for pod %s/%s: %w", namespace, podName, err) } + ctx, cancel := ctxio.MultiContext(ctx, vm.cmdLifetimeCtx) + defer cancel() return vm.AttachToContainer(ctx, attach) } diff --git a/pkg/fugaci/vm.go b/pkg/fugaci/vm.go index d3208f2..6b4133e 100644 --- a/pkg/fugaci/vm.go +++ b/pkg/fugaci/vm.go @@ -373,12 +373,14 @@ func (s *VM) Run() { err = s.waitAndRunCommandInside(s.cmdLifetimeCtx, startedAt.Time, containerID) - s.wg.Add(1) // This needs to be done on separate thread, because otherwise will result in defunct process, // and Stop() method will keep running + s.wg.Add(1) go s.stopContainer(containerID, startedAt.Time) err = runCmd.Wait() + s.cmdCancelFunc(nil) + s.storyLine.Add("container_exitcode", runCmd.ProcessState.ExitCode()) s.storyLine.Add("container_process_state", runCmd.ProcessState) if err != nil && runCmd.ProcessState.ExitCode() != 0 { @@ -388,6 +390,10 @@ func (s *VM) Run() { return } + err = s.streams.Close() + if err != nil { + s.storyLine.Add("streamsClosingErr", err) + } s.logger.Printf("container '%v' finished successfully: %v, exit code=%d\n", containerID, runCmd, runCmd.ProcessState.ExitCode()) s.safeUpdateState(v1.ContainerState{Terminated: &v1.ContainerStateTerminated{ ExitCode: int32(runCmd.ProcessState.ExitCode()), @@ -487,10 +493,6 @@ func (s *VM) Cleanup() error { } return err2 } - err = s.streams.Cleanup() - if err != nil { - s.storyLine.Add("streamsCleanUpErr", err) - } return nil } diff --git a/pkg/streams/streams.go b/pkg/streams/streams.go index 4947536..875265d 100644 --- a/pkg/streams/streams.go +++ b/pkg/streams/streams.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/macvmio/fugaci/pkg/ctxio" "github.com/virtual-kubelet/virtual-kubelet/node/api" + "golang.org/x/sync/errgroup" "io" "os" "sync" @@ -74,38 +75,32 @@ func (f *FilesBasedStreams) Resize() <-chan api.TermSize { return f.termSizeCh } -// Cleanup removes the temporary files created for stdin, stdout, and stderr. -// It is safe to call Cleanup multiple times concurrently. -func (f *FilesBasedStreams) Cleanup() error { +// Close removes the temporary files created for stdin, stdout, and stderr. +// It is safe to call Close multiple times concurrently. +func (f *FilesBasedStreams) Close() error { var errs []error + // Wait for any ongoing operations to finish + f.cleanupWG.Wait() - f.cleanOnce.Do(func() { - f.mu.Lock() - defer f.mu.Unlock() - - // Wait for any ongoing operations to finish - f.cleanupWG.Wait() - - if err := f.stdoutFile.Close(); err != nil { - errs = append(errs, fmt.Errorf("failed to close stdout file: %w", err)) - } - if err := os.Remove(f.stdoutFile.Name()); err != nil { - errs = append(errs, fmt.Errorf("failed to remove stdout file: %w", err)) - } + if err := f.stdoutFile.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close stdout file: %w", err)) + } + if err := os.Remove(f.stdoutFile.Name()); err != nil { + errs = append(errs, fmt.Errorf("failed to remove stdout file: %w", err)) + } - if err := f.stderrFile.Close(); err != nil { - errs = append(errs, fmt.Errorf("failed to close stderr file: %w", err)) - } - if err := os.Remove(f.stderrFile.Name()); err != nil { - errs = append(errs, fmt.Errorf("failed to remove stderr file: %w", err)) - } + if err := f.stderrFile.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close stderr file: %w", err)) + } + if err := os.Remove(f.stderrFile.Name()); err != nil { + errs = append(errs, fmt.Errorf("failed to remove stderr file: %w", err)) + } - // Close termSizeCh - if f.termSizeCh != nil { - close(f.termSizeCh) - f.termSizeCh = nil - } - }) + // Close termSizeCh + if f.termSizeCh != nil { + close(f.termSizeCh) + f.termSizeCh = nil + } if len(errs) > 0 { return fmt.Errorf("cleanup encountered errors: %v", errs) @@ -114,66 +109,93 @@ func (f *FilesBasedStreams) Cleanup() error { } func (f *FilesBasedStreams) Stream(ctx context.Context, attach api.AttachIO, loggerPrintf func(format string, v ...any)) error { - f.cleanupWG.Add(1) allowableError := func(err error) bool { - if err == nil { - return true - } - return errors.Is(err, context.Canceled) || errors.Is(err, io.EOF) + return err == nil || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF) } - go func() { - defer f.cleanupWG.Done() - // Start streaming stdout - if attach.Stdout() != nil { - if err := followFileStream(ctx, attach.Stdout(), f.stdoutFile.Name(), loggerPrintf); !allowableError(err) { + + // Create an errgroup with the provided context + eg, ctx := errgroup.WithContext(ctx) + + // Start streaming stdout + if attach.Stdout() != nil { + f.cleanupWG.Add(1) + eg.Go(func() error { + defer f.cleanupWG.Done() + err := followFileStream(ctx, attach.Stdout(), f.stdoutFile.Name(), loggerPrintf) + if !allowableError(err) { loggerPrintf("Error streaming stdout: %v", err) + return err } - } - }() - - f.cleanupWG.Add(1) - go func() { - defer f.cleanupWG.Done() - // Start streaming stderr - if attach.Stderr() != nil { - if err := followFileStream(ctx, attach.Stderr(), f.stderrFile.Name(), loggerPrintf); !allowableError(err) { + loggerPrintf("stdout copy completed") + return nil + }) + } + + // Start streaming stderr + if attach.Stderr() != nil { + f.cleanupWG.Add(1) + eg.Go(func() error { + defer f.cleanupWG.Done() + err := followFileStream(ctx, attach.Stderr(), f.stderrFile.Name(), loggerPrintf) + if !allowableError(err) { loggerPrintf("Error streaming stderr: %v", err) + return err } - } - }() + loggerPrintf("stderr copy completed") + return nil + }) + } // Handle stdin if f.stdinWriter != nil && attach.Stdin() != nil { f.cleanupWG.Add(1) go func() { defer f.cleanupWG.Done() + // TODO: This blocks until if stdin has no data, even if context is cancelled _, err := io.Copy(f.stdinWriter, attach.Stdin()) if !allowableError(err) { loggerPrintf("Error streaming stdin: %v", err) } + loggerPrintf("stdin copy completed") }() } + // Handle TTY resize events if attach.TTY() { f.cleanupWG.Add(1) - go func() { + eg.Go(func() error { defer f.cleanupWG.Done() - for termSize := range attach.Resize() { - f.termSizeCh <- termSize + defer loggerPrintf("attach tty channel completed") + for { + select { + case termSize, ok := <-attach.Resize(): + if !ok { + // The attach.Resize() channel is closed + return nil + } + select { + case f.termSizeCh <- termSize: + case <-ctx.Done(): + return ctx.Err() + } + case <-ctx.Done(): + return ctx.Err() + } } - }() + }) } - // Wait for context cancellation - loggerPrintf("waiting for Stream to finish") - <-ctx.Done() - loggerPrintf("Stream has completed") - return nil + + loggerPrintf("waiting for Stream() to finish") + err := eg.Wait() + loggerPrintf("Stream() has completed") + return err } -func followFileStream(ctx context.Context, writer io.Writer, filename string, loggerPrintf func(format string, v ...any)) error { +func followFileStream(ctx context.Context, writer io.WriteCloser, filename string, loggerPrintf func(format string, v ...any)) error { if writer == nil { return fmt.Errorf("writer cannot be nil") } + defer writer.Close() tailReader, err := ctxio.NewTailReader(ctx, filename) if err != nil { diff --git a/pkg/streams/streams_test.go b/pkg/streams/streams_test.go new file mode 100644 index 0000000..5b64531 --- /dev/null +++ b/pkg/streams/streams_test.go @@ -0,0 +1,422 @@ +package streams + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/virtual-kubelet/virtual-kubelet/node/api" +) + +func setupFilesBasedStreams(t *testing.T, allocateStdin, allocateTTY bool) *FilesBasedStreams { + t.Helper() + fbs, err := NewFilesBasedStreams(t.TempDir(), "test", allocateStdin, allocateTTY) + if err != nil { + t.Fatalf("Failed to create FilesBasedStreams: %v", err) + } + return fbs +} + +func teardownFilesBasedStreams(t *testing.T, fbs *FilesBasedStreams) { + t.Helper() + if err := fbs.Close(); err != nil { + t.Errorf("Failed to close FilesBasedStreams: %v", err) + } +} + +func TestNewFilesBasedStreams(t *testing.T) { + fbs := setupFilesBasedStreams(t, true, true) + defer teardownFilesBasedStreams(t, fbs) + + if fbs.stdoutFile == nil { + t.Error("stdoutFile should not be nil") + } + if fbs.stderrFile == nil { + t.Error("stderrFile should not be nil") + } + if fbs.stdinReader == nil || fbs.stdinWriter == nil { + t.Error("stdinReader and stdinWriter should not be nil when allocateStdin is true") + } + if fbs.allocateTTY && fbs.termSizeCh == nil { + t.Error("termSizeCh should not be nil when allocateTTY is true") + } +} + +func TestStreamStdout(t *testing.T) { + fbs := setupFilesBasedStreams(t, false, false) + defer teardownFilesBasedStreams(t, fbs) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Prepare mock attachIO + stdoutBuf := &bytes.Buffer{} + attachIO := &MockAttachIO{ + stdout: stdoutBuf, + } + + // Write data to stdoutFile + expectedOutput := "Hello, stdout!" + _, err := fbs.stdoutFile.WriteString(expectedOutput + "\n") + if err != nil { + t.Fatalf("Failed to write to stdoutFile: %v", err) + } + + // Start streaming + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + err := fbs.Stream(ctx, attachIO, t.Logf) + if err != nil { + t.Errorf("Stream returned error: %v", err) + } + }() + + // Give some time for the data to be streamed + time.Sleep(25 * time.Millisecond) + + // Cancel context to stop streaming + cancel() + wg.Wait() + + // Verify that data was received + output := stdoutBuf.String() + if !strings.Contains(output, expectedOutput) { + t.Errorf("Expected output %q in stdout, got %q", expectedOutput, output) + } +} + +func TestStreamStderr(t *testing.T) { + fbs := setupFilesBasedStreams(t, false, false) + defer teardownFilesBasedStreams(t, fbs) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Prepare mock attachIO + stderrBuf := &bytes.Buffer{} + attachIO := &MockAttachIO{ + stderr: stderrBuf, + } + // Write data to stderrFile + expectedOutput := "Hello, stderr!" + _, err := fbs.stderrFile.WriteString(expectedOutput + "\n") + if err != nil { + t.Fatalf("Failed to write to stderrFile: %v", err) + } + + // Start streaming + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + err := fbs.Stream(ctx, attachIO, t.Logf) + if err != nil { + t.Errorf("Stream returned error: %v", err) + } + }() + + // Give some time for the data to be streamed + time.Sleep(25 * time.Millisecond) + + // Cancel context to stop streaming + cancel() + wg.Wait() + + // Verify that data was received + output := stderrBuf.String() + if !strings.Contains(output, expectedOutput) { + t.Errorf("Expected output %q in stderr, got %q", expectedOutput, output) + } + +} + +func TestStreamStdin(t *testing.T) { + fbs := setupFilesBasedStreams(t, true, false) + defer teardownFilesBasedStreams(t, fbs) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Prepare mock attachIO + stdinData := "Hello, stdin!\n" + stdinBuf := bytes.NewBufferString(stdinData) + attachIO := &MockAttachIO{ + stdin: stdinBuf, + } + + // Start streaming + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + err := fbs.Stream(ctx, attachIO, t.Logf) + if err != nil { + t.Errorf("Stream returned error: %v", err) + } + }() + + // Read data from stdinReader + receivedData := make([]byte, len(stdinData)) + _, err := io.ReadFull(fbs.stdinReader, receivedData) + if err != nil { + t.Fatalf("Failed to read from stdinReader: %v", err) + } + + // Verify that data matches + if string(receivedData) != stdinData { + t.Errorf("Expected input %q, got %q", stdinData, string(receivedData)) + } + + // Cancel context to stop streaming + cancel() + wg.Wait() +} + +func TestContextCancellation(t *testing.T) { + fbs := setupFilesBasedStreams(t, true, false) + defer teardownFilesBasedStreams(t, fbs) + + ctx, cancel := context.WithCancel(context.Background()) + + // Prepare mock attachIO + attachIO := &MockAttachIO{ + stdin: &bytes.Buffer{}, + stdout: &bytes.Buffer{}, + stderr: &bytes.Buffer{}, + } + + // Start streaming + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + err := fbs.Stream(ctx, attachIO, t.Logf) + if err != nil { + t.Errorf("Stream returned error: %v", err) + } + }() + + // Cancel context after a short delay + time.Sleep(25 * time.Millisecond) + cancel() + + // Wait for Stream to exit + select { + case <-doneCh: + // Success + case <-time.After(1 * time.Second): + t.Error("Stream did not exit after context cancellation") + } +} + +func TestClose(t *testing.T) { + fbs := setupFilesBasedStreams(t, true, false) + + // Start some operation to ensure cleanup is necessary + fbs.cleanupWG.Add(1) + go func() { + defer fbs.cleanupWG.Done() + time.Sleep(25 * time.Millisecond) + }() + + err := fbs.Close() + if err != nil { + t.Errorf("Close returned error: %v", err) + } + + // Verify that files are closed and removed + if _, err := os.Stat(fbs.stdoutFile.Name()); !os.IsNotExist(err) { + t.Errorf("stdoutFile was not removed") + } + if _, err := os.Stat(fbs.stderrFile.Name()); !os.IsNotExist(err) { + t.Errorf("stderrFile was not removed") + } +} + +func TestTTYResizeEvents(t *testing.T) { + fbs := setupFilesBasedStreams(t, false, true) + defer teardownFilesBasedStreams(t, fbs) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Prepare mock attachIO + resizeCh := make(chan api.TermSize, 1) + attachIO := &MockAttachIO{ + stdout: &bytes.Buffer{}, + resizeCh: resizeCh, + tty: true, + } + + // Start streaming + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + err := fbs.Stream(ctx, attachIO, t.Logf) + if err != nil { + t.Errorf("Stream returned error: %v", err) + } + }() + + // Send a resize event + expectedSize := api.TermSize{Width: 80, Height: 24} + resizeCh <- expectedSize + + // Receive the resize event + select { + case receivedSize, ok := <-fbs.Resize(): + if !ok { + t.Error("termSizeCh was closed unexpectedly") + } + if receivedSize != expectedSize { + t.Errorf("Expected term size %v, got %v", expectedSize, receivedSize) + } + case <-time.After(1 * time.Second): + t.Error("Did not receive term size event") + } + + // Close resize channel to simulate end of events + close(resizeCh) + + // Cancel context to stop streaming + cancel() + wg.Wait() +} + +func TestGoroutinesExit(t *testing.T) { + fbs := setupFilesBasedStreams(t, false, false) + defer teardownFilesBasedStreams(t, fbs) + + ctx, cancel := context.WithCancel(context.Background()) + + // Prepare mock attachIO + attachIO := &MockAttachIO{ + stdout: &bytes.Buffer{}, + stderr: &bytes.Buffer{}, + } + + // Start streaming + go func() { + _ = fbs.Stream(ctx, attachIO, t.Logf) + }() + + // Cancel context to stop streaming + cancel() + // Will be closed by teardown +} + +func TestStreamStdoutError(t *testing.T) { + fbs := setupFilesBasedStreams(t, false, false) + defer teardownFilesBasedStreams(t, fbs) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Prepare mock attachIO with an ErrorWriter + errWriter := &ErrorWriter{Err: fmt.Errorf("write error")} + attachIO := &MockAttachIO{ + stdout: errWriter, + stderr: &bytes.Buffer{}, + } + + _, err := fbs.stdoutFile.Write([]byte("Hello, stdout!")) + if err != nil { + t.Errorf("Failed to write to stdoutFile: %v", err) + } + go func() { + time.Sleep(25 * time.Millisecond) + cancel() + }() + // Start streaming + err = fbs.Stream(ctx, attachIO, t.Logf) + if err == nil { + t.Error("Expected error from Stream, got nil") + } +} + +func TestStreamStderrError(t *testing.T) { + fbs := setupFilesBasedStreams(t, false, false) + defer teardownFilesBasedStreams(t, fbs) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Prepare mock attachIO with an ErrorWriter + errWriter := &ErrorWriter{Err: fmt.Errorf("write error")} + attachIO := &MockAttachIO{ + stdout: &bytes.Buffer{}, + stderr: errWriter, + } + + _, err := fbs.stderrFile.Write([]byte("Hello, stdout!")) + if err != nil { + t.Errorf("Failed to write to stdoutFile: %v", err) + } + go func() { + time.Sleep(25 * time.Millisecond) + cancel() + }() + // Start streaming + err = fbs.Stream(ctx, attachIO, t.Logf) + if err == nil { + t.Error("Expected error from Stream, got nil") + } +} + +// MockAttachIO implements api.AttachIO for testing purposes +type MockAttachIO struct { + stdin io.Reader + stdout io.Writer + stderr io.Writer + resizeCh chan api.TermSize + tty bool +} + +func (m *MockAttachIO) Stdin() io.Reader { + return m.stdin +} + +func (m *MockAttachIO) Stdout() io.WriteCloser { + if wc, ok := m.stdout.(io.WriteCloser); ok { + return wc + } + return nopWriteCloser{m.stdout} +} + +func (m *MockAttachIO) Stderr() io.WriteCloser { + if wc, ok := m.stderr.(io.WriteCloser); ok { + return wc + } + return nopWriteCloser{m.stderr} +} + +func (m *MockAttachIO) TTY() bool { + return m.tty +} + +func (m *MockAttachIO) Resize() <-chan api.TermSize { + return m.resizeCh +} + +type nopWriteCloser struct { + io.Writer +} + +func (nopWriteCloser) Close() error { return nil } + +// ErrorWriter is an io.Writer that returns an error on Write +type ErrorWriter struct { + Err error +} + +func (w *ErrorWriter) Write(p []byte) (int, error) { + return 0, w.Err +}