Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dialer to prefer IPv6 #600

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,8 @@ func (p *Proxy) createFilterChain(bl *blacklist.Blacklist) (filters.Chain, proxy

// Google anomaly detection can be triggered very often over IPv6.
// Prefer IPv4 to mitigate, see issue #97
_dialer := preferIPV4Dialer(timeoutToDialOriginSite)
// TODO: remove the comment above when the issue is resolved
_dialer := preferIPV6Dialer(timeoutToDialOriginSite)
dialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
// resolve separately so that we can track the DNS resolution time
resolvedAddr, resolveErr := net.ResolveTCPAddr(network, addr)
Expand Down
70 changes: 70 additions & 0 deletions prefer_ipv6.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package proxy

import (
"context"
"errors"
"fmt"
"net"
"strconv"
"time"
)

func preferIPV6Dialer(timeout time.Duration) func(ctx context.Context, network, hostport string) (net.Conn, error) {
return func(ctx context.Context, network, hostport string) (net.Conn, error) {
tcpAddr, err := tcpAddrPrefer6IPv4Fallback(hostport)
if err != nil {
return nil, err
}

dialer := net.Dialer{
Deadline: time.Now().Add(timeout),
}
conn, err := dialer.DialContext(ctx, "tcp6", tcpAddr.String())
if err != nil {
var e *net.AddrError
// if this is a network address error, we will retry with the specified network instead (tcp4 most likely)
if errors.As(err, &e) {
conn, err = dialer.DialContext(ctx, network, hostport)
}
}
return conn, err
}
}

func tcpAddrPrefer6IPv4Fallback(hostport string) (*net.TCPAddr, error) {
host, portStr, err := net.SplitHostPort(hostport)
if err != nil {
return nil, err
}

port, err := strconv.Atoi(portStr)
if err != nil {
return nil, err
}

// Attempt to directly resolve as IPv6 to avoid unnecessary lookups
ipv6Addr, err := net.ResolveIPAddr("ip6", host)
if err == nil && ipv6Addr.IP.To4() == nil {
return &net.TCPAddr{IP: ipv6Addr.IP, Port: port}, nil
}

// If IPv6 resolution failed, fall back to a full lookup and prefer any IPv6 addresses found
ips, err := net.LookupIP(host)
if err != nil {
return nil, err
}

for _, ip := range ips {
if ip.To4() == nil {
return &net.TCPAddr{IP: ip, Port: port}, nil
}
}

// If no IPv6 addresses are found, try resolving as IPv4
ipv4Addr, err := net.ResolveTCPAddr("tcp4", hostport)
if err == nil {
return ipv4Addr, nil
}

return nil, fmt.Errorf("unable to resolve any IP addresses for host: %s", host)
}
84 changes: 84 additions & 0 deletions prefer_ipv6_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package proxy

import (
"context"
"net"
"testing"
"time"
)

func TestPreferIPV6Dialer(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
network string
hostport string
server func(t *testing.T, hostport string) func()
wantErr bool
}{
{
name: "IPv6 address",
timeout: 1 * time.Second,
network: "tcp",
hostport: "[::1]:8080",
server: func(t *testing.T, hostport string) func() { return runTestServer(t, "[::1]:8080") },
wantErr: false,
},
{
name: "IPv4 address",
timeout: 1 * time.Second,
network: "tcp",
hostport: "127.0.0.1:8080",
server: func(t *testing.T, hostport string) func() { return runTestServer(t, "127.0.0.1:8080") },
wantErr: false,
},
{
name: "Invalid address",
timeout: 1 * time.Second,
network: "tcp",
hostport: "invalid",
server: nil,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.server != nil {
closer := tt.server(t, tt.hostport)
defer closer()
}

dialer := preferIPV6Dialer(tt.timeout)
conn, err := dialer(context.Background(), tt.network, tt.hostport)
if (err != nil) != tt.wantErr {
t.Errorf("preferIPV6Dialer() error = %v, wantErr %v", err, tt.wantErr)
return
}
if conn != nil {
conn.Close()
}
})
}
}

func runTestServer(t *testing.T, addr string) func() {
listener, err := net.Listen("tcp", addr)
if err != nil {
t.Fatalf("Failed to start server: %v", err)
}

go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}
conn.Close()
}
}()

return func() {
listener.Close()
}
}
Loading