diff --git a/libs/oauth2/src/dex_provider.rs b/libs/oauth2/src/dex_provider.rs new file mode 100644 index 0000000..01dd234 --- /dev/null +++ b/libs/oauth2/src/dex_provider.rs @@ -0,0 +1,108 @@ +use std::{future::Future, pin::Pin}; + +use crate::{ + errors::Oauth2Error, + oauth_provider::{decode_oauth_id_token, OAuthProvider, OAuthProviderFactory, OAuthResponse}, + Provider, ProviderConfig, TokenResponse, +}; +use base64::prelude::{Engine as _, BASE64_STANDARD}; + +use url::form_urlencoded; + +pub struct DexProvider { + provider_config: ProviderConfig, +} + +/// Get the authorization header for the provider +/// +/// # Arguments +/// * `provider_config` - The provider configuration +/// +/// # Returns +/// The authorization header +fn get_authorization_header(provider_config: &ProviderConfig) -> String { + format!( + "Basic {}", + BASE64_STANDARD.encode(format!( + "{}:{}", + provider_config.app_id, provider_config.app_secret + )) + ) +} + +impl OAuthProviderFactory for DexProvider { + fn new() -> Self { + let provider_config = Self::get_provider_config(Provider::Custom); + Self { provider_config } + } +} +impl OAuthProvider for DexProvider { + fn get_redirect_url(&self, callback_url: &str, state: &str) -> String { + let redirect_url = + form_urlencoded::byte_serialize(callback_url.as_bytes()).collect::(); + let scope = form_urlencoded::byte_serialize(self.provider_config.scope.as_bytes()) + .collect::(); + let state = form_urlencoded::byte_serialize(state.as_bytes()).collect::(); + + format!( + "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}", + self.provider_config.authorization_url, + self.provider_config.app_id, + redirect_url, + scope, + state + ) + } + + fn exchange_code( + &self, + code: &str, + callback_url: &str, + ) -> Pin> + Send + Sync>> { + let code = code.to_string(); + let callback_url = callback_url.to_string(); + let provider_config = self.provider_config.clone(); + + Box::pin(async move { + let code = form_urlencoded::byte_serialize(code.as_bytes()).collect::(); + let authorization_header = get_authorization_header(&provider_config); + let response = reqwest::Client::new() + .post(provider_config.token_exchange_url.as_str()) + .header("Authorization", authorization_header) + .header("Content-Type", "application/x-www-form-urlencoded") + .form(&[ + ("grant_type", "authorization_code"), + ("code", code.as_str()), + ("redirect_uri", &callback_url), + ("client_id", &provider_config.app_id.as_str()), + ]) + .send() + .await + .map_err(|_| Oauth2Error::ExchangeCodeError)?; + + let response = response + .error_for_status() + .map_err(|_| Oauth2Error::ExchangeCodeError)?; + + let body = response + .json::() + .await + .map_err(|_| Oauth2Error::ExchangeCodeError)?; + + if let Some(id_token) = body.id_token { + let (username, email) = decode_oauth_id_token(&id_token)?; + Ok(OAuthResponse { + access_token: body.access_token, + username, + email, + }) + } else { + Err(Oauth2Error::ExchangeCodeError) + } + }) + } + + fn get_provider_type(&self) -> crate::Provider { + Provider::Custom + } +} diff --git a/libs/oauth2/src/lib.rs b/libs/oauth2/src/lib.rs index b2ceeb0..458473b 100644 --- a/libs/oauth2/src/lib.rs +++ b/libs/oauth2/src/lib.rs @@ -1,12 +1,8 @@ -use base64::{ - engine::general_purpose::URL_SAFE_NO_PAD, - prelude::{Engine as _, BASE64_STANDARD}, -}; +pub mod dex_provider; +pub mod oauth_provider; use serde::{Deserialize, Serialize}; use std::{fs, str::FromStr}; -use url::form_urlencoded; mod errors; -use errors::Oauth2Error; #[derive(Deserialize, Serialize, Clone, Debug)] pub struct ProviderConfig { @@ -34,7 +30,7 @@ pub enum Provider { } #[derive(Debug, Serialize, Deserialize)] -struct Claims { +pub struct Claims { aud: String, sub: String, name: String, @@ -89,229 +85,26 @@ pub struct TokenResponse { pub refresh_token: Option, } -/// Get the name of the provider config file -/// from the OAUTH2_CONFIG_FILE environment variable or -/// default to "oauth2.toml" +/// Get the providers config from a config file /// -/// # Returns -/// The name of the provider config file -pub fn get_provider_config_file() -> String { - std::env::var("OAUTH2_CONFIG_FILE").unwrap_or_else(|_| "oauth2.toml".to_string()) -} - -/// Get the redirect url for the specified provider -/// -/// # Arguments -/// * `provider_config` - The provider configuration -/// -/// # Returns -/// The redirect url -/// -pub fn get_provider_config(config_file: &str) -> Vec { +/// # Returns +/// The providers config +pub fn get_providers_config_from_file(config_file: &str) -> Vec { let config_file_content = fs::read_to_string(config_file).expect("Failed to read config file"); let config: Config = toml::from_str(&config_file_content).expect("Failed to parse config file"); config.provider } -/// Get redirect url for the provider -/// -/// # Arguments -/// * `provider_config` - The provider configuration -/// * `callback_url` - The callback url -/// * `state` - The state (for CSRF protection) -/// -/// # Returns -/// The redirect url -pub fn get_redirect_url( - provider_config: &ProviderConfig, - callback_url: &str, - state: &str, -) -> String { - let redirect_url = form_urlencoded::byte_serialize(callback_url.as_bytes()).collect::(); - let scope = - form_urlencoded::byte_serialize(provider_config.scope.as_bytes()).collect::(); - let state = form_urlencoded::byte_serialize(state.as_bytes()).collect::(); - - format!( - "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}", - provider_config.authorization_url, provider_config.app_id, redirect_url, scope, state - ) -} - -/// Get the authorization header for the provider -/// -/// # Arguments -/// * `provider_config` - The provider configuration -/// -/// # Returns -/// The authorization header -pub fn get_authorization_header(provider_config: &ProviderConfig) -> String { - format!( - "Basic {}", - BASE64_STANDARD.encode(format!( - "{}:{}", - provider_config.app_id, provider_config.app_secret - )) - ) -} - -/// Exchange the code for an access token -/// -/// # Arguments -/// * `provider_config` - The provider configuration -/// * `code` - The code -/// * `callback_url` - The callback url -/// -/// # Returns -/// -/// * The access token -/// * The username -/// * The email -pub async fn exchange_code( - provider_config: &ProviderConfig, - code: &str, - callback_url: &str, -) -> Result<(String,String,String), Oauth2Error> { - let code = form_urlencoded::byte_serialize(code.as_bytes()).collect::(); - //let callback_url = form_urlencoded::byte_serialize(callback_url.as_bytes()).collect::(); - let authorization_header = get_authorization_header(provider_config); - let response = reqwest::Client::new() - .post(provider_config.token_exchange_url.as_str()) - .header("Authorization", authorization_header) - .header("Content-Type", "application/x-www-form-urlencoded") - .form(&[ - ("grant_type", "authorization_code"), - ("code", code.as_str()), - ("redirect_uri", callback_url), - ("client_id", provider_config.app_id.as_str()), - ]) - .send() - .await - .map_err(|_| Oauth2Error::ExchangeCodeError)?; - - let response = response - .error_for_status() - .map_err(|_| Oauth2Error::ExchangeCodeError)?; - - let body = response - .json::() - .await - .map_err(|_| Oauth2Error::ExchangeCodeError)?; - - if body.id_token.is_some() { - // Response contains an id_token which is a JWT token - let id_token = body.id_token.unwrap(); - let resp = decode_id_token(&id_token); - if resp.is_err() { - return Err(Oauth2Error::DecodeIdTokenError); - } - let (name, email) = resp.unwrap(); - return Ok((body.access_token, name, email)); - } - Ok((body.access_token, "".to_string(), "tobefilled@example.org".to_string())) -} - -/// Decode the id token - -/// # Arguments -/// * `id_token` - The jwt id token -/// -/// # Returns -/// the username and email - -pub fn decode_id_token(id_token: &str) -> Result<(String, String), Oauth2Error> { - let parts: Vec<&str> = id_token.split('.').collect(); - let claims = URL_SAFE_NO_PAD - .decode(parts[1]) - .map_err(|_| Oauth2Error::DecodeIdTokenError)?; - let claims: Claims = - serde_json::from_slice(&claims).map_err(|_| Oauth2Error::DecodeIdTokenError)?; - Ok((claims.name, claims.email)) -} - -/// Verify the access token -/// -/// # Arguments -/// * `provider_config` - The provider configuration -/// * `access_token` - The access token +/// Get the name of the provider config file +/// from the OAUTH2_CONFIG_FILE environment variable or +/// default to "oauth2.toml" /// /// # Returns -/// (access_token, username, email) -/// The access token if valid, otherwise an error -pub async fn verify_access_token( - provider_config: &ProviderConfig, - access_token: &str, -) -> Result { - let client = reqwest::Client::new(); - match provider_config.provider { - Provider::Github => { - let response = client - .get("https://api.github.com/user") - .bearer_auth(access_token) - .send() - .await; - match response { - Ok(response) if response.status().is_success() => Ok(access_token.to_string()), - _ => Err(Oauth2Error::VerifyTokenError), - } - } - Provider::Google => { - let response = client - .get(&format!( - "https://www.googleapis.com/oauth2/v1/tokeninfo?access_token={}", - access_token - )) - .send() - .await; - - match response { - Ok(response) if response.status().is_success() => Ok(access_token.to_string()), - _ => Err(Oauth2Error::VerifyTokenError), - } - } - Provider::Facebook => { - let app_token = format!("{}|{}", provider_config.app_id, provider_config.app_secret); - let response = client - .get(&format!( - "https://graph.facebook.com/debug_token?input_token={}&access_token={}", - access_token, app_token - )) - .send() - .await; - - match response { - Ok(response) if response.status().is_success() => Ok(access_token.to_string()), - _ => Err(Oauth2Error::VerifyTokenError), - } - } - Provider::Gitlab => { - let response = client - .post("https://gitlab.com/oauth/token/info") - .bearer_auth(access_token) - .send() - .await; - - match response { - Ok(response) if response.status().is_success() => Ok(access_token.to_string()), - _ => Err(Oauth2Error::VerifyTokenError), - } - } - // TODO: Add other providers - Provider::Custom => Ok(access_token.to_string()), - _ => Err(Oauth2Error::VerifyTokenError), - } +/// The name of the provider config file +pub fn get_providers_config_file() -> String { + std::env::var("OAUTH2_CONFIG_FILE").unwrap_or_else(|_| "oauth2.toml".to_string()) } -/// Decode the access token -/// -/// # Arguments -/// * `access_token` - The access token -/// -/// # Returns -/// The decoded access token -pub fn decode_access_token(access_token: &str) -> Result { - Ok(access_token.to_string()) -} #[cfg(test)] mod tests { use super::*; diff --git a/libs/oauth2/src/oauth_provider.rs b/libs/oauth2/src/oauth_provider.rs new file mode 100644 index 0000000..31e7d64 --- /dev/null +++ b/libs/oauth2/src/oauth_provider.rs @@ -0,0 +1,65 @@ +use crate::{ + errors::Oauth2Error, get_providers_config_file, get_providers_config_from_file, Claims, Provider, ProviderConfig +}; +use std::{future::Future, pin::Pin}; +use base64::prelude::{Engine as _, BASE64_URL_SAFE_NO_PAD}; + +pub struct OAuthResponse { + pub access_token: String, + pub username: String, + pub email: String, +} +pub trait OAuthProviderFactory { + fn new() -> Self; + /// Get the provider config for the given provider name + /// + /// # Arguments + /// * `provider_name` - The name of the provider + /// + /// # Returns + /// The provider config + fn get_provider_config(tprovider: Provider) -> ProviderConfig { + let provider_config = get_providers_config_from_file(get_providers_config_file().as_str()); + provider_config + .iter() + .find(|&provider| provider.provider == tprovider) + .expect("Provider not found") + .clone() + } +} + +pub trait OAuthProvider: Send + Sync{ + /// Get redirect url for the provider + /// + /// # Arguments + /// * `callback_url` - The callback url + /// * `state` - The state code + /// + /// # Returns + /// The redirect url + fn get_redirect_url(&self, callback_url: &str, state: &str) -> String; + fn exchange_code( + &self, + code: &str, + callback_url: &str, + ) -> Pin> + Send + Sync>>; + + /// Get the provider type + fn get_provider_type(&self) -> Provider; +} + +/// Decode the Oauth id token +/// # Arguments +/// * `id_token` - The jwt id token +/// +/// # Returns +/// the username and email +pub fn decode_oauth_id_token(id_token: &str) -> Result<(String, String), Oauth2Error> { + let parts: Vec<&str> = id_token.split('.').collect(); + let claims = BASE64_URL_SAFE_NO_PAD + .decode(parts[1]) + .map_err(|_| Oauth2Error::DecodeIdTokenError)?; + let claims: Claims = + serde_json::from_slice(&claims).map_err(|_| Oauth2Error::DecodeIdTokenError)?; + Ok((claims.name, claims.email)) +} \ No newline at end of file diff --git a/libs/state/src/state.rs b/libs/state/src/state.rs index 3e45977..389adca 100644 --- a/libs/state/src/state.rs +++ b/libs/state/src/state.rs @@ -377,7 +377,7 @@ impl ApiState { if oauth2_providers.is_empty() { log::debug!("get providers from {}", config_file); - let config = oauth2::get_provider_config(config_file); + let config = oauth2::get_providers_config_from_file(config_file); if config.len() == 0 { return None; } @@ -433,25 +433,23 @@ impl ApiState { } let oidc_session = oidc_session.unwrap(); oidc_session.code = Some(authorization_code.clone()); - if oidc_session.provider_config.is_some() + if oidc_session.provider.is_some() && oidc_session.code.is_some() && oidc_session.callback_url.is_some() { - let exchange_result = oauth2::exchange_code( - &oidc_session.clone().provider_config.unwrap(), - &authorization_code, - &oidc_session.clone().callback_url.unwrap(), - ) - .await; + let provider = oidc_session.clone().provider.unwrap(); + let callback_url = oidc_session.clone().callback_url.unwrap(); + let exchange_result = provider.exchange_code(authorization_code.as_str(), callback_url.as_str()).await; + if exchange_result.is_ok() { let access_token = exchange_result.unwrap(); - let username = access_token.1.clone(); + let username = access_token.username.clone(); - oidc_session.auth_token = Some(access_token.0.clone()); + oidc_session.auth_token = Some(access_token.access_token.clone()); oidc_session.name = Some(if username.len()>0 {username.clone()} else {oidc_session.id.clone()}); - oidc_session.email = Some(access_token.2.clone()); + oidc_session.email = Some(access_token.email.clone()); log::debug!("oidc_session_exchange_code {:?}", oidc_session.auth_token); - return Some(access_token.0); + return Some(access_token.access_token); } } None diff --git a/libs/utils/src/types.rs b/libs/utils/src/types.rs index 2039bfd..05b194a 100644 --- a/libs/utils/src/types.rs +++ b/libs/utils/src/types.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; use std::fmt; +use std::sync::Arc; +use oauth2::oauth_provider::OAuthProvider; use rocket_okapi::okapi::schemars; use rocket_okapi::JsonSchema; use serde::de::Visitor; @@ -518,7 +520,7 @@ pub struct OidcUserInfo { pub other: HashMap, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Clone)] pub struct OidcState { pub id: String, // is id of the Rustdesk client pub uuid: String, // is uuid of the Rustdesk client @@ -526,7 +528,7 @@ pub struct OidcState { pub auth_token: Option, pub redirect_url: Option, pub callback_url: Option, - pub provider_config: Option, + pub provider: Option>, pub name: Option, pub email: Option, } @@ -539,7 +541,7 @@ impl Default for OidcState { auth_token: None, redirect_url: None, callback_url: None, - provider_config: None, + provider: None, name: None, email: None, } diff --git a/src/lib.rs b/src/lib.rs index 587f253..111006c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,10 +3,12 @@ mod extended_json; use std::collections::HashMap; use std::env; +use std::sync::Arc; use api::ActionResponse; use extended_json::ExtendedJson; - +use oauth2::oauth_provider::OAuthProvider; +use oauth2::oauth_provider::OAuthProviderFactory; use rocket::fairing::{Fairing, Info, Kind}; use rocket::form::validate::Len; use rocket::http::Header; @@ -29,7 +31,9 @@ use rocket::{ }; use state::{ApiState, UserPasswordInfo}; use utils::{ - include_png_as_base64, unwrap_or_return, AbTagRenameRequest, AddUserRequest, AddressBook, EnableUserRequest, OidcSettingsResponse, SoftwareResponse, SoftwareVersionResponse, UpdateUserRequest, UserList + include_png_as_base64, unwrap_or_return, AbTagRenameRequest, AddUserRequest, AddressBook, + EnableUserRequest, OidcSettingsResponse, SoftwareResponse, SoftwareVersionResponse, + UpdateUserRequest, UserList, }; use utils::{ AbGetResponse, AbRequest, AuditRequest, CurrentUserRequest, CurrentUserResponse, @@ -458,7 +462,7 @@ async fn login_options( ) -> Result>, status::Unauthorized<()>> { let mut providers: Vec = Vec::new(); let providers_config = state - .get_oauth2_config(oauth2::get_provider_config_file().as_str()) + .get_oauth2_config(oauth2::get_providers_config_file().as_str()) .await; if providers_config.is_none() { return Err(status::Unauthorized::<()>(())); @@ -472,7 +476,7 @@ async fn login_options( /// OIDC Auth request /// /// This entrypoint is called by the client for getting the authorization url for the Oauth2 provider he chooses -/// +/// /// For testing you can generate a valid uuid field with the following command: `uuidgen | base64` #[openapi(tag = "login")] #[post("/api/oidc/auth", format = "application/json", data = "")] @@ -494,11 +498,10 @@ async fn oidc_auth( }); } let uuid_decoded = uuid_decoded.unwrap(); - let uuid_client = - String::from_utf8(uuid_decoded).unwrap(); + let uuid_client = String::from_utf8(uuid_decoded).unwrap(); let callback_url = format!("{}/api/oidc/callback", get_host(headers.clone())); let providers_config = state - .get_oauth2_config(oauth2::get_provider_config_file().as_str()) + .get_oauth2_config(oauth2::get_providers_config_file().as_str()) .await; if providers_config.is_none() { return Json(OidcAuthUrl { @@ -507,19 +510,32 @@ async fn oidc_auth( }); } let providers_config = providers_config.unwrap(); - let provider = providers_config + let provider_config = providers_config .iter() .find(|config| config.op == request.op); - if provider.is_none() { + if provider_config.is_none() { return Json(OidcAuthUrl { url: "".to_string(), code: "".to_string(), }); } - let provider = provider.unwrap(); - let redirect_url = - oauth2::get_redirect_url(provider, callback_url.as_str(), uuid_code.as_str()); + let provider_config = provider_config.unwrap(); + let provider_trait_object: Arc = { + match provider_config.provider { + oauth2::Provider::Github => todo!(), + oauth2::Provider::Gitlab => todo!(), + oauth2::Provider::Google => todo!(), + oauth2::Provider::Apple => todo!(), + oauth2::Provider::Okta => todo!(), + oauth2::Provider::Facebook => todo!(), + oauth2::Provider::Azure => todo!(), + oauth2::Provider::Auth0 => todo!(), + oauth2::Provider::Custom => Arc::new(oauth2::dex_provider::DexProvider::new()), + } + }; + + let redirect_url = provider_trait_object.get_redirect_url(callback_url.as_str(), uuid_code.as_str()); let _oidc_session = state .insert_oidc_session( uuid_code.clone(), @@ -530,9 +546,9 @@ async fn oidc_auth( auth_token: None, redirect_url: Some(redirect_url.clone()), callback_url: Some(callback_url), - provider_config: Some(provider.clone()), + provider: Some(provider_trait_object), name: None, - email: None + email: None, }, ) .await; @@ -1018,15 +1034,15 @@ async fn users_client( } /// Get the software download url -/// +/// /// # Arguments -/// +/// /// * `key` - The key to the software download link, it can be `osx`, `w64` or `ios` -/// +/// /// # Usage -/// +/// /// * it needs a valid S3 configuration file defined with the `S3_CONFIG_FILE` environment variable -/// +/// ///
 /// [s3config]
 /// Endpoint = "https://compat.objectstorage.eu-london-1.oraclecloud.com"
@@ -1041,7 +1057,10 @@ async fn users_client(
 /// IOSKey = "master/sctgdesk-releases/sctgdesk-1.2.4.ipa"
 /// 
#[openapi(tag = "Software")] -#[get("/api/software/client-download-link/", format = "application/json")] +#[get( + "/api/software/client-download-link/", + format = "application/json" +)] async fn software(key: &str) -> Result, status::NotFound<()>> { log::debug!("software"); let config = get_s3_config_file() @@ -1051,7 +1070,7 @@ async fn software(key: &str) -> Result, status::NotFound< let config = config.unwrap(); match key { "osx" => { - let key = config.clone().s3config.osxkey; + let key = config.clone().s3config.osxkey; let url = get_signed_release_url_with_config(config, key.as_str()) .await .map_err(|e| status::NotFound(Box::new(e))); @@ -1060,7 +1079,7 @@ async fn software(key: &str) -> Result, status::NotFound< Ok(Json(response)) } "w64" => { - let key = config.clone().s3config.windows64_key; + let key = config.clone().s3config.windows64_key; let url = get_signed_release_url_with_config(config, key.as_str()) .await .map_err(|e| status::NotFound(Box::new(e))); @@ -1069,7 +1088,7 @@ async fn software(key: &str) -> Result, status::NotFound< Ok(Json(response)) } "ios" => { - let key = config.clone().s3config.ioskey; + let key = config.clone().s3config.ioskey; let url = get_signed_release_url_with_config(config, key.as_str()) .await .map_err(|e| status::NotFound(Box::new(e))); @@ -1089,7 +1108,7 @@ async fn software_version() -> Json { let version = env::var("MAIN_PKG_VERSION").unwrap(); let response = SoftwareVersionResponse { server: Some(version), - client: None + client: None, }; Json(response) -} \ No newline at end of file +}