diff --git a/pyproject.toml b/pyproject.toml index 64ed3f8..f0b00b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,8 +48,7 @@ dependencies = [ "prometheus-client", "python-dateutil", "pyyaml", - "sqlalchemy<1.4", - "sqlalchemy-aio>=0.17", + "sqlalchemy>=2", "toolrack>=4", ] optional-dependencies.testing = [ diff --git a/query_exporter/db.py b/query_exporter/db.py index 7406818..a168684 100644 --- a/query_exporter/db.py +++ b/query_exporter/db.py @@ -1,14 +1,24 @@ """Database wrapper.""" import asyncio -from collections.abc import Iterable +from collections.abc import ( + Callable, + Iterable, + Sequence, +) +from concurrent import futures from dataclasses import ( dataclass, field, ) +from functools import partial from itertools import chain import logging import sys +from threading import ( + Thread, + current_thread, +) from time import ( perf_counter, time, @@ -28,16 +38,17 @@ event, text, ) +from sqlalchemy.engine import ( + Connection, + CursorResult, + Engine, + Row, +) from sqlalchemy.exc import ( ArgumentError, NoSuchModuleError, ) -from sqlalchemy_aio import ASYNCIO_STRATEGY -from sqlalchemy_aio.asyncio import AsyncioEngine -from sqlalchemy_aio.base import ( - AsyncConnection, - AsyncResultProxy, -) +from sqlalchemy.sql.elements import TextClause #: Timeout for a query QueryTimeout = int | float @@ -116,7 +127,7 @@ def __init__(self, query_name: str, message: str) -> None: ) -# database errors that mean the query won't ever succeed. Not all possible +# Database errors that mean the query won't ever succeed. Not all possible # fatal errors are tracked here, because some DBAPI errors can happen in # circumstances which can be fatal or not. Since there doesn't seem to be a # reliable way to know, there might be cases when a query will never succeed @@ -140,7 +151,7 @@ def __post_init__(self) -> None: create_db_engine(self.dsn) -def create_db_engine(dsn: str, **kwargs: Any) -> AsyncioEngine: +def create_db_engine(dsn: str, **kwargs: Any) -> Engine: """Create the database engine, validating the DSN""" try: return create_engine(dsn, **kwargs) @@ -161,22 +172,20 @@ class QueryResults(NamedTuple): """Results of a database query.""" keys: list[str] - rows: list[tuple[Any]] + rows: Sequence[Row[Any]] timestamp: float | None = None latency: float | None = None @classmethod - async def from_results(cls, results: AsyncResultProxy) -> Self: + def from_result(cls, result: CursorResult[Any]) -> Self: """Return a QueryResults from results for a query.""" timestamp = time() - conn_info = results._result_proxy.connection.info - latency = conn_info.get("query_latency", None) - return cls( - await results.keys(), - await results.fetchall(), - timestamp=timestamp, - latency=latency, - ) + keys: list[str] = [] + rows: Sequence[Row[Any]] = [] + if result.returns_rows: + keys, rows = list(result.keys()), result.all() + latency = result.connection.info.get("query_latency", None) + return cls(keys, rows, timestamp=timestamp, latency=latency) class MetricResult(NamedTuple): @@ -276,11 +285,149 @@ def _check_query_parameters(self) -> None: raise InvalidQueryParameters(self.name) +class WorkerAction: + """An action to be called in the worker thread.""" + + def __init__( + self, func: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: + self._func = partial(func, *args, **kwargs) + self._loop = asyncio.get_event_loop() + self._future = self._loop.create_future() + + def __str__(self) -> str: + return self._func.func.__name__ + + def __call__(self) -> None: + """Call the action asynchronously in a thread-safe way.""" + try: + result = self._func() + except Exception as e: + self._call_threadsafe(self._future.set_exception, e) + else: + self._call_threadsafe(self._future.set_result, result) + + async def result(self) -> Any: + """Wait for completion and return the action result.""" + return await self._future + + def _call_threadsafe(self, call: Callable[..., Any], *args: Any) -> None: + self._loop.call_soon_threadsafe(partial(call, *args)) + + +class DataBaseConnection: + """A connection to a database engine.""" + + _conn: Connection | None = None + _worker: Thread | None = None + + def __init__( + self, + dbname: str, + engine: Engine, + logger: logging.Logger = logging.getLogger(), + ) -> None: + self.dbname = dbname + self.engine = engine + self.logger = logger + self._loop = asyncio.get_event_loop() + self._queue: asyncio.Queue[WorkerAction] = asyncio.Queue() + + @property + def connected(self) -> bool: + """Whether the connection is open.""" + return self._conn is not None + + async def open(self) -> None: + """Open the connection.""" + if self.connected: + return + + self._create_worker() + await self._call_in_thread(self._connect) + + async def close(self) -> None: + """Close the connection.""" + if not self.connected: + return + + await self._call_in_thread(self._close) + self._terminate_worker() + + async def execute( + self, + sql: TextClause, + parameters: dict[str, Any] | None = None, + ) -> QueryResults: + """Execute a query, returning results.""" + if parameters is None: + parameters = {} + result = await self._call_in_thread(self._execute, sql, parameters) + query_results: QueryResults = await self._call_in_thread( + QueryResults.from_result, result + ) + return query_results + + def _create_worker(self) -> None: + assert not self._worker + self._worker = Thread( + target=self._run, name=f"DataBase-{self.dbname}", daemon=True + ) + self._worker.start() + + def _terminate_worker(self) -> None: + assert self._worker + self._worker.join() + self._worker = None + + def _connect(self) -> None: + self._conn = self.engine.connect() + + def _execute( + self, sql: TextClause, parameters: dict[str, Any] + ) -> CursorResult[Any]: + assert self._conn + return self._conn.execute(sql, parameters) + + def _close(self) -> None: + assert self._conn + self._conn.detach() + self._conn.close() + self._conn = None + + def _run(self) -> None: + """The worker thread function.""" + + def debug(message: str) -> None: + self.logger.debug(f'worker "{current_thread().name}": {message}') + + debug(f"started with ID {current_thread().native_id}") + while True: + future = asyncio.run_coroutine_threadsafe( + self._queue.get(), self._loop + ) + action = future.result() + debug(f'received action "{action}"') + action() + self._loop.call_soon_threadsafe(self._queue.task_done) + if self._conn is None: + # the connection has been closed, exit the thread + debug("shutting down") + return + + async def _call_in_thread( + self, func: Callable[..., Any], *args: Any, **kwargs: Any + ) -> Any: + """Call a sync action in the worker thread.""" + call = WorkerAction(func, *args, **kwargs) + await self._queue.put(call) + return await call.result() + + class DataBase: """A database to perform Queries.""" - _engine: AsyncioEngine - _conn: AsyncConnection | None = None + _conn: DataBaseConnection _pending_queries: int = 0 def __init__( @@ -291,27 +438,32 @@ def __init__( self.config = config self.logger = logger self._connect_lock = asyncio.Lock() - self._engine = create_db_engine( + execution_options = {} + if self.config.autocommit: + execution_options["isolation_level"] = "AUTOCOMMIT" + engine = create_db_engine( self.config.dsn, - strategy=ASYNCIO_STRATEGY, - execution_options={"autocommit": self.config.autocommit}, + execution_options=execution_options, ) - - self._setup_query_latency_tracking() + self._conn = DataBaseConnection(self.config.name, engine, self.logger) + self._setup_query_latency_tracking(engine) async def __aenter__(self) -> Self: await self.connect() return self async def __aexit__( - self, exc_type: type, exc_value: Exception, traceback: TracebackType + self, + exc_type: type, + exc_value: Exception, + traceback: TracebackType, ) -> None: await self.close() @property def connected(self) -> bool: """Whether the database is connected.""" - return self._conn is not None + return self._conn.connected async def connect(self) -> None: """Connect to the database.""" @@ -320,7 +472,7 @@ async def connect(self) -> None: return try: - self._conn = await self._engine.connect() + await self._conn.open() except Exception as error: raise self._db_error(error, exc_class=DataBaseConnectError) @@ -349,10 +501,11 @@ async def execute(self, query: Query) -> MetricResults: f'running query "{query.name}" on database "{self.config.name}"' ) self._pending_queries += 1 - self._conn: AsyncConnection try: - result = await self._execute_query(query) - return query.results(await QueryResults.from_results(result)) + query_results = await self.execute_sql( + query.sql, parameters=query.parameters, timeout=query.timeout + ) + return query.results(query_results) except TimeoutError: raise self._query_timeout_error( query.name, cast(QueryTimeout, query.timeout) @@ -372,34 +525,19 @@ async def execute_sql( sql: str, parameters: dict[str, Any] | None = None, timeout: QueryTimeout | None = None, - ) -> AsyncResultProxy: + ) -> QueryResults: """Execute a raw SQL query.""" - if parameters is None: - parameters = {} - self._conn: AsyncConnection return await asyncio.wait_for( self._conn.execute(text(sql), parameters), timeout=timeout, ) - async def _execute_query(self, query: Query) -> AsyncResultProxy: - """Execute a query.""" - return await self.execute_sql( - query.sql, parameters=query.parameters, timeout=query.timeout - ) - async def _close(self) -> None: # ensure the connection with the DB is actually closed - self._conn: AsyncConnection - self._conn.sync_connection.detach() await self._conn.close() - self._conn = None - self._pending_queries = 0 self.logger.debug(f'disconnected from database "{self.config.name}"') - def _setup_query_latency_tracking(self) -> None: - engine = self._engine.sync_engine - + def _setup_query_latency_tracking(self, engine: Engine) -> None: @event.listens_for(engine, "before_cursor_execute") # type: ignore def before_cursor_execute( conn, cursor, statement, parameters, context, executemany diff --git a/query_exporter/loop.py b/query_exporter/loop.py index 7f6e54f..75b025b 100644 --- a/query_exporter/loop.py +++ b/query_exporter/loop.py @@ -83,15 +83,20 @@ def expire_series( expired = {} for name, metric_last_seen in self._last_seen.items(): expiration = cast(int, self._expirations[name]) - expired[name] = [ + expired_labels = [ label_values for label_values, last_seen in metric_last_seen.items() if timestamp > last_seen + expiration ] + if expired_labels: + expired[name] = expired_labels + # clear expired series from tracking for name, series_labels in expired.items(): for label_values in series_labels: del self._last_seen[name][label_values] + if not self._last_seen[name]: + del self._last_seen[name] return expired @@ -137,7 +142,7 @@ async def start(self) -> None: call: TimedCall if query.interval: call = PeriodicCall(self._run_query, query) - call.start(query.interval) + call.start(query.interval, now=True) elif query.schedule is not None: call = TimedCall(self._run_query, query) call.start(self._loop_times_iter(query.schedule)) @@ -153,7 +158,7 @@ async def stop(self) -> None: def clear_expired_series(self) -> None: """Clear metric series that have expired.""" - expired_series = self._last_seen.expire_series(self._timestamp()) + expired_series = self._last_seen.expire_series(self._loop.time()) for name, label_values in expired_series.items(): metric = self._registry.get_metric(name) for values in label_values: @@ -170,7 +175,7 @@ async def run_aperiodic_queries(self) -> None: def _loop_times_iter(self, schedule: str) -> Iterator[float | int]: """Wrap a croniter iterator to sync time with the loop clock.""" - cron_iter = croniter(schedule, self._now()) + cron_iter = croniter(schedule, datetime.now(gettz())) while True: cc = next(cron_iter) t = time.time() @@ -198,7 +203,7 @@ async def _execute_query(self, query: Query, dbname: str) -> None: self._increment_queries_count(db, query, "error") if error.fatal: self._logger.debug( - f'removing doomed query "{query.name}" ' + f'removing failed query "{query.name}" ' f'for database "{dbname}"' ) self._doomed_queries[query.name].add(dbname) @@ -264,20 +269,21 @@ def _update_metric( ) metric = self._registry.get_metric(name, labels=all_labels) self._update_metric_value(metric, method, value) - self._last_seen.update(name, all_labels, self._timestamp()) + self._last_seen.update(name, all_labels, self._loop.time()) def _get_metric_method(self, metric: MetricConfig) -> str: - method = { - "counter": "inc", - "gauge": "set", - "histogram": "observe", - "summary": "observe", - "enum": "state", - }[metric.type] if metric.type == "counter" and not metric.config.get( "increment", True ): method = "set" + else: + method = { + "counter": "inc", + "gauge": "set", + "histogram": "observe", + "summary": "observe", + "enum": "state", + }[metric.type] return method def _update_metric_value( @@ -325,11 +331,3 @@ def _update_query_timestamp_metric( timestamp, labels={"query": query.config_name}, ) - - def _now(self) -> datetime: - """Return the current time with local timezone.""" - return datetime.now().replace(tzinfo=gettz()) - - def _timestamp(self) -> float: - """Return the current timestamp.""" - return self._now().timestamp() diff --git a/tests/conftest.py b/tests/conftest.py index 43b6c1f..9eb582c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,6 @@ import asyncio +from collections.abc import Iterator +import re import pytest from toolrack.testing.fixtures import advance_time @@ -58,3 +60,22 @@ async def execute(self, query): mocker.patch.object(DataBase, "execute", execute) yield tracker + + +class AssertRegexpMatch: + """Assert that comparison matches the specified regexp.""" + + def __init__(self, pattern: str, flags: int = 0) -> None: + self._re = re.compile(pattern, flags) + + def __eq__(self, string: str) -> bool: + return bool(self._re.match(string)) + + def __repr__(self) -> str: + return self._re.pattern # pragma: nocover + + +@pytest.fixture +def re_match() -> Iterator[type[AssertRegexpMatch]]: + """Matcher for asserting that a string matches a regexp.""" + yield AssertRegexpMatch diff --git a/tests/db_test.py b/tests/db_test.py index 2fb6680..ae0d504 100644 --- a/tests/db_test.py +++ b/tests/db_test.py @@ -1,16 +1,23 @@ import asyncio +from collections.abc import Iterator import logging import time import pytest -from sqlalchemy import create_engine -from sqlalchemy_aio import ASYNCIO_STRATEGY -from sqlalchemy_aio.base import AsyncConnection +from sqlalchemy import ( + create_engine, + text, +) +from sqlalchemy.engine import ( + Connection, + Engine, +) from query_exporter.config import DataBaseConfig from query_exporter.db import ( DataBase, DataBaseConnectError, + DataBaseConnection, DataBaseError, DataBaseQueryError, InvalidQueryParameters, @@ -22,6 +29,7 @@ QueryMetric, QueryResults, QueryTimeoutExpired, + WorkerAction, create_db_engine, ) @@ -302,31 +310,118 @@ def test_results_wrong_names_with_labels(self): class TestQueryResults: - async def test_from_results(self): - """The from_results method creates a QueryResult.""" - engine = create_engine("sqlite://", strategy=ASYNCIO_STRATEGY) - async with engine.connect() as conn: - result = await conn.execute("SELECT 1 AS a, 2 AS b") - query_results = await QueryResults.from_results(result) + def test_from_result(self): + """The from_result method returns a QueryResult.""" + engine = create_engine("sqlite://") + with engine.connect() as conn: + result = conn.execute(text("SELECT 1 AS a, 2 AS b")) + query_results = QueryResults.from_result(result) assert query_results.keys == ["a", "b"] assert query_results.rows == [(1, 2)] assert query_results.latency is None assert query_results.timestamp < time.time() - async def test_from_results_with_latency(self): - """The from_results method creates a QueryResult.""" - engine = create_engine("sqlite://", strategy=ASYNCIO_STRATEGY) - async with engine.connect() as conn: - result = await conn.execute("SELECT 1 AS a, 2 AS b") + def test_from_empty(self): + """The from_result method returns empty QueryResult.""" + engine = create_engine("sqlite://") + with engine.connect() as conn: + result = conn.execute(text("PRAGMA auto_vacuum = 1")) + query_results = QueryResults.from_result(result) + assert query_results.keys == [] + assert query_results.rows == [] + assert query_results.latency is None + + def test_from_result_with_latency(self): + """The from_result method tracks call latency.""" + engine = create_engine("sqlite://") + with engine.connect() as conn: + result = conn.execute(text("SELECT 1 AS a, 2 AS b")) # simulate latency tracking - conn.sync_connection.info["query_latency"] = 1.2 - query_results = await QueryResults.from_results(result) + conn.info["query_latency"] = 1.2 + query_results = QueryResults.from_result(result) assert query_results.keys == ["a", "b"] assert query_results.rows == [(1, 2)] assert query_results.latency == 1.2 assert query_results.timestamp < time.time() +@pytest.fixture +async def conn() -> Iterator[DataBaseConnection]: + engine = create_engine("sqlite://") + connection = DataBaseConnection("db", engine) + yield connection + await connection.close() + + +class TestWorkerAction: + async def test_call_wait(self): + def func(a: int, b: int) -> int: + return a + b + + action = WorkerAction(func, 10, 20) + action() + assert await action.result() == 30 + + async def test_call_exception(self): + def func() -> None: + raise Exception("fail!") + + action = WorkerAction(func) + action() + with pytest.raises(Exception) as error: + await action.result() + assert str(error.value) == "fail!" + + +class TestDataBaseConnection: + def test_engine(self, conn): + """The connection keeps the SQLAlchemy engine.""" + assert isinstance(conn.engine, Engine) + + async def test_open(self, conn: DataBaseConnection) -> None: + """The open method opens the database connection.""" + await conn.open() + assert conn.connected + assert conn._conn is not None + assert conn._worker.is_alive() + + async def test_open_noop(self, conn: DataBaseConnection) -> None: + """The open method is a no-op if connection is already open.""" + await conn.open() + await conn.open() + assert conn.connected + + async def test_close(self, conn: DataBaseConnection) -> None: + """The close method closes the connection.""" + await conn.open() + await conn.close() + assert not conn.connected + assert conn._conn is None + + async def test_close_noop(self, conn: DataBaseConnection) -> None: + """The close method is a no-op if connection is already closed.""" + await conn.open() + await conn.close() + await conn.close() + assert not conn.connected + + async def test_execute(self, conn: DataBaseConnection) -> None: + """The connection can execute queries.""" + await conn.open() + query_results = await conn.execute(text("SELECT 1 AS a, 2 AS b")) + assert query_results.keys == ["a", "b"] + assert query_results.rows == [(1, 2)] + + async def test_execute_with_params(self, conn: DataBaseConnection) -> None: + """The connection can execute queries with parameters.""" + await conn.open() + query_results = await conn.execute( + text("SELECT :a AS a, :b AS b"), parameters={"a": 1, "b": 2} + ) + assert query_results.keys == ["a", "b"] + assert query_results.rows == [(1, 2)] + + @pytest.fixture def db_config(): return DataBaseConfig( @@ -352,23 +447,32 @@ def test_instantiate(self, db_config): async def test_as_context_manager(self, db): """The database can be used as an async context manager.""" async with db: - result = await db.execute_sql("SELECT 10 AS a, 20 AS b") - assert await result.fetchall() == [(10, 20)] + query_result = await db.execute_sql("SELECT 10 AS a, 20 AS b") + assert query_result.rows == [(10, 20)] # the db is closed at context exit assert not db.connected - async def test_connect(self, caplog, db): + async def test_connect(self, caplog, re_match, db): """The connect connects to the database.""" with caplog.at_level(logging.DEBUG): await db.connect() - assert isinstance(db._conn, AsyncConnection) - assert caplog.messages == ['connected to database "db"'] + assert db.connected + assert isinstance(db._conn._conn, Connection) + assert caplog.messages == [ + re_match(r'worker "DataBase-db": started'), + 'worker "DataBase-db": received action "_connect"', + 'connected to database "db"', + ] - async def test_connect_lock(self, caplog, db): + async def test_connect_lock(self, caplog, re_match, db): """The connect method has a lock to prevent concurrent calls.""" with caplog.at_level(logging.DEBUG): await asyncio.gather(db.connect(), db.connect()) - assert caplog.messages == ['connected to database "db"'] + assert caplog.messages == [ + re_match(r'worker "DataBase-db": started'), + 'worker "DataBase-db": received action "_connect"', + 'connected to database "db"', + ] async def test_connect_error(self): """A DataBaseConnectError is raised if database connection fails.""" @@ -414,15 +518,18 @@ async def test_connect_sql_fail(self, caplog): assert 'failed executing query "WRONG"' in str(error.value) assert 'disconnected from database "db"' in caplog.messages - async def test_close(self, caplog, db): + async def test_close(self, caplog, re_match, db): """The close method closes database connection.""" await db.connect() - connection = db._conn with caplog.at_level(logging.DEBUG): await db.close() - assert caplog.messages == ['disconnected from database "db"'] - assert connection.closed - assert db._conn is None + assert caplog.messages == [ + 'worker "DataBase-db": received action "_close"', + 'worker "DataBase-db": shutting down', + 'disconnected from database "db"', + ] + assert not db.connected + assert db._conn._conn is None async def test_execute_log(self, db, caplog): """A message is logged about the query being executed.""" @@ -435,7 +542,11 @@ async def test_execute_log(self, db, caplog): await db.connect() with caplog.at_level(logging.DEBUG): await db.execute(query) - assert caplog.messages == ['running query "query" on database "db"'] + assert caplog.messages == [ + 'running query "query" on database "db"', + 'worker "DataBase-db": received action "_execute"', + 'worker "DataBase-db": received action "from_result"', + ] await db.close() @pytest.mark.parametrize("connected", [True, False]) @@ -452,9 +563,7 @@ async def test_execute_keep_connected(self, mocker, connected): "SELECT 1.0 AS metric", ) await db.connect() - mock_conn_detach = mocker.patch.object( - db._conn.sync_connection, "detach" - ) + mock_conn_detach = mocker.patch.object(db._conn._conn, "detach") await db.execute(query) assert db.connected == connected if not connected: @@ -491,7 +600,7 @@ async def test_execute_not_connected(self, db): metric_results = await db.execute(query) assert metric_results.results == [MetricResult("metric", 1, {})] # the connection is kept for reuse - assert not db._conn.closed + assert db.connected async def test_execute(self, db): """The execute method executes a query.""" @@ -544,6 +653,14 @@ async def test_execute_with_labels(self, db): MetricResult("metric2", 33, {"label2": "baz"}), ] + async def test_execute_fail(self, caplog, db): + """If the query fails, an exception is raised.""" + query = Query("query", 10, [QueryMetric("metric", [])], "WRONG") + await db.connect() + with pytest.raises(DataBaseQueryError) as error: + await db.execute(query) + assert "syntax error" in str(error.value) + async def test_execute_query_invalid_count(self, caplog, db): """If the number of fields don't match, an error is raised.""" query = Query( @@ -585,7 +702,7 @@ async def test_execute_query_invalid_count_with_labels(self, db): ) assert error.value.fatal - async def test_execute_query_invalid_names_with_labels(self, db): + async def test_execute_invalid_names_with_labels(self, db): """If the names of fields don't match, an error is raised.""" query = Query( "query", @@ -602,7 +719,7 @@ async def test_execute_query_invalid_names_with_labels(self, db): ) assert error.value.fatal - async def test_execute_query_traceback_debug(self, caplog, mocker, db): + async def test_execute_traceback_debug(self, caplog, mocker, db): """Traceback are logged as debug messages.""" query = Query( "query", @@ -610,10 +727,8 @@ async def test_execute_query_traceback_debug(self, caplog, mocker, db): [QueryMetric("metric", [])], "SELECT 1 AS metric", ) - mocker.patch.object(db, "_execute_query").side_effect = Exception( - "boom!" - ) await db.connect() + mocker.patch.object(db, "execute_sql").side_effect = Exception("boom!") with ( caplog.at_level(logging.DEBUG), pytest.raises(DataBaseQueryError) as error, @@ -625,7 +740,7 @@ async def test_execute_query_traceback_debug(self, caplog, mocker, db): 'query "query" on database "db" failed: boom!' in caplog.messages ) # traceback is included in messages - assert "await self._execute_query(query)" in caplog.messages[-1] + assert "await self.execute_sql(" in caplog.messages[-1] async def test_execute_timeout(self, caplog, db): """If the query times out, an error is raised and logged.""" @@ -660,4 +775,18 @@ async def test_execute_sql(self, db): """It's possible to execute raw SQL.""" await db.connect() result = await db.execute_sql("SELECT 10, 20") - assert await result.fetchall() == [(10, 20)] + assert result.rows == [(10, 20)] + + @pytest.mark.parametrize( + "error,message", + [ + ("message", "message"), + (Exception("message"), "message"), + (Exception(), "Exception"), + ], + ) + def test_error_message( + self, db: DataBase, error: str | Exception, message: str + ) -> None: + """An error message is returned both for strings and exceptions.""" + assert db._error_message(error) == message diff --git a/tests/loop_test.py b/tests/loop_test.py index c30113e..a269ab2 100644 --- a/tests/loop_test.py +++ b/tests/loop_test.py @@ -1,10 +1,10 @@ import asyncio from collections import defaultdict -from datetime import datetime +from collections.abc import Callable, Iterator from decimal import Decimal import logging from pathlib import Path -import re +from typing import Any from prometheus_aioexporter import MetricsRegistry import pytest @@ -18,21 +18,8 @@ from query_exporter.db import DataBase -class re_match: - """Assert that comparison matches the specified regexp.""" - - def __init__(self, pattern, flags=0): - self._re = re.compile(pattern, flags) - - def __eq__(self, string): - return bool(self._re.match(string)) - - def __repr__(self): - return self._re.pattern # pragma: nocover - - @pytest.fixture -def config_data(): +def config_data() -> Iterator[dict[str, Any]]: yield { "databases": {"db": {"dsn": "sqlite://"}}, "metrics": {"m": {"type": "gauge"}}, @@ -48,15 +35,17 @@ def config_data(): @pytest.fixture -def registry(): +def registry() -> Iterator[MetricsRegistry]: yield MetricsRegistry() @pytest.fixture -async def make_query_loop(tmp_path, config_data, registry): +async def make_query_loop( + tmp_path: Path, config_data: dict[str, Any], registry: MetricsRegistry +) -> Iterator[Callable[[], MetricsRegistry]]: query_loops = [] - def make_loop(): + def make_loop() -> loop.QueryLoop: config_file = tmp_path / "config.yaml" config_file.write_text(yaml.dump(config_data), "utf-8") logger = logging.getLogger() @@ -75,19 +64,12 @@ def make_loop(): @pytest.fixture -async def query_loop(make_query_loop): +async def query_loop( + make_query_loop: Callable[[], loop.QueryLoop], +) -> Iterator[loop.QueryLoop]: yield make_query_loop() -@pytest.fixture -def metrics_expiration(): - yield { - "m1": 50, - "m2": 100, - "m3": None, - } - - def metric_values(metric, by_labels=()): """Return values for the metric.""" if metric._type == "gauge": @@ -107,7 +89,7 @@ def metric_values(metric, by_labels=()): return values if by_labels else values[suffix] -async def run_queries(db_file: Path, *queries: str): +async def run_queries(db_file: Path, *queries: str) -> None: config = DataBaseConfig(name="db", dsn=f"sqlite:///{db_file}") async with DataBase(config) as db: for query in queries: @@ -115,9 +97,9 @@ async def run_queries(db_file: Path, *queries: str): class TestMetricsLastSeen: - def test_update(self, metrics_expiration): + def test_update(self) -> None: """Last seen times are tracked for each series of metrics with expiration.""" - last_seen = loop.MetricsLastSeen(metrics_expiration) + last_seen = loop.MetricsLastSeen({"m1": 50, "m2": 100}) last_seen.update("m1", {"l1": "v1", "l2": "v2"}, 100) last_seen.update("m1", {"l1": "v3", "l2": "v4"}, 200) last_seen.update("other", {"l3": "v100"}, 300) @@ -128,41 +110,63 @@ def test_update(self, metrics_expiration): } } - def test_update_label_values_sorted_by_name(self, metrics_expiration): + def test_update_label_values_sorted_by_name(self) -> None: """Last values are sorted by label names.""" - last_seen = loop.MetricsLastSeen(metrics_expiration) + last_seen = loop.MetricsLastSeen({"m1": 50}) last_seen.update("m1", {"l2": "v2", "l1": "v1"}, 100) assert last_seen._last_seen == {"m1": {("v1", "v2"): 100}} - def test_expire_series(self, metrics_expiration): + def test_expire_series_not_expired(self) -> None: + """If no entry for a metric is expired, it's not returned.""" + last_seen = loop.MetricsLastSeen({"m1": 50}) + last_seen.update("m1", {"l1": "v1", "l2": "v2"}, 10) + last_seen.update("m1", {"l1": "v3", "l2": "v4"}, 20) + assert last_seen.expire_series(30) == {} + assert last_seen._last_seen == { + "m1": { + ("v1", "v2"): 10, + ("v3", "v4"): 20, + } + } + + def test_expire_series(self) -> None: """Expired metric series are returned and removed.""" - last_seen = loop.MetricsLastSeen(metrics_expiration) + last_seen = loop.MetricsLastSeen({"m1": 50, "m2": 100}) last_seen.update("m1", {"l1": "v1", "l2": "v2"}, 10) last_seen.update("m1", {"l1": "v3", "l2": "v4"}, 100) last_seen.update("m2", {"l3": "v100"}, 100) - expired = last_seen.expire_series(120) - assert expired == {"m1": [("v1", "v2")], "m2": []} + assert last_seen.expire_series(120) == {"m1": [("v1", "v2")]} assert last_seen._last_seen == { "m1": {("v3", "v4"): 100}, "m2": {("v100",): 100}, } + def test_expire_no_labels(self) -> None: + last_seen = loop.MetricsLastSeen({"m1": 50}) + last_seen.update("m1", {}, 10) + expired = last_seen.expire_series(120) + assert expired == {"m1": [()]} + assert last_seen._last_seen == {} + class TestQueryLoop: - async def test_start(self, query_loop): + async def test_start(self, query_tracker, query_loop) -> None: """The start method starts timed calls for queries.""" await query_loop.start() timed_call = query_loop._timed_calls["q"] assert timed_call.running + await query_tracker.wait_results() - async def test_stop(self, query_loop): + async def test_stop(self, query_loop) -> None: """The stop method stops timed calls for queries.""" await query_loop.start() timed_call = query_loop._timed_calls["q"] await query_loop.stop() assert not timed_call.running - async def test_run_query(self, query_tracker, query_loop, registry): + async def test_run_query( + self, query_tracker, query_loop, registry + ) -> None: """Queries are run and update metrics.""" await query_loop.start() await query_tracker.wait_results() @@ -288,14 +292,20 @@ async def test_update_metric_decimal_value( assert value == 100.123 assert isinstance(value, float) - async def test_run_query_log(self, caplog, query_tracker, query_loop): + async def test_run_query_log( + self, caplog, re_match, query_tracker, query_loop + ): """Debug messages are logged on query execution.""" caplog.set_level(logging.DEBUG) await query_loop.start() await query_tracker.wait_queries() assert caplog.messages == [ + re_match(r'worker "DataBase-db": started'), + 'worker "DataBase-db": received action "_connect"', 'connected to database "db"', 'running query "q" on database "db"', + 'worker "DataBase-db": received action "_execute"', + 'worker "DataBase-db": received action "from_result"', 'updating metric "m" set 100.0 {database="db"}', re_match( r'updating metric "query_latency" observe .* \{database="db",query="q"\}' @@ -307,7 +317,7 @@ async def test_run_query_log(self, caplog, query_tracker, query_loop): ] async def test_run_query_log_labels( - self, caplog, query_tracker, config_data, make_query_loop + self, caplog, re_match, query_tracker, config_data, make_query_loop ): """Debug messages include metric labels.""" config_data["metrics"]["m"]["labels"] = ["l"] @@ -317,8 +327,12 @@ async def test_run_query_log_labels( await query_loop.start() await query_tracker.wait_queries() assert caplog.messages == [ + re_match(r'worker "DataBase-db": started'), + 'worker "DataBase-db": received action "_connect"', 'connected to database "db"', 'running query "q" on database "db"', + 'worker "DataBase-db": received action "_execute"', + 'worker "DataBase-db": received action "from_result"', 'updating metric "m" set 100.0 {database="db",l="foo"}', re_match( r'updating metric "query_latency" observe .* \{database="db",query="q"\}' @@ -346,7 +360,7 @@ async def test_run_query_increase_database_error_count( """Count of database errors is incremented on failed connection.""" query_loop = make_query_loop() db = query_loop._databases["db"] - mock_connect = mocker.patch.object(db._engine, "connect") + mock_connect = mocker.patch.object(db._conn.engine, "connect") mock_connect.side_effect = Exception("connection failed") await query_loop.start() await query_tracker.wait_failures() @@ -403,19 +417,23 @@ async def test_run_query_at_interval( assert len(query_tracker.queries) == 2 async def test_run_timed_queries_invalid_result_count( - self, query_tracker, config_data, make_query_loop, advance_time + self, query_tracker, config_data, make_query_loop ): """Timed queries returning invalid elements count are removed.""" config_data["queries"]["q"]["sql"] = "SELECT 100.0 AS a, 200.0 AS b" + config_data["queries"]["q"]["interval"] = 1.0 query_loop = make_query_loop() await query_loop.start() - await advance_time(0) # kick the first run - assert len(query_tracker.queries) == 1 - assert len(query_tracker.results) == 0 - # the query is not run again - await advance_time(5) + timed_call = query_loop._timed_calls["q"] + await asyncio.sleep(1.1) + await query_tracker.wait_failures() + assert len(query_tracker.failures) == 1 assert len(query_tracker.results) == 0 - await advance_time(5) + # the query has been stopped and removed + assert not timed_call.running + await asyncio.sleep(1.1) + await query_tracker.wait_failures() + assert len(query_tracker.failures) == 1 assert len(query_tracker.results) == 0 async def test_run_timed_queries_invalid_result_count_stop_task( @@ -539,8 +557,8 @@ async def test_run_aperiodic_queries_not_removed_if_not_failing_on_all_dbs( async def test_clear_expired_series( self, - mocker, tmp_path, + advance_time, query_tracker, config_data, make_query_loop, @@ -555,7 +573,6 @@ async def test_clear_expired_series( "expiration": 10, } ) - # call metric collection directly config_data["queries"]["q"]["sql"] = "SELECT * FROM test" del config_data["queries"]["q"]["interval"] @@ -566,24 +583,19 @@ async def test_clear_expired_series( 'INSERT INTO test VALUES (20, "bar")', ) query_loop = make_query_loop() - mock_timestamp = mocker.patch.object(query_loop, "_timestamp") - mock_timestamp.return_value = datetime.now().timestamp() await query_loop.run_aperiodic_queries() await query_tracker.wait_results() queries_metric = registry.get_metric("m") - assert metric_values(queries_metric, by_labels=("l",)), { + assert metric_values(queries_metric, by_labels=("l",)) == { ("foo",): 10.0, ("bar",): 20.0, } - await run_queries( - db, - "DELETE FROM test WHERE m = 10", - ) - # mock that more time has passed than expiration - mock_timestamp.return_value += 20 + await run_queries(db, "DELETE FROM test WHERE m = 10") + # go beyond expiration time + await advance_time(20) await query_loop.run_aperiodic_queries() await query_tracker.wait_results() query_loop.clear_expired_series() - assert metric_values(queries_metric, by_labels=("l",)), { + assert metric_values(queries_metric, by_labels=("l",)) == { ("bar",): 20.0, }