From cf6826919b0aefac23758852ce97320eafcc4ea6 Mon Sep 17 00:00:00 2001 From: ptrus Date: Wed, 2 Oct 2024 10:15:20 +0200 Subject: [PATCH] go/common/grpc: allow non-tls connections to loopback addresses --- .changelog/5878.feature.md | 1 + go/common/grpc/grpc.go | 55 +++++++++++++++++++++++++++ go/common/grpc/grpc_test.go | 45 ++++++++++++++++++++++ go/oasis-node/cmd/common/grpc/grpc.go | 13 +++++-- 4 files changed, 111 insertions(+), 3 deletions(-) create mode 100644 .changelog/5878.feature.md create mode 100644 go/common/grpc/grpc_test.go diff --git a/.changelog/5878.feature.md b/.changelog/5878.feature.md new file mode 100644 index 00000000000..12552d01cf9 --- /dev/null +++ b/.changelog/5878.feature.md @@ -0,0 +1 @@ +go/common/grpc: allow non-tls connections to loopback addresses diff --git a/go/common/grpc/grpc.go b/go/common/grpc/grpc.go index aef91c881dd..25ae5bab693 100644 --- a/go/common/grpc/grpc.go +++ b/go/common/grpc/grpc.go @@ -9,6 +9,7 @@ import ( "net" "os" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -703,6 +704,60 @@ func Dial(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { return grpc.Dial(target, dialOpts...) } +// IsSocketAddress checks if the gRPC address is a socket address. +func IsSocketAddress(addr string) bool { + if strings.HasPrefix(addr, "unix:") || strings.HasPrefix(addr, "unix-abstract:") { + return true + } + if strings.HasPrefix(addr, "vsock:") { + return true + } + return false +} + +// IsLocalAddress checks if the gRPC address points to a local socket or loopback address. +// +// This function takes a conservative approach and may return false for complex addresses +// such as those with authorities or non-standard schemes. +// +// The expected format for the gRPC address is specified at: +// https://github.com/grpc/grpc/blob/master/doc/naming.md +func IsLocalAddress(addr string) bool { + // Sockets are considered local. + if IsSocketAddress(addr) { + return true + } + + // Strip dns: scheme if present, other schemes are not supported. + if strings.HasPrefix(addr, "dns:") { + addr = strings.TrimPrefix(addr, "dns:") + + // If authority is present, consider the address non-local as it might rely on complex resolver logic. + if strings.HasPrefix(addr, "//") { + return false + } + } + + // Validate the address. + host, _, err := net.SplitHostPort(addr) + if err != nil { + // Try parsing with a port. + host, _, err = net.SplitHostPort(addr + ":" + "80") + if err != nil { + // This means that the scheme was not trimmed (e.g. was not 'dns:'). + // Consider such addresses non-local as they might rely on complex resolver logic. + return false + } + // The address is parsed fine with port attached, continue. + } + + ip, err := net.LookupIP(host) + if err != nil || len(ip) == 0 { + return false + } + return ip[0].IsLoopback() +} + func init() { Flags.Bool(CfgLogDebug, false, "gRPC request/responses in debug logs (very verbose)") _ = Flags.MarkHidden(CfgLogDebug) diff --git a/go/common/grpc/grpc_test.go b/go/common/grpc/grpc_test.go new file mode 100644 index 00000000000..d18d4e18d73 --- /dev/null +++ b/go/common/grpc/grpc_test.go @@ -0,0 +1,45 @@ +package grpc + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsLocalRPC(t *testing.T) { + for _, tc := range []struct { + name string + addr string + expected bool + }{ + // Invalid inputs. + {"Empty", "", false}, + {"Invalid scheme", "test:localhost:8080", false}, + + // Local unix sockets/vsock. + {"Local unix socket", "unix:///tmp/socket", true}, + {"Local unix socket (abstract)", "unix-abstract:abstract_path", true}, + {"Local vsock", "vsock:2:12345", true}, + + // Loopback addresses. + {"Local IPv4 loopback", "127.0.0.1:8080", true}, + {"Local IPv4 loopback no port", "127.0.0.1", true}, + {"Local IPv6 loopback", "[::1]:8080", true}, + {"Local IPv6 loopback no port", "[::1]", true}, + {"Localhost no port", "localhost", true}, + {"Localhost", "localhost:8080", true}, + {"Localhost explicit scheme", "dns:localhost:8080", true}, + + // Non-local addresses. + {"Non-local address", "example.com", false}, + {"Non-local address with port", "example.com:8080", false}, + {"Non-local with explicit scheme", "dns:example.com", false}, + {"Non-local with authority", "dns://authority/example.com:8080", false}, + + // Complex addresses (conservatively considered non-local). + {"Localhost with authority", "dns://authority/localhost:8080", false}, + {"Localhost non-standard scheme", "test:localhost:8080", false}, + } { + require.Equal(t, tc.expected, IsLocalAddress(tc.addr), tc.name+": "+tc.addr) + } +} diff --git a/go/oasis-node/cmd/common/grpc/grpc.go b/go/oasis-node/cmd/common/grpc/grpc.go index 5afe5516a19..7e7a66220b0 100644 --- a/go/oasis-node/cmd/common/grpc/grpc.go +++ b/go/oasis-node/cmd/common/grpc/grpc.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "fmt" "os" - "strings" "github.com/spf13/cobra" flag "github.com/spf13/pflag" @@ -27,6 +26,8 @@ const ( CfgAddress = "address" // CfgWait waits for the remote address to become available. CfgWait = "wait" + // CfgInsecureLoopback allows non-TLS connection to loopback addresses. + CfgInsecureLoopback = "insecure" defaultAddress = "unix:" + common.InternalSocketName ) @@ -81,9 +82,14 @@ func NewClient(cmd *cobra.Command) (*grpc.ClientConn, error) { } var creds credentials.TransportCredentials - if strings.HasPrefix(addr, "unix:") { + switch { + case cmnGrpc.IsSocketAddress(addr): creds = insecure.NewCredentials() - } else { + case viper.GetBool(CfgInsecureLoopback) && cmnGrpc.IsLocalAddress(addr): + creds = insecure.NewCredentials() + case viper.GetBool(CfgInsecureLoopback): + return nil, fmt.Errorf("insecure loopback requested but address is not loopback: %s", addr) + default: creds = credentials.NewTLS(&tls.Config{}) } opts := []grpc.DialOption{grpc.WithTransportCredentials(creds)} @@ -112,6 +118,7 @@ func init() { ClientFlags.StringP(CfgAddress, "a", defaultAddress, "remote gRPC address") ClientFlags.Bool(CfgWait, false, "wait for gRPC address to become available") + ClientFlags.BoolP(CfgInsecureLoopback, "k", false, "allows non-TLS connection to loopback addresses") ClientFlags.AddFlagSet(cmnGrpc.Flags) _ = viper.BindPFlags(ClientFlags) }