From e70975e506d4f5db7b310e3d0dfc619dcaac9e8e Mon Sep 17 00:00:00 2001 From: Kiran Kumar Mohanty Date: Wed, 21 Feb 2024 19:06:53 +0530 Subject: [PATCH] Update dialer to prefer IPv6 --- http_proxy.go | 3 +- prefer_ipv6.go | 70 +++++++++++++++++++++++++++++++++++++ prefer_ipv6_test.go | 84 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 1 deletion(-) create mode 100644 prefer_ipv6.go create mode 100644 prefer_ipv6_test.go diff --git a/http_proxy.go b/http_proxy.go index e315ccff..166310da 100644 --- a/http_proxy.go +++ b/http_proxy.go @@ -604,7 +604,8 @@ func (p *Proxy) createFilterChain(bl *blacklist.Blacklist) (filters.Chain, proxy // Google anomaly detection can be triggered very often over IPv6. // Prefer IPv4 to mitigate, see issue #97 - _dialer := preferIPV4Dialer(timeoutToDialOriginSite) + // TODO: remove the comment above when the issue is resolved + _dialer := preferIPV6Dialer(timeoutToDialOriginSite) dialer := func(ctx context.Context, network, addr string) (net.Conn, error) { // resolve separately so that we can track the DNS resolution time resolvedAddr, resolveErr := net.ResolveTCPAddr(network, addr) diff --git a/prefer_ipv6.go b/prefer_ipv6.go new file mode 100644 index 00000000..e422db97 --- /dev/null +++ b/prefer_ipv6.go @@ -0,0 +1,70 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "net" + "strconv" + "time" +) + +func preferIPV6Dialer(timeout time.Duration) func(ctx context.Context, network, hostport string) (net.Conn, error) { + return func(ctx context.Context, network, hostport string) (net.Conn, error) { + tcpAddr, err := tcpAddrPrefer6IPv4Fallback(hostport) + if err != nil { + return nil, err + } + + dialer := net.Dialer{ + Deadline: time.Now().Add(timeout), + } + conn, err := dialer.DialContext(ctx, "tcp6", tcpAddr.String()) + if err != nil { + var e *net.AddrError + // if this is a network address error, we will retry with the specified network instead (tcp4 most likely) + if errors.As(err, &e) { + conn, err = dialer.DialContext(ctx, network, hostport) + } + } + return conn, err + } +} + +func tcpAddrPrefer6IPv4Fallback(hostport string) (*net.TCPAddr, error) { + host, portStr, err := net.SplitHostPort(hostport) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, err + } + + // Attempt to directly resolve as IPv6 to avoid unnecessary lookups + ipv6Addr, err := net.ResolveIPAddr("ip6", host) + if err == nil && ipv6Addr.IP.To4() == nil { + return &net.TCPAddr{IP: ipv6Addr.IP, Port: port}, nil + } + + // If IPv6 resolution failed, fall back to a full lookup and prefer any IPv6 addresses found + ips, err := net.LookupIP(host) + if err != nil { + return nil, err + } + + for _, ip := range ips { + if ip.To4() == nil { + return &net.TCPAddr{IP: ip, Port: port}, nil + } + } + + // If no IPv6 addresses are found, try resolving as IPv4 + ipv4Addr, err := net.ResolveTCPAddr("tcp4", hostport) + if err == nil { + return ipv4Addr, nil + } + + return nil, fmt.Errorf("unable to resolve any IP addresses for host: %s", host) +} diff --git a/prefer_ipv6_test.go b/prefer_ipv6_test.go new file mode 100644 index 00000000..d9942e42 --- /dev/null +++ b/prefer_ipv6_test.go @@ -0,0 +1,84 @@ +package proxy + +import ( + "context" + "net" + "testing" + "time" +) + +func TestPreferIPV6Dialer(t *testing.T) { + tests := []struct { + name string + timeout time.Duration + network string + hostport string + server func(t *testing.T, hostport string) func() + wantErr bool + }{ + { + name: "IPv6 address", + timeout: 1 * time.Second, + network: "tcp", + hostport: "[::1]:8080", + server: func(t *testing.T, hostport string) func() { return runTestServer(t, "[::1]:8080") }, + wantErr: false, + }, + { + name: "IPv4 address", + timeout: 1 * time.Second, + network: "tcp", + hostport: "127.0.0.1:8080", + server: func(t *testing.T, hostport string) func() { return runTestServer(t, "127.0.0.1:8080") }, + wantErr: false, + }, + { + name: "Invalid address", + timeout: 1 * time.Second, + network: "tcp", + hostport: "invalid", + server: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.server != nil { + closer := tt.server(t, tt.hostport) + defer closer() + } + + dialer := preferIPV6Dialer(tt.timeout) + conn, err := dialer(context.Background(), tt.network, tt.hostport) + if (err != nil) != tt.wantErr { + t.Errorf("preferIPV6Dialer() error = %v, wantErr %v", err, tt.wantErr) + return + } + if conn != nil { + conn.Close() + } + }) + } +} + +func runTestServer(t *testing.T, addr string) func() { + listener, err := net.Listen("tcp", addr) + if err != nil { + t.Fatalf("Failed to start server: %v", err) + } + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + conn.Close() + } + }() + + return func() { + listener.Close() + } +}