diff --git a/drpchttp/alpn.go b/drpchttp/alpn.go new file mode 100644 index 0000000..587edb3 --- /dev/null +++ b/drpchttp/alpn.go @@ -0,0 +1,101 @@ +package drpchttp + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net/http" + + "storj.io/drpc/drpcserver" +) + +// TLSNextProto the ALPN protocol ID used for switching to DRPC. +const TLSNextProto = `drpc/0` + +// ErrNextProtosUnconfigured is returned from [ConfigureNextProto] when an +// [http.Server] is not explicitly configured for protocol negotiation. +var ErrNextProtosUnconfigured = errors.New("drpchttp: (*http.Server).TLSNextProto not populated; doing nothing") + +// ConfigureNextProto adds a "next protocol" handler to the passed [http.Server] +// that dispatches connections to the passed [drpcserver.Server]. The "fallback" +// [context.Context] is used for connections if a suitable Context cannot be +// derived from the [http] interface. If nil is passed, [context.Background] +// will be used. +// +// This function is only effective if the [http.Server] is serving over a TLS +// connection. If [http.Server.TLSNextProto] is not populated, +// [ErrNextProtosUnconfigured] will be reported. This is done to avoid +// accidentally disabling HTTP/2 support, which is only enabled by default if +// TLSNextProto is not populated. See [golang.org/x/net/http2] for explicit +// HTTP/2 configuration. +// +// If [http.Server.TLSConfig] is populated, [ConfigureTLS] is called +// automatically. Note that it's only used if the [http.Server.ServeTLS] or +// [http.Server.ListenAndServeTLS] methods are used. +func ConfigureNextProto(hs *http.Server, srv *drpcserver.Server, fallback context.Context) error { + const errPrefix = `drpchttp: can't setup ALPN: ` + switch { + case hs == nil: + return errors.New(errPrefix + "nil http.Server") + case srv == nil: + return errors.New(errPrefix + "nil drpcserver.Server") + case hs.TLSNextProto == nil: + return ErrNextProtosUnconfigured + } + if fallback == nil { + fallback = context.Background() + } + + // This is patterned on the go http2 package. + // + // This handler ignores the passed http.Handler argument and instead hijacks + // the Connection and hands it to the DRPC server. + + if cfg := hs.TLSConfig; cfg != nil { + var err error + hs.TLSConfig, err = ConfigureTLS(cfg) + if err != nil { + return fmt.Errorf(errPrefix+"%w", err) + } + } + + protoHandler := func(s *http.Server, c *tls.Conn, h http.Handler) { + // According to a comment in x/net/http2, there's an unadvertised method + // on the Handler implementation that returns the Context. Technically + // an internal detail, but use it if we can. + var ctx context.Context + type baseContexter interface { + BaseContext() context.Context + } + switch bc, ok := h.(baseContexter); { + case ok: + ctx = bc.BaseContext() + default: + ctx = fallback + } + + // Dunno if there's a better place or way to get a logger or something + // that can handle a returned error. + log := s.ErrorLog + + if err := srv.ServeOne(ctx, c); err != nil && log != nil { + log.Printf("drpc error: %v", err) + } + } + hs.TLSNextProto[TLSNextProto] = protoHandler + + return nil +} + +// ConfigureTLS returns a copy of the passed [tls.Config] modified to enable +// DRPC as a negotiated protocol. +// +// This is needed for client configurations and server configurations that do +// not use [http.Server.ServeTLS]. +func ConfigureTLS(cfg *tls.Config) (*tls.Config, error) { + // Should this just modify the passed-in config? + n := cfg.Clone() + n.NextProtos = append(n.NextProtos, TLSNextProto) + return n, nil +} diff --git a/drpchttp/alpn_test.go b/drpchttp/alpn_test.go new file mode 100644 index 0000000..652012f --- /dev/null +++ b/drpchttp/alpn_test.go @@ -0,0 +1,153 @@ +package drpchttp + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "net" + "net/http" + "net/http/httptest" + "slices" + "strings" + "testing" + + "storj.io/drpc" + "storj.io/drpc/drpcconn" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + "storj.io/drpc/drpctest" +) + +func TestALPN(t *testing.T) { + ctx := drpctest.NewTracker(t) + defer ctx.Close() + + // Set up a DPRC server: + // + // A real server would obviously register services. + dsrv := drpcserver.New(drpcmux.New()) + + // Create a TLS config: + // + // A real server would add "h2", etc + cfg := &tls.Config{ + NextProtos: []string{"http/1.1"}, + } + // Test the function actually modifies the NextProtos: + cfg, err := ConfigureTLS(cfg) + if err != nil { + t.Fatal(err) + } + if !slices.Contains(cfg.NextProtos, TLSNextProto) { + t.Errorf("NextProtos (%v) does not contain %#q", cfg.NextProtos, TLSNextProto) + } + t.Logf("server tls.Config NextProtos: %v", cfg.NextProtos) + + // Create a test HTTP server. + srv := httptest.NewUnstartedServer(http.HandlerFunc(nil)) + + srv.TLS = cfg + srv.Config.BaseContext = func(_ net.Listener) context.Context { return ctx } + // Configure other protocol hooks to fail the test if called. + srv.Config.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){ + "": func(_ *http.Server, _ *tls.Conn, _ http.Handler) { t.Error("got non-ALPN request") }, + "http/1.1": func(_ *http.Server, _ *tls.Conn, _ http.Handler) { t.Error("got http/1.1 request") }, + } + // Test the configuration function actually modifies the TLSNextProto map: + if err := ConfigureNextProto(srv.Config, dsrv, nil); err != nil { + t.Fatal(err) + } + if _, ok := srv.Config.TLSNextProto[TLSNextProto]; !ok { + t.Error("protocol hook not set") + } + + srv.StartTLS() + t.Cleanup(srv.Close) + + // Do a bunch of client setup: + // + // The server setup does not add the server's root CA (it adds it to a + // created [http.Transport]), so we must do it manually. + addr := srv.Listener.Addr() + roots := x509.NewCertPool() + roots.AddCert(srv.Certificate()) + clCfg := &tls.Config{ + RootCAs: roots, + } + clCfg, err = ConfigureTLS(clCfg) + if err != nil { + t.Fatal(err) + } + t.Logf("config tls.Config NextProtos: %v", clCfg.NextProtos) + td := tls.Dialer{ + NetDialer: &net.Dialer{}, + Config: clCfg, + } + // Open a TLS connection. + conn, err := td.DialContext(ctx, addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }) + + // Now create a DRPC connection over the TLS connection and call a + // nonexistent endpoint. + c := drpcconn.New(conn) + err = c.Invoke(ctx, "/", new(bogusEncoding), new(bogusMsg), nil) + if got, want := err.Error(), `unknown rpc: "/"`; !strings.Contains(got, want) { + t.Errorf("got: %#q, want: %#q", got, want) + } + + // Check that the proper protocol was used. Must do this after the [Invoke] + // call because the TLS handshake completes on the first read or write. + t.Logf("negotiated protocol: %q", conn.(*tls.Conn).ConnectionState().NegotiatedProtocol) +} + +type bogusMsg struct { + OK bool +} + +type bogusEncoding struct{} + +func (b *bogusEncoding) Marshal(msg drpc.Message) ([]byte, error) { return json.Marshal(msg) } +func (b *bogusEncoding) Unmarshal(buf []byte, msg drpc.Message) error { + return json.Unmarshal(buf, msg) +} + +func ExampleConfigureNextProto() { + // Set up the HTTP server. The ALPN support uses the server's accept loop. + hSrv := new(http.Server) + hSrv.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) + // A production server may want to explicitly enable HTTP/2: + /* + h2Srv := new(http2.Server) + // Configure Handler, etc... + http2.ConfigureServer(hSrv, h2Srv) + */ + + // Set up the DRPC server. + dSrv := drpcserver.New(nil) + + ConfigureNextProto(hSrv, dSrv, context.TODO()) + + hSrv.ListenAndServeTLS("cert.pem", "key.pem") +} + +func ExampleConfigureTLS_client() { + // Set up the TLS config. + cfg, _ := ConfigureTLS(new(tls.Config)) + dialer := tls.Dialer{Config: cfg} + + // Open a TLS connection. + conn, _ := dialer.DialContext(context.TODO(), "tcp", "[::]:https") + defer conn.Close() + + // Now, create a DRPC connection over the TLS connection. + c := drpcconn.New(conn) + c.Close() +}