Skip to content

Commit

Permalink
estuary-cdk: support oauth2 client credentials grant type flow
Browse files Browse the repository at this point in the history
Currently, only the OAuth2 Authorization Code grant type flow is
supported. Upcoming SaaS connectors require more flexibility & can use
different OAuth2 grant type flows. For example, the Genesys connector
needs a user-configured domain to perform authentication and get an
access token. Our current OAuth2 framework only supports access token
URLs that don't contain user-configured values, and adding that support
would be a larger effort than we want to take on right now. With the
Client Credentials grant type flow, we can inject a user-configured
value in the access token URL at runtime to support this use case.
  • Loading branch information
Alex-Bair committed Nov 11, 2024
1 parent ec3334b commit 38dff0c
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 11 deletions.
2 changes: 2 additions & 0 deletions estuary-cdk/estuary_cdk/capture/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
AccessToken,
BaseOAuth2Credentials,
CaptureBinding,
ClientCredentialsOAuth2Credentials,
ClientCredentialsOAuth2Spec,
OAuth2Spec,
ValidationError,
BasicAuth,
Expand Down
20 changes: 20 additions & 0 deletions estuary-cdk/estuary_cdk/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ class Checkpoint(BaseModel):
ackIntents: dict[str, str]


class ClientCredentialsOAuth2Spec(BaseModel):
accessTokenResponseMap: dict[str, str]
accessTokenUrlTemplate: str


class OAuth2Spec(BaseModel):
provider: str
accessTokenBody: str
Expand Down Expand Up @@ -122,6 +127,21 @@ class ValidationError(Exception):
errors: list[str]


class ClientCredentialsOAuth2Credentials(abc.ABC, BaseModel):
credentials_title: Literal["OAuth Credentials"] = Field(
default="OAuth Credentials",
json_schema_extra={"type": "string"}
)
client_id: str = Field(
title="Client Id",
json_schema_extra={"secret": True},
)
client_secret: str = Field(
title="Client Secret",
json_schema_extra={"secret": True},
)


class BaseOAuth2Credentials(abc.ABC, BaseModel):
credentials_title: Literal["OAuth Credentials"] = Field(
default="OAuth Credentials",
Expand Down
48 changes: 37 additions & 11 deletions estuary-cdk/estuary_cdk/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
import time

from . import Mixin
from .flow import BaseOAuth2Credentials, AccessToken, OAuth2Spec, BasicAuth
from .flow import (
AccessToken,
BasicAuth,
BaseOAuth2Credentials,
ClientCredentialsOAuth2Credentials,
ClientCredentialsOAuth2Spec,
OAuth2Spec,
)

DEFAULT_AUTHORIZATION_HEADER = "Authorization"

Expand Down Expand Up @@ -120,8 +127,8 @@ class AccessTokenResponse(BaseModel):
refresh_token: str = ""
scope: str = ""

oauth_spec: OAuth2Spec | None
credentials: BaseOAuth2Credentials | AccessToken | BasicAuth
oauth_spec: OAuth2Spec | ClientCredentialsOAuth2Spec | None
credentials: BaseOAuth2Credentials | ClientCredentialsOAuth2Credentials | AccessToken | BasicAuth
authorization_header: str = DEFAULT_AUTHORIZATION_HEADER
_access_token: AccessTokenResponse | None = None
_fetched_at: int = 0
Expand All @@ -137,7 +144,7 @@ async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str
).decode(),
)

assert isinstance(self.credentials, BaseOAuth2Credentials)
assert isinstance(self.credentials, BaseOAuth2Credentials) or isinstance(self.credentials, ClientCredentialsOAuth2Credentials)
current_time = time.time()

if self._access_token is not None:
Expand All @@ -158,20 +165,39 @@ async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str
return ("Bearer", self._access_token.access_token)

async def _fetch_oauth2_token(
self, log: Logger, session: HTTPSession, credentials: BaseOAuth2Credentials
self, log: Logger, session: HTTPSession, credentials: BaseOAuth2Credentials | ClientCredentialsOAuth2Credentials
) -> AccessTokenResponse:
assert self.oauth_spec

headers = {}
form = {}

match credentials:
case BaseOAuth2Credentials():
form = {
"grant_type": "refresh_token",
"client_id": credentials.client_id,
"client_secret": credentials.client_secret,
"refresh_token": credentials.refresh_token,
}
case ClientCredentialsOAuth2Credentials():
form = {
"grant_type": "client_credentials",
}
headers = {
"Authorization": "Basic " + base64.b64encode(
f"{credentials.client_id}:{credentials.client_secret}".encode()
).decode()
}
case _:
raise TypeError(f"Unsupported credentials type: {type(credentials)}.")

response = await session.request(
log,
self.oauth_spec.accessTokenUrlTemplate,
method="POST",
form={
"grant_type": "refresh_token",
"client_id": credentials.client_id,
"client_secret": credentials.client_secret,
"refresh_token": credentials.refresh_token,
},
headers=headers,
form=form,
_with_token=False,
)
return self.AccessTokenResponse.model_validate_json(response)
Expand Down

0 comments on commit 38dff0c

Please sign in to comment.