diff --git a/Cargo.lock b/Cargo.lock index e8b1c09bf..4cd06061d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1048,7 +1048,7 @@ dependencies = [ [[package]] name = "defguard" -version = "1.2.0" +version = "1.2.1" dependencies = [ "anyhow", "argon2", @@ -1072,6 +1072,7 @@ dependencies = [ "mime_guess", "model_derive", "openidconnect", + "paste", "pgp", "prost", "prost-build", diff --git a/Cargo.toml b/Cargo.toml index 30388b127..ab341ff6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "defguard" -version = "1.2.0" +version = "1.2.1" edition = "2021" license-file = "LICENSE.md" homepage = "https://defguard.net/" @@ -40,6 +40,7 @@ model_derive = { path = "model-derive" } openidconnect = { version = "3.5", default-features = false, optional = true, features = [ "reqwest", ] } +paste = "1.0.15" pgp = "0.14" prost = "0.13" pulldown-cmark = "0.12" diff --git a/src/enterprise/directory_sync/google.rs b/src/enterprise/directory_sync/google.rs index cc12a7b58..b7b226740 100644 --- a/src/enterprise/directory_sync/google.rs +++ b/src/enterprise/directory_sync/google.rs @@ -5,7 +5,7 @@ use chrono::{DateTime, TimeDelta, Utc}; use jsonwebtoken::{encode, Algorithm, EncodingKey, Header}; use reqwest::{header::AUTHORIZATION, Url}; -use super::{DirectoryGroup, DirectorySync, DirectorySyncError, DirectoryUser}; +use super::{parse_response, DirectoryGroup, DirectorySync, DirectorySyncError, DirectoryUser}; #[cfg(not(test))] const SCOPES: &str = "openid email profile https://www.googleapis.com/auth/admin.directory.customer.readonly https://www.googleapis.com/auth/admin.directory.group.readonly https://www.googleapis.com/auth/admin.directory.user.readonly"; @@ -108,25 +108,6 @@ struct GroupsResponse { groups: Vec, } -/// Parse a reqwest response and return the JSON body if the response is OK, otherwise map an error to a DirectorySyncError::RequestError -/// The context_message is used to provide more context to the error message. -async fn parse_response( - response: reqwest::Response, - context_message: &str, -) -> Result -where - T: serde::de::DeserializeOwned, -{ - let status = &response.status(); - match status { - &reqwest::StatusCode::OK => Ok(response.json().await?), - _ => Err(DirectorySyncError::RequestError(format!( - "{context_message} Code returned: {status}. Details: {}", - response.text().await? - ))), - } -} - impl GoogleDirectorySync { #[must_use] pub fn new(private_key: &str, client_email: &str, admin_email: &str) -> Self { @@ -184,17 +165,14 @@ impl GoogleDirectorySync { if self.is_token_expired() { return Err(DirectorySyncError::AccessTokenExpired); } - let access_token = self .access_token .as_ref() .ok_or(DirectorySyncError::AccessTokenExpired)?; let mut url = Url::from_str(GROUPS_URL).unwrap(); - url.query_pairs_mut() .append_pair("userKey", user_id) .append_pair("maxResults", "500"); - let client = reqwest::Client::new(); let response = client .get(url) @@ -246,7 +224,8 @@ impl GoogleDirectorySync { "https://admin.googleapis.com/admin/directory/v1/groups/{}/members", group.id ); - let mut url = Url::from_str(&url_str).unwrap(); + let mut url = + Url::parse(&url_str).map_err(|err| DirectorySyncError::InvalidUrl(err.to_string()))?; url.query_pairs_mut() .append_pair("includeDerivedMembership", "true") .append_pair("maxResults", "500"); @@ -364,7 +343,9 @@ impl DirectorySync for GoogleDirectorySync { } async fn test_connection(&self) -> Result<(), DirectorySyncError> { + debug!("Testing connection to Google API."); self.query_test_connection().await?; + info!("Successfully tested connection to Google API, connection is working."); Ok(()) } } diff --git a/src/enterprise/directory_sync/microsoft.rs b/src/enterprise/directory_sync/microsoft.rs new file mode 100644 index 000000000..c9c2036e2 --- /dev/null +++ b/src/enterprise/directory_sync/microsoft.rs @@ -0,0 +1,584 @@ +use std::time::Duration; + +use chrono::{TimeDelta, Utc}; +use reqwest::{header::AUTHORIZATION, Url}; +use serde::Deserialize; + +use super::{parse_response, DirectoryGroup, DirectorySync, DirectorySyncError, DirectoryUser}; + +#[allow(dead_code)] +pub(crate) struct MicrosoftDirectorySync { + access_token: Option, + token_expiry: Option>, + client_id: String, + client_secret: String, + url: String, +} + +#[cfg(not(test))] +const ACCESS_TOKEN_URL: &str = "https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token"; +#[cfg(not(test))] +const GROUPS_URL: &str = "https://graph.microsoft.com/v1.0/groups?$top=999"; +#[cfg(not(test))] +const USER_GROUPS: &str = "https://graph.microsoft.com/v1.0/users/{user_id}/memberOf?$top=999"; +#[cfg(not(test))] +const GROUP_MEMBERS: &str = "https://graph.microsoft.com/v1.0/groups/{group_id}/members?$select=accountEnabled,displayName,mail,otherMails&$top=999"; +const REQUEST_TIMEOUT: Duration = Duration::from_secs(10); +const ALL_USERS_URL: &str = + "https://graph.microsoft.com/v1.0/users?$select=accountEnabled,displayName,mail,otherMails&$top=999"; +#[cfg(not(test))] +const MICROSOFT_DEFAULT_SCOPE: &str = "https://graph.microsoft.com/.default"; +#[cfg(not(test))] +const GRANT_TYPE: &str = "client_credentials"; + +#[derive(Deserialize)] +struct TokenResponse { + #[serde(rename = "access_token")] + token: String, + expires_in: i64, +} + +#[derive(Deserialize)] +struct GroupDetails { + #[serde(rename = "displayName")] + display_name: String, + id: String, +} + +#[derive(Deserialize)] +struct GroupsResponse { + value: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct GroupMembersResponse { + value: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct User { + #[serde(rename = "displayName")] + display_name: String, + mail: Option, + #[serde(rename = "accountEnabled")] + account_enabled: bool, + #[serde(rename = "otherMails")] + other_mails: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct UsersResponse { + value: Vec, +} + +async fn make_get_request( + url: Url, + token: String, +) -> Result { + let client = reqwest::Client::new(); + let response = client + .get(url) + .header(AUTHORIZATION, format!("Bearer {token}")) + .timeout(REQUEST_TIMEOUT) + .send() + .await?; + Ok(response) +} + +#[cfg(not(test))] +impl MicrosoftDirectorySync { + async fn query_access_token(&self) -> Result { + debug!("Querying Microsoft directory sync access token."); + let tenant_id = self.extract_tenant()?; + let token_url = ACCESS_TOKEN_URL.replace("{tenant_id}", &tenant_id); + let client = reqwest::Client::new(); + let response = client + .post(&token_url) + .form(&[ + ("client_id", &self.client_id), + ("client_secret", &self.client_secret), + ("scope", &MICROSOFT_DEFAULT_SCOPE.to_string()), + ("grant_type", &GRANT_TYPE.to_string()), + ]) + .send() + .await?; + let token_response: TokenResponse = response.json().await?; + debug!("Microsoft directory sync access token queried successfully."); + Ok(token_response) + } + + async fn query_groups(&self) -> Result { + if self.is_token_expired() { + debug!("Microsoft directory sync access token is expired, aborting group query."); + return Err(DirectorySyncError::AccessTokenExpired); + } + let access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + let url = Url::parse(GROUPS_URL) + .map_err(|err| DirectorySyncError::InvalidUrl(err.to_string()))?; + let response = make_get_request(url, access_token.to_string()).await?; + parse_response(response, "Failed to query all Microsoft groups.").await + } + + async fn query_user_groups(&self, user_id: &str) -> Result { + if self.is_token_expired() { + debug!( + "Microsoft directory sync access token is expired, aborting query of user groups." + ); + return Err(DirectorySyncError::AccessTokenExpired); + } + let access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + let url = Url::parse(&USER_GROUPS.replace("{user_id}", user_id)) + .map_err(|err| DirectorySyncError::InvalidUrl(err.to_string()))?; + let response = make_get_request(url, access_token.to_string()).await?; + parse_response(response, "Failed to query user groups from Microsoft API.").await + } + + async fn query_group_members( + &self, + group: &DirectoryGroup, + ) -> Result { + if self.is_token_expired() { + debug!( + "Microsoft directory sync access token is expired, aborting group member query." + ); + return Err(DirectorySyncError::AccessTokenExpired); + } + let access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + + let url = Url::parse(&GROUP_MEMBERS.replace("{group_id}", &group.id)) + .map_err(|err| DirectorySyncError::InvalidUrl(err.to_string()))?; + let response = make_get_request(url, access_token.to_string()).await?; + parse_response( + response, + "Failed to query group members from Microsoft API.", + ) + .await + } + + async fn query_all_users(&self) -> Result { + if self.is_token_expired() { + debug!("Microsoft directory sync access token is expired, aborting all users query."); + return Err(DirectorySyncError::AccessTokenExpired); + } + let access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + let url = Url::parse(ALL_USERS_URL) + .map_err(|err| DirectorySyncError::InvalidUrl(err.to_string()))?; + let response = make_get_request(url, access_token.to_string()).await?; + parse_response(response, "Failed to query all users in the Microsoft API.").await + } +} + +impl MicrosoftDirectorySync { + pub(crate) const fn new(client_id: String, client_secret: String, url: String) -> Self { + Self { + access_token: None, + client_id, + client_secret, + url, + token_expiry: None, + } + } + + fn extract_tenant(&self) -> Result { + debug!("Extracting tenant ID from Microsoft base URL: {}", self.url); + let parts: Vec<&str> = self.url.split('/').collect(); + debug!("Split Microsoft base URL into the following parts: {parts:?}",); + let tenant_id = + parts + .get(parts.len() - 2) + .ok_or(DirectorySyncError::InvalidProviderConfiguration(format!( + "Couldn't extract tenant ID from the provided Microsoft API base URL: {}", + self.url + )))?; + debug!("Tenant ID extracted successfully: {tenant_id}",); + Ok(tenant_id.to_string()) + } + + async fn refresh_access_token(&mut self) -> Result<(), DirectorySyncError> { + debug!("Refreshing Microsoft directory sync access token."); + let token_response = self.query_access_token().await?; + let expires_in = TimeDelta::seconds(token_response.expires_in); + self.access_token = Some(token_response.token); + self.token_expiry = Some(Utc::now() + expires_in); + debug!( + "Microsoft directory sync access token refreshed, the new token expires at: {:?}", + self.token_expiry + ); + Ok(()) + } + + fn is_token_expired(&self) -> bool { + debug!( + "Checking if Microsoft directory sync token is expired, expiry date: {:?}", + self.token_expiry + ); + let result = self.token_expiry.map_or(true, |expiry| expiry < Utc::now()); + debug!("Token expiry check result: {result}"); + result + } + + async fn query_test_connection(&self) -> Result<(), DirectorySyncError> { + let access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + let url = Url::parse(&format!("{ALL_USERS_URL}?$top=1")) + .map_err(|err| DirectorySyncError::InvalidUrl(err.to_string()))?; + let response = make_get_request(url, access_token.to_string()).await?; + let _result: UsersResponse = + parse_response(response, "Failed to test connection to Microsoft API.").await?; + Ok(()) + } +} + +impl DirectorySync for MicrosoftDirectorySync { + async fn get_groups(&self) -> Result, DirectorySyncError> { + debug!("Querying all groups from Microsoft API."); + let groups = self + .query_groups() + .await? + .value + .into_iter() + .map(|group| DirectoryGroup { + id: group.id, + name: group.display_name, + }); + debug!("All groups queried successfully."); + Ok(groups.collect()) + } + + async fn get_user_groups( + &self, + user_id: &str, + ) -> Result, DirectorySyncError> { + debug!("Querying groups of user: {user_id}"); + let groups = self + .query_user_groups(user_id) + .await? + .value + .into_iter() + .map(|group| DirectoryGroup { + id: group.id, + name: group.display_name, + }); + debug!("User groups queried successfully."); + Ok(groups.collect()) + } + + async fn get_group_members( + &self, + group: &DirectoryGroup, + ) -> Result, DirectorySyncError> { + debug!("Querying members of group: {}", group.name); + let members = self + .query_group_members(group) + .await? + .value + .into_iter() + .filter_map(|user| { + if let Some(email) = user.mail { + Some(email) + } else { + warn!("User {} doesn't have an email address and will be skipped in synchronization.", user.display_name); + None + } + }); + debug!("Group members queried successfully."); + Ok(members.collect()) + } + + async fn prepare(&mut self) -> Result<(), DirectorySyncError> { + debug!("Preparing Microsoft directory sync..."); + if self.is_token_expired() { + debug!("Access token is expired, refreshing."); + self.refresh_access_token().await?; + debug!("Access token refreshed."); + } else { + debug!("Access token is still valid, skipping refresh."); + } + debug!("Microsoft directory sync prepared."); + Ok(()) + } + + async fn get_all_users(&self) -> Result, DirectorySyncError> { + debug!("Querying all users from Microsoft API."); + let users = self + .query_all_users() + .await? + .value + .into_iter() + .filter_map(|user| { + if let Some(email) = user.mail { + Some(DirectoryUser { email, active: user.account_enabled }) + } else if let Some(mail) = user.other_mails.first() { + warn!("User {} doesn't have a primary email address set, his first additional email address will be used: {mail}", user.display_name); + Some(DirectoryUser { email: mail.clone(), active: user.account_enabled }) + } else { + warn!("User {} doesn't have any email address and will be skipped in synchronization.", user.display_name); + None + } + }); + debug!("All users queried successfully."); + Ok(users.collect()) + } + + async fn test_connection(&self) -> Result<(), DirectorySyncError> { + debug!("Testing connection to Microsoft API."); + self.query_test_connection().await?; + info!("Successfully tested connection to Microsoft API, connection is working."); + Ok(()) + } +} + +#[cfg(test)] +impl MicrosoftDirectorySync { + async fn query_user_groups( + &self, + _user_id: &str, + ) -> Result { + if self.is_token_expired() { + return Err(DirectorySyncError::AccessTokenExpired); + } + let _access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + + Ok(GroupsResponse { + value: vec![GroupDetails { + display_name: "group1".into(), + id: "1".into(), + }], + }) + } + + async fn query_groups(&self) -> Result { + if self.is_token_expired() { + return Err(DirectorySyncError::AccessTokenExpired); + } + + let _access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + + Ok(GroupsResponse { + value: vec![ + GroupDetails { + display_name: "group1".into(), + id: "1".into(), + }, + GroupDetails { + display_name: "group2".into(), + id: "2".into(), + }, + GroupDetails { + display_name: "group3".into(), + id: "3".into(), + }, + ], + }) + } + + async fn query_group_members( + &self, + _group: &DirectoryGroup, + ) -> Result { + if self.is_token_expired() { + return Err(DirectorySyncError::AccessTokenExpired); + } + let _access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + + Ok(GroupMembersResponse { + value: vec![ + User { + display_name: "testuser".into(), + mail: Some("testuser@email.com".into()), + account_enabled: true, + other_mails: vec![], + }, + User { + display_name: "testuserdisabled".into(), + mail: Some("testuserdisabled@email.com".into()), + account_enabled: false, + other_mails: vec![], + }, + User { + display_name: "testuser2".into(), + mail: Some( + "testuser2@email.com + " + .into(), + ), + account_enabled: true, + other_mails: vec![], + }, + ], + }) + } + + async fn query_access_token(&self) -> Result { + Ok(TokenResponse { + token: "test_token_refreshed".into(), + expires_in: 3600, + }) + } + + async fn query_all_users(&self) -> Result { + if self.is_token_expired() { + return Err(DirectorySyncError::AccessTokenExpired); + } + let _access_token = self + .access_token + .as_ref() + .ok_or(DirectorySyncError::AccessTokenExpired)?; + Ok(UsersResponse { + value: vec![ + User { + display_name: "testuser".into(), + mail: Some("testuser@email.com".into()), + account_enabled: true, + other_mails: vec![], + }, + User { + display_name: "testuserdisabled".into(), + mail: Some("testuserdisabled@email.com".into()), + account_enabled: false, + other_mails: vec![], + }, + User { + display_name: "testuser2".into(), + mail: Some("testuser2@email.com".into()), + account_enabled: true, + other_mails: vec![], + }, + ], + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_tenant() { + let provider = MicrosoftDirectorySync::new( + "client_id".to_string(), + "client_secret".to_string(), + "https://login.microsoftonline.com/tenant-id-123/v2.0".to_string(), + ); + let tenant = provider.extract_tenant().unwrap(); + assert_eq!(tenant, "tenant-id-123"); + } + + #[tokio::test] + async fn test_token() { + let mut dirsync = MicrosoftDirectorySync::new( + "id".to_string(), + "secret".to_string(), + "https://login.microsoftonline.com/tenant-id-123/v2.0".to_string(), + ); + + // no token + assert!(dirsync.is_token_expired()); + + // expired token + dirsync.access_token = Some("test_token".into()); + dirsync.token_expiry = Some(Utc::now() - TimeDelta::seconds(10000)); + assert!(dirsync.is_token_expired()); + + // valid token + dirsync.access_token = Some("test_token".into()); + dirsync.token_expiry = Some(Utc::now() + TimeDelta::seconds(10000)); + assert!(!dirsync.is_token_expired()); + + // no token + dirsync.access_token = Some("test_token".into()); + dirsync.token_expiry = Some(Utc::now() - TimeDelta::seconds(10000)); + dirsync.refresh_access_token().await.unwrap(); + assert!(!dirsync.is_token_expired()); + assert_eq!(dirsync.access_token, Some("test_token_refreshed".into())); + } + + #[tokio::test] + async fn test_all_users() { + let mut dirsync = MicrosoftDirectorySync::new( + "id".to_string(), + "secret".to_string(), + "https://login.microsoftonline.com/tenant-id-123/v2.0".to_string(), + ); + dirsync.refresh_access_token().await.unwrap(); + + let users = dirsync.get_all_users().await.unwrap(); + + assert_eq!(users.len(), 3); + assert_eq!(users[1].email, "testuserdisabled@email.com"); + assert!(!users[1].active); + } + + #[tokio::test] + async fn test_groups() { + let mut dirsync = MicrosoftDirectorySync::new( + "id".to_string(), + "secret".to_string(), + "https://login.microsoftonline.com/tenant-id-123/v2.0".to_string(), + ); + dirsync.refresh_access_token().await.unwrap(); + + let groups = dirsync.get_groups().await.unwrap(); + + assert_eq!(groups.len(), 3); + + for (i, group) in groups.iter().enumerate().take(3) { + assert_eq!(group.id, (i + 1).to_string()); + assert_eq!(group.name, format!("group{}", i + 1)); + } + } + + #[tokio::test] + async fn test_user_groups() { + let mut dirsync = MicrosoftDirectorySync::new( + "id".to_string(), + "secret".to_string(), + "https://login.microsoftonline.com/tenant-id-123/v2.0".to_string(), + ); + dirsync.refresh_access_token().await.unwrap(); + + let groups = dirsync.get_user_groups("testuser").await.unwrap(); + assert_eq!(groups.len(), 1); + assert_eq!(groups[0].id, "1"); + assert_eq!(groups[0].name, "group1"); + } + + #[tokio::test] + async fn test_group_members() { + let mut dirsync = MicrosoftDirectorySync::new( + "id".to_string(), + "secret".to_string(), + "https://login.microsoftonline.com/tenant-id-123/v2.0".to_string(), + ); + dirsync.refresh_access_token().await.unwrap(); + + let groups = dirsync.get_groups().await.unwrap(); + let members = dirsync.get_group_members(&groups[0]).await.unwrap(); + + assert_eq!(members.len(), 3); + assert_eq!(members[0], "testuser@email.com"); + } +} diff --git a/src/enterprise/directory_sync/mod.rs b/src/enterprise/directory_sync/mod.rs index 79c3a87bc..8e256c778 100644 --- a/src/enterprise/directory_sync/mod.rs +++ b/src/enterprise/directory_sync/mod.rs @@ -1,5 +1,6 @@ use std::collections::{HashMap, HashSet}; +use paste::paste; use sqlx::error::Error as SqlxError; use sqlx::PgPool; use thiserror::Error; @@ -30,6 +31,10 @@ pub enum DirectorySyncError { NotConfigured, #[error("Couldn't map provider's group to a Defguard group as it doesn't exist. There may be an issue with automatic group creation. Error details: {0}")] DefGuardGroupNotFound(String), + #[error("The provided provider configuration is invalid: {0}")] + InvalidProviderConfiguration(String), + #[error("Couldn't construct URL from the given string: {0}")] + InvalidUrl(String), } impl From for DirectorySyncError { @@ -47,6 +52,7 @@ impl From for DirectorySyncError { } pub mod google; +pub mod microsoft; #[derive(Debug, Serialize, Deserialize)] pub struct DirectoryGroup { @@ -88,6 +94,123 @@ trait DirectorySync { async fn test_connection(&self) -> Result<(), DirectorySyncError>; } +/// This macro generates a boilerplate enum which enables a simple polymorphism for things that implement +/// the DirectorySync trait without having to resolve to fully dynamic dispatch using something like Box. +/// +/// +/// When creating a new provider, make sure that: +/// - The provider main struct is called DirectorySync, e.g. GoogleDirectorySync +/// - The provider implements the [`DirectorySync`] trait +/// - You implemented some way to initialize the provider client and added an initialization step in the [`DirectorySyncClient::build`] function +/// - You added the provider name to the macro invocation below the macro definition +/// - You've implemented your provider logic in a file called the same as your provider but lowercase, e.g. google.rs +/// +// If you have time to refactor the whole thing to use boxes instead, go ahead. +macro_rules! dirsync_clients { + ($($variant:ident),*) => { + paste! { + pub(crate) enum DirectorySyncClient { + $( + $variant([< $variant:lower >]::[< $variant DirectorySync >]), + )* + } + } + + impl DirectorySync for DirectorySyncClient { + async fn get_groups(&self) -> Result, DirectorySyncError> { + match self { + $( + DirectorySyncClient::$variant(client) => client.get_groups().await, + )* + } + } + + async fn get_user_groups(&self, user_id: &str) -> Result, DirectorySyncError> { + match self { + $( + DirectorySyncClient::$variant(client) => client.get_user_groups(user_id).await, + )* + } + } + + async fn get_group_members(&self, group: &DirectoryGroup) -> Result, DirectorySyncError> { + match self { + $( + DirectorySyncClient::$variant(client) => client.get_group_members(group).await, + )* + } + } + + async fn prepare(&mut self) -> Result<(), DirectorySyncError> { + match self { + $( + DirectorySyncClient::$variant(client) => client.prepare().await, + )* + } + } + + async fn get_all_users(&self) -> Result, DirectorySyncError> { + match self { + $( + DirectorySyncClient::$variant(client) => client.get_all_users().await, + )* + } + } + + async fn test_connection(&self) -> Result<(), DirectorySyncError> { + match self { + $( + DirectorySyncClient::$variant(client) => client.test_connection().await, + )* + } + } + } + }; +} + +dirsync_clients!(Google, Microsoft); + +impl DirectorySyncClient { + /// Builds the current directory sync client based on the current provider settings (provider name), if possible. + pub(crate) async fn build(pool: &PgPool) -> Result { + let provider_settings = OpenIdProvider::get_current(pool) + .await? + .ok_or(DirectorySyncError::NotConfigured)?; + + match provider_settings.name.as_str() { + "Google" => { + debug!("Google directory sync provider selected"); + match ( + provider_settings.google_service_account_key.as_ref(), + provider_settings.google_service_account_email.as_ref(), + provider_settings.admin_email.as_ref(), + ) { + (Some(key), Some(email), Some(admin_email)) => { + debug!("Google directory has all the configuration needed, proceeding with creating the sync client"); + let client = google::GoogleDirectorySync::new(key, email, admin_email); + debug!("Google directory sync client created"); + Ok(Self::Google(client)) + } + _ => Err(DirectorySyncError::NotConfigured), + } + } + "Microsoft" => { + debug!("Microsoft directory sync provider selected"); + let client = microsoft::MicrosoftDirectorySync::new( + provider_settings.client_id, + provider_settings.client_secret, + provider_settings.base_url, + ); + debug!("Microsoft directory sync client created"); + Ok(Self::Microsoft(client)) + } + _ => Err(DirectorySyncError::UnsupportedProvider( + provider_settings.name.clone(), + )), + } + } +} + async fn sync_user_groups( directory_sync: &T, user: &User, @@ -148,7 +271,7 @@ pub(crate) async fn test_directory_sync_connection( return Ok(()); } - match get_directory_sync_client(pool).await { + match DirectorySyncClient::build(pool).await { Ok(mut dir_sync) => { dir_sync.prepare().await?; dir_sync.test_connection().await?; @@ -178,7 +301,7 @@ pub(crate) async fn sync_user_groups_if_configured( return Ok(()); } - match get_directory_sync_client(pool).await { + match DirectorySyncClient::build(pool).await { Ok(mut dir_sync) => { dir_sync.prepare().await?; sync_user_groups(&dir_sync, user, pool).await?; @@ -319,37 +442,6 @@ async fn sync_all_users_groups( Ok(()) } -async fn get_directory_sync_client( - pool: &PgPool, -) -> Result { - debug!("Getting directory sync client"); - let provider_settings = OpenIdProvider::get_current(pool) - .await? - .ok_or(DirectorySyncError::NotConfigured)?; - - match provider_settings.name.as_str() { - "Google" => { - debug!("Google directory sync provider selected"); - match ( - provider_settings.google_service_account_key.as_ref(), - provider_settings.google_service_account_email.as_ref(), - provider_settings.admin_email.as_ref(), - ) { - (Some(key), Some(email), Some(admin_email)) => { - debug!("Google directory has all the configuration needed, proceeding with creating the sync client"); - let client = google::GoogleDirectorySync::new(key, email, admin_email); - debug!("Google directory sync client created"); - Ok(client) - } - _ => Err(DirectorySyncError::NotConfigured), - } - } - _ => Err(DirectorySyncError::UnsupportedProvider( - provider_settings.name.clone(), - )), - } -} - fn is_directory_sync_enabled(provider: Option<&OpenIdProvider>) -> bool { debug!("Checking if directory sync is enabled"); if let Some(provider_settings) = provider { @@ -570,7 +662,7 @@ pub(crate) async fn do_directory_sync(pool: &PgPool) -> Result<(), DirectorySync .ok_or(DirectorySyncError::NotConfigured)? .directory_sync_target; - match get_directory_sync_client(pool).await { + match DirectorySyncClient::build(pool).await { Ok(mut dir_sync) => { // TODO: Directory sync's access token is dropped every time, find a way to preserve it // Same goes for Etags, those could be used to reduce the amount of data transferred. Some way @@ -597,6 +689,30 @@ pub(crate) async fn do_directory_sync(pool: &PgPool) -> Result<(), DirectorySync Ok(()) } +/// Parse a reqwest response and return the JSON body if the response is OK, otherwise map an error to a DirectorySyncError::RequestError +/// The context_message is used to provide more context to the error message. +async fn parse_response( + response: reqwest::Response, + context_message: &str, +) -> Result +where + T: serde::de::DeserializeOwned, +{ + let status = &response.status(); + match status { + &reqwest::StatusCode::OK => { + let json: serde_json::Value = response.json().await?; + Ok(serde_json::from_value(json).map_err(|err| { + DirectorySyncError::RequestError(format!("{context_message} Error: {err}")) + })?) + } + _ => Err(DirectorySyncError::RequestError(format!( + "{context_message} Code returned: {status}. Details: {}", + response.text().await? + ))), + } +} + #[cfg(test)] mod test { use secrecy::ExposeSecret; @@ -674,7 +790,7 @@ mod test { DirectorySyncTarget::All, ) .await; - let mut client = get_directory_sync_client(&pool).await.unwrap(); + let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); let user1 = make_test_user("user1", &pool).await; make_test_user("user2", &pool).await; @@ -704,7 +820,7 @@ mod test { DirectorySyncTarget::All, ) .await; - let mut client = get_directory_sync_client(&pool).await.unwrap(); + let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); let user1 = make_test_user("user1", &pool).await; @@ -738,7 +854,7 @@ mod test { DirectorySyncTarget::All, ) .await; - let mut client = get_directory_sync_client(&pool).await.unwrap(); + let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); let user1 = make_test_user("user1", &pool).await; @@ -779,7 +895,7 @@ mod test { User::init_admin_user(&pool, config.default_admin_password.expose_secret()) .await .unwrap(); - let mut client = get_directory_sync_client(&pool).await.unwrap(); + let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); let user1 = make_test_user("user1", &pool).await; @@ -817,7 +933,7 @@ mod test { DirectorySyncTarget::All, ) .await; - let mut client = get_directory_sync_client(&pool).await.unwrap(); + let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); let user1 = make_test_user("user1", &pool).await; @@ -860,7 +976,7 @@ mod test { DirectorySyncTarget::All, ) .await; - let mut client = get_directory_sync_client(&pool).await.unwrap(); + let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); let user1 = make_test_user("user1", &pool).await; @@ -908,7 +1024,7 @@ mod test { DirectorySyncTarget::All, ) .await; - let mut client = get_directory_sync_client(&pool).await.unwrap(); + let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); make_test_user("testuser", &pool).await; @@ -958,7 +1074,7 @@ mod test { DirectorySyncTarget::All, ) .await; - let mut client = get_directory_sync_client(&pool).await.unwrap(); + let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); let user = make_test_user("testuser", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); @@ -981,7 +1097,7 @@ mod test { DirectorySyncTarget::Users, ) .await; - let mut client = get_directory_sync_client(&pool).await.unwrap(); + let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); let user = make_test_user("testuser", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); @@ -1002,7 +1118,7 @@ mod test { DirectorySyncTarget::All, ) .await; - let mut client = get_directory_sync_client(&pool).await.unwrap(); + let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); let user = make_test_user("testuser", &pool).await; make_test_user("user2", &pool).await; @@ -1026,7 +1142,7 @@ mod test { DirectorySyncTarget::Groups, ) .await; - let mut client = get_directory_sync_client(&pool).await.unwrap(); + let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); let user = make_test_user("testuser", &pool).await; make_test_user("user2", &pool).await; @@ -1050,7 +1166,7 @@ mod test { DirectorySyncTarget::All, ) .await; - let mut client = get_directory_sync_client(&pool).await.unwrap(); + let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); // Make one admin and check if he's deleted @@ -1093,7 +1209,7 @@ mod test { DirectorySyncTarget::All, ) .await; - let mut client = get_directory_sync_client(&pool).await.unwrap(); + let mut client = DirectorySyncClient::build(&pool).await.unwrap(); client.prepare().await.unwrap(); // a user that's not in the directory diff --git a/web/src/pages/settings/components/OpenIdSettings/components/DirectorySyncSettings.tsx b/web/src/pages/settings/components/OpenIdSettings/components/DirectorySyncSettings.tsx index 8dd29b3d0..a4a79b36a 100644 --- a/web/src/pages/settings/components/OpenIdSettings/components/DirectorySyncSettings.tsx +++ b/web/src/pages/settings/components/OpenIdSettings/components/DirectorySyncSettings.tsx @@ -17,7 +17,7 @@ import useApi from '../../../../../shared/hooks/useApi'; import { useToaster } from '../../../../../shared/hooks/useToaster'; import { titleCase } from '../../../../../shared/utils/titleCase'; -const SUPPORTED_SYNC_PROVIDERS = ['Google']; +const SUPPORTED_SYNC_PROVIDERS = ['Google', 'Microsoft']; export const DirsyncSettings = ({ isLoading }: { isLoading: boolean }) => { const { LL } = useI18nContext(); @@ -89,155 +89,159 @@ export const DirsyncSettings = ({ isLoading }: { isLoading: boolean }) => {
{showDirsync ? ( - providerName === 'Google' ? ( - <> -
- {/* FIXME: Really buggy when using the controller, investigate why */} - setValue('directory_sync_enabled', val)} - // controller={{ control, name: 'directory_sync_enabled' }} - /> -
- ({ - key: val, - displayValue: titleCase(val), - })} - labelExtras={ - {parse(localLL.form.labels.sync_target.helper())} - } - disabled={isLoading} - /> - {parse(localLL.form.labels.sync_interval.helper())} - } - disabled={isLoading} - /> - ({ - key: val, - displayValue: titleCase(val), - })} - labelExtras={ - {parse(localLL.form.labels.user_behavior.helper())} - } - disabled={isLoading} - /> - ({ - key: val, - displayValue: titleCase(val), - })} - labelExtras={ - {parse(localLL.form.labels.admin_behavior.helper())} - } - disabled={isLoading} - /> - {parse(localLL.form.labels.admin_email.helper())} - } - required={dirsyncEnabled} + <> +
+ {/* FIXME: Really buggy when using the controller, investigate why */} + setValue('directory_sync_enabled', val)} + // controller={{ control, name: 'directory_sync_enabled' }} /> - - {parse(localLL.form.labels.service_account_used.helper())} - - } - disabled={isLoading} - required={dirsyncEnabled} - /> -
-
- - {localLL.form.labels.service_account_key_file.helper()} -
-
- { - const file = e.target.files?.[0]; - if (file) { - const reader = new FileReader(); - reader.onload = (e) => { - if (e?.target?.result) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment - const key = JSON.parse(e.target?.result as string); - // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access - setValue('google_service_account_key', key.private_key); - // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access - setValue('google_service_account_email', key.client_email); - setGoogleServiceAccountFileName(file.name); - } - }; - reader.readAsText(file); - } - }} - disabled={isLoading} - /> -
- {' '} -

- {googleServiceAccountFileName - ? `${localLL.form.labels.service_account_key_file.uploaded()}: ${googleServiceAccountFileName}` - : localLL.form.labels.service_account_key_file.uploadPrompt()} -

+
+ ({ + key: val, + displayValue: titleCase(val), + })} + labelExtras={ + {parse(localLL.form.labels.sync_target.helper())} + } + disabled={isLoading} + /> + {parse(localLL.form.labels.sync_interval.helper())} + } + disabled={isLoading} + /> + ({ + key: val, + displayValue: titleCase(val), + })} + labelExtras={ + {parse(localLL.form.labels.user_behavior.helper())} + } + disabled={isLoading} + /> + ({ + key: val, + displayValue: titleCase(val), + })} + labelExtras={ + {parse(localLL.form.labels.admin_behavior.helper())} + } + disabled={isLoading} + /> + {providerName === 'Google' ? ( + <> + {parse(localLL.form.labels.admin_email.helper())} + } + required={dirsyncEnabled} + /> + + {parse(localLL.form.labels.service_account_used.helper())} + + } + disabled={isLoading} + required={dirsyncEnabled} + /> +
+
+ + + {localLL.form.labels.service_account_key_file.helper()} + +
+
+ { + const file = e.target.files?.[0]; + if (file) { + const reader = new FileReader(); + reader.onload = (e) => { + if (e?.target?.result) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + const key = JSON.parse(e.target?.result as string); + // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access + setValue('google_service_account_key', key.private_key); + // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access + setValue('google_service_account_email', key.client_email); + setGoogleServiceAccountFileName(file.name); + } + }; + reader.readAsText(file); + } + }} + disabled={isLoading} + /> +
+ {' '} +

+ {googleServiceAccountFileName + ? `${localLL.form.labels.service_account_key_file.uploaded()}: ${googleServiceAccountFileName}` + : localLL.form.labels.service_account_key_file.uploadPrompt()} +

+
-
-
- -
- - ) : null + + ) : null} +
+ +
+ ) : (

{localLL.form.directory_sync_settings.notSupported()} diff --git a/web/src/pages/settings/components/OpenIdSettings/components/OpenIdSettingsRootForm.tsx b/web/src/pages/settings/components/OpenIdSettings/components/OpenIdSettingsRootForm.tsx index c2c8ab779..394e05e27 100644 --- a/web/src/pages/settings/components/OpenIdSettings/components/OpenIdSettingsRootForm.tsx +++ b/web/src/pages/settings/components/OpenIdSettings/components/OpenIdSettingsRootForm.tsx @@ -106,7 +106,7 @@ export const OpenIdSettingsRootForm = () => { }); } - if (val.directory_sync_enabled) { + if (val.directory_sync_enabled && val.name === 'Google') { if (val.admin_email.length === 0) { ctx.addIssue({ code: z.ZodIssueCode.custom,