diff --git a/conda-environment.yml b/conda-environment.yml index 4a598468d67..102e8167cf8 100644 --- a/conda-environment.yml +++ b/conda-environment.yml @@ -5,7 +5,7 @@ dependencies: - ansimarkup >=1.0.0 - async-timeout>=3.0.0 - colorama >=0.4,<1.0 - - graphene >=2.1,<3 + - graphene >=3.4.3,<4 - graphviz # for static graphing # Note: can't pin jinja2 any higher than this until we give up on Cylc 7 back-compat - jinja2 >=3.0,<3.1 diff --git a/cylc/flow/commands.py b/cylc/flow/commands.py index 3759f119a1b..5f8ca6de0d1 100644 --- a/cylc/flow/commands.py +++ b/cylc/flow/commands.py @@ -74,7 +74,6 @@ ) import cylc.flow.flags from cylc.flow.log_level import log_level_to_verbosity -from cylc.flow.network.schema import WorkflowStopMode from cylc.flow.parsec.exceptions import ParsecError from cylc.flow.task_id import TaskID from cylc.flow.workflow_status import RunMode, StopMode @@ -82,6 +81,7 @@ from metomi.isodatetime.parsers import TimePointParser if TYPE_CHECKING: + from enum import Enum from cylc.flow.scheduler import Scheduler # define a type for command implementations @@ -165,7 +165,7 @@ async def set_prereqs_and_outputs( @_command('stop') async def stop( schd: 'Scheduler', - mode: Union[str, 'StopMode'], + mode: Union[str, 'Enum'], cycle_point: Optional[str] = None, # NOTE clock_time YYYY/MM/DD-HH:mm back-compat removed clock_time: Optional[str] = None, @@ -203,10 +203,10 @@ async def stop( schd._update_workflow_state() else: # immediate shutdown - with suppress(KeyError): - # By default, mode from mutation is a name from the - # WorkflowStopMode graphene.Enum, but we need the value - mode = WorkflowStopMode[mode] # type: ignore[misc] + with suppress(AttributeError): + # By default, mode from mutation is a WorkflowStopMode + # graphene.Enum, but we need the value + mode = mode.value # type: ignore try: mode = StopMode(mode) except ValueError: @@ -298,10 +298,10 @@ async def pause(schd: 'Scheduler'): @_command('set_verbosity') -async def set_verbosity(schd: 'Scheduler', level: Union[int, str]): +async def set_verbosity(schd: 'Scheduler', level: 'Enum'): """Set workflow verbosity.""" try: - lvl = int(level) + lvl = int(level.value) LOG.setLevel(lvl) except (TypeError, ValueError) as exc: raise CommandFailedError(exc) from None diff --git a/cylc/flow/flow_mgr.py b/cylc/flow/flow_mgr.py index 1cd1c1e8c70..df94b304d4d 100644 --- a/cylc/flow/flow_mgr.py +++ b/cylc/flow/flow_mgr.py @@ -34,7 +34,7 @@ def add_flow_opts(parser): parser.add_option( - "--flow", action="append", dest="flow", metavar="FLOW", + "--flow", action="append", dest="flow", metavar="FLOW", default=[], help=f'Assign new tasks to all active flows ("{FLOW_ALL}");' f' no flow ("{FLOW_NONE}"); a new flow ("{FLOW_NEW}");' f' or a specific flow (e.g. "2"). The default is "{FLOW_ALL}".' diff --git a/cylc/flow/network/graphql.py b/cylc/flow/network/graphql.py index 1f49a6ee2fd..5c805148c89 100644 --- a/cylc/flow/network/graphql.py +++ b/cylc/flow/network/graphql.py @@ -19,32 +19,27 @@ """ -from functools import partial -from inspect import isclass, iscoroutinefunction +from inspect import isclass import logging -from typing import TYPE_CHECKING, Any, Tuple, Union +from typing import ( + Any, Awaitable, Callable, TypeVar, Tuple, Dict, Union, cast +) from graphene.utils.str_converters import to_snake_case -from graphql.execution.utils import ( - get_operation_root_type, get_field_def +from graphql import ( + ExecutionContext, + TypeInfo, + TypeInfoVisitor, + Visitor, + visit, + get_argument_values, + get_named_type, + introspection_types, ) -from graphql.execution.values import get_argument_values, get_variable_values -from graphql.language.base import parse, print_ast -from graphql.language import ast -from graphql.backend.base import GraphQLBackend, GraphQLDocument -from graphql.backend.core import execute_and_validate -from graphql.utils.base import type_from_ast -from graphql.type.definition import get_named_type -from promise import Promise -from rx import Observable +from graphql.pyutils import AwaitableOrValue, is_awaitable from cylc.flow.network.schema import NODE_MAP -if TYPE_CHECKING: - from graphql.execution import ExecutionResult - from graphql.language.ast import Document - from graphql.type.schema import GraphQLSchema - logger = logging.getLogger(__name__) @@ -52,6 +47,13 @@ NULL_VALUE = None EMPTY_VALUES: Tuple[list, dict] = ([], {}) STRIP_OPS = {'query', 'subscription'} +INTROSPECTS = { + k.lower() + for k in introspection_types +} + +T = TypeVar("T") +U = TypeVar("U") def grow_tree(tree, path, leaves=None): @@ -94,6 +96,38 @@ def instantiate_middleware(middlewares): yield middleware +async def async_callback( + callback: Callable[[U], AwaitableOrValue[U]], + result: AwaitableOrValue[U], +) -> U: + """Await result and apply callback.""" + result = callback(await cast('Awaitable[Any]', result)) + return await result if is_awaitable(result) else result # type: ignore + + +async def async_resolve(result: AwaitableOrValue[U]) -> AwaitableOrValue[U]: + """Reduce the given potentially awaitable values.""" + if is_awaitable(result): + return await cast('Awaitable[Any]', result) + else: + return result + + +def async_next( + callback: Callable[[U], AwaitableOrValue[U]], + result: AwaitableOrValue[U], +) -> AwaitableOrValue[U]: + """Reduce the given potentially awaitable values using a callback function. + + If the callback does not return an awaitable, then this function will also + not return an awaitable. + """ + if is_awaitable(result): + return async_callback(callback, result) + else: + return callback(cast('U', result)) + + def null_setter(result): """Set type to null if result is empty/null-like.""" # Only set empty parents to null. @@ -111,8 +145,8 @@ def null_setter(result): # However, middleware allows for argument of the request doc to set. def strip_null(data): """Recursively strip data structure of nulls.""" - if isinstance(data, Promise): - return data.then(strip_null) + if is_awaitable(data): + return async_next(strip_null, data) if isinstance(data, dict): return { key: strip_null(val) @@ -138,190 +172,83 @@ def attr_strip_null(result): def null_stripper(exe_result): """Strip nulls in accordance with type of execution result.""" - if isinstance(exe_result, Observable): - return exe_result.map(attr_strip_null) - if not exe_result.errors: + if is_awaitable(exe_result): + return async_next(attr_strip_null, exe_result) + if getattr(exe_result, 'errors', None) is None: return attr_strip_null(exe_result) return exe_result -class AstDocArguments: - """Request doc Argument inspection.""" - - def __init__(self, schema, document_ast, variable_values): - self.schema = schema - self.operation_defs = {} - self.fragment_defs = {} - self.visited_fragments = set() - - for defn in document_ast.definitions: - if isinstance(defn, ast.OperationDefinition): - root_type = get_operation_root_type(schema, defn) - definition_variables = defn.variable_definitions or [] - if definition_variables: - def_var_names = { - v.variable.name.value - for v in definition_variables - } - var_names_diff = def_var_names.difference({ - k - for k in variable_values - if k in def_var_names - }) - # check if we are missing some of the definition variables - if var_names_diff: - msg = (f'Please check your query variables. The ' - f'following variables are missing: ' - f'[{", ".join(var_names_diff)}]') - raise ValueError(msg) - self.operation_defs[getattr(defn.name, 'value', root_type)] = { - 'definition': defn, - 'parent_type': root_type, - 'variables': get_variable_values( - schema, - definition_variables, - variable_values - ), - } - elif isinstance(defn, ast.FragmentDefinition): - self.fragment_defs[defn.name.value] = defn - - def has_arg_val(self, arg_name, arg_value): - """Search through document definitions for argument value. - - Args: - arg_name (str): Field argument to search for. - arg_value (Any): Argument value required. +class CylcVisitor(Visitor): - Returns: + def __init__(self, type_info, variable_values, doc_arg) -> None: + super().__init__() + self.type_info = type_info + self.variable_values = variable_values + self.doc_arg = doc_arg + self.arg_flag = False - Boolean - - """ - for components in self.operation_defs.values(): - defn = components['definition'] - if ( - defn.operation not in STRIP_OPS - or getattr( - defn.name, 'value', None) == 'IntrospectionQuery' - ): - continue - if self.args_selection_search( - components['definition'].selection_set, - components['variables'], - components['parent_type'], - arg_name, - arg_value, - ): - return True - return False - - def args_selection_search( - self, selection_set, variables, parent_type, arg_name, arg_value): - """Recursively search through feild/fragment selection set fields.""" - for field in selection_set.selections: - if isinstance(field, ast.FragmentSpread): - if field.name.value in self.visited_fragments: - continue - frag_def = self.fragment_defs[field.name.value] - frag_type = type_from_ast(self.schema, frag_def.type_condition) - if self.args_selection_search( - frag_def.selection_set, variables, - frag_type, arg_name, arg_value): - return True - self.visited_fragments.add(frag_def.name) - continue - field_def = get_field_def( - self.schema, parent_type, field.name.value) - if field_def is None: - continue + def enter(self, node, key, parent, path, ancestors): + if hasattr(node, 'arguments'): + field_def = self.type_info.get_field_def() arg_vals = get_argument_values( - field_def.args, field.arguments, variables) - if arg_vals.get(arg_name) == arg_value: - return True - if field.selection_set is None: - continue - if self.args_selection_search( - field.selection_set, variables, - get_named_type(field_def.type), arg_name, arg_value): - return True - return False - - -def execute_and_validate_and_strip( - schema: 'GraphQLSchema', - document_ast: 'Document', - *args: Any, - **kwargs: Any -) -> Union['ExecutionResult', Observable]: - """Wrapper around graphql ``execute_and_validate()`` that adds - null stripping.""" - result = execute_and_validate(schema, document_ast, *args, **kwargs) - # Search request document to determine if 'stripNull: true' is set - # as and argument. It can not be done in the middleware, as they - # can be Promises/futures (so may not been resolved at this point). - variable_values = kwargs['variable_values'] or {} - doc_args = AstDocArguments(schema, document_ast, variable_values) - if doc_args.has_arg_val(STRIP_ARG, True): - if kwargs.get('return_promise', False) and hasattr(result, 'then'): - return result.then(null_stripper) # type: ignore[union-attr] - return null_stripper(result) - return result - + field_def, + node, + self.variable_values + ) + if arg_vals.get(self.doc_arg['arg']) == self.doc_arg['val']: + self.arg_flag = True + return self.BREAK + return self.IDLE -class CylcGraphQLBackend(GraphQLBackend): - """Return a GraphQL document using the default - graphql executor with optional null-stripping of result. + def leave(self, node, key, parent, path, ancestors): + return self.IDLE - The null value stripping of result is triggered by the presence - of argument & value "stripNull: true" in any field. - This is a modification of GraphQLCoreBackend found within: - https://github.com/graphql-python/graphql-core-legacy - (graphql-core==2.3.2) +class CylcExecutionContext(ExecutionContext): - Args: - - executor (object): Executor used in evaluating the resolvers. - - """ - - def __init__(self, executor=None): - self.execute_params = {"executor": executor} - - def document_from_string(self, schema, document_string): - """Parse string and setup request document for execution. - - Args: - - schema (graphql.GraphQLSchema): - Schema definition object - document_string (str): - Request query/mutation/subscription document. - - Returns: - - graphql.GraphQLDocument + def execute_operation( + self, initial_result_record, root_value + ) -> AwaitableOrValue[Union[Dict[str, Any], Any, None]]: + """Execute the GraphQL document, and apply requested stipping. + Search request document to determine if 'stripNull: true' is set + as and argument. It can not be done in the middleware, as they + can have awaitables and is prior to validation. """ - if isinstance(document_string, ast.Document): - document_ast = document_string - document_string = print_ast(document_ast) - else: - if not isinstance(document_string, str): - logger.error("The query must be a string") - document_ast = parse(document_string) - return GraphQLDocument( - schema=schema, - document_string=document_string, - document_ast=document_ast, - execute=partial( - execute_and_validate_and_strip, - schema, - document_ast, - **self.execute_params + result = super().execute_operation(initial_result_record, root_value) + + # Traverse the document and stop if found + type_info = TypeInfo(self.schema) + cylc_visitor = CylcVisitor( + type_info, + self.variable_values, + { + 'arg': 'strip_null', + 'val': True, + } + ) + visit( + self.operation, + TypeInfoVisitor( + type_info, + cylc_visitor ), + None ) + if not cylc_visitor.arg_flag: + for fragment in self.fragments.values(): + visit( + fragment, + TypeInfoVisitor( + type_info, + cylc_visitor + ), + None + ) + if cylc_visitor.arg_flag: + return async_next(null_stripper, result) # type: ignore + return result # -- Middleware -- @@ -341,21 +268,22 @@ def __init__(self): def resolve(self, next_, root, info, **args): """Middleware resolver; handles field according to operation.""" # GraphiQL introspection is 'query' but not async - if getattr(info.operation.name, 'value', None) == 'IntrospectionQuery': + if INTROSPECTS.intersection({f'{p}' for p in info.path.as_list()}): return next_(root, info, **args) - if info.operation.operation in STRIP_OPS: - path_string = f'{info.path}' + if info.operation.operation.value in STRIP_OPS: + path_list = info.path.as_list() + path_string = f'{path_list}' # Needed for child fields that resolve without args. # Store arguments of parents as leaves of schema tree from path # to respective field. # no need to regrow the tree on every subscription push/delta if args and path_string not in self.tree_paths: - grow_tree(self.args_tree, info.path, args) + grow_tree(self.args_tree, path_list, args) self.tree_paths.add(path_string) if STRIP_ARG not in args: branch = self.args_tree - for section in info.path: + for section in path_list: if section not in branch: break branch = branch[section] @@ -381,7 +309,7 @@ def resolve(self, next_, root, info, **args): ): # Gather fields set in root - parent_path_string = f'{info.path[:-1:]}' + parent_path_string = f'{path_list[:-1:]}' stamp = getattr(root, 'stamp', '') if ( parent_path_string not in self.field_sets @@ -414,25 +342,6 @@ def resolve(self, next_, root, info, **args): ) ): return None - if ( - info.operation.operation in self.ASYNC_OPS - or iscoroutinefunction(next_) - ): - return self.async_null_setter(next_, root, info, **args) - return null_setter(next_(root, info, **args)) - - if ( - info.operation.operation in self.ASYNC_OPS - or iscoroutinefunction(next_) - ): - return self.async_resolve(next_, root, info, **args) - return next_(root, info, **args) + return async_next(null_setter, next_(root, info, **args)) - async def async_resolve(self, next_, root, info, **args): - """Return awaited coroutine""" - return await next_(root, info, **args) - - async def async_null_setter(self, next_, root, info, **args): - """Set type to null after awaited result if empty/null-like.""" - result = await next_(root, info, **args) - return null_setter(result) + return next_(root, info, **args) diff --git a/cylc/flow/network/resolvers.py b/cylc/flow/network/resolvers.py index fc9b67eeef5..db404930f8f 100644 --- a/cylc/flow/network/resolvers.py +++ b/cylc/flow/network/resolvers.py @@ -57,8 +57,9 @@ ) if TYPE_CHECKING: + from enum import Enum from uuid import UUID - from graphql import ResolveInfo + from graphql import GraphQLResolveInfo from cylc.flow.data_store_mgr import DataStoreMgr from cylc.flow.scheduler import Scheduler @@ -545,7 +546,7 @@ async def get_nodes_edges(self, root_nodes, args): edges=sort_elements(edges, args)) async def subscribe_delta( - self, root, info: 'ResolveInfo', args + self, root, info: 'GraphQLResolveInfo', args ) -> AsyncGenerator[Any, None]: """Delta subscription async generator. @@ -676,7 +677,7 @@ async def flow_delta_processed(self, context, op_id): @abstractmethod async def mutator( self, - info: 'ResolveInfo', + info: 'GraphQLResolveInfo', command: str, w_args: Dict[str, Any], kwargs: Dict[str, Any], @@ -697,7 +698,7 @@ def __init__(self, data: 'DataStoreMgr', schd: 'Scheduler') -> None: # Mutations async def mutator( self, - _info: 'ResolveInfo', + _info: 'GraphQLResolveInfo', command: str, w_args: Dict[str, Any], kwargs: Dict[str, Any], @@ -780,7 +781,7 @@ async def _mutation_mapper( def broadcast( self, - mode: str, + mode: 'Enum', cycle_points: Optional[List[str]] = None, namespaces: Optional[List[str]] = None, settings: Optional[List[Dict[str, str]]] = None, @@ -795,15 +796,15 @@ def broadcast( RUNTIME_FIELD_TO_CFG_MAP.get(key, key): value for key, value in dict_.items() } - if mode == 'put_broadcast': + if mode.value == 'put_broadcast': return self.schd.task_events_mgr.broadcast_mgr.put_broadcast( cycle_points, namespaces, settings) - if mode == 'clear_broadcast': + if mode.value == 'clear_broadcast': return self.schd.task_events_mgr.broadcast_mgr.clear_broadcast( point_strings=cycle_points, namespaces=namespaces, cancel_settings=settings) - if mode == 'expire_broadcast': + if mode.value == 'expire_broadcast': return self.schd.task_events_mgr.broadcast_mgr.expire_broadcast( cutoff) raise ValueError('Unsupported broadcast mode') diff --git a/cylc/flow/network/schema.py b/cylc/flow/network/schema.py index ab34def7f75..002f5d8e5cc 100644 --- a/cylc/flow/network/schema.py +++ b/cylc/flow/network/schema.py @@ -28,7 +28,6 @@ List, Optional, Tuple, - Union, cast, ) @@ -66,11 +65,10 @@ from cylc.flow.workflow_status import StopMode if TYPE_CHECKING: - from graphql import ResolveInfo + from graphql import GraphQLResolveInfo from graphql.type.definition import ( + GraphQLType, GraphQLNamedType, - GraphQLList, - GraphQLNonNull, ) from cylc.flow.network.resolvers import BaseResolvers @@ -260,9 +258,7 @@ class SortArgs(InputObjectType): # Resolvers: -def field_name_from_type( - obj_type: 'Union[GraphQLNamedType, GraphQLList, GraphQLNonNull]' -) -> str: +def field_name_from_type(obj_type: 'GraphQLType') -> str: """Return the field name for given a GraphQL type. If the type is a list or non-null, the base field is extracted. @@ -274,13 +270,13 @@ def field_name_from_type( raise ValueError(f"'{named_type.name}' is not a node type") from None -def get_resolvers(info: 'ResolveInfo') -> 'BaseResolvers': +def get_resolvers(info: 'GraphQLResolveInfo') -> 'BaseResolvers': """Return the resolvers from the context.""" return cast('dict', info.context)['resolvers'] def process_resolver_info( - root: Optional[Any], info: 'ResolveInfo', args: Dict[str, Any] + root: Optional[Any], info: 'GraphQLResolveInfo', args: Dict[str, Any] ) -> Tuple[str, Optional[Any]]: """Set and gather info for resolver.""" # Add the subscription id to the resolver context @@ -308,7 +304,7 @@ def get_native_ids(field_ids): return field_ids -async def get_workflows(root, info: 'ResolveInfo', **args): +async def get_workflows(root, info: 'GraphQLResolveInfo', **args): """Get filtered workflows.""" _, workflow = process_resolver_info(root, info, args) @@ -321,7 +317,7 @@ async def get_workflows(root, info: 'ResolveInfo', **args): return await resolvers.get_workflows(args) -async def get_workflow_by_id(root, info: 'ResolveInfo', **args): +async def get_workflow_by_id(root, info: 'GraphQLResolveInfo', **args): """Return single workflow element.""" _, workflow = process_resolver_info(root, info, args) @@ -334,7 +330,7 @@ async def get_workflow_by_id(root, info: 'ResolveInfo', **args): async def get_nodes_all( - root: Optional[Any], info: 'ResolveInfo', **args + root: Optional[Any], info: 'GraphQLResolveInfo', **args ): """Resolver for returning job, task, family nodes""" @@ -373,7 +369,7 @@ async def get_nodes_all( async def get_nodes_by_ids( - root: Optional[Any], info: 'ResolveInfo', **args + root: Optional[Any], info: 'GraphQLResolveInfo', **args ): """Resolver for returning job, task, family node""" field_name, field_ids = process_resolver_info(root, info, args) @@ -410,7 +406,7 @@ async def get_nodes_by_ids( async def get_node_by_id( - root: Optional[Any], info: 'ResolveInfo', **args + root: Optional[Any], info: 'GraphQLResolveInfo', **args ): """Resolver for returning job, task, family node""" @@ -448,7 +444,7 @@ async def get_node_by_id( ) -async def get_edges_all(root, info: 'ResolveInfo', **args): +async def get_edges_all(root, info: 'GraphQLResolveInfo', **args): """Get all edges from the store filtered by args.""" process_resolver_info(root, info, args) @@ -463,7 +459,7 @@ async def get_edges_all(root, info: 'ResolveInfo', **args): return await resolvers.get_edges_all(args) -async def get_edges_by_ids(root, info: 'ResolveInfo', **args): +async def get_edges_by_ids(root, info: 'GraphQLResolveInfo', **args): """Get all edges from the store by id lookup filtered by args.""" _, field_ids = process_resolver_info(root, info, args) @@ -477,7 +473,7 @@ async def get_edges_by_ids(root, info: 'ResolveInfo', **args): return await resolvers.get_edges_by_ids(args) -async def get_nodes_edges(root, info: 'ResolveInfo', **args): +async def get_nodes_edges(root, info: 'GraphQLResolveInfo', **args): """Resolver for returning job, task, family nodes""" process_resolver_info(root, info, args) @@ -517,7 +513,7 @@ def resolve_state_tasks(root, info, **args): if state in data} -async def resolve_broadcasts(root, info: 'ResolveInfo', **args): +async def resolve_broadcasts(root, info: 'GraphQLResolveInfo', **args): """Resolve and parse broadcasts from JSON.""" broadcasts = json.loads( getattr(root, to_snake_case(info.field_name), '{}')) @@ -1437,7 +1433,7 @@ class Meta: async def mutator( root: Optional[Any], - info: 'ResolveInfo', + info: 'GraphQLResolveInfo', *, command: Optional[str] = None, workflows: Optional[List[str]] = None, @@ -1473,6 +1469,7 @@ async def mutator( if kwargs.get('args', False): kwargs.update(kwargs.get('args', {})) kwargs.pop('args') + resolvers = get_resolvers(info) meta = info.context.get('meta') # type: ignore[union-attr] res = await resolvers.mutator(info, command, w_args, kwargs, meta) @@ -1627,13 +1624,12 @@ class WorkflowStopMode(graphene.Enum): """The mode used to stop a running workflow.""" # NOTE: using a different enum because: - # * Graphene requires special enums. # * We only want to offer a subset of stop modes (REQUEST_* only). - Clean = StopMode.REQUEST_CLEAN.value # type: graphene.Enum - Kill = StopMode.REQUEST_KILL.value # type: graphene.Enum - Now = StopMode.REQUEST_NOW.value # type: graphene.Enum - NowNow = StopMode.REQUEST_NOW_NOW.value # type: graphene.Enum + Clean = StopMode.REQUEST_CLEAN # type: graphene.Enum + Kill = StopMode.REQUEST_KILL # type: graphene.Enum + Now = StopMode.REQUEST_NOW # type: graphene.Enum + NowNow = StopMode.REQUEST_NOW_NOW # type: graphene.Enum @property def description(self): @@ -1688,9 +1684,7 @@ class Arguments: workflows = graphene.List(WorkflowID, required=True) mode = BroadcastMode( - # use the enum name as the default value - # https://github.com/graphql-python/graphql-core-legacy/issues/166 - default_value=BroadcastMode.Set.name, # type: ignore + default_value=BroadcastMode.Set, description='What type of broadcast is this?', required=True ) @@ -1924,9 +1918,7 @@ class Meta: class Arguments: workflows = graphene.List(WorkflowID, required=True) - mode = WorkflowStopMode( - default_value=WorkflowStopMode.Clean.name - ) + mode = WorkflowStopMode(default_value=WorkflowStopMode.Clean) cycle_point = CyclePoint( description='Stop after the workflow reaches this cycle.' ) @@ -2225,7 +2217,8 @@ class Mutations(ObjectType): } -def delta_subs(root, info: 'ResolveInfo', **args) -> AsyncGenerator[Any, None]: +def delta_subs( + root, info: 'GraphQLResolveInfo', **args) -> AsyncGenerator[Any, None]: """Generates the root data from the async gen resolver.""" return get_resolvers(info).subscribe_delta(root, info, args) @@ -2237,12 +2230,12 @@ class Meta: the store. ''') workflow = String() - families = graphene.List(String, default_value=[]) - family_proxies = graphene.List(String, default_value=[]) - jobs = graphene.List(String, default_value=[]) - tasks = graphene.List(String, default_value=[]) - task_proxies = graphene.List(String, default_value=[]) - edges = graphene.List(String, default_value=[]) + families = graphene.List(String) + family_proxies = graphene.List(String) + jobs = graphene.List(String) + tasks = graphene.List(String) + task_proxies = graphene.List(String) + edges = graphene.List(String) class Delta(Interface): diff --git a/cylc/flow/network/server.py b/cylc/flow/network/server.py index 2c170e61198..ec58d20c4bc 100644 --- a/cylc/flow/network/server.py +++ b/cylc/flow/network/server.py @@ -21,15 +21,16 @@ from time import sleep from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union -from graphql.execution.executors.asyncio import AsyncioExecutor import zmq from zmq.auth.thread import ThreadAuthenticator +from graphql.pyutils import is_awaitable + from cylc.flow import LOG, workflow_files from cylc.flow.cfgspec.glbl_cfg import glbl_cfg from cylc.flow.network.authorisation import authorise from cylc.flow.network.graphql import ( - CylcGraphQLBackend, IgnoreFieldMiddleware, instantiate_middleware + CylcExecutionContext, IgnoreFieldMiddleware, instantiate_middleware ) from cylc.flow.network.publisher import WorkflowPublisher from cylc.flow.network.replier import WorkflowReplier @@ -245,7 +246,7 @@ def operate(self) -> None: """Orchestrate the receive, send, publish of messages.""" # Note: this cannot be an async method because the response part # of the listener runs the event loop synchronously - # (in graphql AsyncioExecutor) + # (in graphql schema.execute_async) while True: if self.waiting_to_stop: # The self.stop() method is waiting for us to signal that we @@ -368,24 +369,23 @@ def graphql( object: Execution result, or a list with errors. """ try: - executed: 'ExecutionResult' = schema.execute( + executed: 'ExecutionResult' = schema.execute_async( request_string, variable_values=variables, context_value={ 'resolvers': self.resolvers, 'meta': meta or {}, }, - backend=CylcGraphQLBackend(), middleware=list(instantiate_middleware(self.middleware)), - executor=AsyncioExecutor(), - validate=True, # validate schema (dev only? default is True) - return_promise=False, + execution_context_class=CylcExecutionContext, ) + if is_awaitable(executed): + result = self.loop.run_until_complete(executed) except Exception as exc: return 'ERROR: GraphQL execution error \n%s' % exc - if executed.errors: + if getattr(result, 'errors', None): errors: List[Any] = [] - for error in executed.errors: + for error in result.errors: LOG.error(error) if hasattr(error, '__traceback__'): import traceback @@ -402,7 +402,7 @@ def graphql( continue errors.append(getattr(error, 'message', None)) return errors - return executed.data + return result.data # UIServer Data Commands @authorise() diff --git a/cylc/flow/scripts/stop.py b/cylc/flow/scripts/stop.py index ebdb6380cd7..767982f6296 100755 --- a/cylc/flow/scripts/stop.py +++ b/cylc/flow/scripts/stop.py @@ -250,6 +250,8 @@ async def _run( mode = WorkflowStopMode.NowNow.name elif options.now: mode = WorkflowStopMode.Now.name + else: + mode = WorkflowStopMode.Clean.name mutation_kwargs = { 'request_string': MUTATION, @@ -259,7 +261,10 @@ async def _run( 'cyclePoint': stop_cycle, 'clockTime': options.wall_clock, 'task': stop_task, - 'flowNum': options.flow_num + 'flowNum': ( + int(options.flow_num) + if options.flow_num is not None else None + ) } } diff --git a/setup.cfg b/setup.cfg index 4e3b2f65200..7759425ffa6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,7 +65,7 @@ install_requires = ansimarkup>=1.0.0 async-timeout>=3.0.0 colorama>=0.4,<1 - graphene>=2.1,<3 + graphene>=3.4.3,<4 # Note: can't pin jinja2 any higher than this until we give up on Cylc 7 back-compat jinja2==3.0.* metomi-isodatetime>=1!3.0.0,<1!3.2.0 @@ -78,8 +78,6 @@ install_requires = # NOTE: exclude two urwid versions that were not compatible with Tui urwid==2.*,!=2.6.2,!=2.6.3 # unpinned transient dependencies used for type checking - rx - promise tomli>=2; python_version < "3.11" [options.packages.find] diff --git a/tests/integration/test_graphql.py b/tests/integration/test_graphql.py index 9b934f3fe40..8be110c3990 100644 --- a/tests/integration/test_graphql.py +++ b/tests/integration/test_graphql.py @@ -199,14 +199,14 @@ async def test_task_proxies(harness): w_tokens.duplicate( cycle='1', task=namespace, - ).id + ) # NOTE: task "d" is not in the n=1 window yet for namespace in ('a', 'b', 'c') ] ret['taskProxies'].sort(key=lambda x: x['id']) assert ret == { 'taskProxies': [ - {'id': id_} + {'id': id_.id} for id_ in ids ] } @@ -214,13 +214,27 @@ async def test_task_proxies(harness): # query "task" ret = await client.async_request( 'graphql', - {'request_string': 'query { taskProxy(id: "%s") { id } }' % ids[0]} + {'request_string': 'query { taskProxy(id: "%s") { id } }' % ids[0].id} ) assert ret == { - 'taskProxy': {'id': ids[0]} + 'taskProxy': {'id': ids[0].id} } + # query "taskProxies" fragment with null stripping + ret = await client.async_request( + 'graphql', + {'request_string': ''' + fragment wf on Workflow { + taskProxies (ids: ["%s"], stripNull: true) { id } + } + query { workflows (ids: ["%s"]) { ...wf } } + ''' % (ids[0].relative_id, ids[0].workflow_id) + } + ) + assert ret == {'workflows': [{'taskProxies': [{'id': ids[0].id}]}]} + + async def test_family_proxies(harness): schd, client, w_tokens = harness diff --git a/tests/integration/test_server.py b/tests/integration/test_server.py index bc7103b8365..1a5b770a1f8 100644 --- a/tests/integration/test_server.py +++ b/tests/integration/test_server.py @@ -55,6 +55,22 @@ def test_graphql(myflow): assert myflow.id == data['workflows'][0]['id'] +def test_graphql_error(myflow): + """Test GraphQL endpoint method.""" + request_string = f''' + query {{ + workflows(ids: ["{myflow.id}"]) {{ + id + notafield + alsonotafield + }} + }} + ''' + errors = call_server_method(myflow.server.graphql, request_string) + for error in errors: + assert 'error' in error + + def test_pb_data_elements(myflow): """Test Protobuf elements endpoint method.""" element_type = 'workflow' diff --git a/tests/unit/network/test_graphql.py b/tests/unit/network/test_graphql.py index e5be079e289..1b9c9a056c9 100644 --- a/tests/unit/network/test_graphql.py +++ b/tests/unit/network/test_graphql.py @@ -18,11 +18,11 @@ import pytest from pytest import param -from graphql import parse +from graphql import parse, TypeInfo, TypeInfoVisitor, visit from cylc.flow.data_messages_pb2 import PbTaskProxy, PbPrerequisite from cylc.flow.network.graphql import ( - AstDocArguments, null_setter, NULL_VALUE, grow_tree + CylcVisitor, null_setter, null_stripper, async_next, NULL_VALUE, grow_tree ) from cylc.flow.network.schema import schema @@ -34,8 +34,8 @@ @pytest.mark.parametrize( 'query,' 'variables,' - 'expected_variables,' - 'expected_error', + 'search_arg,' + 'expected_result', [ pytest.param( ''' @@ -49,9 +49,10 @@ 'workflowID': 'cylc|workflow' }, { - 'workflowID': 'cylc|workflow' + 'arg': 'ids', + 'val': ['cylc|workflow'], }, - None, + True, id="simple query with correct variables" ), pytest.param( @@ -69,9 +70,10 @@ 'workflowID': 'cylc|workflow' }, { - 'workflowID': 'cylc|workflow' + 'arg': 'ids', + 'val': ['cylc|workflow'], }, - None, + True, id="query with a fragment and correct variables" ), pytest.param( @@ -85,46 +87,67 @@ { 'workflowId': 'cylc|workflow' }, - None, - ValueError, + { + 'arg': 'ids', + 'val': None, + }, + False, id="correct variable definition, but missing variable in " "provided values" + ), + pytest.param( + ''' + query ($workflowID: ID) { + workflows (ids: [$workflowID]) { + id + } + } + ''', + { + 'workflowId': 'cylc|workflow' + }, + { + 'arg': 'idfsdf', + 'val': ['cylc|workflow'], + }, + False, + id="correct variable definition, but wrong search argument" ) ] ) def test_query_variables( query: str, variables: dict, - expected_variables: Optional[dict], - expected_error: Optional[Type[Exception]], + search_arg: dict, + expected_result: bool, ): - """Test that query variables are parsed correctly. + """Test that query variables are parsed and found correctly. Args: query: a valid GraphQL query (using our schema) variables: map with variable values for the query - expected_variables: expected parsed variables - expected_error: expected error, if any + search_arg: argument and value to search for + expected_result: was the argument and value found """ def test(): """Inner function to avoid duplication in if/else""" document = parse(query) - document_arguments = AstDocArguments( - schema=schema, - document_ast=document, - variable_values=variables + type_info = TypeInfo(self.schema) + cylc_visitor = CylcVisitor( + type_info, + variables, + search_arg + ) + visit( + self.operation, + TypeInfoVisitor( + type_info, + cylc_visitor + ), + None ) - parsed_variables = next( - iter( - document_arguments.operation_defs.values() - ) - )['variables'] - assert expected_variables == parsed_variables - if expected_error is not None: - with pytest.raises(expected_error): - test() - else: - test() + + assert expected_result == cylc_visitor.arg_flag @pytest.mark.parametrize( @@ -159,6 +182,46 @@ def test_null_setter(pre_result, expected_result): assert post_result == expected_result +@pytest.mark.parametrize( + 'pre_result,' + 'expected_result', + [ + ( + 'foo', + 'foo' + ), + ( + [NULL_VALUE], + [] + ), + ( + {'nothing': NULL_VALUE}, + {}, + ), + ( + TASK_PROXY_PREREQS.prerequisites, + TASK_PROXY_PREREQS.prerequisites + ), + ( + [NULL_VALUE], + [], + ) + ] +) +async def test_null_stripper(pre_result, expected_result): + """Test the null stripping of different result data/types.""" + # non-async + post_result = async_next(null_stripper, pre_result) + assert post_result == expected_result + + async def async_result(result): + return result + + # async + async_post_result = async_next(null_stripper, async_result(pre_result)) + assert await async_post_result == expected_result + + @pytest.mark.parametrize( 'expect, tree, path, leaves', [ diff --git a/tests/unit/test_links.py b/tests/unit/test_links.py index 9f8a228ad68..2adfb8b284f 100644 --- a/tests/unit/test_links.py +++ b/tests/unit/test_links.py @@ -27,7 +27,8 @@ import re from time import sleep import pytest -import urllib +from urllib import request +from urllib.error import HTTPError EXCLUDE = [ r'*//www.gnu.org/licenses/', @@ -60,13 +61,13 @@ def test_embedded_url(link): to run in parallel """ try: - urllib.request.urlopen(link).getcode() - except urllib.error.HTTPError: + request.urlopen(link).getcode() + except HTTPError: # Sleep and retry to reduce risk of flakiness: sleep(10) try: - urllib.request.urlopen(link).getcode() - except urllib.error.HTTPError as exc: + request.urlopen(link).getcode() + except HTTPError as exc: # Allowing 403 - just because a site forbids us doesn't mean the # link is wrong. if exc.code != 403: