diff --git a/src/request.rs b/src/request.rs index 6d39a82..6883168 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,6 +1,6 @@ use crate::partial::Partial; use async_trait::async_trait; -use axum::extract::FromRequestParts; +use axum::extract::{FromRequestParts, OriginalUri}; use http::{request::Parts, HeaderMap, HeaderValue, StatusCode}; /// Inertia-related information in the request. @@ -10,6 +10,7 @@ use http::{request::Parts, HeaderMap, HeaderValue, StatusCode}; pub(crate) struct Request { pub(crate) is_xhr: bool, pub(crate) version: Option, + /// When using nested services, the `url` will include the full path. pub(crate) url: String, pub(crate) partial: Option, } @@ -33,8 +34,11 @@ where { type Rejection = (StatusCode, HeaderMap); - async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - let url = parts.uri.path().to_string(); + async fn from_request_parts(mut parts: &mut Parts, state: &S) -> Result { + let original_uri = OriginalUri::from_request_parts(&mut parts, state) + .await + .unwrap_or_else(|e| match e {}); + let url = original_uri.0.path().to_string(); let is_xhr = parts .headers .get("X-Inertia") @@ -227,4 +231,41 @@ mod tests { .unwrap(); assert_eq!(res.status(), StatusCode::OK); } + + #[tokio::test] + async fn it_extracts_urls_for_simple_routes() { + async fn handler(req: Request) { + assert_eq!(req.url, "/test".to_string()); + } + let app = Router::new().route("/test", get(handler)); + let (_, addr) = spawn_test_app(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/test", &addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn it_extracts_original_urls_for_nested_routers() { + async fn handler(req: Request) { + assert_eq!(req.url, "/outer/test".to_string()); + } + let inner = Router::new().route("/test", get(handler)); + let outer = Router::new().nest_service("/outer", inner); + let (_, addr) = spawn_test_app(outer).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/outer/test", &addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + } }