From b3ff5e6c5323a78e9681d82620f41f77f957011a Mon Sep 17 00:00:00 2001 From: Kedaya <99012336+woshikedayaa@users.noreply.github.com> Date: Wed, 1 Jan 2025 01:25:09 +0800 Subject: [PATCH] feat: enhance privilege elevation logic (#722) --- cmd/internal/su.go | 130 +++++++++++++++++++++++++++++------------ cmd/reload.go | 2 +- cmd/run.go | 15 +++-- cmd/sysdump.go | 18 +++--- control/dns_control.go | 10 ++-- trace/trace.go | 61 ++++++++++--------- 6 files changed, 146 insertions(+), 90 deletions(-) diff --git a/cmd/internal/su.go b/cmd/internal/su.go index 1151c3bc46..c87f444677 100644 --- a/cmd/internal/su.go +++ b/cmd/internal/su.go @@ -6,44 +6,100 @@ package internal import ( - "fmt" - "os" - "os/exec" - "path/filepath" - - "github.com/sirupsen/logrus" + "fmt" + "github.com/sirupsen/logrus" + "os" + "os/exec" ) func AutoSu() { - if os.Getuid() == 0 { - return - } - program := filepath.Base(os.Args[0]) - pathSudo, err := exec.LookPath("sudo") - if err != nil { - // skip - return - } - // https://github.com/WireGuard/wireguard-tools/blob/71799a8f6d1450b63071a21cad6ed434b348d3d5/src/wg-quick/linux.bash#L85 - p, err := os.StartProcess(pathSudo, append([]string{ - pathSudo, - "-E", - "-p", - fmt.Sprintf("%v must be run as root. Please enter the password for %%u to continue: ", program), - "--", - }, os.Args...), &os.ProcAttr{ - Files: []*os.File{ - os.Stdin, - os.Stdout, - os.Stderr, - }, - }) - if err != nil { - logrus.Fatal(err) - } - stat, err := p.Wait() - if err != nil { - os.Exit(1) - } - os.Exit(stat.ExitCode()) + if os.Geteuid() == 0 { + return + } + path, arg := trySudo() + if path == "" { + path, arg = tryDoas() + } + if path == "" { + path, arg = tryPolkit() + } + + if path == "" { + return + } + logrus.Infof("use [ %s ] to elevate privileges to run [ %s ]", path, os.Args[0]) + p, err := os.StartProcess(path, append(arg, os.Args...), &os.ProcAttr{ + Files: []*os.File{ + os.Stdin, + os.Stdout, + os.Stderr, + }, + }) + if err != nil { + logrus.Fatal(err) + } + stat, err := p.Wait() + if err != nil { + os.Exit(1) + } + os.Exit(stat.ExitCode()) +} + +func trySudo() (path string, arg []string) { + pathSudo, err := exec.LookPath("sudo") + if err != nil || !isExistAndExecutable(pathSudo) { + return "", nil + } + // https://github.com/WireGuard/wireguard-tools/blob/71799a8f6d1450b63071a21cad6ed434b348d3d5/src/wg-quick/linux.bash#L85 + return pathSudo, []string{ + pathSudo, + "-E", + "-p", + fmt.Sprintf("Please enter the password for %%u to continue: "), + "--", + } +} + +func tryDoas() (path string, arg []string) { + // https://man.archlinux.org/man/doas.1 + var err error + path, err = exec.LookPath("doas") + if err != nil { + return "", nil + } + return path, []string{path, "-u", "root"} +} + +func tryPolkit() (path string, arg []string) { + // https://github.com/systemd/systemd/releases/tag/v256 + // introduced run0 which is a polkit wrapper. + var possible = []string{"run0", "pkexec"} + for _, v := range possible { + path, err := exec.LookPath(v) + if err != nil { + continue + } + if isExistAndExecutable(path) { + switch v { + case "run0": + return path, []string{path} + case "pkexec": + return path, []string{path, "--keep-cwd", "--user", "root"} + } + } + } + return "", nil +} + +func isExistAndExecutable(path string) bool { + if path == "" { + return false + } + + st, err := os.Stat(path) + if err == nil { + // https://stackoverflow.com/questions/60128401/how-to-check-if-a-file-is-executable-in-go + return st.Mode()&0o111 == 0o111 + } + return false } diff --git a/cmd/reload.go b/cmd/reload.go index 4f8815e85f..d6d18616d6 100644 --- a/cmd/reload.go +++ b/cmd/reload.go @@ -38,7 +38,7 @@ var ( Use: "reload [pid]", Short: "To reload config file without interrupt connections.", Run: func(cmd *cobra.Command, args []string) { - internal.AutoSu() + internal.AutoSu() if len(args) == 0 { _pid, err := os.ReadFile(PidFilePath) if err != nil { diff --git a/cmd/run.go b/cmd/run.go index 16f2fc5b4e..0c5a033d0d 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -59,9 +59,9 @@ func init() { runCmd.PersistentFlags().StringVar(&logFile, "logfile", "", "Log file to write. Empty means writing to stdout and stderr.") runCmd.PersistentFlags().IntVar(&logFileMaxSize, "logfile-maxsize", 30, "Unit: MB. The maximum size in megabytes of the log file before it gets rotated.") runCmd.PersistentFlags().IntVar(&logFileMaxBackups, "logfile-maxbackups", 3, "The maximum number of old log files to retain.") - runCmd.PersistentFlags().BoolVarP(&disableTimestamp, "disable-timestamp", "", false, "Disable timestamp.") - runCmd.PersistentFlags().BoolVarP(&disablePidFile, "disable-pidfile", "", false, "Not generate /var/run/dae.pid.") - + runCmd.PersistentFlags().BoolVar(&disableTimestamp, "disable-timestamp", false, "Disable timestamp.") + runCmd.PersistentFlags().BoolVar(&disablePidFile, "disable-pidfile", false, "Not generate /var/run/dae.pid.") + runCmd.PersistentFlags().BoolVar(&disableAuthSudo, "disable-sudo", false, "Disable sudo prompt ,may cause startup failure due to insufficient permissions") rand.Shuffle(len(CheckNetworkLinks), func(i, j int) { CheckNetworkLinks[i], CheckNetworkLinks[j] = CheckNetworkLinks[j], CheckNetworkLinks[i] }) @@ -74,6 +74,7 @@ var ( logFileMaxBackups int disableTimestamp bool disablePidFile bool + disableAuthSudo bool runCmd = &cobra.Command{ Use: "run", @@ -82,9 +83,13 @@ var ( if cfgFile == "" { logrus.Fatalln("Argument \"--config\" or \"-c\" is required but not provided.") } - + if disableAuthSudo && os.Geteuid() != 0 { + logrus.Fatalln("Auto-sudo is disabled and current user is not root.") + } // Require "sudo" if necessary. - internal.AutoSu() + if !disableAuthSudo { + internal.AutoSu() + } // Read config from --config cfgFile. conf, includes, err := readConfig(cfgFile) diff --git a/cmd/sysdump.go b/cmd/sysdump.go index 997f43c8cf..f7395a02a3 100644 --- a/cmd/sysdump.go +++ b/cmd/sysdump.go @@ -7,18 +7,18 @@ package cmd import ( "bytes" - "io/ioutil" "fmt" + "io/ioutil" "os" "os/exec" "path/filepath" "strings" - "time" + "time" - "github.com/vishvananda/netlink" - "github.com/spf13/cobra" "github.com/mholt/archiver/v3" "github.com/shirou/gopsutil/v4/net" + "github.com/spf13/cobra" + "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) @@ -46,7 +46,7 @@ func dumpNetworkInfo() { dumpNetfilter(tempDir) dumpIPTables(tempDir) - tarFile := fmt.Sprintf("dae-sysdump.%d.tar.gz",time.Now().Unix()) + tarFile := fmt.Sprintf("dae-sysdump.%d.tar.gz", time.Now().Unix()) if err := archiver.Archive([]string{tempDir}, tarFile); err != nil { fmt.Printf("Failed to create tar archive: %v\n", err) return @@ -55,7 +55,6 @@ func dumpNetworkInfo() { fmt.Printf("System network information collected and saved to %s\n", tarFile) } - // Translate scope enum into semantic words func scopeToString(scope netlink.Scope) string { switch scope { @@ -74,7 +73,6 @@ func scopeToString(scope netlink.Scope) string { } } - // Translate protocol enum into semantic words func protocolToString(proto int) string { switch proto { @@ -157,7 +155,6 @@ func typeToString(typ int) string { } } - func dumpRouting(outputDir string) { routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) if err != nil { @@ -232,7 +229,6 @@ func dumpNetInterfaces(outputDir string) { ioutil.WriteFile(filepath.Join(outputDir, "interfaces.txt"), buffer.Bytes(), 0644) } - func dumpSysctl(outputDir string) { sysctlPath := "/proc/sys/net" var buffer bytes.Buffer @@ -281,12 +277,12 @@ func dumpIPTables(outputDir string) { ioutil.WriteFile(filepath.Join(outputDir, "iptables.txt"), output, 0644) } - ip6tables := exec.Command("ip6tables-save","-c") + ip6tables := exec.Command("ip6tables-save", "-c") output, err = ip6tables.CombinedOutput() if err != nil { fmt.Printf("Failed to get ip6tables: %v\n", err) } else { - ioutil.WriteFile(filepath.Join(outputDir, "ip6tables.txt"), output, 0644) + ioutil.WriteFile(filepath.Join(outputDir, "ip6tables.txt"), output, 0644) } } diff --git a/control/dns_control.go b/control/dns_control.go index 5435814738..82245713f2 100644 --- a/control/dns_control.go +++ b/control/dns_control.go @@ -76,8 +76,8 @@ type DnsController struct { fixedDomainTtl map[string]int // mutex protects the dnsCache. - dnsCacheMu sync.Mutex - dnsCache map[string]*DnsCache + dnsCacheMu sync.Mutex + dnsCache map[string]*DnsCache dnsForwarderCacheMu sync.Mutex dnsForwarderCache map[string]DnsForwarder } @@ -113,9 +113,9 @@ func NewDnsController(routing *dns.Dns, option *DnsControllerOption) (c *DnsCont bestDialerChooser: option.BestDialerChooser, timeoutExceedCallback: option.TimeoutExceedCallback, - fixedDomainTtl: option.FixedDomainTtl, - dnsCacheMu: sync.Mutex{}, - dnsCache: make(map[string]*DnsCache), + fixedDomainTtl: option.FixedDomainTtl, + dnsCacheMu: sync.Mutex{}, + dnsCache: make(map[string]*DnsCache), dnsForwarderCacheMu: sync.Mutex{}, dnsForwarderCache: make(map[string]DnsForwarder), }, nil diff --git a/trace/trace.go b/trace/trace.go index b801ecbf37..ae6c816e64 100644 --- a/trace/trace.go +++ b/trace/trace.go @@ -11,9 +11,9 @@ import ( "encoding/binary" "errors" "fmt" - "slices" "net" "os" + "slices" "syscall" "unsafe" @@ -278,43 +278,42 @@ func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfre logrus.Debugf("failed to parse ringbuf event: %+v", err) continue } - if skb2events[event.Skb]==nil { - skb2events[event.Skb] = []bpfEvent{} + if skb2events[event.Skb] == nil { + skb2events[event.Skb] = []bpfEvent{} } - skb2events[event.Skb] = append(skb2events[event.Skb],event) + skb2events[event.Skb] = append(skb2events[event.Skb], event) - - sym := NearestSymbol(event.Pc); - if skb2symNames[event.Skb]==nil { + sym := NearestSymbol(event.Pc) + if skb2symNames[event.Skb] == nil { skb2symNames[event.Skb] = []string{} } - skb2symNames[event.Skb] = append(skb2symNames[event.Skb],sym.Name) + skb2symNames[event.Skb] = append(skb2symNames[event.Skb], sym.Name) switch sym.Name { - case "__kfree_skb","kfree_skbmem": - // most skb end in the call of kfree_skbmem - if !dropOnly || slices.Contains(skb2symNames[event.Skb],"kfree_skb_reason") { - // trace dropOnly with drop reason or all skb - for _,skb_ev := range skb2events[event.Skb] { - fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", skb_ev.Skb, skb_ev.Mark, skb_ev.Netns, skb_ev.Ifindex, TrimNull(string(skb_ev.Ifname[:])), skb_ev.Pid, TrimNull(string(skb_ev.Pname[:]))) - if event.L3Proto == syscall.ETH_P_IP { - fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(skb_ev.Saddr[:4]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:4]).String(), Ntohs(skb_ev.Dport)) - } else { - fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(skb_ev.Saddr[:]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:]).String(), Ntohs(skb_ev.Dport)) - } - if event.L4Proto == syscall.IPPROTO_TCP { - fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(skb_ev.TcpFlags)) - } - fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen) - sym := NearestSymbol(skb_ev.Pc) - fmt.Fprintf(writer, "%s", sym.Name) - if sym.Name == "kfree_skb_reason" { - fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[skb_ev.SecondParam]) - } - fmt.Fprintf(writer, "\n") + case "__kfree_skb", "kfree_skbmem": + // most skb end in the call of kfree_skbmem + if !dropOnly || slices.Contains(skb2symNames[event.Skb], "kfree_skb_reason") { + // trace dropOnly with drop reason or all skb + for _, skb_ev := range skb2events[event.Skb] { + fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", skb_ev.Skb, skb_ev.Mark, skb_ev.Netns, skb_ev.Ifindex, TrimNull(string(skb_ev.Ifname[:])), skb_ev.Pid, TrimNull(string(skb_ev.Pname[:]))) + if event.L3Proto == syscall.ETH_P_IP { + fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(skb_ev.Saddr[:4]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:4]).String(), Ntohs(skb_ev.Dport)) + } else { + fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(skb_ev.Saddr[:]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:]).String(), Ntohs(skb_ev.Dport)) + } + if event.L4Proto == syscall.IPPROTO_TCP { + fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(skb_ev.TcpFlags)) } + fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen) + sym := NearestSymbol(skb_ev.Pc) + fmt.Fprintf(writer, "%s", sym.Name) + if sym.Name == "kfree_skb_reason" { + fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[skb_ev.SecondParam]) + } + fmt.Fprintf(writer, "\n") + } delete(skb2events, event.Skb) - delete(skb2symNames, event.Skb) - } + delete(skb2symNames, event.Skb) + } } } }