From fca6f7e7754b37845e1b87f19a5ce9fe785c6ec4 Mon Sep 17 00:00:00 2001 From: Tushar Mathur Date: Wed, 31 Jan 2024 20:15:48 +0530 Subject: [PATCH] feat: add grpc support for wasm (#1041) Co-authored-by: Sandipsinh Rathod Co-authored-by: Sandipsinh Rathod <62684960+ssddOnTop@users.noreply.github.com> --- Cargo.lock | 2 + Cargo.toml | 1 + cloudflare/Cargo.toml | 1 + cloudflare/package.json | 2 +- cloudflare/src/file.rs | 1 + cloudflare/src/handle.rs | 5 +- cloudflare/src/http.rs | 6 +- cloudflare/wrangler.toml | 2 +- src/blueprint/definitions.rs | 65 ++--- src/blueprint/from_config.rs | 15 +- src/blueprint/mod.rs | 4 +- src/blueprint/operators/const_field.rs | 23 +- src/blueprint/operators/expr.rs | 28 +-- src/blueprint/operators/graphql.rs | 6 +- src/blueprint/operators/grpc.rs | 41 ++-- src/blueprint/operators/http.rs | 24 +- src/blueprint/operators/modify.rs | 6 +- src/blueprint/upstream.rs | 10 +- src/cli/mod.rs | 1 + src/cli/server/server.rs | 15 +- src/cli/tc.rs | 19 +- src/config/config_set.rs | 54 +++++ src/config/mod.rs | 2 + src/config/reader.rs | 226 ++++++++++++++++-- src/grpc/data_loader_request.rs | 38 ++- src/grpc/protobuf.rs | 95 ++++---- src/grpc/request_template.rs | 45 +++- src/grpc/tests/cycle.proto | 4 + src/grpc/tests/duplicate.proto | 4 + src/grpc/tests/nested0.proto | 4 + src/grpc/tests/nested1.proto | 4 + src/lib.rs | 2 +- .../errors/test-grpc-proto-path.graphql | 2 +- tests/graphql_spec.rs | 47 ++-- tests/http_spec.rs | 7 +- tests/server_spec.rs | 6 +- 36 files changed, 580 insertions(+), 237 deletions(-) create mode 100644 src/config/config_set.rs create mode 100644 src/grpc/tests/cycle.proto create mode 100644 src/grpc/tests/duplicate.proto create mode 100644 src/grpc/tests/nested0.proto create mode 100644 src/grpc/tests/nested1.proto diff --git a/Cargo.lock b/Cargo.lock index 24b9068f63..b91914e5c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -807,6 +807,7 @@ dependencies = [ "hyper", "lazy_static", "log", + "protox", "reqwest", "serde_json", "serde_qs", @@ -3472,6 +3473,7 @@ dependencies = [ "prost", "prost-reflect", "protox", + "protox-parse", "regex", "reqwest", "reqwest-middleware", diff --git a/Cargo.toml b/Cargo.toml index 131e07b109..117365c389 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,6 +75,7 @@ async-std = { version = "1.12.0", features = [ ] } ttl_cache = "0.5.1" protox = "0.5.1" +protox-parse = "0.5.0" prost-reflect = { version = "0.12.0", features = ["serde"] } prost = "0.12.3" update-informer = { version = "1.1.0", default-features = false, features = ["github", "reqwest"], optional = true } diff --git a/cloudflare/Cargo.toml b/cloudflare/Cargo.toml index 3046358619..b3bd0309df 100644 --- a/cloudflare/Cargo.toml +++ b/cloudflare/Cargo.toml @@ -23,6 +23,7 @@ async-graphql-value = "7.0.1" serde_json = "1.0.113" serde_qs = "0.12.0" console_error_panic_hook = "0.1.7" +protox = "0.5.1" [profile.release] lto = true diff --git a/cloudflare/package.json b/cloudflare/package.json index b070bff83a..403b38589a 100644 --- a/cloudflare/package.json +++ b/cloudflare/package.json @@ -5,7 +5,7 @@ "scripts": { "deploy": "npx wrangler deploy --minify", "publish": "npx wrangler publish --minify", - "dev": "npx wrangler dev --port 19194 --remote", + "dev": "npx wrangler dev --port 19194", "test": "cargo install -q worker-build && worker-build && vitest --run" }, "devDependencies": { diff --git a/cloudflare/src/file.rs b/cloudflare/src/file.rs index c7381bfee7..a8f29cfd89 100644 --- a/cloudflare/src/file.rs +++ b/cloudflare/src/file.rs @@ -24,6 +24,7 @@ impl CloudflareFileIO { // TODO: avoid the unsafe impl unsafe impl Sync for CloudflareFileIO {} +unsafe impl Send for CloudflareFileIO {} async fn get(bucket: Rc, path: String) -> anyhow::Result { let maybe_object = bucket diff --git a/cloudflare/src/handle.rs b/cloudflare/src/handle.rs index 1fe0357438..2f9401e196 100644 --- a/cloudflare/src/handle.rs +++ b/cloudflare/src/handle.rs @@ -7,7 +7,7 @@ use lazy_static::lazy_static; use tailcall::async_graphql_hyper::GraphQLRequest; use tailcall::blueprint::Blueprint; use tailcall::config::reader::ConfigReader; -use tailcall::config::Config; +use tailcall::config::ConfigSet; use tailcall::http::{graphiql, handle_request, AppContext}; use tailcall::EnvIO; @@ -56,13 +56,14 @@ async fn get_config( env_io: Arc, env: Rc, file_path: &str, -) -> anyhow::Result { +) -> anyhow::Result { let bucket_id = env_io .get("BUCKET") .ok_or(anyhow!("BUCKET var is not set"))?; log::debug!("R2 Bucket ID: {}", bucket_id); let file_io = init_file(env.clone(), bucket_id)?; let http_io = init_http(); + let reader = ConfigReader::init(file_io, http_io); let config = reader.read(&file_path).await?; Ok(config) diff --git a/cloudflare/src/http.rs b/cloudflare/src/http.rs index b7d362d6da..e488297d45 100644 --- a/cloudflare/src/http.rs +++ b/cloudflare/src/http.rs @@ -44,9 +44,7 @@ impl HttpIO for CloudflareHttp { } } -pub async fn to_response( - response: hyper::Response, -) -> anyhow::Result { +pub async fn to_response(response: hyper::Response) -> Result { let status = response.status().as_u16(); let headers = response.headers().clone(); let bytes = hyper::body::to_bytes(response).await?; @@ -80,7 +78,7 @@ pub fn to_method(method: worker::Method) -> Result { } } -pub async fn to_request(mut req: worker::Request) -> anyhow::Result> { +pub async fn to_request(mut req: worker::Request) -> Result> { let body = req.text().await.map_err(to_anyhow)?; let method = req.method(); let uri = req.url().map_err(to_anyhow)?.as_str().to_string(); diff --git a/cloudflare/wrangler.toml b/cloudflare/wrangler.toml index fcdf32e304..0c5bf0a3b8 100644 --- a/cloudflare/wrangler.toml +++ b/cloudflare/wrangler.toml @@ -6,7 +6,7 @@ compatibility_date = "2023-03-22" account_id = "59eda2a637301830ad43a6e3e4419346" [build] -command = "cargo install -q worker-build && worker-build --release" +command = "cargo install -q worker-build && worker-build" # the path to config must start with the binding name of respective r2 bucket. [vars] diff --git a/src/blueprint/definitions.rs b/src/blueprint/definitions.rs index e3d1965c60..88bb304270 100644 --- a/src/blueprint/definitions.rs +++ b/src/blueprint/definitions.rs @@ -69,7 +69,7 @@ struct ProcessFieldWithinTypeContext<'a> { remaining_path: &'a [String], type_info: &'a config::Type, is_required: bool, - config: &'a Config, + config_set: &'a ConfigSet, invalid_path_handler: &'a InvalidPathHandler, path_resolver_error_handler: &'a PathResolverErrorHandler, original_path: &'a [String], @@ -81,7 +81,7 @@ struct ProcessPathContext<'a> { field: &'a config::Field, type_info: &'a config::Type, is_required: bool, - config: &'a Config, + config_set: &'a ConfigSet, invalid_path_handler: &'a InvalidPathHandler, path_resolver_error_handler: &'a PathResolverErrorHandler, original_path: &'a [String], @@ -93,7 +93,7 @@ fn process_field_within_type(context: ProcessFieldWithinTypeContext) -> Valid Valid Valid Valid Valid Valid { let field = context.field; let type_info = context.type_info; let is_required = context.is_required; - let config = context.config; + let config_set = context.config_set; let invalid_path_handler = context.invalid_path_handler; let path_resolver_error_handler = context.path_resolver_error_handler; if let Some((field_name, remaining_path)) = path.split_first() { @@ -193,7 +193,7 @@ fn process_path(context: ProcessPathContext) -> Valid { let mut modified_field = field.clone(); modified_field.list = false; return process_path(ProcessPathContext { - config, + config_set, type_info, invalid_path_handler, path_resolver_error_handler, @@ -207,7 +207,7 @@ fn process_path(context: ProcessPathContext) -> Valid { .fields .get(field_name) .map(|_| type_info) - .or_else(|| config.find_type(&field.type_of)); + .or_else(|| config_set.find_type(&field.type_of)); if let Some(type_info) = target_type_info { return process_field_within_type(ProcessFieldWithinTypeContext { @@ -216,7 +216,7 @@ fn process_path(context: ProcessPathContext) -> Valid { remaining_path, type_info, is_required, - config, + config_set, invalid_path_handler, path_resolver_error_handler, original_path: context.original_path, @@ -252,9 +252,9 @@ fn to_enum_type_definition( fn to_object_type_definition( name: &str, type_of: &config::Type, - config: &Config, + config_set: &ConfigSet, ) -> Valid { - to_fields(name, type_of, config).map(|fields| { + to_fields(name, type_of, config_set).map(|fields| { Definition::ObjectTypeDefinition(ObjectTypeDefinition { name: name.to_string(), description: type_of.doc.clone(), @@ -266,8 +266,8 @@ fn to_object_type_definition( fn update_args<'a>( hasher: DefaultHasher, -) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new( +) -> TryFold<'a, (&'a ConfigSet, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { + TryFold::<(&ConfigSet, &Field, &config::Type, &str), FieldDefinition, String>::new( move |(_, field, typ, name), _| { let mut hasher = hasher.clone(); name.hash(&mut hasher); @@ -333,8 +333,8 @@ fn update_resolver_from_path( /// To solve the problem that by default such fields will be resolved to null value /// and nested resolvers won't be called pub fn update_nested_resolvers<'a>( -) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new( +) -> TryFold<'a, (&'a ConfigSet, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { + TryFold::<(&ConfigSet, &Field, &config::Type, &str), FieldDefinition, String>::new( move |(config, field, _, name), mut b_field| { if !field.has_resolver() && validate_field_has_resolver(name, field, &config.types).is_succeed() @@ -361,9 +361,9 @@ fn validate_field_type_exist(config: &Config, field: &Field) -> Valid<(), String fn to_fields( object_name: &str, type_of: &config::Type, - config: &Config, + config_set: &ConfigSet, ) -> Valid, String> { - let operation_type = if config.schema.mutation.as_deref().eq(&Some(object_name)) { + let operation_type = if config_set.schema.mutation.as_deref().eq(&Some(object_name)) { GraphQLOperationType::Mutation } else { GraphQLOperationType::Query @@ -390,7 +390,10 @@ fn to_fields( .and(update_expr(&operation_type).trace(config::Expr::trace_name().as_str())) .and(update_modify().trace(config::Modify::trace_name().as_str())) .and(update_nested_resolvers()) - .try_fold(&(config, field, type_of, name), FieldDefinition::default()) + .try_fold( + &(config_set, field, type_of, name), + FieldDefinition::default(), + ) }; // Process fields that are not marked as `omit` @@ -400,7 +403,7 @@ fn to_fields( .iter() .filter(|(_, field)| !field.is_omitted()), |(name, field)| { - validate_field_type_exist(config, field) + validate_field_type_exist(config_set, field) .and(to_field(name, field)) .trace(name) }, @@ -455,7 +458,7 @@ fn to_fields( field: source_field, type_info: type_of, is_required: false, - config, + config_set, invalid_path_handler: &invalid_path_handler, path_resolver_error_handler: &path_resolver_error_handler, original_path: &add_field.path, @@ -481,11 +484,11 @@ fn to_fields( }) } -pub fn to_definitions<'a>() -> TryFold<'a, Config, Vec, String> { - TryFold::, String>::new(|config, _| { - let output_types = config.output_types(); - let input_types = config.input_types(); - Valid::from_iter(config.types.iter(), |(name, type_)| { +pub fn to_definitions<'a>() -> TryFold<'a, ConfigSet, Vec, String> { + TryFold::, String>::new(|config_set, _| { + let output_types = config_set.output_types(); + let input_types = config_set.input_types(); + Valid::from_iter(config_set.types.iter(), |(name, type_)| { let dbl_usage = input_types.contains(name) && output_types.contains(name); if let Some(variants) = &type_.variants { if !variants.is_empty() { @@ -498,11 +501,11 @@ pub fn to_definitions<'a>() -> TryFold<'a, Config, Vec, String> { } else if dbl_usage { Valid::fail("type is used in input and output".to_string()).trace(name) } else { - to_object_type_definition(name, type_, config) + to_object_type_definition(name, type_, config_set) .trace(name) .and_then(|definition| match definition.clone() { Definition::ObjectTypeDefinition(object_type_definition) => { - if config.input_types().contains(name) { + if config_set.input_types().contains(name) { to_input_object_type_definition(object_type_definition).trace(name) } else if type_.interface { to_interface_type_definition(object_type_definition).trace(name) @@ -515,7 +518,7 @@ pub fn to_definitions<'a>() -> TryFold<'a, Config, Vec, String> { } }) .map(|mut types| { - types.extend(config.unions.iter().map(to_union_type_definition)); + types.extend(config_set.unions.iter().map(to_union_type_definition)); types }) }) diff --git a/src/blueprint/from_config.rs b/src/blueprint/from_config.rs index 2f7d80ff6b..76aca5c138 100644 --- a/src/blueprint/from_config.rs +++ b/src/blueprint/from_config.rs @@ -3,15 +3,16 @@ use std::collections::{BTreeMap, HashMap}; use super::{Server, TypeLike}; use crate::blueprint::compress::compress; use crate::blueprint::*; -use crate::config::{Arg, Batch, Config, Field}; +use crate::config::{Arg, Batch, Config, ConfigSet, Field}; use crate::json::JsonSchema; use crate::lambda::{Expression, IO}; use crate::try_fold::TryFold; use crate::valid::{Valid, ValidationError}; -pub fn config_blueprint<'a>() -> TryFold<'a, Config, Blueprint, String> { - let server = TryFoldConfig::::new(|config, blueprint| { - Valid::from(Server::try_from(config.server.clone())).map(|server| blueprint.server(server)) +pub fn config_blueprint<'a>() -> TryFold<'a, ConfigSet, Blueprint, String> { + let server = TryFoldConfig::::new(|config_set, blueprint| { + Valid::from(Server::try_from(config_set.server.clone())) + .map(|server| blueprint.server(server)) }); let schema = to_schema().transform::( @@ -105,12 +106,12 @@ where } } -impl TryFrom<&Config> for Blueprint { +impl TryFrom<&ConfigSet> for Blueprint { type Error = ValidationError; - fn try_from(config: &Config) -> Result { + fn try_from(config_set: &ConfigSet) -> Result { config_blueprint() - .try_fold(config, Blueprint::default()) + .try_fold(config_set, Blueprint::default()) .to_result() } } diff --git a/src/blueprint/mod.rs b/src/blueprint/mod.rs index 43c37389ae..b3f2d30988 100644 --- a/src/blueprint/mod.rs +++ b/src/blueprint/mod.rs @@ -20,10 +20,10 @@ pub use server::*; pub use timeout::GlobalTimeout; pub use upstream::*; -use crate::config::{Arg, Config, Field}; +use crate::config::{Arg, ConfigSet, Field}; use crate::try_fold::TryFold; -pub type TryFoldConfig<'a, A> = TryFold<'a, Config, A, String>; +pub type TryFoldConfig<'a, A> = TryFold<'a, ConfigSet, A, String>; pub(crate) trait TypeLike { fn name(&self) -> &str; diff --git a/src/blueprint/operators/const_field.rs b/src/blueprint/operators/const_field.rs index e5c31a1ea2..605107ef7d 100644 --- a/src/blueprint/operators/const_field.rs +++ b/src/blueprint/operators/const_field.rs @@ -2,7 +2,7 @@ use async_graphql_value::ConstValue; use crate::blueprint::*; use crate::config; -use crate::config::{Config, Field}; +use crate::config::Field; use crate::lambda::Expression; use crate::lambda::Expression::Literal; use crate::try_fold::TryFold; @@ -23,14 +23,14 @@ fn validate_data_with_schema( } pub struct CompileConst<'a> { - pub config: &'a config::Config, + pub config_set: &'a config::ConfigSet, pub field: &'a config::Field, pub value: &'a serde_json::Value, pub validate: bool, } pub fn compile_const(inputs: CompileConst) -> Valid { - let config = inputs.config; + let config_set = inputs.config_set; let field = inputs.field; let value = inputs.value; let validate = inputs.validate; @@ -39,7 +39,7 @@ pub fn compile_const(inputs: CompileConst) -> Valid { match ConstValue::from_json(data.to_owned()) { Ok(gql) => { let validation = if validate { - validate_data_with_schema(config, field, gql) + validate_data_with_schema(config_set, field, gql) } else { Valid::succeed(()) }; @@ -50,15 +50,20 @@ pub fn compile_const(inputs: CompileConst) -> Valid { } pub fn update_const_field<'a>( -) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new( - |(config, field, _, _), b_field| { +) -> TryFold<'a, (&'a ConfigSet, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { + TryFold::<(&ConfigSet, &Field, &config::Type, &str), FieldDefinition, String>::new( + |(config_set, field, _, _), b_field| { let Some(const_field) = &field.const_field else { return Valid::succeed(b_field); }; - compile_const(CompileConst { config, field, value: &const_field.data, validate: true }) - .map(|resolver| b_field.resolver(Some(resolver))) + compile_const(CompileConst { + config_set, + field, + value: &const_field.data, + validate: true, + }) + .map(|resolver| b_field.resolver(Some(resolver))) }, ) } diff --git a/src/blueprint/operators/expr.rs b/src/blueprint/operators/expr.rs index 4c8a88ccdf..986070e66f 100644 --- a/src/blueprint/operators/expr.rs +++ b/src/blueprint/operators/expr.rs @@ -1,6 +1,6 @@ use crate::blueprint::*; use crate::config; -use crate::config::{Config, ExprBody, Field, If}; +use crate::config::{ExprBody, Field, If}; use crate::lambda::{Expression, List, Logic, Math, Relation}; use crate::try_fold::TryFold; use crate::valid::Valid; @@ -8,19 +8,19 @@ use crate::valid::Valid; struct CompilationContext<'a> { config_field: &'a config::Field, operation_type: &'a config::GraphQLOperationType, - config: &'a config::Config, + config_set: &'a config::ConfigSet, } pub fn update_expr( operation_type: &config::GraphQLOperationType, -) -> TryFold<'_, (&Config, &Field, &config::Type, &str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new( - |(config, field, _, _), b_field| { +) -> TryFold<'_, (&ConfigSet, &Field, &config::Type, &str), FieldDefinition, String> { + TryFold::<(&ConfigSet, &Field, &config::Type, &str), FieldDefinition, String>::new( + |(config_set, field, _, _), b_field| { let Some(expr) = &field.expr else { return Valid::succeed(b_field); }; - let context = CompilationContext { config, operation_type, config_field: field }; + let context = CompilationContext { config_set, operation_type, config_field: field }; compile(&context, expr.body.clone()).map(|compiled| b_field.resolver(Some(compiled))) }, @@ -51,15 +51,15 @@ fn compile_ab( /// Compiles expr into Expression /// fn compile(ctx: &CompilationContext, expr: ExprBody) -> Valid { - let config = ctx.config; + let config_set = ctx.config_set; let field = ctx.config_field; let operation_type = ctx.operation_type; match expr { // Io Expr - ExprBody::Http(http) => compile_http(config, field, &http), + ExprBody::Http(http) => compile_http(config_set, field, &http), ExprBody::Grpc(grpc) => { let grpc = CompileGrpc { - config, + config_set, field, operation_type, grpc: &grpc, @@ -67,11 +67,11 @@ fn compile(ctx: &CompilationContext, expr: ExprBody) -> Valid compile_graphql(config, operation_type, &gql), + ExprBody::GraphQL(gql) => compile_graphql(config_set, operation_type, &gql), // Safe Expr ExprBody::Const(value) => { - compile_const(CompileConst { config, field, value: &value, validate: false }) + compile_const(CompileConst { config_set, field, value: &value, validate: false }) } // Logic @@ -178,7 +178,7 @@ mod tests { use serde_json::{json, Number}; use super::{compile, CompilationContext}; - use crate::config::{Config, Expr, Field, GraphQLOperationType}; + use crate::config::{ConfigSet, Expr, Field, GraphQLOperationType}; use crate::http::RequestContext; use crate::lambda::{Concurrent, Eval, EvaluationContext, ResolverContextLike}; @@ -217,11 +217,11 @@ mod tests { impl Expr { async fn eval(expr: serde_json::Value) -> anyhow::Result { let expr = serde_json::from_value::(expr)?; - let config = Config::default(); + let config_set = ConfigSet::default(); let field = Field::default(); let operation_type = GraphQLOperationType::Query; let context = CompilationContext { - config: &config, + config_set: &config_set, config_field: &field, operation_type: &operation_type, }; diff --git a/src/blueprint/operators/graphql.rs b/src/blueprint/operators/graphql.rs index 4be5959e7d..12dce4109f 100644 --- a/src/blueprint/operators/graphql.rs +++ b/src/blueprint/operators/graphql.rs @@ -1,5 +1,5 @@ use crate::blueprint::FieldDefinition; -use crate::config::{self, Config, Field, GraphQLOperationType}; +use crate::config::{self, ConfigSet, Field, GraphQLOperationType}; use crate::graphql::RequestTemplate; use crate::helpers; use crate::lambda::{Expression, Lambda}; @@ -40,8 +40,8 @@ pub fn compile_graphql( pub fn update_graphql<'a>( operation_type: &'a GraphQLOperationType, -) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( +) -> TryFold<'a, (&'a ConfigSet, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { + TryFold::<(&ConfigSet, &Field, &config::Type, &'a str), FieldDefinition, String>::new( |(config, field, type_of, _), b_field| { let Some(graphql) = &field.graphql else { return Valid::succeed(b_field); diff --git a/src/blueprint/operators/grpc.rs b/src/blueprint/operators/grpc.rs index f974d34d39..cccbcc3565 100644 --- a/src/blueprint/operators/grpc.rs +++ b/src/blueprint/operators/grpc.rs @@ -1,10 +1,9 @@ -use std::path::Path; - +use prost_reflect::prost_types::FileDescriptorSet; use prost_reflect::FieldDescriptor; use crate::blueprint::{FieldDefinition, TypeLike}; use crate::config::group_by::GroupBy; -use crate::config::{Config, Field, GraphQLOperationType, Grpc}; +use crate::config::{Config, ConfigSet, Field, GraphQLOperationType, Grpc}; use crate::grpc::protobuf::{ProtobufOperation, ProtobufSet}; use crate::grpc::request_template::RequestTemplate; use crate::json::JsonSchema; @@ -30,9 +29,12 @@ fn to_url(grpc: &Grpc, config: &Config) -> Valid { }) } -fn to_operation(grpc: &Grpc) -> Valid { +fn to_operation( + grpc: &Grpc, + file_descriptor_set: &FileDescriptorSet, +) -> Valid { Valid::from( - ProtobufSet::from_proto_file(Path::new(&grpc.proto_path)) + ProtobufSet::from_proto_file(file_descriptor_set) .map_err(|e| ValidationError::new(e.to_string())), ) .and_then(|set| { @@ -110,27 +112,30 @@ fn validate_group_by( } pub struct CompileGrpc<'a> { - pub config: &'a config::Config, - pub operation_type: &'a config::GraphQLOperationType, - pub field: &'a config::Field, - pub grpc: &'a config::Grpc, + pub config_set: &'a ConfigSet, + pub operation_type: &'a GraphQLOperationType, + pub field: &'a Field, + pub grpc: &'a Grpc, pub validate_with_schema: bool, } pub fn compile_grpc(inputs: CompileGrpc) -> Valid { - let config = inputs.config; + let config_set = inputs.config_set; let operation_type = inputs.operation_type; let field = inputs.field; let grpc = inputs.grpc; let validate_with_schema = inputs.validate_with_schema; - to_url(grpc, config) - .zip(to_operation(grpc)) + to_url(grpc, config_set) + .zip(to_operation( + grpc, + &config_set.extensions.grpc_file_descriptor, + )) .zip(helpers::headers::to_mustache_headers(&grpc.headers)) .zip(helpers::body::to_body(grpc.body.as_deref())) .and_then(|(((url, operation), headers), body)| { let validation = if validate_with_schema { - let field_schema = json_schema_from_field(config, field); + let field_schema = json_schema_from_field(config_set, field); if grpc.group_by.is_empty() { validate_schema(field_schema, &operation, field.name()).unit() } else { @@ -163,22 +168,22 @@ pub fn compile_grpc(inputs: CompileGrpc) -> Valid { pub fn update_grpc<'a>( operation_type: &'a GraphQLOperationType, -) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( - |(config, field, type_of, _name), b_field| { +) -> TryFold<'a, (&'a ConfigSet, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { + TryFold::<(&ConfigSet, &Field, &config::Type, &'a str), FieldDefinition, String>::new( + |(config_set, field, type_of, _name), b_field| { let Some(grpc) = &field.grpc else { return Valid::succeed(b_field); }; compile_grpc(CompileGrpc { - config, + config_set, operation_type, field, grpc, validate_with_schema: true, }) .map(|resolver| b_field.resolver(Some(resolver))) - .and_then(|b_field| b_field.validate_field(type_of, config).map_to(b_field)) + .and_then(|b_field| b_field.validate_field(type_of, config_set).map_to(b_field)) }, ) } diff --git a/src/blueprint/operators/http.rs b/src/blueprint/operators/http.rs index ce0990548b..135b4a6c83 100644 --- a/src/blueprint/operators/http.rs +++ b/src/blueprint/operators/http.rs @@ -1,6 +1,6 @@ use crate::blueprint::*; use crate::config::group_by::GroupBy; -use crate::config::{Config, Field}; +use crate::config::Field; use crate::endpoint::Endpoint; use crate::http::{Method, RequestTemplate}; use crate::lambda::{Expression, Lambda, IO}; @@ -9,7 +9,7 @@ use crate::valid::{Valid, ValidationError}; use crate::{config, helpers}; pub fn compile_http( - config: &config::Config, + config_set: &config::ConfigSet, field: &config::Field, http: &config::Http, ) -> Valid { @@ -20,12 +20,14 @@ pub fn compile_http( "GroupBy can only be applied if batching is enabled".to_string(), ) .when(|| { - (config.upstream.get_delay() < 1 || config.upstream.get_max_size() < 1) + (config_set.upstream.get_delay() < 1 || config_set.upstream.get_max_size() < 1) && !http.group_by.is_empty() }), ) .and(Valid::from_option( - http.base_url.as_ref().or(config.upstream.base_url.as_ref()), + http.base_url + .as_ref() + .or(config_set.upstream.base_url.as_ref()), "No base URL defined".to_string(), )) .zip(helpers::headers::to_mustache_headers(&http.headers)) @@ -39,8 +41,8 @@ pub fn compile_http( .iter() .map(|(k, v)| (k.clone(), v.clone())) .collect(); - let output_schema = to_json_schema_for_field(field, config); - let input_schema = to_json_schema_for_args(&field.args, config); + let output_schema = to_json_schema_for_field(field, config_set); + let input_schema = to_json_schema_for_args(&field.args, config_set); RequestTemplate::try_from( Endpoint::new(base_url.to_string()) @@ -69,16 +71,16 @@ pub fn compile_http( } pub fn update_http<'a>( -) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( - |(config, field, type_of, _), b_field| { +) -> TryFold<'a, (&'a ConfigSet, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { + TryFold::<(&ConfigSet, &Field, &config::Type, &'a str), FieldDefinition, String>::new( + |(config_set, field, type_of, _), b_field| { let Some(http) = &field.http else { return Valid::succeed(b_field); }; - compile_http(config, field, http) + compile_http(config_set, field, http) .map(|resolver| b_field.resolver(Some(resolver))) - .and_then(|b_field| b_field.validate_field(type_of, config).map_to(b_field)) + .and_then(|b_field| b_field.validate_field(type_of, config_set).map_to(b_field)) }, ) } diff --git a/src/blueprint/operators/modify.rs b/src/blueprint/operators/modify.rs index cdf8bcb708..f9f058ace2 100644 --- a/src/blueprint/operators/modify.rs +++ b/src/blueprint/operators/modify.rs @@ -1,13 +1,13 @@ use crate::blueprint::*; use crate::config; -use crate::config::{Config, Field}; +use crate::config::Field; use crate::lambda::Lambda; use crate::try_fold::TryFold; use crate::valid::Valid; pub fn update_modify<'a>( -) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( +) -> TryFold<'a, (&'a ConfigSet, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { + TryFold::<(&ConfigSet, &Field, &config::Type, &'a str), FieldDefinition, String>::new( |(config, field, type_of, _), mut b_field| { if let Some(modify) = field.modify.as_ref() { if let Some(new_name) = &modify.name { diff --git a/src/blueprint/upstream.rs b/src/blueprint/upstream.rs index 88bec02d34..cd3264fcaf 100644 --- a/src/blueprint/upstream.rs +++ b/src/blueprint/upstream.rs @@ -1,11 +1,13 @@ +use std::ops::Deref; + use super::TryFoldConfig; -use crate::config::{Config, Upstream}; +use crate::config::{ConfigSet, Upstream}; use crate::try_fold::TryFold; use crate::valid::{Valid, ValidationError}; -pub fn to_upstream<'a>() -> TryFold<'a, Config, Upstream, String> { - TryFoldConfig::::new(|config, up| { - let upstream = up.merge_right(config.upstream.clone()); +pub fn to_upstream<'a>() -> TryFold<'a, ConfigSet, Upstream, String> { + TryFoldConfig::::new(|config_set, up| { + let upstream = up.merge_right(config_set.deref().upstream.clone()); if let Some(ref base_url) = upstream.base_url { Valid::from( reqwest::Url::parse(base_url).map_err(|e| ValidationError::new(e.to_string())), diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 4904966f20..cda8baeb3d 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -8,6 +8,7 @@ pub(crate) mod http; pub mod javascript; pub mod server; mod tc; + use std::hash::Hash; use std::sync::Arc; diff --git a/src/cli/server/server.rs b/src/cli/server/server.rs index 1a90c703ee..68542d5d58 100644 --- a/src/cli/server/server.rs +++ b/src/cli/server/server.rs @@ -1,3 +1,4 @@ +use std::ops::Deref; use std::sync::Arc; use anyhow::Result; @@ -8,16 +9,16 @@ use super::http_2::start_http_2; use super::server_config::ServerConfig; use crate::blueprint::{Blueprint, Http}; use crate::cli::CLIError; -use crate::config::Config; +use crate::config::ConfigSet; pub struct Server { - config: Config, + config_set: ConfigSet, server_up_sender: Option>, } impl Server { - pub fn new(config: Config) -> Self { - Self { config, server_up_sender: None } + pub fn new(config_set: ConfigSet) -> Self { + Self { config_set, server_up_sender: None } } pub fn server_up_receiver(&mut self) -> oneshot::Receiver<()> { @@ -30,7 +31,7 @@ impl Server { /// Starts the server in the current Runtime pub async fn start(self) -> Result<()> { - let blueprint = Blueprint::try_from(&self.config).map_err(CLIError::from)?; + let blueprint = Blueprint::try_from(&self.config_set).map_err(CLIError::from)?; let server_config = Arc::new(ServerConfig::new(blueprint.clone())); match blueprint.server.http.clone() { @@ -42,9 +43,9 @@ impl Server { } /// Starts the server in its own multithreaded Runtime - pub async fn fork_start(self) -> anyhow::Result<()> { + pub async fn fork_start(self) -> Result<()> { let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(self.config.server.get_workers()) + .worker_threads(self.config_set.deref().server.get_workers()) .enable_all() .build()?; diff --git a/src/cli/tc.rs b/src/cli/tc.rs index 2188c88bde..271645928d 100644 --- a/src/cli/tc.rs +++ b/src/cli/tc.rs @@ -1,4 +1,5 @@ use std::path::Path; +use std::sync::Arc; use std::{env, fs}; use anyhow::Result; @@ -20,29 +21,29 @@ use crate::{print_schema, FileIO}; const FILE_NAME: &str = ".tailcallrc.graphql"; const YML_FILE_NAME: &str = ".graphqlrc.yml"; -pub async fn run() -> anyhow::Result<()> { +pub async fn run() -> Result<()> { let cli = Cli::parse(); logger_init(); update_checker::check_for_update().await; - let file_io: std::sync::Arc = init_file(); + let file_io: Arc = init_file(); let default_http_io = init_http(&Upstream::default(), None); let config_reader = ConfigReader::init(file_io.clone(), default_http_io); match cli.command { Command::Start { file_paths } => { - let config = config_reader.read_all(&file_paths).await?; - log::info!("N + 1: {}", config.n_plus_one().len().to_string()); - let server = Server::new(config); + let config_set = config_reader.read_all(&file_paths).await?; + log::info!("N + 1: {}", config_set.n_plus_one().len().to_string()); + let server = Server::new(config_set); server.fork_start().await?; Ok(()) } Command::Check { file_paths, n_plus_one_queries, schema, operations } => { - let config = (config_reader.read_all(&file_paths)).await?; - let blueprint = Blueprint::try_from(&config).map_err(CLIError::from); + let config_set = (config_reader.read_all(&file_paths)).await?; + let blueprint = Blueprint::try_from(&config_set).map_err(CLIError::from); match blueprint { Ok(blueprint) => { log::info!("{}", "Config successfully validated".to_string()); - display_config(&config, n_plus_one_queries); + display_config(&config_set, n_plus_one_queries); if schema { display_schema(&blueprint); } @@ -56,7 +57,7 @@ pub async fn run() -> anyhow::Result<()> { })) .await .into_iter() - .collect::>>()?; + .collect::>>()?; validate_operations(&blueprint, ops) .await diff --git a/src/config/config_set.rs b/src/config/config_set.rs new file mode 100644 index 0000000000..cff11c4d44 --- /dev/null +++ b/src/config/config_set.rs @@ -0,0 +1,54 @@ +use std::ops::Deref; + +use prost_reflect::prost_types::FileDescriptorSet; + +use crate::config::Config; + +/// A wrapper on top of Config that contains all the resolved extensions. +#[derive(Clone, Debug, Default)] +pub struct ConfigSet { + pub config: Config, + pub extensions: Extensions, +} + +/// Extensions are meta-information required before we can generate the blueprint. +/// Typically, this information cannot be inferred without performing an IO operation, i.e., +/// reading a file, making an HTTP call, etc. +#[derive(Clone, Debug, Default)] +pub struct Extensions { + pub grpc_file_descriptor: FileDescriptorSet, + + /// Contains the contents of the JS file + pub script: Option, +} + +impl Extensions { + pub fn merge_right(mut self, other: &Extensions) -> Self { + self.grpc_file_descriptor + .file + .extend(other.grpc_file_descriptor.file.clone()); + self.script = other.script.clone().or(self.script.take()); + self + } +} + +impl ConfigSet { + pub fn merge_right(mut self, other: &Self) -> Self { + self.config = self.config.merge_right(&other.config); + self.extensions = self.extensions.merge_right(&other.extensions); + self + } +} + +impl Deref for ConfigSet { + type Target = Config; + fn deref(&self) -> &Self::Target { + &self.config + } +} + +impl From for ConfigSet { + fn from(config: Config) -> Self { + ConfigSet { config, ..Default::default() } + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs index f06eb89ed6..18000a83cb 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,10 +1,12 @@ pub use config::*; +pub use config_set::*; pub use expr::*; pub use key_values::*; pub use server::*; pub use source::*; pub use upstream::*; mod config; +mod config_set; mod expr; mod from_document; pub mod group_by; diff --git a/src/config/reader.rs b/src/config/reader.rs index 944513a773..48527c8829 100644 --- a/src/config/reader.rs +++ b/src/config/reader.rs @@ -1,27 +1,34 @@ +use std::collections::{HashMap, VecDeque}; use std::sync::Arc; +use anyhow::Context; use futures_util::future::join_all; use futures_util::TryFutureExt; +use prost_reflect::prost_types::{FileDescriptorProto, FileDescriptorSet}; +use protox::file::{FileResolver, GoogleFileResolver}; use url::Url; -use super::{Script, ScriptOptions}; +use super::{ConfigSet, ExprBody, Extensions, Script, ScriptOptions}; use crate::config::{Config, Source}; use crate::{FileIO, HttpIO}; -/// Reads the configuration from a file or from an HTTP URL and resolves all linked assets. +const NULL_STR: &str = "\0\0\0\0\0\0\0"; + +/// Reads the configuration from a file or from an HTTP URL and resolves all linked extensions to create a ConfigSet. pub struct ConfigReader { - file: Arc, - http: Arc, + file_io: Arc, + http_io: Arc, } +/// Response of a file read operation struct FileRead { content: String, path: String, } impl ConfigReader { - pub fn init(file: Arc, http: Arc) -> Self { - Self { file, http } + pub fn init(file_io: Arc, http_io: Arc) -> Self { + Self { file_io, http_io } } /// Reads a file from the filesystem or from an HTTP URL @@ -29,14 +36,15 @@ impl ConfigReader { // Is an HTTP URL let content = if let Ok(url) = Url::parse(&file.to_string()) { let response = self - .http + .http_io .execute(reqwest::Request::new(reqwest::Method::GET, url)) .await?; String::from_utf8(response.body.to_vec())? } else { // Is a file path - self.file.read(&file.to_string()).await? + + self.file_io.read(&file.to_string()).await? }; Ok(FileRead { content, path: file.to_string() }) @@ -56,34 +64,197 @@ impl ConfigReader { } /// Reads the script file and replaces the path with the content - async fn read_script(&self, mut config: Config) -> anyhow::Result { - if let Some(Script::Path(options)) = config.server.script { + async fn ext_script(&self, mut config_set: ConfigSet) -> anyhow::Result { + let config = &mut config_set.config; + if let Some(Script::Path(ref options)) = &config.server.script { let timeout = options.timeout; - let path = options.src; - let script = self.read_file(path.clone()).await?.content; + let script = self.read_file(options.src.clone()).await?.content; config.server.script = Some(Script::File(ScriptOptions { src: script, timeout })); } - Ok(config) + Ok(config_set) } /// Reads a single file and returns the config - pub async fn read(&self, file: T) -> anyhow::Result { + pub async fn read(&self, file: T) -> anyhow::Result { self.read_all(&[file]).await } /// Reads all the files and returns a merged config - pub async fn read_all(&self, files: &[T]) -> anyhow::Result { + pub async fn read_all(&self, files: &[T]) -> anyhow::Result { let files = self.read_files(files).await?; - let mut config = Config::default(); + let mut config_set = ConfigSet::default(); + for file in files.iter() { let source = Source::detect(&file.path)?; let schema = &file.content; - let new_config = Config::from_source(source, schema)?; - let new_config = self.read_script(new_config).await?; - config = config.merge_right(&new_config); + + // Create initial config set + let new_config_set = self.resolve(Config::from_source(source, schema)?).await?; + + // Merge it with the original config set + config_set = config_set.merge_right(&new_config_set); + } + Ok(config_set) + } + + /// Resolves all the links in a Config to create a ConfigSet + pub async fn resolve(&self, config: Config) -> anyhow::Result { + // Create initial config set + let config_set = ConfigSet::from(config); + + // Extend it with the worker script + let config_set = self.ext_script(config_set).await?; + + // Extend it with protobuf definitions for GRPC + let config_set = self.ext_grpc(config_set).await?; + + Ok(config_set) + } + + /// Returns final ConfigSet from Config + pub async fn ext_grpc(&self, mut config_set: ConfigSet) -> anyhow::Result { + let config = &config_set.config; + let mut descriptors: HashMap = HashMap::new(); + let mut grpc_file_descriptor = FileDescriptorSet::default(); + for (_, typ) in config.types.iter() { + for (_, fld) in typ.fields.iter() { + let proto_path = if let Some(grpc) = &fld.grpc { + &grpc.proto_path + } else if let Some(ExprBody::Grpc(grpc)) = fld.expr.as_ref().map(|e| &e.body) { + &grpc.proto_path + } else { + NULL_STR + }; + + if proto_path != NULL_STR { + descriptors = self + .resolve_descriptors(descriptors, proto_path.to_string()) + .await?; + } + } + } + for (_, v) in descriptors { + grpc_file_descriptor.file.push(v); } - Ok(config) + config_set.extensions = Extensions { grpc_file_descriptor, ..Default::default() }; + Ok(config_set) + } + + /// Performs BFS to import all nested proto files + async fn resolve_descriptors( + &self, + mut descriptors: HashMap, + proto_path: String, + ) -> anyhow::Result> { + let parent_proto = self.read_proto(&proto_path).await?; + let mut queue = VecDeque::new(); + queue.push_back(parent_proto.clone()); + + while let Some(file) = queue.pop_front() { + for import in file.dependency.iter() { + let proto = self.read_proto(import).await?; + if descriptors.get(import).is_none() { + queue.push_back(proto.clone()); + descriptors.insert(import.clone(), proto); + } + } + } + + descriptors.insert(proto_path, parent_proto); + + Ok(descriptors) + } + + /// Tries to load well-known google proto files and if not found uses normal file and http IO to resolve them + async fn read_proto(&self, path: &str) -> anyhow::Result { + let content = if let Ok(file) = GoogleFileResolver::new().open_file(path) { + file.source() + .context("Unable to extract content of google well-known proto file")? + .to_string() + } else { + self.read_file(path).await?.content + }; + + Ok(protox_parse::parse(path, &content)?) + } +} + +#[cfg(test)] +mod test_proto_config { + use std::collections::HashMap; + use std::path::{Path, PathBuf}; + + use anyhow::{Context, Result}; + + use crate::cli::{init_file, init_http}; + use crate::config::reader::ConfigReader; + + #[tokio::test] + async fn test_resolve() { + // Skipping IO tests as they are covered in reader.rs + let reader = ConfigReader::init(init_file(), init_http(&Default::default(), None)); + reader + .read_proto("google/protobuf/empty.proto") + .await + .unwrap(); + } + + #[tokio::test] + async fn test_nested_imports() -> Result<()> { + let root_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let mut test_dir = root_dir.join(file!()); + test_dir.pop(); // config + test_dir.pop(); // src + + let mut root = test_dir.clone(); + root.pop(); + + test_dir.push("grpc"); // grpc + test_dir.push("tests"); // tests + + let mut test_file = test_dir.clone(); + + test_file.push("nested0.proto"); // nested0.proto + assert!(test_file.exists()); + let test_file = test_file.to_str().unwrap().to_string(); + + let reader = ConfigReader::init(init_file(), init_http(&Default::default(), None)); + let helper_map = reader + .resolve_descriptors(HashMap::new(), test_file) + .await?; + let files = test_dir.read_dir()?; + for file in files { + let file = file?; + let path = file.path(); + let path_str = + path_to_file_name(path.as_path()).context("It must be able to extract path")?; + let source = tokio::fs::read_to_string(path).await?; + let expected = protox_parse::parse(&path_str, &source)?; + let actual = helper_map.get(&expected.name.unwrap()).unwrap(); + + assert_eq!(&expected.dependency, &actual.dependency); + } + + Ok(()) + } + fn path_to_file_name(path: &Path) -> Option { + let components: Vec<_> = path.components().collect(); + + // Find the index of the "src" component + if let Some(src_index) = components.iter().position(|&c| c.as_os_str() == "src") { + // Reconstruct the path from the "src" component onwards + let after_src_components = &components[src_index..]; + let result = after_src_components + .iter() + .fold(PathBuf::new(), |mut acc, comp| { + acc.push(comp); + acc + }); + Some(result.to_str().unwrap().to_string()) + } else { + None + } } } @@ -103,6 +274,9 @@ mod reader_tests { #[tokio::test] async fn test_all() { + let file_io = init_file(); + let http_io = init_http(&Upstream::default(), None); + let mut cfg = Config::default(); cfg.schema.query = Some("Test".to_string()); cfg = cfg.types([("Test", Type::default())].to_vec()); @@ -135,7 +309,7 @@ mod reader_tests { .iter() .map(|x| x.to_string()) .collect(); - let cr = ConfigReader::init(init_file(), init_http(&Upstream::default(), None)); + let cr = ConfigReader::init(file_io, http_io); let c = cr.read_all(&files).await.unwrap(); assert_eq!( ["Post", "Query", "Test", "User"] @@ -153,6 +327,9 @@ mod reader_tests { #[tokio::test] async fn test_local_files() { + let file_io = init_file(); + let http_io = init_http(&Upstream::default(), None); + let files: Vec = [ "examples/jsonplaceholder.yml", "examples/jsonplaceholder.graphql", @@ -161,7 +338,7 @@ mod reader_tests { .iter() .map(|x| x.to_string()) .collect(); - let cr = ConfigReader::init(init_file(), init_http(&Upstream::default(), None)); + let cr = ConfigReader::init(file_io, http_io); let c = cr.read_all(&files).await.unwrap(); assert_eq!( ["Post", "Query", "User"] @@ -177,8 +354,11 @@ mod reader_tests { #[tokio::test] async fn test_script_loader() { + let file_io = init_file(); + let http_io = init_http(&Upstream::default(), None); + let cargo_manifest = std::env::var("CARGO_MANIFEST_DIR").unwrap(); - let reader = ConfigReader::init(init_file(), init_http(&Upstream::default(), None)); + let reader = ConfigReader::init(file_io, http_io); let config = reader .read(&format!( diff --git a/src/grpc/data_loader_request.rs b/src/grpc/data_loader_request.rs index 55ba03db14..434d644cb3 100644 --- a/src/grpc/data_loader_request.rs +++ b/src/grpc/data_loader_request.rs @@ -57,15 +57,17 @@ mod tests { use hyper::header::{HeaderName, HeaderValue}; use hyper::HeaderMap; - use once_cell::sync::Lazy; use pretty_assertions::assert_eq; use url::Url; use super::DataLoaderRequest; + use crate::cli::{init_file, init_http}; + use crate::config::reader::ConfigReader; + use crate::config::{Config, Field, Grpc, Type, Upstream}; use crate::grpc::protobuf::{ProtobufOperation, ProtobufSet}; use crate::grpc::request_template::RenderedRequestTemplate; - static PROTOBUF_OPERATION: Lazy = Lazy::new(|| { + async fn get_protobuf_op() -> ProtobufOperation { let root_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); let mut test_file = root_dir.join(file!()); @@ -73,19 +75,35 @@ mod tests { test_file.push("tests"); test_file.push("greetings.proto"); - let protobuf_set = ProtobufSet::from_proto_file(&test_file).unwrap(); + let file_io = init_file(); + let http_io = init_http(&Upstream::default(), None); + let mut config = Config::default(); + let grpc = Grpc { + proto_path: test_file.to_str().unwrap().to_string(), + ..Default::default() + }; + config.types.insert( + "foo".to_string(), + Type::default().fields(vec![("bar", Field::default().grpc(grpc))]), + ); + let reader = ConfigReader::init(file_io, http_io); + let config_set = reader.resolve(config).await.unwrap(); + + let protobuf_set = + ProtobufSet::from_proto_file(&config_set.extensions.grpc_file_descriptor).unwrap(); + let service = protobuf_set.find_service("Greeter").unwrap(); service.find_operation("SayHello").unwrap() - }); + } - #[test] - fn dataloader_req_empty_headers() { + #[tokio::test] + async fn dataloader_req_empty_headers() { let batch_headers = BTreeSet::default(); let tmpl = RenderedRequestTemplate { url: Url::parse("http://localhost:3000/").unwrap(), headers: HeaderMap::new(), - operation: PROTOBUF_OPERATION.clone(), + operation: get_protobuf_op().await, body: "{}".to_owned(), }; @@ -95,8 +113,8 @@ mod tests { assert_eq!(dl_req_1, dl_req_2); } - #[test] - fn dataloader_req_batch_headers() { + #[tokio::test] + async fn dataloader_req_batch_headers() { let batch_headers = BTreeSet::from_iter(["test-header".to_owned()]); let tmpl_1 = RenderedRequestTemplate { url: Url::parse("http://localhost:3000/").unwrap(), @@ -104,7 +122,7 @@ mod tests { HeaderName::from_static("test-header"), HeaderValue::from_static("value1"), )]), - operation: PROTOBUF_OPERATION.clone(), + operation: get_protobuf_op().await, body: "{}".to_owned(), }; let tmpl_2 = tmpl_1.clone(); diff --git a/src/grpc/protobuf.rs b/src/grpc/protobuf.rs index 0b3f8cc44d..c350bfd1b7 100644 --- a/src/grpc/protobuf.rs +++ b/src/grpc/protobuf.rs @@ -1,11 +1,10 @@ -use std::env::current_dir; use std::fmt::Debug; -use std::path::{Path, PathBuf}; use anyhow::{anyhow, bail, Context, Result}; use async_graphql::Value; use prost::bytes::BufMut; use prost::Message; +use prost_reflect::prost_types::FileDescriptorSet; use prost_reflect::{ DescriptorPool, DynamicMessage, MessageDescriptor, MethodDescriptor, ServiceDescriptor, }; @@ -71,24 +70,9 @@ impl ProtobufSet { // TODO: load definitions from proto file for now, but in future // it could be more convenient to load FileDescriptorSet instead // either from file or server reflection - pub fn from_proto_file(proto_path: &Path) -> Result { - let proto_path = if proto_path.is_relative() { - let dir = current_dir()?; - - dir.join(proto_path) - } else { - PathBuf::from(proto_path) - }; - - let parent_dir = proto_path - .parent() - .context("Failed to resolve parent dir for proto file")?; - - let file_descriptor_set = protox::compile([proto_path.as_path()], [parent_dir]) - .with_context(|| "Failed to parse or load proto file".to_string())?; - - let descriptor_pool = DescriptorPool::from_file_descriptor_set(file_descriptor_set)?; - + pub fn from_proto_file(file_descriptor_set: &FileDescriptorSet) -> Result { + let descriptor_pool = + DescriptorPool::from_file_descriptor_set(file_descriptor_set.clone())?; Ok(Self { descriptor_pool }) } @@ -209,6 +193,7 @@ impl ProtobufOperation { #[cfg(test)] mod tests { + // TODO: Rewrite protobuf tests use std::path::PathBuf; use anyhow::Result; @@ -217,6 +202,9 @@ mod tests { use serde_json::json; use super::*; + use crate::cli::{init_file, init_http}; + use crate::config::reader::ConfigReader; + use crate::config::{Config, Field, Grpc, Type, Upstream}; static TEST_DIR: Lazy = Lazy::new(|| { let root_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); @@ -235,6 +223,29 @@ mod tests { test_file } + async fn get_proto_file(name: &str) -> Result { + let file_io = init_file(); + let http_io = init_http(&Upstream::default(), None); + let reader = ConfigReader::init(file_io, http_io); + let mut config = Config::default(); + let grpc = Grpc { + proto_path: get_test_file(name) + .to_str() + .context("Failed to parse or load proto file")? + .to_string(), + ..Default::default() + }; + config.types.insert( + "foo".to_string(), + Type::default().fields(vec![("bar", Field::default().grpc(grpc))]), + ); + Ok(reader + .resolve(config) + .await? + .extensions + .grpc_file_descriptor) + } + #[test] fn convert_value() { assert_eq!( @@ -264,23 +275,21 @@ mod tests { ); } - #[test] - fn unknown_file() -> Result<()> { - let proto_file = get_test_file("_unknown.proto"); - let error = ProtobufSet::from_proto_file(&proto_file).unwrap_err(); + #[tokio::test] + async fn unknown_file() -> Result<()> { + let error = get_proto_file("_unknown.proto").await.unwrap_err(); assert_eq!( error.to_string(), - format!("Failed to parse or load proto file") + "No such file or directory (os error 2)".to_string() ); Ok(()) } - #[test] - fn service_not_found() -> Result<()> { - let proto_file = get_test_file("greetings.proto"); - let file = ProtobufSet::from_proto_file(&proto_file)?; + #[tokio::test] + async fn service_not_found() -> Result<()> { + let file = ProtobufSet::from_proto_file(&get_proto_file("greetings.proto").await?)?; let error = file.find_service("_unknown").unwrap_err(); assert_eq!( @@ -291,10 +300,9 @@ mod tests { Ok(()) } - #[test] - fn method_not_found() -> Result<()> { - let proto_file = get_test_file("greetings.proto"); - let file = ProtobufSet::from_proto_file(&proto_file)?; + #[tokio::test] + async fn method_not_found() -> Result<()> { + let file = ProtobufSet::from_proto_file(&get_proto_file("greetings.proto").await?)?; let service = file.find_service("Greeter")?; let error = service.find_operation("_unknown").unwrap_err(); @@ -303,10 +311,9 @@ mod tests { Ok(()) } - #[test] - fn greetings_proto_file() -> Result<()> { - let proto_file = get_test_file("greetings.proto"); - let file = ProtobufSet::from_proto_file(&proto_file)?; + #[tokio::test] + async fn greetings_proto_file() -> Result<()> { + let file = ProtobufSet::from_proto_file(&get_proto_file("greetings.proto").await?)?; let service = file.find_service("Greeter")?; let operation = service.find_operation("SayHello")?; @@ -324,10 +331,9 @@ mod tests { Ok(()) } - #[test] - fn news_proto_file() -> Result<()> { - let proto_file = get_test_file("news.proto"); - let file = ProtobufSet::from_proto_file(&proto_file)?; + #[tokio::test] + async fn news_proto_file() -> Result<()> { + let file = ProtobufSet::from_proto_file(&get_proto_file("news.proto").await?)?; let service = file.find_service("NewsService")?; let operation = service.find_operation("GetNews")?; @@ -349,10 +355,9 @@ mod tests { Ok(()) } - #[test] - fn news_proto_file_multiple_messages() -> Result<()> { - let proto_file = get_test_file("news.proto"); - let file = ProtobufSet::from_proto_file(&proto_file)?; + #[tokio::test] + async fn news_proto_file_multiple_messages() -> Result<()> { + let file = ProtobufSet::from_proto_file(&get_proto_file("news.proto").await?)?; let service = file.find_service("NewsService")?; let multiple_operation = service.find_operation("GetMultipleNews")?; diff --git a/src/grpc/request_template.rs b/src/grpc/request_template.rs index 3c0bf71f42..2774b3af94 100644 --- a/src/grpc/request_template.rs +++ b/src/grpc/request_template.rs @@ -103,15 +103,16 @@ mod tests { use derive_setters::Setters; use hyper::header::{HeaderName, HeaderValue}; use hyper::{HeaderMap, Method}; - use once_cell::sync::Lazy; use pretty_assertions::assert_eq; use super::RequestTemplate; - use crate::config::GraphQLOperationType; + use crate::cli::{init_file, init_http}; + use crate::config::reader::ConfigReader; + use crate::config::{Config, Field, GraphQLOperationType, Grpc, Type, Upstream}; use crate::grpc::protobuf::{ProtobufOperation, ProtobufSet}; use crate::mustache::Mustache; - static PROTOBUF_OPERATION: Lazy = Lazy::new(|| { + async fn get_protobuf_op() -> ProtobufOperation { let root_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); let mut test_file = root_dir.join(file!()); @@ -119,11 +120,33 @@ mod tests { test_file.push("tests"); test_file.push("greetings.proto"); - let protobuf_set = ProtobufSet::from_proto_file(&test_file).unwrap(); + let file_io = init_file(); + let http_io = init_http(&Upstream::default(), None); + let reader = ConfigReader::init(file_io, http_io); + let mut config = Config::default(); + let grpc = Grpc { + proto_path: test_file.to_str().unwrap().to_string(), + ..Default::default() + }; + config.types.insert( + "foo".to_string(), + Type::default().fields(vec![("bar", Field::default().grpc(grpc))]), + ); + + let protobuf_set = ProtobufSet::from_proto_file( + &reader + .resolve(config) + .await + .unwrap() + .extensions + .grpc_file_descriptor, + ) + .unwrap(); + let service = protobuf_set.find_service("Greeter").unwrap(); service.find_operation("SayHello").unwrap() - }); + } #[derive(Setters)] struct Context { @@ -149,15 +172,15 @@ mod tests { } } - #[test] - fn request_with_empty_body() { + #[tokio::test] + async fn request_with_empty_body() { let tmpl = RequestTemplate { url: Mustache::parse("http://localhost:3000/").unwrap(), headers: vec![( HeaderName::from_static("test-header"), Mustache::parse("value").unwrap(), )], - operation: PROTOBUF_OPERATION.clone(), + operation: get_protobuf_op().await, body: None, operation_type: GraphQLOperationType::Query, }; @@ -186,12 +209,12 @@ mod tests { } } - #[test] - fn request_with_body() { + #[tokio::test] + async fn request_with_body() { let tmpl = RequestTemplate { url: Mustache::parse("http://localhost:3000/").unwrap(), headers: vec![], - operation: PROTOBUF_OPERATION.clone(), + operation: get_protobuf_op().await, body: Some(Mustache::parse(r#"{ "name": "test" }"#).unwrap()), operation_type: GraphQLOperationType::Query, }; diff --git a/src/grpc/tests/cycle.proto b/src/grpc/tests/cycle.proto new file mode 100644 index 0000000000..61a572906b --- /dev/null +++ b/src/grpc/tests/cycle.proto @@ -0,0 +1,4 @@ +syntax = "proto3"; + +import "src/grpc/tests/nested0.proto"; +import "src/grpc/tests/duplicate.proto"; diff --git a/src/grpc/tests/duplicate.proto b/src/grpc/tests/duplicate.proto new file mode 100644 index 0000000000..ded85bbf79 --- /dev/null +++ b/src/grpc/tests/duplicate.proto @@ -0,0 +1,4 @@ +syntax = "proto3"; + +import "src/grpc/tests/greetings.proto"; +import "src/grpc/tests/news.proto"; diff --git a/src/grpc/tests/nested0.proto b/src/grpc/tests/nested0.proto new file mode 100644 index 0000000000..3a443c6b76 --- /dev/null +++ b/src/grpc/tests/nested0.proto @@ -0,0 +1,4 @@ +syntax = "proto3"; + +import "src/grpc/tests/greetings.proto"; +import "src/grpc/tests/nested1.proto"; \ No newline at end of file diff --git a/src/grpc/tests/nested1.proto b/src/grpc/tests/nested1.proto new file mode 100644 index 0000000000..c26138174c --- /dev/null +++ b/src/grpc/tests/nested1.proto @@ -0,0 +1,4 @@ +syntax = "proto3"; + +import "src/grpc/tests/news.proto"; +import "src/grpc/tests/cycle.proto"; diff --git a/src/lib.rs b/src/lib.rs index 70944ff5be..d46e6bc1d4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,7 +44,7 @@ pub trait HttpIO: Sync + Send + 'static { } #[async_trait::async_trait] -pub trait FileIO { +pub trait FileIO: Send + Sync { async fn write<'a>(&'a self, path: &'a str, content: &'a [u8]) -> anyhow::Result<()>; async fn read<'a>(&'a self, path: &'a str) -> anyhow::Result; } diff --git a/tests/graphql/errors/test-grpc-proto-path.graphql b/tests/graphql/errors/test-grpc-proto-path.graphql index aa7ff1a4ed..d97a703318 100644 --- a/tests/graphql/errors/test-grpc-proto-path.graphql +++ b/tests/graphql/errors/test-grpc-proto-path.graphql @@ -25,4 +25,4 @@ type News { } #> client-sdl -type Failure @error(message: "Failed to parse or load proto file", trace: ["Query", "news", "@grpc"]) +type Failure @error(message: "No such file or directory (os error 2)", trace: []) diff --git a/tests/graphql_spec.rs b/tests/graphql_spec.rs index 930ea41541..9ddfa0aed1 100644 --- a/tests/graphql_spec.rs +++ b/tests/graphql_spec.rs @@ -15,12 +15,13 @@ use regex::Regex; use serde::{Deserialize, Serialize}; use serde_json::Value; use tailcall::blueprint::Blueprint; -use tailcall::cli::{init_env, init_http, init_in_memory_cache}; -use tailcall::config::Config; +use tailcall::cli::{init_env, init_file, init_http, init_in_memory_cache}; +use tailcall::config::reader::ConfigReader; +use tailcall::config::{Config, ConfigSet}; use tailcall::directive::DirectiveCodec; use tailcall::http::{AppContext, RequestContext}; use tailcall::print_schema; -use tailcall::valid::{Cause, Valid}; +use tailcall::valid::{Cause, Valid, ValidationError}; static INIT: Once = Once::new(); @@ -278,9 +279,10 @@ fn test_config_identity() -> std::io::Result<()> { } // Check server SDL matches expected client SDL -#[test] -fn test_server_to_client_sdl() -> std::io::Result<()> { +#[tokio::test] +async fn test_server_to_client_sdl() -> std::io::Result<()> { let specs = GraphQLSpec::cargo_read("tests/graphql"); + let file_io = init_file(); for spec in specs? { let expected = spec.find_source(Tag::ClientSDL); @@ -288,8 +290,11 @@ fn test_server_to_client_sdl() -> std::io::Result<()> { let content = spec.find_source(Tag::ServerSDL); let content = content.as_str(); let config = Config::from_sdl(content).to_result().unwrap(); + let upstream = config.upstream.clone(); + let reader = ConfigReader::init(file_io.clone(), init_http(&upstream, None)); + let config_set = reader.resolve(config).await.unwrap(); let actual = - print_schema::print_schema((Blueprint::try_from(&config).unwrap()).to_schema()); + print_schema::print_schema((Blueprint::try_from(&config_set).unwrap()).to_schema()); if spec .annotation @@ -320,8 +325,8 @@ async fn test_execution() -> std::io::Result<()> { .to_result() .unwrap(); config.server.query_validation = Some(false); - - let blueprint = Valid::from(Blueprint::try_from(&config)) + let config_set = ConfigSet::from(config); + let blueprint = Valid::from(Blueprint::try_from(&config_set)) .trace(spec.path.to_str().unwrap_or_default()) .to_result() .unwrap(); @@ -371,19 +376,31 @@ async fn test_execution() -> std::io::Result<()> { } // Standardize errors on Client SDL -#[test] -fn test_failures_in_client_sdl() -> std::io::Result<()> { +#[tokio::test] +async fn test_failures_in_client_sdl() -> std::io::Result<()> { let specs = GraphQLSpec::cargo_read("tests/graphql/errors"); + let file_io = init_file(); for spec in specs? { let content = spec.find_source(Tag::ServerSDL); let expected = spec.sdl_errors; let content = content.as_str(); - let config = Config::from_sdl(content); - - let actual = config - .and_then(|config| Valid::from(Blueprint::try_from(&config))) - .to_result(); + println!("{:?}", spec.path); + + let config = Config::from_sdl(content).to_result(); + let actual = match config { + Ok(config) => { + let upstream = config.upstream.clone(); + let reader = ConfigReader::init(file_io.clone(), init_http(&upstream, None)); + match reader.resolve(config).await { + Ok(config_set) => Valid::from(Blueprint::try_from(&config_set)) + .to_result() + .map(|_| ()), + Err(e) => Err(ValidationError::new(e.to_string())), + } + } + Err(e) => Err(e), + }; match actual { Err(cause) => { let actual: Vec = diff --git a/tests/http_spec.rs b/tests/http_spec.rs index 757446f15c..6873fa4743 100644 --- a/tests/http_spec.rs +++ b/tests/http_spec.rs @@ -20,7 +20,7 @@ use tailcall::async_graphql_hyper::{GraphQLBatchRequest, GraphQLRequest}; use tailcall::blueprint::Blueprint; use tailcall::cli::{init_file, init_hook_http, init_http, init_in_memory_cache}; use tailcall::config::reader::ConfigReader; -use tailcall::config::{Config, Source, Upstream}; +use tailcall::config::{Config, ConfigSet, Source, Upstream}; use tailcall::http::{handle_request, AppContext, Method, Response}; use tailcall::{EnvIO, HttpIO}; use url::Url; @@ -199,13 +199,14 @@ impl HttpSpec { } async fn server_context(&self) -> Arc { + let file_io = init_file(); let http_client = init_http(&Upstream::default(), None); let config = match self.config.clone() { ConfigSource::File(file) => { - let reader = ConfigReader::init(init_file(), http_client); + let reader = ConfigReader::init(file_io, http_client); reader.read_all(&[file]).await.unwrap() } - ConfigSource::Inline(config) => config, + ConfigSource::Inline(config) => ConfigSet::from(config), }; let blueprint = Blueprint::try_from(&config).unwrap(); let client = init_hook_http( diff --git a/tests/server_spec.rs b/tests/server_spec.rs index 3ae7a8a8ab..6f4fe775ab 100644 --- a/tests/server_spec.rs +++ b/tests/server_spec.rs @@ -6,8 +6,9 @@ use tailcall::config::reader::ConfigReader; use tailcall::config::Upstream; async fn test_server(configs: &[&str], url: &str) { + let file_io = init_file(); let http_client = init_http(&Upstream::default(), None); - let reader = ConfigReader::init(init_file(), http_client); + let reader = ConfigReader::init(file_io, http_client); let config = reader.read_all(configs).await.unwrap(); let mut server = Server::new(config); let server_up_receiver = server.server_up_receiver(); @@ -91,8 +92,9 @@ async fn server_start_http2_rsa() { #[tokio::test] async fn server_start_http2_nokey() { let configs = &["tests/server/config/server-start-http2-nokey.graphql"]; + let file_io = init_file(); let http_client = init_http(&Upstream::default(), None); - let reader = ConfigReader::init(init_file(), http_client); + let reader = ConfigReader::init(file_io, http_client); let config = reader.read_all(configs).await.unwrap(); let server = Server::new(config); assert!(server.start().await.is_err())