From 734b3f08ff11c83fb2db860e875ded18a007e15f Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Wed, 7 Oct 2020 22:21:34 -0400 Subject: [PATCH] feat: flask asyncio support for dataloaders --- graphql_server/flask/graphqlview.py | 76 +++++++++++++++++------------ 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/graphql_server/flask/graphqlview.py b/graphql_server/flask/graphqlview.py index 1b33433..a5e2746 100644 --- a/graphql_server/flask/graphqlview.py +++ b/graphql_server/flask/graphqlview.py @@ -1,3 +1,4 @@ +import asyncio import copy from collections.abc import MutableMapping from functools import partial @@ -6,6 +7,7 @@ from flask import Response, render_template_string, request from flask.views import View from graphql.error import GraphQLError +from graphql.pyutils import is_awaitable from graphql.type.schema import GraphQLSchema from graphql_server import ( @@ -41,6 +43,7 @@ class GraphQLView(View): default_query = None header_editor_enabled = None should_persist_headers = None + enable_async = False methods = ["GET", "POST", "PUT", "DELETE"] @@ -53,19 +56,13 @@ def __init__(self, **kwargs): if hasattr(self, key): setattr(self, key, value) - assert isinstance( - self.schema, GraphQLSchema - ), "A Schema is required to be provided to GraphQLView." + assert isinstance(self.schema, GraphQLSchema), "A Schema is required to be provided to GraphQLView." def get_root_value(self): return self.root_value def get_context(self): - context = ( - copy.copy(self.context) - if self.context and isinstance(self.context, MutableMapping) - else {} - ) + context = copy.copy(self.context) if self.context and isinstance(self.context, MutableMapping) else {} if isinstance(context, MutableMapping) and "request" not in context: context.update({"request": request}) return context @@ -73,6 +70,37 @@ def get_context(self): def get_middleware(self): return self.middleware + def result_results(self, request_method, data, catch): + return run_http_query( + self.schema, + request_method, + data, + query_data=request.args, + batch_enabled=self.batch, + catch=catch, + # Execute options + root_value=self.get_root_value(), + context_value=self.get_context(), + middleware=self.get_middleware(), + run_sync=not self.enable_async, + ) + + async def resolve_results_async(self, request_method, data, catch): + execution_results, all_params = run_http_query( + self.schema, + request_method, + data, + query_data=request.args, + batch_enabled=self.batch, + catch=catch, + # Execute options + root_value=self.get_root_value(), + context_value=self.get_context(), + middleware=self.get_middleware(), + run_sync=not self.enable_async, + ) + return [await ex if is_awaitable(ex) else ex for ex in execution_results], all_params + def dispatch_request(self): try: request_method = request.method.lower() @@ -84,18 +112,11 @@ def dispatch_request(self): pretty = self.pretty or show_graphiql or request.args.get("pretty") all_params: List[GraphQLParams] - execution_results, all_params = run_http_query( - self.schema, - request_method, - data, - query_data=request.args, - batch_enabled=self.batch, - catch=catch, - # Execute options - root_value=self.get_root_value(), - context_value=self.get_context(), - middleware=self.get_middleware(), - ) + if self.enable_async: + execution_results, all_params = asyncio.run(self.resolve_results_async(request_method, data, catch)) + else: + execution_results, all_params = self.result_results(request_method, data, catch) + result, status_code = encode_execution_results( execution_results, is_batch=isinstance(data, list), @@ -123,9 +144,7 @@ def dispatch_request(self): header_editor_enabled=self.header_editor_enabled, should_persist_headers=self.should_persist_headers, ) - source = render_graphiql_sync( - data=graphiql_data, config=graphiql_config, options=graphiql_options - ) + source = render_graphiql_sync(data=graphiql_data, config=graphiql_config, options=graphiql_options) return render_template_string(source) return Response(result, status=status_code, content_type="application/json") @@ -150,10 +169,7 @@ def parse_body(self): elif content_type == "application/json": return load_json_body(request.data.decode("utf8")) - elif content_type in ( - "application/x-www-form-urlencoded", - "multipart/form-data", - ): + elif content_type in ("application/x-www-form-urlencoded", "multipart/form-data",): return request.form return {} @@ -166,8 +182,4 @@ def should_display_graphiql(self): def request_wants_html(self): best = request.accept_mimetypes.best_match(["application/json", "text/html"]) - return ( - best == "text/html" - and request.accept_mimetypes[best] - > request.accept_mimetypes["application/json"] - ) + return best == "text/html" and request.accept_mimetypes[best] > request.accept_mimetypes["application/json"]