-
-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
245 additions
and
260 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
use std::{future::Future, pin::Pin}; | ||
|
||
use crate::{ | ||
errors::Oauth2Error, | ||
oauth_provider::{decode_oauth_id_token, OAuthProvider, OAuthProviderFactory, OAuthResponse}, | ||
Provider, ProviderConfig, TokenResponse, | ||
}; | ||
use base64::prelude::{Engine as _, BASE64_STANDARD}; | ||
|
||
use url::form_urlencoded; | ||
|
||
pub struct DexProvider { | ||
provider_config: ProviderConfig, | ||
} | ||
|
||
/// Get the authorization header for the provider | ||
/// | ||
/// # Arguments | ||
/// * `provider_config` - The provider configuration | ||
/// | ||
/// # Returns | ||
/// The authorization header | ||
fn get_authorization_header(provider_config: &ProviderConfig) -> String { | ||
format!( | ||
"Basic {}", | ||
BASE64_STANDARD.encode(format!( | ||
"{}:{}", | ||
provider_config.app_id, provider_config.app_secret | ||
)) | ||
) | ||
} | ||
|
||
impl OAuthProviderFactory for DexProvider { | ||
fn new() -> Self { | ||
let provider_config = Self::get_provider_config(Provider::Custom); | ||
Self { provider_config } | ||
} | ||
} | ||
impl OAuthProvider for DexProvider { | ||
fn get_redirect_url(&self, callback_url: &str, state: &str) -> String { | ||
let redirect_url = | ||
form_urlencoded::byte_serialize(callback_url.as_bytes()).collect::<String>(); | ||
let scope = form_urlencoded::byte_serialize(self.provider_config.scope.as_bytes()) | ||
.collect::<String>(); | ||
let state = form_urlencoded::byte_serialize(state.as_bytes()).collect::<String>(); | ||
|
||
format!( | ||
"{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}", | ||
self.provider_config.authorization_url, | ||
self.provider_config.app_id, | ||
redirect_url, | ||
scope, | ||
state | ||
) | ||
} | ||
|
||
fn exchange_code( | ||
&self, | ||
code: &str, | ||
callback_url: &str, | ||
) -> Pin<Box<dyn Future<Output = Result<OAuthResponse, Oauth2Error>> + Send + Sync>> { | ||
let code = code.to_string(); | ||
let callback_url = callback_url.to_string(); | ||
let provider_config = self.provider_config.clone(); | ||
|
||
Box::pin(async move { | ||
let code = form_urlencoded::byte_serialize(code.as_bytes()).collect::<String>(); | ||
let authorization_header = get_authorization_header(&provider_config); | ||
let response = reqwest::Client::new() | ||
.post(provider_config.token_exchange_url.as_str()) | ||
.header("Authorization", authorization_header) | ||
.header("Content-Type", "application/x-www-form-urlencoded") | ||
.form(&[ | ||
("grant_type", "authorization_code"), | ||
("code", code.as_str()), | ||
("redirect_uri", &callback_url), | ||
("client_id", &provider_config.app_id.as_str()), | ||
]) | ||
.send() | ||
.await | ||
.map_err(|_| Oauth2Error::ExchangeCodeError)?; | ||
|
||
let response = response | ||
.error_for_status() | ||
.map_err(|_| Oauth2Error::ExchangeCodeError)?; | ||
|
||
let body = response | ||
.json::<TokenResponse>() | ||
.await | ||
.map_err(|_| Oauth2Error::ExchangeCodeError)?; | ||
|
||
if let Some(id_token) = body.id_token { | ||
let (username, email) = decode_oauth_id_token(&id_token)?; | ||
Ok(OAuthResponse { | ||
access_token: body.access_token, | ||
username, | ||
email, | ||
}) | ||
} else { | ||
Err(Oauth2Error::ExchangeCodeError) | ||
} | ||
}) | ||
} | ||
|
||
fn get_provider_type(&self) -> crate::Provider { | ||
Provider::Custom | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,8 @@ | ||
use base64::{ | ||
engine::general_purpose::URL_SAFE_NO_PAD, | ||
prelude::{Engine as _, BASE64_STANDARD}, | ||
}; | ||
pub mod dex_provider; | ||
pub mod oauth_provider; | ||
use serde::{Deserialize, Serialize}; | ||
use std::{fs, str::FromStr}; | ||
use url::form_urlencoded; | ||
mod errors; | ||
use errors::Oauth2Error; | ||
|
||
#[derive(Deserialize, Serialize, Clone, Debug)] | ||
pub struct ProviderConfig { | ||
|
@@ -34,7 +30,7 @@ pub enum Provider { | |
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
struct Claims { | ||
pub struct Claims { | ||
aud: String, | ||
sub: String, | ||
name: String, | ||
|
@@ -89,229 +85,26 @@ pub struct TokenResponse { | |
pub refresh_token: Option<String>, | ||
} | ||
|
||
/// Get the name of the provider config file | ||
/// from the OAUTH2_CONFIG_FILE environment variable or | ||
/// default to "oauth2.toml" | ||
/// Get the providers config from a config file | ||
/// | ||
/// # Returns | ||
/// The name of the provider config file | ||
pub fn get_provider_config_file() -> String { | ||
std::env::var("OAUTH2_CONFIG_FILE").unwrap_or_else(|_| "oauth2.toml".to_string()) | ||
} | ||
|
||
/// Get the redirect url for the specified provider | ||
/// | ||
/// # Arguments | ||
/// * `provider_config` - The provider configuration | ||
/// | ||
/// # Returns | ||
/// The redirect url | ||
/// | ||
pub fn get_provider_config(config_file: &str) -> Vec<ProviderConfig> { | ||
/// # Returns | ||
/// The providers config | ||
pub fn get_providers_config_from_file(config_file: &str) -> Vec<ProviderConfig> { | ||
let config_file_content = fs::read_to_string(config_file).expect("Failed to read config file"); | ||
let config: Config = toml::from_str(&config_file_content).expect("Failed to parse config file"); | ||
config.provider | ||
} | ||
|
||
/// Get redirect url for the provider | ||
/// | ||
/// # Arguments | ||
/// * `provider_config` - The provider configuration | ||
/// * `callback_url` - The callback url | ||
/// * `state` - The state (for CSRF protection) | ||
/// | ||
/// # Returns | ||
/// The redirect url | ||
pub fn get_redirect_url( | ||
provider_config: &ProviderConfig, | ||
callback_url: &str, | ||
state: &str, | ||
) -> String { | ||
let redirect_url = form_urlencoded::byte_serialize(callback_url.as_bytes()).collect::<String>(); | ||
let scope = | ||
form_urlencoded::byte_serialize(provider_config.scope.as_bytes()).collect::<String>(); | ||
let state = form_urlencoded::byte_serialize(state.as_bytes()).collect::<String>(); | ||
|
||
format!( | ||
"{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}", | ||
provider_config.authorization_url, provider_config.app_id, redirect_url, scope, state | ||
) | ||
} | ||
|
||
/// Get the authorization header for the provider | ||
/// | ||
/// # Arguments | ||
/// * `provider_config` - The provider configuration | ||
/// | ||
/// # Returns | ||
/// The authorization header | ||
pub fn get_authorization_header(provider_config: &ProviderConfig) -> String { | ||
format!( | ||
"Basic {}", | ||
BASE64_STANDARD.encode(format!( | ||
"{}:{}", | ||
provider_config.app_id, provider_config.app_secret | ||
)) | ||
) | ||
} | ||
|
||
/// Exchange the code for an access token | ||
/// | ||
/// # Arguments | ||
/// * `provider_config` - The provider configuration | ||
/// * `code` - The code | ||
/// * `callback_url` - The callback url | ||
/// | ||
/// # Returns | ||
/// | ||
/// * The access token | ||
/// * The username | ||
/// * The email | ||
pub async fn exchange_code( | ||
provider_config: &ProviderConfig, | ||
code: &str, | ||
callback_url: &str, | ||
) -> Result<(String,String,String), Oauth2Error> { | ||
let code = form_urlencoded::byte_serialize(code.as_bytes()).collect::<String>(); | ||
//let callback_url = form_urlencoded::byte_serialize(callback_url.as_bytes()).collect::<String>(); | ||
let authorization_header = get_authorization_header(provider_config); | ||
let response = reqwest::Client::new() | ||
.post(provider_config.token_exchange_url.as_str()) | ||
.header("Authorization", authorization_header) | ||
.header("Content-Type", "application/x-www-form-urlencoded") | ||
.form(&[ | ||
("grant_type", "authorization_code"), | ||
("code", code.as_str()), | ||
("redirect_uri", callback_url), | ||
("client_id", provider_config.app_id.as_str()), | ||
]) | ||
.send() | ||
.await | ||
.map_err(|_| Oauth2Error::ExchangeCodeError)?; | ||
|
||
let response = response | ||
.error_for_status() | ||
.map_err(|_| Oauth2Error::ExchangeCodeError)?; | ||
|
||
let body = response | ||
.json::<TokenResponse>() | ||
.await | ||
.map_err(|_| Oauth2Error::ExchangeCodeError)?; | ||
|
||
if body.id_token.is_some() { | ||
// Response contains an id_token which is a JWT token | ||
let id_token = body.id_token.unwrap(); | ||
let resp = decode_id_token(&id_token); | ||
if resp.is_err() { | ||
return Err(Oauth2Error::DecodeIdTokenError); | ||
} | ||
let (name, email) = resp.unwrap(); | ||
return Ok((body.access_token, name, email)); | ||
} | ||
Ok((body.access_token, "".to_string(), "[email protected]".to_string())) | ||
} | ||
|
||
/// Decode the id token | ||
/// # Arguments | ||
/// * `id_token` - The jwt id token | ||
/// | ||
/// # Returns | ||
/// the username and email | ||
pub fn decode_id_token(id_token: &str) -> Result<(String, String), Oauth2Error> { | ||
let parts: Vec<&str> = id_token.split('.').collect(); | ||
let claims = URL_SAFE_NO_PAD | ||
.decode(parts[1]) | ||
.map_err(|_| Oauth2Error::DecodeIdTokenError)?; | ||
let claims: Claims = | ||
serde_json::from_slice(&claims).map_err(|_| Oauth2Error::DecodeIdTokenError)?; | ||
Ok((claims.name, claims.email)) | ||
} | ||
|
||
/// Verify the access token | ||
/// | ||
/// # Arguments | ||
/// * `provider_config` - The provider configuration | ||
/// * `access_token` - The access token | ||
/// Get the name of the provider config file | ||
/// from the OAUTH2_CONFIG_FILE environment variable or | ||
/// default to "oauth2.toml" | ||
/// | ||
/// # Returns | ||
/// (access_token, username, email) | ||
/// The access token if valid, otherwise an error | ||
pub async fn verify_access_token( | ||
provider_config: &ProviderConfig, | ||
access_token: &str, | ||
) -> Result<String, Oauth2Error> { | ||
let client = reqwest::Client::new(); | ||
match provider_config.provider { | ||
Provider::Github => { | ||
let response = client | ||
.get("https://api.github.com/user") | ||
.bearer_auth(access_token) | ||
.send() | ||
.await; | ||
match response { | ||
Ok(response) if response.status().is_success() => Ok(access_token.to_string()), | ||
_ => Err(Oauth2Error::VerifyTokenError), | ||
} | ||
} | ||
Provider::Google => { | ||
let response = client | ||
.get(&format!( | ||
"https://www.googleapis.com/oauth2/v1/tokeninfo?access_token={}", | ||
access_token | ||
)) | ||
.send() | ||
.await; | ||
|
||
match response { | ||
Ok(response) if response.status().is_success() => Ok(access_token.to_string()), | ||
_ => Err(Oauth2Error::VerifyTokenError), | ||
} | ||
} | ||
Provider::Facebook => { | ||
let app_token = format!("{}|{}", provider_config.app_id, provider_config.app_secret); | ||
let response = client | ||
.get(&format!( | ||
"https://graph.facebook.com/debug_token?input_token={}&access_token={}", | ||
access_token, app_token | ||
)) | ||
.send() | ||
.await; | ||
|
||
match response { | ||
Ok(response) if response.status().is_success() => Ok(access_token.to_string()), | ||
_ => Err(Oauth2Error::VerifyTokenError), | ||
} | ||
} | ||
Provider::Gitlab => { | ||
let response = client | ||
.post("https://gitlab.com/oauth/token/info") | ||
.bearer_auth(access_token) | ||
.send() | ||
.await; | ||
|
||
match response { | ||
Ok(response) if response.status().is_success() => Ok(access_token.to_string()), | ||
_ => Err(Oauth2Error::VerifyTokenError), | ||
} | ||
} | ||
// TODO: Add other providers | ||
Provider::Custom => Ok(access_token.to_string()), | ||
_ => Err(Oauth2Error::VerifyTokenError), | ||
} | ||
/// The name of the provider config file | ||
pub fn get_providers_config_file() -> String { | ||
std::env::var("OAUTH2_CONFIG_FILE").unwrap_or_else(|_| "oauth2.toml".to_string()) | ||
} | ||
|
||
/// Decode the access token | ||
/// | ||
/// # Arguments | ||
/// * `access_token` - The access token | ||
/// | ||
/// # Returns | ||
/// The decoded access token | ||
pub fn decode_access_token(access_token: &str) -> Result<String, Oauth2Error> { | ||
Ok(access_token.to_string()) | ||
} | ||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
Oops, something went wrong.