Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: flask asyncio support for dataloaders #66

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 9 additions & 24 deletions graphql_server/__init__.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
from graphql.error import format_error as format_error_default
from graphql.execution import ExecutionResult, execute
from graphql.language import OperationType, parse
from graphql.pyutils import AwaitableOrValue
from graphql.pyutils import AwaitableOrValue, is_awaitable
from graphql.type import GraphQLSchema, validate_schema
from graphql.utilities import get_operation_ast
from graphql.validation import ASTValidationRule, validate
@@ -99,9 +99,7 @@ def run_http_query(

if not is_batch:
if not isinstance(data, (dict, MutableMapping)):
raise HttpQueryError(
400, f"GraphQL params should be a dict. Received {data!r}."
)
raise HttpQueryError(400, f"GraphQL params should be a dict. Received {data!r}.")
data = [data]
elif not batch_enabled:
raise HttpQueryError(400, "Batch GraphQL requests are not enabled.")
@@ -114,15 +112,10 @@ def run_http_query(
if not is_batch:
extra_data = query_data or {}

all_params: List[GraphQLParams] = [
get_graphql_params(entry, extra_data) for entry in data
]
all_params: List[GraphQLParams] = [get_graphql_params(entry, extra_data) for entry in data]

results: List[Optional[AwaitableOrValue[ExecutionResult]]] = [
get_response(
schema, params, catch_exc, allow_only_query, run_sync, **execute_options
)
for params in all_params
get_response(schema, params, catch_exc, allow_only_query, run_sync, **execute_options) for params in all_params
]
return GraphQLResponse(results, all_params)

@@ -160,10 +153,7 @@ def encode_execution_results(
Returns a ServerResponse tuple with the serialized response as the first item and
a status code of 200 or 400 in case any result was invalid as the second item.
"""
results = [
format_execution_result(execution_result, format_error)
for execution_result in execution_results
]
results = [format_execution_result(execution_result, format_error) for execution_result in execution_results]
result, status_codes = zip(*results)
status_code = max(status_codes)

@@ -274,14 +264,11 @@ def get_response(
if operation != OperationType.QUERY.value:
raise HttpQueryError(
405,
f"Can only perform a {operation} operation"
" from a POST request.",
f"Can only perform a {operation} operation" " from a POST request.",
headers={"Allow": "POST"},
)

validation_errors = validate(
schema, document, rules=validation_rules, max_errors=max_errors
)
validation_errors = validate(schema, document, rules=validation_rules, max_errors=max_errors)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)

@@ -290,7 +277,7 @@ def get_response(
document,
variable_values=params.variables,
operation_name=params.operation_name,
is_awaitable=assume_not_awaitable if run_sync else None,
is_awaitable=assume_not_awaitable if run_sync else is_awaitable,
**kwargs,
)

@@ -317,9 +304,7 @@ def format_execution_result(
fe = [format_error(e) for e in execution_result.errors] # type: ignore
response = {"errors": fe}

if execution_result.errors and any(
not getattr(e, "path", None) for e in execution_result.errors
):
if execution_result.errors and any(not getattr(e, "path", None) for e in execution_result.errors):
status_code = 400
else:
response["data"] = execution_result.data
68 changes: 40 additions & 28 deletions graphql_server/flask/graphqlview.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import asyncio
import copy
from collections.abc import MutableMapping
from functools import partial
from typing import List

from flask import Response, render_template_string, request
from flask.views import View
from graphql import ExecutionResult
from graphql.error import GraphQLError
from graphql.pyutils import is_awaitable
from graphql.type.schema import GraphQLSchema

from graphql_server import (
@@ -41,6 +44,7 @@ class GraphQLView(View):
default_query = None
header_editor_enabled = None
should_persist_headers = None
enable_async = False

methods = ["GET", "POST", "PUT", "DELETE"]

@@ -53,26 +57,46 @@ def __init__(self, **kwargs):
if hasattr(self, key):
setattr(self, key, value)

assert isinstance(
self.schema, GraphQLSchema
), "A Schema is required to be provided to GraphQLView."
assert isinstance(self.schema, GraphQLSchema), "A Schema is required to be provided to GraphQLView."

def get_root_value(self):
return self.root_value

def get_context(self):
context = (
copy.copy(self.context)
if self.context and isinstance(self.context, MutableMapping)
else {}
)
context = copy.copy(self.context) if self.context and isinstance(self.context, MutableMapping) else {}
if isinstance(context, MutableMapping) and "request" not in context:
context.update({"request": request})
return context

def get_middleware(self):
return self.middleware

def get_async_execution_results(self, request_method, data, catch):
async def await_execution_results():
execution_results, all_params = self.run_http_query(request_method, data, catch)
return [
ex if ex is None or not is_awaitable(ex) else await ex
for ex in execution_results
], all_params

q = asyncio.run(await_execution_results())
return q

def run_http_query(self, request_method, data, catch):
return run_http_query(
self.schema,
request_method,
data,
query_data=request.args,
batch_enabled=self.batch,
catch=catch,
# Execute options
root_value=self.get_root_value(),
context_value=self.get_context(),
middleware=self.get_middleware(),
run_sync=not self.enable_async,
)

def dispatch_request(self):
try:
request_method = request.method.lower()
@@ -84,18 +108,12 @@ def dispatch_request(self):
pretty = self.pretty or show_graphiql or request.args.get("pretty")

all_params: List[GraphQLParams]
execution_results, all_params = run_http_query(
self.schema,
request_method,
data,
query_data=request.args,
batch_enabled=self.batch,
catch=catch,
# Execute options
root_value=self.get_root_value(),
context_value=self.get_context(),
middleware=self.get_middleware(),
)

if self.enable_async:
execution_results, all_params = self.get_async_execution_results(request_method, data, catch)
else:
execution_results, all_params = self.run_http_query(request_method, data, catch)

result, status_code = encode_execution_results(
execution_results,
is_batch=isinstance(data, list),
@@ -123,9 +141,7 @@ def dispatch_request(self):
header_editor_enabled=self.header_editor_enabled,
should_persist_headers=self.should_persist_headers,
)
source = render_graphiql_sync(
data=graphiql_data, config=graphiql_config, options=graphiql_options
)
source = render_graphiql_sync(data=graphiql_data, config=graphiql_config, options=graphiql_options)
return render_template_string(source)

return Response(result, status=status_code, content_type="application/json")
@@ -167,8 +183,4 @@ def should_display_graphiql(self):
@staticmethod
def request_wants_html():
best = request.accept_mimetypes.best_match(["application/json", "text/html"])
return (
best == "text/html"
and request.accept_mimetypes[best]
> request.accept_mimetypes["application/json"]
)
return best == "text/html" and request.accept_mimetypes[best] > request.accept_mimetypes["application/json"]
9 changes: 5 additions & 4 deletions tests/flask/app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from flask import Flask

from graphql_server.flask import GraphQLView
from tests.flask.schema import Schema
from tests.flask.schema import AsyncSchema, Schema


def create_app(path="/graphql", **kwargs):
server = Flask(__name__)
server.debug = True
server.add_url_rule(
path, view_func=GraphQLView.as_view("graphql", schema=Schema, **kwargs)
)
if kwargs.get("enable_async", None):
server.add_url_rule(path, view_func=GraphQLView.as_view("graphql", schema=AsyncSchema, **kwargs))
else:
server.add_url_rule(path, view_func=GraphQLView.as_view("graphql", schema=Schema, **kwargs))
return server


25 changes: 22 additions & 3 deletions tests/flask/schema.py
Original file line number Diff line number Diff line change
@@ -43,9 +43,28 @@ def resolve_raises(*_):

MutationRootType = GraphQLObjectType(
name="MutationRoot",
fields={
"writeTest": GraphQLField(type_=QueryRootType, resolve=lambda *_: QueryRootType)
},
fields={"writeTest": GraphQLField(type_=QueryRootType, resolve=lambda *_: QueryRootType)},
)

Schema = GraphQLSchema(QueryRootType, MutationRootType)


async def async_resolver(obj, info):
return "async"


AsyncQueryRootType = GraphQLObjectType(
name="QueryRoot",
fields={
"sync": GraphQLField(GraphQLNonNull(GraphQLString), resolve=lambda obj, info: "sync"),
"nsync": GraphQLField(GraphQLNonNull(GraphQLString), resolve=async_resolver),
},
)
AsyncMutationRootType = GraphQLObjectType(
name="MutationRoot",
fields={
"sync": GraphQLField(type_=GraphQLString, resolve=lambda obj, info: "sync"),
"nsync": GraphQLField(type_=GraphQLString, resolve=async_resolver),
},
)
AsyncSchema = GraphQLSchema(AsyncQueryRootType, AsyncMutationRootType)
79 changes: 41 additions & 38 deletions tests/flask/test_graphqlview.py
Original file line number Diff line number Diff line change
@@ -83,9 +83,7 @@ def test_allows_get_with_operation_name(app, client):
)

assert response.status_code == 200
assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
assert response_json(response) == {"data": {"test": "Hello World", "shared": "Hello Everyone"}}


def test_reports_validation_errors(app, client):
@@ -272,7 +270,9 @@ def test_supports_post_url_encoded_query_with_string_variables(app, client):
def test_supports_post_json_query_with_get_variable_values(app, client):
response = client.post(
url_string(app, variables=json.dumps({"who": "Dolly"})),
data=json_dump_kwarg(query="query helloWho($who: String){ test(who: $who) }",),
data=json_dump_kwarg(
query="query helloWho($who: String){ test(who: $who) }",
),
content_type="application/json",
)

@@ -283,7 +283,11 @@ def test_supports_post_json_query_with_get_variable_values(app, client):
def test_post_url_encoded_query_with_get_variable_values(app, client):
response = client.post(
url_string(app, variables=json.dumps({"who": "Dolly"})),
data=urlencode(dict(query="query helloWho($who: String){ test(who: $who) }",)),
data=urlencode(
dict(
query="query helloWho($who: String){ test(who: $who) }",
)
),
content_type="application/x-www-form-urlencoded",
)

@@ -320,9 +324,7 @@ def test_allows_post_with_operation_name(app, client):
)

assert response.status_code == 200
assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
assert response_json(response) == {"data": {"test": "Hello World", "shared": "Hello Everyone"}}


def test_allows_post_with_get_operation_name(app, client):
@@ -340,18 +342,14 @@ def test_allows_post_with_get_operation_name(app, client):
)

assert response.status_code == 200
assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
assert response_json(response) == {"data": {"test": "Hello World", "shared": "Hello Everyone"}}


@pytest.mark.parametrize("app", [create_app(pretty=True)])
def test_supports_pretty_printing(app, client):
response = client.get(url_string(app, query="{test}"))

assert response.data.decode() == (
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
assert response.data.decode() == ("{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}")


@pytest.mark.parametrize("app", [create_app(pretty=False)])
@@ -364,9 +362,7 @@ def test_not_pretty_by_default(app, client):
def test_supports_pretty_printing_by_request(app, client):
response = client.get(url_string(app, query="{test}", pretty="1"))

assert response.data.decode() == (
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
assert response.data.decode() == ("{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}")


def test_handles_field_errors_caught_by_graphql(app, client):
@@ -403,9 +399,7 @@ def test_handles_errors_caused_by_a_lack_of_query(app, client):

assert response.status_code == 400
assert response_json(response) == {
"errors": [
{"message": "Must provide query string.", "locations": None, "path": None}
]
"errors": [{"message": "Must provide query string.", "locations": None, "path": None}]
}


@@ -425,15 +419,11 @@ def test_handles_batch_correctly_if_is_disabled(app, client):


def test_handles_incomplete_json_bodies(app, client):
response = client.post(
url_string(app), data='{"query":', content_type="application/json"
)
response = client.post(url_string(app), data='{"query":', content_type="application/json")

assert response.status_code == 400
assert response_json(response) == {
"errors": [
{"message": "POST body sent invalid JSON.", "locations": None, "path": None}
]
"errors": [{"message": "POST body sent invalid JSON.", "locations": None, "path": None}]
}


@@ -445,9 +435,7 @@ def test_handles_plain_post_text(app, client):
)
assert response.status_code == 400
assert response_json(response) == {
"errors": [
{"message": "Must provide query string.", "locations": None, "path": None}
]
"errors": [{"message": "Must provide query string.", "locations": None, "path": None}]
}


@@ -461,9 +449,7 @@ def test_handles_poorly_formed_variables(app, client):
)
assert response.status_code == 400
assert response_json(response) == {
"errors": [
{"message": "Variables are invalid JSON.", "locations": None, "path": None}
]
"errors": [{"message": "Variables are invalid JSON.", "locations": None, "path": None}]
}


@@ -524,9 +510,7 @@ def test_post_multipart_data(app, client):
)

assert response.status_code == 200
assert response_json(response) == {
"data": {u"writeTest": {u"test": u"Hello World"}}
}
assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}}


@pytest.mark.parametrize("app", [create_app(batch=True)])
@@ -575,6 +559,25 @@ def test_batch_allows_post_with_operation_name(app, client):
)

assert response.status_code == 200
assert response_json(response) == [
{"data": {"test": "Hello World", "shared": "Hello Everyone"}}
]
assert response_json(response) == [{"data": {"test": "Hello World", "shared": "Hello Everyone"}}]


@pytest.mark.parametrize(
("query", "result"),
(
("query sync {sync}", {"sync": "sync"}),
("query nsync {nsync}", {"nsync": "async"}),
("mutation sync {sync}", {"sync": "sync"}),
("mutation nsync {nsync}", {"nsync": "async"}),
),
)
@pytest.mark.parametrize("app", [create_app(enable_async=True)])
def test_async_client_handles_sync_calls(app, client, query, result):
response = client.post(
url_string(app),
data=json_dump_kwarg(query=query),
content_type="application/json",
)

assert response.status_code == 200, response.data
assert response_json(response) == {"data": result}