Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-45138: Finish the worker and API model separation #192

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/vocutouts/models/cutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def to_worker_stencil(self) -> WorkerStencil:
)


class CutoutParameters(ParametersModel):
class CutoutParameters(ParametersModel[WorkerCutout]):
"""Parameters to a cutout request."""

ids: list[str] = Field(
Expand Down Expand Up @@ -255,7 +255,7 @@ def from_job_parameters(cls, params: list[UWSJobParameter]) -> Self:
except ValidationError as e:
raise InvalidCutoutParameterError(str(e), params) from e

def to_worker_cutout(self) -> WorkerCutout:
def to_worker_parameters(self) -> WorkerCutout:
"""Convert to the domain model used by the backend worker."""
stencils = [s.to_worker_stencil() for s in self.stencils]
return WorkerCutout(dataset_ids=self.ids, stencils=stencils)
Expand Down
11 changes: 9 additions & 2 deletions src/vocutouts/uws/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Callable, Coroutine
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Self, TypeAlias
from typing import Generic, Self, TypeAlias, TypeVar

from arq.connections import RedisSettings
from pydantic import BaseModel, SecretStr
Expand All @@ -27,6 +27,9 @@
]
"""Type for a dependency that gathers parameters for a job."""

T = TypeVar("T", bound=BaseModel)
"""Generic type for the worker parameters."""

__all__ = [
"DestructionValidator",
"ExecutionDurationValidator",
Expand All @@ -36,7 +39,7 @@
]


class ParametersModel(BaseModel, ABC):
class ParametersModel(BaseModel, ABC, Generic[T]):
"""Defines the interface for a model suitable for job parameters."""

@classmethod
Expand Down Expand Up @@ -64,6 +67,10 @@ def from_job_parameters(cls, params: list[UWSJobParameter]) -> Self:
Raised if the parameters do not validate.
"""

@abstractmethod
def to_worker_parameters(self) -> T:
"""Convert to the domain model used by the backend worker."""


@dataclass
class UWSConfig:
Expand Down
5 changes: 3 additions & 2 deletions src/vocutouts/uws/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,16 @@ async def start(self, user: str, job_id: str, token: str) -> JobMetadata:
raise PermissionDeniedError(f"Access to job {job_id} denied")
if job.phase not in (ExecutionPhase.PENDING, ExecutionPhase.HELD):
raise InvalidPhaseError("Cannot start job in phase {job.phase}")
params = self._validate_parameters(job.parameters)
logger = self._build_logger_for_job(job, params)
params_model = self._validate_parameters(job.parameters)
logger = self._build_logger_for_job(job, params_model)
info = WorkerJobInfo(
job_id=job.job_id,
user=user,
token=token,
timeout=job.execution_duration,
run_id=job.run_id,
)
params = params_model.to_worker_parameters()
metadata = await self._arq.enqueue(self._config.worker, params, info)
await self._storage.mark_queued(job_id, metadata)
logger.info("Started job", arq_job_id=metadata.id)
Expand Down
3 changes: 3 additions & 0 deletions tests/handlers/async_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from httpx import ASGITransport, AsyncClient
from safir.testing.slack import MockSlackWebhook

from vocutouts.models.domain.cutout import WorkerCutout
from vocutouts.uws.models import UWSJobResult

from ..support.uws import MockJobRunner
Expand Down Expand Up @@ -96,6 +97,8 @@ async def test_create_job(client: AsyncClient, runner: MockJobRunner) -> None:
assert r.headers["Location"] == "https://example.com/api/cutout/jobs/2"

async def run_job() -> None:
arq_job = await runner.get_job_metadata("someone", "2")
assert isinstance(arq_job.args[0], WorkerCutout)
await runner.mark_in_progress("someone", "2", delay=0.2)
results = [
UWSJobResult(
Expand Down
4 changes: 2 additions & 2 deletions tests/models/domain/cutout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ def test_pickle() -> None:
CutoutParameters(
ids=["foo"],
stencils=[CircleStencil.from_string("1 1.42 1")],
).to_worker_cutout(),
).to_worker_parameters(),
CutoutParameters(
ids=["foo"],
stencils=[PolygonStencil.from_string("1 0 1 1 0 1 0 0")],
).to_worker_cutout(),
).to_worker_parameters(),
):
cutout_pickle = pickle.loads(pickle.dumps(cutout))
assert cutout.dataset_ids == cutout_pickle.dataset_ids
Expand Down
32 changes: 29 additions & 3 deletions tests/support/uws.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from arq.connections import RedisSettings
from fastapi import Form, Query
from pydantic import SecretStr
from safir.arq import ArqMode, MockArqQueue
from pydantic import BaseModel, SecretStr
from safir.arq import ArqMode, JobMetadata, MockArqQueue

from vocutouts.uws.config import ParametersModel, UWSConfig
from vocutouts.uws.dependencies import UWSFactory
Expand All @@ -23,7 +23,11 @@
]


class SimpleParameters(ParametersModel):
class SimpleWorkerParameters(BaseModel):
name: str


class SimpleParameters(ParametersModel[SimpleWorkerParameters]):
name: str

@classmethod
Expand All @@ -32,6 +36,9 @@ def from_job_parameters(cls, params: list[UWSJobParameter]) -> Self:
assert params[0].parameter_id == "name"
return cls(name=params[0].value)

def to_worker_parameters(self) -> SimpleWorkerParameters:
return SimpleWorkerParameters(name=self.name)


async def _get_dependency(
name: Annotated[str, Query()],
Expand Down Expand Up @@ -102,6 +109,25 @@ def __init__(self, factory: UWSFactory, arq_queue: MockArqQueue) -> None:
self._store = factory.create_job_store()
self._arq = arq_queue

async def get_job_metadata(
self, username: str, job_id: str
) -> JobMetadata:
"""Get the arq job metadata for a job.

Parameters
----------
job_id
UWS job ID.

Returns
-------
JobMetadata
arq job metadata.
"""
job = await self._service.get(username, job_id)
assert job.message_id
return await self._arq.get_job_metadata(job.message_id)

async def mark_in_progress(
self, username: str, job_id: str, *, delay: float | None = None
) -> UWSJob:
Expand Down
30 changes: 27 additions & 3 deletions tests/uws/job_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

from __future__ import annotations

from datetime import datetime, timedelta
from datetime import UTC, datetime, timedelta
from unittest.mock import ANY

import pytest
from arq.constants import default_queue_name
from arq.jobs import JobStatus
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from safir.arq import MockArqQueue
Expand All @@ -17,8 +20,9 @@
from vocutouts.uws.config import UWSConfig
from vocutouts.uws.dependencies import UWSFactory
from vocutouts.uws.models import UWSJob, UWSJobParameter, UWSJobResult
from vocutouts.uws.uwsworker import WorkerJobInfo

from ..support.uws import MockJobRunner
from ..support.uws import MockJobRunner, SimpleWorkerParameters

PENDING_JOB = """
<uws:job
Expand Down Expand Up @@ -95,7 +99,10 @@

@pytest.mark.asyncio
async def test_job_run(
client: AsyncClient, runner: MockJobRunner, uws_factory: UWSFactory
client: AsyncClient,
runner: MockJobRunner,
uws_factory: UWSFactory,
uws_config: UWSConfig,
) -> None:
job_service = uws_factory.create_job_service()

Expand Down Expand Up @@ -158,6 +165,23 @@ async def test_job_run(
)
await runner.mark_in_progress("user", "1")

# Check that the correct data was passed to the backend worker.
metadata = await runner.get_job_metadata("user", "1")
assert metadata.name == uws_config.worker
assert metadata.args[0] == SimpleWorkerParameters(name="Jane")
assert metadata.args[1] == WorkerJobInfo(
job_id="1",
user="user",
token="sometoken",
timeout=ANY,
run_id="some-run-id",
)
assert not metadata.kwargs
now = datetime.now(tz=UTC)
assert now - timedelta(seconds=2) <= metadata.enqueue_time <= now
assert metadata.status == JobStatus.in_progress
assert metadata.queue_name == default_queue_name

# Tell the queue the job is finished.
results = [
UWSJobResult(
Expand Down