Skip to content

Commit

Permalink
fix: refreshableConfigAuthHeaderMiddleware handles both basic and api…
Browse files Browse the repository at this point in the history
…token
  • Loading branch information
bmoylan committed Dec 18, 2024
1 parent ca1e0d0 commit e210731
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 54 deletions.
53 changes: 34 additions & 19 deletions conjure-go-client/httpclient/authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"net/http"

"github.com/palantir/conjure-go-runtime/v2/conjure-go-client/httpclient/internal/refreshingclient"
"github.com/palantir/pkg/refreshable"
)

// TokenProvider accepts a context and returns either:
Expand Down Expand Up @@ -50,17 +49,6 @@ func (h *authTokenMiddleware) RoundTrip(req *http.Request, next http.RoundTrippe
return next.RoundTrip(req)
}

func newAuthTokenMiddlewareFromRefreshable(token refreshable.StringPtr) Middleware {
return &authTokenMiddleware{
provideToken: func(ctx context.Context) (string, error) {
if s := token.CurrentStringPtr(); s != nil {
return *s, nil
}
return "", nil
},
}
}

// BasicAuthProvider accepts a context and returns either:
//
// (1) a nonempty BasicAuth and a nil error, or
Expand All @@ -77,13 +65,40 @@ type BasicAuthProvider func(context.Context) (BasicAuth, error)
// (3) a nil BasicAuth and a non-nil error.
type BasicAuthOptionalProvider func(context.Context) (*BasicAuth, error)

func newBasicAuthMiddlewareFromRefreshable(auth refreshingclient.RefreshableBasicAuthPtr) Middleware {
return MiddlewareFunc(func(req *http.Request, next http.RoundTripper) (*http.Response, error) {
if basicAuth := auth.CurrentBasicAuthPtr(); basicAuth != nil {
setBasicAuth(req.Header, basicAuth.User, basicAuth.Password)
}
return next.RoundTrip(req)
})
type basicAuthMiddleware struct {
provideBasicAuth BasicAuthOptionalProvider
}

func (b basicAuthMiddleware) RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) {
basicAuth, err := b.provideBasicAuth(req.Context())
if err != nil {
return nil, err
}
if basicAuth != nil {
setBasicAuth(req.Header, basicAuth.User, basicAuth.Password)
}
return next.RoundTrip(req)
}

type refreshableConfigAuthHeaderMiddleware struct {
cfg refreshingclient.RefreshableValidatedClientParams
}

// newRefreshableConfigAuthHeaderMiddleware returns a new Middleware that sets the Authorization header using the
// current API token or BasicAuth credentials from the provided RefreshableValidatedClientParams. If the request already
// has an Authorization header (e.g. set by a different Middleware), it will not be overwritten.
func newRefreshableConfigAuthHeaderMiddleware(cfg refreshingclient.RefreshableValidatedClientParams) Middleware {
return &refreshableConfigAuthHeaderMiddleware{cfg: cfg}
}

func (r *refreshableConfigAuthHeaderMiddleware) RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) {
curr := r.cfg.CurrentValidatedClientParams()
if curr.APIToken != nil {
req.Header.Set("Authorization", "Bearer "+*curr.APIToken)
} else if curr.BasicAuth != nil {
setBasicAuth(req.Header, curr.BasicAuth.User, curr.BasicAuth.Password)
}
return next.RoundTrip(req)
}

func setBasicAuth(h http.Header, username, password string) {
Expand Down
34 changes: 21 additions & 13 deletions conjure-go-client/httpclient/client_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ type httpClientBuilder struct {
DisableTraceHeaders bool
}

func (b *httpClientBuilder) Build(ctx context.Context, config RefreshableClientConfig, reloadErrorSubmitter func(error), params ...HTTPClientParam) (RefreshableHTTPClient, refreshingclient.RefreshableValidatedClientParams, error) {
// Build returns a RoundTripper and the refreshable validated client params it's based on.
func (b *httpClientBuilder) Build(ctx context.Context, config RefreshableClientConfig, reloadErrorSubmitter func(error), params ...HTTPClientParam) (http.RoundTripper, refreshingclient.RefreshableValidatedClientParams, error) {
for _, p := range params {
if p == nil {
continue
Expand Down Expand Up @@ -121,10 +122,8 @@ func (b *httpClientBuilder) Build(ctx context.Context, config RefreshableClientC
if !b.DisableRecovery {
transport = wrapTransport(transport, recoveryMiddleware{})
}
transport = wrapTransport(transport, b.Middlewares...)

client := refreshingclient.NewRefreshableHTTPClient(transport, validParams.Timeout())
return client, validParams, nil
return transport, validParams, nil
}

// NewClient returns a configured client ready for use.
Expand Down Expand Up @@ -158,13 +157,11 @@ func newClient(ctx context.Context, config RefreshableClientConfig, b *clientBui
edm = errorDecoderMiddleware{errorDecoder: b.ErrorDecoder}
}

middleware := b.HTTP.Middlewares
b.HTTP.Middlewares = nil

httpClient, validParams, err := b.HTTP.Build(ctx, config, reloadErrorSubmitter)
transport, validParams, err := b.HTTP.Build(ctx, config, reloadErrorSubmitter)
if err != nil {
return nil, err
}
httpClient := refreshingclient.NewRefreshableHTTPClient(transport, validParams.Timeout())

if !b.AllowEmptyURIs {
// Validate that the URIs are not empty.
Expand All @@ -185,9 +182,13 @@ func newClient(ctx context.Context, config RefreshableClientConfig, b *clientBui
return b.URIScorerBuilder(uris)
})

middleware = append(middleware,
newAuthTokenMiddlewareFromRefreshable(validParams.APIToken()),
newBasicAuthMiddlewareFromRefreshable(validParams.BasicAuth()))
// Move the user-configured middlewares from the http.Client to the clientImpl struct
// before httpClientBuilder.Build so they can be wrapped in the correct layer
// of the Client transport stack (outside error decoder and inside body middleware).
middleware := b.HTTP.Middlewares
// Prepend the auth header middleware to the middleware stack.
// If a user-provided middleware sets an Authorization header, it will take precedence over a value from configuration.
middleware = append(middleware, newRefreshableConfigAuthHeaderMiddleware(validParams))

return &clientImpl{
serviceName: validParams.ServiceName(),
Expand Down Expand Up @@ -221,8 +222,15 @@ type RefreshableHTTPClient = refreshingclient.RefreshableHTTPClient
// The RefreshableClientConfig is not accepted as a client param because there must be exactly one
// subscription used to build the ValidatedClientParams in Build().
func NewHTTPClientFromRefreshableConfig(ctx context.Context, config RefreshableClientConfig, params ...HTTPClientParam) (RefreshableHTTPClient, error) {
client, _, err := new(httpClientBuilder).Build(ctx, config, nil, params...)
return client, err
b := &httpClientBuilder{}
transport, validParams, err := b.Build(ctx, config, nil, params...)
if err != nil {
return nil, err
}
transport = wrapTransport(transport, b.Middlewares...)
transport = wrapTransport(transport, newRefreshableConfigAuthHeaderMiddleware(validParams))
httpClient := refreshingclient.NewRefreshableHTTPClient(transport, validParams.Timeout())
return httpClient, nil
}

// Map the final config to a set of validated client params used to build the dialer, retrier, tls config, and transport.
Expand Down
27 changes: 8 additions & 19 deletions conjure-go-client/httpclient/client_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,40 +493,29 @@ func WithErrorDecoder(errorDecoder ErrorDecoder) ClientParam {
// WithBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and
// password.
func WithBasicAuth(user, password string) ClientOrHTTPClientParam {
return WithInnerMiddleware(MiddlewareFunc(func(req *http.Request, next http.RoundTripper) (*http.Response, error) {
setBasicAuth(req.Header, user, password)
return next.RoundTrip(req)
}))
return WithInnerMiddleware(&basicAuthMiddleware{provideBasicAuth: func(ctx context.Context) (*BasicAuth, error) {
return &BasicAuth{User: user, Password: password}, nil
}})
}

// WithBasicAuthProvider sets the request's Authorization header to use HTTP Basic Authentication.
// The provider is expected to always return a nonempty BasicAuth value, or an error.
func WithBasicAuthProvider(provider BasicAuthProvider) ClientOrHTTPClientParam {
return WithInnerMiddleware(MiddlewareFunc(func(req *http.Request, next http.RoundTripper) (*http.Response, error) {
basicAuth, err := provider(req.Context())
return WithInnerMiddleware(&basicAuthMiddleware{provideBasicAuth: func(ctx context.Context) (*BasicAuth, error) {
basicAuth, err := provider(ctx)
if err != nil {
return nil, err
}
setBasicAuth(req.Header, basicAuth.User, basicAuth.Password)
return next.RoundTrip(req)
}))
return &basicAuth, nil
}})
}

// WithBasicAuthOptionalProvider sets the request's Authorization header to use HTTP Basic Authentication based on the
// return value of the provided BasicAuthOptionalProvider. If the provider returns a non-nil error, if the returned
// BasicAuth value is non-nil then its values are set on the header, while if the returned BasicAuth value is nil then
// no basic authentication header values are set.
func WithBasicAuthOptionalProvider(provider BasicAuthOptionalProvider) ClientOrHTTPClientParam {
return WithInnerMiddleware(MiddlewareFunc(func(req *http.Request, next http.RoundTripper) (*http.Response, error) {
basicAuth, err := provider(req.Context())
if err != nil {
return nil, err
}
if basicAuth != nil {
setBasicAuth(req.Header, basicAuth.User, basicAuth.Password)
}
return next.RoundTrip(req)
}))
return WithInnerMiddleware(&basicAuthMiddleware{provideBasicAuth: provider})
}

// WithBalancedURIScoring adds middleware that prioritizes sending requests to URIs with the fewest in-flight requests
Expand Down
11 changes: 8 additions & 3 deletions conjure-go-client/httpclient/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,14 @@ func TestMetricsMiddleware_HTTPClient(t *testing.T) {
}

func TestMetricsMiddleware_ClientTimeout(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
time.Sleep(time.Second)
w.WriteHeader(200)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
select {
case <-req.Context().Done():
// client timed out and closed connection
return
case <-time.After(5 * time.Second):
assert.Fail(t, "timeout waiting for client to close connection")
}
}))
defer srv.Close()

Expand Down

0 comments on commit e210731

Please sign in to comment.