Skip to content

Commit

Permalink
run: serialize arguments for celery task
Browse files Browse the repository at this point in the history
  • Loading branch information
kpsherva committed Aug 28, 2024
1 parent 85d54b8 commit cffab7f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
13 changes: 9 additions & 4 deletions invenio_jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class Job(db.Model, Timestamp):
active = db.Column(db.Boolean, default=True, nullable=False)
title = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text)
# default_args = db.Column(JSON, default=lambda: dict(), nullable=True)
task = db.Column(db.String(255))
default_queue = db.Column(db.String(64))
schedule = db.Column(JSON, nullable=True)
Expand Down Expand Up @@ -146,21 +145,26 @@ def create(cls, job, **kwargs):
"""Create a new run."""
if "args" not in kwargs:
kwargs["args"] = cls.generate_args(job)
else:
task_arguments = deepcopy(kwargs["args"].get("args", {}))
kwargs["args"] = cls.generate_args(job, task_arguments=task_arguments)
if "queue" not in kwargs:
kwargs["queue"] = job.default_queue

return cls(job=job, **kwargs)

@classmethod
def generate_args(cls, job):
def generate_args(cls, job, task_arguments=None):
"""Generate new run args.
We allow a templating mechanism to generate the args for the run. It's important
that the Jinja template context only includes "safe" values, i.e. no DB model
classes or Python objects or functions. Otherwise, we risk that users could
execute arbitrary code, or perform harmful DB operations (e.g. delete rows).
"""
args = deepcopy(job.default_args)
if task_arguments:
args = Task.get(job.task).build_task_arguments(job_obj=job, **task_arguments)
else:
args = deepcopy(job.default_args)
args = json.dumps(args, indent=4, sort_keys=True, default=str)
args = json.loads(args)
return args
Expand Down Expand Up @@ -208,3 +212,4 @@ def all(cls):
def get(cls, id_):
"""Get registered task by id."""
return cls(current_jobs.registry.get(id_))

12 changes: 10 additions & 2 deletions invenio_jobs/services/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
"""Service schemas."""

import inspect
import json
from copy import deepcopy
from datetime import timezone

from invenio_i18n import lazy_gettext as _
from invenio_users_resources.services import schemas as user_schemas
from marshmallow import EXCLUDE, Schema, fields, post_load, pre_dump, types, validate
from marshmallow import EXCLUDE, Schema, fields, post_load, pre_dump, types, validate, \
pre_load
from marshmallow_oneofschema import OneOfSchema
from marshmallow_utils.fields import SanitizedUnicode, TZDateTime
from marshmallow_utils.permissions import FieldPermissionsMixin
Expand Down Expand Up @@ -243,10 +245,16 @@ class Meta:
dump_default=lambda: current_jobs.default_queue,
)

@pre_load
def wrap_args(self, obj, many, **kwargs):
"""Workaround for nested args."""
obj["args"] = {"args": obj["args"]}
return obj

@post_load
def pick_args(self, obj, many, **kwargs):
"""Choose custom or default args."""
custom_args = obj.pop("custom_args")
if custom_args:
obj["args"] = custom_args
obj["args"] = json.loads(custom_args)
return obj

0 comments on commit cffab7f

Please sign in to comment.