diff --git a/cylc/flow/network/__init__.py b/cylc/flow/network/__init__.py
index 916b129e244..96da617db40 100644
--- a/cylc/flow/network/__init__.py
+++ b/cylc/flow/network/__init__.py
@@ -13,279 +13,18 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-"""Package for network interfaces to Cylc scheduler objects."""
-import asyncio
-import getpass
-import json
-from typing import Optional, Tuple
+"""Cylc networking code.
-import zmq
-import zmq.asyncio
-import zmq.auth
+Contains:
+* Server code (hosted by the scheduler process).
+* Client implementations (used to communicate with the scheduler).
+* Workflow scanning logic.
+* Schema and interface definitions.
+"""
-from cylc.flow import LOG
-from cylc.flow.exceptions import (
- ClientError,
- CylcError,
- CylcVersionError,
- ServiceFileError,
- WorkflowStopped
-)
-from cylc.flow.hostuserutil import get_fqdn_by_host
-from cylc.flow.workflow_files import (
- ContactFileFields,
- KeyType,
- KeyOwner,
- KeyInfo,
- load_contact_file,
- get_workflow_srv_dir
-)
-
-API = 5 # cylc API version
-MSG_TIMEOUT = "TIMEOUT"
-
-
-def encode_(message):
- """Convert the structure holding a message field from JSON to a string."""
- try:
- return json.dumps(message)
- except TypeError as exc:
- return json.dumps({'errors': [{'message': str(exc)}]})
-
-
-def decode_(message):
- """Convert an encoded message string to JSON with an added 'user' field."""
- msg = json.loads(message)
- msg['user'] = getpass.getuser() # assume this is the user
- return msg
-
-
-def get_location(workflow: str) -> Tuple[str, int, int]:
- """Extract host and port from a workflow's contact file.
-
- NB: if it fails to load the workflow contact file, it will exit.
-
- Args:
- workflow: workflow ID
- Returns:
- Tuple (host name, port number, publish port number)
- Raises:
- WorkflowStopped: if the workflow is not running.
- CylcVersionError: if target is a Cylc 7 (or earlier) workflow.
- """
- try:
- contact = load_contact_file(workflow)
- except (IOError, ValueError, ServiceFileError):
- # Contact file does not exist or corrupted, workflow should be dead
- raise WorkflowStopped(workflow)
-
- host = contact[ContactFileFields.HOST]
- host = get_fqdn_by_host(host)
- port = int(contact[ContactFileFields.PORT])
- if ContactFileFields.PUBLISH_PORT in contact:
- pub_port = int(contact[ContactFileFields.PUBLISH_PORT])
- else:
- version = contact.get('CYLC_VERSION', None)
- raise CylcVersionError(version=version)
- return host, port, pub_port
-
-
-class ZMQSocketBase:
- """Initiate the ZMQ socket bind for specified pattern.
-
- NOTE: Security to be provided via zmq.auth (see PR #3359).
-
- Args:
- pattern (enum): ZeroMQ message pattern (zmq.PATTERN).
-
- context (object, optional): instantiated ZeroMQ context, defaults
- to zmq.asyncio.Context().
-
- This class is designed to be inherited by REP Server (REQ/REP)
- and by PUB Publisher (PUB/SUB), as the start-up logic is similar.
-
-
- To tailor this class overwrite it's method on inheritance.
-
- """
-
- def __init__(
- self,
- pattern,
- workflow: str,
- bind: bool = False,
- context: Optional[zmq.Context] = None,
- ):
- self.bind = bind
- if context is None:
- self.context: zmq.Context = zmq.asyncio.Context()
- else:
- self.context = context
- self.pattern = pattern
- self.workflow = workflow
- self.host: Optional[str] = None
- self.port: Optional[int] = None
- self.socket: Optional[zmq.Socket] = None
- self.loop: Optional[asyncio.AbstractEventLoop] = None
- self.stopping = False
-
- def start(self, *args, **kwargs):
- """Create the async loop, and bind socket."""
- # set asyncio loop
- try:
- self.loop = asyncio.get_running_loop()
- except RuntimeError:
- self.loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self.loop)
-
- if self.bind:
- self._socket_bind(*args, **kwargs)
- else:
- self._socket_connect(*args, **kwargs)
-
- # initiate bespoke items
- self._bespoke_start()
-
- # Keeping srv_prv_key_loc as optional arg so as to not break interface
- def _socket_bind(self, min_port, max_port, srv_prv_key_loc=None):
- """Bind socket.
-
- Will use a port range provided to select random ports.
-
- """
- if srv_prv_key_loc is None:
- # Create new KeyInfo object for the server private key
- workflow_srv_dir = get_workflow_srv_dir(self.workflow)
- srv_prv_key_info = KeyInfo(
- KeyType.PRIVATE,
- KeyOwner.SERVER,
- workflow_srv_dir=workflow_srv_dir)
- else:
- srv_prv_key_info = KeyInfo(
- KeyType.PRIVATE,
- KeyOwner.SERVER,
- full_key_path=srv_prv_key_loc)
-
- # create socket
- self.socket = self.context.socket(self.pattern)
- self._socket_options()
-
- try:
- server_public_key, server_private_key = zmq.auth.load_certificate(
- srv_prv_key_info.full_key_path)
- except ValueError:
- raise ServiceFileError(
- f"Failed to find server's public "
- f"key in "
- f"{srv_prv_key_info.full_key_path}."
- )
- except OSError:
- raise ServiceFileError(
- f"IO error opening server's private "
- f"key from "
- f"{srv_prv_key_info.full_key_path}."
- )
- if server_private_key is None: # this can't be caught by exception
- raise ServiceFileError(
- f"Failed to find server's private "
- f"key in "
- f"{srv_prv_key_info.full_key_path}."
- )
- self.socket.curve_publickey = server_public_key
- self.socket.curve_secretkey = server_private_key
- self.socket.curve_server = True
-
- try:
- if min_port == max_port:
- self.port = min_port
- self.socket.bind(f'tcp://*:{min_port}')
- else:
- self.port = self.socket.bind_to_random_port(
- 'tcp://*', min_port, max_port)
- except (zmq.error.ZMQError, zmq.error.ZMQBindError) as exc:
- raise CylcError(f'could not start Cylc ZMQ server: {exc}')
-
- # Keeping srv_public_key_loc as optional arg so as to not break interface
- def _socket_connect(self, host, port, srv_public_key_loc=None):
- """Connect socket to stub."""
- workflow_srv_dir = get_workflow_srv_dir(self.workflow)
- if srv_public_key_loc is None:
- # Create new KeyInfo object for the server public key
- srv_pub_key_info = KeyInfo(
- KeyType.PUBLIC,
- KeyOwner.SERVER,
- workflow_srv_dir=workflow_srv_dir)
-
- else:
- srv_pub_key_info = KeyInfo(
- KeyType.PUBLIC,
- KeyOwner.SERVER,
- full_key_path=srv_public_key_loc)
-
- self.host = host
- self.port = port
- self.socket = self.context.socket(self.pattern)
- self._socket_options()
-
- client_priv_key_info = KeyInfo(
- KeyType.PRIVATE,
- KeyOwner.CLIENT,
- workflow_srv_dir=workflow_srv_dir)
- error_msg = "Failed to find user's private key, so cannot connect."
- try:
- client_public_key, client_priv_key = zmq.auth.load_certificate(
- client_priv_key_info.full_key_path)
- except (OSError, ValueError):
- raise ClientError(error_msg)
- if client_priv_key is None: # this can't be caught by exception
- raise ClientError(error_msg)
- self.socket.curve_publickey = client_public_key
- self.socket.curve_secretkey = client_priv_key
-
- # A client can only connect to the server if it knows its public key,
- # so we grab this from the location it was created on the filesystem:
- try:
- # 'load_certificate' will try to load both public & private keys
- # from a provided file but will return None, not throw an error,
- # for the latter item if not there (as for all public key files)
- # so it is OK to use; there is no method to load only the
- # public key.
- server_public_key = zmq.auth.load_certificate(
- srv_pub_key_info.full_key_path)[0]
- self.socket.curve_serverkey = server_public_key
- except (OSError, ValueError): # ValueError raised w/ no public key
- raise ClientError(
- "Failed to load the workflow's public key, so cannot connect.")
-
- self.socket.connect(f'tcp://{host}:{port}')
-
- def _socket_options(self):
- """Set socket options.
-
- i.e. self.socket.sndhwm
- """
- self.socket.sndhwm = 10000
-
- def _bespoke_start(self):
- """Initiate bespoke items at start."""
- self.stopping = False
-
- def stop(self, stop_loop=True):
- """Stop the server.
-
- Args:
- stop_loop (Boolean): Stop running IOLoop.
-
- """
- self._bespoke_stop()
- if stop_loop and self.loop and self.loop.is_running():
- self.loop.stop()
- if self.socket and not self.socket.closed:
- self.socket.close()
- LOG.debug('...stopped')
-
- def _bespoke_stop(self):
- """Bespoke stop items."""
- LOG.debug('stopping zmq socket...')
- self.stopping = True
+# Cylc API version.
+# This is the Cylc protocol version number that determines whether a client can
+# communicate with a server. This should be changed when breaking changes are
+# made for which backwards compatibility can not be provided.
+API = 5
diff --git a/cylc/flow/network/base.py b/cylc/flow/network/base.py
new file mode 100644
index 00000000000..1842407b448
--- /dev/null
+++ b/cylc/flow/network/base.py
@@ -0,0 +1,237 @@
+# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE.
+# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Base ZMQ socket implementation for network server/client implementations."""
+
+import asyncio
+from typing import Optional
+
+import zmq
+import zmq.asyncio
+import zmq.auth
+
+from cylc.flow import LOG
+from cylc.flow.exceptions import (
+ ClientError,
+ CylcError,
+ ServiceFileError,
+)
+from cylc.flow.workflow_files import (
+ KeyType,
+ KeyOwner,
+ KeyInfo,
+ get_workflow_srv_dir,
+)
+
+
+class ZMQSocketBase:
+ """Initiate the ZMQ socket bind for specified pattern.
+
+ NOTE: Security to be provided via zmq.auth (see PR #3359).
+
+ Args:
+ pattern (enum): ZeroMQ message pattern (zmq.PATTERN).
+
+ context (object, optional): instantiated ZeroMQ context, defaults
+ to zmq.asyncio.Context().
+
+ This class is designed to be inherited by REP Server (REQ/REP)
+ and by PUB Publisher (PUB/SUB), as the start-up logic is similar.
+
+
+ To tailor this class overwrite it's method on inheritance.
+
+ """
+
+ def __init__(
+ self,
+ pattern,
+ workflow: str,
+ bind: bool = False,
+ context: Optional[zmq.Context] = None,
+ ):
+ self.bind = bind
+ if context is None:
+ self.context: zmq.Context = zmq.asyncio.Context()
+ else:
+ self.context = context
+ self.pattern = pattern
+ self.workflow = workflow
+ self.host: Optional[str] = None
+ self.port: Optional[int] = None
+ self.socket: Optional[zmq.Socket] = None
+ self.loop: Optional[asyncio.AbstractEventLoop] = None
+ self.stopping = False
+
+ def start(self, *args, **kwargs):
+ """Create the async loop, and bind socket."""
+ # set asyncio loop
+ try:
+ self.loop = asyncio.get_running_loop()
+ except RuntimeError:
+ self.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(self.loop)
+
+ if self.bind:
+ self._socket_bind(*args, **kwargs)
+ else:
+ self._socket_connect(*args, **kwargs)
+
+ # initiate bespoke items
+ self._bespoke_start()
+
+ # Keeping srv_prv_key_loc as optional arg so as to not break interface
+ def _socket_bind(self, min_port, max_port, srv_prv_key_loc=None):
+ """Bind socket.
+
+ Will use a port range provided to select random ports.
+
+ """
+ if srv_prv_key_loc is None:
+ # Create new KeyInfo object for the server private key
+ workflow_srv_dir = get_workflow_srv_dir(self.workflow)
+ srv_prv_key_info = KeyInfo(
+ KeyType.PRIVATE,
+ KeyOwner.SERVER,
+ workflow_srv_dir=workflow_srv_dir)
+ else:
+ srv_prv_key_info = KeyInfo(
+ KeyType.PRIVATE,
+ KeyOwner.SERVER,
+ full_key_path=srv_prv_key_loc)
+
+ # create socket
+ self.socket = self.context.socket(self.pattern)
+ self._socket_options()
+
+ try:
+ server_public_key, server_private_key = zmq.auth.load_certificate(
+ srv_prv_key_info.full_key_path)
+ except ValueError:
+ raise ServiceFileError(
+ f"Failed to find server's public "
+ f"key in "
+ f"{srv_prv_key_info.full_key_path}."
+ )
+ except OSError:
+ raise ServiceFileError(
+ f"IO error opening server's private "
+ f"key from "
+ f"{srv_prv_key_info.full_key_path}."
+ )
+ if server_private_key is None: # this can't be caught by exception
+ raise ServiceFileError(
+ f"Failed to find server's private "
+ f"key in "
+ f"{srv_prv_key_info.full_key_path}."
+ )
+ self.socket.curve_publickey = server_public_key
+ self.socket.curve_secretkey = server_private_key
+ self.socket.curve_server = True
+
+ try:
+ if min_port == max_port:
+ self.port = min_port
+ self.socket.bind(f'tcp://*:{min_port}')
+ else:
+ self.port = self.socket.bind_to_random_port(
+ 'tcp://*', min_port, max_port)
+ except (zmq.error.ZMQError, zmq.error.ZMQBindError) as exc:
+ raise CylcError(f'could not start Cylc ZMQ server: {exc}')
+
+ # Keeping srv_public_key_loc as optional arg so as to not break interface
+ def _socket_connect(self, host, port, srv_public_key_loc=None):
+ """Connect socket to stub."""
+ workflow_srv_dir = get_workflow_srv_dir(self.workflow)
+ if srv_public_key_loc is None:
+ # Create new KeyInfo object for the server public key
+ srv_pub_key_info = KeyInfo(
+ KeyType.PUBLIC,
+ KeyOwner.SERVER,
+ workflow_srv_dir=workflow_srv_dir)
+
+ else:
+ srv_pub_key_info = KeyInfo(
+ KeyType.PUBLIC,
+ KeyOwner.SERVER,
+ full_key_path=srv_public_key_loc)
+
+ self.host = host
+ self.port = port
+ self.socket = self.context.socket(self.pattern)
+ self._socket_options()
+
+ client_priv_key_info = KeyInfo(
+ KeyType.PRIVATE,
+ KeyOwner.CLIENT,
+ workflow_srv_dir=workflow_srv_dir)
+ error_msg = "Failed to find user's private key, so cannot connect."
+ try:
+ client_public_key, client_priv_key = zmq.auth.load_certificate(
+ client_priv_key_info.full_key_path)
+ except (OSError, ValueError):
+ raise ClientError(error_msg)
+ if client_priv_key is None: # this can't be caught by exception
+ raise ClientError(error_msg)
+ self.socket.curve_publickey = client_public_key
+ self.socket.curve_secretkey = client_priv_key
+
+ # A client can only connect to the server if it knows its public key,
+ # so we grab this from the location it was created on the filesystem:
+ try:
+ # 'load_certificate' will try to load both public & private keys
+ # from a provided file but will return None, not throw an error,
+ # for the latter item if not there (as for all public key files)
+ # so it is OK to use; there is no method to load only the
+ # public key.
+ server_public_key = zmq.auth.load_certificate(
+ srv_pub_key_info.full_key_path)[0]
+ self.socket.curve_serverkey = server_public_key
+ except (OSError, ValueError): # ValueError raised w/ no public key
+ raise ClientError(
+ "Failed to load the workflow's public key, so cannot connect.")
+
+ self.socket.connect(f'tcp://{host}:{port}')
+
+ def _socket_options(self):
+ """Set socket options.
+
+ i.e. self.socket.sndhwm
+ """
+ self.socket.sndhwm = 10000
+
+ def _bespoke_start(self):
+ """Initiate bespoke items at start."""
+ self.stopping = False
+
+ def stop(self, stop_loop=True):
+ """Stop the server.
+
+ Args:
+ stop_loop (Boolean): Stop running IOLoop.
+
+ """
+ self._bespoke_stop()
+ if stop_loop and self.loop and self.loop.is_running():
+ self.loop.stop()
+ if self.socket and not self.socket.closed:
+ self.socket.close()
+ LOG.debug('...stopped')
+
+ def _bespoke_stop(self):
+ """Bespoke stop items."""
+ LOG.debug('stopping zmq socket...')
+ self.stopping = True
diff --git a/cylc/flow/network/client.py b/cylc/flow/network/client.py
index e7e26954d56..6f8206ee786 100644
--- a/cylc/flow/network/client.py
+++ b/cylc/flow/network/client.py
@@ -35,14 +35,14 @@
WorkflowStopped,
)
from cylc.flow.hostuserutil import get_fqdn_by_host
-from cylc.flow.network import (
+from cylc.flow.network.base import ZMQSocketBase
+from cylc.flow.network.client_factory import CommsMeth
+from cylc.flow.network.server import PB_METHOD_MAP
+from cylc.flow.network.util import (
encode_,
decode_,
get_location,
- ZMQSocketBase
)
-from cylc.flow.network.client_factory import CommsMeth
-from cylc.flow.network.server import PB_METHOD_MAP
from cylc.flow.workflow_files import (
detect_old_contact_file,
)
diff --git a/cylc/flow/network/publisher.py b/cylc/flow/network/publisher.py
index 70d40d3cdb9..78574f9e8c5 100644
--- a/cylc/flow/network/publisher.py
+++ b/cylc/flow/network/publisher.py
@@ -21,7 +21,7 @@
import zmq
from cylc.flow import LOG
-from cylc.flow.network import ZMQSocketBase
+from cylc.flow.network.base import ZMQSocketBase
def serialize_data(
diff --git a/cylc/flow/network/replier.py b/cylc/flow/network/replier.py
index 09bfb55f662..a40756c05b4 100644
--- a/cylc/flow/network/replier.py
+++ b/cylc/flow/network/replier.py
@@ -21,7 +21,8 @@
import zmq
from cylc.flow import LOG
-from cylc.flow.network import encode_, decode_, ZMQSocketBase
+from cylc.flow.network.base import ZMQSocketBase
+from cylc.flow.network.util import encode_, decode_
if TYPE_CHECKING:
from cylc.flow.network.server import WorkflowRuntimeServer
diff --git a/cylc/flow/network/subscriber.py b/cylc/flow/network/subscriber.py
index 66bd16f81f8..28d5b5d1bb2 100644
--- a/cylc/flow/network/subscriber.py
+++ b/cylc/flow/network/subscriber.py
@@ -22,8 +22,9 @@
import zmq
-from cylc.flow.network import ZMQSocketBase, get_location
from cylc.flow.data_store_mgr import DELTAS_MAP
+from cylc.flow.network.base import ZMQSocketBase
+from cylc.flow.network.util import get_location
if TYPE_CHECKING:
import zmq.asyncio
diff --git a/cylc/flow/network/util.py b/cylc/flow/network/util.py
new file mode 100644
index 00000000000..6d1a006060d
--- /dev/null
+++ b/cylc/flow/network/util.py
@@ -0,0 +1,77 @@
+# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE.
+# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Common networking utilities."""
+
+import getpass
+import json
+from typing import Tuple
+
+from cylc.flow.exceptions import (
+ CylcVersionError,
+ ServiceFileError,
+ WorkflowStopped
+)
+from cylc.flow.hostuserutil import get_fqdn_by_host
+from cylc.flow.workflow_files import (
+ ContactFileFields,
+ load_contact_file,
+)
+
+
+def encode_(message):
+ """Convert the structure holding a message field from JSON to a string."""
+ try:
+ return json.dumps(message)
+ except TypeError as exc:
+ return json.dumps({'errors': [{'message': str(exc)}]})
+
+
+def decode_(message):
+ """Convert an encoded message string to JSON with an added 'user' field."""
+ msg = json.loads(message)
+ msg['user'] = getpass.getuser() # assume this is the user
+ return msg
+
+
+def get_location(workflow: str) -> Tuple[str, int, int]:
+ """Extract host and port from a workflow's contact file.
+
+ NB: if it fails to load the workflow contact file, it will exit.
+
+ Args:
+ workflow: workflow ID
+ Returns:
+ Tuple (host name, port number, publish port number)
+ Raises:
+ WorkflowStopped: if the workflow is not running.
+ CylcVersionError: if target is a Cylc 7 (or earlier) workflow.
+ """
+ try:
+ contact = load_contact_file(workflow)
+ except (IOError, ValueError, ServiceFileError):
+ # Contact file does not exist or corrupted, workflow should be dead
+ raise WorkflowStopped(workflow)
+
+ host = contact[ContactFileFields.HOST]
+ host = get_fqdn_by_host(host)
+ port = int(contact[ContactFileFields.PORT])
+ if ContactFileFields.PUBLISH_PORT in contact:
+ pub_port = int(contact[ContactFileFields.PUBLISH_PORT])
+ else:
+ version = contact.get('CYLC_VERSION', None)
+ raise CylcVersionError(version=version)
+ return host, port, pub_port
diff --git a/cylc/flow/scripts/subscribe.py b/cylc/flow/scripts/subscribe.py
index 5d174718c23..bd42dd6c90d 100755
--- a/cylc/flow/scripts/subscribe.py
+++ b/cylc/flow/scripts/subscribe.py
@@ -34,7 +34,7 @@
WORKFLOW_ID_ARG_DOC,
CylcOptionParser as COP,
)
-from cylc.flow.network import get_location
+from cylc.flow.network.util import get_location
from cylc.flow.network.subscriber import WorkflowSubscriber, process_delta_msg
from cylc.flow.terminal import cli_function
from cylc.flow.data_store_mgr import DELTAS_MAP
diff --git a/tests/integration/test_replier.py b/tests/integration/test_replier.py
index ce0b53fdaa8..7e219e8dd44 100644
--- a/tests/integration/test_replier.py
+++ b/tests/integration/test_replier.py
@@ -14,11 +14,12 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-from async_timeout import timeout
-from cylc.flow.network import decode_
-from cylc.flow.network.client import WorkflowRuntimeClient
import asyncio
+from cylc.flow.network.client import WorkflowRuntimeClient
+from cylc.flow.network.util import decode_
+
+from async_timeout import timeout
import pytest
diff --git a/tests/integration/test_zmq.py b/tests/integration/test_zmq.py
index 24c8db6d9b0..41ef6be2767 100644
--- a/tests/integration/test_zmq.py
+++ b/tests/integration/test_zmq.py
@@ -18,7 +18,7 @@
import zmq
from cylc.flow.exceptions import CylcError
-from cylc.flow.network import ZMQSocketBase
+from cylc.flow.network.base import ZMQSocketBase
from .key_setup import setup_keys
diff --git a/tests/unit/network/test__init__.py b/tests/unit/network/test_network_util.py
similarity index 80%
rename from tests/unit/network/test__init__.py
rename to tests/unit/network/test_network_util.py
index 71c32cf9bd1..1d40ca3b91e 100644
--- a/tests/unit/network/test__init__.py
+++ b/tests/unit/network/test_network_util.py
@@ -17,10 +17,10 @@
import pytest
-import cylc
+import cylc.flow.network.util
from cylc.flow.exceptions import CylcVersionError
-from cylc.flow.network import get_location
-from cylc.flow.workflow_files import load_contact_file, ContactFileFields
+from cylc.flow.network.util import get_location
+from cylc.flow.workflow_files import ContactFileFields
BASE_CONTACT_DATA = {
@@ -33,7 +33,7 @@
def mpatch_get_fqdn_by_host(monkeypatch):
"""Monkeypatch function used the same by all tests."""
monkeypatch.setattr(
- cylc.flow.network, 'get_fqdn_by_host', lambda _ : 'myhost.x.y.z'
+ cylc.flow.network.util, 'get_fqdn_by_host', lambda _: 'myhost.x.y.z'
)
@@ -42,7 +42,7 @@ def test_get_location_ok(monkeypatch, mpatch_get_fqdn_by_host):
contact_data = BASE_CONTACT_DATA.copy()
contact_data[ContactFileFields.PUBLISH_PORT] = '8042'
monkeypatch.setattr(
- cylc.flow.network, 'load_contact_file', lambda _ : contact_data
+ cylc.flow.network.util, 'load_contact_file', lambda _: contact_data
)
assert get_location('_') == (
'myhost.x.y.z', 42, 8042
@@ -55,7 +55,7 @@ def test_get_location_old_contact_file(monkeypatch, mpatch_get_fqdn_by_host):
contact_data['CYLC_SUITE_PUBLISH_PORT'] = '8042'
contact_data['CYLC_VERSION'] = '5.1.2'
monkeypatch.setattr(
- cylc.flow.network, 'load_contact_file', lambda _ : contact_data
+ cylc.flow.network.util, 'load_contact_file', lambda _: contact_data
)
- with pytest.raises(CylcVersionError, match=r'.*5.1.2.*') as exc:
+ with pytest.raises(CylcVersionError, match=r'.*5.1.2.*'):
get_location('_')