From 2f1406236d3b6ab46a00166a0cf20e1a4ee67fe0 Mon Sep 17 00:00:00 2001 From: boxbeam Date: Thu, 1 Feb 2024 13:56:28 -0500 Subject: [PATCH] feat(webserver): Populate the user field in completions requests if the user authenticated (#1341) * Partially complete * Fix error and populate user field * Add back comment to refactored code * Get user email from access token --- crates/tabby/src/routes/completions.rs | 34 +++++++++- crates/tabby/src/services/completion.rs | 2 +- ee/tabby-db/src/users.rs | 11 ++-- ee/tabby-webserver/src/lib.rs | 5 ++ ee/tabby-webserver/src/schema/auth.rs | 2 +- ee/tabby-webserver/src/service/mod.rs | 86 +++++++++++++------------ 6 files changed, 88 insertions(+), 52 deletions(-) diff --git a/crates/tabby/src/routes/completions.rs b/crates/tabby/src/routes/completions.rs index 72a192218cb..8dbf312a61b 100644 --- a/crates/tabby/src/routes/completions.rs +++ b/crates/tabby/src/routes/completions.rs @@ -1,7 +1,8 @@ use std::sync::Arc; -use axum::{extract::State, Json}; +use axum::{extract::State, headers::Header, Json, TypedHeader}; use hyper::StatusCode; +use tabby_webserver::public::USER_HEADER_FIELD_NAME; use tracing::{instrument, warn}; use crate::services::completion::{CompletionRequest, CompletionResponse, CompletionService}; @@ -23,8 +24,12 @@ use crate::services::completion::{CompletionRequest, CompletionResponse, Complet #[instrument(skip(state, request))] pub async fn completions( State(state): State>, - Json(request): Json, + TypedHeader(MaybeUser(user)): TypedHeader, + Json(mut request): Json, ) -> Result, StatusCode> { + if let Some(user) = user { + request.user.replace(user); + } match state.generate(&request).await { Ok(resp) => Ok(Json(resp)), Err(err) => { @@ -33,3 +38,28 @@ pub async fn completions( } } } + +#[derive(Debug)] +pub struct MaybeUser(Option); + +impl Header for MaybeUser { + fn name() -> &'static axum::http::HeaderName { + &USER_HEADER_FIELD_NAME + } + + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + let Some(value) = values.next() else { + return Ok(MaybeUser(None)); + }; + let str = value.to_str().expect("User email is always a valid string"); + Ok(MaybeUser(Some(str.to_string()))) + } + + fn encode>(&self, _values: &mut E) { + todo!() + } +} diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index 0bec43cb843..ea188309630 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -44,7 +44,7 @@ pub struct CompletionRequest { /// A unique identifier representing your end-user, which can help Tabby to monitor & generating /// reports. - user: Option, + pub(crate) user: Option, debug_options: Option, diff --git a/ee/tabby-db/src/users.rs b/ee/tabby-db/src/users.rs index 1f7acda468f..8b7eee2545d 100644 --- a/ee/tabby-db/src/users.rs +++ b/ee/tabby-db/src/users.rs @@ -124,16 +124,15 @@ impl DbConn { ); let users = sqlx::query_as(&query).fetch_all(&self.pool).await?; - Ok(users) } - pub async fn verify_auth_token(&self, token: &str) -> bool { + pub async fn verify_auth_token(&self, token: &str) -> Result { let token = token.to_owned(); - let id = query_scalar!("SELECT id FROM users WHERE auth_token = ?", token) + let email = query_scalar!("SELECT email FROM users WHERE auth_token = ?", token) .fetch_one(&self.pool) .await; - id.is_ok() + email.map_err(Into::into) } pub async fn reset_user_auth_token_by_email(&self, email: &str) -> Result<()> { @@ -222,9 +221,9 @@ mod tests { let user = conn.get_user(id).await.unwrap().unwrap(); - assert!(!conn.verify_auth_token("abcd").await); + assert!(conn.verify_auth_token("abcd").await.is_err()); - assert!(conn.verify_auth_token(&user.auth_token).await); + assert!(conn.verify_auth_token(&user.auth_token).await.is_ok()); conn.reset_user_auth_token_by_email(&user.email) .await diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index b7f09a70ffd..25ada61687c 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -9,6 +9,11 @@ mod service; mod ui; pub mod public { + + pub static USER_HEADER_FIELD_NAME: HeaderName = HeaderName::from_static("x-tabby-user"); + + use axum::http::HeaderName; + pub use super::{ handler::attach_webserver, /* used by tabby workers (consumer of /hub api) */ diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index 71abe991128..a774778a33f 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -233,7 +233,7 @@ impl RefreshTokenResponse { #[derive(Debug, GraphQLObject)] pub struct VerifyTokenResponse { - claims: JWTPayload, + pub claims: JWTPayload, } impl VerifyTokenResponse { diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index 42d21f10af8..997934bf3cf 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -21,13 +21,16 @@ use tabby_db::DbConn; use tracing::{info, warn}; use self::{cron::run_cron, email::new_email_service}; -use crate::schema::{ - auth::AuthenticationService, - email::EmailService, - job::JobService, - repository::RepositoryService, - worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService}, - ServiceLocator, +use crate::{ + public::USER_HEADER_FIELD_NAME, + schema::{ + auth::AuthenticationService, + email::EmailService, + job::JobService, + repository::RepositoryService, + worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService}, + ServiceLocator, + }, }; struct ServerContext { @@ -60,41 +63,32 @@ impl ServerContext { } } - async fn authorize_request(&self, request: &Request) -> bool { + async fn authorize_request(&self, request: &Request) -> (bool, Option) { let path = request.uri().path(); - if path.starts_with("/v1/") || path.starts_with("/v1beta/") { - let token = { - let authorization = request - .headers() - .get("authorization") - .map(HeaderValue::to_str) - .and_then(Result::ok); - - if let Some(authorization) = authorization { - let split = authorization.split_once(' '); - match split { - // Found proper bearer - Some(("Bearer", contents)) => Some(contents), - _ => None, - } - } else { - None - } - }; - - if let Some(token) = token { - if self.db_conn.verify_access_token(token).await.is_err() - && !self.db_conn.verify_auth_token(token).await - { - return false; - } - } else { - // Admin system is initialized, but there's no valid token. - return false; - } + if !(path.starts_with("/v1/") || path.starts_with("/v1beta/")) { + return (true, None); + } + let authorization = request + .headers() + .get("authorization") + .map(HeaderValue::to_str) + .and_then(Result::ok); + + let token = authorization + .and_then(|s| s.split_once(' ')) + .map(|(_bearer, token)| token); + + let Some(token) = token else { + // Admin system is initialized, but there is no valid token. + return (false, None); + }; + if let Ok(jwt) = self.db_conn.verify_access_token(token).await { + return (true, Some(jwt.claims.sub)); + } + match self.db_conn.verify_auth_token(token).await { + Ok(email) => (true, Some(email)), + Err(_) => (false, None), } - - true } } @@ -147,10 +141,11 @@ impl WorkerService for ServerContext { async fn dispatch_request( &self, - request: Request, + mut request: Request, next: Next, ) -> axum::response::Response { - if !self.authorize_request(&request).await { + let (auth, user) = self.authorize_request(&request).await; + if !auth { return axum::response::Response::builder() .status(StatusCode::UNAUTHORIZED) .body(Body::empty()) @@ -158,6 +153,13 @@ impl WorkerService for ServerContext { .into_response(); } + if let Some(user) = user { + request.headers_mut().append( + &USER_HEADER_FIELD_NAME, + HeaderValue::from_str(&user).expect("User must be valid header"), + ); + } + let remote_addr = request .extensions() .get::>()