From 42c3ec48b0ac85ed736cb8906489c9c53c448598 Mon Sep 17 00:00:00 2001 From: Michael Quigley Date: Thu, 12 Sep 2024 12:44:54 -0400 Subject: [PATCH] rough proctree implementation (#748) --- agent/proctree/impl_posix.go | 59 +++++++++++++++++++++++++ agent/proctree/impl_windows.go | 79 ++++++++++++++++++++++++++++++++++ agent/proctree/proctree.go | 66 ++++++++++++++++++++++++++++ 3 files changed, 204 insertions(+) create mode 100644 agent/proctree/impl_posix.go create mode 100755 agent/proctree/impl_windows.go create mode 100755 agent/proctree/proctree.go diff --git a/agent/proctree/impl_posix.go b/agent/proctree/impl_posix.go new file mode 100644 index 000000000..aa8596c56 --- /dev/null +++ b/agent/proctree/impl_posix.go @@ -0,0 +1,59 @@ +//go:build !windows + +package proctree + +import ( + "os/exec" + "sync" +) + +func Init(_ string) error { + return nil +} + +func StartChild(tail TailFunction, args ...string) (*Child, error) { + cmd := exec.Command(args[0], args[1:]...) + + cld := &Child{ + TailFunction: tail, + cmd: cmd, + outStream: make(chan []byte), + errStream: make(chan []byte), + wg: new(sync.WaitGroup), + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, err + } + + if err := cmd.Start(); err != nil { + return nil, err + } + + cld.wg.Add(3) + go reader(stdout, cld.outStream, cld.wg) + go reader(stderr, cld.errStream, cld.wg) + go cld.combiner(cld.wg) + + return cld, nil +} + +func WaitChild(c *Child) error { + c.wg.Wait() + if err := c.cmd.Wait(); err != nil { + return err + } + return nil +} + +func StopChild(c *Child) error { + if err := c.cmd.Process.Kill(); err != nil { + return err + } + return nil +} diff --git a/agent/proctree/impl_windows.go b/agent/proctree/impl_windows.go new file mode 100755 index 000000000..6a446fe8b --- /dev/null +++ b/agent/proctree/impl_windows.go @@ -0,0 +1,79 @@ +//go:build windows + +package proctree + +import ( + "github.com/kolesnikovae/go-winjob" + "golang.org/x/sys/windows" + "os/exec" + "sync" +) + +var job *winjob.JobObject + +func Init(name string) error { + var err error + if job == nil { + job, err = winjob.Create(name, winjob.LimitKillOnJobClose, winjob.LimitBreakawayOK) + if err != nil { + return err + } + } + return nil +} + +func StartChild(tail TailFunction, args ...string) (*Child, error) { + cmd := exec.Command(args[0], args[1:]...) + cmd.SysProcAttr = &windows.SysProcAttr{CreationFlags: windows.CREATE_SUSPENDED} + + cld := &Child{ + TailFunction: tail, + cmd: cmd, + outStream: make(chan []byte), + errStream: make(chan []byte), + wg: new(sync.WaitGroup), + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, err + } + + if err := cmd.Start(); err != nil { + return nil, err + } + + if err := job.Assign(cmd.Process); err != nil { + return nil, err + } + + if err := winjob.ResumeProcess(cmd.Process.Pid); err != nil { + return nil, err + } + + cld.wg.Add(3) + go reader(stdout, cld.outStream, cld.wg) + go reader(stderr, cld.errStream, cld.wg) + go cld.combiner(cld.wg) + + return cld, nil +} + +func WaitChild(c *Child) error { + c.wg.Wait() + if err := c.cmd.Wait(); err != nil { + return err + } + return nil +} + +func StopChild(c *Child) error { + if err := c.cmd.Process.Kill(); err != nil { + return err + } + return nil +} diff --git a/agent/proctree/proctree.go b/agent/proctree/proctree.go new file mode 100755 index 000000000..201517e7a --- /dev/null +++ b/agent/proctree/proctree.go @@ -0,0 +1,66 @@ +package proctree + +import ( + "fmt" + "io" + "os/exec" + "sync" +) + +type Child struct { + TailFunction TailFunction + cmd *exec.Cmd + outStream chan []byte + errStream chan []byte + wg *sync.WaitGroup +} + +type TailFunction func(data []byte) + +func (c *Child) combiner(wg *sync.WaitGroup) { + defer wg.Done() + + outDone := false + errDone := false + for { + select { + case data := <-c.outStream: + if data != nil { + if c.TailFunction != nil { + c.TailFunction(data) + } + } else { + outDone = true + } + case data := <-c.errStream: + if data != nil { + if c.TailFunction != nil { + c.TailFunction(data) + } + } else { + errDone = true + } + } + if outDone && errDone { + return + } + } +} + +func reader(r io.ReadCloser, o chan []byte, wg *sync.WaitGroup) { + defer close(o) + defer wg.Done() + + buf := make([]byte, 64*1024) + for { + n, err := r.Read(buf) + if err != nil { + if err == io.EOF { + return + } + fmt.Printf("error reading: %v", err) + return + } + o <- buf[:n] + } +}