From bdb41ea0004acc8b718a1509b8d2ce1f6863a524 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Wed, 6 Jan 2021 15:16:06 -0500 Subject: [PATCH 1/3] fix: enable async cleaned up with hook points --- graphql_server/__init__.py | 33 ++++++++------------------ graphql_server/flask/graphqlview.py | 36 ++++++++++++++++------------- 2 files changed, 29 insertions(+), 40 deletions(-) diff --git a/graphql_server/__init__.py b/graphql_server/__init__.py index 8942332..b96dfc2 100644 --- a/graphql_server/__init__.py +++ b/graphql_server/__init__.py @@ -15,7 +15,7 @@ from graphql.error import format_error as format_error_default from graphql.execution import ExecutionResult, execute from graphql.language import OperationType, parse -from graphql.pyutils import AwaitableOrValue +from graphql.pyutils import AwaitableOrValue, is_awaitable from graphql.type import GraphQLSchema, validate_schema from graphql.utilities import get_operation_ast from graphql.validation import ASTValidationRule, validate @@ -99,9 +99,7 @@ def run_http_query( if not is_batch: if not isinstance(data, (dict, MutableMapping)): - raise HttpQueryError( - 400, f"GraphQL params should be a dict. Received {data!r}." - ) + raise HttpQueryError(400, f"GraphQL params should be a dict. Received {data!r}.") data = [data] elif not batch_enabled: raise HttpQueryError(400, "Batch GraphQL requests are not enabled.") @@ -114,15 +112,10 @@ def run_http_query( if not is_batch: extra_data = query_data or {} - all_params: List[GraphQLParams] = [ - get_graphql_params(entry, extra_data) for entry in data - ] + all_params: List[GraphQLParams] = [get_graphql_params(entry, extra_data) for entry in data] results: List[Optional[AwaitableOrValue[ExecutionResult]]] = [ - get_response( - schema, params, catch_exc, allow_only_query, run_sync, **execute_options - ) - for params in all_params + get_response(schema, params, catch_exc, allow_only_query, run_sync, **execute_options) for params in all_params ] return GraphQLResponse(results, all_params) @@ -160,10 +153,7 @@ def encode_execution_results( Returns a ServerResponse tuple with the serialized response as the first item and a status code of 200 or 400 in case any result was invalid as the second item. """ - results = [ - format_execution_result(execution_result, format_error) - for execution_result in execution_results - ] + results = [format_execution_result(execution_result, format_error) for execution_result in execution_results] result, status_codes = zip(*results) status_code = max(status_codes) @@ -274,14 +264,11 @@ def get_response( if operation != OperationType.QUERY.value: raise HttpQueryError( 405, - f"Can only perform a {operation} operation" - " from a POST request.", + f"Can only perform a {operation} operation" " from a POST request.", headers={"Allow": "POST"}, ) - validation_errors = validate( - schema, document, rules=validation_rules, max_errors=max_errors - ) + validation_errors = validate(schema, document, rules=validation_rules, max_errors=max_errors) if validation_errors: return ExecutionResult(data=None, errors=validation_errors) @@ -290,7 +277,7 @@ def get_response( document, variable_values=params.variables, operation_name=params.operation_name, - is_awaitable=assume_not_awaitable if run_sync else None, + is_awaitable=assume_not_awaitable if run_sync else is_awaitable, **kwargs, ) @@ -317,9 +304,7 @@ def format_execution_result( fe = [format_error(e) for e in execution_result.errors] # type: ignore response = {"errors": fe} - if execution_result.errors and any( - not getattr(e, "path", None) for e in execution_result.errors - ): + if execution_result.errors and any(not getattr(e, "path", None) for e in execution_result.errors): status_code = 400 else: response["data"] = execution_result.data diff --git a/graphql_server/flask/graphqlview.py b/graphql_server/flask/graphqlview.py index a417406..16a8c8b 100644 --- a/graphql_server/flask/graphqlview.py +++ b/graphql_server/flask/graphqlview.py @@ -1,3 +1,4 @@ +import asyncio import copy from collections.abc import MutableMapping from functools import partial @@ -5,7 +6,9 @@ from flask import Response, render_template_string, request from flask.views import View +from graphql import ExecutionResult from graphql.error import GraphQLError +from graphql.pyutils import is_awaitable from graphql.type.schema import GraphQLSchema from graphql_server import ( @@ -41,6 +44,7 @@ class GraphQLView(View): default_query = None header_editor_enabled = None should_persist_headers = None + enable_async = True methods = ["GET", "POST", "PUT", "DELETE"] @@ -53,19 +57,13 @@ def __init__(self, **kwargs): if hasattr(self, key): setattr(self, key, value) - assert isinstance( - self.schema, GraphQLSchema - ), "A Schema is required to be provided to GraphQLView." + assert isinstance(self.schema, GraphQLSchema), "A Schema is required to be provided to GraphQLView." def get_root_value(self): return self.root_value def get_context(self): - context = ( - copy.copy(self.context) - if self.context and isinstance(self.context, MutableMapping) - else {} - ) + context = copy.copy(self.context) if self.context and isinstance(self.context, MutableMapping) else {} if isinstance(context, MutableMapping) and "request" not in context: context.update({"request": request}) return context @@ -73,6 +71,13 @@ def get_context(self): def get_middleware(self): return self.middleware + @staticmethod + def get_async_execution_results(execution_results): + async def await_execution_results(execution_results): + return [ex if ex is None or is_awaitable(ex) else await ex for ex in execution_results] + + return asyncio.run(await_execution_results(execution_results)) + def dispatch_request(self): try: request_method = request.method.lower() @@ -96,6 +101,11 @@ def dispatch_request(self): context_value=self.get_context(), middleware=self.get_middleware(), ) + + if self.enable_async: + if any(is_awaitable(ex) for ex in execution_results): + execution_results = self.get_async_execution_results(execution_results) + result, status_code = encode_execution_results( execution_results, is_batch=isinstance(data, list), @@ -123,9 +133,7 @@ def dispatch_request(self): header_editor_enabled=self.header_editor_enabled, should_persist_headers=self.should_persist_headers, ) - source = render_graphiql_sync( - data=graphiql_data, config=graphiql_config, options=graphiql_options - ) + source = render_graphiql_sync(data=graphiql_data, config=graphiql_config, options=graphiql_options) return render_template_string(source) return Response(result, status=status_code, content_type="application/json") @@ -167,8 +175,4 @@ def should_display_graphiql(self): @staticmethod def request_wants_html(): best = request.accept_mimetypes.best_match(["application/json", "text/html"]) - return ( - best == "text/html" - and request.accept_mimetypes[best] - > request.accept_mimetypes["application/json"] - ) + return best == "text/html" and request.accept_mimetypes[best] > request.accept_mimetypes["application/json"] From 2955648c1d991de92ac9402ee41f5f8c9a86ae70 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Wed, 6 Jan 2021 16:38:15 -0500 Subject: [PATCH 2/3] fix: fix up calls and added async unit tests --- graphql_server/flask/graphqlview.py | 5 +- tests/flask/app.py | 9 ++-- tests/flask/schema.py | 25 +++++++-- tests/flask/test_graphqlview.py | 79 +++++++++++++++-------------- 4 files changed, 71 insertions(+), 47 deletions(-) diff --git a/graphql_server/flask/graphqlview.py b/graphql_server/flask/graphqlview.py index 16a8c8b..9d8b073 100644 --- a/graphql_server/flask/graphqlview.py +++ b/graphql_server/flask/graphqlview.py @@ -44,7 +44,7 @@ class GraphQLView(View): default_query = None header_editor_enabled = None should_persist_headers = None - enable_async = True + enable_async = False methods = ["GET", "POST", "PUT", "DELETE"] @@ -74,7 +74,7 @@ def get_middleware(self): @staticmethod def get_async_execution_results(execution_results): async def await_execution_results(execution_results): - return [ex if ex is None or is_awaitable(ex) else await ex for ex in execution_results] + return [ex if ex is None or not is_awaitable(ex) else await ex for ex in execution_results] return asyncio.run(await_execution_results(execution_results)) @@ -100,6 +100,7 @@ def dispatch_request(self): root_value=self.get_root_value(), context_value=self.get_context(), middleware=self.get_middleware(), + run_sync=not self.enable_async, ) if self.enable_async: diff --git a/tests/flask/app.py b/tests/flask/app.py index ec9e9d0..3d2a83c 100644 --- a/tests/flask/app.py +++ b/tests/flask/app.py @@ -1,15 +1,16 @@ from flask import Flask from graphql_server.flask import GraphQLView -from tests.flask.schema import Schema +from tests.flask.schema import AsyncSchema, Schema def create_app(path="/graphql", **kwargs): server = Flask(__name__) server.debug = True - server.add_url_rule( - path, view_func=GraphQLView.as_view("graphql", schema=Schema, **kwargs) - ) + if kwargs.get("enable_async", None): + server.add_url_rule(path, view_func=GraphQLView.as_view("graphql", schema=AsyncSchema, **kwargs)) + else: + server.add_url_rule(path, view_func=GraphQLView.as_view("graphql", schema=Schema, **kwargs)) return server diff --git a/tests/flask/schema.py b/tests/flask/schema.py index eb51e26..bd73da3 100644 --- a/tests/flask/schema.py +++ b/tests/flask/schema.py @@ -43,9 +43,28 @@ def resolve_raises(*_): MutationRootType = GraphQLObjectType( name="MutationRoot", - fields={ - "writeTest": GraphQLField(type_=QueryRootType, resolve=lambda *_: QueryRootType) - }, + fields={"writeTest": GraphQLField(type_=QueryRootType, resolve=lambda *_: QueryRootType)}, ) Schema = GraphQLSchema(QueryRootType, MutationRootType) + + +async def async_resolver(obj, info): + return "async" + + +AsyncQueryRootType = GraphQLObjectType( + name="QueryRoot", + fields={ + "sync": GraphQLField(GraphQLNonNull(GraphQLString), resolve=lambda obj, info: "sync"), + "nsync": GraphQLField(GraphQLNonNull(GraphQLString), resolve=async_resolver), + }, +) +AsyncMutationRootType = GraphQLObjectType( + name="MutationRoot", + fields={ + "sync": GraphQLField(type_=GraphQLString, resolve=lambda obj, info: "sync"), + "nsync": GraphQLField(type_=GraphQLString, resolve=async_resolver), + }, +) +AsyncSchema = GraphQLSchema(AsyncQueryRootType, AsyncMutationRootType) diff --git a/tests/flask/test_graphqlview.py b/tests/flask/test_graphqlview.py index d8d60b0..a346326 100644 --- a/tests/flask/test_graphqlview.py +++ b/tests/flask/test_graphqlview.py @@ -83,9 +83,7 @@ def test_allows_get_with_operation_name(app, client): ) assert response.status_code == 200 - assert response_json(response) == { - "data": {"test": "Hello World", "shared": "Hello Everyone"} - } + assert response_json(response) == {"data": {"test": "Hello World", "shared": "Hello Everyone"}} def test_reports_validation_errors(app, client): @@ -272,7 +270,9 @@ def test_supports_post_url_encoded_query_with_string_variables(app, client): def test_supports_post_json_query_with_get_variable_values(app, client): response = client.post( url_string(app, variables=json.dumps({"who": "Dolly"})), - data=json_dump_kwarg(query="query helloWho($who: String){ test(who: $who) }",), + data=json_dump_kwarg( + query="query helloWho($who: String){ test(who: $who) }", + ), content_type="application/json", ) @@ -283,7 +283,11 @@ def test_supports_post_json_query_with_get_variable_values(app, client): def test_post_url_encoded_query_with_get_variable_values(app, client): response = client.post( url_string(app, variables=json.dumps({"who": "Dolly"})), - data=urlencode(dict(query="query helloWho($who: String){ test(who: $who) }",)), + data=urlencode( + dict( + query="query helloWho($who: String){ test(who: $who) }", + ) + ), content_type="application/x-www-form-urlencoded", ) @@ -320,9 +324,7 @@ def test_allows_post_with_operation_name(app, client): ) assert response.status_code == 200 - assert response_json(response) == { - "data": {"test": "Hello World", "shared": "Hello Everyone"} - } + assert response_json(response) == {"data": {"test": "Hello World", "shared": "Hello Everyone"}} def test_allows_post_with_get_operation_name(app, client): @@ -340,18 +342,14 @@ def test_allows_post_with_get_operation_name(app, client): ) assert response.status_code == 200 - assert response_json(response) == { - "data": {"test": "Hello World", "shared": "Hello Everyone"} - } + assert response_json(response) == {"data": {"test": "Hello World", "shared": "Hello Everyone"}} @pytest.mark.parametrize("app", [create_app(pretty=True)]) def test_supports_pretty_printing(app, client): response = client.get(url_string(app, query="{test}")) - assert response.data.decode() == ( - "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}" - ) + assert response.data.decode() == ("{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}") @pytest.mark.parametrize("app", [create_app(pretty=False)]) @@ -364,9 +362,7 @@ def test_not_pretty_by_default(app, client): def test_supports_pretty_printing_by_request(app, client): response = client.get(url_string(app, query="{test}", pretty="1")) - assert response.data.decode() == ( - "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}" - ) + assert response.data.decode() == ("{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}") def test_handles_field_errors_caught_by_graphql(app, client): @@ -403,9 +399,7 @@ def test_handles_errors_caused_by_a_lack_of_query(app, client): assert response.status_code == 400 assert response_json(response) == { - "errors": [ - {"message": "Must provide query string.", "locations": None, "path": None} - ] + "errors": [{"message": "Must provide query string.", "locations": None, "path": None}] } @@ -425,15 +419,11 @@ def test_handles_batch_correctly_if_is_disabled(app, client): def test_handles_incomplete_json_bodies(app, client): - response = client.post( - url_string(app), data='{"query":', content_type="application/json" - ) + response = client.post(url_string(app), data='{"query":', content_type="application/json") assert response.status_code == 400 assert response_json(response) == { - "errors": [ - {"message": "POST body sent invalid JSON.", "locations": None, "path": None} - ] + "errors": [{"message": "POST body sent invalid JSON.", "locations": None, "path": None}] } @@ -445,9 +435,7 @@ def test_handles_plain_post_text(app, client): ) assert response.status_code == 400 assert response_json(response) == { - "errors": [ - {"message": "Must provide query string.", "locations": None, "path": None} - ] + "errors": [{"message": "Must provide query string.", "locations": None, "path": None}] } @@ -461,9 +449,7 @@ def test_handles_poorly_formed_variables(app, client): ) assert response.status_code == 400 assert response_json(response) == { - "errors": [ - {"message": "Variables are invalid JSON.", "locations": None, "path": None} - ] + "errors": [{"message": "Variables are invalid JSON.", "locations": None, "path": None}] } @@ -524,9 +510,7 @@ def test_post_multipart_data(app, client): ) assert response.status_code == 200 - assert response_json(response) == { - "data": {u"writeTest": {u"test": u"Hello World"}} - } + assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}} @pytest.mark.parametrize("app", [create_app(batch=True)]) @@ -575,6 +559,25 @@ def test_batch_allows_post_with_operation_name(app, client): ) assert response.status_code == 200 - assert response_json(response) == [ - {"data": {"test": "Hello World", "shared": "Hello Everyone"}} - ] + assert response_json(response) == [{"data": {"test": "Hello World", "shared": "Hello Everyone"}}] + + +@pytest.mark.parametrize( + ("query", "result"), + ( + ("query sync {sync}", {"sync": "sync"}), + ("query nsync {nsync}", {"nsync": "async"}), + ("mutation sync {sync}", {"sync": "sync"}), + ("mutation nsync {nsync}", {"nsync": "async"}), + ), +) +@pytest.mark.parametrize("app", [create_app(enable_async=True)]) +def test_async_client_handles_sync_calls(app, client, query, result): + response = client.post( + url_string(app), + data=json_dump_kwarg(query=query), + content_type="application/json", + ) + + assert response.status_code == 200, response.data + assert response_json(response) == {"data": result} From 41ec395d4945724905d4f50e0ed5077edef43d7d Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Thu, 7 Jan 2021 12:03:53 -0500 Subject: [PATCH 3/3] fix: async loop must exist for dataloaders to work --- graphql_server/flask/graphqlview.py | 49 ++++++++++++++++------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/graphql_server/flask/graphqlview.py b/graphql_server/flask/graphqlview.py index 9d8b073..96d10f3 100644 --- a/graphql_server/flask/graphqlview.py +++ b/graphql_server/flask/graphqlview.py @@ -71,12 +71,31 @@ def get_context(self): def get_middleware(self): return self.middleware - @staticmethod - def get_async_execution_results(execution_results): - async def await_execution_results(execution_results): - return [ex if ex is None or not is_awaitable(ex) else await ex for ex in execution_results] - - return asyncio.run(await_execution_results(execution_results)) + def get_async_execution_results(self, request_method, data, catch): + async def await_execution_results(): + execution_results, all_params = self.run_http_query(request_method, data, catch) + return [ + ex if ex is None or not is_awaitable(ex) else await ex + for ex in execution_results + ], all_params + + q = asyncio.run(await_execution_results()) + return q + + def run_http_query(self, request_method, data, catch): + return run_http_query( + self.schema, + request_method, + data, + query_data=request.args, + batch_enabled=self.batch, + catch=catch, + # Execute options + root_value=self.get_root_value(), + context_value=self.get_context(), + middleware=self.get_middleware(), + run_sync=not self.enable_async, + ) def dispatch_request(self): try: @@ -89,23 +108,11 @@ def dispatch_request(self): pretty = self.pretty or show_graphiql or request.args.get("pretty") all_params: List[GraphQLParams] - execution_results, all_params = run_http_query( - self.schema, - request_method, - data, - query_data=request.args, - batch_enabled=self.batch, - catch=catch, - # Execute options - root_value=self.get_root_value(), - context_value=self.get_context(), - middleware=self.get_middleware(), - run_sync=not self.enable_async, - ) if self.enable_async: - if any(is_awaitable(ex) for ex in execution_results): - execution_results = self.get_async_execution_results(execution_results) + execution_results, all_params = self.get_async_execution_results(request_method, data, catch) + else: + execution_results, all_params = self.run_http_query(request_method, data, catch) result, status_code = encode_execution_results( execution_results,