diff --git a/invenio_jobs/config.py b/invenio_jobs/config.py index fd2df0d..f000353 100644 --- a/invenio_jobs/config.py +++ b/invenio_jobs/config.py @@ -30,6 +30,27 @@ JOBS_FACETS = {} """Facets/aggregations for Jobs results.""" +JOBS_QUEUES = { + "celery": { + "name": "celery", + "title": _("Default"), + "description": _("Default queue"), + }, + "low": { + "name": "low", + "title": _("Low"), + "description": _("Low priority queue"), + }, +} +"""List of available Celery queues. + +This doesn't create any of the queues, but just controls to which Celery queue a job +is pushed to. You still need to configure Celery workers to listen to these queues. +""" + +JOBS_DEFAULT_QUEUE = None +"""Default Celery queue.""" + JOBS_SORT_OPTIONS = { "jobs": dict( title=_("Jobs"), diff --git a/invenio_jobs/ext.py b/invenio_jobs/ext.py index 967c98f..4d52458 100644 --- a/invenio_jobs/ext.py +++ b/invenio_jobs/ext.py @@ -8,9 +8,12 @@ """Jobs extension.""" +from celery import current_app as current_celery_app +from flask import current_app from invenio_i18n import gettext as _ from . import config +from .models import Task from .resources import ( JobsResource, JobsResourceConfig, @@ -66,6 +69,24 @@ def init_resource(self, app): TasksResourceConfig.build(app), self.tasks_service ) + @property + def queues(self): + """Return the queues.""" + return current_app.config["JOBS_QUEUES"] + + @property + def default_queue(self): + """Return the default queue.""" + return ( + current_app.config.get("JOBS_DEFAULT_QUEUE") + or current_celery_app.conf.task_default_queue + ) + + @property + def tasks(self): + """Return the tasks.""" + return Task.all() + def finalize_app(app): """Finalize app.""" diff --git a/invenio_jobs/models.py b/invenio_jobs/models.py index ae07738..90b0e2d 100644 --- a/invenio_jobs/models.py +++ b/invenio_jobs/models.py @@ -8,9 +8,10 @@ """Models.""" import enum +import uuid from inspect import signature -from celery import current_app +from celery import current_app as current_celery_app from invenio_accounts.models import User from invenio_db import db from sqlalchemy.dialects import postgresql @@ -29,15 +30,15 @@ class Job(db.Model, Timestamp): """Job model.""" - id = db.Column(UUIDType, primary_key=True) + id = db.Column(UUIDType, primary_key=True, default=uuid.uuid4) active = db.Column(db.Boolean, default=True, nullable=False) title = db.Column(db.String(255), nullable=False) description = db.Column(db.Text) - celery_tasks = db.Column(db.String(255)) + task = db.Column(db.String(255)) default_queue = db.Column(db.String(64)) default_args = db.Column(JSON, default=lambda: dict(), nullable=True) - schedule = db.Column(JSON, default=lambda: dict(), nullable=True) + schedule = db.Column(JSON, nullable=True) # TODO: See if we move this to an API class @property @@ -60,7 +61,7 @@ class RunStatusEnum(enum.Enum): class Run(db.Model, Timestamp): """Run model.""" - id = db.Column(UUIDType, primary_key=True) + id = db.Column(UUIDType, primary_key=True, default=uuid.uuid4) job_id = db.Column(UUIDType, db.ForeignKey(Job.id)) job = db.relationship(Job, backref=db.backref("runs", lazy="dynamic")) @@ -116,10 +117,10 @@ def all(cls): """Return all tasks.""" if getattr(cls, "_all_tasks", None) is None: # Cache results - cls._all_tasks = [ - cls(task) - for task in current_app.tasks.values() + cls._all_tasks = { + k: cls(task) + for k, task in current_celery_app.tasks.items() # Filter outer Celery internal tasks - if not task.name.startswith("celery.") - ] + if not k.startswith("celery.") + } return cls._all_tasks diff --git a/invenio_jobs/proxies.py b/invenio_jobs/proxies.py new file mode 100644 index 0000000..99aa1f0 --- /dev/null +++ b/invenio_jobs/proxies.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Invenio-Jobs is free software; you can redistribute it and/or modify it +# under the terms of the MIT License; see LICENSE file for more details. + +"""Proxies.""" + +from flask import current_app +from werkzeug.local import LocalProxy + +current_jobs = LocalProxy(lambda: current_app.extensions["invenio-jobs"]) +"""Jobs extension.""" + +current_jobs_service = LocalProxy(lambda: current_jobs.service) +"""Jobs service.""" + +current_runs_service = LocalProxy(lambda: current_jobs.runs_service) +"""Runs service.""" diff --git a/invenio_jobs/resources/config.py b/invenio_jobs/resources/config.py index c6999e2..7b92635 100644 --- a/invenio_jobs/resources/config.py +++ b/invenio_jobs/resources/config.py @@ -14,6 +14,8 @@ from invenio_records_resources.resources.records.args import SearchRequestArgsSchema from invenio_records_resources.services.base.config import ConfiguratorMixin +from ..services.errors import JobNotFoundError + class TasksResourceConfig(ResourceConfig, ConfiguratorMixin): """Celery tasks resource config.""" @@ -50,7 +52,9 @@ class JobsResourceConfig(ResourceConfig, ConfiguratorMixin): error_handlers = { **ErrorHandlersMixin.error_handlers, - # TODO: Add custom error handlers here + JobNotFoundError: create_error_handler( + lambda e: HTTPJSONException(code=404, description=e.description) + ), } diff --git a/invenio_jobs/services/config.py b/invenio_jobs/services/config.py index 0de7d1b..007c16d 100644 --- a/invenio_jobs/services/config.py +++ b/invenio_jobs/services/config.py @@ -17,6 +17,7 @@ SearchOptions as SearchOptionsBase, ) from invenio_records_resources.services.records.links import pagination_links +from sqlalchemy import asc, desc from ..models import Job, Run, Task from . import results @@ -68,7 +69,15 @@ class TasksServiceConfig(ServiceConfig, ConfiguratorMixin): class JobSearchOptions(SearchOptionsBase): """Job search options.""" - # TODO: See what we need to override + sort_default = "title" + sort_direction_default = "asc" + sort_direction_options = { + "asc": dict(title=_("Ascending"), fn=asc), + "desc": dict(title=_("Descending"), fn=desc), + } + sort_options = {"title": dict(title=_("Title"), fields=["title"])} + + pagination_options = {"default_results_per_page": 25} class JobsServiceConfig(ServiceConfig, ConfiguratorMixin): diff --git a/invenio_jobs/services/errors.py b/invenio_jobs/services/errors.py new file mode 100644 index 0000000..630aa8b --- /dev/null +++ b/invenio_jobs/services/errors.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# Copyright (C) 2024 University of Münster. +# +# Invenio-Jobs is free software; you can redistribute it and/or modify it +# under the terms of the MIT License; see LICENSE file for more details. + +"""Service definitions.""" + +from invenio_i18n import gettext as _ + + +class JobsError(Exception): + """Base class for Jobs errors.""" + + def __init__(self, description, *args: object): + """Constructor.""" + self.description = description + super().__init__(*args) + + +class JobNotFoundError(JobsError): + """Job not found error.""" + + def __init__(self, id): + """Initialise error.""" + super().__init__( + description=_("Job with ID {id} does not exist.").format(id=id) + ) diff --git a/invenio_jobs/services/results.py b/invenio_jobs/services/results.py index dbdbecf..d0cefb7 100644 --- a/invenio_jobs/services/results.py +++ b/invenio_jobs/services/results.py @@ -7,6 +7,9 @@ """Service results.""" +from collections.abc import Iterable, Sized + +from flask_sqlalchemy import Pagination from invenio_records_resources.services.records.results import RecordItem, RecordList @@ -16,18 +19,29 @@ class Item(RecordItem): @property def id(self): """Get the result id.""" - return self._record.id + return str(self._record.id) class List(RecordList): """List result.""" + @property + def items(self): + """Iterator over the items.""" + if isinstance(self._results, Pagination): + return self._results.items + elif isinstance(self._results, Iterable): + return self._results + return self._results + @property def total(self): """Get total number of hits.""" if hasattr(self._results, "hits"): return self._results.hits.total["value"] - elif isinstance(self._results, (tuple, list)): + if isinstance(self._results, Pagination): + return self._results.total + elif isinstance(self._results, Sized): return len(self._results) else: return None @@ -44,7 +58,7 @@ def aggregations(self): @property def hits(self): """Iterator over the hits.""" - for hit in self._results: + for hit in self.items: # Project the hit projection = self._schema.dump( hit, diff --git a/invenio_jobs/services/schema.py b/invenio_jobs/services/schema.py index 34efdc5..6b6bf60 100644 --- a/invenio_jobs/services/schema.py +++ b/invenio_jobs/services/schema.py @@ -11,8 +11,12 @@ from invenio_i18n import lazy_gettext as _ from marshmallow import EXCLUDE, Schema, fields, validate +from marshmallow_oneofschema import OneOfSchema from marshmallow_utils.fields import SanitizedUnicode from marshmallow_utils.permissions import FieldPermissionsMixin +from marshmallow_utils.validators import LazyOneOf + +from ..proxies import current_jobs def _not_blank(**kwargs): @@ -57,6 +61,47 @@ class TaskSchema(Schema, FieldPermissionsMixin): ) +class IntervalScheduleSchema(Schema): + """Schema for an interval schedule based on ``datetime.timedelta``.""" + + type = fields.Constant("interval") + + days = fields.Integer() + seconds = fields.Integer() + microseconds = fields.Integer() + milliseconds = fields.Integer() + minutes = fields.Integer() + hours = fields.Integer() + weeks = fields.Integer() + + +class CrontabScheduleSchema(Schema): + """Schema for a crontab schedule.""" + + type = fields.Constant("crontab") + + minute = fields.String(load_default="*") + hour = fields.String(load_default="*") + day_of_week = fields.String(load_default="*") + day_of_month = fields.String(load_default="*") + month_of_year = fields.String(load_default="*") + + +class ScheduleSchema(OneOfSchema): + """Schema for a schedule.""" + + def get_obj_type(self, obj): + if isinstance(obj, dict) and "type" in obj: + return obj["type"] + return super().get_obj_type(obj) + + type_schemas = { + "interval": IntervalScheduleSchema, + "crontab": CrontabScheduleSchema, + } + type_field_remove = False + + class JobSchema(Schema, FieldPermissionsMixin): """Base schema for a job.""" @@ -67,12 +112,31 @@ class Meta: id = fields.UUID(dump_only=True) + created = fields.DateTime(dump_only=True) + updated = fields.DateTime(dump_only=True) + title = SanitizedUnicode(required=True, validate=_not_blank(max=250)) description = SanitizedUnicode() + active = fields.Boolean(load_default=True) + + task = fields.String( + required=True, + validate=LazyOneOf(choices=lambda: current_jobs.tasks.keys()), + ) + default_queue = fields.String( + validate=LazyOneOf(choices=lambda: current_jobs.queues.keys()), + load_default=lambda: current_jobs.default_queue, + ) + default_args = fields.Dict(load_default=dict) + + schedule = fields.Nested(ScheduleSchema, allow_none=True, load_default=None) + + last_run = fields.Nested(lambda: RunSchema, dump_only=True) + class RunSchema(Schema, FieldPermissionsMixin): - """Base schema for a job.""" + """Base schema for a job run.""" class Meta: """Meta attributes for the schema.""" @@ -81,5 +145,8 @@ class Meta: id = fields.UUID(dump_only=True) + created = fields.DateTime(dump_only=True) + updated = fields.DateTime(dump_only=True) + title = SanitizedUnicode(required=True, validate=_not_blank(max=250)) description = SanitizedUnicode() diff --git a/invenio_jobs/services/services.py b/invenio_jobs/services/services.py index 4ede6fc..73148a8 100644 --- a/invenio_jobs/services/services.py +++ b/invenio_jobs/services/services.py @@ -8,12 +8,19 @@ """Service definitions.""" +import sqlalchemy as sa from invenio_records_resources.services.base import LinksTemplate from invenio_records_resources.services.base.utils import map_search_params from invenio_records_resources.services.records import RecordService -from invenio_records_resources.services.uow import unit_of_work +from invenio_records_resources.services.uow import ( + ModelCommitOp, + ModelDeleteOp, + unit_of_work, +) -from ..models import Task +from ..models import Job +from ..proxies import current_jobs +from .errors import JobNotFoundError class TasksService(RecordService): @@ -23,8 +30,7 @@ def search(self, identity, params): """Search for tasks.""" self.require_permission(identity, "search") - # TODO: Use an API class - tasks = Task.all() + tasks = current_jobs.tasks.values() search_params = map_search_params(self.config.search, params) query_param = search_params["q"] @@ -53,23 +59,103 @@ def search(self, identity, params): class JobsService(RecordService): """Jobs service.""" - def search(self, identity, **kwargs): + @unit_of_work() + def create(self, identity, data, uow=None): + """Create a job.""" + self.require_permission(identity, "create") + + # TODO: See if we need extra validation (e.g. tasks, args, etc.) + valid_data, errors = self.schema.load( + data, + context={"identity": identity}, + raise_errors=True, + ) + + job = Job(**valid_data) + uow.register(ModelCommitOp(job)) + return self.result_item(self, identity, job, links_tpl=self.links_item_tpl) + + def search(self, identity, params): """Search for jobs.""" - raise NotImplementedError() + self.require_permission(identity, "search") + + search_params = map_search_params(self.config.search, params) + query_param = search_params["q"] + filters = [] + if query_param: + filters.extend( + [ + Job.title.ilike(f"%{query_param}%"), + Job.description.ilike(f"%{query_param}%"), + ] + ) + + jobs = ( + Job.query.filter(sa.or_(*filters)) + .order_by( + search_params["sort_direction"]( + sa.text(",".join(search_params["sort"])) + ) + ) + .paginate( + page=search_params["page"], + per_page=search_params["size"], + error_out=False, + ) + ) + + return self.result_list( + self, + identity, + jobs, + params=search_params, + links_tpl=LinksTemplate(self.config.links_search, context={"args": params}), + links_item_tpl=self.links_item_tpl, + ) def read(self, identity, id_): """Retrieve a job.""" - raise NotImplementedError() + self.require_permission(identity, "read") + job = self._get_job(id_) + + return self.result_item(self, identity, job, links_tpl=self.links_item_tpl) @unit_of_work() def update(self, identity, id_, data, uow=None): """Update a job.""" - raise NotImplementedError() + self.require_permission(identity, "update") + + job = self._get_job(id_) + + valid_data, errors = self.schema.load( + data, + context={"identity": identity, "job": job}, + raise_errors=True, + ) + + for key, value in valid_data.items(): + setattr(job, key, value) + uow.register(ModelCommitOp(job)) + return self.result_item(self, identity, job, links_tpl=self.links_item_tpl) @unit_of_work() def delete(self, identity, id_, uow=None): """Delete a job.""" - raise NotImplementedError() + self.require_permission(identity, "delete") + job = self._get_job(id_) + + # TODO: Check if we can delete the job (e.g. if there are still active Runs) + uow.register(ModelDeleteOp(job)) + + return True + + @classmethod + def _get_job(cls, id): + """Get a job by id.""" + job = Job.query.get(id) + if job is None: + raise JobNotFoundError(id) + return job class RunsService(RecordService): diff --git a/tests/conftest.py b/tests/conftest.py index 20c0955..8935f4d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,11 +11,18 @@ fixtures are available. """ +from types import SimpleNamespace + import pytest +from flask_principal import AnonymousIdentity +from invenio_access.permissions import any_user as any_user_need +from invenio_access.permissions import system_identity from invenio_app.factory import create_api as _create_app from invenio_records_permissions.generators import AnyUser from invenio_records_permissions.policies import BasePermissionPolicy +from invenio_jobs.proxies import current_jobs_service + @pytest.fixture(scope="module") def app_config(app_config): @@ -48,3 +55,63 @@ def extra_entry_points(): "mock_module = mock_module.tasks", ], } + + +# +# Users +# +@pytest.fixture(scope="module") +def anon_identity(): + """Anonymous user.""" + identity = AnonymousIdentity() + identity.provides.add(any_user_need) + return identity + + +@pytest.fixture() +def jobs(db, anon_identity): + """Job fixtures.""" + common_data = { + "task": "tasks.mock_task", + "default_queue": "low", + "default_args": { + "arg1": "value1", + "arg2": "value2", + "kwarg1": "value3", + }, + } + interval_job = current_jobs_service.create( + anon_identity, + { + "title": "Test interval job", + "schedule": { + "type": "interval", + "hours": 4, + }, + **common_data, + }, + ) + crontab_job = current_jobs_service.create( + anon_identity, + { + "title": "Test crontab job", + "schedule": { + "type": "crontab", + "minute": "0", + "hour": "0", + }, + **common_data, + }, + ) + simple_job = current_jobs_service.create( + anon_identity, + { + "title": "Test unscheduled job", + **common_data, + }, + ) + return SimpleNamespace( + interval=interval_job, + crontab=crontab_job, + simple=simple_job, + ) diff --git a/tests/resources/test_resources.py b/tests/resources/test_resources.py index c289164..9126740 100644 --- a/tests/resources/test_resources.py +++ b/tests/resources/test_resources.py @@ -54,3 +54,238 @@ def test_tasks_search(client): assert res.status_code == 200 assert res.json["hits"]["total"] == 1 assert mock_task_res == res.json["hits"]["hits"][0] + + +def test_jobs_create(db, client, anon_identity): + """Test job creation.""" + # Test minimal job payload + res = client.post( + "/jobs", + json={ + "title": "Test minimal job", + "task": "tasks.mock_task", + }, + ) + assert res.status_code == 201 + assert res.json == { + "id": res.json["id"], + "title": "Test minimal job", + "description": None, + "active": True, + "task": "tasks.mock_task", + "default_queue": "celery", + "default_args": {}, + "schedule": None, + "last_run": None, + "created": res.json["created"], + "updated": res.json["updated"], + "links": { + "runs": f"https://127.0.0.1:5000/api/jobs/{res.json['id']}/runs", + "self": f"https://127.0.0.1:5000/api/jobs/{res.json['id']}", + }, + } + + # Test full job payload + res = client.post( + "/jobs", + json={ + "title": "Test full job", + "task": "tasks.mock_task", + "description": "Test description", + "active": False, + "default_queue": "low", + "default_args": { + "arg1": "value1", + "arg2": "value2", + "kwarg1": "value3", + }, + "schedule": {"type": "interval", "hours": 4}, + }, + ) + assert res.status_code == 201 + assert res.json == { + "id": res.json["id"], + "title": "Test full job", + "description": "Test description", + "active": False, + "task": "tasks.mock_task", + "default_queue": "low", + "default_args": { + "arg1": "value1", + "arg2": "value2", + "kwarg1": "value3", + }, + "schedule": {"type": "interval", "hours": 4}, + "last_run": None, + "created": res.json["created"], + "updated": res.json["updated"], + "links": { + "runs": f"https://127.0.0.1:5000/api/jobs/{res.json['id']}/runs", + "self": f"https://127.0.0.1:5000/api/jobs/{res.json['id']}", + }, + } + + +def test_jobs_update(db, client, jobs): + """Test job updates.""" + # Update existing job + res = client.put( + f"/jobs/{jobs.simple.id}", + json={ + "title": "Test updated job", + "task": "tasks.mock_task", + "description": "Test updated description", + "schedule": {"type": "interval", "hours": 2}, + "active": False, + "default_queue": "celery", + "default_args": { + "arg1": "new_value1", + "arg2": "new_value2", + "kwarg2": False, + }, + }, + ) + assert res.status_code == 200 + updated_job = { + "id": jobs.simple.id, + "title": "Test updated job", + "description": "Test updated description", + "active": False, + "task": "tasks.mock_task", + "default_queue": "celery", + "default_args": { + "arg1": "new_value1", + "arg2": "new_value2", + "kwarg2": False, + }, + "schedule": {"type": "interval", "hours": 2}, + "last_run": None, + "created": jobs.simple["created"], + "updated": res.json["updated"], + "links": { + "runs": f"https://127.0.0.1:5000/api/jobs/{jobs.simple.id}/runs", + "self": f"https://127.0.0.1:5000/api/jobs/{jobs.simple.id}", + }, + } + assert res.json == updated_job + + # Read the job to check the update + res = client.get(f"/jobs/{jobs.simple.id}") + assert res.status_code == 200 + assert res.json == updated_job + + +def test_jobs_search(client, jobs): + """Test jobs search.""" + res = client.get("/jobs") + assert res.status_code == 200 + assert "hits" in res.json + assert res.json["hits"]["total"] == 3 + hits = res.json["hits"]["hits"] + + interval_job_res = next((j for j in hits if j["id"] == jobs.interval.id), None) + assert interval_job_res == { + "id": jobs.interval.id, + "title": "Test interval job", + "description": None, + "active": True, + "task": "tasks.mock_task", + "default_queue": "low", + "default_args": { + "arg1": "value1", + "arg2": "value2", + "kwarg1": "value3", + }, + "schedule": { + "type": "interval", + "hours": 4, + }, + "last_run": None, + "created": jobs.interval["created"], + "updated": jobs.interval["updated"], + "links": { + "runs": f"https://127.0.0.1:5000/api/jobs/{jobs.interval.id}/runs", + "self": f"https://127.0.0.1:5000/api/jobs/{jobs.interval.id}", + }, + } + + crontab_job_res = next((j for j in hits if j["id"] == jobs.crontab.id), None) + assert crontab_job_res == { + "id": jobs.crontab.id, + "title": "Test crontab job", + "description": None, + "active": True, + "task": "tasks.mock_task", + "default_queue": "low", + "default_args": { + "arg1": "value1", + "arg2": "value2", + "kwarg1": "value3", + }, + "schedule": { + "type": "crontab", + "minute": "0", + "hour": "0", + "day_of_week": "*", + "day_of_month": "*", + "month_of_year": "*", + }, + "last_run": None, + "created": jobs.crontab["created"], + "updated": jobs.crontab["updated"], + "links": { + "runs": f"https://127.0.0.1:5000/api/jobs/{jobs.crontab.id}/runs", + "self": f"https://127.0.0.1:5000/api/jobs/{jobs.crontab.id}", + }, + } + + simple_job_res = next((j for j in hits if j["id"] == jobs.simple.id), None) + assert simple_job_res == { + "id": jobs.simple.id, + "title": "Test unscheduled job", + "description": None, + "active": True, + "task": "tasks.mock_task", + "default_queue": "low", + "default_args": { + "arg1": "value1", + "arg2": "value2", + "kwarg1": "value3", + }, + "schedule": None, + "last_run": None, + "created": jobs.simple["created"], + "updated": jobs.simple["updated"], + "links": { + "runs": f"https://127.0.0.1:5000/api/jobs/{jobs.simple.id}/runs", + "self": f"https://127.0.0.1:5000/api/jobs/{jobs.simple.id}", + }, + } + + # Test filtering + res = client.get("/jobs?q=interval") + assert res.status_code == 200 + assert res.json["hits"]["total"] == 1 + assert interval_job_res == res.json["hits"]["hits"][0] + + +def test_jobs_delete(db, client, jobs): + """Test job deletion.""" + res = client.delete(f"/jobs/{jobs.simple.id}") + assert res.status_code == 204 + + # Shouldn't be able to get again + res = client.get(f"/jobs/{jobs.simple.id}") + assert res.status_code == 404 + assert res.json == { + "message": f"Job with ID {jobs.simple.id} does not exist.", + "status": 404, + } + + # Shouldn't appear in search results + res = client.get("/jobs") + assert res.status_code == 200 + assert "hits" in res.json + assert res.json["hits"]["total"] == 2 + hits = res.json["hits"]["hits"] + assert all(j["id"] != jobs.simple.id for j in hits)