diff --git a/src/core/async_graphql_hyper.rs b/src/core/async_graphql_hyper.rs index 5518ea5da7..4b9df5e488 100644 --- a/src/core/async_graphql_hyper.rs +++ b/src/core/async_graphql_hyper.rs @@ -1,9 +1,10 @@ -use std::any::Any; +use std::collections::HashMap; use std::hash::{Hash, Hasher}; use anyhow::Result; -use async_graphql::parser::types::{ExecutableDocument, OperationType}; -use async_graphql::{BatchResponse, Executor, Value}; +use async_graphql::parser::types::ExecutableDocument; +use async_graphql::{BatchResponse, Value}; +use async_graphql_value::ConstValue; use http::header::{HeaderMap, HeaderValue, CACHE_CONTROL, CONTENT_TYPE}; use http::{Response, StatusCode}; use hyper::Body; @@ -13,32 +14,17 @@ use tailcall_hasher::TailcallHasher; use super::jit::{BatchResponse as JITBatchResponse, JITExecutor}; +// TODO: replace usage with some other implementation. +// This one is used to calculate hash and use the value later +// as a key in the HashMap. But such use could lead to potential +// issues in case of hash collisions #[derive(PartialEq, Eq, Clone, Hash, Debug)] pub struct OperationId(u64); #[async_trait::async_trait] pub trait GraphQLRequestLike: Hash + Send { - fn data(self, data: D) -> Self; - async fn execute(self, executor: &E) -> GraphQLResponse - where - E: Executor; - async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse; - fn parse_query(&mut self) -> Option<&ExecutableDocument>; - - fn is_query(&mut self) -> bool { - self.parse_query() - .map(|a| { - let mut is_query = false; - for (_, operation) in a.operations.iter() { - is_query = operation.node.ty == OperationType::Query; - } - is_query - }) - .unwrap_or(false) - } - fn operation_id(&self, headers: &HeaderMap) -> OperationId { let mut hasher = TailcallHasher::default(); let state = &mut hasher; @@ -51,86 +37,101 @@ pub trait GraphQLRequestLike: Hash + Send { } } -#[derive(Debug, Deserialize)] -pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest); -impl GraphQLBatchRequest {} -impl Hash for GraphQLBatchRequest { - //TODO: Fix Hash implementation for BatchRequest, which should ideally batch - // execution of individual requests instead of the whole chunk of requests as - // one. - fn hash(&self, state: &mut H) { - for request in self.0.iter() { - request.query.hash(state); - request.operation_name.hash(state); - for (name, value) in request.variables.iter() { - name.hash(state); - value.to_string().hash(state); - } - } - } +#[derive(Debug, Hash, Serialize, Deserialize)] +#[serde(untagged)] +pub enum BatchWrapper { + Single(T), + Batch(Vec), } -#[async_trait::async_trait] -impl GraphQLRequestLike for GraphQLBatchRequest { - fn data(mut self, data: D) -> Self { - for request in self.0.iter_mut() { - request.data.insert(data.clone()); - } - self - } - - async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse { - GraphQLArcResponse::new(executor.execute_batch(self.0).await) - } - /// Shortcut method to execute the request on the executor. - async fn execute(self, executor: &E) -> GraphQLResponse - where - E: Executor, - { - GraphQLResponse(executor.execute_batch(self.0).await) - } +pub type GraphQLBatchRequest = BatchWrapper; - fn parse_query(&mut self) -> Option<&ExecutableDocument> { - None +#[async_trait::async_trait] +impl GraphQLRequestLike for BatchWrapper { + async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse { + GraphQLArcResponse::new(executor.execute_batch(self).await) } } -#[derive(Debug, Deserialize)] -pub struct GraphQLRequest(pub async_graphql::Request); +#[derive(Debug, Default, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct GraphQLRequest { + #[serde(default)] + pub query: String, + #[serde(default)] + pub operation_name: Option, + #[serde(default)] + pub variables: HashMap, + #[serde(default)] + pub extensions: HashMap, +} -impl GraphQLRequest {} impl Hash for GraphQLRequest { fn hash(&self, state: &mut H) { - self.0.query.hash(state); - self.0.operation_name.hash(state); - for (name, value) in self.0.variables.iter() { + self.query.hash(state); + self.operation_name.hash(state); + for (name, value) in self.variables.iter() { name.hash(state); value.to_string().hash(state); } } } + +impl GraphQLRequest { + pub fn new(query: impl Into) -> Self { + Self { query: query.into(), ..Default::default() } + } +} + #[async_trait::async_trait] impl GraphQLRequestLike for GraphQLRequest { - #[must_use] - fn data(mut self, data: D) -> Self { - self.0.data.insert(data); - self - } async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse { - let response = executor.execute(self.0).await; + let response = executor.execute(self).await; GraphQLArcResponse::new(JITBatchResponse::Single(response)) } +} + +#[derive(Debug)] +pub struct ParsedGraphQLRequest { + pub query: String, + pub operation_name: Option, + pub variables: HashMap, + pub extensions: HashMap, + pub parsed_query: ExecutableDocument, +} - /// Shortcut method to execute the request on the schema. - async fn execute(self, executor: &E) -> GraphQLResponse - where - E: Executor, - { - GraphQLResponse(executor.execute(self.0).await.into()) +impl Hash for ParsedGraphQLRequest { + fn hash(&self, state: &mut H) { + self.query.hash(state); + self.operation_name.hash(state); + for (name, value) in self.variables.iter() { + name.hash(state); + value.to_string().hash(state); + } } +} + +impl TryFrom for ParsedGraphQLRequest { + type Error = async_graphql::parser::Error; - fn parse_query(&mut self) -> Option<&ExecutableDocument> { - self.0.parsed_query().ok() + fn try_from(req: GraphQLRequest) -> std::result::Result { + let parsed_query = async_graphql::parser::parse_query(&req.query)?; + + Ok(Self { + query: req.query, + operation_name: req.operation_name, + variables: req.variables, + extensions: req.extensions, + parsed_query, + }) + } +} + +#[async_trait::async_trait] +impl GraphQLRequestLike for ParsedGraphQLRequest { + async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse { + let response = executor.execute(self).await; + GraphQLArcResponse::new(JITBatchResponse::Single(response)) } } @@ -148,42 +149,6 @@ impl From for GraphQLResponse { } } -impl From for GraphQLRequest { - fn from(query: GraphQLQuery) -> Self { - let mut request = async_graphql::Request::new(query.query); - - if let Some(operation_name) = query.operation_name { - request = request.operation_name(operation_name); - } - - if let Some(variables) = query.variables { - let value = serde_json::from_str(&variables).unwrap_or_default(); - let variables = async_graphql::Variables::from_json(value); - request = request.variables(variables); - } - - GraphQLRequest(request) - } -} - -#[derive(Debug)] -pub struct GraphQLQuery { - query: String, - operation_name: Option, - variables: Option, -} - -impl GraphQLQuery { - /// Shortcut method to execute the request on the schema. - pub async fn execute(self, executor: &E) -> GraphQLResponse - where - E: Executor, - { - let request: GraphQLRequest = self.into(); - request.execute(executor).await - } -} - static APPLICATION_JSON: Lazy = Lazy::new(|| HeaderValue::from_static("application/json")); @@ -408,6 +373,17 @@ impl GraphQLArcResponse { pub fn into_response(self) -> Result> { self.build_response(StatusCode::OK, self.default_body()?) } + + /// Transforms a plain `GraphQLResponse` into a `Response`. + /// Differs as `to_response` by flattening the response's data + /// `{"data": {"user": {"name": "John"}}}` becomes `{"name": "John"}`. + pub fn into_rest_response(self) -> Result> { + if !self.response.is_ok() { + return self.build_response(StatusCode::INTERNAL_SERVER_ERROR, self.default_body()?); + } + + self.into_response() + } } #[cfg(test)] diff --git a/src/core/http/request_handler.rs b/src/core/http/request_handler.rs index e7ab5efae2..aa6886d129 100644 --- a/src/core/http/request_handler.rs +++ b/src/core/http/request_handler.rs @@ -241,20 +241,23 @@ async fn handle_rest_apis( *request.uri_mut() = request.uri().path().replace(API_URL_PREFIX, "").parse()?; let req_ctx = Arc::new(create_request_context(&request, app_ctx.as_ref())); if let Some(p_request) = app_ctx.endpoints.matches(&request) { + let (req, body) = request.into_parts(); let http_route = format!("{API_URL_PREFIX}{}", p_request.path.as_str()); req_counter.set_http_route(&http_route); let span = tracing::info_span!( "REST", - otel.name = format!("REST {} {}", request.method(), p_request.path.as_str()), + otel.name = format!("REST {} {}", req.method, p_request.path.as_str()), otel.kind = ?SpanKind::Server, - { HTTP_REQUEST_METHOD } = %request.method(), + { HTTP_REQUEST_METHOD } = %req.method, { HTTP_ROUTE } = http_route ); return async { - let graphql_request = p_request.into_request(request).await?; + let graphql_request = p_request.into_request(body).await?; + let operation_id = graphql_request.operation_id(&req.headers); + let exec = JITExecutor::new(app_ctx.clone(), req_ctx.clone(), operation_id) + .flatten_response(true); let mut response = graphql_request - .data(req_ctx.clone()) - .execute(&app_ctx.schema) + .execute_with_jit(exec) .await .set_cache_control( app_ctx.blueprint.server.enable_cache_control_header, diff --git a/src/core/jit/error.rs b/src/core/jit/error.rs index 086b604502..25a6c388ee 100644 --- a/src/core/jit/error.rs +++ b/src/core/jit/error.rs @@ -56,6 +56,12 @@ pub enum Error { Unknown, } +impl From for Error { + fn from(value: async_graphql::ServerError) -> Self { + Self::ServerError(value) + } +} + impl ErrorExtensions for Error { fn extend(&self) -> super::graphql_error::Error { match self { diff --git a/src/core/jit/exec_const.rs b/src/core/jit/exec_const.rs index 1606f05631..29990f2f88 100644 --- a/src/core/jit/exec_const.rs +++ b/src/core/jit/exec_const.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use async_graphql_value::{ConstValue, Value}; +use derive_setters::Setters; use futures_util::future::join_all; use tailcall_valid::Validator; @@ -14,17 +15,20 @@ use crate::core::ir::model::IR; use crate::core::ir::{self, EmptyResolverContext, EvalContext}; use crate::core::jit::synth::Synth; use crate::core::jit::transform::InputResolver; -use crate::core::json::{JsonLike, JsonLikeList}; +use crate::core::json::{JsonLike, JsonLikeList, JsonObjectLike}; use crate::core::Transform; /// A specialized executor that executes with async_graphql::Value +#[derive(Setters)] pub struct ConstValueExecutor { pub plan: OperationPlan, + + flatten_response: bool, } impl From> for ConstValueExecutor { fn from(plan: OperationPlan) -> Self { - Self { plan } + Self { plan, flatten_response: false } } } @@ -56,6 +60,7 @@ impl ConstValueExecutor { let is_introspection_query = req_ctx.server.get_enable_introspection() && self.plan.is_introspection_query; + let flatten_response = self.flatten_response; let variables = &request.variables; // Attempt to skip unnecessary fields @@ -102,13 +107,45 @@ impl ConstValueExecutor { let async_req = async_graphql::Request::from(request).only_introspection(); let async_resp = app_ctx.execute(async_req).await; - resp.merge_with(&async_resp).into() + to_any_response(resp.merge_with(&async_resp), flatten_response) } else { - resp.into() + to_any_response(resp, flatten_response) } } } +fn to_any_response( + resp: Response, + flatten: bool, +) -> AnyResponse> { + if flatten { + if resp.errors.is_empty() { + AnyResponse { + body: Arc::new( + serde_json::to_vec(flatten_response(&resp.data)).unwrap_or_default(), + ), + is_ok: true, + cache_control: resp.cache_control, + } + } else { + AnyResponse { + body: Arc::new(serde_json::to_vec(&resp).unwrap_or_default()), + is_ok: false, + cache_control: resp.cache_control, + } + } + } else { + resp.into() + } +} + +fn flatten_response<'a, T: JsonLike<'a>>(data: &'a T) -> &'a T { + match data.as_object() { + Some(obj) if obj.len() == 1 => flatten_response(obj.iter().next().unwrap().1), + _ => data, + } +} + struct ConstValueExec<'a> { plan: &'a OperationPlan, req_context: &'a RequestContext, diff --git a/src/core/jit/graphql_error.rs b/src/core/jit/graphql_error.rs index 6d2e0a7132..ade6b4af61 100644 --- a/src/core/jit/graphql_error.rs +++ b/src/core/jit/graphql_error.rs @@ -53,6 +53,10 @@ impl From> for GraphQLError { return e.into(); } + if let super::Error::ServerError(e) = inner_value { + return e.into(); + } + let ext = inner_value.extend().extensions; let mut server_error = GraphQLError::new(inner_value.to_string(), Some(position)); server_error.extensions = ext; diff --git a/src/core/jit/graphql_executor.rs b/src/core/jit/graphql_executor.rs index e1c131c374..4450de81a4 100644 --- a/src/core/jit/graphql_executor.rs +++ b/src/core/jit/graphql_executor.rs @@ -1,25 +1,26 @@ use std::collections::BTreeMap; -use std::future::Future; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use async_graphql::{BatchRequest, Value}; +use async_graphql::Value; use async_graphql_value::{ConstValue, Extensions}; +use derive_setters::Setters; use futures_util::stream::FuturesOrdered; use futures_util::StreamExt; use tailcall_hasher::TailcallHasher; use super::{AnyResponse, BatchResponse, Response}; use crate::core::app_context::AppContext; -use crate::core::async_graphql_hyper::OperationId; +use crate::core::async_graphql_hyper::{BatchWrapper, GraphQLRequest, OperationId}; use crate::core::http::RequestContext; -use crate::core::jit::{self, ConstValueExecutor, OPHash, Pos, Positioned}; +use crate::core::jit::{self, ConstValueExecutor, OPHash}; -#[derive(Clone)] +#[derive(Clone, Setters)] pub struct JITExecutor { app_ctx: Arc, req_ctx: Arc, operation_id: OperationId, + flatten_response: bool, } impl JITExecutor { @@ -28,7 +29,7 @@ impl JITExecutor { req_ctx: Arc, operation_id: OperationId, ) -> Self { - Self { app_ctx, req_ctx, operation_id } + Self { app_ctx, req_ctx, operation_id, flatten_response: false } } #[inline(always)] @@ -61,72 +62,85 @@ impl JITExecutor { out.unwrap_or_default() } + /// Calculates hash for the request considering + /// the request is const, i.e. doesn't depend on input. + /// That's basically use only the query itself to calculating the hash #[inline(always)] - fn req_hash(request: &async_graphql::Request) -> OPHash { - let mut hasher = TailcallHasher::default(); - request.query.hash(&mut hasher); + fn const_execution_hash(request: &jit::Request) -> OPHash { + let hasher = &mut TailcallHasher::default(); + + request.query.hash(hasher); OPHash::new(hasher.finish()) } } impl JITExecutor { - pub fn execute( - &self, - request: async_graphql::Request, - ) -> impl Future>> + Send + '_ { - // TODO: hash considering only the query itself ignoring specified operation and - // variables that could differ for the same query - let hash = Self::req_hash(&request); - - async move { - if let Some(response) = self.app_ctx.const_execution_cache.get(&hash) { - return response.clone(); - } - - let jit_request = jit::Request::from(request); - let exec = if let Some(op) = self.app_ctx.operation_plans.get(&hash) { - ConstValueExecutor::from(op.value().clone()) - } else { - let exec = match ConstValueExecutor::try_new(&jit_request, &self.app_ctx) { - Ok(exec) => exec, - Err(error) => { - return Response::::default() - .with_errors(vec![Positioned::new(error, Pos::default())]) - .into() - } - }; - self.app_ctx - .operation_plans - .insert(hash.clone(), exec.plan.clone()); - exec - }; - - let is_const = exec.plan.is_const; - let is_protected = exec.plan.is_protected; - - let response = if exec.plan.can_dedupe() { - self.dedupe_and_exec(exec, jit_request).await - } else { - self.exec(exec, jit_request).await + pub async fn execute(&self, request: T) -> AnyResponse> + where + jit::Request: TryFrom, + T: Hash + Send + 'static, + { + let jit_request = match jit::Request::try_from(request) { + Ok(request) => request, + Err(error) => return Response::::from(error).into(), + }; + + let const_execution_hash = Self::const_execution_hash(&jit_request); + + // check if the request is has been set to const_execution_cache + // and if yes serve the response from the cache since + // the query doesn't depend on input and could be calculated once + // WARN: make sure the value is set to cache only if the plan is actually + // is_const + if let Some(response) = self + .app_ctx + .const_execution_cache + .get(&const_execution_hash) + { + return response.clone(); + } + let exec = if let Some(op) = self.app_ctx.operation_plans.get(&const_execution_hash) { + ConstValueExecutor::from(op.value().clone()) + } else { + let exec = match ConstValueExecutor::try_new(&jit_request, &self.app_ctx) { + Ok(exec) => exec, + Err(error) => return Response::::from(error).into(), }; - - // Cache the response if it's constant and not wrapped with protected. - if is_const && !is_protected { - self.app_ctx - .const_execution_cache - .insert(hash, response.clone()); - } - - response + self.app_ctx + .operation_plans + .insert(const_execution_hash.clone(), exec.plan.clone()); + exec + }; + + let exec = exec.flatten_response(self.flatten_response); + let is_const = exec.plan.is_const; + let is_protected = exec.plan.is_protected; + + let response = if exec.plan.can_dedupe() { + self.dedupe_and_exec(exec, jit_request).await + } else { + self.exec(exec, jit_request).await + }; + + // Cache the response if it's constant and not wrapped with protected. + if is_const && !is_protected { + self.app_ctx + .const_execution_cache + .insert(const_execution_hash, response.clone()); } + + response } /// Execute a GraphQL batch query. - pub async fn execute_batch(&self, batch_request: BatchRequest) -> BatchResponse> { + pub async fn execute_batch( + &self, + batch_request: BatchWrapper, + ) -> BatchResponse> { match batch_request { - BatchRequest::Single(request) => BatchResponse::Single(self.execute(request).await), - BatchRequest::Batch(requests) => { + BatchWrapper::Single(request) => BatchResponse::Single(self.execute(request).await), + BatchWrapper::Batch(requests) => { let futs = FuturesOrdered::from_iter( requests.into_iter().map(|request| self.execute(request)), ); diff --git a/src/core/jit/model.rs b/src/core/jit/model.rs index 9b2950a22a..65fe69dff5 100644 --- a/src/core/jit/model.rs +++ b/src/core/jit/model.rs @@ -20,6 +20,12 @@ use crate::core::scalar::Scalar; #[derive(Debug, Deserialize, Clone)] pub struct Variables(HashMap); +impl From> for Variables { + fn from(value: HashMap) -> Self { + Self(value) + } +} + impl PathString for Variables { fn path_string<'a, T: AsRef>(&'a self, path: &'a [T]) -> Option> { self.get(path[0].as_ref()) @@ -293,6 +299,10 @@ impl Debug for Field { } } +// TODO: replace usage with some other implementation. +// This one is used to calculate hash and use the value later +// as a key in the HashMap. But such use could lead to potential +// issues in case of hash collisions #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct OPHash(u64); @@ -575,10 +585,10 @@ impl From for Positioned { #[cfg(test)] mod test { use async_graphql::parser::types::ConstDirective; - use async_graphql::Request; use async_graphql_value::ConstValue; use super::{Directive, OperationPlan}; + use crate::core::async_graphql_hyper::GraphQLRequest; use crate::core::blueprint::Blueprint; use crate::core::config::ConfigModule; use crate::core::jit; @@ -589,8 +599,8 @@ mod test { let module = ConfigModule::from(config); let bp = Blueprint::try_from(&module).unwrap(); - let request = Request::new(query); - let jit_request = jit::Request::from(request); + let request = GraphQLRequest::new(query); + let jit_request = jit::Request::try_from(request).unwrap(); jit_request.create_plan(&bp).unwrap() } diff --git a/src/core/jit/request.rs b/src/core/jit/request.rs index f37c4a721e..261d98e53c 100644 --- a/src/core/jit/request.rs +++ b/src/core/jit/request.rs @@ -1,38 +1,44 @@ use std::collections::HashMap; -use std::ops::DerefMut; +use async_graphql::parser::types::ExecutableDocument; use async_graphql_value::ConstValue; -use serde::Deserialize; use tailcall_valid::Validator; use super::{transform, Builder, OperationPlan, Result, Variables}; +use crate::core::async_graphql_hyper::{GraphQLRequest, ParsedGraphQLRequest}; use crate::core::blueprint::Blueprint; use crate::core::transform::TransformerOps; use crate::core::Transform; -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Clone)] pub struct Request { - #[serde(default)] pub query: String, - #[serde(default, rename = "operationName")] pub operation_name: Option, - #[serde(default)] pub variables: Variables, - #[serde(default)] pub extensions: HashMap, + pub parsed_query: ExecutableDocument, } -// NOTE: This is hot code and should allocate minimal memory -impl From for Request { - fn from(mut value: async_graphql::Request) -> Self { - let variables = std::mem::take(value.variables.deref_mut()); +impl TryFrom for Request { + type Error = super::Error; - Self { + fn try_from(value: GraphQLRequest) -> Result { + let value = ParsedGraphQLRequest::try_from(value)?; + + Self::try_from(value) + } +} + +impl TryFrom for Request { + type Error = super::Error; + fn try_from(value: ParsedGraphQLRequest) -> Result { + Ok(Self { + parsed_query: value.parsed_query, query: value.query, operation_name: value.operation_name, - variables: Variables::from_iter(variables.into_iter().map(|(k, v)| (k.to_string(), v))), - extensions: value.extensions.0, - } + variables: Variables::from(value.variables), + extensions: value.extensions, + }) } } @@ -41,8 +47,7 @@ impl Request { &self, blueprint: &Blueprint, ) -> Result> { - let doc = async_graphql::parser::parse_query(&self.query)?; - let builder = Builder::new(blueprint, &doc); + let builder = Builder::new(blueprint, &self.parsed_query); let plan = builder.build(self.operation_name.as_deref())?; transform::CheckConst::new() @@ -67,6 +72,7 @@ impl Request { operation_name: None, variables: Variables::new(), extensions: HashMap::new(), + parsed_query: async_graphql::parser::parse_query(query).unwrap(), } } diff --git a/src/core/jit/response.rs b/src/core/jit/response.rs index aabe67dd65..d12cb5bfda 100644 --- a/src/core/jit/response.rs +++ b/src/core/jit/response.rs @@ -4,7 +4,7 @@ use derive_setters::Setters; use serde::Serialize; use super::graphql_error::GraphQLError; -use super::Positioned; +use super::{Pos, Positioned}; use crate::core::async_graphql_hyper::CacheControl; use crate::core::jit; use crate::core::json::{JsonLike, JsonObjectLike}; @@ -33,6 +33,12 @@ impl Default for Response { } } +impl From for Response { + fn from(value: jit::Error) -> Self { + Response::default().with_errors(vec![Positioned::new(value, Pos::default())]) + } +} + impl Response { pub fn new(result: Result>) -> Self { match result { diff --git a/src/core/rest/endpoint.rs b/src/core/rest/endpoint.rs index 2234a9f885..8daf9ca31b 100644 --- a/src/core/rest/endpoint.rs +++ b/src/core/rest/endpoint.rs @@ -11,7 +11,6 @@ use super::path::{Path, Segment}; use super::query_params::QueryParams; use super::type_map::TypeMap; use super::{Request, Result}; -use crate::core::async_graphql_hyper::GraphQLRequest; use crate::core::directive::DirectiveCodec; use crate::core::http::Method; use crate::core::rest::typed_variables::{UrlParamType, N}; @@ -83,11 +82,11 @@ impl Endpoint { Ok(endpoints) } - pub fn into_request(self) -> GraphQLRequest { + pub fn into_request(self) -> async_graphql::Request { let variables = Self::get_default_variables(&self); let mut req = async_graphql::Request::new("").variables(variables); req.set_parsed_query(Self::remove_rest_directives(self.doc)); - GraphQLRequest(req) + req } fn get_default_variables(endpoint: &Endpoint) -> Variables { diff --git a/src/core/rest/operation.rs b/src/core/rest/operation.rs index ea65a9f89a..31598de4db 100644 --- a/src/core/rest/operation.rs +++ b/src/core/rest/operation.rs @@ -4,24 +4,26 @@ use async_graphql::dynamic::Schema; use tailcall_valid::{Cause, Valid, Validator}; use super::{Error, Result}; -use crate::core::async_graphql_hyper::{GraphQLRequest, GraphQLRequestLike}; use crate::core::blueprint::{Blueprint, SchemaModifiers}; use crate::core::http::RequestContext; #[derive(Debug)] pub struct OperationQuery { - query: GraphQLRequest, + query: async_graphql::Request, } impl OperationQuery { - pub fn new(query: GraphQLRequest, request_context: Arc) -> Result { + pub fn new( + query: async_graphql::Request, + request_context: Arc, + ) -> Result { let query = query.data(request_context); Ok(Self { query }) } async fn validate(self, schema: &Schema) -> Vec { schema - .execute(self.query.0) + .execute(self.query) .await .errors .iter() diff --git a/src/core/rest/partial_request.rs b/src/core/rest/partial_request.rs index cfd8053483..d6838e9af0 100644 --- a/src/core/rest/partial_request.rs +++ b/src/core/rest/partial_request.rs @@ -1,10 +1,14 @@ +use std::collections::HashMap; +use std::ops::DerefMut; + use async_graphql::parser::types::ExecutableDocument; -use async_graphql::{Name, Variables}; +use async_graphql::Variables; use async_graphql_value::ConstValue; +use hyper::Body; use super::path::Path; -use super::{Request, Result}; -use crate::core::async_graphql_hyper::GraphQLRequest; +use super::Result; +use crate::core::async_graphql_hyper::ParsedGraphQLRequest; /// A partial GraphQLRequest that contains a parsed executable GraphQL document. #[derive(Debug)] @@ -16,17 +20,26 @@ pub struct PartialRequest<'a> { } impl PartialRequest<'_> { - pub async fn into_request(self, request: Request) -> Result { - let mut variables = self.variables; + pub async fn into_request(mut self, body: Body) -> Result { + let variables = std::mem::take(self.variables.deref_mut()); + let mut variables = + HashMap::from_iter(variables.into_iter().map(|(k, v)| (k.to_string(), v))); + if let Some(key) = self.body { - let bytes = hyper::body::to_bytes(request.into_body()).await?; + let bytes = hyper::body::to_bytes(body).await?; let body: ConstValue = serde_json::from_slice(&bytes)?; - variables.insert(Name::new(key), body); + variables.insert(key.to_string(), body); } - let mut req = async_graphql::Request::new("").variables(variables); - req.set_parsed_query(self.doc.clone()); - - Ok(GraphQLRequest(req)) + Ok(ParsedGraphQLRequest { + // use path as query because query is used as part of the hashing + // and we need to have different hashed for different operations + // TODO: is there any way to make it more explicit here? + query: self.path.as_str().to_string(), + operation_name: None, + variables, + extensions: Default::default(), + parsed_query: self.doc.clone(), + }) } }