Skip to content

Commit

Permalink
refactor to pull authorization from plane headers in single-doc mode …
Browse files Browse the repository at this point in the history
…and to pull it from a token in full-server mode
  • Loading branch information
rolyatmax committed Jan 15, 2025
1 parent 5edd5a6 commit 6c570c8
Showing 1 changed file with 84 additions and 56 deletions.
140 changes: 84 additions & 56 deletions crates/y-sweet/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use axum::{
ws::{Message, WebSocket},
Path, Query, Request, State, WebSocketUpgrade,
},
http::StatusCode,
http::{
header::{HeaderMap, HeaderName},
StatusCode,
},
middleware::{self, Next},
response::{IntoResponse, Response},
routing::{get, post},
Expand Down Expand Up @@ -40,6 +43,8 @@ use y_sweet_core::{
sync_kv::SyncKv,
};

const PLANE_VERIFIED_USER_DATA_HEADER: &str = "x-verified-user-data";

fn current_time_epoch_millis() -> u64 {
let now = std::time::SystemTime::now();
let duration_since_epoch = now.duration_since(std::time::UNIX_EPOCH).unwrap();
Expand Down Expand Up @@ -286,10 +291,10 @@ impl Server {

pub fn check_auth(
&self,
header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
) -> Result<(), AppError> {
if let Some(auth) = &self.authenticator {
if let Some(TypedHeader(headers::Authorization(bearer))) = header {
if let Some(TypedHeader(headers::Authorization(bearer))) = auth_header {
if let Ok(()) =
auth.verify_server_token(bearer.token(), current_time_epoch_millis())
{
Expand All @@ -302,25 +307,6 @@ impl Server {
}
}

pub fn check_doc_auth(
&self,
header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
doc_id: &str,
) -> Result<Authorization, AppError> {
if let Some(auth) = &self.authenticator {
if let Some(TypedHeader(headers::Authorization(bearer))) = header {
if let Ok(authorization) =
auth.verify_doc_token(bearer.token(), doc_id, current_time_epoch_millis())
{
return Ok(authorization);
}
}
Err((StatusCode::UNAUTHORIZED, anyhow!("Unauthorized.")))?
} else {
Ok(Authorization::Full)
}
}

pub async fn redact_error_middleware(req: Request, next: Next) -> impl IntoResponse {
let resp = next.run(req).await;
if resp.status().is_server_error() || resp.status().is_client_error() {
Expand Down Expand Up @@ -426,10 +412,11 @@ struct HandlerParams {
async fn get_doc_as_update(
State(server_state): State<Arc<Server>>,
Path(doc_id): Path<String>,
authorization: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
) -> Result<Response, AppError> {
// All authorization types allow reading the document.
let _ = server_state.check_doc_auth(authorization, &doc_id)?;
let token = get_token_from_header(auth_header);
let _ = server_state.verify_doc_token(token.as_deref(), &doc_id)?;

let dwskv = server_state
.get_or_create_doc(&doc_id)
Expand All @@ -444,40 +431,48 @@ async fn get_doc_as_update(
async fn get_doc_as_update_deprecated(
Path(doc_id): Path<String>,
State(server_state): State<Arc<Server>>,
authorization: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
) -> Result<Response, AppError> {
tracing::warn!("/doc/:doc_id/as-update is deprecated; call /doc/:doc_id/auth instead and then call as-update on the returned base URL.");
get_doc_as_update(State(server_state), Path(doc_id), authorization).await
get_doc_as_update(State(server_state), Path(doc_id), auth_header).await
}

async fn update_doc_deprecated(
Path(doc_id): Path<String>,
State(server_state): State<Arc<Server>>,
authorization: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
body: Bytes,
) -> Result<Response, AppError> {
tracing::warn!("/doc/:doc_id/update is deprecated; call /doc/:doc_id/auth instead and then call update on the returned base URL.");
update_doc(Path(doc_id), State(server_state), authorization, body).await
update_doc(Path(doc_id), State(server_state), auth_header, body).await
}

async fn get_doc_as_update_single(
State(server_state): State<Arc<Server>>,
authorization: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
) -> Result<Response, AppError> {
let doc_id = server_state.get_single_doc_id()?;
get_doc_as_update(State(server_state), Path(doc_id), authorization).await
get_doc_as_update(State(server_state), Path(doc_id), auth_header).await
}

async fn update_doc(
Path(doc_id): Path<String>,
State(server_state): State<Arc<Server>>,
authorization: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
body: Bytes,
) -> Result<Response, AppError> {
if !matches!(
server_state.check_doc_auth(authorization, &doc_id)?,
Authorization::Full
) {
let token = get_token_from_header(auth_header);
let authorization = server_state.verify_doc_token(token.as_deref(), &doc_id)?;
update_doc_inner(doc_id, server_state, authorization, body).await
}

async fn update_doc_inner(
doc_id: String,
server_state: Arc<Server>,
authorization: Authorization,
body: Bytes,
) -> Result<Response, AppError> {
if !matches!(authorization, Authorization::Full) {
return Err(AppError(StatusCode::FORBIDDEN, anyhow!("Unauthorized.")));
}

Expand All @@ -496,21 +491,23 @@ async fn update_doc(

async fn update_doc_single(
State(server_state): State<Arc<Server>>,
authorization: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
headers: HeaderMap,
body: Bytes,
) -> Result<Response, AppError> {
let doc_id = server_state.get_single_doc_id()?;
update_doc(Path(doc_id), State(server_state), authorization, body).await
// the doc server is meant to be run in Plane, so we expect verified plane
// headers to be used for authorization.
let authorization = get_authorization_from_plane_header(headers)?;
update_doc_inner(doc_id, server_state, authorization, body).await
}

async fn handle_socket_upgrade(
ws: WebSocketUpgrade,
Path(doc_id): Path<String>,
Query(params): Query<HandlerParams>,
authorization: Authorization,
State(server_state): State<Arc<Server>>,
) -> Result<Response, AppError> {
let auth = server_state.verify_doc_token(params.token.as_deref(), &doc_id)?;
if !matches!(auth, Authorization::Full) && !server_state.docs.contains_key(&doc_id) {
if !matches!(authorization, Authorization::Full) && !server_state.docs.contains_key(&doc_id) {
return Err(AppError(
StatusCode::NOT_FOUND,
anyhow!("Doc {} not found", doc_id),
Expand All @@ -524,7 +521,9 @@ async fn handle_socket_upgrade(
let awareness = dwskv.awareness();
let cancellation_token = server_state.cancellation_token.clone();

Ok(ws.on_upgrade(move |socket| handle_socket(socket, awareness, auth, cancellation_token)))
Ok(ws.on_upgrade(move |socket| {
handle_socket(socket, awareness, authorization, cancellation_token)
}))
}

async fn handle_socket_upgrade_deprecated(
Expand All @@ -536,7 +535,8 @@ async fn handle_socket_upgrade_deprecated(
tracing::warn!(
"/doc/ws/:doc_id is deprecated; call /doc/:doc_id/auth instead and use the returned URL."
);
handle_socket_upgrade(ws, Path(doc_id), Query(params), State(server_state)).await
let authorization = server_state.verify_doc_token(params.token.as_deref(), &doc_id)?;
handle_socket_upgrade(ws, Path(doc_id), authorization, State(server_state)).await
}

async fn handle_socket_upgrade_full_path(
Expand All @@ -551,14 +551,14 @@ async fn handle_socket_upgrade_full_path(
anyhow!("For Yjs compatibility, the doc_id appears twice in the URL. It must be the same in both places, but we got {} and {}.", doc_id, doc_id2),
));
}

handle_socket_upgrade(ws, Path(doc_id), Query(params), State(server_state)).await
let authorization = server_state.verify_doc_token(params.token.as_deref(), &doc_id)?;
handle_socket_upgrade(ws, Path(doc_id), authorization, State(server_state)).await
}

async fn handle_socket_upgrade_single(
ws: WebSocketUpgrade,
Path(doc_id): Path<String>,
Query(params): Query<HandlerParams>,
headers: HeaderMap,
State(server_state): State<Arc<Server>>,
) -> Result<Response, AppError> {
let single_doc_id = server_state.get_single_doc_id()?;
Expand All @@ -568,7 +568,11 @@ async fn handle_socket_upgrade_single(
anyhow!("Document not found"),
));
}
handle_socket_upgrade(ws, Path(single_doc_id), Query(params), State(server_state)).await

// the doc server is meant to be run in Plane, so we expect verified plane
// headers to be used for authorization.
let authorization = get_authorization_from_plane_header(headers)?;
handle_socket_upgrade(ws, Path(single_doc_id), authorization, State(server_state)).await
}

async fn handle_socket(
Expand Down Expand Up @@ -622,10 +626,10 @@ async fn handle_socket(
}

async fn check_store(
authorization: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
State(server_state): State<Arc<Server>>,
) -> Result<Json<Value>, AppError> {
server_state.check_auth(authorization)?;
server_state.check_auth(auth_header)?;

if server_state.store.is_none() {
return Ok(Json(json!({"ok": false, "error": "No store set."})));
Expand All @@ -637,13 +641,13 @@ async fn check_store(
}

async fn check_store_deprecated(
authorization: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
State(server_state): State<Arc<Server>>,
) -> Result<Json<Value>, AppError> {
tracing::warn!(
"GET check_store is deprecated, use POST check_store with an empty body instead."
);
check_store(authorization, State(server_state)).await
check_store(auth_header, State(server_state)).await
}

/// Always returns a 200 OK response, as long as we are listening.
Expand All @@ -652,11 +656,11 @@ async fn ready() -> Result<Json<Value>, AppError> {
}

async fn new_doc(
authorization: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
State(server_state): State<Arc<Server>>,
Json(body): Json<DocCreationRequest>,
) -> Result<Json<NewDocResponse>, AppError> {
server_state.check_auth(authorization)?;
server_state.check_auth(auth_header)?;

let doc_id = if let Some(doc_id) = body.doc_id {
if !validate_doc_name(doc_id.as_str()) {
Expand All @@ -683,15 +687,13 @@ async fn new_doc(
}

async fn auth_doc(
authorization_header: Option<
TypedHeader<headers::Authorization<headers::authorization::Bearer>>,
>,
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
TypedHeader(host): TypedHeader<headers::Host>,
State(server_state): State<Arc<Server>>,
Path(doc_id): Path<String>,
body: Option<Json<AuthDocRequest>>,
) -> Result<Json<ClientToken>, AppError> {
server_state.check_auth(authorization_header)?;
server_state.check_auth(auth_header)?;

let Json(AuthDocRequest {
authorization,
Expand Down Expand Up @@ -743,6 +745,32 @@ async fn auth_doc(
}))
}

fn get_token_from_header(
auth_header: Option<TypedHeader<headers::Authorization<headers::authorization::Bearer>>>,
) -> Option<String> {
if let Some(TypedHeader(headers::Authorization(bearer))) = auth_header {
Some(bearer.token().to_string())
} else {
None
}
}

#[derive(Deserialize)]
struct PlaneVerifiedUserData {
authorization: Authorization,
}

fn get_authorization_from_plane_header(headers: HeaderMap) -> Result<Authorization, AppError> {
if let Some(token) = headers.get(HeaderName::from_static(PLANE_VERIFIED_USER_DATA_HEADER)) {
let token_str = token.to_str().map_err(|e| (StatusCode::BAD_REQUEST, e))?;
let user_data: PlaneVerifiedUserData =
serde_json::from_str(token_str).map_err(|e| (StatusCode::BAD_REQUEST, e))?;
Ok(user_data.authorization)
} else {
Err((StatusCode::UNAUTHORIZED, anyhow!("No token provided.")))?
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down

0 comments on commit 6c570c8

Please sign in to comment.