Skip to content

Commit

Permalink
Add user_profile_method to upstream SSO provider
Browse files Browse the repository at this point in the history
  • Loading branch information
MatMaul committed Oct 15, 2024
1 parent 8723e40 commit 4a62a23
Show file tree
Hide file tree
Showing 24 changed files with 361 additions and 1,149 deletions.
11 changes: 11 additions & 0 deletions crates/cli/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ pub async fn config_sync(
}
};

let user_profile_method = match provider.user_profile_method {
mas_config::UpstreamOAuth2UserProfileMethod::Auto => {
mas_data_model::UpstreamOAuthProviderUserProfileMethod::Auto
}
mas_config::UpstreamOAuth2UserProfileMethod::UserinfoEndpoint => {
mas_data_model::UpstreamOAuthProviderUserProfileMethod::UserinfoEndpoint
}
};

repo.upstream_oauth_provider()
.upsert(
clock,
Expand All @@ -241,13 +250,15 @@ pub async fn config_sync(
brand_name: provider.brand_name,
scope: provider.scope.parse()?,
token_endpoint_auth_method: provider.token_endpoint_auth_method.into(),
user_profile_method,
token_endpoint_signing_alg: provider
.token_endpoint_auth_signing_alg
.clone(),
client_id: provider.client_id,
encrypted_client_secret,
claims_imports: map_claims_imports(&provider.claims_imports),
token_endpoint_override: provider.token_endpoint,
userinfo_endpoint_override: provider.userinfo_endpoint,
authorization_endpoint_override: provider.authorization_endpoint,
jwks_uri_override: provider.jwks_uri,
discovery_mode,
Expand Down
1 change: 1 addition & 0 deletions crates/config/src/sections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub use self::{
EmailImportPreference as UpstreamOAuth2EmailImportPreference,
ImportAction as UpstreamOAuth2ImportAction, PkceMethod as UpstreamOAuth2PkceMethod,
SetEmailVerification as UpstreamOAuth2SetEmailVerification, UpstreamOAuth2Config,
UserProfileMethod as UpstreamOAuth2UserProfileMethod,
},
};
use crate::util::ConfigurationSection;
Expand Down
34 changes: 34 additions & 0 deletions crates/config/src/sections/upstream_oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,26 @@ impl From<TokenAuthMethod> for OAuthClientAuthenticationMethod {
}
}

/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the id_token from the token_endpoint
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum UserProfileMethod {
/// Use the userinfo endpoint if `openid` is not included in `scopes`
#[default]
Auto,

/// Always use the userinfo endpoint
UserinfoEndpoint,
}

impl UserProfileMethod {
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default(&self) -> bool {
matches!(self, UserProfileMethod::Auto)
}
}

/// How to handle a claim
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[serde(rename_all = "lowercase")]
Expand Down Expand Up @@ -401,6 +421,14 @@ pub struct Provider {
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,

/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the id_token from the token_endpoint.
///
/// Defaults to `auto`, which uses the userinfo endpoint if `openid` is not
/// included in `scopes`, and the ID token otherwise.
#[serde(default, skip_serializing_if = "UserProfileMethod::is_default")]
pub user_profile_method: UserProfileMethod,

/// The scopes to request from the provider
pub scope: String,

Expand All @@ -424,6 +452,12 @@ pub struct Provider {
#[serde(skip_serializing_if = "Option::is_none")]
pub authorization_endpoint: Option<Url>,

/// The URL to use for the provider's userinfo endpoint
///
/// Defaults to the `userinfo_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_endpoint: Option<Url>,

/// The URL to use for the provider's token endpoint
///
/// Defaults to the `token_endpoint` provided through discovery
Expand Down
1 change: 1 addition & 0 deletions crates/data-model/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub use self::{
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference,
UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderSubjectPreference,
UpstreamOAuthProviderUserProfileMethod,
},
user_agent::{DeviceType, UserAgent},
users::{
Expand Down
1 change: 1 addition & 0 deletions crates/data-model/src/upstream_oauth2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub use self::{
PkceMode as UpstreamOAuthProviderPkceMode,
SetEmailVerification as UpsreamOAuthProviderSetEmailVerification,
SubjectPreference as UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProvider,
UserProfileMethod as UpstreamOAuthProviderUserProfileMethod,
},
session::{UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState},
};
47 changes: 47 additions & 0 deletions crates/data-model/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,51 @@ impl std::fmt::Display for PkceMode {
}
}

/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the id_token from the token_endpoint
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum UserProfileMethod {
/// Use the userinfo endpoint if `openid` is not included in `scopes`
#[default]
Auto,

/// Always use the userinfo endpoint
UserinfoEndpoint,
}

#[derive(Debug, Clone, Error)]
#[error("Invalid user profile method {0:?}")]
pub struct InvalidUserProfileMethodError(String);

impl std::str::FromStr for UserProfileMethod {
type Err = InvalidUserProfileMethodError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"auto" => Ok(Self::Auto),
"userinfo_endpoint" => Ok(Self::UserinfoEndpoint),
s => Err(InvalidUserProfileMethodError(s.to_owned())),
}
}
}

impl UserProfileMethod {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Auto => "auto",
Self::UserinfoEndpoint => "userinfo_endpoint",
}
}
}

impl std::fmt::Display for UserProfileMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthProvider {
pub id: Ulid,
Expand All @@ -127,11 +172,13 @@ pub struct UpstreamOAuthProvider {
pub jwks_uri_override: Option<Url>,
pub authorization_endpoint_override: Option<Url>,
pub token_endpoint_override: Option<Url>,
pub userinfo_endpoint_override: Option<Url>,
pub scope: Scope,
pub client_id: String,
pub encrypted_client_secret: Option<String>,
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
pub user_profile_method: UserProfileMethod,
pub created_at: DateTime<Utc>,
pub disabled_at: Option<DateTime<Utc>>,
pub claims_imports: ClaimsImports,
Expand Down
18 changes: 17 additions & 1 deletion crates/handlers/src/upstream_oauth2/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,18 @@ impl<'a> LazyProviderInfos<'a> {
Ok(self.load().await?.token_endpoint())
}

/// Get the userinfo endpoint for the provider.
///
/// Uses [`UpstreamOAuthProvider.userinfo_endpoint_override`] if set, otherwise
/// uses the one from discovery.
pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
return Ok(userinfo_endpoint);
}

Ok(self.load().await?.userinfo_endpoint())
}

/// Get the PKCE methods supported by the provider.
///
/// If the mode is set to auto, it will use the ones from discovery,
Expand Down Expand Up @@ -276,7 +288,9 @@ mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};

use hyper::{body::Bytes, Request, Response, StatusCode};
use mas_data_model::UpstreamOAuthProviderClaimsImports;
use mas_data_model::{
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderUserProfileMethod,
};
use mas_http::BoxCloneSyncService;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_storage::{clock::MockClock, Clock};
Expand Down Expand Up @@ -487,8 +501,10 @@ mod tests {
brand_name: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
user_profile_method: UpstreamOAuthProviderUserProfileMethod::Auto,
jwks_uri_override: None,
authorization_endpoint_override: None,
userinfo_endpoint_override: None,
token_endpoint_override: None,
scope: Scope::from_iter([OPENID]),
client_id: "client_id".to_owned(),
Expand Down
36 changes: 30 additions & 6 deletions crates/handlers/src/upstream_oauth2/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use mas_axum_utils::{
cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID,
};
use mas_data_model::UpstreamOAuthProvider;
use mas_data_model::UpstreamOAuthProviderUserProfileMethod;
use mas_keystore::{Encrypter, Keystore};
use mas_oidc_client::requests::{
authorization_code::AuthorizationValidationData, jose::JwtVerificationData,
Expand Down Expand Up @@ -94,13 +95,14 @@ pub(crate) enum RouteError {
MissingCookie,

#[error(transparent)]
Internal(Box<dyn std::error::Error>),
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
}

impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
impl_from_error_for_route!(mas_oidc_client::error::JwksError);
impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError);
impl_from_error_for_route!(mas_oidc_client::error::UserInfoError);
impl_from_error_for_route!(super::ProviderCredentialsError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);

Expand Down Expand Up @@ -212,32 +214,54 @@ pub(crate) async fn get(
redirect_uri,
};

let id_token_verification_data = JwtVerificationData {
let verification_data = JwtVerificationData {
issuer: &provider.issuer,
jwks: &jwks,
// TODO: make that configurable
signing_algorithm: &mas_iana::jose::JsonWebSignatureAlg::Rs256,
client_id: &provider.client_id,
};

let (response, id_token) =
let (response, id_token_map) =
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
&http_service,
client_credentials,
lazy_metadata.token_endpoint().await?,
code,
validation_data,
Some(id_token_verification_data),
Some(verification_data),
clock.now(),
&mut rng,
)
.await?;

let (_header, id_token) = id_token.ok_or(RouteError::MissingIDToken)?.into_parts();
let (_header, id_token) = id_token_map
.clone()
.ok_or(RouteError::MissingIDToken)?
.into_parts();

let use_userinfo_endpoint = match provider.user_profile_method {
UpstreamOAuthProviderUserProfileMethod::Auto => !provider.scope.contains("openid"),
UpstreamOAuthProviderUserProfileMethod::UserinfoEndpoint => true,
};

let userinfo = if use_userinfo_endpoint {
let user_info_resp = mas_oidc_client::requests::userinfo::fetch_userinfo(
&http_service,
lazy_metadata.userinfo_endpoint().await?,
response.access_token.as_str(),
Some(verification_data),
&id_token_map.ok_or(RouteError::MissingIDToken)?,
)
.await?;
minijinja::Value::from_serialize(&user_info_resp)
} else {
minijinja::Value::from_serialize(&id_token)
};

let env = {
let mut env = environment();
env.add_global("user", minijinja::Value::from_serialize(&id_token));
env.add_global("user", userinfo);
env
};

Expand Down
3 changes: 3 additions & 0 deletions crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -907,12 +907,15 @@ mod tests {
brand_name: None,
scope: Scope::from_iter([OPENID]),
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
user_profile_method:
mas_data_model::UpstreamOAuthProviderUserProfileMethod::Auto,
token_endpoint_signing_alg: None,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports,
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
Expand Down
6 changes: 6 additions & 0 deletions crates/handlers/src/views/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,14 @@ mod test {
scope: [OPENID].into_iter().collect(),
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
token_endpoint_signing_alg: None,
user_profile_method:
mas_data_model::UpstreamOAuthProviderUserProfileMethod::Auto,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
Expand Down Expand Up @@ -406,11 +409,14 @@ mod test {
scope: [OPENID].into_iter().collect(),
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
token_endpoint_signing_alg: None,
user_profile_method:
mas_data_model::UpstreamOAuthProviderUserProfileMethod::Auto,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
Expand Down
9 changes: 9 additions & 0 deletions crates/oauth2-types/src/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,15 @@ impl VerifiedProviderMetadata {
}
}

/// TODO
#[must_use]
pub fn userinfo_endpoint(&self) -> &Url {
match &self.userinfo_endpoint {
Some(u) => u,
None => unreachable!(),
}
}

/// URL of the authorization server's token endpoint.
#[must_use]
pub fn token_endpoint(&self) -> &Url {
Expand Down

This file was deleted.

Loading

0 comments on commit 4a62a23

Please sign in to comment.