From e2107315605c31b0dd81bfec5ef7862f0c32119e Mon Sep 17 00:00:00 2001 From: Brad Moylan Date: Wed, 18 Dec 2024 13:33:43 -0800 Subject: [PATCH] fix: refreshableConfigAuthHeaderMiddleware handles both basic and apitoken --- conjure-go-client/httpclient/authn.go | 53 ++++++++++++------- .../httpclient/client_builder.go | 34 +++++++----- conjure-go-client/httpclient/client_params.go | 27 +++------- conjure-go-client/httpclient/metrics_test.go | 11 ++-- 4 files changed, 71 insertions(+), 54 deletions(-) diff --git a/conjure-go-client/httpclient/authn.go b/conjure-go-client/httpclient/authn.go index c82edf52..7d5a49dd 100644 --- a/conjure-go-client/httpclient/authn.go +++ b/conjure-go-client/httpclient/authn.go @@ -21,7 +21,6 @@ import ( "net/http" "github.com/palantir/conjure-go-runtime/v2/conjure-go-client/httpclient/internal/refreshingclient" - "github.com/palantir/pkg/refreshable" ) // TokenProvider accepts a context and returns either: @@ -50,17 +49,6 @@ func (h *authTokenMiddleware) RoundTrip(req *http.Request, next http.RoundTrippe return next.RoundTrip(req) } -func newAuthTokenMiddlewareFromRefreshable(token refreshable.StringPtr) Middleware { - return &authTokenMiddleware{ - provideToken: func(ctx context.Context) (string, error) { - if s := token.CurrentStringPtr(); s != nil { - return *s, nil - } - return "", nil - }, - } -} - // BasicAuthProvider accepts a context and returns either: // // (1) a nonempty BasicAuth and a nil error, or @@ -77,13 +65,40 @@ type BasicAuthProvider func(context.Context) (BasicAuth, error) // (3) a nil BasicAuth and a non-nil error. type BasicAuthOptionalProvider func(context.Context) (*BasicAuth, error) -func newBasicAuthMiddlewareFromRefreshable(auth refreshingclient.RefreshableBasicAuthPtr) Middleware { - return MiddlewareFunc(func(req *http.Request, next http.RoundTripper) (*http.Response, error) { - if basicAuth := auth.CurrentBasicAuthPtr(); basicAuth != nil { - setBasicAuth(req.Header, basicAuth.User, basicAuth.Password) - } - return next.RoundTrip(req) - }) +type basicAuthMiddleware struct { + provideBasicAuth BasicAuthOptionalProvider +} + +func (b basicAuthMiddleware) RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) { + basicAuth, err := b.provideBasicAuth(req.Context()) + if err != nil { + return nil, err + } + if basicAuth != nil { + setBasicAuth(req.Header, basicAuth.User, basicAuth.Password) + } + return next.RoundTrip(req) +} + +type refreshableConfigAuthHeaderMiddleware struct { + cfg refreshingclient.RefreshableValidatedClientParams +} + +// newRefreshableConfigAuthHeaderMiddleware returns a new Middleware that sets the Authorization header using the +// current API token or BasicAuth credentials from the provided RefreshableValidatedClientParams. If the request already +// has an Authorization header (e.g. set by a different Middleware), it will not be overwritten. +func newRefreshableConfigAuthHeaderMiddleware(cfg refreshingclient.RefreshableValidatedClientParams) Middleware { + return &refreshableConfigAuthHeaderMiddleware{cfg: cfg} +} + +func (r *refreshableConfigAuthHeaderMiddleware) RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) { + curr := r.cfg.CurrentValidatedClientParams() + if curr.APIToken != nil { + req.Header.Set("Authorization", "Bearer "+*curr.APIToken) + } else if curr.BasicAuth != nil { + setBasicAuth(req.Header, curr.BasicAuth.User, curr.BasicAuth.Password) + } + return next.RoundTrip(req) } func setBasicAuth(h http.Header, username, password string) { diff --git a/conjure-go-client/httpclient/client_builder.go b/conjure-go-client/httpclient/client_builder.go index 34d612c1..489eba0d 100644 --- a/conjure-go-client/httpclient/client_builder.go +++ b/conjure-go-client/httpclient/client_builder.go @@ -81,7 +81,8 @@ type httpClientBuilder struct { DisableTraceHeaders bool } -func (b *httpClientBuilder) Build(ctx context.Context, config RefreshableClientConfig, reloadErrorSubmitter func(error), params ...HTTPClientParam) (RefreshableHTTPClient, refreshingclient.RefreshableValidatedClientParams, error) { +// Build returns a RoundTripper and the refreshable validated client params it's based on. +func (b *httpClientBuilder) Build(ctx context.Context, config RefreshableClientConfig, reloadErrorSubmitter func(error), params ...HTTPClientParam) (http.RoundTripper, refreshingclient.RefreshableValidatedClientParams, error) { for _, p := range params { if p == nil { continue @@ -121,10 +122,8 @@ func (b *httpClientBuilder) Build(ctx context.Context, config RefreshableClientC if !b.DisableRecovery { transport = wrapTransport(transport, recoveryMiddleware{}) } - transport = wrapTransport(transport, b.Middlewares...) - client := refreshingclient.NewRefreshableHTTPClient(transport, validParams.Timeout()) - return client, validParams, nil + return transport, validParams, nil } // NewClient returns a configured client ready for use. @@ -158,13 +157,11 @@ func newClient(ctx context.Context, config RefreshableClientConfig, b *clientBui edm = errorDecoderMiddleware{errorDecoder: b.ErrorDecoder} } - middleware := b.HTTP.Middlewares - b.HTTP.Middlewares = nil - - httpClient, validParams, err := b.HTTP.Build(ctx, config, reloadErrorSubmitter) + transport, validParams, err := b.HTTP.Build(ctx, config, reloadErrorSubmitter) if err != nil { return nil, err } + httpClient := refreshingclient.NewRefreshableHTTPClient(transport, validParams.Timeout()) if !b.AllowEmptyURIs { // Validate that the URIs are not empty. @@ -185,9 +182,13 @@ func newClient(ctx context.Context, config RefreshableClientConfig, b *clientBui return b.URIScorerBuilder(uris) }) - middleware = append(middleware, - newAuthTokenMiddlewareFromRefreshable(validParams.APIToken()), - newBasicAuthMiddlewareFromRefreshable(validParams.BasicAuth())) + // Move the user-configured middlewares from the http.Client to the clientImpl struct + // before httpClientBuilder.Build so they can be wrapped in the correct layer + // of the Client transport stack (outside error decoder and inside body middleware). + middleware := b.HTTP.Middlewares + // Prepend the auth header middleware to the middleware stack. + // If a user-provided middleware sets an Authorization header, it will take precedence over a value from configuration. + middleware = append(middleware, newRefreshableConfigAuthHeaderMiddleware(validParams)) return &clientImpl{ serviceName: validParams.ServiceName(), @@ -221,8 +222,15 @@ type RefreshableHTTPClient = refreshingclient.RefreshableHTTPClient // The RefreshableClientConfig is not accepted as a client param because there must be exactly one // subscription used to build the ValidatedClientParams in Build(). func NewHTTPClientFromRefreshableConfig(ctx context.Context, config RefreshableClientConfig, params ...HTTPClientParam) (RefreshableHTTPClient, error) { - client, _, err := new(httpClientBuilder).Build(ctx, config, nil, params...) - return client, err + b := &httpClientBuilder{} + transport, validParams, err := b.Build(ctx, config, nil, params...) + if err != nil { + return nil, err + } + transport = wrapTransport(transport, b.Middlewares...) + transport = wrapTransport(transport, newRefreshableConfigAuthHeaderMiddleware(validParams)) + httpClient := refreshingclient.NewRefreshableHTTPClient(transport, validParams.Timeout()) + return httpClient, nil } // Map the final config to a set of validated client params used to build the dialer, retrier, tls config, and transport. diff --git a/conjure-go-client/httpclient/client_params.go b/conjure-go-client/httpclient/client_params.go index e4106d3b..34863ba8 100644 --- a/conjure-go-client/httpclient/client_params.go +++ b/conjure-go-client/httpclient/client_params.go @@ -493,23 +493,21 @@ func WithErrorDecoder(errorDecoder ErrorDecoder) ClientParam { // WithBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and // password. func WithBasicAuth(user, password string) ClientOrHTTPClientParam { - return WithInnerMiddleware(MiddlewareFunc(func(req *http.Request, next http.RoundTripper) (*http.Response, error) { - setBasicAuth(req.Header, user, password) - return next.RoundTrip(req) - })) + return WithInnerMiddleware(&basicAuthMiddleware{provideBasicAuth: func(ctx context.Context) (*BasicAuth, error) { + return &BasicAuth{User: user, Password: password}, nil + }}) } // WithBasicAuthProvider sets the request's Authorization header to use HTTP Basic Authentication. // The provider is expected to always return a nonempty BasicAuth value, or an error. func WithBasicAuthProvider(provider BasicAuthProvider) ClientOrHTTPClientParam { - return WithInnerMiddleware(MiddlewareFunc(func(req *http.Request, next http.RoundTripper) (*http.Response, error) { - basicAuth, err := provider(req.Context()) + return WithInnerMiddleware(&basicAuthMiddleware{provideBasicAuth: func(ctx context.Context) (*BasicAuth, error) { + basicAuth, err := provider(ctx) if err != nil { return nil, err } - setBasicAuth(req.Header, basicAuth.User, basicAuth.Password) - return next.RoundTrip(req) - })) + return &basicAuth, nil + }}) } // WithBasicAuthOptionalProvider sets the request's Authorization header to use HTTP Basic Authentication based on the @@ -517,16 +515,7 @@ func WithBasicAuthProvider(provider BasicAuthProvider) ClientOrHTTPClientParam { // BasicAuth value is non-nil then its values are set on the header, while if the returned BasicAuth value is nil then // no basic authentication header values are set. func WithBasicAuthOptionalProvider(provider BasicAuthOptionalProvider) ClientOrHTTPClientParam { - return WithInnerMiddleware(MiddlewareFunc(func(req *http.Request, next http.RoundTripper) (*http.Response, error) { - basicAuth, err := provider(req.Context()) - if err != nil { - return nil, err - } - if basicAuth != nil { - setBasicAuth(req.Header, basicAuth.User, basicAuth.Password) - } - return next.RoundTrip(req) - })) + return WithInnerMiddleware(&basicAuthMiddleware{provideBasicAuth: provider}) } // WithBalancedURIScoring adds middleware that prioritizes sending requests to URIs with the fewest in-flight requests diff --git a/conjure-go-client/httpclient/metrics_test.go b/conjure-go-client/httpclient/metrics_test.go index 25fbab29..bd757663 100644 --- a/conjure-go-client/httpclient/metrics_test.go +++ b/conjure-go-client/httpclient/metrics_test.go @@ -213,9 +213,14 @@ func TestMetricsMiddleware_HTTPClient(t *testing.T) { } func TestMetricsMiddleware_ClientTimeout(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - time.Sleep(time.Second) - w.WriteHeader(200) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + select { + case <-req.Context().Done(): + // client timed out and closed connection + return + case <-time.After(5 * time.Second): + assert.Fail(t, "timeout waiting for client to close connection") + } })) defer srv.Close()