From ec72c6f43e695749cf5e204b816d896e5097abf1 Mon Sep 17 00:00:00 2001 From: James Pickett Date: Fri, 24 Jan 2025 09:17:00 -0800 Subject: [PATCH] add tracing to all allowed cmds (#2061) --- .golangci.yml | 2 + ee/allowedcmd/cmd.go | 50 ++++++++++++++++-- ee/allowedcmd/cmd_darwin.go | 67 ++++++++++++------------ ee/allowedcmd/cmd_linux.go | 75 +++++++++++++-------------- ee/allowedcmd/cmd_test.go | 14 ++--- ee/allowedcmd/cmd_windows.go | 23 ++++---- ee/debug/checkups/network.go | 2 +- ee/debug/checkups/power_windows.go | 4 +- ee/debug/checkups/services_windows.go | 4 +- ee/tables/pwpolicy/pwpolicy_test.go | 10 ++-- ee/tables/tablehelpers/exec.go | 2 +- 11 files changed, 149 insertions(+), 104 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 2e867cbd0..80cd604f4 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -41,6 +41,8 @@ linters-settings: msg: do not use panic so that launcher can shut down gracefully - p: ^go func.*$ msg: use gowrapper.Go() instead of raw goroutines for proper panic handling + - p: \.Cmd\.(Run|Start|Output|CombinedOutput) + msg: "Do not call embedded exec.Cmd methods directly on TracedCmd; call TracedCmd.Run(), TracedCmd.Start(), etc. instead" sloglint: kv-only: true context: "all" diff --git a/ee/allowedcmd/cmd.go b/ee/allowedcmd/cmd.go index cf7e381a2..d31f39886 100644 --- a/ee/allowedcmd/cmd.go +++ b/ee/allowedcmd/cmd.go @@ -13,17 +13,59 @@ import ( "os" "os/exec" "path/filepath" + + "github.com/kolide/launcher/pkg/traces" ) -type AllowedCommand func(ctx context.Context, arg ...string) (*exec.Cmd, error) +type AllowedCommand func(ctx context.Context, arg ...string) (*TracedCmd, error) + +type TracedCmd struct { + Ctx context.Context // nolint:containedctx // This is an approved usage of context for short lived cmd + *exec.Cmd +} + +// Start overrides the Start method to add tracing before executing the command. +func (t *TracedCmd) Start() error { + _, span := traces.StartSpan(t.Ctx, "path", t.Cmd.Path, "args", fmt.Sprintf("%+v", t.Cmd.Args)) + defer span.End() + + return t.Cmd.Start() //nolint:forbidigo // This is our approved usage of t.Cmd.Start() +} + +// Run overrides the Run method to add tracing before running the command. +func (t *TracedCmd) Run() error { + _, span := traces.StartSpan(t.Ctx, "path", t.Cmd.Path, "args", fmt.Sprintf("%+v", t.Cmd.Args)) + defer span.End() + + return t.Cmd.Run() //nolint:forbidigo // This is our approved usage of t.Cmd.Run() +} + +// Output overrides the Output method to add tracing before capturing output. +func (t *TracedCmd) Output() ([]byte, error) { + _, span := traces.StartSpan(t.Ctx, "path", t.Cmd.Path, "args", fmt.Sprintf("%+v", t.Cmd.Args)) + defer span.End() + + return t.Cmd.Output() //nolint:forbidigo // This is our approved usage of t.Cmd.Output() +} -func newCmd(ctx context.Context, fullPathToCmd string, arg ...string) *exec.Cmd { - return exec.CommandContext(ctx, fullPathToCmd, arg...) //nolint:forbidigo // This is our approved usage of exec.CommandContext +// CombinedOutput overrides the CombinedOutput method to add tracing before capturing combined output. +func (t *TracedCmd) CombinedOutput() ([]byte, error) { + _, span := traces.StartSpan(t.Ctx, "path", t.Cmd.Path, "args", fmt.Sprintf("%+v", t.Cmd.Args)) + defer span.End() + + return t.Cmd.CombinedOutput() //nolint:forbidigo // This is our approved usage of t.Cmd.CombinedOutput() +} + +func newCmd(ctx context.Context, fullPathToCmd string, arg ...string) *TracedCmd { + return &TracedCmd{ + Ctx: ctx, + Cmd: exec.CommandContext(ctx, fullPathToCmd, arg...), //nolint:forbidigo // This is our approved usage of exec.CommandContext + } } var ErrCommandNotFound = errors.New("command not found") -func validatedCommand(ctx context.Context, knownPath string, arg ...string) (*exec.Cmd, error) { +func validatedCommand(ctx context.Context, knownPath string, arg ...string) (*TracedCmd, error) { knownPath = filepath.Clean(knownPath) if _, err := os.Stat(knownPath); err == nil { diff --git a/ee/allowedcmd/cmd_darwin.go b/ee/allowedcmd/cmd_darwin.go index 459c43884..fce5345bd 100644 --- a/ee/allowedcmd/cmd_darwin.go +++ b/ee/allowedcmd/cmd_darwin.go @@ -6,22 +6,21 @@ package allowedcmd import ( "context" "errors" - "os/exec" ) -func Airport(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Airport(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport", arg...) } -func Bioutil(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Bioutil(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/bioutil", arg...) } -func Bputil(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Bputil(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/bputil", arg...) } -func Brew(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Brew(ctx context.Context, arg ...string) (*TracedCmd, error) { for _, p := range []string{"/opt/homebrew/bin/brew", "/usr/local/bin/brew"} { validatedCmd, err := validatedCommand(ctx, p, arg...) if err != nil { @@ -36,118 +35,118 @@ func Brew(ctx context.Context, arg ...string) (*exec.Cmd, error) { return nil, errors.New("homebrew not found") } -func Diskutil(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Diskutil(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/diskutil", arg...) } -func Echo(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Echo(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/bin/echo", arg...) } -func Falconctl(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Falconctl(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/Applications/Falcon.app/Contents/Resources/falconctl", arg...) } -func Fdesetup(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Fdesetup(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/fdesetup", arg...) } -func Firmwarepasswd(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Firmwarepasswd(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/firmwarepasswd", arg...) } -func Ifconfig(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Ifconfig(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/sbin/ifconfig", arg...) } -func Ioreg(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Ioreg(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/ioreg", arg...) } -func Launchctl(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Launchctl(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/bin/launchctl", arg...) } -func Lsof(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Lsof(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/lsof", arg...) } -func Mdfind(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Mdfind(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/mdfind", arg...) } -func Mdmclient(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Mdmclient(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/libexec/mdmclient", arg...) } -func Netstat(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Netstat(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/netstat", arg...) } -func NixEnv(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func NixEnv(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/nix/var/nix/profiles/default/bin/nix-env", arg...) } -func Open(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Open(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/open", arg...) } -func Pkgutil(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Pkgutil(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/pkgutil", arg...) } -func Powermetrics(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Powermetrics(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/powermetrics", arg...) } -func Profiles(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Profiles(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/profiles", arg...) } -func Ps(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Ps(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/bin/ps", arg...) } -func Pwpolicy(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Pwpolicy(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/pwpolicy", arg...) } -func Remotectl(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Remotectl(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/libexec/remotectl", arg...) } -func Repcli(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Repcli(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/Applications/VMware Carbon Black Cloud/repcli.bundle/Contents/MacOS/repcli", arg...) } -func Scutil(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Scutil(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/scutil", arg...) } -func Socketfilterfw(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Socketfilterfw(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/libexec/ApplicationFirewall/socketfilterfw", arg...) } -func Softwareupdate(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Softwareupdate(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/softwareupdate", arg...) } -func SystemProfiler(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func SystemProfiler(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/system_profiler", arg...) } -func Tmutil(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Tmutil(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/tmutil", arg...) } -func ZerotierCli(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func ZerotierCli(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/local/bin/zerotier-cli", arg...) } -func Zfs(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Zfs(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/zfs", arg...) } -func Zpool(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Zpool(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/zpool", arg...) } diff --git a/ee/allowedcmd/cmd_linux.go b/ee/allowedcmd/cmd_linux.go index 32aa870e4..5fcfd87ce 100644 --- a/ee/allowedcmd/cmd_linux.go +++ b/ee/allowedcmd/cmd_linux.go @@ -6,14 +6,13 @@ package allowedcmd import ( "context" "errors" - "os/exec" ) -func Apt(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Apt(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/apt", arg...) } -func Brew(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Brew(ctx context.Context, arg ...string) (*TracedCmd, error) { validatedCmd, err := validatedCommand(ctx, "/home/linuxbrew/.linuxbrew/bin/brew", arg...) if err != nil { return nil, err @@ -24,11 +23,11 @@ func Brew(ctx context.Context, arg ...string) (*exec.Cmd, error) { return validatedCmd, nil } -func Coredumpctl(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Coredumpctl(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/coredumpctl", arg...) } -func Cryptsetup(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Cryptsetup(ctx context.Context, arg ...string) (*TracedCmd, error) { for _, p := range []string{"/usr/sbin/cryptsetup", "/sbin/cryptsetup"} { validatedCmd, err := validatedCommand(ctx, p, arg...) if err != nil { @@ -41,55 +40,55 @@ func Cryptsetup(ctx context.Context, arg ...string) (*exec.Cmd, error) { return nil, errors.New("cryptsetup not found") } -func Dnf(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Dnf(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/dnf", arg...) } -func Dpkg(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Dpkg(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/dpkg", arg...) } -func Echo(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Echo(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/echo", arg...) } -func Falconctl(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Falconctl(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/opt/CrowdStrike/falconctl", arg...) } -func FalconKernelCheck(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func FalconKernelCheck(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/opt/CrowdStrike/falcon-kernel-check", arg...) } -func Flatpak(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Flatpak(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/flatpak", arg...) } -func GnomeExtensions(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func GnomeExtensions(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/gnome-extensions", arg...) } -func Gsettings(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Gsettings(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/gsettings", arg...) } -func Ifconfig(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Ifconfig(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/ifconfig", arg...) } -func Ip(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Ip(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/ip", arg...) } -func Journalctl(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Journalctl(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/journalctl", arg...) } -func Loginctl(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Loginctl(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/loginctl", arg...) } -func Lsblk(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Lsblk(ctx context.Context, arg ...string) (*TracedCmd, error) { for _, p := range []string{"/bin/lsblk", "/usr/bin/lsblk"} { validatedCmd, err := validatedCommand(ctx, p, arg...) if err != nil { @@ -102,43 +101,43 @@ func Lsblk(ctx context.Context, arg ...string) (*exec.Cmd, error) { return nil, errors.New("lsblk not found") } -func Lsof(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Lsof(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/lsof", arg...) } -func NixEnv(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func NixEnv(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/run/current-system/sw/bin/nix-env", arg...) } -func Nftables(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Nftables(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/nft", arg...) } -func Nmcli(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Nmcli(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/nmcli", arg...) } -func NotifySend(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func NotifySend(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/notify-send", arg...) } -func Pacman(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Pacman(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/pacman", arg...) } -func Patchelf(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Patchelf(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/run/current-system/sw/bin/patchelf", arg...) } -func Ps(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Ps(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/ps", arg...) } -func Repcli(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Repcli(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/opt/carbonblack/psc/bin/repcli", arg...) } -func Rpm(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Rpm(ctx context.Context, arg ...string) (*TracedCmd, error) { for _, p := range []string{"/bin/rpm", "/usr/bin/rpm"} { validatedCmd, err := validatedCommand(ctx, p, arg...) if err != nil { @@ -151,15 +150,15 @@ func Rpm(ctx context.Context, arg ...string) (*exec.Cmd, error) { return nil, errors.New("rpm not found") } -func Snap(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Snap(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/snap", arg...) } -func Systemctl(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Systemctl(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/systemctl", arg...) } -func Ws1HubUtil(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Ws1HubUtil(ctx context.Context, arg ...string) (*TracedCmd, error) { for _, p := range []string{"/usr/bin/ws1HubUtil", "/opt/vmware/ws1-hub/bin/ws1HubUtil"} { validatedCmd, err := validatedCommand(ctx, p, arg...) if err != nil { @@ -172,30 +171,30 @@ func Ws1HubUtil(ctx context.Context, arg ...string) (*exec.Cmd, error) { return nil, errors.New("ws1HubUtil not found") } -func XdgOpen(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func XdgOpen(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/xdg-open", arg...) } -func Xrdb(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Xrdb(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/xrdb", arg...) } -func XWwwBrowser(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func XWwwBrowser(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/x-www-browser", arg...) } -func ZerotierCli(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func ZerotierCli(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/local/bin/zerotier-cli", arg...) } -func Zfs(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Zfs(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/zfs", arg...) } -func Zpool(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Zpool(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/sbin/zpool", arg...) } -func Zypper(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Zypper(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, "/usr/bin/zypper", arg...) } diff --git a/ee/allowedcmd/cmd_test.go b/ee/allowedcmd/cmd_test.go index 9f071ba57..16119ba32 100644 --- a/ee/allowedcmd/cmd_test.go +++ b/ee/allowedcmd/cmd_test.go @@ -13,18 +13,18 @@ func TestEcho(t *testing.T) { t.Parallel() // echo is the only one available on all platforms and likely to be available in CI - cmd, err := Echo(context.TODO(), "hello") + tracedCmd, err := Echo(context.TODO(), "hello") require.NoError(t, err) - require.Contains(t, cmd.Path, "echo") - require.Contains(t, cmd.Args, "hello") + require.Contains(t, tracedCmd.Path, "echo") + require.Contains(t, tracedCmd.Args, "hello") } func Test_newCmd(t *testing.T) { t.Parallel() cmdPath := filepath.Join("some", "path", "to", "a", "command") - cmd := newCmd(context.TODO(), cmdPath) - require.Equal(t, cmdPath, cmd.Path) + tracedCmd := newCmd(context.TODO(), cmdPath) + require.Equal(t, cmdPath, tracedCmd.Path) } func Test_validatedCommand(t *testing.T) { @@ -38,10 +38,10 @@ func Test_validatedCommand(t *testing.T) { cmdPath = `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe` } - cmd, err := validatedCommand(context.TODO(), cmdPath) + tracedCmd, err := validatedCommand(context.TODO(), cmdPath) require.NoError(t, err) - require.Equal(t, cmdPath, cmd.Path) + require.Equal(t, cmdPath, tracedCmd.Path) } func Test_validatedCommand_doesNotSearchPathOnNonNixOS(t *testing.T) { diff --git a/ee/allowedcmd/cmd_windows.go b/ee/allowedcmd/cmd_windows.go index 5425a5926..13bb44b67 100644 --- a/ee/allowedcmd/cmd_windows.go +++ b/ee/allowedcmd/cmd_windows.go @@ -6,52 +6,51 @@ package allowedcmd import ( "context" "os" - "os/exec" "path/filepath" ) -func CommandPrompt(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func CommandPrompt(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, filepath.Join(os.Getenv("WINDIR"), "System32", "cmd.exe"), arg...) } -func Dism(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Dism(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, filepath.Join(os.Getenv("WINDIR"), "System32", "Dism.exe"), arg...) } -func Dsregcmd(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Dsregcmd(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, filepath.Join(os.Getenv("WINDIR"), "System32", "dsregcmd.exe"), arg...) } -func Echo(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Echo(ctx context.Context, arg ...string) (*TracedCmd, error) { // echo on Windows is only available as a command in cmd.exe return newCmd(ctx, "echo", arg...), nil } -func Ipconfig(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Ipconfig(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, filepath.Join(os.Getenv("WINDIR"), "System32", "ipconfig.exe"), arg...) } -func Powercfg(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Powercfg(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, filepath.Join(os.Getenv("WINDIR"), "System32", "powercfg.exe"), arg...) } -func Powershell(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Powershell(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, filepath.Join(os.Getenv("WINDIR"), "System32", "WindowsPowerShell", "v1.0", "powershell.exe"), arg...) } -func Repcli(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Repcli(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, filepath.Join(os.Getenv("PROGRAMFILES"), "Confer", "repcli"), arg...) } -func Secedit(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Secedit(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, filepath.Join(os.Getenv("WINDIR"), "System32", "SecEdit.exe"), arg...) } -func Taskkill(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func Taskkill(ctx context.Context, arg ...string) (*TracedCmd, error) { return validatedCommand(ctx, filepath.Join(os.Getenv("WINDIR"), "System32", "taskkill.exe"), arg...) } -func ZerotierCli(ctx context.Context, arg ...string) (*exec.Cmd, error) { +func ZerotierCli(ctx context.Context, arg ...string) (*TracedCmd, error) { // For windows, "-q" should be prepended before all other args return validatedCommand(ctx, filepath.Join(os.Getenv("SYSTEMROOT"), "ProgramData", "ZeroTier", "One", "zerotier-one_x64.exe"), append([]string{"-q"}, arg...)...) } diff --git a/ee/debug/checkups/network.go b/ee/debug/checkups/network.go index 4a1f49ea2..da703f733 100644 --- a/ee/debug/checkups/network.go +++ b/ee/debug/checkups/network.go @@ -55,7 +55,7 @@ func (n *networkCheckup) Run(ctx context.Context, extraWriter io.Writer) error { if err != nil { continue } - _ = runCmdMarkdownLogged(cmd, commandOutput) + _ = runCmdMarkdownLogged(cmd.Cmd, commandOutput) } for _, fileLocation := range listFiles() { diff --git a/ee/debug/checkups/power_windows.go b/ee/debug/checkups/power_windows.go index 6c988ed4f..2ce99f631 100644 --- a/ee/debug/checkups/power_windows.go +++ b/ee/debug/checkups/power_windows.go @@ -32,7 +32,7 @@ func (p *powerCheckup) Run(ctx context.Context, extraWriter io.Writer) error { if err != nil { return fmt.Errorf("creating powercfg command: %w", err) } - hideWindow(powerCfgCmd) + hideWindow(powerCfgCmd.Cmd) if out, err := powerCfgCmd.CombinedOutput(); err != nil { return fmt.Errorf("running powercfg.exe: error %w, output %s", err, string(out)) } @@ -51,7 +51,7 @@ func (p *powerCheckup) Run(ctx context.Context, extraWriter io.Writer) error { return fmt.Errorf("creating powercfg sleep states command: %w", err) } - hideWindow(powerCfgSleepStatesCmd) + hideWindow(powerCfgSleepStatesCmd.Cmd) availableSleepStatesOutput, err := powerCfgSleepStatesCmd.CombinedOutput() if err != nil { return fmt.Errorf("running powercfg.exe for sleep states: error %w, output %s", err, string(availableSleepStatesOutput)) diff --git a/ee/debug/checkups/services_windows.go b/ee/debug/checkups/services_windows.go index d6550132e..d5cd3634c 100644 --- a/ee/debug/checkups/services_windows.go +++ b/ee/debug/checkups/services_windows.go @@ -253,7 +253,7 @@ func gatherServiceManagerEvents(ctx context.Context, z *zip.Writer) error { if err != nil { return fmt.Errorf("creating powershell command: %w", err) } - hideWindow(cmd) + hideWindow(cmd.Cmd) cmd.Stdout = out cmd.Stderr = out if err := cmd.Run(); err != nil { @@ -284,7 +284,7 @@ func gatherServiceManagerEventLogs(ctx context.Context, z *zip.Writer) error { if err != nil { return fmt.Errorf("creating powershell command: %w", err) } - hideWindow(getEventLogCmd) + hideWindow(getEventLogCmd.Cmd) getEventLogCmd.Stdout = eventLogOut getEventLogCmd.Stderr = eventLogOut if err := getEventLogCmd.Run(); err != nil { diff --git a/ee/tables/pwpolicy/pwpolicy_test.go b/ee/tables/pwpolicy/pwpolicy_test.go index 0f674a756..4af6a068b 100644 --- a/ee/tables/pwpolicy/pwpolicy_test.go +++ b/ee/tables/pwpolicy/pwpolicy_test.go @@ -9,6 +9,7 @@ import ( "path" "testing" + "github.com/kolide/launcher/ee/allowedcmd" "github.com/kolide/launcher/ee/tables/tablehelpers" "github.com/kolide/launcher/pkg/log/multislogger" "github.com/stretchr/testify/assert" @@ -71,8 +72,11 @@ func TestQueries(t *testing.T) { } -func execFaker(filename string) func(context.Context, ...string) (*exec.Cmd, error) { - return func(ctx context.Context, _ ...string) (*exec.Cmd, error) { - return exec.CommandContext(ctx, "/bin/cat", filename), nil //nolint:forbidigo // Fine to use exec.CommandContext in test +func execFaker(filename string) func(context.Context, ...string) (*allowedcmd.TracedCmd, error) { + return func(ctx context.Context, _ ...string) (*allowedcmd.TracedCmd, error) { + return &allowedcmd.TracedCmd{ + Ctx: ctx, + Cmd: exec.CommandContext(ctx, "/bin/cat", filename), //nolint:forbidigo // Fine to use exec.CommandContext in test + }, nil } } diff --git a/ee/tables/tablehelpers/exec.go b/ee/tables/tablehelpers/exec.go index 3316bce97..07a224f1b 100644 --- a/ee/tables/tablehelpers/exec.go +++ b/ee/tables/tablehelpers/exec.go @@ -64,7 +64,7 @@ func Run(ctx context.Context, slogger *slog.Logger, timeoutSeconds int, execCmd } for _, opt := range opts { - if err := opt(cmd); err != nil { + if err := opt(cmd.Cmd); err != nil { return fmt.Errorf("applying option: %w", err) } }