diff --git a/estuary-cdk/estuary_cdk/capture/common.py b/estuary-cdk/estuary_cdk/capture/common.py index aa82c513b..b4eff6d82 100644 --- a/estuary-cdk/estuary_cdk/capture/common.py +++ b/estuary-cdk/estuary_cdk/capture/common.py @@ -22,6 +22,8 @@ AccessToken, BaseOAuth2Credentials, CaptureBinding, + ClientCredentialsOAuth2Credentials, + ClientCredentialsOAuth2Spec, OAuth2Spec, ValidationError, BasicAuth, diff --git a/estuary-cdk/estuary_cdk/flow.py b/estuary-cdk/estuary_cdk/flow.py index 288ebf260..c5ef26cf5 100644 --- a/estuary-cdk/estuary_cdk/flow.py +++ b/estuary-cdk/estuary_cdk/flow.py @@ -66,6 +66,11 @@ class Checkpoint(BaseModel): ackIntents: dict[str, str] +class ClientCredentialsOAuth2Spec(BaseModel): + accessTokenResponseMap: dict[str, str] + accessTokenUrlTemplate: str + + class OAuth2Spec(BaseModel): provider: str accessTokenBody: str @@ -122,6 +127,21 @@ class ValidationError(Exception): errors: list[str] +class ClientCredentialsOAuth2Credentials(abc.ABC, BaseModel): + credentials_title: Literal["OAuth Credentials"] = Field( + default="OAuth Credentials", + json_schema_extra={"type": "string"} + ) + client_id: str = Field( + title="Client Id", + json_schema_extra={"secret": True}, + ) + client_secret: str = Field( + title="Client Secret", + json_schema_extra={"secret": True}, + ) + + class BaseOAuth2Credentials(abc.ABC, BaseModel): credentials_title: Literal["OAuth Credentials"] = Field( default="OAuth Credentials", diff --git a/estuary-cdk/estuary_cdk/http.py b/estuary-cdk/estuary_cdk/http.py index d43daa389..8966d7ab6 100644 --- a/estuary-cdk/estuary_cdk/http.py +++ b/estuary-cdk/estuary_cdk/http.py @@ -9,7 +9,14 @@ import time from . import Mixin -from .flow import BaseOAuth2Credentials, AccessToken, OAuth2Spec, BasicAuth +from .flow import ( + AccessToken, + BasicAuth, + BaseOAuth2Credentials, + ClientCredentialsOAuth2Credentials, + ClientCredentialsOAuth2Spec, + OAuth2Spec, +) DEFAULT_AUTHORIZATION_HEADER = "Authorization" @@ -120,8 +127,8 @@ class AccessTokenResponse(BaseModel): refresh_token: str = "" scope: str = "" - oauth_spec: OAuth2Spec | None - credentials: BaseOAuth2Credentials | AccessToken | BasicAuth + oauth_spec: OAuth2Spec | ClientCredentialsOAuth2Spec | None + credentials: BaseOAuth2Credentials | ClientCredentialsOAuth2Credentials | AccessToken | BasicAuth authorization_header: str = DEFAULT_AUTHORIZATION_HEADER _access_token: AccessTokenResponse | None = None _fetched_at: int = 0 @@ -137,7 +144,7 @@ async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str ).decode(), ) - assert isinstance(self.credentials, BaseOAuth2Credentials) + assert isinstance(self.credentials, BaseOAuth2Credentials) or isinstance(self.credentials, ClientCredentialsOAuth2Credentials) current_time = time.time() if self._access_token is not None: @@ -158,20 +165,39 @@ async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str return ("Bearer", self._access_token.access_token) async def _fetch_oauth2_token( - self, log: Logger, session: HTTPSession, credentials: BaseOAuth2Credentials + self, log: Logger, session: HTTPSession, credentials: BaseOAuth2Credentials | ClientCredentialsOAuth2Credentials ) -> AccessTokenResponse: assert self.oauth_spec + headers = {} + form = {} + + match credentials: + case BaseOAuth2Credentials(): + form = { + "grant_type": "refresh_token", + "client_id": credentials.client_id, + "client_secret": credentials.client_secret, + "refresh_token": credentials.refresh_token, + } + case ClientCredentialsOAuth2Credentials(): + form = { + "grant_type": "client_credentials", + } + headers = { + "Authorization": "Basic " + base64.b64encode( + f"{credentials.client_id}:{credentials.client_secret}".encode() + ).decode() + } + case _: + raise TypeError(f"Unsupported credentials type: {type(credentials)}.") + response = await session.request( log, self.oauth_spec.accessTokenUrlTemplate, method="POST", - form={ - "grant_type": "refresh_token", - "client_id": credentials.client_id, - "client_secret": credentials.client_secret, - "refresh_token": credentials.refresh_token, - }, + headers=headers, + form=form, _with_token=False, ) return self.AccessTokenResponse.model_validate_json(response)