diff --git a/src/charm.py b/src/charm.py index 53df2de..cbb1aea 100755 --- a/src/charm.py +++ b/src/charm.py @@ -3,6 +3,7 @@ # Licensed under the GPLv3, see LICENSE file for details. import controlsocket +import configchangesocket import json import logging import secrets @@ -14,7 +15,7 @@ from ops.framework import StoredState from ops.charm import InstallEvent, RelationJoinedEvent, RelationDepartedEvent from ops.main import main -from ops.model import ActiveStatus, BlockedStatus, ErrorStatus, Relation +from ops.model import ActiveStatus, BlockedStatus, Relation from pathlib import Path from typing import List @@ -22,6 +23,8 @@ class JujuControllerCharm(CharmBase): + METRICS_SOCKET_PATH = '/var/lib/juju/control.socket' + CONFIG_SOCKET_PATH = '/var/lib/juju/configchange.socket' DB_BIND_ADDR_KEY = 'db-bind-address' ALL_BIND_ADDRS_KEY = 'db-bind-addresses' @@ -30,6 +33,23 @@ class JujuControllerCharm(CharmBase): def __init__(self, *args): super().__init__(*args) + self._observe() + + self._stored.set_default( + db_bind_address='', + last_bind_addresses=[], + all_bind_addresses=dict(), + ) + + # TODO (manadart 2024-03-05): Get these at need. + # No need to instantiate them for every invocatoin. + self._control_socket = controlsocket.ControlSocketClient( + socket_path=self.METRICS_SOCKET_PATH) + self._config_change_socket = configchangesocket.ConfigChangeSocketClient( + socket_path=self.CONFIG_SOCKET_PATH) + + def _observe(self): + """Set up all framework event observers.""" self.framework.observe(self.on.install, self._on_install) self.framework.observe(self.on.collect_unit_status, self._on_collect_status) self.framework.observe(self.on.config_changed, self._on_config_changed) @@ -37,18 +57,12 @@ def __init__(self, *args): self.on.dashboard_relation_joined, self._on_dashboard_relation_joined) self.framework.observe( self.on.website_relation_joined, self._on_website_relation_joined) - - self._stored.set_default( - db_bind_address='', last_bind_addresses=[], all_bind_addresses=dict()) - self.framework.observe( - self.on.dbcluster_relation_changed, self._on_dbcluster_relation_changed) - - self.control_socket = controlsocket.Client( - socket_path='/var/lib/juju/control.socket') self.framework.observe( self.on.metrics_endpoint_relation_created, self._on_metrics_endpoint_relation_created) self.framework.observe( self.on.metrics_endpoint_relation_broken, self._on_metrics_endpoint_relation_broken) + self.framework.observe( + self.on.dbcluster_relation_changed, self._on_dbcluster_relation_changed) def _on_install(self, event: InstallEvent): """Ensure that the controller configuration file exists.""" @@ -64,7 +78,7 @@ def _on_collect_status(self, event: CollectStatusEvent): try: self.api_port() except AgentConfException as e: - event.add_status(ErrorStatus( + event.add_status(BlockedStatus( f'cannot read controller API port from agent configuration: {e}')) event.add_status(ActiveStatus()) @@ -108,7 +122,7 @@ def _on_website_relation_joined(self, event): def _on_metrics_endpoint_relation_created(self, event: RelationJoinedEvent): username = metrics_username(event.relation) password = generate_password() - self.control_socket.add_metrics_user(username, password) + self._control_socket.add_metrics_user(username, password) # Set up Prometheus scrape config try: @@ -141,7 +155,7 @@ def _on_metrics_endpoint_relation_created(self, event: RelationJoinedEvent): def _on_metrics_endpoint_relation_broken(self, event: RelationDepartedEvent): username = metrics_username(event.relation) - self.control_socket.remove_metrics_user(username) + self._control_socket.remove_metrics_user(username) def _on_dbcluster_relation_changed(self, event): """Maintain our own bind address in relation data. @@ -201,6 +215,8 @@ def _ensure_db_bind_address(self, relation): self._stored.db_bind_address = ip def _update_config_file(self, bind_addresses): + logger.info('writing new DB cluster to config file: %s', bind_addresses) + file_path = self._controller_config_path() with open(file_path) as conf_file: conf = yaml.safe_load(conf_file) @@ -212,6 +228,7 @@ def _update_config_file(self, bind_addresses): with open(file_path, 'w') as conf_file: yaml.dump(conf, conf_file) + self._request_config_reload() self._stored.all_bind_addresses = bind_addresses def api_port(self) -> str: @@ -241,8 +258,15 @@ def _agent_conf(self, key: str): return agent_conf.get(key) def _controller_config_path(self) -> str: - unit_num = self.unit.name.split('/')[1] - return f'/var/lib/juju/agents/controller-{unit_num}/agent.conf' + """Interrogate the running controller jujud service to determine + the local controller ID, then use it to construct a config path. + """ + controller_id = self._config_change_socket.get_controller_agent_id() + return f'/var/lib/juju/agents/controller-{controller_id}/agent.conf' + + def _request_config_reload(self): + """Send a reload request to the config reload socket""" + self._config_change_socket.reload_config() def metrics_username(relation: Relation) -> str: @@ -259,7 +283,11 @@ def generate_password() -> str: class AgentConfException(Exception): - """Raised when there are errors reading info from agent.conf.""" + """Raised when there are errors regarding agent configuration.""" + + +class ControllerProcessException(Exception): + """Raised when there are errors regarding detection of controller service or process.""" if __name__ == "__main__": diff --git a/src/configchangesocket.py b/src/configchangesocket.py new file mode 100644 index 0000000..5e7ab42 --- /dev/null +++ b/src/configchangesocket.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +# Copyright 2023 Canonical Ltd. +# Licensed under the GPLv3, see LICENSE file for details. +import urllib +from typing import Optional + +import unixsocket +import logging + +logger = logging.getLogger(__name__) + + +class ConfigChangeSocketClient(unixsocket.SocketClient): + """ + Client to the Juju config change socket. + """ + def __init__(self, socket_path: str, + opener: Optional[urllib.request.OpenerDirector] = None): + super().__init__(socket_path, opener=opener) + + def get_controller_agent_id(self): + resp = self.request_raw(path='/agent-id', method='GET') + return resp.read().decode('utf-8') + + def reload_config(self): + self.request_raw(path='/reload', method='POST') diff --git a/src/controlsocket.py b/src/controlsocket.py index bcf9307..ab81946 100644 --- a/src/controlsocket.py +++ b/src/controlsocket.py @@ -1,136 +1,25 @@ #!/usr/bin/env python3 # Copyright 2023 Canonical Ltd. # Licensed under the GPLv3, see LICENSE file for details. +import urllib +from typing import Optional -import email.message -import email.parser -import http.client -import json +import unixsocket import logging -import socket -import sys -import urllib.error -import urllib.parse -import urllib.request -from typing import ( - Any, - Dict, - Generator, - Literal, - Optional, - Union, -) logger = logging.getLogger(__name__) -class Client: +class ControlSocketClient(unixsocket.SocketClient): """ - Client to the Juju control socket. - - Defaults to using a Unix socket at socket_path (which must be specified - unless a custom opener is provided). - - Originally copy-pasted from ops.pebble.Client. + Client to Juju control socket. """ - def __init__(self, socket_path: str, - opener: Optional[urllib.request.OpenerDirector] = None, - base_url: str = 'http://localhost', - timeout: float = 5.0): - if not isinstance(socket_path, str): - raise TypeError(f'`socket_path` should be a string, not: {type(socket_path)}') - if opener is None: - opener = self._get_default_opener(socket_path) - self.socket_path = socket_path - self.opener = opener - self.base_url = base_url - self.timeout = timeout - - @classmethod - def _get_default_opener(cls, socket_path: str) -> urllib.request.OpenerDirector: - """Build the default opener to use for requests (HTTP over Unix socket).""" - opener = urllib.request.OpenerDirector() - opener.add_handler(_UnixSocketHandler(socket_path)) - opener.add_handler(urllib.request.HTTPDefaultErrorHandler()) - opener.add_handler(urllib.request.HTTPRedirectHandler()) - opener.add_handler(urllib.request.HTTPErrorProcessor()) - return opener - - # we need to cast the return type depending on the request params - def _request(self, - method: str, - path: str, - query: Optional[Dict[str, Any]] = None, - body: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: - """Make a JSON request to the socket with the given HTTP method and path. - - If query dict is provided, it is encoded and appended as a query string - to the URL. If body dict is provided, it is serialied as JSON and used - as the HTTP body (with Content-Type: "application/json"). The resulting - body is decoded from JSON. - """ - headers = {'Accept': 'application/json'} - data = None - if body is not None: - data = json.dumps(body).encode('utf-8') - headers['Content-Type'] = 'application/json' - - response = self._request_raw(method, path, query, headers, data) - self._ensure_content_type(response.headers, 'application/json') - raw_resp: Dict[str, Any] = json.loads(response.read()) - return raw_resp - - @staticmethod - def _ensure_content_type(headers: email.message.Message, - expected: 'Literal["multipart/form-data", "application/json"]'): - """Parse Content-Type header from headers and ensure it's equal to expected. - - Return a dict of any options in the header, e.g., {'boundary': ...}. - """ - ctype = headers.get_content_type() - params = headers.get_params() or {} - options = {key: value for key, value in params if value} - if ctype != expected: - raise ProtocolError(f'expected Content-Type {expected!r}, got {ctype!r}') - return options - - def _request_raw( - self, method: str, path: str, - query: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, Any]] = None, - data: Optional[Union[bytes, Generator[bytes, Any, Any]]] = None, - ) -> http.client.HTTPResponse: - """Make a request to the socket; return the raw HTTPResponse object.""" - url = self.base_url + path - if query: - url = f"{url}?{urllib.parse.urlencode(query, doseq=True)}" - - if headers is None: - headers = {} - request = urllib.request.Request(url, method=method, data=data, headers=headers) - - try: - response = self.opener.open(request, timeout=self.timeout) - except urllib.error.HTTPError as e: - code = e.code - status = e.reason - try: - body: Dict[str, Any] = json.loads(e.read()) - message: str = body['error'] - except (OSError, ValueError, KeyError) as e2: - # Will only happen on read error or if the server sends invalid JSON. - body: Dict[str, Any] = {} - message = f'{type(e2).__name__} - {e2}' - raise APIError(body, code, status, message) - except urllib.error.URLError as e: - raise ConnectionError(e.reason) - - return response + opener: Optional[urllib.request.OpenerDirector] = None): + super().__init__(socket_path, opener=opener) def add_metrics_user(self, username: str, password: str): - resp = self._request( + resp = self.json_request( method='POST', path='/metrics-users', body={"username": username, "password": password}, @@ -138,92 +27,8 @@ def add_metrics_user(self, username: str, password: str): logger.debug('result of add_metrics_user request: %r', resp) def remove_metrics_user(self, username: str): - resp = self._request( + resp = self.json_request( method='DELETE', path=f'/metrics-users/{username}', ) logger.debug('result of remove_metrics_user request: %r', resp) - - -class _NotProvidedFlag: - pass - - -_not_provided = _NotProvidedFlag() - - -class _UnixSocketConnection(http.client.HTTPConnection): - """Implementation of HTTPConnection that connects to a named Unix socket.""" - - def __init__(self, host: str, socket_path: str, - timeout: Union[_NotProvidedFlag, float] = _not_provided): - if timeout is _not_provided: - super().__init__(host) - else: - assert isinstance(timeout, (int, float)), timeout # type guard for pyright - super().__init__(host, timeout=timeout) - self.socket_path = socket_path - - def connect(self): - """Override connect to use Unix socket (instead of TCP socket).""" - if not hasattr(socket, 'AF_UNIX'): - raise NotImplementedError(f'Unix sockets not supported on {sys.platform}') - self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.sock.connect(self.socket_path) - if self.timeout is not _not_provided: - self.sock.settimeout(self.timeout) - - -class _UnixSocketHandler(urllib.request.AbstractHTTPHandler): - """Implementation of HTTPHandler that uses a named Unix socket.""" - - def __init__(self, socket_path: str): - super().__init__() - self.socket_path = socket_path - - def http_open(self, req: urllib.request.Request): - """Override http_open to use a Unix socket connection (instead of TCP).""" - return self.do_open(_UnixSocketConnection, req, # type:ignore - socket_path=self.socket_path) - - -class Error(Exception): - """Base class of most errors raised by the client.""" - - def __repr__(self): - return f'<{type(self).__module__}.{type(self).__name__} {self.args}>' - - -class ProtocolError(Error): - """Raised when there's a higher-level protocol error talking to the socket.""" - - -class ConnectionError(Error): - """Raised when the client can't connect to the socket.""" - - -class APIError(Error): - """Raised when an HTTP API error occurs talking to the Pebble server.""" - - body: Dict[str, Any] - """Body of the HTTP response, parsed as JSON.""" - - code: int - """HTTP status code.""" - - status: str - """HTTP status string (reason).""" - - message: str - """Human-readable error message from the API.""" - - def __init__(self, body: Dict[str, Any], code: int, status: str, message: str): - """This shouldn't be instantiated directly.""" - super().__init__(message) # Makes str(e) return message - self.body = body - self.code = code - self.status = status - self.message = message - - def __repr__(self): - return f'APIError({self.body!r}, {self.code!r}, {self.status!r}, {self.message!r})' diff --git a/src/unixsocket.py b/src/unixsocket.py new file mode 100644 index 0000000..f3069e6 --- /dev/null +++ b/src/unixsocket.py @@ -0,0 +1,205 @@ +import email.message +import email.parser +import http.client +import json +import socket +import sys +import urllib.error +import urllib.parse +import urllib.request +from typing import ( + Any, + Dict, + Generator, + Literal, + Optional, + Union, +) + + +class SocketClient: + """ + Defaults to using a Unix socket at socket_path (which must be specified + unless a custom opener is provided). + + Originally copy-pasted from ops.pebble.Client. + """ + + def __init__(self, socket_path: str, + opener: Optional[urllib.request.OpenerDirector] = None, + base_url: str = 'http://localhost', + timeout: float = 5.0): + if not isinstance(socket_path, str): + raise TypeError(f'`socket_path` should be a string, not: {type(socket_path)}') + if opener is None: + opener = self._get_default_opener(socket_path) + self.socket_path = socket_path + self.opener = opener + self.base_url = base_url + self.timeout = timeout + + @classmethod + def _get_default_opener(cls, socket_path: str) -> urllib.request.OpenerDirector: + """Build the default opener to use for requests (HTTP over Unix socket).""" + opener = urllib.request.OpenerDirector() + opener.add_handler(_UnixSocketHandler(socket_path)) + opener.add_handler(urllib.request.HTTPDefaultErrorHandler()) + opener.add_handler(urllib.request.HTTPRedirectHandler()) + opener.add_handler(urllib.request.HTTPErrorProcessor()) + return opener + + # we need to cast the return type depending on the request params + def json_request(self, + method: str, + path: str, + query: Optional[Dict[str, Any]] = None, + body: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Make a JSON request to the socket with the given HTTP method and path. + + If query dict is provided, it is encoded and appended as a query string + to the URL. If body dict is provided, it is serialied as JSON and used + as the HTTP body (with Content-Type: "application/json"). The resulting + body is decoded from JSON. + """ + headers = {'Accept': 'application/json'} + data = None + if body is not None: + data = json.dumps(body).encode('utf-8') + headers['Content-Type'] = 'application/json' + + response = self.request_raw(method, path, query, headers, data) + self._ensure_content_type(response.headers, 'application/json') + raw_resp: Dict[str, Any] = json.loads(response.read()) + return raw_resp + + @staticmethod + def _ensure_content_type(headers: email.message.Message, + expected: 'Literal["multipart/form-data", "application/json"]'): + """Parse Content-Type header from headers and ensure it's equal to expected. + + Return a dict of any options in the header, e.g., {'boundary': ...}. + """ + ctype = headers.get_content_type() + params = headers.get_params() or {} + options = {key: value for key, value in params if value} + if ctype != expected: + raise ProtocolError(f'expected Content-Type {expected!r}, got {ctype!r}') + return options + + def request_raw( + self, method: str, path: str, + query: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + data: Optional[Union[bytes, Generator[bytes, Any, Any]]] = None, + ) -> http.client.HTTPResponse: + """Make a request to the socket; return the raw HTTPResponse object.""" + url = self.base_url + path + if query: + url = f"{url}?{urllib.parse.urlencode(query, doseq=True)}" + + if headers is None: + headers = {} + request = urllib.request.Request(url, method=method, data=data, headers=headers) + + try: + response = self.opener.open(request, timeout=self.timeout) + except urllib.error.HTTPError as e: + code = e.code + status = e.reason + try: + body: Dict[str, Any] = json.loads(e.read()) + message: str = body['error'] + except (OSError, ValueError, KeyError) as e2: + # Will only happen on read error or if the server sends invalid JSON. + body: Dict[str, Any] = {} + message = f'{type(e2).__name__} - {e2}' + raise APIError(body, code, status, message) + except urllib.error.URLError as e: + raise ConnectionError(e.reason) + + return response + + +class _NotProvidedFlag: + pass + + +_not_provided = _NotProvidedFlag() + + +class _UnixSocketHandler(urllib.request.AbstractHTTPHandler): + """Implementation of HTTPHandler that uses a named Unix socket.""" + + def __init__(self, socket_path: str): + super().__init__() + self.socket_path = socket_path + + def http_open(self, req: urllib.request.Request): + """Override http_open to use a Unix socket connection (instead of TCP).""" + return self.do_open(_UnixSocketConnection, req, # type:ignore + socket_path=self.socket_path) + + +class _UnixSocketConnection(http.client.HTTPConnection): + """Implementation of HTTPConnection that connects to a named Unix socket.""" + + def __init__(self, host: str, socket_path: str, + timeout: Union[_NotProvidedFlag, float] = _not_provided): + if timeout is _not_provided: + super().__init__(host) + else: + assert isinstance(timeout, (int, float)), timeout # type guard for pyright + super().__init__(host, timeout=timeout) + self.socket_path = socket_path + + def connect(self): + """Override connect to use Unix socket (instead of TCP socket).""" + if not hasattr(socket, 'AF_UNIX'): + raise NotImplementedError(f'Unix sockets not supported on {sys.platform}') + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.sock.connect(self.socket_path) + if self.timeout is not _not_provided: + self.sock.settimeout(self.timeout) + + +class Error(Exception): + """Base class of most errors raised by the client.""" + + def __repr__(self): + return f'<{type(self).__module__}.{type(self).__name__} {self.args}>' + + +class ProtocolError(Error): + """Raised when there's a higher-level protocol error talking to the socket.""" + + +class ConnectionError(Error): + """Raised when the client can't connect to the socket.""" + + +class APIError(Error): + """Raised when an HTTP API error occurs talking to the Pebble server.""" + + body: Dict[str, Any] + """Body of the HTTP response, parsed as JSON.""" + + code: int + """HTTP status code.""" + + status: str + """HTTP status string (reason).""" + + message: str + """Human-readable error message from the API.""" + + def __init__(self, body: Dict[str, Any], code: int, status: str, message: str): + """This shouldn't be instantiated directly.""" + super().__init__(message) # Makes str(e) return message + self.body = body + self.code = code + self.status = status + self.message = message + + def __repr__(self): + return f'APIError({self.body!r}, {self.code!r}, {self.status!r}, {self.message!r})' diff --git a/tests/test_charm.py b/tests/test_charm.py index 2e9fd78..c178200 100644 --- a/tests/test_charm.py +++ b/tests/test_charm.py @@ -5,10 +5,11 @@ import json import os import unittest + import yaml from charm import JujuControllerCharm, AgentConfException -from ops.model import BlockedStatus, ActiveStatus, ErrorStatus +from ops.model import BlockedStatus, ActiveStatus from ops.testing import Harness from unittest.mock import mock_open, patch @@ -84,8 +85,8 @@ def test_website_relation_joined(self, _, binding): @patch("builtins.open", new_callable=mock_open, read_data=agent_conf) @patch("charm.MetricsEndpointProvider", autospec=True) @patch("charm.generate_password", new=lambda: "passwd") - @patch("controlsocket.Client.add_metrics_user") - @patch("controlsocket.Client.remove_metrics_user") + @patch("controlsocket.ControlSocketClient.add_metrics_user") + @patch("controlsocket.ControlSocketClient.remove_metrics_user") def test_metrics_endpoint_relation(self, mock_remove_user, mock_add_user, mock_metrics_provider, _): harness = self.harness @@ -132,41 +133,44 @@ def test_apiaddresses_not_list(self, _): harness.charm.api_port() @patch("builtins.open", new_callable=mock_open, read_data=agent_conf_apiaddresses_missing) - @patch("controlsocket.Client.add_metrics_user") + @patch("controlsocket.ControlSocketClient.add_metrics_user") def test_apiaddresses_missing_status(self, *_): harness = self.harness harness.add_relation('metrics-endpoint', 'prometheus-k8s') harness.evaluate_status() - self.assertIsInstance(harness.charm.unit.status, ErrorStatus) + self.assertIsInstance(harness.charm.unit.status, BlockedStatus) @patch("builtins.open", new_callable=mock_open, read_data=agent_conf_ipv4) def test_apiaddresses_ipv4(self, _): - harness = self.harness - - self.assertEqual(harness.charm.api_port(), 17070) + self.assertEqual(self.harness.charm.api_port(), 17070) @patch("builtins.open", new_callable=mock_open, read_data=agent_conf_ipv6) def test_apiaddresses_ipv6(self, _): - harness = self.harness - - self.assertEqual(harness.charm.api_port(), 17070) + self.assertEqual(self.harness.charm.api_port(), 17070) @patch("builtins.open", new_callable=mock_open, read_data=agent_conf) + @patch("configchangesocket.ConfigChangeSocketClient.get_controller_agent_id") @patch("ops.model.Model.get_binding") - def test_dbcluster_relation_changed_single_addr(self, binding, _): + @patch("configchangesocket.ConfigChangeSocketClient.reload_config") + def test_dbcluster_relation_changed_single_addr( + self, mock_reload_config, mock_get_binding, mock_get_agent_id, *__): harness = self.harness - binding.return_value = mockBinding(['192.168.1.17']) + mock_get_binding.return_value = mockBinding(['192.168.1.17']) + + mock_get_agent_id.return_value = '0' harness.set_leader() # Have another unit enter the relation. # Its bind address should end up in the application data bindings list. - relation_id = harness.add_relation('dbcluster', harness.charm.app.name) + relation_id = harness.add_relation('dbcluster', harness.charm.app) harness.add_relation_unit(relation_id, 'juju-controller/1') self.harness.update_relation_data( relation_id, 'juju-controller/1', {'db-bind-address': '192.168.1.100'}) + mock_reload_config.assert_called_once() + unit_data = harness.get_relation_data(relation_id, 'juju-controller/0') self.assertEqual(unit_data['db-bind-address'], '192.168.1.17') @@ -183,7 +187,7 @@ def test_dbcluster_relation_changed_multi_addr_error(self, binding, _): harness = self.harness binding.return_value = mockBinding(["192.168.1.17", "192.168.1.18"]) - relation_id = harness.add_relation('dbcluster', harness.charm.app.name) + relation_id = harness.add_relation('dbcluster', harness.charm.app) harness.add_relation_unit(relation_id, 'juju-controller/1') self.harness.update_relation_data( @@ -192,11 +196,17 @@ def test_dbcluster_relation_changed_multi_addr_error(self, binding, _): harness.evaluate_status() self.assertIsInstance(harness.charm.unit.status, BlockedStatus) + @patch("configchangesocket.ConfigChangeSocketClient.get_controller_agent_id") @patch("builtins.open", new_callable=mock_open) @patch("ops.model.Model.get_binding") - def test_dbcluster_relation_changed_write_file(self, binding, mock_open): + @patch("configchangesocket.ConfigChangeSocketClient.reload_config") + def test_dbcluster_relation_changed_write_file( + self, mock_reload_config, mock_get_binding, mock_open, mock_get_agent_id): + harness = self.harness - binding.return_value = mockBinding(['192.168.1.17']) + mock_get_binding.return_value = mockBinding(['192.168.1.17']) + + mock_get_agent_id.return_value = '0' relation_id = harness.add_relation('dbcluster', harness.charm.app) harness.add_relation_unit(relation_id, 'juju-controller/1') @@ -208,12 +218,12 @@ def test_dbcluster_relation_changed_write_file(self, binding, mock_open): self.assertEqual(mock_open.call_count, 2) # First call to read out the YAML - first_call_args, _ = mock_open.call_args_list[0] - self.assertEqual(first_call_args, (file_path,)) + first_open_args, _ = mock_open.call_args_list[0] + self.assertEqual(first_open_args, (file_path,)) # Second call to write the updated YAML. - second_call_args, _ = mock_open.call_args_list[1] - self.assertEqual(second_call_args, (file_path, 'w')) + second_open_args, _ = mock_open.call_args_list[1] + self.assertEqual(second_open_args, (file_path, 'w')) # yaml.dump appears to write the the file incrementally, # so we need to hoover up the call args to reconstruct. @@ -223,6 +233,9 @@ def test_dbcluster_relation_changed_write_file(self, binding, mock_open): self.assertEqual(yaml.safe_load(written), {'db-bind-addresses': bound}) + # The last thing we should have done is send a reload request via the socket.. + mock_reload_config.assert_called_once() + class mockNetwork: def __init__(self, addresses): diff --git a/tests/test_controlsocket.py b/tests/test_sockets.py similarity index 74% rename from tests/test_controlsocket.py rename to tests/test_sockets.py index 9969b9a..6745666 100644 --- a/tests/test_controlsocket.py +++ b/tests/test_sockets.py @@ -5,19 +5,20 @@ import io import unittest import urllib.error -from controlsocket import Client, APIError, ConnectionError +from controlsocket import ControlSocketClient +from configchangesocket import ConfigChangeSocketClient +from unixsocket import APIError, ConnectionError class TestClass(unittest.TestCase): def test_add_metrics_user_success(self): mock_opener = MockOpener(self) - control_socket = Client('fake_socket_path', opener=mock_opener) + control_socket = ControlSocketClient('fake_socket_path', opener=mock_opener) mock_opener.expect( url='http://localhost/metrics-users', method='POST', body=r'{"username": "juju-metrics-r0", "password": "passwd"}', - response=MockResponse( headers=MockHeaders(content_type='application/json'), body=r'{"message":"created user \"juju-metrics-r0\""}' @@ -27,13 +28,12 @@ def test_add_metrics_user_success(self): def test_add_metrics_user_fail(self): mock_opener = MockOpener(self) - control_socket = Client('fake_socket_path', opener=mock_opener) + control_socket = ControlSocketClient('fake_socket_path', opener=mock_opener) mock_opener.expect( url='http://localhost/metrics-users', method='POST', body=r'{"username": "juju-metrics-r0", "password": "passwd"}', - error=urllib.error.HTTPError( url='http://localhost/metrics-users', code=409, @@ -52,13 +52,12 @@ def test_add_metrics_user_fail(self): def test_remove_metrics_user_success(self): mock_opener = MockOpener(self) - control_socket = Client('fake_socket_path', opener=mock_opener) + control_socket = ControlSocketClient('fake_socket_path', opener=mock_opener) mock_opener.expect( url='http://localhost/metrics-users/juju-metrics-r0', method='DELETE', body=None, - response=MockResponse( headers=MockHeaders(content_type='application/json'), body=r'{"message":"deleted user \"juju-metrics-r0\""}' @@ -68,13 +67,12 @@ def test_remove_metrics_user_success(self): def test_remove_metrics_user_fail(self): mock_opener = MockOpener(self) - control_socket = Client('fake_socket_path', opener=mock_opener) + control_socket = ControlSocketClient('fake_socket_path', opener=mock_opener) mock_opener.expect( url='http://localhost/metrics-users/juju-metrics-r0', method='DELETE', body=None, - error=urllib.error.HTTPError( url='http://localhost/metrics-users/juju-metrics-r0', code=404, @@ -93,19 +91,48 @@ def test_remove_metrics_user_fail(self): def test_connection_error(self): mock_opener = MockOpener(self) - control_socket = Client('fake_socket_path', opener=mock_opener) + control_socket = ControlSocketClient('fake_socket_path', opener=mock_opener) mock_opener.expect( url='http://localhost/metrics-users', method='POST', body=r'{"username": "juju-metrics-r0", "password": "passwd"}', - error=urllib.error.URLError('could not connect to socket') ) with self.assertRaisesRegex(ConnectionError, 'could not connect to socket'): control_socket.add_metrics_user('juju-metrics-r0', 'passwd') + def test_get_controller_agent_id(self): + mock_opener = MockOpener(self) + config_reload_socket = ConfigChangeSocketClient('fake_socket_path', opener=mock_opener) + + mock_opener.expect( + url='http://localhost/agent-id', + method='GET', + body=None, + response=MockResponse( + headers=MockHeaders(content_type='application/text'), + body=b'666' + ) + ) + + id = config_reload_socket.get_controller_agent_id() + self.assertEqual(id, '666') + + def test_reload_config(self): + mock_opener = MockOpener(self) + config_reload_socket = ConfigChangeSocketClient('fake_socket_path', opener=mock_opener) + + mock_opener.expect( + url='http://localhost/reload', + method='POST', + body=None, + response=None, + ) + + config_reload_socket.reload_config() + class MockOpener: def __init__(self, test_case): @@ -127,14 +154,14 @@ def open(self, request, timeout): else: self.test.assertEqual(request.data.decode('utf-8'), self.body) - if self.response: - return self.response - else: + if self.error: raise self.error + else: + return self.response class MockResponse: - def __init__(self, headers, body): + def __init__(self, headers, body=None): self.headers = headers self.body = body @@ -143,7 +170,7 @@ def read(self): class MockHeaders: - def __init__(self, content_type, params=None): + def __init__(self, content_type=None, params=None): self.content_type = content_type self.params = params