Skip to content

Commit

Permalink
Merge branch 'main' into add-persist
Browse files Browse the repository at this point in the history
  • Loading branch information
st0nie authored Dec 31, 2024
2 parents de90c0e + b3ff5e6 commit ca4de62
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 90 deletions.
130 changes: 93 additions & 37 deletions cmd/internal/su.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,100 @@
package internal

import (
"fmt"
"os"
"os/exec"
"path/filepath"

"github.com/sirupsen/logrus"
"fmt"
"github.com/sirupsen/logrus"
"os"
"os/exec"
)

func AutoSu() {
if os.Getuid() == 0 {
return
}
program := filepath.Base(os.Args[0])
pathSudo, err := exec.LookPath("sudo")
if err != nil {
// skip
return
}
// https://github.com/WireGuard/wireguard-tools/blob/71799a8f6d1450b63071a21cad6ed434b348d3d5/src/wg-quick/linux.bash#L85
p, err := os.StartProcess(pathSudo, append([]string{
pathSudo,
"-E",
"-p",
fmt.Sprintf("%v must be run as root. Please enter the password for %%u to continue: ", program),
"--",
}, os.Args...), &os.ProcAttr{
Files: []*os.File{
os.Stdin,
os.Stdout,
os.Stderr,
},
})
if err != nil {
logrus.Fatal(err)
}
stat, err := p.Wait()
if err != nil {
os.Exit(1)
}
os.Exit(stat.ExitCode())
if os.Geteuid() == 0 {
return
}
path, arg := trySudo()
if path == "" {
path, arg = tryDoas()
}
if path == "" {
path, arg = tryPolkit()
}

if path == "" {
return
}
logrus.Infof("use [ %s ] to elevate privileges to run [ %s ]", path, os.Args[0])
p, err := os.StartProcess(path, append(arg, os.Args...), &os.ProcAttr{
Files: []*os.File{
os.Stdin,
os.Stdout,
os.Stderr,
},
})
if err != nil {
logrus.Fatal(err)
}
stat, err := p.Wait()
if err != nil {
os.Exit(1)
}
os.Exit(stat.ExitCode())
}

func trySudo() (path string, arg []string) {
pathSudo, err := exec.LookPath("sudo")
if err != nil || !isExistAndExecutable(pathSudo) {
return "", nil
}
// https://github.com/WireGuard/wireguard-tools/blob/71799a8f6d1450b63071a21cad6ed434b348d3d5/src/wg-quick/linux.bash#L85
return pathSudo, []string{
pathSudo,
"-E",
"-p",
fmt.Sprintf("Please enter the password for %%u to continue: "),
"--",
}
}

func tryDoas() (path string, arg []string) {
// https://man.archlinux.org/man/doas.1
var err error
path, err = exec.LookPath("doas")
if err != nil {
return "", nil
}
return path, []string{path, "-u", "root"}
}

func tryPolkit() (path string, arg []string) {
// https://github.com/systemd/systemd/releases/tag/v256
// introduced run0 which is a polkit wrapper.
var possible = []string{"run0", "pkexec"}
for _, v := range possible {
path, err := exec.LookPath(v)
if err != nil {
continue
}
if isExistAndExecutable(path) {
switch v {
case "run0":
return path, []string{path}
case "pkexec":
return path, []string{path, "--keep-cwd", "--user", "root"}
}
}
}
return "", nil
}

func isExistAndExecutable(path string) bool {
if path == "" {
return false
}

st, err := os.Stat(path)
if err == nil {
// https://stackoverflow.com/questions/60128401/how-to-check-if-a-file-is-executable-in-go
return st.Mode()&0o111 == 0o111
}
return false
}
2 changes: 1 addition & 1 deletion cmd/reload.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ var (
Use: "reload [pid]",
Short: "To reload config file without interrupt connections.",
Run: func(cmd *cobra.Command, args []string) {
internal.AutoSu()
internal.AutoSu()
if len(args) == 0 {
_pid, err := os.ReadFile(PidFilePath)
if err != nil {
Expand Down
15 changes: 10 additions & 5 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ func init() {
runCmd.PersistentFlags().StringVar(&logFile, "logfile", "", "Log file to write. Empty means writing to stdout and stderr.")
runCmd.PersistentFlags().IntVar(&logFileMaxSize, "logfile-maxsize", 30, "Unit: MB. The maximum size in megabytes of the log file before it gets rotated.")
runCmd.PersistentFlags().IntVar(&logFileMaxBackups, "logfile-maxbackups", 3, "The maximum number of old log files to retain.")
runCmd.PersistentFlags().BoolVarP(&disableTimestamp, "disable-timestamp", "", false, "Disable timestamp.")
runCmd.PersistentFlags().BoolVarP(&disablePidFile, "disable-pidfile", "", false, "Not generate /var/run/dae.pid.")

runCmd.PersistentFlags().BoolVar(&disableTimestamp, "disable-timestamp", false, "Disable timestamp.")
runCmd.PersistentFlags().BoolVar(&disablePidFile, "disable-pidfile", false, "Not generate /var/run/dae.pid.")
runCmd.PersistentFlags().BoolVar(&disableAuthSudo, "disable-sudo", false, "Disable sudo prompt ,may cause startup failure due to insufficient permissions")
rand.Shuffle(len(CheckNetworkLinks), func(i, j int) {
CheckNetworkLinks[i], CheckNetworkLinks[j] = CheckNetworkLinks[j], CheckNetworkLinks[i]
})
Expand All @@ -74,6 +74,7 @@ var (
logFileMaxBackups int
disableTimestamp bool
disablePidFile bool
disableAuthSudo bool

runCmd = &cobra.Command{
Use: "run",
Expand All @@ -82,9 +83,13 @@ var (
if cfgFile == "" {
logrus.Fatalln("Argument \"--config\" or \"-c\" is required but not provided.")
}

if disableAuthSudo && os.Geteuid() != 0 {
logrus.Fatalln("Auto-sudo is disabled and current user is not root.")
}
// Require "sudo" if necessary.
internal.AutoSu()
if !disableAuthSudo {
internal.AutoSu()
}

// Read config from --config cfgFile.
conf, includes, err := readConfig(cfgFile)
Expand Down
18 changes: 7 additions & 11 deletions cmd/sysdump.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ package cmd

import (
"bytes"
"io/ioutil"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"time"

"github.com/vishvananda/netlink"
"github.com/spf13/cobra"
"github.com/mholt/archiver/v3"
"github.com/shirou/gopsutil/v4/net"
"github.com/spf13/cobra"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)

Expand Down Expand Up @@ -46,7 +46,7 @@ func dumpNetworkInfo() {
dumpNetfilter(tempDir)
dumpIPTables(tempDir)

tarFile := fmt.Sprintf("dae-sysdump.%d.tar.gz",time.Now().Unix())
tarFile := fmt.Sprintf("dae-sysdump.%d.tar.gz", time.Now().Unix())
if err := archiver.Archive([]string{tempDir}, tarFile); err != nil {
fmt.Printf("Failed to create tar archive: %v\n", err)
return
Expand All @@ -55,7 +55,6 @@ func dumpNetworkInfo() {
fmt.Printf("System network information collected and saved to %s\n", tarFile)
}


// Translate scope enum into semantic words
func scopeToString(scope netlink.Scope) string {
switch scope {
Expand All @@ -74,7 +73,6 @@ func scopeToString(scope netlink.Scope) string {
}
}


// Translate protocol enum into semantic words
func protocolToString(proto int) string {
switch proto {
Expand Down Expand Up @@ -157,7 +155,6 @@ func typeToString(typ int) string {
}
}


func dumpRouting(outputDir string) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
Expand Down Expand Up @@ -232,7 +229,6 @@ func dumpNetInterfaces(outputDir string) {
ioutil.WriteFile(filepath.Join(outputDir, "interfaces.txt"), buffer.Bytes(), 0644)
}


func dumpSysctl(outputDir string) {
sysctlPath := "/proc/sys/net"
var buffer bytes.Buffer
Expand Down Expand Up @@ -281,12 +277,12 @@ func dumpIPTables(outputDir string) {
ioutil.WriteFile(filepath.Join(outputDir, "iptables.txt"), output, 0644)
}

ip6tables := exec.Command("ip6tables-save","-c")
ip6tables := exec.Command("ip6tables-save", "-c")
output, err = ip6tables.CombinedOutput()
if err != nil {
fmt.Printf("Failed to get ip6tables: %v\n", err)
} else {
ioutil.WriteFile(filepath.Join(outputDir, "ip6tables.txt"), output, 0644)
ioutil.WriteFile(filepath.Join(outputDir, "ip6tables.txt"), output, 0644)
}
}

Expand Down
10 changes: 5 additions & 5 deletions control/dns_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ type DnsController struct {

fixedDomainTtl map[string]int
// mutex protects the dnsCache.
dnsCacheMu sync.Mutex
dnsCache map[string]*DnsCache
dnsCacheMu sync.Mutex
dnsCache map[string]*DnsCache
dnsForwarderCacheMu sync.Mutex
dnsForwarderCache map[string]DnsForwarder
}
Expand Down Expand Up @@ -113,9 +113,9 @@ func NewDnsController(routing *dns.Dns, option *DnsControllerOption) (c *DnsCont
bestDialerChooser: option.BestDialerChooser,
timeoutExceedCallback: option.TimeoutExceedCallback,

fixedDomainTtl: option.FixedDomainTtl,
dnsCacheMu: sync.Mutex{},
dnsCache: make(map[string]*DnsCache),
fixedDomainTtl: option.FixedDomainTtl,
dnsCacheMu: sync.Mutex{},
dnsCache: make(map[string]*DnsCache),
dnsForwarderCacheMu: sync.Mutex{},
dnsForwarderCache: make(map[string]DnsForwarder),
}, nil
Expand Down
61 changes: 30 additions & 31 deletions trace/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
"encoding/binary"
"errors"
"fmt"
"slices"
"net"
"os"
"slices"
"syscall"
"unsafe"

Expand Down Expand Up @@ -278,43 +278,42 @@ func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfre
logrus.Debugf("failed to parse ringbuf event: %+v", err)
continue
}
if skb2events[event.Skb]==nil {
skb2events[event.Skb] = []bpfEvent{}
if skb2events[event.Skb] == nil {
skb2events[event.Skb] = []bpfEvent{}
}
skb2events[event.Skb] = append(skb2events[event.Skb],event)
skb2events[event.Skb] = append(skb2events[event.Skb], event)


sym := NearestSymbol(event.Pc);
if skb2symNames[event.Skb]==nil {
sym := NearestSymbol(event.Pc)
if skb2symNames[event.Skb] == nil {
skb2symNames[event.Skb] = []string{}
}
skb2symNames[event.Skb] = append(skb2symNames[event.Skb],sym.Name)
skb2symNames[event.Skb] = append(skb2symNames[event.Skb], sym.Name)
switch sym.Name {
case "__kfree_skb","kfree_skbmem":
// most skb end in the call of kfree_skbmem
if !dropOnly || slices.Contains(skb2symNames[event.Skb],"kfree_skb_reason") {
// trace dropOnly with drop reason or all skb
for _,skb_ev := range skb2events[event.Skb] {
fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", skb_ev.Skb, skb_ev.Mark, skb_ev.Netns, skb_ev.Ifindex, TrimNull(string(skb_ev.Ifname[:])), skb_ev.Pid, TrimNull(string(skb_ev.Pname[:])))
if event.L3Proto == syscall.ETH_P_IP {
fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(skb_ev.Saddr[:4]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:4]).String(), Ntohs(skb_ev.Dport))
} else {
fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(skb_ev.Saddr[:]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:]).String(), Ntohs(skb_ev.Dport))
}
if event.L4Proto == syscall.IPPROTO_TCP {
fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(skb_ev.TcpFlags))
}
fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen)
sym := NearestSymbol(skb_ev.Pc)
fmt.Fprintf(writer, "%s", sym.Name)
if sym.Name == "kfree_skb_reason" {
fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[skb_ev.SecondParam])
}
fmt.Fprintf(writer, "\n")
case "__kfree_skb", "kfree_skbmem":
// most skb end in the call of kfree_skbmem
if !dropOnly || slices.Contains(skb2symNames[event.Skb], "kfree_skb_reason") {
// trace dropOnly with drop reason or all skb
for _, skb_ev := range skb2events[event.Skb] {
fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", skb_ev.Skb, skb_ev.Mark, skb_ev.Netns, skb_ev.Ifindex, TrimNull(string(skb_ev.Ifname[:])), skb_ev.Pid, TrimNull(string(skb_ev.Pname[:])))
if event.L3Proto == syscall.ETH_P_IP {
fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(skb_ev.Saddr[:4]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:4]).String(), Ntohs(skb_ev.Dport))
} else {
fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(skb_ev.Saddr[:]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:]).String(), Ntohs(skb_ev.Dport))
}
if event.L4Proto == syscall.IPPROTO_TCP {
fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(skb_ev.TcpFlags))
}
fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen)
sym := NearestSymbol(skb_ev.Pc)
fmt.Fprintf(writer, "%s", sym.Name)
if sym.Name == "kfree_skb_reason" {
fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[skb_ev.SecondParam])
}
fmt.Fprintf(writer, "\n")
}
delete(skb2events, event.Skb)
delete(skb2symNames, event.Skb)
}
delete(skb2symNames, event.Skb)
}
}
}
}

0 comments on commit ca4de62

Please sign in to comment.