Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

drpchttp: add helpers for ALPN support #54

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions drpchttp/alpn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package drpchttp

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"

"storj.io/drpc/drpcserver"
)

// TLSNextProto the ALPN protocol ID used for switching to DRPC.
const TLSNextProto = `drpc/0`

// ErrNextProtosUnconfigured is returned from [ConfigureNextProto] when an
// [http.Server] is not explicitly configured for protocol negotiation.
var ErrNextProtosUnconfigured = errors.New("drpchttp: (*http.Server).TLSNextProto not populated; doing nothing")

// ConfigureNextProto adds a "next protocol" handler to the passed [http.Server]
// that dispatches connections to the passed [drpcserver.Server]. The "fallback"
// [context.Context] is used for connections if a suitable Context cannot be
// derived from the [http] interface. If nil is passed, [context.Background]
// will be used.
//
// This function is only effective if the [http.Server] is serving over a TLS
// connection. If [http.Server.TLSNextProto] is not populated,
// [ErrNextProtosUnconfigured] will be reported. This is done to avoid
// accidentally disabling HTTP/2 support, which is only enabled by default if
// TLSNextProto is not populated. See [golang.org/x/net/http2] for explicit
// HTTP/2 configuration.
//
// If [http.Server.TLSConfig] is populated, [ConfigureTLS] is called
// automatically. Note that it's only used if the [http.Server.ServeTLS] or
// [http.Server.ListenAndServeTLS] methods are used.
func ConfigureNextProto(hs *http.Server, srv *drpcserver.Server, fallback context.Context) error {
const errPrefix = `drpchttp: can't setup ALPN: `
switch {
case hs == nil:
return errors.New(errPrefix + "nil http.Server")
case srv == nil:
return errors.New(errPrefix + "nil drpcserver.Server")
case hs.TLSNextProto == nil:
return ErrNextProtosUnconfigured
}
if fallback == nil {
fallback = context.Background()
}

// This is patterned on the go http2 package.
//
// This handler ignores the passed http.Handler argument and instead hijacks
// the Connection and hands it to the DRPC server.

if cfg := hs.TLSConfig; cfg != nil {
var err error
hs.TLSConfig, err = ConfigureTLS(cfg)
if err != nil {
return fmt.Errorf(errPrefix+"%w", err)
}
}

protoHandler := func(s *http.Server, c *tls.Conn, h http.Handler) {
// According to a comment in x/net/http2, there's an unadvertised method
// on the Handler implementation that returns the Context. Technically
// an internal detail, but use it if we can.
var ctx context.Context
type baseContexter interface {
BaseContext() context.Context
}
switch bc, ok := h.(baseContexter); {
case ok:
ctx = bc.BaseContext()
default:
ctx = fallback
}

// Dunno if there's a better place or way to get a logger or something
// that can handle a returned error.
log := s.ErrorLog

if err := srv.ServeOne(ctx, c); err != nil && log != nil {
log.Printf("drpc error: %v", err)
}
}
hs.TLSNextProto[TLSNextProto] = protoHandler

return nil
}

// ConfigureTLS returns a copy of the passed [tls.Config] modified to enable
// DRPC as a negotiated protocol.
//
// This is needed for client configurations and server configurations that do
// not use [http.Server.ServeTLS].
func ConfigureTLS(cfg *tls.Config) (*tls.Config, error) {
// Should this just modify the passed-in config?
n := cfg.Clone()
n.NextProtos = append(n.NextProtos, TLSNextProto)
return n, nil
}
153 changes: 153 additions & 0 deletions drpchttp/alpn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package drpchttp

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"slices"
"strings"
"testing"

"storj.io/drpc"
"storj.io/drpc/drpcconn"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"storj.io/drpc/drpctest"
)

func TestALPN(t *testing.T) {
ctx := drpctest.NewTracker(t)
defer ctx.Close()

// Set up a DPRC server:
//
// A real server would obviously register services.
dsrv := drpcserver.New(drpcmux.New())

// Create a TLS config:
//
// A real server would add "h2", etc
cfg := &tls.Config{
NextProtos: []string{"http/1.1"},
}
// Test the function actually modifies the NextProtos:
cfg, err := ConfigureTLS(cfg)
if err != nil {
t.Fatal(err)
}
if !slices.Contains(cfg.NextProtos, TLSNextProto) {
t.Errorf("NextProtos (%v) does not contain %#q", cfg.NextProtos, TLSNextProto)
}
t.Logf("server tls.Config NextProtos: %v", cfg.NextProtos)

// Create a test HTTP server.
srv := httptest.NewUnstartedServer(http.HandlerFunc(nil))

srv.TLS = cfg
srv.Config.BaseContext = func(_ net.Listener) context.Context { return ctx }
// Configure other protocol hooks to fail the test if called.
srv.Config.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){
"": func(_ *http.Server, _ *tls.Conn, _ http.Handler) { t.Error("got non-ALPN request") },
"http/1.1": func(_ *http.Server, _ *tls.Conn, _ http.Handler) { t.Error("got http/1.1 request") },
}
// Test the configuration function actually modifies the TLSNextProto map:
if err := ConfigureNextProto(srv.Config, dsrv, nil); err != nil {
t.Fatal(err)
}
if _, ok := srv.Config.TLSNextProto[TLSNextProto]; !ok {
t.Error("protocol hook not set")
}

srv.StartTLS()
t.Cleanup(srv.Close)

// Do a bunch of client setup:
//
// The server setup does not add the server's root CA (it adds it to a
// created [http.Transport]), so we must do it manually.
addr := srv.Listener.Addr()
roots := x509.NewCertPool()
roots.AddCert(srv.Certificate())
clCfg := &tls.Config{
RootCAs: roots,
}
clCfg, err = ConfigureTLS(clCfg)
if err != nil {
t.Fatal(err)
}
t.Logf("config tls.Config NextProtos: %v", clCfg.NextProtos)
td := tls.Dialer{
NetDialer: &net.Dialer{},
Config: clCfg,
}
// Open a TLS connection.
conn, err := td.DialContext(ctx, addr.Network(), addr.String())
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
})

// Now create a DRPC connection over the TLS connection and call a
// nonexistent endpoint.
c := drpcconn.New(conn)
err = c.Invoke(ctx, "/", new(bogusEncoding), new(bogusMsg), nil)
if got, want := err.Error(), `unknown rpc: "/"`; !strings.Contains(got, want) {
t.Errorf("got: %#q, want: %#q", got, want)
}

// Check that the proper protocol was used. Must do this after the [Invoke]
// call because the TLS handshake completes on the first read or write.
t.Logf("negotiated protocol: %q", conn.(*tls.Conn).ConnectionState().NegotiatedProtocol)
}

type bogusMsg struct {
OK bool
}

type bogusEncoding struct{}

func (b *bogusEncoding) Marshal(msg drpc.Message) ([]byte, error) { return json.Marshal(msg) }
func (b *bogusEncoding) Unmarshal(buf []byte, msg drpc.Message) error {
return json.Unmarshal(buf, msg)
}

func ExampleConfigureNextProto() {
// Set up the HTTP server. The ALPN support uses the server's accept loop.
hSrv := new(http.Server)
hSrv.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
// A production server may want to explicitly enable HTTP/2:
/*
h2Srv := new(http2.Server)
// Configure Handler, etc...
http2.ConfigureServer(hSrv, h2Srv)
*/

// Set up the DRPC server.
dSrv := drpcserver.New(nil)

ConfigureNextProto(hSrv, dSrv, context.TODO())

hSrv.ListenAndServeTLS("cert.pem", "key.pem")
}

func ExampleConfigureTLS_client() {
// Set up the TLS config.
cfg, _ := ConfigureTLS(new(tls.Config))
dialer := tls.Dialer{Config: cfg}

// Open a TLS connection.
conn, _ := dialer.DialContext(context.TODO(), "tcp", "[::]:https")
defer conn.Close()

// Now, create a DRPC connection over the TLS connection.
c := drpcconn.New(conn)
c.Close()
}