Skip to content

Commit

Permalink
Support attaching to Stdin and TTY from multiple clients at the same …
Browse files Browse the repository at this point in the history
…time
  • Loading branch information
tomekjarosik committed Nov 28, 2024
1 parent 01b23b9 commit 2b68f7c
Show file tree
Hide file tree
Showing 5 changed files with 543 additions and 64 deletions.
28 changes: 28 additions & 0 deletions pkg/ctxio/multicontext.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package ctxio

import (
"context"
"sync"
)

// MultiContext returns a context that is canceled when any of the provided contexts are canceled.
func MultiContext(ctxs ...context.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())

var once sync.Once
for _, c := range ctxs {
c := c // Capture range variable
go func() {
select {
case <-c.Done():
once.Do(func() {
cancel()
})
case <-ctx.Done():
// The merged context was canceled elsewhere
}
}()
}

return ctx, cancel
}
5 changes: 5 additions & 0 deletions pkg/fugaci/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/macvmio/fugaci/pkg/ctxio"
"github.com/macvmio/fugaci/pkg/curie"
"github.com/macvmio/fugaci/pkg/portforwarder"
"github.com/macvmio/fugaci/pkg/sshrunner"
Expand Down Expand Up @@ -226,6 +227,8 @@ func (s *Provider) PortForward(ctx context.Context, namespace, podName string, p
if err != nil {
return fmt.Errorf("failed to find VM for pod %s/%s: %w", namespace, podName, err)
}
ctx, cancel := ctxio.MultiContext(ctx, vm.cmdLifetimeCtx)
defer cancel()
return vm.PortForward(ctx, port, stream)
}

Expand All @@ -234,6 +237,8 @@ func (s *Provider) AttachToContainer(ctx context.Context, namespace, podName, co
if err != nil {
return fmt.Errorf("failed to find VM for pod %s/%s: %w", namespace, podName, err)
}
ctx, cancel := ctxio.MultiContext(ctx, vm.cmdLifetimeCtx)
defer cancel()
return vm.AttachToContainer(ctx, attach)
}

Expand Down
12 changes: 7 additions & 5 deletions pkg/fugaci/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,14 @@ func (s *VM) Run() {

err = s.waitAndRunCommandInside(s.cmdLifetimeCtx, startedAt.Time, containerID)

s.wg.Add(1)
// This needs to be done on separate thread, because otherwise will result in defunct process,
// and Stop() method will keep running
s.wg.Add(1)
go s.stopContainer(containerID, startedAt.Time)

err = runCmd.Wait()
s.cmdCancelFunc(nil)

s.storyLine.Add("container_exitcode", runCmd.ProcessState.ExitCode())
s.storyLine.Add("container_process_state", runCmd.ProcessState)
if err != nil && runCmd.ProcessState.ExitCode() != 0 {
Expand All @@ -388,6 +390,10 @@ func (s *VM) Run() {
return
}

err = s.streams.Close()
if err != nil {
s.storyLine.Add("streamsClosingErr", err)
}
s.logger.Printf("container '%v' finished successfully: %v, exit code=%d\n", containerID, runCmd, runCmd.ProcessState.ExitCode())
s.safeUpdateState(v1.ContainerState{Terminated: &v1.ContainerStateTerminated{
ExitCode: int32(runCmd.ProcessState.ExitCode()),
Expand Down Expand Up @@ -487,10 +493,6 @@ func (s *VM) Cleanup() error {
}
return err2
}
err = s.streams.Cleanup()
if err != nil {
s.storyLine.Add("streamsCleanUpErr", err)
}
return nil
}

Expand Down
140 changes: 81 additions & 59 deletions pkg/streams/streams.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"github.com/macvmio/fugaci/pkg/ctxio"
"github.com/virtual-kubelet/virtual-kubelet/node/api"
"golang.org/x/sync/errgroup"
"io"
"os"
"sync"
Expand Down Expand Up @@ -74,38 +75,32 @@ func (f *FilesBasedStreams) Resize() <-chan api.TermSize {
return f.termSizeCh
}

// Cleanup removes the temporary files created for stdin, stdout, and stderr.
// It is safe to call Cleanup multiple times concurrently.
func (f *FilesBasedStreams) Cleanup() error {
// Close removes the temporary files created for stdin, stdout, and stderr.
// It is safe to call Close multiple times concurrently.
func (f *FilesBasedStreams) Close() error {
var errs []error
// Wait for any ongoing operations to finish
f.cleanupWG.Wait()

f.cleanOnce.Do(func() {
f.mu.Lock()
defer f.mu.Unlock()

// Wait for any ongoing operations to finish
f.cleanupWG.Wait()

if err := f.stdoutFile.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close stdout file: %w", err))
}
if err := os.Remove(f.stdoutFile.Name()); err != nil {
errs = append(errs, fmt.Errorf("failed to remove stdout file: %w", err))
}
if err := f.stdoutFile.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close stdout file: %w", err))
}
if err := os.Remove(f.stdoutFile.Name()); err != nil {
errs = append(errs, fmt.Errorf("failed to remove stdout file: %w", err))
}

if err := f.stderrFile.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close stderr file: %w", err))
}
if err := os.Remove(f.stderrFile.Name()); err != nil {
errs = append(errs, fmt.Errorf("failed to remove stderr file: %w", err))
}
if err := f.stderrFile.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close stderr file: %w", err))
}
if err := os.Remove(f.stderrFile.Name()); err != nil {
errs = append(errs, fmt.Errorf("failed to remove stderr file: %w", err))
}

// Close termSizeCh
if f.termSizeCh != nil {
close(f.termSizeCh)
f.termSizeCh = nil
}
})
// Close termSizeCh
if f.termSizeCh != nil {
close(f.termSizeCh)
f.termSizeCh = nil
}

if len(errs) > 0 {
return fmt.Errorf("cleanup encountered errors: %v", errs)
Expand All @@ -114,66 +109,93 @@ func (f *FilesBasedStreams) Cleanup() error {
}

func (f *FilesBasedStreams) Stream(ctx context.Context, attach api.AttachIO, loggerPrintf func(format string, v ...any)) error {
f.cleanupWG.Add(1)
allowableError := func(err error) bool {
if err == nil {
return true
}
return errors.Is(err, context.Canceled) || errors.Is(err, io.EOF)
return err == nil || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF)
}
go func() {
defer f.cleanupWG.Done()
// Start streaming stdout
if attach.Stdout() != nil {
if err := followFileStream(ctx, attach.Stdout(), f.stdoutFile.Name(), loggerPrintf); !allowableError(err) {

// Create an errgroup with the provided context
eg, ctx := errgroup.WithContext(ctx)

// Start streaming stdout
if attach.Stdout() != nil {
f.cleanupWG.Add(1)
eg.Go(func() error {
defer f.cleanupWG.Done()
err := followFileStream(ctx, attach.Stdout(), f.stdoutFile.Name(), loggerPrintf)
if !allowableError(err) {
loggerPrintf("Error streaming stdout: %v", err)
return err
}
}
}()

f.cleanupWG.Add(1)
go func() {
defer f.cleanupWG.Done()
// Start streaming stderr
if attach.Stderr() != nil {
if err := followFileStream(ctx, attach.Stderr(), f.stderrFile.Name(), loggerPrintf); !allowableError(err) {
loggerPrintf("stdout copy completed")
return nil
})
}

// Start streaming stderr
if attach.Stderr() != nil {
f.cleanupWG.Add(1)
eg.Go(func() error {
defer f.cleanupWG.Done()
err := followFileStream(ctx, attach.Stderr(), f.stderrFile.Name(), loggerPrintf)
if !allowableError(err) {
loggerPrintf("Error streaming stderr: %v", err)
return err
}
}
}()
loggerPrintf("stderr copy completed")
return nil
})
}

// Handle stdin
if f.stdinWriter != nil && attach.Stdin() != nil {
f.cleanupWG.Add(1)
go func() {
defer f.cleanupWG.Done()
// TODO: This blocks until if stdin has no data, even if context is cancelled
_, err := io.Copy(f.stdinWriter, attach.Stdin())
if !allowableError(err) {
loggerPrintf("Error streaming stdin: %v", err)
}
loggerPrintf("stdin copy completed")
}()
}

// Handle TTY resize events
if attach.TTY() {
f.cleanupWG.Add(1)
go func() {
eg.Go(func() error {
defer f.cleanupWG.Done()
for termSize := range attach.Resize() {
f.termSizeCh <- termSize
defer loggerPrintf("attach tty channel completed")
for {
select {
case termSize, ok := <-attach.Resize():
if !ok {
// The attach.Resize() channel is closed
return nil
}
select {
case f.termSizeCh <- termSize:
case <-ctx.Done():
return ctx.Err()
}
case <-ctx.Done():
return ctx.Err()
}
}
}()
})
}
// Wait for context cancellation
loggerPrintf("waiting for Stream to finish")
<-ctx.Done()
loggerPrintf("Stream has completed")
return nil

loggerPrintf("waiting for Stream() to finish")
err := eg.Wait()
loggerPrintf("Stream() has completed")
return err
}

func followFileStream(ctx context.Context, writer io.Writer, filename string, loggerPrintf func(format string, v ...any)) error {
func followFileStream(ctx context.Context, writer io.WriteCloser, filename string, loggerPrintf func(format string, v ...any)) error {
if writer == nil {
return fmt.Errorf("writer cannot be nil")
}
defer writer.Close()

tailReader, err := ctxio.NewTailReader(ctx, filename)
if err != nil {
Expand Down
Loading

0 comments on commit 2b68f7c

Please sign in to comment.