From 50429522a94206b1f7fc4eec404d8b6d91a92873 Mon Sep 17 00:00:00 2001 From: Oscar Beaumont Date: Thu, 18 Jul 2024 22:29:24 +0800 Subject: [PATCH] Axum overhaul + errors runtime + back compat --- examples/axum/src/api.rs | 30 ++- examples/axum/src/main.rs | 2 +- integrations/axum/Cargo.toml | 1 + integrations/axum/src/ctx_fn.rs | 66 +++++ integrations/axum/src/endpoint.rs | 369 +++++++++++++++++++------- integrations/axum/src/extractors.rs | 1 - integrations/axum/src/file.rs | 5 +- integrations/axum/src/lib.rs | 46 +++- middleware/openapi/src/lib.rs | 2 +- rspc/src/error.rs | 11 + rspc/src/lib.rs | 3 +- rspc/src/notes.rs | 27 -- rspc/src/procedure.rs | 2 +- rspc/src/procedure/builder.rs | 8 +- rspc/src/procedure/error.rs | 96 ++++++- rspc/src/procedure/meta.rs | 2 +- rspc/src/procedure/output.rs | 11 +- rspc/src/procedure/procedure.rs | 6 +- rspc/src/procedure/resolver_output.rs | 11 +- rspc/src/procedure/stream.rs | 34 ++- 20 files changed, 562 insertions(+), 171 deletions(-) create mode 100644 integrations/axum/src/ctx_fn.rs delete mode 100644 integrations/axum/src/extractors.rs create mode 100644 rspc/src/error.rs delete mode 100644 rspc/src/notes.rs diff --git a/examples/axum/src/api.rs b/examples/axum/src/api.rs index 41946be9..496b333a 100644 --- a/examples/axum/src/api.rs +++ b/examples/axum/src/api.rs @@ -1,9 +1,11 @@ -use std::{error, marker::PhantomData, path::PathBuf, sync::Arc}; +use std::{marker::PhantomData, path::PathBuf, sync::Arc}; use rspc::{ - procedure::{Procedure, ProcedureBuilder, ProcedureKind, ResolverInput, ResolverOutput}, + procedure::{Procedure, ProcedureBuilder, ResolverInput, ResolverOutput}, Infallible, }; +use serde::Serialize; +use specta::Type; use specta_typescript::Typescript; use specta_util::TypeCollection; use thiserror::Error; @@ -12,8 +14,13 @@ pub(crate) mod chat; pub(crate) mod invalidation; pub(crate) mod store; -#[derive(Debug, Error)] -pub enum Error {} +#[derive(Debug, Error, Serialize, Type)] +pub enum Error { + #[error("you made a mistake: {0}")] + Mistake(String), +} + +impl rspc::Error for Error {} // `Clone` is only required for usage with Websockets #[derive(Clone)] @@ -31,7 +38,7 @@ pub struct BaseProcedure(PhantomData); impl BaseProcedure { pub fn builder() -> ProcedureBuilder where - TErr: error::Error + Send + 'static, + TErr: rspc::Error, TInput: ResolverInput, TResult: ResolverOutput, { @@ -44,6 +51,19 @@ pub fn mount() -> Router { .procedure("version", { ::builder().query(|_, _: ()| async { Ok(env!("CARGO_PKG_VERSION")) }) }) + .procedure("error", { + #[derive(Debug, serde::Serialize, Type)] + #[serde(tag = "type")] + enum Testing { + A(String), + } + + ::builder().query(|_, _: ()| async { Ok(Testing::A("go away".into())) }) + }) + .procedure("error2", { + ::builder() + .query(|_, _: ()| async { Err::<(), _>(Error::Mistake("skill issue".into())) }) + }) .merge("chat", chat::mount()) .merge("store", store::mount()) // TODO: I dislike this API diff --git a/examples/axum/src/main.rs b/examples/axum/src/main.rs index 791bfbcb..6ca69d66 100644 --- a/examples/axum/src/main.rs +++ b/examples/axum/src/main.rs @@ -24,7 +24,7 @@ async fn main() { .route("/", get(|| async { "Hello, World!" })) .nest( "/rspc", - rspc_axum::Endpoint::new(router.clone(), ctx_fn.clone()).build(), + rspc_axum::Endpoint::new(router.clone(), ctx_fn.clone()), ) .nest("/", rspc_openapi::mount(router, ctx_fn)); diff --git a/integrations/axum/Cargo.toml b/integrations/axum/Cargo.toml index 7f7d2150..62dcac31 100644 --- a/integrations/axum/Cargo.toml +++ b/integrations/axum/Cargo.toml @@ -13,6 +13,7 @@ categories = ["web-programming", "asynchronous"] [features] default = [] ws = ["dep:tokio", "axum/ws"] +file = ["dep:tokio"] [dependencies] rspc = { version = "0.3.0", path = "../../rspc" } diff --git a/integrations/axum/src/ctx_fn.rs b/integrations/axum/src/ctx_fn.rs new file mode 100644 index 00000000..16815f09 --- /dev/null +++ b/integrations/axum/src/ctx_fn.rs @@ -0,0 +1,66 @@ +use axum::{extract::FromRequestParts, http::request::Parts}; +use std::{future::Future, marker::PhantomData}; + +// TODO: Sealed? +pub trait ContextFunction: Clone + Send + Sync + 'static +where + TState: Send + Sync, + TCtx: Send + 'static, +{ + // fn exec(&self, parts: Parts, state: &TState) -> impl Future> + Send; +} + +pub struct ZeroArgMarker; +impl ContextFunction for TFunc +where + TFunc: Fn() -> TCtx + Clone + Send + Sync + 'static, + TState: Send + Sync, + TCtx: Send + 'static, +{ + // async fn exec(&self, _: Parts, _: &TState) -> Result { + // Ok(self.clone()()) + // } +} + +macro_rules! impl_fn { + ($marker:ident; $($generics:ident),*) => { + #[allow(unused_parens)] + pub struct $marker<$($generics),*>(PhantomData<($($generics),*)>); + + impl + Send),*> ContextFunction> for TFunc + where + TFunc: Fn($($generics),*) -> TCtx + Clone + Send + Sync + 'static, + TState: Send + Sync, + TCtx: Send + 'static + { + // async fn exec(&self, mut parts: Parts, state: &TState) -> Result + // { + // $( + // #[allow(non_snake_case)] + // let Ok($generics) = $generics::from_request_parts(&mut parts, &state).await else { + // return Err(ExecError::AxumExtractorError) + // }; + // )* + + // Ok(self.clone()($($generics),*)) + // } + } + }; +} + +impl_fn!(OneArgMarker; T1); +impl_fn!(TwoArgMarker; T1, T2); +impl_fn!(ThreeArgMarker; T1, T2, T3); +impl_fn!(FourArgMarker; T1, T2, T3, T4); +impl_fn!(FiveArgMarker; T1, T2, T3, T4, T5); +impl_fn!(SixArgMarker; T1, T2, T3, T4, T5, T6); +impl_fn!(SevenArgMarker; T1, T2, T3, T4, T5, T6, T7); +impl_fn!(EightArgMarker; T1, T2, T3, T4, T5, T6, T7, T8); +impl_fn!(NineArgMarker; T1, T2, T3, T4, T5, T6, T7, T8, T9); +impl_fn!(TenArgMarker; T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); +impl_fn!(ElevenArgMarker; T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); +impl_fn!(TwelveArgMarker; T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); +impl_fn!(ThirteenArgMarker; T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); +impl_fn!(FourteenArgMarker; T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); +impl_fn!(FifteenArgMarker; T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); +impl_fn!(SixteenArgMarker; T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); diff --git a/integrations/axum/src/endpoint.rs b/integrations/axum/src/endpoint.rs index 32b24f21..37132c00 100644 --- a/integrations/axum/src/endpoint.rs +++ b/integrations/axum/src/endpoint.rs @@ -1,141 +1,316 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, future::poll_fn, sync::Arc, task::Poll}; use axum::{ body::Bytes, extract::Query, - http::StatusCode, - response::{IntoResponse, Response}, + http::{header, HeaderMap, StatusCode}, routing::{get, post}, Json, }; use futures::StreamExt; -use rspc::{procedure::ProcedureKind, BuiltRouter}; +use rspc::{ + procedure::{Procedure, ProcedureInput, ProcedureKind}, + BuiltRouter, +}; +use serde_json::json; +/// Construct a new [`axum::Router`](axum::Router) to expose a given [`rspc::Router`](rspc::Router). pub struct Endpoint { router: BuiltRouter, - axum: axum::Router<()>, - ctx_fn: Arc TCtx + Send + Sync>, + endpoints: bool, + websocket: Option TCtx>, + batching: bool, } impl Endpoint { - // TODO: Async or `Result` return type for context function - pub fn new( + /// Construct a new [`axum::Router`](axum::Router) with all features enabled. + /// + /// This will enable all features, if you want to configure which features are enabled you can use [`Endpoint::builder`] instead. + /// + /// # Usage + /// + /// ```rust + /// axum::Router::new().nest( + /// "/rspc", + /// rspc_axum::Endpoint::new(rspc::Router::new().build().unwrap(), || ()), + /// ); + /// ``` + pub fn new( router: BuiltRouter, // TODO: Parse this to `Self::build` -> It will make rustfmt result way nicer ctx_fn: impl Fn() -> TCtx + Send + Sync + 'static, - ) -> Self { + ) -> axum::Router + where + S: Clone + Send + Sync + 'static, + // TODO: Error type??? + // F: Future> + Send + Sync + 'static, + TCtx: Clone, + { + Self::builder(router) + .with_endpoints() + .with_websocket() + .with_batching() + .build(ctx_fn) + } + + /// Construct a new [`Endpoint`](Endpoint) with no features enabled. + /// + /// # Usage + /// + /// ```rust + /// axum::Router::new().nest( + /// "/rspc", + /// rspc_axum::Endpoint::builder(rspc::Router::new().build().unwrap()) + /// // Exposes HTTP endpoints for queries and mutations. + /// .with_endpoints() + /// // Exposes a Websocket connection for queries, mutations and subscriptions. + /// .with_websocket() + /// // Enables support for the frontend sending batched queries. + /// .with_batching() + /// .build(|| ()), + /// ); + /// ``` + pub fn builder(router: BuiltRouter) -> Self { Self { router, - axum: axum::Router::new(), - ctx_fn: Arc::new(ctx_fn), + endpoints: false, + websocket: None, + batching: false, } } - // TODO: What to call this??? + /// Enables HTTP endpoints for queries and mutations. + /// + /// This is exposed as `/routerName.procedureName` pub fn with_endpoints(mut self) -> Self { - for (key, procedure) in &self.router.procedures { - let ctx_fn = self.ctx_fn.clone(); - let procedure = procedure.clone(); - self.axum = match procedure.kind() { - ProcedureKind::Query => self.axum.route( - &format!("/{}", key), - // TODO: By moving `procedure` into the closure we hang onto the types for the duration of the program which is probs undesirable. - get(move |query: Query>| async move { - let ctx = (ctx_fn)(); - - let mut stream = procedure - .exec( - ctx, - &mut serde_json::Deserializer::from_str( - query.get("input").map(|v| &**v).unwrap_or("null"), - ), - ) - .map_err(|err| { - // TODO: Error code by matching off `InternalError` - (StatusCode::INTERNAL_SERVER_ERROR, Json(err.to_string())) - .into_response() - })?; - - // TODO: Support for streaming - while let Some(value) = stream.next().await { - return match value.map(|v| v.serialize(serde_json::value::Serializer)) { - Ok(Ok(value)) => Ok(Json(value)), - Ok(Err(err)) => { - // TODO: Error code by matching off `InternalError` - Err((StatusCode::INTERNAL_SERVER_ERROR, Json(err.to_string())) - .into_response()) - } - Err(err) => panic!("{err:?}"), // TODO: Error handling -> How to serialize `TError`??? -> Should this be done in procedure? - }; - } - - Ok::<_, Response>(Json(serde_json::Value::Null)) - }), - ), - ProcedureKind::Mutation => self.axum.route( - &format!("/{}", key), - // TODO: By moving `procedure` into the closure we hang onto the types for the duration of the program which is probs undesirable. - post(move |body: Bytes| async move { - let ctx = (ctx_fn)(); - - let mut stream = procedure - .exec(ctx, &mut serde_json::Deserializer::from_slice(&body)) - .map_err(|err| { - // TODO: Error code by matching off `InternalError` - (StatusCode::INTERNAL_SERVER_ERROR, Json(err.to_string())) - .into_response() - })?; - - // TODO: Support for streaming - while let Some(value) = stream.next().await { - return match value.map(|v| v.serialize(serde_json::value::Serializer)) { - Ok(Ok(value)) => Ok(Json(value)), - Ok(Err(err)) => { - // TODO: Error code by matching off `InternalError` - Err((StatusCode::INTERNAL_SERVER_ERROR, Json(err.to_string())) - .into_response()) - } - Err(err) => panic!("{err:?}"), // TODO: Error handling -> How to serialize `TError`??? -> Should this be done in procedure? - }; - } - - Ok::<_, Response>(Json(serde_json::Value::Null)) - }), - ), - ProcedureKind::Subscription => continue, - }; + Self { + endpoints: true, + ..self } - - self } - // TODO: Put behind feature flag + /// Exposes a Websocket connection for queries, mutations and subscriptions. + /// + /// This is exposed as a `/ws` endpoint. + #[cfg(feature = "ws")] + #[cfg_attr(docsrs, doc(cfg(feature = "ws")))] pub fn with_websocket(self) -> Self where TCtx: Clone, { Self { - axum: self.axum.route( - "/ws", - get(|| async move { - // TODO: Support for websockets - "this is rspc websocket" - }), - ), + websocket: Some(|ctx| ctx.clone()), ..self } } + /// Enables support for the frontend sending batched queries. + /// + /// This is exposed as a `/_batch` endpoint. pub fn with_batching(self) -> Self where TCtx: Clone, { - // TODO: Support for batching & stream batching + Self { + batching: true, + ..self + } + } + + // TODO: Async or `Result` return type for context function + /// Build an [`axum::Router`](axum::Router) with the configured features. + pub fn build(self, ctx_fn: impl Fn() -> TCtx + Send + Sync + 'static) -> axum::Router + where + S: Clone + Send + Sync + 'static, + { + let mut r = axum::Router::new(); + let ctx_fn = Arc::new(ctx_fn); - self + if self.endpoints { + for (key, procedure) in &self.router.procedures { + let ctx_fn = ctx_fn.clone(); + let procedure = procedure.clone(); + r = match procedure.kind() { + ProcedureKind::Query => { + r.route( + &format!("/{}", key), + // TODO: By moving `procedure` into the closure we hang onto the types for the duration of the program which is probs undesirable. + get( + move |query: Query>, + headers: HeaderMap| async move { + let ctx = (ctx_fn)(); + + handle_procedure( + ctx, + &mut serde_json::Deserializer::from_str( + query.get("input").map(|v| &**v).unwrap_or("null"), + ), + headers, + procedure, + ) + .await + }, + ), + ) + } + ProcedureKind::Mutation => r.route( + &format!("/{}", key), + // TODO: By moving `procedure` into the closure we hang onto the types for the duration of the program which is probs undesirable. + post(move |headers: HeaderMap, body: Bytes| async move { + let ctx = (ctx_fn)(); + + handle_procedure( + ctx, + &mut serde_json::Deserializer::from_slice(&body), + headers, + procedure, + ) + .await + }), + ), + ProcedureKind::Subscription => continue, + }; + } + } + + #[cfg(feature = "ws")] + if let Some(clone_ctx) = self.websocket { + use axum::extract::ws::WebSocketUpgrade; + r = r.route( + "/ws", + get(move |ws: WebSocketUpgrade| async move { + let ctx = (ctx_fn)(); + + ws.on_upgrade(move |socket| async move { + todo!(); + + // while let Some(msg) = socket.recv().await {} + }) + }), + ); + } + + if self.batching { + // TODO: Support for batching & stream batching + + // todo!(); + } + + r.with_state(()) } +} + +// Used for `GET` and `POST` endpoints +// TODO: We should probs deserialize into buffer instead of value all over this function!!!! +async fn handle_procedure<'de, TCtx>( + ctx: TCtx, + input: impl ProcedureInput<'de>, + headers: HeaderMap, + procedure: Procedure, +) -> Result, (StatusCode, Json)> { + let is_legacy_client = headers.get("x-rspc").is_none(); + + let mut stream = procedure.exec(ctx, input).map_err(|err| { + if is_legacy_client { + ( + StatusCode::OK, + Json(json!({ + "jsonrpc":"2.0", + "id":null, + "result":{ + "type":"error", + "data": { + "code": 500, + "message": err.to_string(), + "data": null + } + } + })), + ) + } else { + ( + StatusCode::BAD_REQUEST, + Json(json!({ + "_rspc_error": err.to_string() + })), + ) + } + })?; + + if is_legacy_client { + let value = match stream.next().await { + Some(value) => { + if stream.next().await.is_some() { + println!("Streaming was attempted with a legacy rspc client! Ensure your not using `rspc::Stream` unless your clients are up to date."); + } + + value + } + None => { + return Ok(Json(json!({ + "jsonrpc": "2.0", + "id": null, + "result": { + "type": "response", + "data": "todo" + } + }))); + } + }; + + return Ok(Json(serde_json::json!({ + "jsonrpc": "2.0", + "id": null, + "result": match value + .map_err(|err| (err.status(), err.to_string(), err.serialize(serde_json::value::Serializer).unwrap_or_default())) + .and_then(|v| v.serialize(serde_json::value::Serializer).map_err(|err| (500, err.to_string(), serde_json::Value::Null))) { + Ok(value) => { + json!({ + "type": "response", + "data": value, + }) + } + Err((status, message, data)) => { + json!({ + "type": "error", + "data": { + "code": status, + "message": message, + // `data` was technically always `null` in legacy rspc but we'll include it how it was intended. + "data": data, + } + }) + } + } + }))); + } else { + // TODO: Support for streaming + while let Some(value) = stream.next().await { + return match value.map(|v| v.serialize(serde_json::value::Serializer)) { + Ok(Ok(value)) => Ok(Json(value)), + Ok(Err(err)) => Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "_rspc_error": err.to_string() + })), + )), + Err(err) => Err(( + StatusCode::from_u16(err.status()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + Json( + err.serialize(serde_json::value::Serializer) + .map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "_rspc_error": err.to_string() + })), + ) + })?, + ), + )), + }; + } - pub fn build(self) -> axum::Router { - self.axum.with_state(()) + Ok(Json(serde_json::Value::Null)) } } diff --git a/integrations/axum/src/extractors.rs b/integrations/axum/src/extractors.rs deleted file mode 100644 index 8b137891..00000000 --- a/integrations/axum/src/extractors.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/integrations/axum/src/file.rs b/integrations/axum/src/file.rs index c933d3ff..42f6d299 100644 --- a/integrations/axum/src/file.rs +++ b/integrations/axum/src/file.rs @@ -10,10 +10,11 @@ use rspc::{ use tokio::io::AsyncWrite; // TODO: Clone, Debug, etc +#[doc(hidden)] // TODO: Finish this pub struct File>>(pub T); impl ResolverOutput for File { - fn data_type(type_map: &mut TypeMap) -> DataType { + fn data_type(_: &mut TypeMap) -> DataType { DataType::Any // TODO } @@ -33,7 +34,7 @@ impl<'de, F: AsyncWrite + Send + 'static> ProcedureInput<'de> for File { } impl ResolverInput for File { - fn data_type(type_map: &mut TypeMap) -> DataType { + fn data_type(_: &mut TypeMap) -> DataType { DataType::Any // TODO } diff --git a/integrations/axum/src/lib.rs b/integrations/axum/src/lib.rs index eae7c0e9..291f6abc 100644 --- a/integrations/axum/src/lib.rs +++ b/integrations/axum/src/lib.rs @@ -1,13 +1,55 @@ -//! rspc-axum: Axum integration for [rspc](https://rspc.dev). +//! Expose your [rspc](https://rspc.dev) application as an HTTP and/or WebSocket API using [Axum](https://github.com/tokio-rs/axum). +//! +//! # Example +//! +//! To get started you can copy the following example and run it with `cargo run`. +//! +//! ```rust +//! use axum::{ +//! routing::get, +//! Router, +//! }; +//! +//! #[tokio::main] +//! async fn main() { +//! let router = rspc::Router::new().build().unwrap(); +//! +//! let app = Router::new() +//! .route("/", get(|| async { "Hello, World!" })) +//! .nest( +//! "/rspc", +//! rspc_axum::Endpoint::new(router.clone(), || ()), +//! ) +//! +//! // run our app with hyper, listening globally on port 3000 +//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); +//! axum::serve(listener, app).await.unwrap(); +//! } +//! ``` +//! +//! Note: You must enable the `ws` feature to use WebSockets. +//! +//! # Features +//! +//! You can enable any of the following features to enable additional functionality: +//! +//! - `ws`: Support for WebSockets. +//! - `file`: Support for serving files. +//! #![cfg_attr(docsrs, feature(doc_cfg))] #![doc( html_logo_url = "https://github.com/oscartbeaumont/rspc/raw/main/docs/public/logo.png", html_favicon_url = "https://github.com/oscartbeaumont/rspc/raw/main/docs/public/logo.png" )] +mod ctx_fn; mod endpoint; -mod extractors; +#[cfg(feature = "file")] +#[cfg_attr(docsrs, doc(cfg(feature = "file")))] mod file; pub use endpoint::Endpoint; + +#[cfg(feature = "file")] +#[cfg_attr(docsrs, doc(cfg(feature = "file")))] pub use file::File; diff --git a/middleware/openapi/src/lib.rs b/middleware/openapi/src/lib.rs index d6ab5334..a5d877cb 100644 --- a/middleware/openapi/src/lib.rs +++ b/middleware/openapi/src/lib.rs @@ -5,7 +5,7 @@ html_favicon_url = "https://github.com/oscartbeaumont/rspc/raw/main/docs/public/logo.png" )] -use std::{borrow::Cow, collections::HashMap, hash::Hash, sync::Arc}; +use std::{borrow::Cow, collections::HashMap, sync::Arc}; use axum::{ body::Bytes, diff --git a/rspc/src/error.rs b/rspc/src/error.rs new file mode 100644 index 00000000..2ff445b6 --- /dev/null +++ b/rspc/src/error.rs @@ -0,0 +1,11 @@ +use std::error; + +use serde::Serialize; +use specta::Type; + +pub trait Error: error::Error + Send + Serialize + Type + 'static { + // Warning: Returning > 400 will fallback to `500`. As redirects would be invalid and `200` would break matching. + fn status(&self) -> u16 { + 500 + } +} diff --git a/rspc/src/lib.rs b/rspc/src/lib.rs index 2babe4a9..71fafc89 100644 --- a/rspc/src/lib.rs +++ b/rspc/src/lib.rs @@ -13,14 +13,15 @@ #![cfg_attr(docsrs, feature(doc_cfg))] pub mod middleware; -pub mod notes; pub mod procedure; +mod error; mod infallible; mod router; mod state; mod stream; +pub use error::Error; pub use infallible::Infallible; pub use router::{BuiltRouter, Router}; pub use state::State; diff --git a/rspc/src/notes.rs b/rspc/src/notes.rs deleted file mode 100644 index 93ca46c1..00000000 --- a/rspc/src/notes.rs +++ /dev/null @@ -1,27 +0,0 @@ -//! This is a temporary module i'm using to store notes until v1. -//! -//! ## Rust language limitations: -//! - Support for Serde zero-copy deserialization -//! - We need a way to express `where F: Fn(..., I<'_>), I<'a>: Input<'a>` which to my best knowledge is impossible. -//! -//! ## More work needed: -//! - Should `Middleware::setup` return a `Result`? Probs aye? -//! - Specta more safely -//! - [`ResolverOutput`] & [`ResolverInput`] should probs ensure the value returned and the Specta type match -//! - That being said for `dyn Any` that could prove annoying so maybe a `Untyped` escape hatch??? -//! - handling of result types is less efficient that it could be -//! - If it can only return one type can be do 'the downcast/deserialization magic internally. -//! - new Rust diagnostic stuff on all the traits -//! - Handle panics within procedures -//! - `ResolverOutput` trait oddities -//! - `rspc::Stream` within a `rspc::Stream` will panic at runtime -//! - Am I happy with `Output::into_procedure_stream`? It's low key cringe but it might be fine. -//! - `ProcedureInput` vs `ResolverInput` typesafety -//! - You can implement forget to implement `ProcedureInput` for an `ResolverInput` type. -//! - For non-serde types really `ResolverInput` and `ProcedureInput` are the same, can we express that? Probs not without specialization and markers. -//! - Support for Cloudflare Workers/single-threaded async runtimes. I recall this being problematic with `Send + Sync`. -//! - Review all generics on middleware and procedure types to ensure consistent ordering. -//! - Consistency between `TErr` and `TError` -//! - Documentation for everything -//! - Yank all v1 releases once 0.3.0 is out -//! diff --git a/rspc/src/procedure.rs b/rspc/src/procedure.rs index b95acfbd..2204cc4a 100644 --- a/rspc/src/procedure.rs +++ b/rspc/src/procedure.rs @@ -26,7 +26,7 @@ mod resolver_output; mod stream; pub use builder::ProcedureBuilder; -pub use error::InternalError; +pub use error::{InternalError, ResolverError}; pub use exec_input::ProcedureExecInput; pub use input::ProcedureInput; pub use meta::{ProcedureKind, ProcedureMeta}; diff --git a/rspc/src/procedure/builder.rs b/rspc/src/procedure/builder.rs index 116964a1..fe096656 100644 --- a/rspc/src/procedure/builder.rs +++ b/rspc/src/procedure/builder.rs @@ -1,10 +1,10 @@ -use std::{error, fmt, future::Future}; +use std::{fmt, future::Future}; use futures::FutureExt; use crate::{ middleware::{Middleware, MiddlewareHandler}, - State, + Error, State, }; use super::{ProcedureKind, ProcedureMeta, UnbuiltProcedure}; @@ -31,7 +31,7 @@ impl fmt::Debug impl ProcedureBuilder where - TError: error::Error + Send + 'static, + TError: Error, TRootCtx: 'static, TCtx: 'static, TInput: 'static, @@ -92,7 +92,7 @@ where impl ProcedureBuilder> where - TError: error::Error + Send + 'static, + TError: Error, TRootCtx: 'static, TCtx: 'static, TInput: 'static, diff --git a/rspc/src/procedure/error.rs b/rspc/src/procedure/error.rs index 35863cd7..96628de6 100644 --- a/rspc/src/procedure/error.rs +++ b/rspc/src/procedure/error.rs @@ -1,4 +1,13 @@ -use std::{error, fmt}; +use std::{ + any::{type_name, Any, TypeId}, + error, fmt, +}; + +use serde::{Serialize, Serializer}; + +use crate::Error; + +use super::ProcedureOutputSerializeError; pub enum InternalError { /// Attempted to deserialize input but found downcastable input. @@ -33,3 +42,88 @@ impl fmt::Display for InternalError { } impl error::Error for InternalError {} + +trait ErasedError: error::Error + erased_serde::Serialize + Any + Send + 'static { + fn to_box_any(self: Box) -> Box; + + fn to_value(&self) -> Option>; +} +impl ErasedError for T { + fn to_box_any(self: Box) -> Box { + self + } + + fn to_value(&self) -> Option> { + Some(serde_value::to_value(self)) + } +} + +pub struct ResolverError { + status: u16, + type_name: &'static str, + type_id: TypeId, + inner: Box, +} + +impl ResolverError { + pub fn new(value: T) -> Self { + Self { + status: value.status(), + type_name: type_name::(), + type_id: TypeId::of::(), + inner: Box::new(value), + } + } + + pub fn status(&self) -> u16 { + if self.status > 400 || self.status < 600 { + return 500; + } + + self.status + } + + pub fn type_name(&self) -> &'static str { + self.type_name + } + + pub fn type_id(&self) -> TypeId { + self.type_id + } + + pub fn downcast(self) -> Option { + self.inner.to_box_any().downcast().map(|v| *v).ok() + } + + // TODO: Using `ProcedureOutputSerializeError`???? + pub fn serialize( + self, + ser: S, + ) -> Result> { + let value = self + .inner + .to_value() + .ok_or(ProcedureOutputSerializeError::ErrResultNotDeserializable( + self.type_name, + ))? + .expect("serde_value doesn't panic"); // TODO: This is false + + value + .serialize(ser) + .map_err(ProcedureOutputSerializeError::ErrSerializer) + } +} + +impl fmt::Debug for ResolverError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.inner) + } +} + +impl fmt::Display for ResolverError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.inner) + } +} + +impl error::Error for ResolverError {} diff --git a/rspc/src/procedure/meta.rs b/rspc/src/procedure/meta.rs index 70de7a90..efc757c6 100644 --- a/rspc/src/procedure/meta.rs +++ b/rspc/src/procedure/meta.rs @@ -1,5 +1,5 @@ use core::fmt; -use std::{borrow::Cow, collections::HashMap, sync::Arc}; +use std::{borrow::Cow, sync::Arc}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, specta::Type)] #[specta(rename_all = "camelCase")] diff --git a/rspc/src/procedure/output.rs b/rspc/src/procedure/output.rs index e5212a49..05de5e71 100644 --- a/rspc/src/procedure/output.rs +++ b/rspc/src/procedure/output.rs @@ -36,6 +36,15 @@ pub struct ProcedureOutput { inner: Box, } +impl fmt::Debug for ProcedureOutput { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ProcedureOutput") + .field("type_name", &self.type_name) + .field("type_id", &self.type_id) + .finish() + } +} + impl ProcedureOutput { pub fn new(value: T) -> Self { Self { @@ -75,7 +84,7 @@ impl ProcedureOutput { .ok_or(ProcedureOutputSerializeError::ErrResultNotDeserializable( self.type_name, ))? - .expect("serde_value doesn't panic"); + .expect("serde_value doesn't panic"); // TODO: This is false value .serialize(ser) diff --git a/rspc/src/procedure/procedure.rs b/rspc/src/procedure/procedure.rs index 2e9ae8d1..af7e83bd 100644 --- a/rspc/src/procedure/procedure.rs +++ b/rspc/src/procedure/procedure.rs @@ -1,9 +1,9 @@ -use std::{borrow::Cow, error, fmt, sync::Arc}; +use std::{borrow::Cow, fmt, sync::Arc}; use futures::FutureExt; use specta::{DataType, TypeMap}; -use crate::State; +use crate::{Error, State}; use super::{ exec_input::{AnyInput, InputValueInner}, @@ -53,7 +53,7 @@ where /// Construct a new procedure using [`ProcedureBuilder`]. pub fn builder() -> ProcedureBuilder where - TError: error::Error + Send + 'static, + TError: Error, // Only the first layer (middleware or the procedure) needs to be a valid input/output type I: ResolverInput, R: ResolverOutput, diff --git a/rspc/src/procedure/resolver_output.rs b/rspc/src/procedure/resolver_output.rs index c94313d2..66fde067 100644 --- a/rspc/src/procedure/resolver_output.rs +++ b/rspc/src/procedure/resolver_output.rs @@ -1,9 +1,9 @@ -use std::error; - use futures::{stream::once, Stream, StreamExt}; use serde::Serialize; use specta::{DataType, Generics, Type, TypeMap}; +use crate::Error; + use super::{ProcedureOutput, ProcedureStream}; /// A type which can be returned from a procedure. @@ -17,6 +17,7 @@ use super::{ProcedureOutput, ProcedureStream}; /// For each value the [`Self::into_procedure_stream`] implementation **must** defer to [`Self::into_procedure_result`] to convert the value into a [`ProcedureOutput`]. rspc provides a default implementation that takes care of this for you so don't override it unless you have a good reason. /// /// ## Implementation for custom types +/// /// ```rust /// pub struct MyCoolThing(pub String); /// @@ -43,7 +44,7 @@ pub trait ResolverOutput: Sized + Send + 'static { procedure: impl Stream> + Send + 'static, ) -> ProcedureStream where - TError: error::Error + Send + 'static, + TError: Error, { ProcedureStream::from_stream(procedure.map(|v| v?.into_procedure_result())) } @@ -58,7 +59,7 @@ pub trait ResolverOutput: Sized + Send + 'static { impl ResolverOutput for T where T: Serialize + Type + Send + 'static, - TError: error::Error + Send + 'static, + TError: Error, { fn data_type(type_map: &mut TypeMap) -> DataType { T::inline(type_map, Generics::Definition) @@ -83,7 +84,7 @@ where procedure: impl Stream> + Send + 'static, ) -> ProcedureStream where - TErr: error::Error + Send + 'static, + TErr: Error, { ProcedureStream::from_stream( procedure diff --git a/rspc/src/procedure/stream.rs b/rspc/src/procedure/stream.rs index 1ed3b192..5a80f8a7 100644 --- a/rspc/src/procedure/stream.rs +++ b/rspc/src/procedure/stream.rs @@ -7,20 +7,14 @@ use std::{ use futures::{Stream, TryFutureExt, TryStreamExt}; -use super::output::ProcedureOutput; +use crate::Error; -type BoxError = Box; -fn box_error(err: T) -> BoxError -where - T: error::Error + Send + 'static, -{ - Box::new(err) -} +use super::{output::ProcedureOutput, ResolverError}; enum Inner { - Value(Result), - Future(Pin> + Send>>), - Stream(Pin> + Send>>), + Value(Result), + Future(Pin> + Send>>), + Stream(Pin> + Send>>), } pub struct ProcedureStream(Option); @@ -28,30 +22,34 @@ pub struct ProcedureStream(Option); impl ProcedureStream { pub fn from_value(value: Result) -> Self where - TError: error::Error + Send + 'static, + TError: Error, { - Self(Some(Inner::Value(value.map_err(box_error)))) + Self(Some(Inner::Value(value.map_err(ResolverError::new)))) } pub fn from_future(future: F) -> Self where F: Future> + Send + 'static, - TError: error::Error + Send + 'static, + TError: Error, { - Self(Some(Inner::Future(Box::pin(future.map_err(box_error))))) + Self(Some(Inner::Future(Box::pin( + future.map_err(ResolverError::new), + )))) } pub fn from_stream(stream: S) -> Self where S: Stream> + Send + 'static, - TError: error::Error + Send + 'static, + TError: Error, { - Self(Some(Inner::Stream(Box::pin(stream.map_err(box_error))))) + Self(Some(Inner::Stream(Box::pin( + stream.map_err(ResolverError::new), + )))) } } impl Stream for ProcedureStream { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.0.as_mut() {