From 7b659c6bb1fec0240a3c93692bd71af3987a41b3 Mon Sep 17 00:00:00 2001 From: Johnny Graettinger Date: Fri, 20 Sep 2024 00:13:49 -0500 Subject: [PATCH] go/network: refactored and updated connector networking feature Move connector networking entirely into this repo, from the legacy data-plane-gatweay repo, and significantly retool it along the way to: * Improve latency and throughput of HTTP reverse-proxy cases, by allowing the reverse proxy to use multiple pooled connections built atop network proxy RPCs with reasonable idle timeouts. This improves concurrency as many HTTP/2 requests can be in flight at once, and improves latency to the end user by ammortizing connections to reduce aggregate TCP and TLS startup time. * Improve user-facing error experience around misconfigurations, by often assuming an HTTP protocol and yielding a more informative error. * Overhauling metrics that we collect. * Updating the authorization flow, laying groundwork for the UI to use the /authorize/user/task API (but not requiring it just yet). --- go/network/auth.go | 120 ++++++ go/network/frontend.go | 392 ++++++++++++++++++ go/network/metrics.go | 51 +++ go/network/proxy_client.go | 196 +++++++++ .../proxy.go => network/proxy_server.go} | 55 +-- go/network/sni.go | 115 +++++ go/network/tap.go | 50 +++ go/runtime/flow_consumer.go | 41 +- go/runtime/task.go | 2 +- 9 files changed, 972 insertions(+), 50 deletions(-) create mode 100644 go/network/auth.go create mode 100644 go/network/frontend.go create mode 100644 go/network/metrics.go create mode 100644 go/network/proxy_client.go rename go/{runtime/proxy.go => network/proxy_server.go} (71%) create mode 100644 go/network/sni.go create mode 100644 go/network/tap.go diff --git a/go/network/auth.go b/go/network/auth.go new file mode 100644 index 0000000000..8f7c581293 --- /dev/null +++ b/go/network/auth.go @@ -0,0 +1,120 @@ +package network + +import ( + "errors" + "fmt" + "net/http" + "net/url" + + pb "go.gazette.dev/core/broker/protocol" + "google.golang.org/grpc/metadata" +) + +// verifyAuthorization ensures the request has an authorization which +// is valid for capability NETWORK_PROXY to `taskName`. +func verifyAuthorization(req *http.Request, verifier pb.Verifier, taskName string) error { + var bearer = req.Header.Get("authorization") + if bearer != "" { + // Pass. + } else if cookie, err := req.Cookie(AuthCookieName); err == nil { + bearer = fmt.Sprintf("Bearer %s", cookie.Value) + } else { + return errors.New("missing authorization") + } + + var _, cancel, claims, err = verifier.Verify( + metadata.NewIncomingContext( + req.Context(), + metadata.Pairs("authorization", bearer), + ), + 0, // TODO(johnny): Should be pf.Capability_NETWORK_PROXY. + ) + if err != nil { + return err + } + cancel() // We don't use the returned context. + + /* TODO(johnny): Inspect claims once UI is updated to use /authorize/user/task API. + if !claims.Selector.Matches(pb.MustLabelSet( + labels.TaskName, taskName, + )) { + return fmt.Errorf("invalid authorization for task %s (%s)", taskName, bearer) + } + */ + _ = claims + + return nil +} + +// startAuthRedirect redirect an interactive user to the dashboard, which will +// obtain a user task authorization and redirect back to us with it. +func startAuthRedirect(w http.ResponseWriter, req *http.Request, err error, dashboard *url.URL, taskName string) { + var query = make(url.Values) + query.Add("orig_url", "https://"+req.Host+req.URL.Path) + query.Add("task", taskName) + query.Add("prefix", taskName) + query.Add("err", err.Error()) // Informational. + + var target = dashboard.JoinPath("/data-plane-auth-req") + target.RawQuery = query.Encode() + + http.Redirect(w, req, target.String(), http.StatusTemporaryRedirect) +} + +// completeAuthRedirect handles path "/auth-redirect" as part of a redirect chain +// back from the dashboard. It expects a token parameter, which is set as a cookie, +// and an original URL which it in-turn redirects to. +func completeAuthRedirect(w http.ResponseWriter, req *http.Request) { + var params = req.URL.Query() + + var token = params.Get("token") + if token == "" { + http.Error(w, "URL is missing required `token` parameter", http.StatusBadRequest) + return + } + var origUrl = params.Get("orig_url") + if origUrl == "" { + http.Error(w, "URL is missing required `orig_url` parameter", http.StatusBadRequest) + return + } + + var cookie = &http.Cookie{ + Name: AuthCookieName, + Value: token, + Secure: true, + HttpOnly: true, + Path: "/", + } + http.SetCookie(w, cookie) + + http.Redirect(w, req, origUrl, http.StatusTemporaryRedirect) +} + +func scrubProxyRequest(req *http.Request, public bool) { + if _, ok := req.Header["User-Agent"]; !ok { + req.Header.Set("User-Agent", "") // Omit auto-added User-Agent. + } + + if public { + return // All done. + } + + // Scrub authentication token(s) from the request. + req.Header.Del("Authorization") + + // There's no `DeleteCookie` function, so we parse them, delete them all, and + // add them back in while filtering out the flow_auth cookie. + var cookies = req.Cookies() + req.Header.Del("Cookie") + + for _, cookie := range cookies { + if cookie.Name != AuthCookieName { + req.AddCookie(cookie) + } + } +} + +// AuthCookieName is the name of the cookie that we use for passing the JWT for interactive logins. +// It's name begins with '__Host-' in order to opt in to some additional security restrictions. +// See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#cookie_prefixes +const AuthCookieName = "__Host-flow_auth" diff --git a/go/network/frontend.go b/go/network/frontend.go new file mode 100644 index 0000000000..5c2fbffc1e --- /dev/null +++ b/go/network/frontend.go @@ -0,0 +1,392 @@ +package network + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "math" + "net" + "net/http" + "net/http/httputil" + "net/url" + "reflect" + "slices" + "strings" + "sync" + "time" + + pf "github.com/estuary/flow/go/protocols/flow" + lru "github.com/hashicorp/golang-lru/v2" + pb "go.gazette.dev/core/broker/protocol" + pc "go.gazette.dev/core/consumer/protocol" + "golang.org/x/net/http2" +) + +// Frontend accepts connections over its configured Listener, +// matches on the TLS ServerName (SNI), and either: +// +// - Passes unmatched connections on to the tapped Listener +// - Attaches the connection to a connector via TCP proxy +// - Serves the connection as HTTP/2 using an authorizing reverse proxy +// - Or, returns an HTTP/1.1 descriptive error about what went wrong +// +// If Frontend is not running in a TLS context, all connections are +// trivially passed to the tapped Listener. +type Frontend struct { + controlAPI *url.URL + dashboard *url.URL + domains []string + networkClient pf.NetworkProxyClient + shardClient pc.ShardClient + verifier pb.Verifier + + listener net.Listener // Tapped listener. + tlsConfig *tls.Config // TLS config for accepted connections. + + // Forwarding channel for connections that are passed on. + fwdCh chan<- net.Conn + fwdErr *error + + // Cache of mappings from parsed to resolved SNIs. + sniCache *lru.Cache[parsedSNI, resolvedSNI] + + // Map of frontend connections currently undergoing TLS handshake. + handshake map[uintptr]*frontendConn + handshakeMu sync.Mutex +} + +// frontendConn is the state of a connection initiated +// by a user into the Frontend. +type frontendConn struct { + id uintptr + ctx context.Context + + // Raw and TLS-wrapped connections to the user. + raw net.Conn + tls *tls.Conn + + pass bool + parsed parsedSNI + resolved resolvedSNI + + // Error while resolving SNI to a mapped task: + // the SNI is invalid with respect to current task config. + sniErr error + // proxyClient dialed during TLS handshake. + // If set, we are acting as a TCP proxy. + dialed *proxyClient + // Error while dialing a shard for TCP proxy: + // this is an internal and usually-temporary error. + dialErr error +} + +func NewFrontend( + tap *Tap, + fqdn string, + controlAPI *url.URL, + dashboard *url.URL, + networkClient pf.NetworkProxyClient, + shardClient pc.ShardClient, + verifier pb.Verifier, +) (*Frontend, error) { + if tap.raw == nil { + return nil, fmt.Errorf("Tap has not tapped a raw net.Listener") + } + + var domains, fqdnParts []string = nil, strings.Split(fqdn, ".") + for i := range fqdnParts { + domains = append(domains, strings.Join(fqdnParts[i:], ".")) + } + var sniCache, err = lru.New[parsedSNI, resolvedSNI](1024) + if err != nil { + panic(err) + } + + var proxy = &Frontend{ + controlAPI: controlAPI, + dashboard: dashboard, + domains: domains, + networkClient: networkClient, + shardClient: shardClient, + verifier: verifier, + listener: tap.raw, + tlsConfig: tap.config, + fwdCh: tap.fwdCh, + fwdErr: &tap.fwdErr, + sniCache: sniCache, + handshake: make(map[uintptr]*frontendConn), + } + if proxy.tlsConfig != nil { + proxy.tlsConfig.GetConfigForClient = proxy.getTLSConfigForClient + } + + return proxy, nil +} + +func (p *Frontend) Serve(ctx context.Context) (_err error) { + defer func() { + // Forward terminal error to callers of Tap.Accept(). + *p.fwdErr = _err + close(p.fwdCh) + }() + + for { + var raw, err = p.listener.Accept() + if err != nil { + return err + } + if p.tlsConfig == nil { + p.fwdCh <- raw // Not serving TLS. + continue + } + go p.serveConn(ctx, raw) + } +} + +func (p *Frontend) serveConn(ctx context.Context, raw net.Conn) { + var conn = &frontendConn{ + id: reflect.ValueOf(raw).Pointer(), + ctx: ctx, + raw: raw, + tls: tls.Server(raw, p.tlsConfig), + } + + // Push `conn` onto the map of current handshakes. + p.handshakeMu.Lock() + p.handshake[conn.id] = conn + p.handshakeMu.Unlock() + + // The TLS handshake machinery will next call into getTLSConfigForClient(). + var err = conn.tls.HandshakeContext(conn.ctx) + + // Clear `conn` from the map of current handshakes. + p.handshakeMu.Lock() + delete(p.handshake, conn.id) + p.handshakeMu.Unlock() + + if err != nil { + handshakeCounter.WithLabelValues(err.Error()).Inc() // `err` is low-variance. + p.serveConnErr(conn.raw, 421, "This service may only be accessed using TLS, such as through an https:// URL.\n") + return + } + if conn.pass { + handshakeCounter.WithLabelValues("OKPass").Inc() + p.fwdCh <- conn.tls // Connection is not for us. + return + } + + if conn.sniErr != nil { + handshakeCounter.WithLabelValues("ErrSNI").Inc() + p.serveConnErr(conn.tls, 404, fmt.Sprintf("Failed to match the connection to a task:\n\t%s\n", conn.sniErr)) + } else if conn.dialErr != nil { + handshakeCounter.WithLabelValues("ErrDial").Inc() + p.serveConnErr(conn.tls, 503, fmt.Sprintf("Failed to connect to a task shard:\n\t%s\n", conn.dialErr)) + } else if conn.dialed != nil { + handshakeCounter.WithLabelValues("OkTCP").Inc() + p.serveConnTCP(conn) + } else { + handshakeCounter.WithLabelValues("OkHTTP").Inc() + p.serveConnHTTP(conn) + } +} + +func (p *Frontend) getTLSConfigForClient(hello *tls.ClientHelloInfo) (*tls.Config, error) { + p.handshakeMu.Lock() + var conn = p.handshake[reflect.ValueOf(hello.Conn).Pointer()] + p.handshakeMu.Unlock() + + // Exact match of the FQDN or a parent domain means it's not for us. + if slices.Contains(p.domains, hello.ServerName) { + conn.pass = true + return nil, nil + } + + var ok bool + var target, service, _ = strings.Cut(hello.ServerName, ".") + + // This block parses the SNI `target` and matches it to shard configuration. + if !slices.Contains(p.domains, service) { + conn.sniErr = fmt.Errorf("TLS ServerName %s is an invalid domain", hello.ServerName) + } else if conn.parsed, conn.sniErr = parseSNI(target); conn.sniErr != nil { + // No need to wrap error. + } else if conn.resolved, ok = p.sniCache.Get(conn.parsed); !ok { + // We didn't hit cache while resolving the parsed SNI. + // We must fetch matching shard specs to inspect their shard ID prefix and port config. + var shards []pc.ListResponse_Shard + shards, conn.sniErr = listShards(hello.Context(), p.shardClient, conn.parsed, "") + + if conn.sniErr != nil { + conn.sniErr = fmt.Errorf("fetching shards: %w", conn.sniErr) + } else if len(shards) == 0 { + conn.sniErr = errors.New("the requested subdomain does not match a known task and port combination") + } else { + conn.resolved = newResolvedSNI(conn.parsed, &shards[0].Spec) + p.sniCache.Add(conn.parsed, conn.resolved) + } + } + + if conn.sniErr == nil && conn.resolved.portProtocol != "" { + // We intend to TCP proxy to the connector. Dial the shard now so that + // we fail-fast during TLS handshake, instead of letting the client + // think it has a good connection. + var addr = conn.raw.RemoteAddr().String() + conn.dialed, conn.dialErr = dialShard( + conn.ctx, p.networkClient, p.shardClient, conn.parsed, conn.resolved, addr) + } + + var nextProtos []string + if conn.sniErr != nil || conn.dialErr != nil { + nextProtos = []string{"http/1.1"} // We'll send a descriptive HTTP/1.1 error. + } else if conn.dialed == nil { + nextProtos = []string{"h2"} // We'll reverse-proxy. The user MUST speak HTTP/2. + } else { + nextProtos = []string{conn.resolved.portProtocol} // We'll TCP proxy. + } + + return &tls.Config{ + Certificates: p.tlsConfig.Certificates, + NextProtos: nextProtos, + }, nil +} + +func (p *Frontend) serveConnTCP(user *frontendConn) { + var task, port, proto = user.resolved.taskName, user.parsed.port, user.resolved.portProtocol + userStartedCounter.WithLabelValues(task, port, proto).Inc() + + // Enable TCP keep-alive to ensure broken user connections are closed. + if tcpConn, ok := user.raw.(*net.TCPConn); ok { + tcpConn.SetKeepAlive(true) + tcpConn.SetKeepAlivePeriod(time.Minute) + } + + var ( + done = make(chan struct{}) + errBack error + errFwd error + shard = user.dialed + ) + + // Backward loop that reads from `shard` and writes to `user`. + // This may be sitting in a call to shard.Read() which races shard.Close(). + go func() { + _, errBack = io.Copy(user.tls, shard) + _ = user.tls.CloseWrite() + close(done) + }() + + // Forward loop that reads from `user` and writes to `shard`. + if _, errFwd = io.Copy(shard, user.tls); errFwd == nil { + _ = shard.rpc.CloseSend() // Allow reads to drain. + } else { + // `shard` write RST or `user` read error. + // Either way, we want to abort reads from `shard` => `user`. + _ = shard.Close() + } + <-done + + // If errBack is: + // - nil, then we read a clean EOF from shard and wrote it all to the user. + // - A user write RST, then errFwd MUST be an error read from the user and shard.Close() was already called. + // - A shard read error, then the shard RPC is already done. + _ = user.tls.Close() + + var status string + if errFwd != nil && errBack != nil { + status = "Err" + } else if errFwd != nil { + status = "ErrUser" + } else if errBack != nil { + status = "ErrShard" + } else { + status = "OK" + } + userHandledCounter.WithLabelValues(task, port, proto, status).Inc() +} + +func (p *Frontend) serveConnHTTP(user *frontendConn) { + var task, port, proto = user.resolved.taskName, user.parsed.port, user.resolved.portProtocol + userStartedCounter.WithLabelValues(task, port, proto).Inc() + + var transport = &http.Transport{ + DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { + return dialShard(ctx, p.networkClient, p.shardClient, user.parsed, user.resolved, user.raw.RemoteAddr().String()) + }, + // Connections are fairly cheap because they're "dialed" over an + // established gRPC HTTP/2 transport, but they do require a + // Open / Opened round trip and we'd like to re-use them. + // Note also that the maximum number of connections is implicitly + // bounded by http2.Server's MaxConcurrentStreams (default: 100), + // and the gRPC transport doesn't bound the number of streams. + IdleConnTimeout: 5 * time.Second, + MaxConnsPerHost: 0, // No limit. + MaxIdleConns: 0, // No limit. + MaxIdleConnsPerHost: math.MaxInt, + } + + var reverse = httputil.ReverseProxy{ + Director: func(req *http.Request) { + req.URL.Scheme = "https" + req.URL.Host = user.parsed.hostname + scrubProxyRequest(req, user.resolved.portIsPublic) + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + var body = fmt.Sprintf("Service temporarily unavailable: %s\nPlease retry in a moment.", err) + http.Error(w, body, http.StatusServiceUnavailable) + httpHandledCounter.WithLabelValues(task, port, "ErrProxy").Inc() + }, + ModifyResponse: func(r *http.Response) error { + httpHandledCounter.WithLabelValues(task, port, r.Status).Inc() + return nil + }, + Transport: transport, + } + + var handle = func(w http.ResponseWriter, req *http.Request) { + httpStartedCounter.WithLabelValues(task, port, req.Method).Inc() + + if user.resolved.portIsPublic { + reverse.ServeHTTP(w, req) + } else if req.URL.Path == "/auth-redirect" { + completeAuthRedirect(w, req) + httpHandledCounter.WithLabelValues(task, port, "CompleteAuth").Inc() + } else if err := verifyAuthorization(req, p.verifier, user.resolved.taskName); err == nil { + reverse.ServeHTTP(w, req) + } else if req.Method == "GET" && strings.Contains(req.Header.Get("accept"), "html") { + // Presence of "html" in Accept means this is probably a browser. + // Start a redirect chain to obtain an authorization cookie. + startAuthRedirect(w, req, err, p.dashboard, user.resolved.taskName) + httpHandledCounter.WithLabelValues(task, port, "StartAuth").Inc() + } else { + http.Error(w, err.Error(), http.StatusForbidden) + httpHandledCounter.WithLabelValues(task, port, "MissingAuth").Inc() + } + } + + (&http2.Server{ + // IdleTimeout can be generous: it's intended to catch broken TCP transports. + // MaxConcurrentStreams is an important setting left as the default (100). + IdleTimeout: time.Minute, + }).ServeConn(user.tls, &http2.ServeConnOpts{ + Context: user.ctx, + Handler: http.HandlerFunc(handle), + }) + + userHandledCounter.WithLabelValues(task, port, proto, "OK").Inc() +} + +func (f *Frontend) serveConnErr(conn net.Conn, status int, body string) { + // We're terminating this connection and sending a best-effort error. + // We don't know what the client is, or even if they speak HTTP, + // but we do only offer `http/1.1` during ALPN under an error condition. + var resp, _ = httputil.DumpResponse(&http.Response{ + ProtoMajor: 1, ProtoMinor: 1, + StatusCode: status, + Body: io.NopCloser(strings.NewReader(body)), + Close: true, + }, true) + + _, _ = conn.Write(resp) + _ = conn.Close() +} diff --git a/go/network/metrics.go b/go/network/metrics.go new file mode 100644 index 0000000000..73856f7461 --- /dev/null +++ b/go/network/metrics.go @@ -0,0 +1,51 @@ +package network + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var handshakeCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "flow_net_proxy_handshake_total", + Help: "counter of connections which attempted TLS handshake with the connector network proxy frontend", +}, []string{"status"}) + +var userStartedCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "flow_net_proxy_user_started_total", + Help: "counter of started user-initiated connections to the connector network proxy frontend", +}, []string{"task", "port", "proto"}) + +var userHandledCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "flow_net_proxy_user_handled_total", + Help: "counter of handled user-initiated connections to the connector network proxy frontend", +}, []string{"task", "port", "proto", "status"}) + +var shardStartedCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "flow_net_proxy_shard_started_total", + Help: "counter of started shard connector client connections initiated by the network proxy", +}, []string{"task", "port", "proto"}) + +var shardHandledCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "flow_net_proxy_shard_handled_total", + Help: "counter of handled shard connector client connections initiated by the network proxy", +}, []string{"task", "port", "proto", "status"}) + +var httpStartedCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "flow_net_proxy_http_started_total", + Help: "counter of started reverse-proxy connector HTTP requests initiated by the network proxy", +}, []string{"task", "port", "method"}) + +var httpHandledCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "flow_net_proxy_http_handled_total", + Help: "counter of handled reverse-proxy connector HTTP requests initiated by the network proxy", +}, []string{"task", "port", "status"}) + +var bytesReceivedCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "flow_net_proxy_bytes_received_total", + Help: "counter of bytes received from user connections by the connector network proxy frontend", +}, []string{"task", "port", "proto"}) + +var bytesSentCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "flow_net_proxy_bytes_sent_total", + Help: "counter of bytes sent to user connections by the connector network proxy frontend", +}, []string{"task", "port", "proto"}) diff --git a/go/network/proxy_client.go b/go/network/proxy_client.go new file mode 100644 index 0000000000..eeaa1c2c6d --- /dev/null +++ b/go/network/proxy_client.go @@ -0,0 +1,196 @@ +package network + +import ( + "context" + "fmt" + "io" + "math/rand/v2" + "net" + "strconv" + "time" + + pf "github.com/estuary/flow/go/protocols/flow" + "github.com/prometheus/client_golang/prometheus" + pb "go.gazette.dev/core/broker/protocol" + pc "go.gazette.dev/core/consumer/protocol" +) + +// proxyClient is a connection between the frontend +// and the shard assignment currently hosting the connector. +type proxyClient struct { + buf []byte // Unread remainder of last response. + picked pc.ListResponse_Shard // Picked primary shard assignment and route. + rpc pf.NetworkProxy_ProxyClient // Running RPC. + rxCh chan struct{} // Token for a capability to read from `rpc`. + labels []string // Metric labels. + nWrite prometheus.Counter // Accumulates bytes written. + nRead prometheus.Counter // Accumulates bytes read. +} + +func dialShard( + ctx context.Context, + networkClient pf.NetworkProxyClient, + shardClient pc.ShardClient, + parsed parsedSNI, + resolved resolvedSNI, + userAddr string, +) (*proxyClient, error) { + var labels = []string{resolved.taskName, parsed.port, resolved.portProtocol} + shardStartedCounter.WithLabelValues(labels...).Inc() + + var fetched, err = listShards(ctx, shardClient, parsed, resolved.shardIDPrefix) + if err != nil { + shardHandledCounter.WithLabelValues(append(labels, "ErrList")...).Inc() + return nil, fmt.Errorf("failed to list matching task shards: %w", err) + } + + // Pick a random primary. + rand.Shuffle(len(fetched), func(i, j int) { fetched[i], fetched[j] = fetched[j], fetched[i] }) + + var primary = -1 + for i := range fetched { + if fetched[i].Route.Primary != -1 { + primary = i + break + } + } + if primary == -1 { + shardHandledCounter.WithLabelValues(append(labels, "ErrNoPrimary")...).Inc() + return nil, fmt.Errorf("task has no ready primary shard assignment") + } + + var claims = pb.Claims{ + Capability: pf.Capability_NETWORK_PROXY, + Selector: pb.LabelSelector{ + Include: pb.MustLabelSet("id:prefix", resolved.shardIDPrefix), + }, + } + var picked = fetched[primary] + + rpc, err := networkClient.Proxy( + pb.WithDispatchRoute( + pb.WithClaims(ctx, claims), + picked.Route, + picked.Route.Members[picked.Route.Primary], + ), + ) + if err != nil { + shardHandledCounter.WithLabelValues(append(labels, "ErrProxy")...).Inc() + return nil, fmt.Errorf("failed to start network proxy RPC to task shard: %w", err) + } + + var port, _ = strconv.ParseUint(parsed.port, 10, 16) // parseSNI() already verified. + var openErr = rpc.Send(&pf.TaskNetworkProxyRequest{ + Open: &pf.TaskNetworkProxyRequest_Open{ + ShardId: picked.Spec.Id, + TargetPort: uint32(port), + ClientAddr: userAddr, + }, + }) + + opened, err := rpc.Recv() + if err != nil { + err = fmt.Errorf("failed to read opened response from task shard: %w", err) + } else if opened.OpenResponse == nil { + err = fmt.Errorf("task shard proxy RPC is missing expected OpenResponse") + } else if status := opened.OpenResponse.Status; status != pf.TaskNetworkProxyResponse_OK { + err = fmt.Errorf("task shard proxy RPC has non-ready status: %s", status) + } else if openErr != nil { + err = fmt.Errorf("failed to send open request: %w", err) + } + + if err != nil { + rpc.CloseSend() + _, _ = rpc.Recv() + shardHandledCounter.WithLabelValues(append(labels, "ErrOpen")...).Inc() + return nil, err + } + + var rxCh = make(chan struct{}, 1) + rxCh <- struct{}{} + + // Received and sent from the user's perspective. + var nWrite = bytesReceivedCounter.WithLabelValues(labels...) + var nRead = bytesSentCounter.WithLabelValues(labels...) + + return &proxyClient{ + buf: nil, + picked: picked, + rpc: rpc, + rxCh: rxCh, + labels: labels, + nWrite: nWrite, + nRead: nRead, + }, nil +} + +// Write to the shard proxy client. MUST not be called concurrently with Close. +func (pc *proxyClient) Write(b []byte) (n int, err error) { + if err = pc.rpc.Send(&pf.TaskNetworkProxyRequest{Data: b}); err != nil { + return 0, err + } + pc.nWrite.Add(float64(len(b))) + return len(b), nil +} + +// Read from the shard proxy client. MAY be called concurrently with Close. +func (pc *proxyClient) Read(b []byte) (n int, err error) { + if len(pc.buf) == 0 { + if _, ok := <-pc.rxCh; !ok { + return 0, io.EOF // RPC already completed. + } + + if rx, err := pc.rpc.Recv(); err != nil { + close(pc.rxCh) + + if err == io.EOF { + shardHandledCounter.WithLabelValues(append(pc.labels, "OK")...).Inc() + } else { + shardHandledCounter.WithLabelValues(append(pc.labels, "ErrRead")...).Inc() + } + return 0, err + } else { + pc.buf = rx.Data + pc.rxCh <- struct{}{} // Yield token. + pc.nRead.Add(float64(len(rx.Data))) + } + } + + var i = copy(b, pc.buf) + pc.buf = pc.buf[i:] + return i, nil +} + +// Close the proxy client. MAY be called concurrently with Read. +func (pc *proxyClient) Close() error { + // Note that http.Transport in particular will sometimes but not always race + // calls of Read() and Close(). We must ensure the RPC reads a final error as + // part of Close(), because we can't guarantee a current or future call to + // Read() will occur, but there may also be a raced Read() which will receive + // EOF after we CloseSend() -- and if we naively attempted another pc.rpc.Recv() + // it would block forever. + var _ = pc.rpc.CloseSend() + + if _, ok := <-pc.rxCh; !ok { + return nil // Read already completed. + } + close(pc.rxCh) // Future Read()'s return EOF. + + for { + if _, err := pc.rpc.Recv(); err == io.EOF { + shardHandledCounter.WithLabelValues(append(pc.labels, "OK")...).Inc() + return nil + } else if err != nil { + shardHandledCounter.WithLabelValues(append(pc.labels, "ErrClose")...).Inc() + return err + } + } +} + +func (sc *proxyClient) LocalAddr() net.Addr { return nil } +func (sc *proxyClient) RemoteAddr() net.Addr { return nil } +func (sc *proxyClient) SetDeadline(t time.Time) error { return nil } +func (sc *proxyClient) SetReadDeadline(t time.Time) error { return nil } +func (sc *proxyClient) SetWriteDeadline(t time.Time) error { return nil } + +var _ net.Conn = &proxyClient{} diff --git a/go/runtime/proxy.go b/go/network/proxy_server.go similarity index 71% rename from go/runtime/proxy.go rename to go/network/proxy_server.go index 379f9d329f..cbae98a3b5 100644 --- a/go/runtime/proxy.go +++ b/go/network/proxy_server.go @@ -1,28 +1,29 @@ -package runtime +package network import ( "context" "fmt" "io" "net" - "strconv" "sync/atomic" pf "github.com/estuary/flow/go/protocols/flow" "github.com/estuary/flow/go/protocols/ops" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" + pr "github.com/estuary/flow/go/protocols/runtime" pb "go.gazette.dev/core/broker/protocol" "go.gazette.dev/core/consumer" pc "go.gazette.dev/core/consumer/protocol" "golang.org/x/net/trace" ) -type proxyServer struct { - resolver *consumer.Resolver +// ProxyServer is the "backend" of connector networking. +// It accepts Proxy requests from the Frontend and connects them to the +// corresponding TCP port of a running connector container. +type ProxyServer struct { + Resolver *consumer.Resolver } -func (ps *proxyServer) Proxy(claims pb.Claims, stream pf.NetworkProxy_ProxyServer) (_err error) { +func (ps *ProxyServer) Proxy(claims pb.Claims, stream pf.NetworkProxy_ProxyServer) (_err error) { var ctx = stream.Context() var open, err = stream.Recv() @@ -31,11 +32,8 @@ func (ps *proxyServer) Proxy(claims pb.Claims, stream pf.NetworkProxy_ProxyServe } else if err := validateOpen(open); err != nil { return fmt.Errorf("invalid open proxy message: %w", err) } - var labels = []string{ - open.Open.ShardId.String(), strconv.Itoa(int(open.Open.TargetPort)), - } - resolution, err := ps.resolver.Resolve(consumer.ResolveArgs{ + resolution, err := ps.Resolver.Resolve(consumer.ResolveArgs{ Context: ctx, Claims: claims, MayProxy: false, @@ -58,7 +56,9 @@ func (ps *proxyServer) Proxy(claims pb.Claims, stream pf.NetworkProxy_ProxyServe } // Resolve the target port to the current container. - var container, publisher = resolution.Store.(Application).proxyHook() + var container, publisher = resolution.Store.(interface { + ProxyHook() (*pr.Container, ops.Publisher) + }).ProxyHook() resolution.Done() if tr, ok := trace.FromContext(ctx); ok { @@ -109,8 +109,6 @@ func (ps *proxyServer) Proxy(claims pb.Claims, stream pf.NetworkProxy_ProxyServe "clientAddr", open.Open.ClientAddr, "targetPort", open.Open.TargetPort, ) - proxyConnectionsAcceptedCounter.WithLabelValues(labels...).Inc() - var inbound, outbound uint64 defer func() { @@ -121,18 +119,12 @@ func (ps *proxyServer) Proxy(claims pb.Claims, stream pf.NetworkProxy_ProxyServe "byteOut", outbound, "error", _err, ) - if _err == nil { - proxyConnectionsClosedCounter.WithLabelValues(append(labels, "ok")...).Inc() - } else { - proxyConnectionsClosedCounter.WithLabelValues(append(labels, "error")...).Inc() - } }() // Forward loop that proxies from `client` => `delegate`. go func() { defer delegate.CloseWrite() - var counter = proxyConnBytesInboundCounter.WithLabelValues(labels...) for { if request, err := stream.Recv(); err != nil { err = pf.UnwrapGRPCError(err) @@ -151,7 +143,6 @@ func (ps *proxyServer) Proxy(claims pb.Claims, stream pf.NetworkProxy_ProxyServe return } else { atomic.AddUint64(&inbound, uint64(n)) - counter.Add(float64(n)) } } }() @@ -160,7 +151,6 @@ func (ps *proxyServer) Proxy(claims pb.Claims, stream pf.NetworkProxy_ProxyServe // When this loop completes, so does the Proxy RPC. var buffer = make([]byte, 1<<14) // 16KB. - var counter = proxyConnBytesOutboundCounter.WithLabelValues(labels...) for { if n, err := delegate.Read(buffer); err == io.EOF { return nil @@ -172,7 +162,6 @@ func (ps *proxyServer) Proxy(claims pb.Claims, stream pf.NetworkProxy_ProxyServe return nil } else { outbound += uint64(n) - counter.Add(float64(n)) } } } @@ -200,25 +189,5 @@ func validateOpen(req *pf.TaskNetworkProxyRequest) error { return nil } -// Prometheus metrics for connector TCP networking. -// These metrics match those collected by data-plane-gateway. -var proxyConnectionsAcceptedCounter = promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "net_proxy_conns_accept_total", - Help: "counter of proxy connections that have been accepted", -}, []string{"shard", "port"}) -var proxyConnectionsClosedCounter = promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "net_proxy_conns_closed_total", - Help: "counter of proxy connections that have completed and closed", -}, []string{"shard", "port", "status"}) - -var proxyConnBytesInboundCounter = promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "net_proxy_conn_inbound_bytes_total", - Help: "total bytes proxied from client to container", -}, []string{"shard", "port"}) -var proxyConnBytesOutboundCounter = promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "net_proxy_conn_outbound_bytes_total", - Help: "total bytes proxied from container to client", -}, []string{"shard", "port"}) - // See crates/runtime/src/container.rs const connectorInitPort = 49092 diff --git a/go/network/sni.go b/go/network/sni.go new file mode 100644 index 0000000000..1d34fdf64b --- /dev/null +++ b/go/network/sni.go @@ -0,0 +1,115 @@ +package network + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + + "github.com/estuary/flow/go/labels" + pb "go.gazette.dev/core/broker/protocol" + pc "go.gazette.dev/core/consumer/protocol" +) + +// Parsed portions of the TLS ServerName which are used to map to a shard. +type parsedSNI struct { + hostname string + port string + keyBegin string + rClockBegin string +} + +// Resolved task shard metadata which allow us to complete TLS handshake. +type resolvedSNI struct { + portIsPublic bool + portProtocol string + shardIDPrefix string + taskName string +} + +// parseSNI parses a `target` into a parsedSNI. +// We accept two forms of targets: +// * d7f4a9d02b48c1a-6789 (hostname and port) +// * d7f4a9d02b48c1a-00000000-80000000-6789 (hostname, key begin, r-clock begin, and port). +func parseSNI(target string) (parsedSNI, error) { + var parts = strings.Split(target, "-") + var hostname, port, keyBegin, rClockBegin string + + if len(parts) == 2 { + hostname = parts[0] + port = parts[1] + } else if len(parts) == 4 { + hostname = parts[0] + keyBegin = parts[1] + rClockBegin = parts[2] + port = parts[3] + } else { + return parsedSNI{}, fmt.Errorf("expected two or for subdomain components, not %d", len(parts)) + } + + var _, err = strconv.ParseUint(port, 10, 16) + if err != nil { + return parsedSNI{}, fmt.Errorf("failed to parse subdomain port number: %w", err) + } + + return parsedSNI{ + hostname: hostname, + port: port, + keyBegin: keyBegin, + rClockBegin: rClockBegin, + }, nil +} + +func newResolvedSNI(parsed parsedSNI, shard *pc.ShardSpec) resolvedSNI { + var shardIDPrefix = shard.Id.String() + if ind := strings.LastIndexByte(shardIDPrefix, '/'); ind != -1 { + shardIDPrefix = shardIDPrefix[:ind+1] // Including trailing '/'. + } + + var portProtocol = shard.LabelSet.ValueOf(labels.PortProtoPrefix + parsed.port) + var portIsPublic = shard.LabelSet.ValueOf(labels.PortPublicPrefix+parsed.port) == "true" + + // Private ports MUST use the HTTP/1.1 reverse proxy. + if !portIsPublic { + portProtocol = "" + } + + return resolvedSNI{ + shardIDPrefix: shardIDPrefix, + portProtocol: portProtocol, + portIsPublic: portIsPublic, + taskName: shard.LabelSet.ValueOf(labels.TaskName), + } +} + +func listShards(ctx context.Context, shards pc.ShardClient, parsed parsedSNI, shardIDPrefix string) ([]pc.ListResponse_Shard, error) { + var include = []pb.Label{ + {Name: labels.ExposePort, Value: parsed.port}, + {Name: labels.Hostname, Value: parsed.hostname}, + } + if parsed.keyBegin != "" { + include = append(include, pb.Label{Name: labels.KeyBegin, Value: parsed.keyBegin}) + } + if parsed.rClockBegin != "" { + include = append(include, pb.Label{Name: labels.RClockBegin, Value: parsed.rClockBegin}) + } + if shardIDPrefix != "" { + include = append(include, pb.Label{Name: "id", Value: shardIDPrefix, Prefix: true}) + } + + var resp, err = shards.List( + pb.WithDispatchDefault(ctx), + &pc.ListRequest{ + Selector: pb.LabelSelector{Include: pb.LabelSet{Labels: include}}, + }, + ) + if err == nil && resp.Status != pc.Status_OK { + err = errors.New(resp.Status.String()) + } + if err != nil { + return nil, err + } + + return resp.Shards, nil +} diff --git a/go/network/tap.go b/go/network/tap.go new file mode 100644 index 0000000000..2bbf66f10c --- /dev/null +++ b/go/network/tap.go @@ -0,0 +1,50 @@ +package network + +import ( + "crypto/tls" + "net" +) + +// Tap is an adapter which retains a tapped net.Listener and itself acts as a +// net.Listener, with an Accept() that communicates over a forwarding channel. +// It's used for late binding of a Proxy to a pre-created Listener, +// and to hand off connections which are not intended for Proxy. +type Tap struct { + raw net.Listener + config *tls.Config + fwdCh chan net.Conn + fwdErr error +} + +func NewTap() *Tap { + return &Tap{ + raw: nil, // Set by Tap(). + config: nil, // Set by Tap(). + fwdCh: make(chan net.Conn, 4), + fwdErr: nil, + } +} + +func (tap *Tap) Wrap(tapped net.Listener, config *tls.Config) (net.Listener, error) { + tap.raw = tapped + tap.config = config + return tap, nil +} + +func (tap *Tap) Accept() (net.Conn, error) { + if conn, ok := <-tap.fwdCh; ok { + return conn, nil + } else { + return nil, tap.fwdErr + } +} + +func (tap *Tap) Close() error { + return tap.raw.Close() +} + +func (tap *Tap) Addr() net.Addr { + return tap.raw.Addr() +} + +var _ net.Listener = &Tap{} diff --git a/go/runtime/flow_consumer.go b/go/runtime/flow_consumer.go index fa8042e604..e1c8211361 100644 --- a/go/runtime/flow_consumer.go +++ b/go/runtime/flow_consumer.go @@ -10,6 +10,7 @@ import ( "github.com/estuary/flow/go/bindings" "github.com/estuary/flow/go/flow" "github.com/estuary/flow/go/labels" + "github.com/estuary/flow/go/network" "github.com/estuary/flow/go/protocols/capture" "github.com/estuary/flow/go/protocols/derive" pf "github.com/estuary/flow/go/protocols/flow" @@ -36,6 +37,7 @@ type FlowConsumerConfig struct { AllowLocal bool `long:"allow-local" env:"ALLOW_LOCAL" description:"Allow local connectors. True for local stacks, and false otherwise."` BuildsRoot string `long:"builds-root" required:"true" env:"BUILDS_ROOT" description:"Base URL for fetching Flow catalog builds"` ControlAPI pb.Endpoint `long:"control-api" env:"CONTROL_API" description:"Address of the control-plane API"` + Dashboard pb.Endpoint `long:"dashboard" env:"DASHBOARD" description:"Address of the Estuary dashboard"` DataPlaneFQDN string `long:"data-plane-fqdn" env:"DATA_PLANE_FQDN" description:"Fully-qualified domain name of the data-plane to which this reactor belongs"` Network string `long:"network" description:"The Docker network that connector containers are given access to. Defaults to the bridge network"` ProxyRuntimes int `long:"proxy-runtimes" default:"2" description:"The number of proxy connector runtimes that may run concurrently"` @@ -45,7 +47,14 @@ type FlowConsumerConfig struct { // Execute delegates to runconsumer.Cmd.Execute. func (c *FlowConsumerConfig) Execute(args []string) error { - return runconsumer.Cmd{Cfg: c, App: new(FlowConsumer)}.Execute(args) + var app = &FlowConsumer{ + Tap: network.NewTap(), + } + return runconsumer.Cmd{ + Cfg: c, + App: app, + WrapListener: app.Tap.Wrap, + }.Execute(args) } // FlowConsumer implements the Estuary Flow Consumer. @@ -65,6 +74,8 @@ type FlowConsumer struct { // It's important that we use a Context that's scoped to the life of the process, // rather than the lives of individual shards, so we don't lose logs. OpsContext context.Context + // Network listener tap. + Tap *network.Tap } // Application is the interface implemented by Flow shard task stores. @@ -81,9 +92,9 @@ type Application interface { ReplayRange(_ consumer.Shard, _ pb.Journal, begin, end pb.Offset) message.Iterator ReadThrough(pb.Offsets) (pb.Offsets, error) - // proxyHook exposes a current Container and ops.Publisher - // for use by the network proxy server. - proxyHook() (*pr.Container, ops.Publisher) + // ProxyHook exposes a current Container and ops.Publisher + // for use by network.ProxyServer. + ProxyHook() (*pr.Container, ops.Publisher) } var _ consumer.Application = (*FlowConsumer)(nil) @@ -197,7 +208,9 @@ func (f *FlowConsumer) InitApplication(args runconsumer.InitArgs) error { return fmt.Errorf("catalog builds service: %w", err) } - if keyedAuth, ok := args.Service.Authorizer.(*auth.KeyedAuth); ok && !config.Flow.TestAPIs { + var localAuthorizer = args.Service.Authorizer + + if keyedAuth, ok := localAuthorizer.(*auth.KeyedAuth); ok && !config.Flow.TestAPIs { // Wrap the underlying KeyedAuth Authorizer to use the control-plane's Authorize API. args.Service.Authorizer = NewControlPlaneAuthorizer( keyedAuth, @@ -242,7 +255,7 @@ func (f *FlowConsumer) InitApplication(args runconsumer.InitArgs) error { pr.NewVerifiedShufflerServer(shuffle.NewAPI(args.Service.Resolver), f.Service.Verifier)) pf.RegisterNetworkProxyServer(args.Server.GRPCServer, - pf.NewVerifiedNetworkProxyServer(&proxyServer{resolver: args.Service.Resolver}, f.Service.Verifier)) + pf.NewVerifiedNetworkProxyServer(&network.ProxyServer{Resolver: args.Service.Resolver}, f.Service.Verifier)) var connectorProxy = &connectorProxy{ address: args.Server.Endpoint(), @@ -255,5 +268,21 @@ func (f *FlowConsumer) InitApplication(args runconsumer.InitArgs) error { derive.RegisterConnectorServer(args.Server.GRPCServer, connectorProxy) materialize.RegisterConnectorServer(args.Server.GRPCServer, connectorProxy) + networkProxy, err := network.NewFrontend( + f.Tap, + config.Consumer.Host, + config.Flow.ControlAPI.URL(), + config.Flow.Dashboard.URL(), + pf.NewAuthNetworkProxyClient(pf.NewNetworkProxyClient(args.Server.GRPCLoopback), localAuthorizer), + pc.NewAuthShardClient(pc.NewShardClient(args.Server.GRPCLoopback), localAuthorizer), + args.Service.Verifier, + ) + if err != nil { + return fmt.Errorf("failed to build network proxy: %w", err) + } + args.Tasks.Queue("network-proxy-frontend", func() error { + return networkProxy.Serve(args.Tasks.Context()) + }) + return nil } diff --git a/go/runtime/task.go b/go/runtime/task.go index 48e0c02eeb..edf433ae64 100644 --- a/go/runtime/task.go +++ b/go/runtime/task.go @@ -148,7 +148,7 @@ func (t *taskBase[TaskSpec]) initTerm(shard consumer.Shard) error { return nil } -func (t *taskBase[TaskSpec]) proxyHook() (*pr.Container, ops.Publisher) { +func (t *taskBase[TaskSpec]) ProxyHook() (*pr.Container, ops.Publisher) { return t.container.Load(), t.opsPublisher }