From ee3102d1517865f3ba16e3278cc52f12d2c9d2eb Mon Sep 17 00:00:00 2001 From: Ronnie Dutta <61982285+MetRonnie@users.noreply.github.com> Date: Tue, 28 Jan 2025 16:19:44 +0000 Subject: [PATCH] Improve server-client communication error handling --- cylc/flow/network/__init__.py | 70 +++++++++++++---- cylc/flow/network/client.py | 70 ++++++++++------- cylc/flow/network/multi.py | 28 +++++-- cylc/flow/network/replier.py | 54 ++++++++----- cylc/flow/network/server.py | 122 +++++++++++++++++------------- cylc/flow/run_modes/dummy.py | 19 +++-- cylc/flow/run_modes/simulation.py | 20 +++-- cylc/flow/run_modes/skip.py | 21 +++-- cylc/flow/task_outputs.py | 9 ++- tests/integration/test_client.py | 35 +++++++++ tests/integration/test_replier.py | 28 ++++--- tests/integration/test_server.py | 66 +++++++++++----- 12 files changed, 377 insertions(+), 165 deletions(-) diff --git a/cylc/flow/network/__init__.py b/cylc/flow/network/__init__.py index 42b79475ca5..882fe9b99d1 100644 --- a/cylc/flow/network/__init__.py +++ b/cylc/flow/network/__init__.py @@ -18,7 +18,12 @@ import asyncio import getpass import json -from typing import Optional, Tuple +from typing import ( + TYPE_CHECKING, + Optional, + Tuple, + Union, +) import zmq import zmq.asyncio @@ -30,34 +35,71 @@ CylcError, CylcVersionError, ServiceFileError, - WorkflowStopped + WorkflowStopped, ) from cylc.flow.hostuserutil import get_fqdn_by_host from cylc.flow.workflow_files import ( ContactFileFields, - KeyType, - KeyOwner, KeyInfo, + KeyOwner, + KeyType, + get_workflow_srv_dir, load_contact_file, - get_workflow_srv_dir ) + +if TYPE_CHECKING: + # BACK COMPAT: typing_extensions.TypedDict + # FROM: Python 3.7 + # TO: Python 3.11 + from typing_extensions import TypedDict + + API = 5 # cylc API version MSG_TIMEOUT = "TIMEOUT" +if TYPE_CHECKING: + class ResponseDict(TypedDict, total=False): + """Structure of server response messages. -def encode_(message): - """Convert the structure holding a message field from JSON to a string.""" - try: - return json.dumps(message) - except TypeError as exc: - return json.dumps({'errors': [{'message': str(exc)}]}) + Confusingly, has similar format to GraphQL execution result. + But if we change this now we could break compatibility for + issuing commands to/receiving responses from workflows running in + different versions of Cylc 8. + """ + data: object + """For most Cylc commands that issue GQL mutations, the data field will + look like: + data: { + : { + result: [ + { + id: , + response: [, ] + }, + ... + ] + } + } + but this is not 100% consistent unfortunately + """ + error: Union[Exception, str, dict] + """If an error occurred that could not be handled. + (usually a dict {message: str, traceback?: str}). + """ + user: str + cylc_version: str + """Server (i.e. running workflow) Cylc version. + + Going forward, we include this so we can more easily handle any future + back-compat issues.""" -def decode_(message): - """Convert an encoded message string to JSON with an added 'user' field.""" +def load_server_response(message: str) -> 'ResponseDict': + """Convert a JSON message string to dict with an added 'user' field.""" msg = json.loads(message) - msg['user'] = getpass.getuser() # assume this is the user + if 'user' not in msg: + msg['user'] = getpass.getuser() # assume this is the user return msg diff --git a/cylc/flow/network/client.py b/cylc/flow/network/client.py index 099ef8bc0ff..20a8afb0552 100644 --- a/cylc/flow/network/client.py +++ b/cylc/flow/network/client.py @@ -15,18 +15,31 @@ # along with this program. If not, see . """Client for workflow runtime API.""" -from abc import ABCMeta, abstractmethod +from abc import ( + ABCMeta, + abstractmethod, +) import asyncio +import json import os from shutil import which import socket import sys -from typing import Any, Optional, Union, Dict +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Optional, + Union, +) import zmq import zmq.asyncio -from cylc.flow import LOG +from cylc.flow import ( + LOG, + __version__ as CYLC_VERSION, +) from cylc.flow.exceptions import ( ClientError, ClientTimeout, @@ -36,16 +49,17 @@ ) from cylc.flow.hostuserutil import get_fqdn_by_host from cylc.flow.network import ( - encode_, - decode_, + ZMQSocketBase, get_location, - ZMQSocketBase + load_server_response, ) from cylc.flow.network.client_factory import CommsMeth from cylc.flow.network.server import PB_METHOD_MAP -from cylc.flow.workflow_files import ( - detect_old_contact_file, -) +from cylc.flow.workflow_files import detect_old_contact_file + + +if TYPE_CHECKING: + from cylc.flow.network import ResponseDict class WorkflowRuntimeClientBase(metaclass=ABCMeta): @@ -270,7 +284,7 @@ async def async_request( args: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, req_meta: Optional[Dict[str, Any]] = None - ) -> object: + ) -> Union[bytes, object]: """Send an asynchronous request using asyncio. Has the same arguments and return values as ``serial_request``. @@ -292,12 +306,12 @@ async def async_request( if req_meta: msg['meta'].update(req_meta) LOG.debug('zmq:send %s', msg) - message = encode_(msg) + message = json.dumps(msg) self.socket.send_string(message) # receive response if self.poller.poll(timeout): - res = await self.socket.recv() + res: bytes = await self.socket.recv() else: self.timeout_handler() raise ClientTimeout( @@ -307,26 +321,28 @@ async def async_request( ' --comms-timeout option;' '\n* or check the workflow log.' ) + LOG.debug('zmq:recv %s', res) - if msg['command'] in PB_METHOD_MAP: - response = {'data': res} - else: - response = decode_( - res.decode() if isinstance(res, bytes) else res - ) - LOG.debug('zmq:recv %s', response) + if command in PB_METHOD_MAP: + return res + + response: ResponseDict = load_server_response(res.decode()) try: return response['data'] except KeyError: - error = response.get( - 'error', - {'message': f'Received invalid response: {response}'}, - ) - raise ClientError( - error.get('message'), # type: ignore - error.get('traceback'), # type: ignore - ) from None + error = response.get('error') + if not error: + error = ( + f"Received invalid response for Cylc {CYLC_VERSION}: " + f"{response}" + ) + wflow_cylc_ver = response.get('cylc_version') + if wflow_cylc_ver: + error += ( + f"\n(Workflow is running in Cylc {wflow_cylc_ver})" + ) + raise ClientError(str(error)) from None def get_header(self) -> dict: """Return "header" data to attach to each request for traceability. diff --git a/cylc/flow/network/multi.py b/cylc/flow/network/multi.py index 9c190f68799..a61c6d05bd2 100644 --- a/cylc/flow/network/multi.py +++ b/cylc/flow/network/multi.py @@ -16,12 +16,23 @@ import asyncio import sys -from typing import Callable, Dict, List, Tuple, Optional, Union, Type +from typing import ( + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) from ansimarkup import ansiprint from cylc.flow.async_util import unordered_map -from cylc.flow.exceptions import CylcError, WorkflowStopped +from cylc.flow.exceptions import ( + CylcError, + WorkflowStopped, +) import cylc.flow.flags from cylc.flow.id_cli import parse_ids_async from cylc.flow.terminal import DIM @@ -220,14 +231,15 @@ def _process_response( def _report( - response: dict, + response: Union[dict, list], ) -> Tuple[Optional[str], Optional[str], bool]: """Report the result of a GraphQL operation. This analyses GraphQL mutation responses to determine the outcome. Args: - response: The GraphQL response. + response: The workflow server response (NOT necessarily conforming to + GraphQL execution result spec). Returns: (stdout, stderr, outcome) @@ -235,6 +247,12 @@ def _report( """ try: ret: List[Tuple[Optional[str], Optional[str], bool]] = [] + if not isinstance(response, dict): + if isinstance(response, list) and response[0].get('error'): + # If operating on workflow running in older Cylc version, + # may get a error response like [{'error': '...'}] + raise Exception(response) + raise Exception(f"Unexpected response: {response}") for mutation_response in response.values(): # extract the result of each mutation result in the response success, msg = mutation_response['result'][0]['response'] @@ -268,7 +286,7 @@ def _report( # response returned is not in the expected format - this shouldn't # happen but we need to protect against it err_msg = '' - if cylc.flow.flags.verbosity > 1: # debug mode + if cylc.flow.flags.verbosity > 0: # verbose mode # print the full result to stderr err_msg += f'\n <{DIM}>response={response}' return ( diff --git a/cylc/flow/network/replier.py b/cylc/flow/network/replier.py index 09bfb55f662..98555a1166f 100644 --- a/cylc/flow/network/replier.py +++ b/cylc/flow/network/replier.py @@ -15,15 +15,27 @@ # along with this program. If not, see . """Server for workflow runtime API.""" +import json from queue import Queue -from typing import TYPE_CHECKING, Optional +from typing import ( + TYPE_CHECKING, + Optional, +) import zmq -from cylc.flow import LOG -from cylc.flow.network import encode_, decode_, ZMQSocketBase +from cylc.flow import ( + LOG, + __version__ as CYLC_VERSION, +) +from cylc.flow.network import ( + ZMQSocketBase, + load_server_response, +) + if TYPE_CHECKING: + from cylc.flow.network import ResponseDict from cylc.flow.network.server import WorkflowRuntimeServer @@ -69,7 +81,7 @@ def _bespoke_stop(self) -> None: LOG.debug('stopping zmq replier...') self.queue.put('STOP') - def listener(self): + def listener(self) -> None: """The server main loop, listen for and serve requests. When called, this method will receive and respond until there are no @@ -90,7 +102,9 @@ def listener(self): try: # Check for messages - msg = self.socket.recv_string(zmq.NOBLOCK) + msg = self.socket.recv_string( # type: ignore[union-attr] + zmq.NOBLOCK + ) except zmq.error.Again: # No messages, break to parent loop/caller. break @@ -99,27 +113,27 @@ def listener(self): continue # attempt to decode the message, authenticating the user in the # process + res: ResponseDict + response: bytes try: - message = decode_(msg) + message = load_server_response(msg) except Exception as exc: # purposefully catch generic exception # failed to decode message, possibly resulting from failed # authentication - LOG.exception('failed to decode message: "%s"', exc) - import traceback - response = encode_( - { - 'error': { - 'message': 'failed to decode message: "%s"' % msg, - 'traceback': traceback.format_exc(), - } - } - ).encode() + LOG.exception(exc) + LOG.error(f'failed to decode message: "{msg}"') + res = { + 'error': {'message': str(exc)}, + 'cylc_version': CYLC_VERSION, + } + response = json.dumps(res).encode() else: # success case - serve the request res = self.server.receiver(message) + data = res.get('data') # send back the string to bytes response - if isinstance(res.get('data'), bytes): - response = res['data'] + if isinstance(data, bytes): + response = data else: - response = encode_(res).encode() - self.socket.send(response) + response = json.dumps(res).encode() + self.socket.send(response) # type: ignore[union-attr] diff --git a/cylc/flow/network/server.py b/cylc/flow/network/server.py index 2c170e61198..9a44d46921e 100644 --- a/cylc/flow/network/server.py +++ b/cylc/flow/network/server.py @@ -19,29 +19,46 @@ from queue import Queue from textwrap import dedent from time import sleep -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Union, +) from graphql.execution.executors.asyncio import AsyncioExecutor import zmq from zmq.auth.thread import ThreadAuthenticator -from cylc.flow import LOG, workflow_files +from cylc.flow import ( + LOG, + __version__ as CYLC_VERSION, + workflow_files, +) from cylc.flow.cfgspec.glbl_cfg import glbl_cfg +from cylc.flow.data_messages_pb2 import PbEntireWorkflow +from cylc.flow.data_store_mgr import DELTAS_MAP from cylc.flow.network.authorisation import authorise from cylc.flow.network.graphql import ( - CylcGraphQLBackend, IgnoreFieldMiddleware, instantiate_middleware + CylcGraphQLBackend, + IgnoreFieldMiddleware, + instantiate_middleware, ) from cylc.flow.network.publisher import WorkflowPublisher from cylc.flow.network.replier import WorkflowReplier from cylc.flow.network.resolvers import Resolvers from cylc.flow.network.schema import schema -from cylc.flow.data_store_mgr import DELTAS_MAP -from cylc.flow.data_messages_pb2 import PbEntireWorkflow + if TYPE_CHECKING: - from cylc.flow.scheduler import Scheduler from graphql.execution import ExecutionResult + from cylc.flow.network import ResponseDict + from cylc.flow.scheduler import Scheduler + # maps server methods to the protobuf message (for client/UIS import) PB_METHOD_MAP: Dict[str, Any] = { @@ -267,7 +284,7 @@ async def publish_queued_items(self) -> None: articles = self.publish_queue.get() await self.publisher.publish(*articles) - def receiver(self, message): + def receiver(self, message) -> 'ResponseDict': """Process incoming messages and coordinate response. Wrap incoming messages, dispatch them to exposed methods and/or @@ -285,26 +302,44 @@ def receiver(self, message): args.update({'user': message['user']}) if 'meta' in message: args['meta'] = message['meta'] - except KeyError: + except KeyError as exc: # malformed message - return {'error': { - 'message': 'Request missing required field(s).'}} + return { + 'error': { + 'message': ( + f"Request missing field {exc} required for " + f"Cylc {CYLC_VERSION}" + ) + }, + 'cylc_version': CYLC_VERSION, + } except AttributeError: # no exposed method by that name - return {'error': { - 'message': 'No method by the name "%s"' % message['command']}} + return { + 'error': { + 'message': ( + f"No method by the name '{message['command']}' " + f"at Cylc {CYLC_VERSION}" + ) + }, + 'cylc_version': CYLC_VERSION, + } # generate response try: - response = method(**args) + data = method(**args) except Exception as exc: # includes incorrect arguments (TypeError) - LOG.exception(exc) # note the error server side - import traceback - return {'error': { - 'message': str(exc), 'traceback': traceback.format_exc()}} + LOG.exception(exc) # log the error server side + return { + 'error': {'message': str(exc)}, + 'cylc_version': CYLC_VERSION, + } - return {'data': response} + return { + 'data': data, + 'cylc_version': CYLC_VERSION, + } def register_endpoints(self): """Register all exposed methods.""" @@ -357,7 +392,7 @@ def graphql( variables: Optional[Dict[str, Any]] = None, meta: Optional[Dict[str, Any]] = None ): - """Return the GraphQL schema execution result. + """Return the data field of the GraphQL schema execution result. Args: request_string: GraphQL request passed to Graphene. @@ -367,41 +402,24 @@ def graphql( Returns: object: Execution result, or a list with errors. """ - try: - executed: 'ExecutionResult' = schema.execute( - request_string, - variable_values=variables, - context_value={ - 'resolvers': self.resolvers, - 'meta': meta or {}, - }, - backend=CylcGraphQLBackend(), - middleware=list(instantiate_middleware(self.middleware)), - executor=AsyncioExecutor(), - validate=True, # validate schema (dev only? default is True) - return_promise=False, - ) - except Exception as exc: - return 'ERROR: GraphQL execution error \n%s' % exc + executed: 'ExecutionResult' = schema.execute( + request_string, + variable_values=variables, + context_value={ + 'resolvers': self.resolvers, + 'meta': meta or {}, + }, + backend=CylcGraphQLBackend(), + middleware=list(instantiate_middleware(self.middleware)), + executor=AsyncioExecutor(), + validate=True, # validate schema (dev only? default is True) + return_promise=False, + ) if executed.errors: - errors: List[Any] = [] for error in executed.errors: - LOG.error(error) - if hasattr(error, '__traceback__'): - import traceback - formatted_tb = traceback.format_exception( - type(error), error, error.__traceback__ - ) - LOG.error("".join(formatted_tb)) - errors.append({ - 'error': { - 'message': str(error), - 'traceback': formatted_tb - } - }) - continue - errors.append(getattr(error, 'message', None)) - return errors + LOG.warning(error) + if not executed.data: + raise Exception(executed.errors[0]) return executed.data # UIServer Data Commands diff --git a/cylc/flow/run_modes/dummy.py b/cylc/flow/run_modes/dummy.py index 26d887d87dc..a0c090b7386 100644 --- a/cylc/flow/run_modes/dummy.py +++ b/cylc/flow/run_modes/dummy.py @@ -18,22 +18,31 @@ Dummy mode shares settings with simulation mode. """ -from typing import TYPE_CHECKING, Any, Dict, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Tuple, +) +from cylc.flow.platforms import get_platform +from cylc.flow.run_modes import RunMode from cylc.flow.run_modes.simulation import ( ModeSettings, disable_platforms, get_simulated_run_len, - parse_fail_cycle_points + parse_fail_cycle_points, ) -from cylc.flow.run_modes import RunMode -from cylc.flow.platforms import get_platform if TYPE_CHECKING: + # BACK COMPAT: typing_extensions.Literal + # FROM: Python 3.7 + # TO: Python 3.8 + from typing_extensions import Literal + from cylc.flow.task_job_mgr import TaskJobManager from cylc.flow.task_proxy import TaskProxy - from typing_extensions import Literal CLEAR_THESE_SCRIPTS = [ diff --git a/cylc/flow/run_modes/simulation.py b/cylc/flow/run_modes/simulation.py index 900a2c1fc4f..8bbb8ccefa1 100644 --- a/cylc/flow/run_modes/simulation.py +++ b/cylc/flow/run_modes/simulation.py @@ -18,9 +18,15 @@ from dataclasses import dataclass from logging import INFO -from typing import ( - TYPE_CHECKING, Any, Dict, List, Tuple, Union) from time import time +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Tuple, + Union, +) from metomi.isodatetime.parsers import DurationParser @@ -29,22 +35,26 @@ from cylc.flow.cycling.loader import get_point from cylc.flow.exceptions import PointParsingError from cylc.flow.platforms import FORBIDDEN_WITH_PLATFORM +from cylc.flow.run_modes import RunMode from cylc.flow.task_outputs import TASK_OUTPUT_SUBMITTED from cylc.flow.task_state import ( - TASK_STATUS_RUNNING, TASK_STATUS_FAILED, + TASK_STATUS_RUNNING, TASK_STATUS_SUCCEEDED, ) from cylc.flow.wallclock import get_unix_time_from_time_string -from cylc.flow.run_modes import RunMode if TYPE_CHECKING: + # BACK COMPAT: typing_extensions.Literal + # FROM: Python 3.7 + # TO: Python 3.8 + from typing_extensions import Literal + from cylc.flow.task_events_mgr import TaskEventsManager from cylc.flow.task_job_mgr import TaskJobManager from cylc.flow.task_proxy import TaskProxy from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager - from typing_extensions import Literal def submit_task_job( diff --git a/cylc/flow/run_modes/skip.py b/cylc/flow/run_modes/skip.py index 49736883911..4b0770bd159 100644 --- a/cylc/flow/run_modes/skip.py +++ b/cylc/flow/run_modes/skip.py @@ -17,23 +17,32 @@ """ from logging import INFO from typing import ( - TYPE_CHECKING, Dict, List, Tuple) + TYPE_CHECKING, + Dict, + List, + Tuple, +) from cylc.flow import LOG from cylc.flow.exceptions import WorkflowConfigError +from cylc.flow.run_modes import RunMode from cylc.flow.task_outputs import ( + TASK_OUTPUT_FAILED, + TASK_OUTPUT_STARTED, TASK_OUTPUT_SUBMITTED, TASK_OUTPUT_SUCCEEDED, - TASK_OUTPUT_FAILED, - TASK_OUTPUT_STARTED ) -from cylc.flow.run_modes import RunMode + if TYPE_CHECKING: - from cylc.flow.taskdef import TaskDef + # BACK COMPAT: typing_extensions.Literal + # FROM: Python 3.7 + # TO: Python 3.8 + from typing_extensions import Literal + from cylc.flow.task_job_mgr import TaskJobManager from cylc.flow.task_proxy import TaskProxy - from typing_extensions import Literal + from cylc.flow.taskdef import TaskDef def submit_task_job( diff --git a/cylc/flow/task_outputs.py b/cylc/flow/task_outputs.py index 8548ab405e4..11a3979c860 100644 --- a/cylc/flow/task_outputs.py +++ b/cylc/flow/task_outputs.py @@ -18,12 +18,12 @@ import ast import re from typing import ( + TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, - TYPE_CHECKING, Tuple, Union, ) @@ -35,10 +35,15 @@ restricted_evaluator, ) + if TYPE_CHECKING: - from cylc.flow.taskdef import TaskDef + # BACK COMPAT: typing_extensions.Literal + # FROM: Python 3.7 + # TO: Python 3.8 from typing_extensions import Literal + from cylc.flow.taskdef import TaskDef + # Standard task output strings, used for triggering. TASK_OUTPUT_EXPIRED = "expired" diff --git a/tests/integration/test_client.py b/tests/integration/test_client.py index 2195d3a112b..3d58076b860 100644 --- a/tests/integration/test_client.py +++ b/tests/integration/test_client.py @@ -15,8 +15,11 @@ # along with this program. If not, see . """Test cylc.flow.client.WorkflowRuntimeClient.""" +import json +from unittest.mock import Mock import pytest +from cylc.flow.exceptions import ClientError from cylc.flow.network.client import WorkflowRuntimeClient from cylc.flow.network.server import PB_METHOD_MAP @@ -88,3 +91,35 @@ async def test_command_validation_failure(harness): 'response': [False, '--pre=all must be used alone'], } ] + + +@pytest.mark.parametrize( + 'sock_response, expected', + [ + pytest.param({'error': 'message'}, r"^message$", id="basic"), + pytest.param( + {'foo': 1}, + r"^Received invalid response for Cylc 8\.[\w.]+: \{'foo': 1[^}]*\}$", + id="no-err-field", + ), + pytest.param( + {'cylc_version': '8.x.y'}, + r"^Received invalid.+\n\(Workflow is running in Cylc 8.x.y\)$", + id="no-err-field-with-version", + ), + ], +) +async def test_async_request_err( + harness, monkeypatch: pytest.MonkeyPatch, sock_response, expected +): + async def mock_recv(): + return json.dumps(sock_response).encode() + + client: WorkflowRuntimeClient + schd, client = harness + with monkeypatch.context() as mp: + mp.setattr(client, 'socket', Mock(recv=mock_recv)) + mp.setattr(client.poller, 'poll', Mock()) + + with pytest.raises(ClientError, match=expected): + await client.async_request('graphql') diff --git a/tests/integration/test_replier.py b/tests/integration/test_replier.py index ce0b53fdaa8..74addb83113 100644 --- a/tests/integration/test_replier.py +++ b/tests/integration/test_replier.py @@ -14,28 +14,38 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from async_timeout import timeout -from cylc.flow.network import decode_ -from cylc.flow.network.client import WorkflowRuntimeClient import asyncio +import getpass +from async_timeout import timeout import pytest +from cylc.flow import __version__ as CYLC_VERSION +from cylc.flow.network import load_server_response +from cylc.flow.network.client import WorkflowRuntimeClient +from cylc.flow.scheduler import Scheduler + -async def test_listener(one, start, ): +async def test_listener(one: Scheduler, start): """Test listener.""" async with start(one): + # Test listener handles an invalid message from client + # (without directly calling listener): client = WorkflowRuntimeClient(one.workflow) client.socket.send_string(r'Not JSON') - res = await client.socket.recv() - assert 'error' in decode_(res.decode()) + res = load_server_response( + (await client.socket.recv()).decode() + ) + assert res['error'] + assert 'data' not in res + # Check other fields are present: + assert res['cylc_version'] == CYLC_VERSION + assert res['user'] == getpass.getuser() one.server.replier.queue.put('STOP') async with timeout(2): # wait for the server to consume the STOP item from the queue - while True: - if one.server.replier.queue.empty(): - break + while not one.server.replier.queue.empty(): await asyncio.sleep(0.01) # ensure the server is "closed" one.server.replier.queue.put('foobar') diff --git a/tests/integration/test_server.py b/tests/integration/test_server.py index bc7103b8365..6d6e0c939ba 100644 --- a/tests/integration/test_server.py +++ b/tests/integration/test_server.py @@ -14,11 +14,13 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import logging from typing import Callable from async_timeout import timeout from getpass import getuser import pytest +from cylc.flow import __version__ as CYLC_VERSION from cylc.flow.network.server import PB_METHOD_MAP from cylc.flow.scheduler import Scheduler @@ -89,35 +91,59 @@ async def test_stop(one: Scheduler, start): assert one.server.stopped -async def test_receiver(one: Scheduler, start): +async def test_receiver_basic(one: Scheduler, start, log_filter): """Test the receiver with different message objects.""" async with timeout(5): async with start(one): # start with a message that works - msg = {'command': 'api', 'user': '', 'args': {}} - assert 'error' not in one.server.receiver(msg) - assert 'data' in one.server.receiver(msg) - - # remove the user field - should error - msg2 = dict(msg) - msg2.pop('user') - assert 'error' in one.server.receiver(msg2) - - # remove the command field - should error - msg3 = dict(msg) - msg3.pop('command') - assert 'error' in one.server.receiver(msg3) - - # provide an invalid command - should error - msg4 = {**msg, 'command': 'foobar'} - assert 'error' in one.server.receiver(msg4) + msg = {'command': 'api', 'user': 'bono', 'args': {}} + res = one.server.receiver(msg) + assert not res.get('error') + assert res['data'] + assert res['cylc_version'] == CYLC_VERSION # simulate a command failure with the original message # (the one which worked earlier) - should error def _api(*args, **kwargs): - raise Exception('foo') + raise Exception('oopsie') one.server.api = _api - assert 'error' in one.server.receiver(msg) + res = one.server.receiver(msg) + assert res == { + 'error': {'message': 'oopsie'}, + 'cylc_version': CYLC_VERSION, + } + assert log_filter(logging.ERROR, 'oopsie') + + +@pytest.mark.parametrize( + 'msg, expected', + [ + pytest.param( + {'command': 'api', 'args': {}}, + f"Request missing field 'user' required for Cylc {CYLC_VERSION}", + id='missing-user', + ), + pytest.param( + {'user': 'bono', 'args': {}}, + f"Request missing field 'command' required for Cylc {CYLC_VERSION}", + id='missing-command', + ), + pytest.param( + {'command': 'foobar', 'user': 'bono', 'args': {}}, + f"No method by the name 'foobar' at Cylc {CYLC_VERSION}", + id='bad-command', + ), + ], +) +async def test_receiver_bad_requests(one: Scheduler, start, msg, expected): + """Test the receiver with different bad requests.""" + async with timeout(5): + async with start(one): + res = one.server.receiver(msg) + assert res == { + 'error': {'message': expected}, + 'cylc_version': CYLC_VERSION, + } async def test_publish_before_shutdown(