Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add user_profile_method to upstream SSO provider #3363

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions crates/cli/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ pub async fn config_sync(
}
};

let user_profile_method = match provider.user_profile_method {
mas_config::UpstreamOAuth2UserProfileMethod::Auto => {
mas_data_model::UpstreamOAuthProviderUserProfileMethod::Auto
}
mas_config::UpstreamOAuth2UserProfileMethod::UserinfoEndpoint => {
mas_data_model::UpstreamOAuthProviderUserProfileMethod::UserinfoEndpoint
}
};

repo.upstream_oauth_provider()
.upsert(
clock,
Expand All @@ -241,13 +250,15 @@ pub async fn config_sync(
brand_name: provider.brand_name,
scope: provider.scope.parse()?,
token_endpoint_auth_method: provider.token_endpoint_auth_method.into(),
user_profile_method,
token_endpoint_signing_alg: provider
.token_endpoint_auth_signing_alg
.clone(),
client_id: provider.client_id,
encrypted_client_secret,
claims_imports: map_claims_imports(&provider.claims_imports),
token_endpoint_override: provider.token_endpoint,
userinfo_endpoint_override: provider.userinfo_endpoint,
authorization_endpoint_override: provider.authorization_endpoint,
jwks_uri_override: provider.jwks_uri,
discovery_mode,
Expand Down
1 change: 1 addition & 0 deletions crates/config/src/sections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub use self::{
EmailImportPreference as UpstreamOAuth2EmailImportPreference,
ImportAction as UpstreamOAuth2ImportAction, PkceMethod as UpstreamOAuth2PkceMethod,
SetEmailVerification as UpstreamOAuth2SetEmailVerification, UpstreamOAuth2Config,
UserProfileMethod as UpstreamOAuth2UserProfileMethod,
},
};
use crate::util::ConfigurationSection;
Expand Down
34 changes: 34 additions & 0 deletions crates/config/src/sections/upstream_oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,26 @@ impl From<TokenAuthMethod> for OAuthClientAuthenticationMethod {
}
}

/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the id_token from the token_endpoint
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum UserProfileMethod {
/// Use the userinfo endpoint if `openid` is not included in `scopes`
#[default]
Auto,

/// Always use the userinfo endpoint
UserinfoEndpoint,
}

impl UserProfileMethod {
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default(&self) -> bool {
matches!(self, UserProfileMethod::Auto)
}
}

/// How to handle a claim
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[serde(rename_all = "lowercase")]
Expand Down Expand Up @@ -401,6 +421,14 @@ pub struct Provider {
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,

/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the id_token from the token_endpoint.
///
/// Defaults to `auto`, which uses the userinfo endpoint if `openid` is not
/// included in `scopes`, and the ID token otherwise.
#[serde(default, skip_serializing_if = "UserProfileMethod::is_default")]
pub user_profile_method: UserProfileMethod,

/// The scopes to request from the provider
pub scope: String,

Expand All @@ -424,6 +452,12 @@ pub struct Provider {
#[serde(skip_serializing_if = "Option::is_none")]
pub authorization_endpoint: Option<Url>,

/// The URL to use for the provider's userinfo endpoint
///
/// Defaults to the `userinfo_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_endpoint: Option<Url>,

/// The URL to use for the provider's token endpoint
///
/// Defaults to the `token_endpoint` provided through discovery
Expand Down
1 change: 1 addition & 0 deletions crates/data-model/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub use self::{
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference,
UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderSubjectPreference,
UpstreamOAuthProviderUserProfileMethod,
},
user_agent::{DeviceType, UserAgent},
users::{
Expand Down
1 change: 1 addition & 0 deletions crates/data-model/src/upstream_oauth2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub use self::{
PkceMode as UpstreamOAuthProviderPkceMode,
SetEmailVerification as UpsreamOAuthProviderSetEmailVerification,
SubjectPreference as UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProvider,
UserProfileMethod as UpstreamOAuthProviderUserProfileMethod,
},
session::{UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState},
};
47 changes: 47 additions & 0 deletions crates/data-model/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,51 @@ impl std::fmt::Display for PkceMode {
}
}

/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the id_token from the token_endpoint
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum UserProfileMethod {
/// Use the userinfo endpoint if `openid` is not included in `scopes`
#[default]
Auto,

/// Always use the userinfo endpoint
UserinfoEndpoint,
}

#[derive(Debug, Clone, Error)]
#[error("Invalid user profile method {0:?}")]
pub struct InvalidUserProfileMethodError(String);

impl std::str::FromStr for UserProfileMethod {
type Err = InvalidUserProfileMethodError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"auto" => Ok(Self::Auto),
"userinfo_endpoint" => Ok(Self::UserinfoEndpoint),
s => Err(InvalidUserProfileMethodError(s.to_owned())),
}
}
}

impl UserProfileMethod {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Auto => "auto",
Self::UserinfoEndpoint => "userinfo_endpoint",
}
}
}

impl std::fmt::Display for UserProfileMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthProvider {
pub id: Ulid,
Expand All @@ -127,11 +172,13 @@ pub struct UpstreamOAuthProvider {
pub jwks_uri_override: Option<Url>,
pub authorization_endpoint_override: Option<Url>,
pub token_endpoint_override: Option<Url>,
pub userinfo_endpoint_override: Option<Url>,
pub scope: Scope,
pub client_id: String,
pub encrypted_client_secret: Option<String>,
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
pub user_profile_method: UserProfileMethod,
pub created_at: DateTime<Utc>,
pub disabled_at: Option<DateTime<Utc>>,
pub claims_imports: ClaimsImports,
Expand Down
21 changes: 20 additions & 1 deletion crates/data-model/src/upstream_oauth2/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ pub enum UpstreamOAuthAuthorizationSessionState {
completed_at: DateTime<Utc>,
link_id: Ulid,
id_token: Option<String>,
userinfo: Option<String>,
},
Consumed {
completed_at: DateTime<Utc>,
consumed_at: DateTime<Utc>,
link_id: Ulid,
id_token: Option<String>,
userinfo: Option<String>,
},
}

Expand All @@ -42,12 +44,14 @@ impl UpstreamOAuthAuthorizationSessionState {
completed_at: DateTime<Utc>,
link: &UpstreamOAuthLink,
id_token: Option<String>,
userinfo: Option<String>,
) -> Result<Self, InvalidTransitionError> {
match self {
Self::Pending => Ok(Self::Completed {
completed_at,
link_id: link.id,
id_token,
userinfo,
}),
Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
}
Expand All @@ -67,11 +71,13 @@ impl UpstreamOAuthAuthorizationSessionState {
completed_at,
link_id,
id_token,
userinfo,
} => Ok(Self::Consumed {
completed_at,
link_id,
consumed_at,
id_token,
userinfo,
}),
Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
}
Expand Down Expand Up @@ -124,6 +130,16 @@ impl UpstreamOAuthAuthorizationSessionState {
}
}

#[must_use]
pub fn userinfo(&self) -> Option<&str> {
match self {
Self::Pending => None,
Self::Completed { userinfo, .. } | Self::Consumed { userinfo, .. } => {
userinfo.as_deref()
}
}
}

/// Get the time at which the upstream OAuth 2.0 authorization session was
/// consumed.
///
Expand Down Expand Up @@ -201,8 +217,11 @@ impl UpstreamOAuthAuthorizationSession {
completed_at: DateTime<Utc>,
link: &UpstreamOAuthLink,
id_token: Option<String>,
userinfo: Option<String>,
) -> Result<Self, InvalidTransitionError> {
self.state = self.state.complete(completed_at, link, id_token)?;
self.state = self
.state
.complete(completed_at, link, id_token, userinfo)?;
Ok(self)
}

Expand Down
18 changes: 17 additions & 1 deletion crates/handlers/src/upstream_oauth2/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ impl<'a> LazyProviderInfos<'a> {
Ok(self.load().await?.token_endpoint())
}

/// Get the userinfo endpoint for the provider.
///
/// Uses [`UpstreamOAuthProvider.userinfo_endpoint_override`] if set, otherwise
/// uses the one from discovery.
pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
return Ok(userinfo_endpoint);
}

Ok(self.load().await?.userinfo_endpoint())
}

/// Get the PKCE methods supported by the provider.
///
/// If the mode is set to auto, it will use the ones from discovery,
Expand Down Expand Up @@ -274,7 +286,9 @@ mod tests {
// XXX: sadly, we can't test HTTPS requests with wiremock, so we can only test
// 'insecure' discovery

use mas_data_model::UpstreamOAuthProviderClaimsImports;
use mas_data_model::{
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderUserProfileMethod,
};
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_storage::{clock::MockClock, Clock};
use oauth2_types::scope::{Scope, OPENID};
Expand Down Expand Up @@ -386,8 +400,10 @@ mod tests {
brand_name: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
user_profile_method: UpstreamOAuthProviderUserProfileMethod::Auto,
jwks_uri_override: None,
authorization_endpoint_override: None,
userinfo_endpoint_override: None,
token_endpoint_override: None,
scope: Scope::from_iter([OPENID]),
client_id: "client_id".to_owned(),
Expand Down
Loading
Loading