diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index fb42256..6f2e270 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -19,7 +19,7 @@ jobs: - name: Set up Go 3.x uses: actions/setup-go@v3 with: - go-version: '1.18.3' + go-version: '^1.19' - name: Get dependencies run: | diff --git a/README.md b/README.md index e382450..279fd39 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,12 @@ docker run -d --pull always -p 80:80 -p 443:443 -p 53:53/udp -v "$(pwd):/tmp/" g In order for `sniproxy` to work properly, ports 80, 443 and 53 need to be open. if you're using ubuntu, there's a good chance that `systemd-resolved` is using port 53. to disable it, follow [these instructions](https://gist.github.com/zoilomora/f7d264cefbb589f3f1b1fc2cea2c844c) +if you would like to keep `systemd-resolved` and disable the builtin resolver, you can use the following: +```bash +sed -i 's/#DNS=/DNS=9.9.9.9/; s/#DNSStubListener=yes/DNSStubListener=no/' /etc/systemd/resolved.conf +systemctl restart systemd-resolved +``` +above will replace the builtin resolver with 9.9.9.9 Issue ===== diff --git a/dns.go b/dns.go index c4b551e..4b9c4f3 100644 --- a/dns.go +++ b/dns.go @@ -2,7 +2,6 @@ package main import ( "bufio" - "crypto/tls" "fmt" "net/http" "net/url" @@ -10,34 +9,43 @@ import ( "strings" "time" + "github.com/golang-collections/collections/tst" + doqclient "github.com/mosajjal/doqd/pkg/client" "github.com/mosajjal/sniproxy/doh" - doqclient "github.com/natesales/doqd/pkg/client" log "github.com/sirupsen/logrus" "github.com/miekg/dns" ) -// inDomainList returns true if the domain exists in the routeDomainList +var ( + matchPrefix = uint8(1) + matchSuffix = uint8(2) + matchFQDN = uint8(3) +) + +// inDomainList returns true if the domain is meant to be SKIPPED and not go through sni proxy // todo: this needs to be replaced by a few tst -func inDomainList(name string) bool { - for _, item := range c.routeDomainList { - if len(item) == 2 { - if item[1] == "suffix" { - if strings.HasSuffix(name, item[0]) { - return true - } - } else if item[1] == "fqdn" { - if name == item[0] { - return true - } - } else if item[1] == "prefix" { - if strings.HasPrefix(name, item[0]) { - return true - } - } +func inDomainList(fqdn string) bool { + fqdnLower := strings.ToLower(fqdn) + // check for fqdn match + if c.routeFQDNs[fqdnLower] == matchFQDN { + return false + } + // check for prefix match + if longestPrefix := c.routePrefixes.GetLongestPrefix(fqdnLower); longestPrefix != nil { + // check if the longest prefix is present in the type hashtable as a prefix + if c.routeFQDNs[longestPrefix.(string)] == matchPrefix { + return false + } + } + // check for suffix match. Note that suffix is just prefix reversed + if longestSuffix := c.routeSuffixes.GetLongestPrefix(reverse(fqdnLower)); longestSuffix != nil { + // check if the longest suffix is present in the type hashtable as a suffix + if c.routeFQDNs[longestSuffix.(string)] == matchSuffix { + return false } } - return false + return true } var dnsClient struct { @@ -46,13 +54,26 @@ var dnsClient struct { classicDNS dns.Client } -func loadDomainsToList(Filename string) [][]string { - log.Info("Loading the domain from file/url to a list") - var lines [][]string +func reverse(s string) string { + r := []rune(s) + for i, j := 0, len(r)-1; i < len(r)/2; i, j = i+1, j-1 { + r[i], r[j] = r[j], r[i] + } + return string(r) +} + +// LoadDomainsCsv loads a domains Csv file/URL. returns 3 parameters: +// 1. a TST for all the prefixes (type 1) +// 2. a TST for all the suffixes (type 2) +// 3. a hashtable for all the full match fqdn (type 3) +func LoadDomainsCsv(Filename string) (prefix *tst.TernarySearchTree, suffix *tst.TernarySearchTree, all map[string]uint8) { + prefix = tst.New() + suffix = tst.New() + all = make(map[string]uint8) + log.Info("Loading the domain from file/url") var scanner *bufio.Scanner if strings.HasPrefix(Filename, "http://") || strings.HasPrefix(Filename, "https://") { log.Info("domain list is a URL, trying to fetch") - http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} client := http.Client{ CheckRedirect: func(r *http.Request, via []*http.Request) error { r.URL.Opaque = r.URL.Path @@ -79,16 +100,36 @@ func loadDomainsToList(Filename string) [][]string { for scanner.Scan() { lowerCaseLine := strings.ToLower(scanner.Text()) - lines = append(lines, strings.Split(lowerCaseLine, ",")) + // split the line by comma to understand the logic + fqdn := strings.Split(lowerCaseLine, ",") + if len(fqdn) != 2 { + log.Warnf("%s is not a valid line, assuming fqdn", lowerCaseLine) + fqdn = []string{lowerCaseLine, "fqdn"} + } + // add the fqdn to the hashtable with its type + switch entryType := fqdn[1]; entryType { + case "prefix": + all[fqdn[0]] = matchPrefix + prefix.Insert(fqdn[0], fqdn[0]) + case "suffix": + all[fqdn[0]] = matchSuffix + // suffix match is much faster if we reverse the strings and match for prefix + suffix.Insert(reverse(fqdn[0]), fqdn[0]) + case "fqdn": + all[fqdn[0]] = matchFQDN + default: + log.Warnf("%s is not a valid line, assuming fqdn", lowerCaseLine) + all[fqdn[0]] = matchFQDN + } } - log.Infof("%s loaded with %d lines", Filename, len(lines)) - return lines + log.Infof("%s loaded with %d prefix, %d suffix and %d fqdn", Filename, prefix.Len(), suffix.Len(), len(all)-prefix.Len()-suffix.Len()) + return prefix, suffix, all } func performExternalQuery(question dns.Question, server string) (*dns.Msg, time.Duration, error) { dnsURL, err := url.Parse(server) if err != nil { - log.Fatalf("Invalid upstream DNS URL: %s", server) + log.Fatalf("[DNS] Invalid upstream DNS URL: %s", server) } msg := dns.Msg{ MsgHdr: dns.MsgHdr{ @@ -112,14 +153,14 @@ func performExternalQuery(question dns.Question, server string) (*dns.Msg, time. } func processQuestion(q dns.Question) ([]dns.RR, error) { - if c.AllDomains || inDomainList(q.Name) { + if c.AllDomains || !inDomainList(q.Name) { // Return the public IP. rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, c.PublicIP)) if err != nil { return nil, err } - log.Printf("returned sniproxy address for domain: %s", q.Name) + log.Infof("[DNS] returned sniproxy address for domain: %s", q.Name) return []dns.RR{rr}, nil } @@ -130,7 +171,7 @@ func processQuestion(q dns.Question) ([]dns.RR, error) { return nil, err } - log.Printf("returned origin address for domain: %s, rtt: %s", q.Name, rtt) + log.Infof("[DNS] returned origin address for domain: %s, rtt: %s", q.Name, rtt) return resp.Answer, nil } diff --git a/go.mod b/go.mod index 7afd7a1..23ee903 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,14 @@ module github.com/mosajjal/sniproxy -go 1.19 +go 1.18 require ( + github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 github.com/miekg/dns v1.1.50 - github.com/natesales/doqd v0.2.1 + github.com/mosajjal/doqd v0.0.0-20221017212049-9745a8eb6912 github.com/sirupsen/logrus v1.9.0 github.com/spf13/pflag v1.0.5 - golang.org/x/net v0.0.0-20221014081412-f15817d10f9b + golang.org/x/net v0.0.0-20221017152216-f25eb7ecb193 ) require ( @@ -30,7 +31,7 @@ require ( golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect golang.org/x/exp v0.0.0-20221012211006-4de253d81b95 // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/sys v0.0.0-20221013171732-95e765b1cc43 // indirect + golang.org/x/sys v0.1.0 // indirect golang.org/x/tools v0.1.12 // indirect google.golang.org/protobuf v1.28.1 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect diff --git a/go.sum b/go.sum index 42c4aa1..d966297 100644 --- a/go.sum +++ b/go.sum @@ -77,6 +77,8 @@ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/me github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 h1:zN2lZNZRflqFyxVaTIU61KNKQ9C0055u9CAfpmqUvo4= +github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3/go.mod h1:nPpo7qLxd6XL3hWJG/O60sR8ZKfMCiIoNap5GvD12KU= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -168,6 +170,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mosajjal/doqd v0.0.0-20221017212049-9745a8eb6912 h1:vLKIgK5v4cyBdrBEJ5DbAafDzj34k1c6Sn8TfePSPRY= +github.com/mosajjal/doqd v0.0.0-20221017212049-9745a8eb6912/go.mod h1:JGEePNwJX0biMT8VFlVdtgycjIba/td/jlqGgxYxrA4= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/natesales/doqd v0.2.1 h1:I1JRd58SHZf68xIIQ8Tg2+LvOQAjXaHALjhP8jXlkzU= @@ -317,8 +321,8 @@ golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20221014081412-f15817d10f9b h1:tvrvnPFcdzp294diPnrdZZZ8XUt2Tyj7svb7X52iDuU= -golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20221017152216-f25eb7ecb193 h1:3Moaxt4TfzNcQH6DWvlYKraN1ozhBXQHcgvXjRGeim0= +golang.org/x/net v0.0.0-20221017152216-f25eb7ecb193/go.mod h1:RpDiru2p0u2F0lLpEoqnP2+7xs0ifAuOcJ442g6GU2s= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -387,8 +391,8 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20221013171732-95e765b1cc43 h1:OK7RB6t2WQX54srQQYSXMW8dF5C6/8+oA/s5QBmmto4= -golang.org/x/sys v0.0.0-20221013171732-95e765b1cc43/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/main.go b/main.go index 10540f3..761bc2d 100644 --- a/main.go +++ b/main.go @@ -13,9 +13,10 @@ import ( "strings" "time" + "github.com/golang-collections/collections/tst" + doqclient "github.com/mosajjal/doqd/pkg/client" + doqserver "github.com/mosajjal/doqd/pkg/server" "github.com/mosajjal/sniproxy/doh" - doqclient "github.com/natesales/doqd/pkg/client" - doqserver "github.com/natesales/doqd/pkg/server" flag "github.com/spf13/pflag" "github.com/miekg/dns" @@ -23,16 +24,18 @@ import ( ) type runConfig struct { - BindIP string `json:"bindIP"` - PublicIP string `json:"publicIP"` - UpstreamDNS string `json:"upstreamDNS"` - DomainListPath string `json:"domainListPath"` - DomainListRefreshInterval duration `json:"domainListRefreshInterval"` - BindDNSOverTCP bool `json:"bindDnsOverTcp"` - BindDNSOverTLS bool `json:"bindDnsOverTls"` - BindDNSOverQuic bool `json:"bindDnsOverQuic"` - AllDomains bool `json:"allDomains"` - routeDomainList [][]string `json:"-"` + BindIP string `json:"bindIP"` + PublicIP string `json:"publicIP"` + UpstreamDNS string `json:"upstreamDNS"` + DomainListPath string `json:"domainListPath"` + DomainListRefreshInterval duration `json:"domainListRefreshInterval"` + BindDNSOverTCP bool `json:"bindDnsOverTcp"` + BindDNSOverTLS bool `json:"bindDnsOverTls"` + BindDNSOverQuic bool `json:"bindDnsOverQuic"` + AllDomains bool `json:"allDomains"` + routePrefixes *tst.TernarySearchTree + routeSuffixes *tst.TernarySearchTree + routeFQDNs map[string]uint8 } var c runConfig @@ -140,6 +143,7 @@ func lookupDomain4(domain string) (net.IP, error) { } func handle443(conn net.Conn) error { + defer conn.Close() incoming := make([]byte, 2048) n, err := conn.Read(incoming) if err != nil { @@ -151,17 +155,19 @@ func handle443(conn net.Conn) error { log.Println(err) return err } - // rAddrDns, err := performExternalQuery(dns.Question{Name: sni + ".", Qtype: dns.TypeA, Qclass: dns.ClassINET}, *upstreamDNS) - // if err != nil { - // log.Println(err) - // return err - // } - // rAddr := rAddrDns.Answer[0].(*dns.A).A + // check SNI against domainlist for an extra layer of security + if !c.AllDomains && inDomainList(sni) { + log.Warnf("[TCP] a client requested connection to %s, but it's not allowed as per configuration.. resetting TCP", sni) + conn.Close() + return nil + } rAddr, err := lookupDomain4(sni) if err != nil || rAddr == nil { log.Println(err) return err } + // TODO: handle timeout and context here + log.Infof("[TCP] connecting to %s (%s)", rAddr, sni) target, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: rAddr, Port: 443}) if err != nil { log.Println("could not connect to target", err) @@ -221,8 +227,9 @@ func runHTTPS() { handleError(err) go func() { go handle443(c) - <-time.After(30 * time.Second) - c.Close() + //TODO: there's a better way to handle TCP timeouts than just a blanket 30 seconds rule + // <-time.After(30 * time.Second) + // c.Close() }() } } @@ -297,7 +304,7 @@ func runDNS() { } // Accept QUIC connections - log.Infof("Starting QUIC listener on %s\n", ":443") + log.Infof("Starting QUIC listener on %s\n", ":8853") go doqServer.Listen() } @@ -349,27 +356,27 @@ func main() { log.Fatalf("Invalid upstream DNS URL: %s", c.UpstreamDNS) } if dnsURL.Scheme == "quic" { - c, err := doqclient.New(dnsURL.Host, true, true) + dnsC, err := doqclient.New(dnsURL.Host, true, true) if err != nil { log.Fatalf("Failed to connect to upstream DNS: %s", err.Error()) } - dnsClient.Doq = c + dnsClient.Doq = dnsC } else if dnsURL.Scheme == "https" { - c, err := doh.New(*dnsURL, true, true) + dnsC, err := doh.New(*dnsURL, true, true) if err != nil { log.Fatalf("Failed to connect to upstream DNS: %s", err.Error()) } - dnsClient.Doh = c + dnsClient.Doh = dnsC } else { - c := dns.Client{ + dnsC := dns.Client{ Net: dnsURL.Scheme, } // this dial is not used and it's only good for testing - _, err := c.Dial(dnsURL.Host) + _, err := dnsC.Dial(dnsURL.Host) if err != nil { log.Fatalf("Failed to connect to upstream DNS: %s", err.Error()) } - dnsClient.classicDNS = c + dnsClient.classicDNS = dnsC } go runHTTP() @@ -378,9 +385,11 @@ func main() { // fetch domain list and refresh them periodically if !c.AllDomains { - c.routeDomainList = loadDomainsToList(c.DomainListPath) + // c.routeDomainList = loadDomainsToList(c.DomainListPath) + c.routePrefixes, c.routeSuffixes, c.routeFQDNs = LoadDomainsCsv(c.DomainListPath) for range time.NewTicker(c.DomainListRefreshInterval.Duration).C { - c.routeDomainList = loadDomainsToList(c.DomainListPath) + // c.routeDomainList = loadDomainsToList(c.DomainListPath) + c.routePrefixes, c.routeSuffixes, c.routeFQDNs = LoadDomainsCsv(c.DomainListPath) } } else { select {}