diff --git a/src/vocutouts/models/cutout.py b/src/vocutouts/models/cutout.py index fff5960..ebcfc40 100644 --- a/src/vocutouts/models/cutout.py +++ b/src/vocutouts/models/cutout.py @@ -181,7 +181,7 @@ def to_worker_stencil(self) -> WorkerStencil: ) -class CutoutParameters(ParametersModel): +class CutoutParameters(ParametersModel[WorkerCutout]): """Parameters to a cutout request.""" ids: list[str] = Field( @@ -255,7 +255,7 @@ def from_job_parameters(cls, params: list[UWSJobParameter]) -> Self: except ValidationError as e: raise InvalidCutoutParameterError(str(e), params) from e - def to_worker_cutout(self) -> WorkerCutout: + def to_worker_parameters(self) -> WorkerCutout: """Convert to the domain model used by the backend worker.""" stencils = [s.to_worker_stencil() for s in self.stencils] return WorkerCutout(dataset_ids=self.ids, stencils=stencils) diff --git a/src/vocutouts/uws/config.py b/src/vocutouts/uws/config.py index 3f6567b..48e3539 100644 --- a/src/vocutouts/uws/config.py +++ b/src/vocutouts/uws/config.py @@ -6,7 +6,7 @@ from collections.abc import Callable, Coroutine from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Self, TypeAlias +from typing import Generic, Self, TypeAlias, TypeVar from arq.connections import RedisSettings from pydantic import BaseModel, SecretStr @@ -27,6 +27,9 @@ ] """Type for a dependency that gathers parameters for a job.""" +T = TypeVar("T", bound=BaseModel) +"""Generic type for the worker parameters.""" + __all__ = [ "DestructionValidator", "ExecutionDurationValidator", @@ -36,7 +39,7 @@ ] -class ParametersModel(BaseModel, ABC): +class ParametersModel(BaseModel, ABC, Generic[T]): """Defines the interface for a model suitable for job parameters.""" @classmethod @@ -64,6 +67,10 @@ def from_job_parameters(cls, params: list[UWSJobParameter]) -> Self: Raised if the parameters do not validate. """ + @abstractmethod + def to_worker_parameters(self) -> T: + """Convert to the domain model used by the backend worker.""" + @dataclass class UWSConfig: diff --git a/src/vocutouts/uws/service.py b/src/vocutouts/uws/service.py index a975c9e..2a58279 100644 --- a/src/vocutouts/uws/service.py +++ b/src/vocutouts/uws/service.py @@ -335,8 +335,8 @@ async def start(self, user: str, job_id: str, token: str) -> JobMetadata: raise PermissionDeniedError(f"Access to job {job_id} denied") if job.phase not in (ExecutionPhase.PENDING, ExecutionPhase.HELD): raise InvalidPhaseError("Cannot start job in phase {job.phase}") - params = self._validate_parameters(job.parameters) - logger = self._build_logger_for_job(job, params) + params_model = self._validate_parameters(job.parameters) + logger = self._build_logger_for_job(job, params_model) info = WorkerJobInfo( job_id=job.job_id, user=user, @@ -344,6 +344,7 @@ async def start(self, user: str, job_id: str, token: str) -> JobMetadata: timeout=job.execution_duration, run_id=job.run_id, ) + params = params_model.to_worker_parameters() metadata = await self._arq.enqueue(self._config.worker, params, info) await self._storage.mark_queued(job_id, metadata) logger.info("Started job", arq_job_id=metadata.id) diff --git a/tests/handlers/async_test.py b/tests/handlers/async_test.py index 79886ae..f6bde91 100644 --- a/tests/handlers/async_test.py +++ b/tests/handlers/async_test.py @@ -10,6 +10,7 @@ from httpx import ASGITransport, AsyncClient from safir.testing.slack import MockSlackWebhook +from vocutouts.models.domain.cutout import WorkerCutout from vocutouts.uws.models import UWSJobResult from ..support.uws import MockJobRunner @@ -96,6 +97,8 @@ async def test_create_job(client: AsyncClient, runner: MockJobRunner) -> None: assert r.headers["Location"] == "https://example.com/api/cutout/jobs/2" async def run_job() -> None: + arq_job = await runner.get_job_metadata("someone", "2") + assert isinstance(arq_job.args[0], WorkerCutout) await runner.mark_in_progress("someone", "2", delay=0.2) results = [ UWSJobResult( diff --git a/tests/models/domain/cutout_test.py b/tests/models/domain/cutout_test.py index 0ca54af..0dde54e 100644 --- a/tests/models/domain/cutout_test.py +++ b/tests/models/domain/cutout_test.py @@ -16,11 +16,11 @@ def test_pickle() -> None: CutoutParameters( ids=["foo"], stencils=[CircleStencil.from_string("1 1.42 1")], - ).to_worker_cutout(), + ).to_worker_parameters(), CutoutParameters( ids=["foo"], stencils=[PolygonStencil.from_string("1 0 1 1 0 1 0 0")], - ).to_worker_cutout(), + ).to_worker_parameters(), ): cutout_pickle = pickle.loads(pickle.dumps(cutout)) assert cutout.dataset_ids == cutout_pickle.dataset_ids diff --git a/tests/support/uws.py b/tests/support/uws.py index 2d79111..50682fe 100644 --- a/tests/support/uws.py +++ b/tests/support/uws.py @@ -9,8 +9,8 @@ from arq.connections import RedisSettings from fastapi import Form, Query -from pydantic import SecretStr -from safir.arq import ArqMode, MockArqQueue +from pydantic import BaseModel, SecretStr +from safir.arq import ArqMode, JobMetadata, MockArqQueue from vocutouts.uws.config import ParametersModel, UWSConfig from vocutouts.uws.dependencies import UWSFactory @@ -23,7 +23,11 @@ ] -class SimpleParameters(ParametersModel): +class SimpleWorkerParameters(BaseModel): + name: str + + +class SimpleParameters(ParametersModel[SimpleWorkerParameters]): name: str @classmethod @@ -32,6 +36,9 @@ def from_job_parameters(cls, params: list[UWSJobParameter]) -> Self: assert params[0].parameter_id == "name" return cls(name=params[0].value) + def to_worker_parameters(self) -> SimpleWorkerParameters: + return SimpleWorkerParameters(name=self.name) + async def _get_dependency( name: Annotated[str, Query()], @@ -102,6 +109,25 @@ def __init__(self, factory: UWSFactory, arq_queue: MockArqQueue) -> None: self._store = factory.create_job_store() self._arq = arq_queue + async def get_job_metadata( + self, username: str, job_id: str + ) -> JobMetadata: + """Get the arq job metadata for a job. + + Parameters + ---------- + job_id + UWS job ID. + + Returns + ------- + JobMetadata + arq job metadata. + """ + job = await self._service.get(username, job_id) + assert job.message_id + return await self._arq.get_job_metadata(job.message_id) + async def mark_in_progress( self, username: str, job_id: str, *, delay: float | None = None ) -> UWSJob: diff --git a/tests/uws/job_api_test.py b/tests/uws/job_api_test.py index ac1f6e4..fcda26b 100644 --- a/tests/uws/job_api_test.py +++ b/tests/uws/job_api_test.py @@ -6,9 +6,12 @@ from __future__ import annotations -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta +from unittest.mock import ANY import pytest +from arq.constants import default_queue_name +from arq.jobs import JobStatus from fastapi import FastAPI from httpx import ASGITransport, AsyncClient from safir.arq import MockArqQueue @@ -17,8 +20,9 @@ from vocutouts.uws.config import UWSConfig from vocutouts.uws.dependencies import UWSFactory from vocutouts.uws.models import UWSJob, UWSJobParameter, UWSJobResult +from vocutouts.uws.uwsworker import WorkerJobInfo -from ..support.uws import MockJobRunner +from ..support.uws import MockJobRunner, SimpleWorkerParameters PENDING_JOB = """ None: job_service = uws_factory.create_job_service() @@ -158,6 +165,23 @@ async def test_job_run( ) await runner.mark_in_progress("user", "1") + # Check that the correct data was passed to the backend worker. + metadata = await runner.get_job_metadata("user", "1") + assert metadata.name == uws_config.worker + assert metadata.args[0] == SimpleWorkerParameters(name="Jane") + assert metadata.args[1] == WorkerJobInfo( + job_id="1", + user="user", + token="sometoken", + timeout=ANY, + run_id="some-run-id", + ) + assert not metadata.kwargs + now = datetime.now(tz=UTC) + assert now - timedelta(seconds=2) <= metadata.enqueue_time <= now + assert metadata.status == JobStatus.in_progress + assert metadata.queue_name == default_queue_name + # Tell the queue the job is finished. results = [ UWSJobResult(