Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non-root hosts failing on resolving DNS #2269

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 69 additions & 58 deletions rpcclient/infrastructure.go
Original file line number Diff line number Diff line change
Expand Up @@ -759,41 +759,26 @@ out:
// result, unmarshalling it, and delivering the unmarshalled result to the
// provided response channel.
func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
protocol := "http"
if !c.config.DisableTLS {
protocol = "https"
}

var (
err, lastErr error
lastErr error
backoff time.Duration
httpResponse *http.Response
)

parsedAddr, err := ParseAddressString(c.config.Host)
httpURL, err := c.config.httpURL()
if err != nil {
jReq.responseChan <- &Response{
err: fmt.Errorf("failed to parse address %v", err),
}
return
}

var url string
switch parsedAddr.Network() {
case "unix", "unixpacket":
// Using a placeholder URL because a non-empty URL is required.
// The Unix domain socket is specified in the DialContext.
url = protocol + "://unix"
default:
url = protocol + "://" + c.config.Host
}

tries := 10
for i := 0; i < tries; i++ {
var httpReq *http.Request

bodyReader := bytes.NewReader(jReq.marshalledJSON)
httpReq, err = http.NewRequest("POST", url, bodyReader)
httpReq, err = http.NewRequest("POST", httpURL, bodyReader)
if err != nil {
jReq.responseChan <- &Response{result: nil, err: err}
return
Expand Down Expand Up @@ -1355,16 +1340,21 @@ func newHTTPClient(config *ConnConfig) (*http.Client, error) {
}
}

parsedAddr, err := ParseAddressString(config.Host)
parsedDialAddr, err := ParseAddressString(config.Host)
if err != nil {
return nil, err
}
client := http.Client{
Transport: &http.Transport{
Proxy: proxyFunc,
TLSClientConfig: tlsConfig,
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial(parsedAddr.Network(), parsedAddr.String())
DialContext: func(_ context.Context, _,
_ string) (net.Conn, error) {

return net.Dial(
parsedDialAddr.Network(),
parsedDialAddr.String(),
)
},
},
Timeout: defaultHTTPTimeout,
Expand All @@ -1373,6 +1363,32 @@ func newHTTPClient(config *ConnConfig) (*http.Client, error) {
return &client, nil
}

// httpURL returns the URL to use for HTTP POST requests.
func (config *ConnConfig) httpURL() (string, error) {
protocol := "http"
if !config.DisableTLS {
protocol = "https"
}

parsedAddr, err := ParseAddressString(config.Host)
if err != nil {
return "", fmt.Errorf("error parsing host '%v': %v",
config.Host, err)
}

var httpURL string
switch parsedAddr.Network() {
case "unix", "unixpacket":
// Using a placeholder URL because a non-empty URL is required.
// The Unix domain socket is specified in the DialContext.
httpURL = protocol + "://unix"
default:
httpURL = protocol + "://" + config.Host
}

return httpURL, nil
}

// dial opens a websocket connection using the passed connection configuration
// details.
func dial(config *ConnConfig) (*websocket.Conn, error) {
Expand Down Expand Up @@ -1733,53 +1749,48 @@ func (c *Client) Send() error {
return nil
}

// cutPrefix returns s without the provided leading prefix string
// and reports whether it found the prefix.
// If s doesn't start with prefix, cutPrefix returns s, false.
// If prefix is the empty string, cutPrefix returns s, true.
// Copied from go1.20 version.
func cutPrefix(s, prefix string) (after string, found bool) {
if !strings.HasPrefix(s, prefix) {
return s, false
}
return s[len(prefix):], true
}

// ParseAddressString converts an address in string format to a net.Addr that is
// compatible with btcd. UDP is not supported because btcd needs reliable
// connections. We accept a custom function to resolve any TCP addresses so
// that caller is able control exactly how resolution is performed.
// connections.
func ParseAddressString(strAddress string) (net.Addr, error) {
var parsedNetwork, parsedAddr string
// Addresses can either be in unix://address, unixpacket://address URL
// format, or just address:port host format for tcp.
if after, ok := cutPrefix(strAddress, "unix://"); ok {
return net.ResolveUnixAddr("unix", after)
}
if after, ok := cutPrefix(strAddress, "unixpacket://"); ok {
return net.ResolveUnixAddr("unixpacket", after)
}

// Addresses can either be in network://address:port format,
// network:address:port, address:port, or just port. We want to support
// all possible types.
if strings.Contains(strAddress, "://") {
parts := strings.Split(strAddress, "://")
parsedNetwork, parsedAddr = parts[0], parts[1]
} else if strings.Contains(strAddress, ":") {
parts := strings.Split(strAddress, ":")
parsedNetwork = parts[0]
parsedAddr = strings.Join(parts[1:], ":")
} else {
parsedAddr = strAddress
// Not supporting :// anywhere in the host or path.
return nil, fmt.Errorf("unsupported protocol in address: %s",
strAddress)
}

// Only TCP and Unix socket addresses are valid. We can't use IP or
// UDP only connections for anything we do in lnd.
switch parsedNetwork {
case "unix", "unixpacket":
return net.ResolveUnixAddr(parsedNetwork, parsedAddr)

case "tcp", "tcp4", "tcp6":
return net.ResolveTCPAddr(parsedNetwork, verifyPort(parsedAddr))

case "ip", "ip4", "ip6", "udp", "udp4", "udp6", "unixgram":
return nil, fmt.Errorf("only TCP or unix socket "+
"addresses are supported: %s", parsedAddr)

default:
// We'll now possibly use the local host short circuit
// or parse out an all interfaces listen.
addrWithPort := verifyPort(strAddress)

// Otherwise, we'll attempt to resolve the host.
return net.ResolveTCPAddr("tcp", addrWithPort)
// Parse it as a dummy URL to get the host and port.
u, err := url.Parse("dummy://" + strAddress)
if err != nil {
return nil, err
}
return net.ResolveTCPAddr("tcp", verifyPort(u.Host))
guggero marked this conversation as resolved.
Show resolved Hide resolved
}

// verifyPort makes sure that an address string has both a host and a port.
// If the address is just a port, then we'll assume that the user is using the
// short cut to specify a localhost:port address.
// shortcut to specify a localhost:port address.
func verifyPort(address string) string {
host, port, err := net.SplitHostPort(address)
if err != nil {
Expand All @@ -1801,8 +1812,8 @@ func verifyPort(address string) string {
return net.JoinHostPort(address, "")
}

// In the case that both the host and port are empty, we'll use the
// an empty port.
// In the case that both the host and port are empty, we'll use an empty
// port.
if host == "" && port == "" {
return ":"
}
Expand Down
110 changes: 110 additions & 0 deletions rpcclient/infrastructure_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package rpcclient

import (
"testing"

"github.com/stretchr/testify/require"
)

// TestParseAddressString checks different variation of supported and
// unsupported addresses.
func TestParseAddressString(t *testing.T) {
t.Parallel()

// Using localhost only to avoid network calls.
testCases := []struct {
name string
addressString string
expNetwork string
expAddress string
expErrStr string
}{
{
name: "localhost",
addressString: "localhost",
expNetwork: "tcp",
expAddress: "127.0.0.1:0",
},
{
name: "localhost ip",
addressString: "127.0.0.1",
expNetwork: "tcp",
expAddress: "127.0.0.1:0",
},
{
name: "localhost ipv6",
addressString: "::1",
expNetwork: "tcp",
expAddress: "[::1]:0",
},
{
name: "localhost and port",
addressString: "localhost:80",
expNetwork: "tcp",
expAddress: "127.0.0.1:80",
},
{
name: "localhost ipv6 and port",
addressString: "[::1]:80",
expNetwork: "tcp",
expAddress: "[::1]:80",
},
{
name: "colon and port",
addressString: ":80",
expNetwork: "tcp",
expAddress: ":80",
},
{
name: "colon only",
addressString: ":",
expNetwork: "tcp",
expAddress: ":0",
},
{
name: "localhost and path",
addressString: "localhost/path",
expNetwork: "tcp",
expAddress: "127.0.0.1:0",
},
{
name: "localhost port and path",
addressString: "localhost:80/path",
expNetwork: "tcp",
expAddress: "127.0.0.1:80",
},
{
name: "unix prefix",
addressString: "unix://the/rest/of/the/path",
expNetwork: "unix",
expAddress: "the/rest/of/the/path",
},
{
name: "unix prefix",
addressString: "unixpacket://the/rest/of/the/path",
expNetwork: "unixpacket",
expAddress: "the/rest/of/the/path",
},
{
name: "error http prefix",
addressString: "http://localhost:1010",
expErrStr: "unsupported protocol in address",
},
}

for _, tc := range testCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
addr, err := ParseAddressString(tc.addressString)
if tc.expErrStr != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tc.expErrStr)
return
}
require.NoError(t, err)
require.Equal(t, tc.expNetwork, addr.Network())
require.Equal(t, tc.expAddress, addr.String())
})
}
}
Loading