Skip to content

Commit

Permalink
refactor oauth2 in a "object" style
Browse files Browse the repository at this point in the history
  • Loading branch information
eltorio committed May 14, 2024
1 parent 0c8929b commit b61a972
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 260 deletions.
108 changes: 108 additions & 0 deletions libs/oauth2/src/dex_provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use std::{future::Future, pin::Pin};

use crate::{
errors::Oauth2Error,
oauth_provider::{decode_oauth_id_token, OAuthProvider, OAuthProviderFactory, OAuthResponse},
Provider, ProviderConfig, TokenResponse,
};
use base64::prelude::{Engine as _, BASE64_STANDARD};

use url::form_urlencoded;

pub struct DexProvider {
provider_config: ProviderConfig,
}

/// Get the authorization header for the provider
///
/// # Arguments
/// * `provider_config` - The provider configuration
///
/// # Returns
/// The authorization header
fn get_authorization_header(provider_config: &ProviderConfig) -> String {
format!(
"Basic {}",
BASE64_STANDARD.encode(format!(
"{}:{}",
provider_config.app_id, provider_config.app_secret
))
)
}

impl OAuthProviderFactory for DexProvider {
fn new() -> Self {
let provider_config = Self::get_provider_config(Provider::Custom);
Self { provider_config }
}
}
impl OAuthProvider for DexProvider {
fn get_redirect_url(&self, callback_url: &str, state: &str) -> String {
let redirect_url =
form_urlencoded::byte_serialize(callback_url.as_bytes()).collect::<String>();
let scope = form_urlencoded::byte_serialize(self.provider_config.scope.as_bytes())
.collect::<String>();
let state = form_urlencoded::byte_serialize(state.as_bytes()).collect::<String>();

format!(
"{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}",
self.provider_config.authorization_url,
self.provider_config.app_id,
redirect_url,
scope,
state
)
}

fn exchange_code(
&self,
code: &str,
callback_url: &str,
) -> Pin<Box<dyn Future<Output = Result<OAuthResponse, Oauth2Error>> + Send + Sync>> {
let code = code.to_string();
let callback_url = callback_url.to_string();
let provider_config = self.provider_config.clone();

Box::pin(async move {
let code = form_urlencoded::byte_serialize(code.as_bytes()).collect::<String>();
let authorization_header = get_authorization_header(&provider_config);
let response = reqwest::Client::new()
.post(provider_config.token_exchange_url.as_str())
.header("Authorization", authorization_header)
.header("Content-Type", "application/x-www-form-urlencoded")
.form(&[
("grant_type", "authorization_code"),
("code", code.as_str()),
("redirect_uri", &callback_url),
("client_id", &provider_config.app_id.as_str()),
])
.send()
.await
.map_err(|_| Oauth2Error::ExchangeCodeError)?;

let response = response
.error_for_status()
.map_err(|_| Oauth2Error::ExchangeCodeError)?;

let body = response
.json::<TokenResponse>()
.await
.map_err(|_| Oauth2Error::ExchangeCodeError)?;

if let Some(id_token) = body.id_token {
let (username, email) = decode_oauth_id_token(&id_token)?;
Ok(OAuthResponse {
access_token: body.access_token,
username,
email,
})
} else {
Err(Oauth2Error::ExchangeCodeError)
}
})
}

fn get_provider_type(&self) -> crate::Provider {
Provider::Custom
}
}
233 changes: 13 additions & 220 deletions libs/oauth2/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
use base64::{
engine::general_purpose::URL_SAFE_NO_PAD,
prelude::{Engine as _, BASE64_STANDARD},
};
pub mod dex_provider;
pub mod oauth_provider;
use serde::{Deserialize, Serialize};
use std::{fs, str::FromStr};
use url::form_urlencoded;
mod errors;
use errors::Oauth2Error;

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct ProviderConfig {
Expand Down Expand Up @@ -34,7 +30,7 @@ pub enum Provider {
}

#[derive(Debug, Serialize, Deserialize)]
struct Claims {
pub struct Claims {
aud: String,
sub: String,
name: String,
Expand Down Expand Up @@ -89,229 +85,26 @@ pub struct TokenResponse {
pub refresh_token: Option<String>,
}

/// Get the name of the provider config file
/// from the OAUTH2_CONFIG_FILE environment variable or
/// default to "oauth2.toml"
/// Get the providers config from a config file
///
/// # Returns
/// The name of the provider config file
pub fn get_provider_config_file() -> String {
std::env::var("OAUTH2_CONFIG_FILE").unwrap_or_else(|_| "oauth2.toml".to_string())
}

/// Get the redirect url for the specified provider
///
/// # Arguments
/// * `provider_config` - The provider configuration
///
/// # Returns
/// The redirect url
///
pub fn get_provider_config(config_file: &str) -> Vec<ProviderConfig> {
/// # Returns
/// The providers config
pub fn get_providers_config_from_file(config_file: &str) -> Vec<ProviderConfig> {
let config_file_content = fs::read_to_string(config_file).expect("Failed to read config file");
let config: Config = toml::from_str(&config_file_content).expect("Failed to parse config file");
config.provider
}

/// Get redirect url for the provider
///
/// # Arguments
/// * `provider_config` - The provider configuration
/// * `callback_url` - The callback url
/// * `state` - The state (for CSRF protection)
///
/// # Returns
/// The redirect url
pub fn get_redirect_url(
provider_config: &ProviderConfig,
callback_url: &str,
state: &str,
) -> String {
let redirect_url = form_urlencoded::byte_serialize(callback_url.as_bytes()).collect::<String>();
let scope =
form_urlencoded::byte_serialize(provider_config.scope.as_bytes()).collect::<String>();
let state = form_urlencoded::byte_serialize(state.as_bytes()).collect::<String>();

format!(
"{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}",
provider_config.authorization_url, provider_config.app_id, redirect_url, scope, state
)
}

/// Get the authorization header for the provider
///
/// # Arguments
/// * `provider_config` - The provider configuration
///
/// # Returns
/// The authorization header
pub fn get_authorization_header(provider_config: &ProviderConfig) -> String {
format!(
"Basic {}",
BASE64_STANDARD.encode(format!(
"{}:{}",
provider_config.app_id, provider_config.app_secret
))
)
}

/// Exchange the code for an access token
///
/// # Arguments
/// * `provider_config` - The provider configuration
/// * `code` - The code
/// * `callback_url` - The callback url
///
/// # Returns
///
/// * The access token
/// * The username
/// * The email
pub async fn exchange_code(
provider_config: &ProviderConfig,
code: &str,
callback_url: &str,
) -> Result<(String,String,String), Oauth2Error> {
let code = form_urlencoded::byte_serialize(code.as_bytes()).collect::<String>();
//let callback_url = form_urlencoded::byte_serialize(callback_url.as_bytes()).collect::<String>();
let authorization_header = get_authorization_header(provider_config);
let response = reqwest::Client::new()
.post(provider_config.token_exchange_url.as_str())
.header("Authorization", authorization_header)
.header("Content-Type", "application/x-www-form-urlencoded")
.form(&[
("grant_type", "authorization_code"),
("code", code.as_str()),
("redirect_uri", callback_url),
("client_id", provider_config.app_id.as_str()),
])
.send()
.await
.map_err(|_| Oauth2Error::ExchangeCodeError)?;

let response = response
.error_for_status()
.map_err(|_| Oauth2Error::ExchangeCodeError)?;

let body = response
.json::<TokenResponse>()
.await
.map_err(|_| Oauth2Error::ExchangeCodeError)?;

if body.id_token.is_some() {
// Response contains an id_token which is a JWT token
let id_token = body.id_token.unwrap();
let resp = decode_id_token(&id_token);
if resp.is_err() {
return Err(Oauth2Error::DecodeIdTokenError);
}
let (name, email) = resp.unwrap();
return Ok((body.access_token, name, email));
}
Ok((body.access_token, "".to_string(), "[email protected]".to_string()))
}

/// Decode the id token
/// # Arguments
/// * `id_token` - The jwt id token
///
/// # Returns
/// the username and email
pub fn decode_id_token(id_token: &str) -> Result<(String, String), Oauth2Error> {
let parts: Vec<&str> = id_token.split('.').collect();
let claims = URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|_| Oauth2Error::DecodeIdTokenError)?;
let claims: Claims =
serde_json::from_slice(&claims).map_err(|_| Oauth2Error::DecodeIdTokenError)?;
Ok((claims.name, claims.email))
}

/// Verify the access token
///
/// # Arguments
/// * `provider_config` - The provider configuration
/// * `access_token` - The access token
/// Get the name of the provider config file
/// from the OAUTH2_CONFIG_FILE environment variable or
/// default to "oauth2.toml"
///
/// # Returns
/// (access_token, username, email)
/// The access token if valid, otherwise an error
pub async fn verify_access_token(
provider_config: &ProviderConfig,
access_token: &str,
) -> Result<String, Oauth2Error> {
let client = reqwest::Client::new();
match provider_config.provider {
Provider::Github => {
let response = client
.get("https://api.github.com/user")
.bearer_auth(access_token)
.send()
.await;
match response {
Ok(response) if response.status().is_success() => Ok(access_token.to_string()),
_ => Err(Oauth2Error::VerifyTokenError),
}
}
Provider::Google => {
let response = client
.get(&format!(
"https://www.googleapis.com/oauth2/v1/tokeninfo?access_token={}",
access_token
))
.send()
.await;

match response {
Ok(response) if response.status().is_success() => Ok(access_token.to_string()),
_ => Err(Oauth2Error::VerifyTokenError),
}
}
Provider::Facebook => {
let app_token = format!("{}|{}", provider_config.app_id, provider_config.app_secret);
let response = client
.get(&format!(
"https://graph.facebook.com/debug_token?input_token={}&access_token={}",
access_token, app_token
))
.send()
.await;

match response {
Ok(response) if response.status().is_success() => Ok(access_token.to_string()),
_ => Err(Oauth2Error::VerifyTokenError),
}
}
Provider::Gitlab => {
let response = client
.post("https://gitlab.com/oauth/token/info")
.bearer_auth(access_token)
.send()
.await;

match response {
Ok(response) if response.status().is_success() => Ok(access_token.to_string()),
_ => Err(Oauth2Error::VerifyTokenError),
}
}
// TODO: Add other providers
Provider::Custom => Ok(access_token.to_string()),
_ => Err(Oauth2Error::VerifyTokenError),
}
/// The name of the provider config file
pub fn get_providers_config_file() -> String {
std::env::var("OAUTH2_CONFIG_FILE").unwrap_or_else(|_| "oauth2.toml".to_string())
}

/// Decode the access token
///
/// # Arguments
/// * `access_token` - The access token
///
/// # Returns
/// The decoded access token
pub fn decode_access_token(access_token: &str) -> Result<String, Oauth2Error> {
Ok(access_token.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit b61a972

Please sign in to comment.