diff --git a/clients/go/admin/auth_interceptor.go b/clients/go/admin/auth_interceptor.go index 340326784..5bf002a49 100644 --- a/clients/go/admin/auth_interceptor.go +++ b/clients/go/admin/auth_interceptor.go @@ -2,6 +2,7 @@ package admin import ( "context" + "errors" "fmt" "net/http" @@ -16,10 +17,12 @@ import ( "google.golang.org/grpc" ) +const ProxyAuthorizationHeader = "proxy-authorization" + // MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server. // Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values. -func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture) error { - authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg) +func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error { + authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture) if err != nil { return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err) } @@ -48,19 +51,70 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T return nil } +func GetProxyTokenSource(ctx context.Context, cfg *Config) (oauth2.TokenSource, error) { + tokenSourceProvider, err := NewExternalTokenSourceProvider(cfg.ProxyCommand) + if err != nil { + return nil, fmt.Errorf("failed to initialized proxy authorization token source provider. Err: %w", err) + } + proxyTokenSource, err := tokenSourceProvider.GetTokenSource(ctx) + if err != nil { + return nil, err + } + return proxyTokenSource, nil +} + +func MaterializeProxyAuthCredentials(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) error { + proxyTokenSource, err := GetProxyTokenSource(ctx, cfg) + if err != nil { + return err + } + + wrappedTokenSource := NewCustomHeaderTokenSource(proxyTokenSource, cfg.UseInsecureConnection, ProxyAuthorizationHeader) + proxyCredentialsFuture.Store(wrappedTokenSource) + + return nil +} + func shouldAttemptToAuthenticate(errorCode codes.Code) bool { return errorCode == codes.Unauthenticated } +type proxyAuthTransport struct { + transport http.RoundTripper + proxyCredentialsFuture *PerRPCCredentialsFuture +} + +func (c *proxyAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // check if the proxy credentials future is initialized + if !c.proxyCredentialsFuture.IsInitialized() { + return nil, errors.New("proxy credentials future is not initialized") + } + + metadata, err := c.proxyCredentialsFuture.GetRequestMetadata(context.Background(), "") + if err != nil { + return nil, err + } + token := metadata[ProxyAuthorizationHeader] + req.Header.Add(ProxyAuthorizationHeader, token) + return c.transport.RoundTrip(req) +} + // Set up http client used in oauth2 -func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context { +func setHTTPClientContext(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) context.Context { httpClient := &http.Client{} + transport := &http.Transport{} if len(cfg.HTTPProxyURL.String()) > 0 { // create a transport that uses the proxy - transport := &http.Transport{ - Proxy: http.ProxyURL(&cfg.HTTPProxyURL.URL), + transport.Proxy = http.ProxyURL(&cfg.HTTPProxyURL.URL) + } + + if cfg.ProxyCommand != nil { + httpClient.Transport = &proxyAuthTransport{ + transport: transport, + proxyCredentialsFuture: proxyCredentialsFuture, } + } else { httpClient.Transport = transport } @@ -77,9 +131,9 @@ func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context { // more. It'll fail hard if it couldn't do so (i.e. it will no longer attempt to send an unauthenticated request). Once // a token source has been created, it'll invoke the grpc pipeline again, this time the grpc.PerRPCCredentials should // be able to find and acquire a valid AccessToken to annotate the request with. -func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor { +func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - ctx = setHTTPClientContext(ctx, cfg) + ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture) err := invoker(ctx, method, req, reply, cc, opts...) if err != nil { @@ -89,7 +143,7 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut // If the error we receive from executing the request expects if shouldAttemptToAuthenticate(st.Code()) { logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code()) - newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture) + newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture) if newErr != nil { return fmt.Errorf("authentication error! Original Error: %v, Auth Error: %w", err, newErr) } @@ -102,3 +156,18 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut return err } } + +func NewProxyAuthInterceptor(cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + + err := invoker(ctx, method, req, reply, cc, opts...) + if err != nil { + newErr := MaterializeProxyAuthCredentials(ctx, cfg, proxyCredentialsFuture) + if newErr != nil { + return fmt.Errorf("proxy authorization error! Original Error: %v, Proxy Auth Error: %w", err, newErr) + } + return invoker(ctx, method, req, reply, cc, opts...) + } + return err + } +} diff --git a/clients/go/admin/auth_interceptor_test.go b/clients/go/admin/auth_interceptor_test.go index d5e07a713..bb304d4f6 100644 --- a/clients/go/admin/auth_interceptor_test.go +++ b/clients/go/admin/auth_interceptor_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "golang.org/x/oauth2" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -114,7 +115,8 @@ func newAuthMetadataServer(t testing.TB, port int, impl service.AuthMetadataServ func Test_newAuthInterceptor(t *testing.T) { t.Run("Other Error", func(t *testing.T) { f := NewPerRPCCredentialsFuture() - interceptor := NewAuthInterceptor(&Config{}, &mocks.TokenCache{}, f) + p := NewPerRPCCredentialsFuture() + interceptor := NewAuthInterceptor(&Config{}, &mocks.TokenCache{}, f, p) otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Canceled, "").Err() } @@ -146,11 +148,12 @@ func Test_newAuthInterceptor(t *testing.T) { assert.NoError(t, err) f := NewPerRPCCredentialsFuture() + p := NewPerRPCCredentialsFuture() interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, &mocks.TokenCache{}, f) + }, &mocks.TokenCache{}, f, p) unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Unauthenticated, "").Err() } @@ -177,11 +180,13 @@ func Test_newAuthInterceptor(t *testing.T) { assert.NoError(t, err) f := NewPerRPCCredentialsFuture() + p := NewPerRPCCredentialsFuture() + interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, &mocks.TokenCache{}, f) + }, &mocks.TokenCache{}, f, p) authenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return nil } @@ -216,11 +221,13 @@ func Test_newAuthInterceptor(t *testing.T) { assert.NoError(t, err) f := NewPerRPCCredentialsFuture() + p := NewPerRPCCredentialsFuture() + interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, &mocks.TokenCache{}, f) + }, &mocks.TokenCache{}, f, p) unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Aborted, "").Err() } @@ -246,6 +253,8 @@ func TestMaterializeCredentials(t *testing.T) { assert.NoError(t, err) f := NewPerRPCCredentialsFuture() + p := NewPerRPCCredentialsFuture() + err = MaterializeCredentials(ctx, &Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, @@ -254,7 +263,7 @@ func TestMaterializeCredentials(t *testing.T) { Scopes: []string{"all"}, Audience: "http://localhost:30081", AuthorizationHeader: "authorization", - }, &mocks.TokenCache{}, f) + }, &mocks.TokenCache{}, f, p) assert.NoError(t, err) }) t.Run("Failed to fetch client metadata", func(t *testing.T) { @@ -271,13 +280,119 @@ func TestMaterializeCredentials(t *testing.T) { assert.NoError(t, err) f := NewPerRPCCredentialsFuture() + p := NewPerRPCCredentialsFuture() + err = MaterializeCredentials(ctx, &Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port), Scopes: []string{"all"}, - }, &mocks.TokenCache{}, f) + }, &mocks.TokenCache{}, f, p) assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err") }) } + +func TestNewProxyAuthInterceptor(t *testing.T) { + cfg := &Config{ + ProxyCommand: []string{"echo", "test-token"}, + } + + p := NewPerRPCCredentialsFuture() + + interceptor := NewProxyAuthInterceptor(cfg, p) + + ctx := context.Background() + method := "/test.method" + req := "request" + reply := "reply" + cc := new(grpc.ClientConn) + + errorInvoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + return errors.New("test error") + } + + // Call should return an error and trigger the interceptor to materialize proxy auth credentials + err := interceptor(ctx, method, req, reply, cc, errorInvoker) + assert.Error(t, err) + + // Check if proxyCredentialsFuture contains a proxy auth header token + creds, err := p.Get().GetRequestMetadata(ctx, "") + assert.True(t, p.IsInitialized()) + assert.NoError(t, err) + assert.Equal(t, "Bearer test-token", creds[ProxyAuthorizationHeader]) +} + +type testRoundTripper struct { + RoundTripFunc func(req *http.Request) (*http.Response, error) +} + +func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t.RoundTripFunc(req) +} + +func TestSetHTTPClientContext(t *testing.T) { + ctx := context.Background() + + t.Run("no proxy command and no proxy url", func(t *testing.T) { + cfg := &Config{} + + newCtx := setHTTPClientContext(ctx, cfg, nil) + + httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client) + assert.True(t, ok) + + transport, ok := httpClient.Transport.(*http.Transport) + assert.True(t, ok) + assert.Nil(t, transport.Proxy) + }) + + t.Run("proxy url", func(t *testing.T) { + cfg := &Config{ + HTTPProxyURL: config. + URL{URL: url.URL{ + Scheme: "http", + Host: "localhost:8080", + }}, + } + newCtx := setHTTPClientContext(ctx, cfg, nil) + + httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client) + assert.True(t, ok) + + transport, ok := httpClient.Transport.(*http.Transport) + assert.True(t, ok) + assert.NotNil(t, transport.Proxy) + }) + + t.Run("proxy command adds proxy-authorization header", func(t *testing.T) { + cfg := &Config{ + ProxyCommand: []string{"echo", "test-token-http-client"}, + } + + p := NewPerRPCCredentialsFuture() + err := MaterializeProxyAuthCredentials(ctx, cfg, p) + assert.NoError(t, err) + + newCtx := setHTTPClientContext(ctx, cfg, p) + + httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client) + assert.True(t, ok) + + pat, ok := httpClient.Transport.(*proxyAuthTransport) + assert.True(t, ok) + + testRoundTripper := &testRoundTripper{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + // Check if the ProxyAuthorizationHeader is correctly set + assert.Equal(t, "Bearer test-token-http-client", req.Header.Get(ProxyAuthorizationHeader)) + return &http.Response{StatusCode: http.StatusOK}, nil + }, + } + pat.transport = testRoundTripper + + req, _ := http.NewRequest("GET", "http://example.com", nil) + _, err = httpClient.Do(req) + assert.NoError(t, err) + }) +} diff --git a/clients/go/admin/client.go b/clients/go/admin/client.go index 830c86fe8..69e109388 100644 --- a/clients/go/admin/client.go +++ b/clients/go/admin/client.go @@ -110,9 +110,9 @@ func getAuthenticationDialOption(ctx context.Context, cfg *Config, tokenSourcePr } // InitializeAuthMetadataClient creates a new anonymously Auth Metadata Service client. -func InitializeAuthMetadataClient(ctx context.Context, cfg *Config) (client service.AuthMetadataServiceClient, err error) { +func InitializeAuthMetadataClient(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) (client service.AuthMetadataServiceClient, err error) { // Create an unauthenticated connection to fetch AuthMetadata - authMetadataConnection, err := NewAdminConnection(ctx, cfg) + authMetadataConnection, err := NewAdminConnection(ctx, cfg, proxyCredentialsFuture) if err != nil { return nil, fmt.Errorf("failed to initialized admin connection. Error: %w", err) } @@ -120,11 +120,11 @@ func InitializeAuthMetadataClient(ctx context.Context, cfg *Config) (client serv return service.NewAuthMetadataServiceClient(authMetadataConnection), nil } -func NewAdminConnection(ctx context.Context, cfg *Config, opts ...grpc.DialOption) (*grpc.ClientConn, error) { +func NewAdminConnection(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture, opts ...grpc.DialOption) (*grpc.ClientConn, error) { if opts == nil { // Initialize opts list to the potential number of options we will add. Initialization optimizes memory // allocation. - opts = make([]grpc.DialOption, 0, 5) + opts = make([]grpc.DialOption, 0, 7) } if cfg.UseInsecureConnection { @@ -153,6 +153,11 @@ func NewAdminConnection(ctx context.Context, cfg *Config, opts ...grpc.DialOptio opts = append(opts, GetAdditionalAdminClientConfigOptions(cfg)...) + if cfg.ProxyCommand != nil { + opts = append(opts, grpc.WithChainUnaryInterceptor(NewProxyAuthInterceptor(cfg, proxyCredentialsFuture))) + opts = append(opts, grpc.WithPerRPCCredentials(proxyCredentialsFuture)) + } + return grpc.Dial(cfg.Endpoint.String(), opts...) } @@ -172,15 +177,17 @@ func InitializeAdminClient(ctx context.Context, cfg *Config, opts ...grpc.DialOp // for the process. Note that if called with different cfg/dialoptions, it will not refresh the connection. func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, opts ...grpc.DialOption) (*Clientset, error) { credentialsFuture := NewPerRPCCredentialsFuture() + proxyCredentialsFuture := NewPerRPCCredentialsFuture() + opts = append(opts, - grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, credentialsFuture)), + grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)), grpc.WithPerRPCCredentials(credentialsFuture)) if cfg.DefaultServiceConfig != "" { opts = append(opts, grpc.WithDefaultServiceConfig(cfg.DefaultServiceConfig)) } - adminConnection, err := NewAdminConnection(ctx, cfg, opts...) + adminConnection, err := NewAdminConnection(ctx, cfg, proxyCredentialsFuture, opts...) if err != nil { logger.Panicf(ctx, "failed to initialized Admin connection. Err: %s", err.Error()) } diff --git a/clients/go/admin/config.go b/clients/go/admin/config.go index e6c5ad06d..3ce329e4a 100644 --- a/clients/go/admin/config.go +++ b/clients/go/admin/config.go @@ -74,6 +74,8 @@ type Config struct { Command []string `json:"command" pflag:",Command for external authentication token generation"` + ProxyCommand []string `json:"proxyCommand" pflag:",Command for external proxy-authorization token generation"` + // Set the gRPC service config formatted as a json string https://github.com/grpc/grpc/blob/master/doc/service_config.md // eg. {"loadBalancingConfig": [{"round_robin":{}}], "methodConfig": [{"name":[{"service": "foo", "method": "bar"}, {"service": "baz"}], "timeout": "1.000000001s"}]} // find the full schema here https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto#L625 diff --git a/clients/go/admin/pkce/auth_flow_orchestrator.go b/clients/go/admin/pkce/auth_flow_orchestrator.go index e2eee411d..a71841d6c 100644 --- a/clients/go/admin/pkce/auth_flow_orchestrator.go +++ b/clients/go/admin/pkce/auth_flow_orchestrator.go @@ -2,6 +2,7 @@ package pkce import ( "context" + "errors" "fmt" "net/http" "net/url" @@ -63,8 +64,15 @@ func (f TokenOrchestrator) FetchTokenFromAuthFlow(ctx context.Context) (*oauth2. serveMux := http.NewServeMux() server := &http.Server{Addr: redirectURL.Host, Handler: serveMux, ReadHeaderTimeout: 0} // Register the call back handler + + // Pass along http client used in oauth2 + httpClient, ok := ctx.Value(oauth2.HTTPClient).(*http.Client) + if !ok { + return nil, errors.New("Unable to retrieve httpClient used in oauth2 from context") + } + serveMux.HandleFunc(redirectURL.Path, getAuthServerCallbackHandler(f.ClientConfig, pkceCodeVerifier, - tokenChannel, errorChannel, stateString)) // the oauth2 callback endpoint + tokenChannel, errorChannel, stateString, httpClient)) // the oauth2 callback endpoint defer server.Close() go func() { diff --git a/clients/go/admin/pkce/handle_app_call_back.go b/clients/go/admin/pkce/handle_app_call_back.go index 0874a5a4c..470f639fa 100644 --- a/clients/go/admin/pkce/handle_app_call_back.go +++ b/clients/go/admin/pkce/handle_app_call_back.go @@ -11,7 +11,7 @@ import ( ) func getAuthServerCallbackHandler(c *oauth.Config, codeVerifier string, tokenChannel chan *oauth2.Token, - errorChannel chan error, stateString string) func(rw http.ResponseWriter, req *http.Request) { + errorChannel chan error, stateString string, client *http.Client) func(rw http.ResponseWriter, req *http.Request) { return func(rw http.ResponseWriter, req *http.Request) { _, _ = rw.Write([]byte(`

Flyte Authentication

`)) @@ -43,7 +43,8 @@ func getAuthServerCallbackHandler(c *oauth.Config, codeVerifier string, tokenCha var opts []oauth2.AuthCodeOption opts = append(opts, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) - token, err := c.Exchange(context.Background(), req.URL.Query().Get("code"), opts...) + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client) + token, err := c.Exchange(ctx, req.URL.Query().Get("code"), opts...) if err != nil { errorChannel <- fmt.Errorf("error while exchanging auth code due to %v", err) _, _ = rw.Write([]byte(fmt.Sprintf(`

Couldn't get access token due to error: %s

`, err.Error()))) diff --git a/clients/go/admin/pkce/handle_app_call_back_test.go b/clients/go/admin/pkce/handle_app_call_back_test.go index 91eed672c..fbad6b5d3 100644 --- a/clients/go/admin/pkce/handle_app_call_back_test.go +++ b/clients/go/admin/pkce/handle_app_call_back_test.go @@ -25,7 +25,7 @@ func HandleAppCallBackSetup(t *testing.T, state string) (tokenChannel chan *oaut errorChannel = make(chan error, 1) tokenChannel = make(chan *oauth2.Token) testAuthConfig = &oauth.Config{Config: &oauth2.Config{}, DeviceEndpoint: "dummyDeviceEndpoint"} - callBackFn = getAuthServerCallbackHandler(testAuthConfig, "", tokenChannel, errorChannel, state) + callBackFn = getAuthServerCallbackHandler(testAuthConfig, "", tokenChannel, errorChannel, state, &http.Client{}) assert.NotNil(t, callBackFn) req = &http.Request{ Method: http.MethodGet,