diff --git a/bytes/bytes/rabbitmq.py b/bytes/bytes/rabbitmq.py index 8ae5b83f446..f40a042ef1f 100644 --- a/bytes/bytes/rabbitmq.py +++ b/bytes/bytes/rabbitmq.py @@ -53,7 +53,7 @@ def _check_connection(self) -> None: @staticmethod def _queue_name(event: Event) -> str: - return f"{event.organization}__{event.event_id}" + return event.event_id class NullManager(EventManager): diff --git a/mula/docs/architecture.md b/mula/docs/architecture.md index 69a2dbf3c6c..b3d24418acb 100644 --- a/mula/docs/architecture.md +++ b/mula/docs/architecture.md @@ -3,19 +3,14 @@ ## Purpose The _scheduler_ is tasked with populating and maintaining a priority queues of -ranked tasks, and can be popped off through HTTP API calls. The scheduler is +tasks, and can be popped off through HTTP API calls. The scheduler is designed to be extensible, such that you're able to create your own rules for the population, scheduling, and prioritization of tasks. -In the implementation of the scheduler within OpenKAT is tasked with -scheduling and populating the priority queues of 'boefje', 'normalizer' and +In the implementation of the scheduler within OpenKAT the scheduler is tasked +with scheduling and populating the priority queues of `boefje`, `normalizer` and `report` tasks. -Because of the use of a priority queue we can differentiate between tasks that -are to be executed first, e.g. tasks created by the user get precedence over -tasks that are created by the internal rescheduling processes within the -scheduler. - In this document we will outline how the scheduler operates within KAT, how internal systems function and how external services use it. @@ -34,43 +29,75 @@ combines data from the `Octopoes`, `Katalogus`, `Bytes` and `RabbitMQ` systems. External services used and for what purpose: -- Octopoes; retrieval of ooi information - - RabbitMQ; messaging queues to notify the scheduler of scan level changes and the creation of raw files from bytes +- Rocky; interfaces with the scheduler through its rest api + +- Octopoes; retrieval of ooi information + - Katalogus; retrieval of plugin and organization information - Bytes; retrieval of raw file information -- Rocky; interfaces with the scheduler through its rest api +```mermaid +flowchart TB + subgraph "External informational services" + Octopoes["Octopoes
[system]"] + Katalogus["Katalogus
[system]"] + Bytes["Bytes
[system]"] + end + subgraph "Task creation services" + Rocky["Rocky
[webapp]"] + RabbitMQ["RabbitMQ
[message broker]"] + end + + Scheduler["Scheduler
[system]"] + + subgraph "Task handling services" + TaskRunner["Task Runner
[software system]"] + end + + Rocky-->Scheduler + RabbitMQ-->Scheduler -![scheduler_system.svg](./img/scheduler_system.svg) + Octopoes-->Scheduler + Katalogus-->Scheduler + Bytes-->Scheduler + + + Scheduler--"Pop task of queue"-->TaskRunner +``` ### C3 Component level When we take a closer look at the `scheduler` system itself we can identify -several components. The `SchedulerApp` directs the creation and maintenance -of a multitude of schedulers. - -| Scheduler | Schedulers | -| :-------------------------------- | --------------------------------------: | -| ![scheduler](./img/scheduler.svg) | ![schedulers.svg](./img/schedulers.svg) | +several components. The `App` directs the creation and maintenance +of several schedulers. And the `API` that is responsible for interfacing with +the `Scheduler` system. + +```mermaid +flowchart TB + subgraph "**Scheduler**
[system]" + direction TB + subgraph Server["**API**
[component]
REST API"] + end + subgraph App["**App**
[component]
Main python application"] + end + Server-->App + end +``` -Typically in a OpenKAT installation 3 scheduler will be created per organisation: +Typically in a OpenKAT installation 3 scheduler will be created 1. _boefje scheduler_ 2. _normalizer scheduler_ 3. _report scheduler_ Each scheduler type implements it's own priority queue, and can implement it's -own processes of populating, and prioritization of its tasks. - -![queue.svg](./img/queue.svg) - -Interaction with the scheduler and access to the internals of the -`SchedulerApp` can be accessed by the `Server` which implements a HTTP REST API -interface. +own processes of populating, and prioritization of its tasks. Interaction with +the scheduler and access to the internals of the `App` can be achieved by +interfacing with the `Server`. Which implements a HTTP REST API interface. ## Dataflows @@ -92,7 +119,22 @@ responsible for maintaining a queue of tasks for `Task Runners` to pick up and process. A `Scheduler` is responsible for creating `Task` objects and pushing them onto the queue. -![tasks.svg](./img/tasks.svg) +```mermaid +flowchart LR + subgraph "**Scheduler**
[system]" + direction LR + subgraph Scheduler["**Scheduler**
[component]
"] + direction LR + Process["Task creation process"] + subgraph PriorityQueue["PriorityQueue"] + Task0 + Task1[...] + TaskN + end + end + Process-->PriorityQueue + end +``` The `PriorityQueue` derives its state from the state of the `Task` objects that are persisted in the database. In other words, the current state of the @@ -102,13 +144,16 @@ are persisted in the database. In other words, the current state of the A `Task` object contains the following fields: -- `scheduler_id` - The id of the scheduler for which this task is created -- `schedule_id` - Optional, the id of the `Schedule` that created the task -- `priority` - The priority of the task -- `status` - The status of the task -- `type` - The type of the task -- `data` - A JSON object containing the task data -- `hash` - A unique hash generated by specific fields from the task data +| Field | Description | +| -------------- | ------------------------------------------------------------- | +| `scheduler_id` | The id of the scheduler for which this task is created | +| `schedule_id` | Optional, the id of the `Schedule` that created the task | +| `priority` | The priority of the task | +| `organisation` | The organisation for which the task is created | +| `status` | The status of the task | +| `type` | The type of the task | +| `data` | A JSON object containing the task data | +| `hash` | A unique hash generated by specific fields from the task data | Important to note is the `data` field contains the object that a `Task Runner` will use to execute the task. This field is a JSON field that allows any object @@ -120,6 +165,35 @@ By doing this, it allows the scheduler to wrap whatever object within a `Task`, and as a result we're able to create and extend more types of schedulers that are not specifically bound to a type. +A json representation of a `Task` object, for example a `BoefjeTask` object +as the `data` field: + +```json +{ + "scheduler_id": "1", + "schedule_id": "1", + "priority": 1, + "organisation": "openkat-corp", + "status": "PENDING", + "type": "boefje", + "data": { + "ooi": "internet", + "boefje": { + "id": "dns-zone", + "scan_level": 1 + } + }, + "hash": "a1b2c3d4e5f6g7h8i9j0" +} +``` + +A `Task` is a one-time execution of a task and is a unique instance of task that +is present in the `data` object. This means that you will encounter several +instances of the same task. We generate a unique hash for each task by hashing +specific fields from the `data` object. This hash is used to identify the task +within the `PriorityQueue` and is used to check if the same task is already on +the queue. + This approach ensures that the historical record of each task's execution is distinct, providing a clear and isolated view of each instance of the task's lifecycle. This strategy enables maintaining accurate and unambiguous @@ -153,29 +227,28 @@ that `Scheduler` can create `Schedule` objects for its `Task` objects. A `Schedule` object is a way to define when a `Task` should be executed automatically on a recurring schedule by the `Scheduler`. -A `Schedule` will use the 'blueprint' that is defined in its `data` field (this +A `Schedule` will use the _'blueprint'_ that is defined in its `data` field (this is the same as the `data` field of a `Task`) to generate a `Task` object to be pushed on the queue of a `Scheduler`. -![schedules.svg](./img/schedules.svg) - A `Schedule` object contains the following fields: -- `scheduler_id` - The id of the scheduler that created the schedule -- `schedule` - A cron expression that defines when the task should be - executed, this is used to update the value of `deadline_at` -- `deadline_at` - A timestamp that defines when the task should be executed -- `data` - A JSON object containing data for the schedule (this is the same as - the `data` field in the `Task` object) -- `hash` - A unique hash generated by specific fields from the schedule data +| Field | Description | +| -------------- | ------------------------------------------------------------------------------------------------------------------ | +| `scheduler_id` | The id of the scheduler that created the schedule | +| `schedule` | A cron expression that defines when the task should be executed, this is used to update the value of `deadline_at` | +| `deadline_at` | A timestamp that defines when the task should be executed | +| `data` | A JSON object containing data for the schedule (this is the same as the `data` field in the `Task` object) | +| `hash` | A unique hash generated by specific fields from the schedule data | A `Scheduler` can be extended by a process that checks if the `deadline_at` of a `Schedule` has passed, and if so, creates a `Task` object for the `Scheduler` to push onto the queue. -When the `Task` object is pushed onto the queue, the new `deadline_at` value -of the `Schedule` is calculated using the cron expression defined in the -`schedule` field. +Typically when the `Task` object is pushed onto the queue, the new +`deadline_at` value of the `Schedule` is calculated using the cron expression +defined in the `schedule` field. Refer to the specific `Scheduler` for more +information on how this is implemented. ### `BoefjeScheduler` @@ -221,21 +294,46 @@ Before a `BoefjeTask` and pushed on the queue we will check the following: #### Processes -![boefje_scheduler.svg](./img/boefje_scheduler.svg) +```mermaid +flowchart LR + subgraph "**Scheduler**
[system]" + direction LR + subgraph BoefjeScheduler["**BoefjeScheduler**
[component]
"] + direction LR + ProcessManual["Manual"] + ProcessMutations["Mutations"] + ProcessNewBoefjes["NewBoefjes"] + ProcessRescheduling["Rescheduling"] + subgraph PriorityQueue["PriorityQueue"] + Task0 + Task1[...] + TaskN + end + ProcessManual-->PriorityQueue + ProcessMutations-->PriorityQueue + ProcessNewBoefjes-->PriorityQueue + ProcessRescheduling-->PriorityQueue + end + end +``` In order to create a `BoefjeTask` and trigger the dataflow we described above -we have 4 different processes running in threads within a `BoefjeScheduler` +we have 3 different processes running in threads within a `BoefjeScheduler` that can create boefje tasks. Namely: -1. scan profile mutations -2. enabling of boefjes -3. rescheduling of prior tasks -4. manual scan job +| Process | Description | +| ----------------------- | -------------------------------------------------------------------------------------------------- | +| `process_mutations` | scan profile mutations received from RabbitMQ indicating that the scan level of an OOI has changed | +| `process_new_boefjes` | enabling of boefjes will result in gathering of OOI's on which the boefje can be used | +| `process_rescheduling ` | rescheduling of prior tasks | + +Additionally, a boefje task creation can be triggered by a manual scan job that +is created by the user in Rocky. ##### 1. Scan profile mutations When a scan level is increased on an OOI -(`schedulers.boefje.push_tasks_for_scan_profile_mutations`) a message is pushed +(`schedulers.boefje.process_mutations`) a message is pushed on the RabbitMQ `{organization_id}__scan_profile_mutations` queue. The scheduler continuously checks if new messages are posted on the queue. The resulting tasks from this process will get the second highest priority of 2 on the queue. @@ -336,7 +434,22 @@ queue we will check the following: #### Processes -![normalizer_scheduler.svg](./img/normalizer_scheduler.svg) +```mermaid +flowchart LR + subgraph "**Scheduler**
[system]" + direction LR + subgraph NormalizerScheduler["**NormalizerScheduler**
[component]
"] + direction LR + ProcessRawData["RawData"] + subgraph PriorityQueue["PriorityQueue"] + Task0 + Task1[...] + TaskN + end + ProcessRawData-->PriorityQueue + end + end +``` The following processes within a `NormalizerScheduler` will create a `NormalizerTask` tasks: @@ -345,7 +458,7 @@ The following processes within a `NormalizerScheduler` will create a ##### 1. Raw file creation in Bytes -When a raw file is created (`schedulers.normalizer.create_tasks_for_raw_data`) +When a raw file is created (`schedulers.normalizer.process_raw_data`) - The `NormalizerScheduler` retrieves raw files that have been created in Bytes from a message queue. @@ -365,7 +478,22 @@ picked up and processed by the report task runner. #### Processes -![report_scheduler.svg](./img/report_scheduler.svg) +```mermaid +flowchart LR + subgraph "**Scheduler**
[system]" + direction LR + subgraph ReportScheduler["**ReportScheduler**
[component]
"] + direction LR + ProcessRescheduling["Rescheduling"] + subgraph PriorityQueue["PriorityQueue"] + Task0 + Task1[...] + TaskN + end + ProcessRescheduling-->PriorityQueue + end + end +``` The `ReportScheduler` will create a `ReportTask` for the `Task` that is associated with a `Schedule` object. diff --git a/mula/docs/img/boefje_scheduler.svg b/mula/docs/img/boefje_scheduler.svg deleted file mode 100644 index 9f854ad21bf..00000000000 --- a/mula/docs/img/boefje_scheduler.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
BoefjeScheduler
queue
mutations
new boefjes
rescheduling
manual
diff --git a/mula/docs/img/normalizer_scheduler.svg b/mula/docs/img/normalizer_scheduler.svg deleted file mode 100644 index 18b53d70fe8..00000000000 --- a/mula/docs/img/normalizer_scheduler.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
queue
raw data
received
NormalizerScheduler
diff --git a/mula/docs/img/queue.svg b/mula/docs/img/queue.svg deleted file mode 100644 index 1f7fdbfcdee..00000000000 --- a/mula/docs/img/queue.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
Scheduler
API
queue
(process)
diff --git a/mula/docs/img/report_scheduler.svg b/mula/docs/img/report_scheduler.svg deleted file mode 100644 index c6a78c79e97..00000000000 --- a/mula/docs/img/report_scheduler.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
queue
rescheduling
ReportScheduler
diff --git a/mula/docs/img/scheduler.svg b/mula/docs/img/scheduler.svg deleted file mode 100644 index 87fc74ee30a..00000000000 --- a/mula/docs/img/scheduler.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
Scheduler
API
SchedulerApp
diff --git a/mula/docs/img/scheduler_system.svg b/mula/docs/img/scheduler_system.svg deleted file mode 100644 index ac511569ad2..00000000000 --- a/mula/docs/img/scheduler_system.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
Scheduler System
API
SchedulerApp
Rocky
RabbitMQ
KAT-alogus
Octopoes
Bytes
Task Runners
diff --git a/mula/docs/img/schedulers.svg b/mula/docs/img/schedulers.svg deleted file mode 100644 index d804fc43df1..00000000000 --- a/mula/docs/img/schedulers.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
Scheduler
API
boefje-org1
boefje-org2
diff --git a/mula/docs/img/schedules.svg b/mula/docs/img/schedules.svg deleted file mode 100644 index 2d4cf387854..00000000000 --- a/mula/docs/img/schedules.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
(process)
task
task
task
task
Schedule
Task
diff --git a/mula/docs/img/tasks.svg b/mula/docs/img/tasks.svg deleted file mode 100644 index 0b686533366..00000000000 --- a/mula/docs/img/tasks.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
(process)
task
task
task
task
diff --git a/mula/logging.json b/mula/logging.json index 2d7a02642d9..6fd1ed1ce2a 100644 --- a/mula/logging.json +++ b/mula/logging.json @@ -9,13 +9,13 @@ "handlers": { "console": { "class": "logging.StreamHandler", - "level": "INFO", + "level": "DEBUG", "formatter": "default", "stream": "ext://sys.stdout" } }, "root": { - "level": "INFO", + "level": "DEBUG", "handlers": [ "console" ] diff --git a/mula/scheduler/app.py b/mula/scheduler/app.py index d8770730762..f1be55079bf 100644 --- a/mula/scheduler/app.py +++ b/mula/scheduler/app.py @@ -4,7 +4,7 @@ import structlog from opentelemetry import trace -from scheduler import clients, context, schedulers, server +from scheduler import context, schedulers, server from scheduler.utils import thread tracer = trace.get_tracer(__name__) @@ -26,34 +26,21 @@ class App: through a REST API. * Metrics: The collection of application specific metrics. - - Attributes: - logger: - The logger for the class. - ctx: - Application context of shared data (e.g. configuration, external - services connections). - stop_event: A threading.Event object used for communicating a stop - event across threads. - schedulers: - A dict of schedulers, keyed by scheduler id. - server: - The http rest api server instance. """ def __init__(self, ctx: context.AppContext) -> None: """Initialize the application. Args: - ctx: - Application context of shared data (e.g. configuration, - external services connections). + ctx (context.AppContext): Application context of shared data (e.g. + configuration, external services connections). """ self.logger: structlog.BoundLogger = structlog.getLogger(__name__) self.ctx: context.AppContext = ctx + self.server: server.Server | None = None - threading.excepthook = self.unhandled_exception + threading.excepthook = self._unhandled_exception self.stop_event: threading.Event = threading.Event() self.lock: threading.Lock = threading.Lock() @@ -64,147 +51,6 @@ def __init__(self, ctx: context.AppContext) -> None: | schedulers.NormalizerScheduler | schedulers.ReportScheduler, ] = {} - self.server: server.Server | None = None - - @tracer.start_as_current_span("monitor_organisations") - def monitor_organisations(self) -> None: - """Monitor the organisations from the Katalogus service, and add/remove - organisations from the schedulers. - """ - current_schedulers = self.schedulers.copy() - - # We make a difference between the organisation id's that are used - # by the schedulers, and the organisation id's that are in the - # Katalogus service. We will add/remove schedulers based on the - # difference between these two sets. - scheduler_orgs: set[str] = { - s.organisation.id for s in current_schedulers.values() if hasattr(s, "organisation") - } - try: - orgs = self.ctx.services.katalogus.get_organisations() - except clients.errors.ExternalServiceError: - self.logger.exception("Failed to get organisations from Katalogus") - return - - katalogus_orgs = {org.id for org in orgs} - - additions = katalogus_orgs.difference(scheduler_orgs) - self.logger.debug("Organisations to add: %s", len(additions), additions=sorted(additions)) - - removals = scheduler_orgs.difference(katalogus_orgs) - self.logger.debug("Organisations to remove: %s", len(removals), removals=sorted(removals)) - - # We need to get scheduler ids of the schedulers that are associated - # with the removed organisations - removal_scheduler_ids: set[str] = { - s.scheduler_id - for s in current_schedulers.values() - if hasattr(s, "organisation") and s.organisation.id in removals - } - - # Remove schedulers for removed organisations - for scheduler_id in removal_scheduler_ids: - if scheduler_id not in self.schedulers: - continue - - self.schedulers[scheduler_id].stop() - - if removals: - self.logger.debug("Removed %s organisations from scheduler", len(removals), removals=sorted(removals)) - - # Add schedulers for organisation - for org_id in additions: - try: - org = self.ctx.services.katalogus.get_organisation(org_id) - except clients.errors.ExternalServiceError as e: - self.logger.error("Failed to get organisation from Katalogus", error=e, org_id=org_id) - continue - - scheduler_boefje = schedulers.BoefjeScheduler( - ctx=self.ctx, scheduler_id=f"boefje-{org.id}", organisation=org, callback=self.remove_scheduler - ) - - scheduler_normalizer = schedulers.NormalizerScheduler( - ctx=self.ctx, scheduler_id=f"normalizer-{org.id}", organisation=org, callback=self.remove_scheduler - ) - - scheduler_report = schedulers.ReportScheduler( - ctx=self.ctx, scheduler_id=f"report-{org.id}", organisation=org, callback=self.remove_scheduler - ) - - with self.lock: - self.schedulers[scheduler_boefje.scheduler_id] = scheduler_boefje - self.schedulers[scheduler_normalizer.scheduler_id] = scheduler_normalizer - self.schedulers[scheduler_report.scheduler_id] = scheduler_report - - scheduler_normalizer.run() - scheduler_boefje.run() - scheduler_report.run() - - if additions: - # Flush katalogus caches when new organisations are added - self.ctx.services.katalogus.flush_caches() - - self.logger.debug("Added %s organisations to scheduler", len(additions), additions=sorted(additions)) - - @tracer.start_as_current_span("collect_metrics") - def collect_metrics(self) -> None: - """Collect application metrics - - This method that allows to collect metrics throughout the application. - """ - with self.lock: - for s in self.schedulers.copy().values(): - self.ctx.metrics_qsize.labels(scheduler_id=s.scheduler_id).set(s.queue.qsize()) - - status_counts = self.ctx.datastores.task_store.get_status_counts(s.scheduler_id) - for status, count in status_counts.items(): - self.ctx.metrics_task_status_counts.labels(scheduler_id=s.scheduler_id, status=status).set(count) - - def start_schedulers(self) -> None: - # Initialize the schedulers - try: - orgs = self.ctx.services.katalogus.get_organisations() - except clients.errors.ExternalServiceError as e: - self.logger.error("Failed to get organisations from Katalogus", error=e) - return - - for org in orgs: - boefje_scheduler = schedulers.BoefjeScheduler( - ctx=self.ctx, scheduler_id=f"boefje-{org.id}", organisation=org, callback=self.remove_scheduler - ) - self.schedulers[boefje_scheduler.scheduler_id] = boefje_scheduler - - normalizer_scheduler = schedulers.NormalizerScheduler( - ctx=self.ctx, scheduler_id=f"normalizer-{org.id}", organisation=org, callback=self.remove_scheduler - ) - self.schedulers[normalizer_scheduler.scheduler_id] = normalizer_scheduler - - report_scheduler = schedulers.ReportScheduler( - ctx=self.ctx, scheduler_id=f"report-{org.id}", organisation=org, callback=self.remove_scheduler - ) - self.schedulers[report_scheduler.scheduler_id] = report_scheduler - - # Start schedulers - for scheduler in self.schedulers.values(): - scheduler.run() - - def start_monitors(self) -> None: - thread.ThreadRunner( - name="App-monitor_organisations", - target=self.monitor_organisations, - stop_event=self.stop_event, - interval=self.ctx.config.monitor_organisations_interval, - ).start() - - def start_collectors(self) -> None: - thread.ThreadRunner( - name="App-metrics_collector", target=self.collect_metrics, stop_event=self.stop_event, interval=10 - ).start() - - def start_server(self) -> None: - self.server = server.Server(self.ctx, self.schedulers) - thread.ThreadRunner(name="App-server", target=self.server.run, stop_event=self.stop_event, loop=False).start() def run(self) -> None: """Start the main scheduler application, and run in threads the @@ -215,18 +61,12 @@ def run(self) -> None: * metrics collecting * api server """ - # Start schedulers self.start_schedulers() - # Start monitors - self.start_monitors() - - # Start metrics collecting if self.ctx.config.collect_metrics: self.start_collectors() - # API Server - self.start_server() + self.start_server(self.schedulers) # Main thread while not self.stop_event.is_set(): @@ -241,24 +81,55 @@ def run(self) -> None: # Source: https://stackoverflow.com/a/1489838/1346257 os._exit(1) + def start_schedulers(self) -> None: + boefje = schedulers.BoefjeScheduler(ctx=self.ctx) + self.schedulers[boefje.scheduler_id] = boefje + + normalizer = schedulers.NormalizerScheduler(ctx=self.ctx) + self.schedulers[normalizer.scheduler_id] = normalizer + + report = schedulers.ReportScheduler(ctx=self.ctx) + self.schedulers[report.scheduler_id] = report + + for s in self.schedulers.values(): + s.run() + + def start_collectors(self) -> None: + thread.ThreadRunner( + name="App-metrics_collector", target=self._collect_metrics, stop_event=self.stop_event, interval=10 + ).start() + + def start_server( + self, + schedulers: dict[ + str, + schedulers.Scheduler + | schedulers.BoefjeScheduler + | schedulers.NormalizerScheduler + | schedulers.ReportScheduler, + ], + ) -> None: + self.server = server.Server(self.ctx, schedulers) + thread.ThreadRunner(name="App-server", target=self.server.run, stop_event=self.stop_event, loop=False).start() + def shutdown(self) -> None: """Shutdown the scheduler application, and all threads.""" self.logger.info("Shutdown initiated") self.stop_event.set() - # First stop schedulers - for s in self.schedulers.copy().values(): + # Stop all schedulers + for s in self.schedulers.values(): s.stop() # Stop all threads that are still running, except the main thread. # These threads likely have a blocking call and as such are not able # to leverage a stop event. - self.stop_threads() + self._stop_threads() self.logger.info("Shutdown complete") - def stop_threads(self) -> None: + def _stop_threads(self) -> None: """Stop all threads, except the main thread.""" for t in threading.enumerate(): if t is threading.current_thread(): @@ -272,23 +143,23 @@ def stop_threads(self) -> None: t.join(5) - def unhandled_exception(self, args: threading.ExceptHookArgs) -> None: + def _unhandled_exception(self, args: threading.ExceptHookArgs) -> None: """Gracefully shutdown the scheduler application, and all threads when a unhandled exception occurs. """ self.logger.error("Unhandled exception occurred: %s", args.exc_value) self.stop_event.set() - def remove_scheduler(self, scheduler_id: str) -> None: - """Remove a scheduler from the application. This method is passed - as a callback to the scheduler, so that the scheduler can remove - itself from the application. + def _collect_metrics(self) -> None: + """Collect application metrics throughout the application.""" - Args: - scheduler_id: The id of the scheduler to remove. - """ - with self.lock: - if scheduler_id not in self.schedulers: - return + # FIXME:: can be queries instead of a loop + # Collect the queue size of the schedulers, and the status counts of + # the tasks for each scheduler. + for s in self.schedulers.values(): + qsize = self.ctx.datastores.pq_store.qsize(s.scheduler_id) + self.ctx.metrics_qsize.labels(scheduler_id=s.scheduler_id).set(qsize) - self.schedulers.pop(scheduler_id) + status_counts = self.ctx.datastores.task_store.get_status_counts(s.scheduler_id) + for status, count in status_counts.items(): + self.ctx.metrics_task_status_counts.labels(scheduler_id=s.scheduler_id, status=status).set(count) diff --git a/mula/scheduler/config/settings.py b/mula/scheduler/config/settings.py index f095350b8bb..9db1d567e1f 100644 --- a/mula/scheduler/config/settings.py +++ b/mula/scheduler/config/settings.py @@ -130,7 +130,7 @@ class Settings(BaseSettings): ) # Queue settings - pq_maxsize: int = Field(1000, description="How many items a priority queue can hold (0 is infinite)") + pq_maxsize: int = Field(0, description="How many items a priority queue can hold (0 is infinite)") pq_interval: int = Field( 60, description="Interval in seconds of the execution of the `` method of the `scheduler.Scheduler` class" diff --git a/mula/scheduler/models/__init__.py b/mula/scheduler/models/__init__.py index a5390ad6ede..646d51c4f21 100644 --- a/mula/scheduler/models/__init__.py +++ b/mula/scheduler/models/__init__.py @@ -8,5 +8,5 @@ from .plugin import Plugin from .queue import Queue from .schedule import Schedule, ScheduleDB -from .scheduler import Scheduler +from .scheduler import Scheduler, SchedulerType from .task import BoefjeTask, NormalizerTask, ReportTask, Task, TaskDB, TaskStatus diff --git a/mula/scheduler/models/ooi.py b/mula/scheduler/models/ooi.py index e1bf51c4c7f..3761365f03e 100644 --- a/mula/scheduler/models/ooi.py +++ b/mula/scheduler/models/ooi.py @@ -27,3 +27,4 @@ class ScanProfileMutation(BaseModel): operation: MutationOperationType primary_key: str value: OOI | None + client_id: str diff --git a/mula/scheduler/models/organisation.py b/mula/scheduler/models/organisation.py index 58032cb5f30..dc819d297f6 100644 --- a/mula/scheduler/models/organisation.py +++ b/mula/scheduler/models/organisation.py @@ -3,4 +3,4 @@ class Organisation(BaseModel): id: str - name: str + name: str | None = None diff --git a/mula/scheduler/models/schedule.py b/mula/scheduler/models/schedule.py index 4ba0c54bc71..4baf27a7406 100644 --- a/mula/scheduler/models/schedule.py +++ b/mula/scheduler/models/schedule.py @@ -17,17 +17,12 @@ class Schedule(BaseModel): model_config = ConfigDict(from_attributes=True, validate_assignment=True) id: uuid.UUID = Field(default_factory=uuid.uuid4) - scheduler_id: str - + organisation: str hash: str | None = Field(None, max_length=32) - data: dict | None = None - enabled: bool = True - schedule: str | None = None - tasks: list[Task] = [] deadline_at: datetime | None = None @@ -57,21 +52,14 @@ class ScheduleDB(Base): __tablename__ = "schedules" id = Column(GUID, primary_key=True) - scheduler_id = Column(String, nullable=False) - + organisation = Column(String, nullable=False) hash = Column(String(32), nullable=True, unique=True) - data = Column(JSONB, nullable=False) - enabled = Column(Boolean, nullable=False, default=True) - schedule = Column(String, nullable=True) - tasks = relationship("TaskDB", back_populates="schedule") deadline_at = Column(DateTime(timezone=True), nullable=True) - created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) - modified_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()) diff --git a/mula/scheduler/models/scheduler.py b/mula/scheduler/models/scheduler.py index 9c75c923743..e1d0f7c7b77 100644 --- a/mula/scheduler/models/scheduler.py +++ b/mula/scheduler/models/scheduler.py @@ -1,14 +1,24 @@ +import enum from datetime import datetime from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict + + +class SchedulerType(str, enum.Enum): + """Enum for scheduler types.""" + + UNKNOWN = "unknown" + BOEFJE = "boefje" + NORMALIZER = "normalizer" + REPORT = "report" class Scheduler(BaseModel): - """Representation of a schedulers.Scheduler instance. Used for - unmarshalling of schedulers to a JSON representation.""" + model_config = ConfigDict(from_attributes=True, use_enum_values=True) - id: str | None = None - enabled: bool | None = None - priority_queue: dict[str, Any] | None = None + id: str + type: SchedulerType + item_type: Any + qsize: int = 0 last_activity: datetime | None = None diff --git a/mula/scheduler/models/task.py b/mula/scheduler/models/task.py index dee0014e86c..c438dc87760 100644 --- a/mula/scheduler/models/task.py +++ b/mula/scheduler/models/task.py @@ -46,19 +46,13 @@ class Task(BaseModel): model_config = ConfigDict(from_attributes=True, use_enum_values=True) id: uuid.UUID = Field(default_factory=uuid.uuid4) - - scheduler_id: str | None = None - + scheduler_id: str schedule_id: uuid.UUID | None = None - + organisation: str priority: int | None = 0 - status: TaskStatus = TaskStatus.PENDING - type: str | None = None - hash: str | None = Field(None, max_length=32) - data: dict = Field(default_factory=dict) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) @@ -69,24 +63,18 @@ class TaskDB(Base): __tablename__ = "tasks" id = Column(GUID, primary_key=True) - scheduler_id = Column(String, nullable=False) - schedule_id = Column(GUID, ForeignKey("schedules.id", ondelete="SET NULL"), nullable=True) - schedule = relationship("ScheduleDB", back_populates="tasks") - + organisation = Column(String, nullable=False) type = Column(String, nullable=False) - hash = Column(String(32), index=True) - priority = Column(Integer) - data = Column(JSONB, nullable=False) - status = Column(Enum(TaskStatus), nullable=False, default=TaskStatus.PENDING) - created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) + schedule = relationship("ScheduleDB", back_populates="tasks") + created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) modified_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()) diff --git a/mula/scheduler/schedulers/errors.py b/mula/scheduler/schedulers/errors.py new file mode 100644 index 00000000000..d20f03018e0 --- /dev/null +++ b/mula/scheduler/schedulers/errors.py @@ -0,0 +1,21 @@ +import functools + +from scheduler.clients.errors import ExternalServiceError +from scheduler.schedulers.queue.errors import QueueFullError + + +def exception_handler(func): + @functools.wraps(func) + def inner_function(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except ExternalServiceError as exc: + self.logger.exception("An exception occurred", exc=exc) + return None + except QueueFullError as exc: + self.logger.exception("Queue is full", exc=exc) + return None + except Exception as exc: + raise exc + + return inner_function diff --git a/mula/scheduler/schedulers/queue/pq.py b/mula/scheduler/schedulers/queue/pq.py index 4a1914451a6..1e7c0b77d29 100644 --- a/mula/scheduler/schedulers/queue/pq.py +++ b/mula/scheduler/schedulers/queue/pq.py @@ -97,7 +97,7 @@ def __init__( self.pq_store: storage.stores.PriorityQueueStore = pq_store self.lock: threading.Lock = threading.Lock() - def pop(self, filters: storage.filters.FilterRequest | None = None) -> models.Task | None: + def pop(self, filters: storage.filters.FilterRequest | None = None) -> tuple[list[models.Task], int]: """Remove and return the highest priority item from the queue. Optionally apply filters to the queue. @@ -113,14 +113,13 @@ def pop(self, filters: storage.filters.FilterRequest | None = None) -> models.Ta if self.empty(): raise QueueEmptyError(f"Queue {self.pq_id} is empty.") - item = self.pq_store.pop(self.pq_id, filters) - if item is None: - return None + items, count = self.pq_store.pop(self.pq_id, filters) + if items is None: + return ([], 0) - item.status = models.TaskStatus.DISPATCHED - self.pq_store.update(self.pq_id, item) + self.pq_store.bulk_update_status(self.pq_id, [item.id for item in items], models.TaskStatus.DISPATCHED) - return item + return items, count def push(self, task: models.Task) -> models.Task: """Push an item onto the queue. @@ -202,7 +201,17 @@ def push(self, task: models.Task) -> models.Task: task.status = models.TaskStatus.QUEUED item_db = self.pq_store.push(task) else: - self.pq_store.update(self.pq_id, task) + # Get the item from the queue and update it + stored_item_data = self.get_item_by_identifier(task) + if stored_item_data is None: + raise ItemNotFoundError(f"Item {task} not found in datastore {self.pq_id}") + + # Update the item with the new data + patch_data = task.dict(exclude_unset=True) + updated_task = stored_item_data.model_copy(update=patch_data) + + # Update the item in the queue + self.pq_store.update(self.pq_id, updated_task) item_db = self.get_item_by_identifier(task) if not item_db: diff --git a/mula/scheduler/schedulers/scheduler.py b/mula/scheduler/schedulers/scheduler.py index b5ec2f07182..8237e860e97 100644 --- a/mula/scheduler/schedulers/scheduler.py +++ b/mula/scheduler/schedulers/scheduler.py @@ -18,41 +18,38 @@ class Scheduler(abc.ABC): - """The Scheduler class combines the priority queue. - The scheduler is responsible for populating the queue, and ranking tasks. + """The scheduler base class that all schedulers should inherit from. Attributes: logger: - The logger for the class + The logger instance. ctx: Application context of shared data (e.g. configuration, external services connections). - queue: - A queue.PriorityQueue instance - callback: - A callback function to call when the scheduler is stopped. scheduler_id: - The id of the scheduler. + The id of the scheduler. max_tries: The maximum number of retries for an item to be pushed to the queue. - enabled: - Whether the scheduler is enabled or not. - _last_activity: + create_schedule: + Whether to create a schedule for a task. + last_activity: The last activity of the scheduler. + queue: + A queues.PriorityQueue instance listeners: - A dict of connector.Listener instances, used for listening to - external events. + A dictionary of listeners, typically AMQP listeners on which + event messages are received. + threads: + A list of threads that are running, typically long running + processes. lock: - A threading.Lock instance used for locking + A threading lock stop_event_threads: - A threading.Event object used for communicating a stop - event across threads. - threads: - A dict of ThreadRunner instances, used for running processes - concurrently. + A threading event to stop the running threads. """ + TYPE: models.SchedulerType = models.SchedulerType.UNKNOWN ITEM_TYPE: Any = None def __init__( @@ -60,36 +57,16 @@ def __init__( ctx: context.AppContext, scheduler_id: str, queue: PriorityQueue | None = None, - callback: Callable[..., None] | None = None, max_tries: int = -1, create_schedule: bool = False, auto_calculate_deadline: bool = True, ): - """Initialize the Scheduler. - - Args: - ctx: - Application context of shared data (e.g. configuration, external - services connections). - scheduler_id: - The id of the scheduler. - queue: - A queue.PriorityQueue instance - callback: - A callback function to call when the scheduler is stopped. - max_tries: - The maximum number of retries for an item to be pushed to - the queue. - """ - self.logger: structlog.BoundLogger = structlog.getLogger(__name__) self.ctx: context.AppContext = ctx - self.callback: Callable[[], Any] | None = callback # Properties self.scheduler_id: str = scheduler_id self.max_tries: int = max_tries - self.enabled: bool = True self.create_schedule: bool = create_schedule self.auto_calculate_deadline: bool = auto_calculate_deadline self._last_activity: datetime | None = None @@ -106,9 +83,9 @@ def __init__( self.listeners: dict[str, clients.amqp.Listener] = {} # Threads + self.threads: list[thread.ThreadRunner] = [] self.lock: threading.Lock = threading.Lock() self.stop_event_threads: threading.Event = threading.Event() - self.threads: list[thread.ThreadRunner] = [] @abc.abstractmethod def run(self) -> None: @@ -184,6 +161,7 @@ def push_item_to_queue_with_timeout( while not self.is_space_on_queue() and (tries < max_tries or max_tries == -1): self.logger.debug( "Queue %s is full, waiting for space", + self.queue.pq_id, queue_id=self.queue.pq_id, queue_qsize=self.queue.qsize(), scheduler_id=self.scheduler_id, @@ -207,16 +185,6 @@ def push_item_to_queue(self, item: models.Task, create_schedule: bool = True) -> QueueFullError: When the queue is full. InvalidItemError: When the item is invalid. """ - if not self.is_enabled(): - self.logger.warning( - "Scheduler is disabled, not pushing item to queue %s", - self.queue.pq_id, - item_id=item.id, - queue_id=self.queue.pq_id, - scheduler_id=self.scheduler_id, - ) - raise NotAllowedError("Scheduler is disabled") - try: if item.type is None: item.type = self.ITEM_TYPE.type @@ -316,7 +284,9 @@ def post_push(self, item: models.Task, create_schedule: bool = True) -> models.T schedule_db = self.ctx.datastores.schedule_store.get_schedule_by_hash(item.hash) if schedule_db is None: - schedule = models.Schedule(scheduler_id=self.scheduler_id, hash=item.hash, data=item.data) + schedule = models.Schedule( + scheduler_id=self.scheduler_id, hash=item.hash, data=item.data, organisation=item.organisation + ) schedule_db = self.ctx.datastores.schedule_store.create_schedule(schedule) if schedule_db is None: @@ -363,9 +333,10 @@ def post_push(self, item: models.Task, create_schedule: bool = True) -> models.T return item - def pop_item_from_queue(self, filters: storage.filters.FilterRequest | None = None) -> models.Task | None: + def pop_item_from_queue( + self, filters: storage.filters.FilterRequest | None = None + ) -> tuple[list[models.Task], int]: """Pop an item from the queue. - Args: filters: Optional filters to apply when popping an item. @@ -376,38 +347,26 @@ def pop_item_from_queue(self, filters: storage.filters.FilterRequest | None = No NotAllowedError: When the scheduler is disabled. QueueEmptyError: When the queue is empty. """ - if not self.is_enabled(): - self.logger.warning( - "Scheduler is disabled, not popping item from queue", - queue_id=self.queue.pq_id, - queue_qsize=self.queue.qsize(), - scheduler_id=self.scheduler_id, - ) - raise NotAllowedError("Scheduler is disabled") - try: - item = self.queue.pop(filters) + items, count = self.queue.pop(filters) except QueueEmptyError as exc: raise exc - if item is not None: + if items is not None: self.logger.debug( - "Popped item %s from queue %s with priority %s", - item.id, + "Popped %s item(s) from queue %s", + count, self.queue.pq_id, - item.priority, - item_id=item.id, queue_id=self.queue.pq_id, scheduler_id=self.scheduler_id, ) - self.post_pop(item) + self.post_pop(items) - return item + return items, count - def post_pop(self, item: models.Task) -> None: + def post_pop(self, items: list[models.Task]) -> None: """After an item is popped from the queue, we execute this function - Args: item: An item from the queue """ @@ -432,54 +391,7 @@ def calculate_deadline(self, task: models.Task) -> datetime: return adjusted_time - def enable(self) -> None: - """Enable the scheduler. - - This will start the scheduler, and start all listeners and threads. - """ - if self.is_enabled(): - self.logger.debug("Scheduler is already enabled") - return - - self.logger.info("Enabling scheduler: %s", self.scheduler_id, scheduler_id=self.scheduler_id) - self.enabled = True - self.stop_event_threads.clear() - self.run() - - self.logger.info("Enabled scheduler: %s", self.scheduler_id, scheduler_id=self.scheduler_id) - - def disable(self) -> None: - """Disable the scheduler. - - This will stop all listeners and threads, and clear the queue, and any - tasks that were on the queue will be set to CANCELLED. - """ - if not self.is_enabled(): - self.logger.warning("Scheduler already disabled: %s", self.scheduler_id, scheduler_id=self.scheduler_id) - return - - self.logger.info("Disabling scheduler: %s", self.scheduler_id) - self.enabled = False - - self.stop_listeners() - self.stop_threads() - self.queue.clear() - - # Get all tasks that were on the queue and set them to CANCELLED - tasks, _ = self.ctx.datastores.task_store.get_tasks( - scheduler_id=self.scheduler_id, status=models.TaskStatus.QUEUED - ) - task_ids = [task.id for task in tasks] - self.ctx.datastores.task_store.cancel_tasks(scheduler_id=self.scheduler_id, task_ids=task_ids) - - self.logger.info("Disabled scheduler: %s", self.scheduler_id, scheduler_id=self.scheduler_id) - - def stop(self, callback: bool = True) -> None: - """Stop the scheduler. - - Args: - callback: Whether to call the callback function. - """ + def stop(self) -> None: self.logger.info("Stopping scheduler: %s", self.scheduler_id, scheduler_id=self.scheduler_id) # First, stop the listeners, when those are running in a thread and @@ -488,9 +400,6 @@ def stop(self, callback: bool = True) -> None: self.stop_listeners() self.stop_threads() - if self.callback and callback: - self.callback(self.scheduler_id) # type: ignore [call-arg] - self.logger.info("Stopped scheduler: %s", self.scheduler_id, scheduler_id=self.scheduler_id) def stop_listeners(self) -> None: @@ -507,14 +416,6 @@ def stop_threads(self) -> None: self.threads = [] - def is_enabled(self) -> bool: - """Check if the scheduler is enabled. - - Returns: - True if the scheduler is enabled, False otherwise. - """ - return self.enabled - def is_space_on_queue(self) -> bool: """Check if there is space on the queue. @@ -547,15 +448,8 @@ def dict(self) -> dict[str, Any]: """Get a dict representation of the scheduler.""" return { "id": self.scheduler_id, - "enabled": self.enabled, - "priority_queue": { - "id": self.queue.pq_id, - "item_type": self.queue.item_type.type, - "maxsize": self.queue.maxsize, - "qsize": self.queue.qsize(), - "allow_replace": self.queue.allow_replace, - "allow_updates": self.queue.allow_updates, - "allow_priority_updates": self.queue.allow_priority_updates, - }, + "type": self.TYPE.value, + "item_type": self.ITEM_TYPE.__name__, + "qsize": self.queue.qsize(), "last_activity": self.last_activity, } diff --git a/mula/scheduler/schedulers/schedulers/boefje.py b/mula/scheduler/schedulers/schedulers/boefje.py index 260b5cb40db..e562f8bad4c 100644 --- a/mula/scheduler/schedulers/schedulers/boefje.py +++ b/mula/scheduler/schedulers/schedulers/boefje.py @@ -1,84 +1,42 @@ import uuid -from collections.abc import Callable from concurrent import futures from datetime import datetime, timedelta, timezone from types import SimpleNamespace -from typing import Any +from typing import Any, Literal -import structlog from opentelemetry import trace +from pydantic import ValidationError -from scheduler import clients, context, storage, utils +from scheduler import clients, context, models, utils from scheduler.clients.errors import ExternalServiceError -from scheduler.models import ( - OOI, - Boefje, - BoefjeTask, - MutationOperationType, - Organisation, - Plugin, - ScanProfileMutation, - Task, - TaskStatus, -) -from scheduler.schedulers import Scheduler -from scheduler.schedulers.queue import PriorityQueue, QueueFullError -from scheduler.schedulers.rankers import BoefjeRanker +from scheduler.schedulers import Scheduler, rankers +from scheduler.schedulers.errors import exception_handler from scheduler.storage import filters +from scheduler.storage.errors import StorageError tracer = trace.get_tracer(__name__) class BoefjeScheduler(Scheduler): - """A KAT specific implementation of a Boefje scheduler. It extends - the `Scheduler` class by adding an `organisation` attribute. + """Scheduler implementation for the creation of BoefjeTask models. Attributes: - logger: A logger instance. - organisation: The organisation that this scheduler is for. + ranker: The ranker to calculate the priority of a task. """ - ITEM_TYPE: Any = BoefjeTask + ID: Literal["boefje"] = "boefje" + TYPE: models.SchedulerType = models.SchedulerType.BOEFJE + ITEM_TYPE: Any = models.BoefjeTask - def __init__( - self, - ctx: context.AppContext, - scheduler_id: str, - organisation: Organisation, - queue: PriorityQueue | None = None, - callback: Callable[..., None] | None = None, - ): + def __init__(self, ctx: context.AppContext): """Initializes the BoefjeScheduler. Args: - ctx: The application context. - scheduler_id: The id of the scheduler. - organisation: The organisation that this scheduler is for. - queue: The queue to use for this scheduler. - callback: The callback function to call when a task is completed. + ctx (context.AppContext): Application context of shared data (e.g. + configuration, external services connections). """ - self.logger: structlog.BoundLogger = structlog.getLogger(__name__) - self.organisation: Organisation = organisation - - self.queue = queue or PriorityQueue( - pq_id=scheduler_id, - maxsize=ctx.config.pq_maxsize, - item_type=self.ITEM_TYPE, - allow_priority_updates=True, - pq_store=ctx.datastores.pq_store, - ) - - super().__init__( - ctx=ctx, - queue=self.queue, - scheduler_id=scheduler_id, - callback=callback, - create_schedule=True, - auto_calculate_deadline=True, - ) - - # Priority ranker - self.priority_ranker = BoefjeRanker(self.ctx) + super().__init__(ctx=ctx, scheduler_id=self.ID, create_schedule=True, auto_calculate_deadline=True) + self.ranker = rankers.BoefjeRanker(self.ctx) def run(self) -> None: """The run method is called when the scheduler is started. It will @@ -96,242 +54,179 @@ def run(self) -> None: - Rescheduling; when a task has passed its deadline, we need to reschedule it. """ - # Scan profile mutations - self.listeners["scan_profile_mutations"] = clients.ScanProfileMutation( + self.listeners["mutations"] = clients.ScanProfileMutation( dsn=str(self.ctx.config.host_raw_data), - queue=f"{self.organisation.id}__scan_profile_mutations", - func=self.push_tasks_for_scan_profile_mutations, + queue="scan_profile_mutations", + func=self.process_mutations, prefetch_count=self.ctx.config.rabbitmq_prefetch_count, ) - self.run_in_thread( - name=f"BoefjeScheduler-{self.scheduler_id}-mutations", - target=self.listeners["scan_profile_mutations"].listen, - loop=False, - ) - - # New Boefjes - self.run_in_thread( - name=f"BoefjeScheduler-{self.scheduler_id}-new_boefjes", - target=self.push_tasks_for_new_boefjes, - interval=60.0, - ) - - # Rescheduling - self.run_in_thread( - name=f"scheduler-{self.scheduler_id}-reschedule", target=self.push_tasks_for_rescheduling, interval=60.0 - ) + self.run_in_thread(name="BoefjeScheduler-mutations", target=self.listeners["mutations"].listen, loop=False) + self.run_in_thread(name="BoefjeScheduler-new_boefjes", target=self.process_new_boefjes, interval=60.0) + self.run_in_thread(name="BoefjeScheduler-rescheduling", target=self.process_rescheduling, interval=60.0) self.logger.info( - "Boefje scheduler started for %s", - self.organisation.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - item_type=self.queue.item_type.__name__, + "Boefje scheduler started", scheduler_id=self.scheduler_id, item_type=self.queue.item_type.__name__ ) - @tracer.start_as_current_span("boefje_push_tasks_for_scan_profile_mutations") - def push_tasks_for_scan_profile_mutations(self, body: bytes) -> None: + @tracer.start_as_current_span("process_mutations") + def process_mutations(self, body: bytes) -> None: """Create tasks for oois that have a scan level change. Args: mutation: The mutation that was received. """ - # Convert body into a ScanProfileMutation - mutation = ScanProfileMutation.model_validate_json(body) - - self.logger.debug( - "Received scan level mutation %s for: %s", - mutation.operation, - mutation.primary_key, - ooi_primary_key=mutation.primary_key, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) - - # There should be an OOI in value - ooi = mutation.value - if ooi is None: + try: + # Convert body into a ScanProfileMutation + self.logger.info(body) + mutation = models.ScanProfileMutation.model_validate_json(body) self.logger.debug( - "Mutation value is None, skipping", organisation_id=self.organisation.id, scheduler_id=self.scheduler_id - ) - return - - if mutation.operation == MutationOperationType.DELETE: - # When there are tasks of the ooi are on the queue, we need to - # remove them from the queue. - items, _ = self.ctx.datastores.pq_store.get_items( + "Received scan level mutation %s for: %s", + mutation.operation, + mutation.primary_key, + ooi_primary_key=mutation.primary_key, scheduler_id=self.scheduler_id, - filters=filters.FilterRequest( - filters=[filters.Filter(column="data", field="input_ooi", operator="eq", value=ooi.primary_key)] - ), ) - # Delete all items for this ooi, update all tasks for this ooi - # to cancelled. - for item in items: - task = self.ctx.datastores.task_store.get_task(item.id) - if task is None: - continue + # There should be an OOI in value + ooi = mutation.value + if ooi is None: + self.logger.debug("Mutation value is None, skipping", scheduler_id=self.scheduler_id) + return - task.status = TaskStatus.CANCELLED - self.ctx.datastores.task_store.update_task(task) + # When the mutation is a delete operation, we need to remove all tasks + if mutation.operation == models.MutationOperationType.DELETE: + items, _ = self.ctx.datastores.pq_store.get_items( + scheduler_id=self.scheduler_id, + filters=filters.FilterRequest( + filters=[filters.Filter(column="data", field="input_ooi", operator="eq", value=ooi.primary_key)] + ), + ) - return + # Delete all items for this ooi, update all tasks for this ooi + # to cancelled. + for item in items: + task = self.ctx.datastores.task_store.get_task(item.id) + if task is None: + continue - # What available boefjes do we have for this ooi? - boefjes = self.get_boefjes_for_ooi(ooi) - if not boefjes: - self.logger.debug( - "No boefjes available for %s", - ooi.primary_key, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) + task.status = models.TaskStatus.CANCELLED + self.ctx.datastores.task_store.update_task(task) + + return + + # What available boefjes do we have for this ooi? + boefjes = self.get_boefjes_for_ooi(ooi, mutation.client_id) + if not boefjes: + self.logger.debug("No boefjes available for %s", ooi.primary_key, scheduler_id=self.scheduler_id) + return + except (StorageError, ValidationError): + self.logger.exception("Error occurred while processing mutation", scheduler_id=self.scheduler_id) return - with futures.ThreadPoolExecutor( - thread_name_prefix=f"BoefjeScheduler-TPE-{self.scheduler_id}-mutations" - ) as executor: - for boefje in boefjes: - # Is the boefje allowed to run on the ooi? - if not self.has_boefje_permission_to_run(boefje, ooi): - self.logger.debug( - "Boefje not allowed to run on ooi", - boefje_id=boefje.id, - boefje_name=boefje.name, - ooi_primary_key=ooi.primary_key, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) - continue + # Create tasks for the boefjes + boefje_tasks = [] + for boefje in boefjes: + if not self.has_boefje_permission_to_run(boefje, ooi): + self.logger.debug( + "Boefje not allowed to run on ooi", + boefje_id=boefje.id, + ooi_primary_key=ooi.primary_key, + scheduler_id=self.scheduler_id, + ) + continue - create_schedule = True - run_task = True - - # What type of run boefje is it? - if boefje.run_on: - create_schedule = False - run_task = False - if mutation.operation == MutationOperationType.CREATE: - run_task = "create" in boefje.run_on - elif mutation.operation == MutationOperationType.UPDATE: - run_task = "update" in boefje.run_on - - if not run_task: - self.logger.debug( - "Based on boefje run on type, skipping", - boefje_id=boefje.id, - ooi_primary_key=ooi.primary_key, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) - continue + create_schedule, run_task = True, True + + # What type of run boefje is it? + if boefje.run_on: + create_schedule = False + run_task = False + if mutation.operation == models.MutationOperationType.CREATE: + run_task = "create" in boefje.run_on + elif mutation.operation == models.MutationOperationType.UPDATE: + run_task = "update" in boefje.run_on + + if not run_task: + self.logger.debug( + "Based on boefje run on type, skipping", + boefje_id=boefje.id, + ooi_primary_key=ooi.primary_key, + organisation_id=mutation.client_id, + scheduler_id=self.scheduler_id, + ) + continue - boefje_task = BoefjeTask( - boefje=Boefje.model_validate(boefje.model_dump()), + boefje_tasks.append( + models.BoefjeTask( + boefje=models.Boefje.model_validate(boefje.model_dump()), input_ooi=ooi.primary_key if ooi else None, - organization=self.organisation.id, + organization=mutation.client_id, ) + ) + with futures.ThreadPoolExecutor( + thread_name_prefix=f"BoefjeScheduler-TPE-{self.scheduler_id}-mutations" + ) as executor: + for boefje_task in boefje_tasks: executor.submit( self.push_boefje_task, boefje_task, + mutation.client_id, create_schedule, - self.push_tasks_for_scan_profile_mutations.__name__, + self.process_mutations.__name__, ) - @tracer.start_as_current_span("boefje_push_tasks_for_new_boefjes") - def push_tasks_for_new_boefjes(self) -> None: + @tracer.start_as_current_span("process_new_boefjes") + def process_new_boefjes(self) -> None: """When new boefjes are added or enabled we find the ooi's that boefjes can run on, and create tasks for it.""" - new_boefjes = None + boefje_tasks = [] + + # TODO:: this should be optimized see #3357 try: - new_boefjes = self.ctx.services.katalogus.get_new_boefjes_by_org_id(self.organisation.id) + orgs = self.ctx.services.katalogus.get_organisations() except ExternalServiceError: - self.logger.error( - "Failed to get new boefjes for organisation: %s from katalogus", - self.organisation.name, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) + self.logger.exception("Error occurred while processing new boefjes", scheduler_id=self.scheduler_id) return - if new_boefjes is None or not new_boefjes: - self.logger.debug( - "No new boefjes for organisation: %s", - self.organisation.name, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) - return + for org in orgs: + try: + # Get new boefjes for organisation + new_boefjes = self.ctx.services.katalogus.get_new_boefjes_by_org_id(org.id) + if not new_boefjes: + self.logger.debug("No new boefjes found for organisation", organisation_id=org.id) + continue - self.logger.debug( - "Received new boefjes", - boefjes=[boefje.name for boefje in new_boefjes], - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) + # Get all oois for the new boefjes + for boefje in new_boefjes: + oois = self.get_oois_for_boefje(boefje, org.id) + for ooi in oois: + boefje_task = models.BoefjeTask( + boefje=models.Boefje.model_validate(boefje.dict()), + input_ooi=ooi.primary_key, + organization=org.id, + ) - for boefje in new_boefjes: - if not boefje.consumes: - self.logger.debug( - "No consumes found for boefje: %s", - boefje.name, - boefje_id=boefje.id, - organisation_id=self.organisation.id, + boefje_tasks.append((boefje_task, org.id)) + except ExternalServiceError: + self.logger.exception( + "Error occurred while processing new boefjes", + organisation_id=org.id, scheduler_id=self.scheduler_id, ) continue - oois_by_object_type: list[OOI] = [] - try: - oois_by_object_type = self.ctx.services.octopoes.get_objects_by_object_types( - self.organisation.id, boefje.consumes, list(range(boefje.scan_level, 5)) - ) - except ExternalServiceError as exc: - self.logger.error( - "Could not get oois for organisation: %s from octopoes", - self.organisation.name, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - exc_info=exc, + with futures.ThreadPoolExecutor( + thread_name_prefix=f"BoefjeScheduler-TPE-{self.scheduler_id}-new_boefjes" + ) as executor: + for boefje_task, org_id in boefje_tasks: + executor.submit( + self.push_boefje_task, boefje_task, org_id, self.create_schedule, self.process_new_boefjes.__name__ ) - continue - - with futures.ThreadPoolExecutor( - thread_name_prefix=f"BoefjeScheduler-TPE-{self.scheduler_id}-new_boefjes" - ) as executor: - for ooi in oois_by_object_type: - if not self.has_boefje_permission_to_run(boefje, ooi): - self.logger.debug( - "Boefje not allowed to run on ooi", - boefje_id=boefje.id, - ooi_primary_key=ooi.primary_key, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) - continue - - boefje_task = BoefjeTask( - boefje=Boefje.model_validate(boefje.dict()), - input_ooi=ooi.primary_key, - organization=self.organisation.id, - ) - - executor.submit(self.push_boefje_task, boefje_task, self.push_tasks_for_new_boefjes.__name__) - - @tracer.start_as_current_span("boefje_push_tasks_for_rescheduling") - def push_tasks_for_rescheduling(self): - if self.queue.full(): - self.logger.warning( - "Boefjes queue is full, not populating with new tasks", - queue_qsize=self.queue.qsize(), - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) - return + @tracer.start_as_current_span("process_rescheduling") + def process_rescheduling(self): try: schedules, _ = self.ctx.datastores.schedule_store.get_schedules( filters=filters.FilterRequest( @@ -342,247 +237,170 @@ def push_tasks_for_rescheduling(self): ] ) ) - except storage.errors.StorageError as exc_db: - self.logger.error( - "Could not get schedules for rescheduling %s", - self.scheduler_id, - scheduler_id=self.scheduler_id, - organisation_id=self.organisation.id, - exc_info=exc_db, - ) - raise exc_db - - if not schedules: - self.logger.debug( - "No schedules tasks found for scheduler: %s", - self.scheduler_id, - scheduler_id=self.scheduler_id, - organisation_id=self.organisation.id, - ) + if not schedules: + self.logger.debug( + "No schedules tasks found for scheduler: %s", self.scheduler_id, scheduler_id=self.scheduler_id + ) + return + except StorageError: + self.logger.exception("Error occurred while processing rescheduling", scheduler_id=self.scheduler_id) return with futures.ThreadPoolExecutor( thread_name_prefix=f"BoefjeScheduler-TPE-{self.scheduler_id}-rescheduling" ) as executor: for schedule in schedules: - boefje_task = BoefjeTask.model_validate(schedule.data) - - # Plugin still exists? try: + boefje_task = models.BoefjeTask.model_validate(schedule.data) + + # Plugin still exists? plugin = self.ctx.services.katalogus.get_plugin_by_id_and_org_id( - boefje_task.boefje.id, self.organisation.id + boefje_task.boefje.id, schedule.organisation ) if not plugin: self.logger.info( "Boefje does not exist anymore, skipping and disabling schedule", boefje_id=boefje_task.boefje.id, schedule_id=schedule.id, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, ) schedule.enabled = False self.ctx.datastores.schedule_store.update_schedule(schedule) continue - except ExternalServiceError as exc_plugin: - self.logger.error( - "Could not get plugin %s from katalogus", - boefje_task.boefje.id, - boefje_id=boefje_task.boefje.id, - schedule_id=schedule.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - exc_info=exc_plugin, - ) - continue - - # Plugin still enabled? - if not plugin.enabled: - self.logger.debug( - "Boefje is disabled, skipping", - boefje_id=boefje_task.boefje.id, - schedule_id=schedule.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) - schedule.enabled = False - self.ctx.datastores.schedule_store.update_schedule(schedule) - continue - - # Plugin a boefje? - if plugin.type != "boefje": - # We don't disable the schedule, since we should've gotten - # schedules for boefjes only. - self.logger.warning( - "Plugin is not a boefje, skipping", - plugin_id=plugin.id, - schedule_id=schedule.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) - continue - - # When the boefje task has an ooi, we need to do some additional - # checks. - ooi = None - if boefje_task.input_ooi: - # OOI still exists? - try: - ooi = self.ctx.services.octopoes.get_object(boefje_task.organization, boefje_task.input_ooi) - if not ooi: - self.logger.info( - "OOI does not exist anymore, skipping and disabling schedule", - ooi_primary_key=boefje_task.input_ooi, - schedule_id=schedule.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) - schedule.enabled = False - self.ctx.datastores.schedule_store.update_schedule(schedule) - continue - except ExternalServiceError as exc_ooi: - self.logger.error( - "Could not get ooi %s from octopoes", - boefje_task.input_ooi, - ooi_primary_key=boefje_task.input_ooi, - schedule_id=schedule.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - exc_info=exc_ooi, - ) - continue - # Boefje still consuming ooi type? - if ooi.object_type not in plugin.consumes: + # Plugin still enabled? + if not plugin.enabled: self.logger.debug( - "Boefje does not consume ooi anymore, skipping", + "Boefje is disabled, skipping", boefje_id=boefje_task.boefje.id, - ooi_primary_key=ooi.primary_key, - organisation_id=self.organisation.id, + schedule_id=schedule.id, scheduler_id=self.scheduler_id, ) schedule.enabled = False self.ctx.datastores.schedule_store.update_schedule(schedule) continue - # TODO: do we want to disable the schedule when a - # boefje is not allowed to scan an ooi? - - # Boefje allowed to scan ooi? - if not self.has_boefje_permission_to_run(plugin, ooi): - self.logger.info( - "Boefje not allowed to scan ooi, skipping and disabling schedule", - boefje_id=boefje_task.boefje.id, - ooi_primary_key=ooi.primary_key, + # Plugin a boefje? + if plugin.type != "boefje": + # We don't disable the schedule, since we should've gotten + # schedules for boefjes only. + self.logger.warning( + "Plugin is not a boefje, skipping", + plugin_id=plugin.id, schedule_id=schedule.id, - organisation_id=self.organisation.id, + organisation_id=schedule.organisation, scheduler_id=self.scheduler_id, ) - schedule.enabled = False - self.ctx.datastores.schedule_store.update_schedule(schedule) continue - new_boefje_task = BoefjeTask( - boefje=Boefje.model_validate(plugin.dict()), - input_ooi=ooi.primary_key if ooi else None, - organization=self.organisation.id, - ) + # When the boefje task has an ooi, we need to do some additional + # checks. + ooi = None + if boefje_task.input_ooi: + # OOI still exists? + ooi = self.ctx.services.octopoes.get_object(boefje_task.organization, boefje_task.input_ooi) + if not ooi: + self.logger.info( + "OOI does not exist anymore, skipping and disabling schedule", + ooi_primary_key=boefje_task.input_ooi, + schedule_id=schedule.id, + organisation_id=schedule.organisation, + scheduler_id=self.scheduler_id, + ) + schedule.enabled = False + self.ctx.datastores.schedule_store.update_schedule(schedule) + continue - executor.submit(self.push_boefje_task, new_boefje_task, self.push_tasks_for_rescheduling.__name__) + # Boefje still consuming ooi type? + if ooi.object_type not in plugin.consumes: + self.logger.debug( + "Boefje does not consume ooi anymore, skipping", + boefje_id=boefje_task.boefje.id, + ooi_primary_key=ooi.primary_key, + organisation_id=schedule.organisation, + scheduler_id=self.scheduler_id, + ) + schedule.enabled = False + self.ctx.datastores.schedule_store.update_schedule(schedule) + continue - @tracer.start_as_current_span("boefje_push_task") - def push_boefje_task(self, boefje_task: BoefjeTask, create_schedule: bool = True, caller: str = "") -> None: - """Given a Boefje and OOI create a BoefjeTask and push it onto - the queue. + # TODO: do we want to disable the schedule when a + # boefje is not allowed to scan an ooi? - Args: - boefje: Boefje to run. - ooi: OOI to run Boefje on. - caller: The name of the function that called this function, used for logging. + # Boefje allowed to scan ooi? + if not self.has_boefje_permission_to_run(plugin, ooi): + self.logger.info( + "Boefje not allowed to scan ooi, skipping and disabling schedule", + boefje_id=boefje_task.boefje.id, + ooi_primary_key=ooi.primary_key, + schedule_id=schedule.id, + organisation_id=schedule.organisation, + scheduler_id=self.scheduler_id, + ) + schedule.enabled = False + self.ctx.datastores.schedule_store.update_schedule(schedule) + continue - """ - self.logger.debug( - "Pushing boefje task", - task_hash=boefje_task.hash, - boefje_id=boefje_task.boefje.id, - ooi_primary_key=boefje_task.input_ooi, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - caller=caller, - ) + new_boefje_task = models.BoefjeTask( + boefje=models.Boefje.model_validate(plugin.dict()), + input_ooi=ooi.primary_key if ooi else None, + organization=schedule.organisation, + ) + except (StorageError, ValidationError, ExternalServiceError): + self.logger.exception( + "Error occurred while processing rescheduling", + schedule_id=schedule.id, + scheduler_id=self.scheduler_id, + ) + continue - try: - grace_period_passed = self.has_boefje_task_grace_period_passed(boefje_task) - if not grace_period_passed: - self.logger.debug( - "Task has not passed grace period: %s", - boefje_task.hash, - task_hash=boefje_task.hash, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - caller=caller, + executor.submit( + self.push_boefje_task, + new_boefje_task, + schedule.organisation, + self.create_schedule, + self.process_rescheduling.__name__, ) - return - except Exception as exc_grace_period: - self.logger.warning( - "Could not check if grace period has passed: %s", + + @exception_handler + @tracer.start_as_current_span("push_boefje_task") + def push_boefje_task( + self, boefje_task: models.BoefjeTask, organisation_id: str, create_schedule: bool = True, caller: str = "" + ) -> None: + grace_period_passed = self.has_boefje_task_grace_period_passed(boefje_task) + if not grace_period_passed: + self.logger.debug( + "Task has not passed grace period: %s", boefje_task.hash, task_hash=boefje_task.hash, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, caller=caller, - exc_info=exc_grace_period, ) return - try: - is_stalled = self.has_boefje_task_stalled(boefje_task) - if is_stalled: - self.logger.debug( - "Task is stalled: %s", - boefje_task.hash, - task_hash=boefje_task.hash, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - caller=caller, - ) - - # Update task in datastore to be failed - task_db = self.ctx.datastores.task_store.get_latest_task_by_hash(boefje_task.hash) - task_db.status = TaskStatus.FAILED - self.ctx.datastores.task_store.update_task(task_db) - except Exception as exc_stalled: - self.logger.warning( - "Could not check if task is stalled: %s", + is_stalled = self.has_boefje_task_stalled(boefje_task) + if is_stalled: + self.logger.debug( + "Task is stalled: %s", boefje_task.hash, - boefje_task_hash=boefje_task.hash, - organisation_id=self.organisation.id, + task_hash=boefje_task.hash, scheduler_id=self.scheduler_id, caller=caller, - exc_info=exc_stalled, ) - return - try: - is_running = self.has_boefje_task_started_running(boefje_task) - if is_running: - self.logger.debug( - "Task is still running: %s", - boefje_task.hash, - task_hash=boefje_task.hash, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - caller=caller, - ) - return - except Exception as exc_running: - self.logger.warning( - "Could not check if task is running: %s", + # Update task in datastore to be failed + task_db = self.ctx.datastores.task_store.get_latest_task_by_hash(boefje_task.hash) + task_db.status = models.TaskStatus.FAILED + self.ctx.datastores.task_store.update_task(task_db) + + is_running = self.has_boefje_task_started_running(boefje_task) + if is_running: + self.logger.debug( + "Task is still running: %s", boefje_task.hash, task_hash=boefje_task.hash, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, caller=caller, - exc_info=exc_running, ) return @@ -591,39 +409,25 @@ def push_boefje_task(self, boefje_task: BoefjeTask, create_schedule: bool = True "Task is already on queue: %s", boefje_task.hash, task_hash=boefje_task.hash, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, caller=caller, exc_info=True, ) return - latest_task = self.ctx.datastores.task_store.get_latest_task_by_hash(boefje_task.hash) - score = self.priority_ranker.rank(SimpleNamespace(latest_task=latest_task, task=boefje_task)) - - task = Task( + task = models.Task( id=boefje_task.id, scheduler_id=self.scheduler_id, + organisation=organisation_id, type=self.ITEM_TYPE.type, - priority=score, hash=boefje_task.hash, data=boefje_task.model_dump(), ) - try: - self.push_item_to_queue_with_timeout(item=task, max_tries=self.max_tries, create_schedule=create_schedule) - except QueueFullError: - self.logger.warning( - "Could not add task to queue, queue was full: %s", - boefje_task.hash, - task_hash=boefje_task.hash, - queue_qsize=self.queue.qsize(), - queue_maxsize=self.queue.maxsize, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - caller=caller, - ) - return + latest_task = self.ctx.datastores.task_store.get_latest_task_by_hash(boefje_task.hash) + task.priority = self.ranker.rank(SimpleNamespace(latest_task=latest_task, task=boefje_task)) + + self.push_item_to_queue_with_timeout(item=task, max_tries=self.max_tries, create_schedule=create_schedule) self.logger.info( "Created boefje task", @@ -631,15 +435,14 @@ def push_boefje_task(self, boefje_task: BoefjeTask, create_schedule: bool = True task_hash=task.hash, boefje_id=boefje_task.boefje.id, ooi_primary_key=boefje_task.input_ooi, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, caller=caller, ) - def push_item_to_queue(self, item: Task, create_schedule: bool = True) -> Task: + def push_item_to_queue(self, item: models.Task, create_schedule: bool = True) -> models.Task: """Some boefje scheduler specific logic before pushing the item to the queue.""" - boefje_task = BoefjeTask.model_validate(item.data) + boefje_task = models.BoefjeTask.model_validate(item.data) # Check if id's are unique and correctly set. Same id's are necessary # for the task runner. @@ -651,8 +454,7 @@ def push_item_to_queue(self, item: Task, create_schedule: bool = True) -> Task: return super().push_item_to_queue(item=item, create_schedule=create_schedule) - @tracer.start_as_current_span("boefje_has_boefje_permission_to_run") - def has_boefje_permission_to_run(self, boefje: Plugin, ooi: OOI) -> bool: + def has_boefje_permission_to_run(self, boefje: models.Plugin, ooi: models.OOI) -> bool: """Checks whether a boefje is allowed to run on an ooi. Args: @@ -664,22 +466,14 @@ def has_boefje_permission_to_run(self, boefje: Plugin, ooi: OOI) -> bool: """ if boefje.enabled is False: self.logger.debug( - "Boefje: %s is disabled", - boefje.name, - boefje_id=boefje.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, + "Boefje: %s is disabled", boefje.name, boefje_id=boefje.id, scheduler_id=self.scheduler_id ) return False boefje_scan_level = boefje.scan_level if boefje_scan_level is None: self.logger.warning( - "No scan level found for boefje: %s", - boefje.id, - boefje_id=boefje.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, + "No scan level found for boefje: %s", boefje.id, boefje_id=boefje.id, scheduler_id=self.scheduler_id ) return False @@ -692,7 +486,6 @@ def has_boefje_permission_to_run(self, boefje: Plugin, ooi: OOI) -> bool: "No scan_profile found for ooi: %s", ooi.primary_key, ooi_primary_key=ooi.primary_key, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, ) return False @@ -703,7 +496,6 @@ def has_boefje_permission_to_run(self, boefje: Plugin, ooi: OOI) -> bool: "No scan level found for ooi: %s", ooi.primary_key, ooi_primary_key=ooi.primary_key, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, ) return False @@ -721,15 +513,13 @@ def has_boefje_permission_to_run(self, boefje: Plugin, ooi: OOI) -> bool: ooi_scan_level, boefje_id=boefje.id, ooi_primary_key=ooi.primary_key, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, ) return False return True - @tracer.start_as_current_span("boefje_has_boefje_task_started_running") - def has_boefje_task_started_running(self, task: BoefjeTask) -> bool: + def has_boefje_task_started_running(self, task: models.BoefjeTask) -> bool: """Check if the same task is already running. Args: @@ -739,44 +529,17 @@ def has_boefje_task_started_running(self, task: BoefjeTask) -> bool: True if the task is still running, False otherwise. """ # Is task still running according to the datastore? - task_db = None - try: - task_db = self.ctx.datastores.task_store.get_latest_task_by_hash(task.hash) - except Exception as exc_db: - self.logger.error( - "Could not get latest task by hash: %s", - task.hash, - task_id=task.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - exc_info=exc_db, - ) - raise exc_db - - if task_db is not None and task_db.status not in [TaskStatus.FAILED, TaskStatus.COMPLETED]: + task_db = self.ctx.datastores.task_store.get_latest_task_by_hash(task.hash) + if task_db is not None and task_db.status not in [models.TaskStatus.FAILED, models.TaskStatus.COMPLETED]: self.logger.debug( - "Task is still running, according to the datastore", - task_id=task_db.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, + "Task is still running, according to the datastore", task_id=task_db.id, scheduler_id=self.scheduler_id ) return True # Is task running according to bytes? - try: - task_bytes = self.ctx.services.bytes.get_last_run_boefje( - boefje_id=task.boefje.id, input_ooi=task.input_ooi, organization_id=task.organization - ) - except ExternalServiceError as exc: - self.logger.error( - "Failed to get last run boefje from bytes", - boefje_id=task.boefje.id, - input_ooi_primary_key=task.input_ooi, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - exc_info=exc, - ) - raise exc + task_bytes = self.ctx.services.bytes.get_last_run_boefje( + boefje_id=task.boefje.id, input_ooi=task.input_ooi, organization_id=task.organization + ) # Task has been finished (failed, or succeeded) according to # the datastore, but we have no results of it in bytes, meaning @@ -785,7 +548,7 @@ def has_boefje_task_started_running(self, task: BoefjeTask) -> bool: if ( task_bytes is None and task_db is not None - and task_db.status in [TaskStatus.COMPLETED, TaskStatus.FAILED] + and task_db.status in [models.TaskStatus.COMPLETED, models.TaskStatus.FAILED] and ( task_db.modified_at is not None and task_db.modified_at @@ -797,24 +560,19 @@ def has_boefje_task_started_running(self, task: BoefjeTask) -> bool: "please review the bytes logs for more information regarding " "this error.", task_id=task_db.id, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, ) raise RuntimeError("Task has been finished, but no results found in bytes") if task_bytes is not None and task_bytes.ended_at is None and task_bytes.started_at is not None: self.logger.debug( - "Task is still running, according to bytes", - task_id=task_bytes.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, + "Task is still running, according to bytes", task_id=task_bytes.id, scheduler_id=self.scheduler_id ) return True return False - @tracer.start_as_current_span("boefje_is_task_stalled") - def has_boefje_task_stalled(self, task: BoefjeTask) -> bool: + def has_boefje_task_stalled(self, task: models.BoefjeTask) -> bool: """Check if the same task is stalled. Args: @@ -823,23 +581,10 @@ def has_boefje_task_stalled(self, task: BoefjeTask) -> bool: Returns: True if the task is stalled, False otherwise. """ - task_db = None - try: - task_db = self.ctx.datastores.task_store.get_latest_task_by_hash(task.hash) - except Exception as exc_db: - self.logger.warning( - "Could not get latest task by hash: %s", - task.hash, - task_hash=task.hash, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - exc_info=exc_db, - ) - raise exc_db - + task_db = self.ctx.datastores.task_store.get_latest_task_by_hash(task.hash) if ( task_db is not None - and task_db.status == TaskStatus.DISPATCHED + and task_db.status == models.TaskStatus.DISPATCHED and ( task_db.modified_at is not None and datetime.now(timezone.utc) @@ -850,8 +595,7 @@ def has_boefje_task_stalled(self, task: BoefjeTask) -> bool: return False - @tracer.start_as_current_span("boefje_has_boefje_task_grace_period_passed") - def has_boefje_task_grace_period_passed(self, task: BoefjeTask) -> bool: + def has_boefje_task_grace_period_passed(self, task: models.BoefjeTask) -> bool: """Check if the grace period has passed for a task in both the datastore and bytes. @@ -865,24 +609,13 @@ def has_boefje_task_grace_period_passed(self, task: BoefjeTask) -> bool: True if the grace period has passed, False otherwise. """ # Does boefje have an interval specified? - plugin = self.ctx.services.katalogus.get_plugin_by_id_and_org_id(task.boefje.id, self.organisation.id) + plugin = self.ctx.services.katalogus.get_plugin_by_id_and_org_id(task.boefje.id, task.organization) if plugin is not None and plugin.interval is not None and plugin.interval > 0: timeout = timedelta(minutes=plugin.interval) else: timeout = timedelta(seconds=self.ctx.config.pq_grace_period) - try: - task_db = self.ctx.datastores.task_store.get_latest_task_by_hash(task.hash) - except Exception as exc_db: - self.logger.warning( - "Could not get latest task by hash: %s", - task.hash, - task_hash=task.hash, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - exc_info=exc_db, - ) - raise exc_db + task_db = self.ctx.datastores.task_store.get_latest_task_by_hash(task.hash) # Has grace period passed according to datastore? if task_db is not None and datetime.now(timezone.utc) - task_db.modified_at < timeout: @@ -890,24 +623,13 @@ def has_boefje_task_grace_period_passed(self, task: BoefjeTask) -> bool: "Task has not passed grace period, according to the datastore", task_id=task_db.id, task_hash=task.hash, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, ) return False - try: - task_bytes = self.ctx.services.bytes.get_last_run_boefje( - boefje_id=task.boefje.id, input_ooi=task.input_ooi, organization_id=task.organization - ) - except ExternalServiceError as exc_bytes: - self.logger.error( - "Failed to get last run boefje from bytes", - boefje_id=task.boefje.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - exc_info=exc_bytes, - ) - raise exc_bytes + task_bytes = self.ctx.services.bytes.get_last_run_boefje( + boefje_id=task.boefje.id, input_ooi=task.input_ooi, organization_id=task.organization + ) # Did the grace period pass, according to bytes? if ( @@ -919,14 +641,13 @@ def has_boefje_task_grace_period_passed(self, task: BoefjeTask) -> bool: "Task has not passed grace period, according to bytes", task_id=task_bytes.id, task_hash=task.hash, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, ) return False return True - def get_boefjes_for_ooi(self, ooi: OOI) -> list[Plugin]: + def get_boefjes_for_ooi(self, ooi: models.OOI, organisation: str) -> list[models.Plugin]: """Get available all boefjes (enabled and disabled) for an ooi. Args: @@ -935,24 +656,13 @@ def get_boefjes_for_ooi(self, ooi: OOI) -> list[Plugin]: Returns: A list of Plugin of type Boefje that can be run on the ooi. """ - try: - boefjes = self.ctx.services.katalogus.get_boefjes_by_type_and_org_id(ooi.object_type, self.organisation.id) - except ExternalServiceError: - self.logger.error( - "Could not get boefjes for object_type: %s", - ooi.object_type, - object_type=ooi.object_type, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) - return [] + boefjes = self.ctx.services.katalogus.get_boefjes_by_type_and_org_id(ooi.object_type, organisation) if boefjes is None: self.logger.debug( "No boefjes found for type: %s", ooi.object_type, input_ooi_primary_key=ooi.primary_key, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, ) return [] @@ -963,30 +673,52 @@ def get_boefjes_for_ooi(self, ooi: OOI) -> list[Plugin]: ooi, input_ooi_primary_key=ooi.primary_key, boefjes=[boefje.id for boefje in boefjes], - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, ) return boefjes - def set_cron(self, item: Task) -> str | None: + def get_oois_for_boefje(self, boefje: models.Plugin, organisation: str) -> list[models.OOI]: + oois = [] + + oois_by_object_type = self.ctx.services.octopoes.get_objects_by_object_types( + organisation, + boefje.consumes, + list(range(boefje.scan_level, 5)), # type: ignore + ) + + # Filter OOIs based on permission + for ooi in oois_by_object_type: + if not self.has_boefje_permission_to_run(boefje, ooi): + self.logger.debug( + "Boefje not allowed to run on ooi", + boefje_id=boefje.id, + ooi_primary_key=ooi.primary_key, + scheduler_id=self.scheduler_id, + ) + continue + oois.append(ooi) + + return oois + + def set_cron(self, item: models.Task) -> str | None: """Override Schedule.set_cron() when a boefje specifies a schedule for execution (cron expression) we schedule for its execution""" # Does a boefje have a schedule defined? plugin = self.ctx.services.katalogus.get_plugin_by_id_and_org_id( - utils.deep_get(item.data, ["boefje", "id"]), self.organisation.id + utils.deep_get(item.data, ["boefje", "id"]), item.organisation ) if plugin is None or plugin.cron is None: return super().set_cron(item) return plugin.cron - def calculate_deadline(self, task: Task) -> datetime: + def calculate_deadline(self, task: models.Task) -> datetime: """Override Scheduler.calculate_deadline() to calculate the deadline for a task and based on the boefje interval.""" # Does the boefje have an interval defined? plugin = self.ctx.services.katalogus.get_plugin_by_id_and_org_id( - utils.deep_get(task.data, ["boefje", "id"]), self.organisation.id + utils.deep_get(task.data, ["boefje", "id"]), task.organisation ) if plugin is not None and plugin.interval is not None and plugin.interval > 0: return datetime.now(timezone.utc) + timedelta(minutes=plugin.interval) diff --git a/mula/scheduler/schedulers/schedulers/normalizer.py b/mula/scheduler/schedulers/schedulers/normalizer.py index d1dff6e7c56..ff53d58bd75 100644 --- a/mula/scheduler/schedulers/schedulers/normalizer.py +++ b/mula/scheduler/schedulers/schedulers/normalizer.py @@ -1,62 +1,38 @@ import uuid -from collections.abc import Callable from concurrent import futures from types import SimpleNamespace -from typing import Any +from typing import Any, Literal -import structlog from opentelemetry import trace from scheduler import clients, context, models from scheduler.clients.errors import ExternalServiceError -from scheduler.models import Normalizer, NormalizerTask, Organisation, Plugin, RawDataReceivedEvent, Task, TaskStatus -from scheduler.schedulers import Scheduler -from scheduler.schedulers.queue import PriorityQueue, QueueFullError -from scheduler.schedulers.rankers import NormalizerRanker +from scheduler.schedulers import Scheduler, rankers +from scheduler.schedulers.errors import exception_handler tracer = trace.get_tracer(__name__) class NormalizerScheduler(Scheduler): - """A KAT specific implementation of a Normalizer scheduler. It extends - the `Scheduler` class by adding a `organisation` attribute. + """Scheduler implementation for the creation of NormalizerTask models. Attributes: - logger: A logger instance. - organisation: The organisation that this scheduler is for. + ranker: The ranker to calculate the priority of a task. """ - ITEM_TYPE: Any = NormalizerTask - - def __init__( - self, - ctx: context.AppContext, - scheduler_id: str, - organisation: Organisation, - queue: PriorityQueue | None = None, - callback: Callable[..., None] | None = None, - ): - self.logger: structlog.BoundLogger = structlog.getLogger(__name__) - self.organisation: Organisation = organisation - - self.queue = queue or PriorityQueue( - pq_id=scheduler_id, - maxsize=ctx.config.pq_maxsize, - item_type=self.ITEM_TYPE, - allow_priority_updates=True, - pq_store=ctx.datastores.pq_store, - ) + ID: Literal["normalizer"] = "normalizer" + TYPE: models.SchedulerType = models.SchedulerType.NORMALIZER + ITEM_TYPE: Any = models.NormalizerTask - super().__init__( - ctx=ctx, - queue=self.queue, - scheduler_id=scheduler_id, - callback=callback, - create_schedule=False, - auto_calculate_deadline=False, - ) + def __init__(self, ctx: context.AppContext): + """Initializes the NormalizerScheduler. - self.ranker = NormalizerRanker(ctx=self.ctx) + Args: + ctx (context.AppContext): Application context of shared data (e.g. + configuration, external services connections). + """ + super().__init__(ctx=ctx, scheduler_id=self.ID, create_schedule=False, auto_calculate_deadline=False) + self.ranker = rankers.NormalizerRanker(ctx=self.ctx) def run(self) -> None: """The run method is called when the scheduler is started. It will @@ -68,31 +44,22 @@ def run(self) -> None: for each normalizer that is registered for the mime type of the raw file. """ - listener = clients.RawData( + self.listeners["raw_data"] = clients.RawData( dsn=str(self.ctx.config.host_raw_data), - queue=f"{self.organisation.id}__raw_file_received", - func=self.push_tasks_for_received_raw_data, + queue="raw_file_received", + func=self.process_raw_data, prefetch_count=self.ctx.config.rabbitmq_prefetch_count, ) - self.listeners["raw_data"] = listener - - self.run_in_thread( - name=f"NormalizerScheduler-{self.scheduler_id}-raw_file", - target=self.listeners["raw_data"].listen, - loop=False, - ) + self.run_in_thread(name="NormalizerScheduler-raw_file", target=self.listeners["raw_data"].listen, loop=False) self.logger.info( - "Normalizer scheduler started for %s", - self.organisation.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - item_type=self.queue.item_type.__name__, + "Normalizer scheduler started", scheduler_id=self.scheduler_id, item_type=self.queue.item_type.__name__ ) - @tracer.start_as_current_span("normalizer_push_task_for_received_raw_data") - def push_tasks_for_received_raw_data(self, body: bytes) -> None: + # TODO: exception handling + @tracer.start_as_current_span("process_raw_data") + def process_raw_data(self, body: bytes) -> None: """Create tasks for the received raw data. Args: @@ -100,134 +67,79 @@ def push_tasks_for_received_raw_data(self, body: bytes) -> None: message queue. """ # Convert body into a RawDataReceivedEvent - latest_raw_data = RawDataReceivedEvent.model_validate_json(body) - + latest_raw_data = models.RawDataReceivedEvent.model_validate_json(body) self.logger.debug( "Received raw data %s", latest_raw_data.raw_data.id, raw_data_id=latest_raw_data.raw_data.id, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, ) # Check if the raw data doesn't contain an error mime-type, # we don't need to create normalizers when the raw data returned # an error. - for mime_type in latest_raw_data.raw_data.mime_types: - if mime_type.get("value", "").startswith("error/"): - self.logger.debug( - "Skipping raw data %s with error mime type", - latest_raw_data.raw_data.id, - mime_type=mime_type.get("value"), - raw_data_id=latest_raw_data.raw_data.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) - return - - # Get all normalizers for the mime types of the raw data - normalizers: dict[str, Plugin] = {} - for mime_type in latest_raw_data.raw_data.mime_types: - normalizers_by_mime_type: list[Plugin] = self.get_normalizers_for_mime_type(mime_type.get("value")) - - for normalizer in normalizers_by_mime_type: - normalizers[normalizer.id] = normalizer - - if not normalizers: + if self.has_raw_data_errors(latest_raw_data.raw_data): self.logger.debug( - "No normalizers found for raw data %s", + "Skipping raw data %s with error mime type", latest_raw_data.raw_data.id, raw_data_id=latest_raw_data.raw_data.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, ) + return - with futures.ThreadPoolExecutor( - thread_name_prefix=f"NormalizerScheduler-TPE-{self.scheduler_id}-raw_data" - ) as executor: - for normalizer in normalizers.values(): - if not self.has_normalizer_permission_to_run(normalizer): - self.logger.debug( - "Normalizer is not allowed to run: %s", - normalizer.id, - normalizer_id=normalizer.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) - continue - - normalizer_task = NormalizerTask( - normalizer=Normalizer.model_validate(normalizer.model_dump()), raw_data=latest_raw_data.raw_data - ) - - executor.submit( - self.push_normalizer_task, normalizer_task, self.push_tasks_for_received_raw_data.__name__ - ) - - @tracer.start_as_current_span("normalizer_push_task") - def push_normalizer_task(self, normalizer_task: models.NormalizerTask, caller: str = "") -> None: - """Given a normalizer and raw data, create a task and push it to the - queue. + # TODO: deduplication + # Get all normalizers for the mime types of the raw data + normalizers = [] + for mime_type in latest_raw_data.raw_data.mime_types: + normalizers_by_mime_type = self.get_normalizers_for_mime_type( + mime_type.get("value"), latest_raw_data.organization + ) + normalizers.extend(normalizers_by_mime_type) - Args: - normalizer: The normalizer to create a task for. - raw_data: The raw data to create a task for. - caller: The name of the function that called this function, used for logging. - """ self.logger.debug( - "Pushing normalizer task", - task_id=normalizer_task.id, - normalizer_id=normalizer_task.normalizer.id, - organisation_id=self.organisation.id, + "Found normalizers for raw data", + raw_data_id=latest_raw_data.raw_data.id, + mime_types=[mime_type.get("value") for mime_type in latest_raw_data.raw_data.mime_types], + normalizers=[normalizer.id for normalizer in normalizers], scheduler_id=self.scheduler_id, - caller=caller, ) - try: - plugin = self.ctx.services.katalogus.get_plugin_by_id_and_org_id( - normalizer_task.normalizer.id, self.organisation.id - ) - if not self.has_normalizer_permission_to_run(plugin): + # Create tasks for the normalizers + normalizer_tasks = [] + for normalizer in normalizers: + if not self.has_normalizer_permission_to_run(normalizer): self.logger.debug( - "Task is not allowed to run: %s", - normalizer_task.id, - task_id=normalizer_task.id, - organisation_id=self.organisation.id, + "Normalizer is not allowed to run: %s", + normalizer.id, + normalizer_id=normalizer.id, scheduler_id=self.scheduler_id, - caller=caller, ) - return - except ExternalServiceError: - self.logger.warning( - "Could not get plugin by id: %s", - normalizer_task.normalizer.id, - task_id=normalizer_task.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - caller=caller, + continue + + normalizer_task = models.NormalizerTask( + normalizer=models.Normalizer.model_validate(normalizer.model_dump()), raw_data=latest_raw_data.raw_data ) - return + normalizer_tasks.append(normalizer_task) - try: - if self.has_normalizer_task_started_running(normalizer_task): - self.logger.debug( - "Task is still running: %s", - normalizer_task.id, - task_id=normalizer_task.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - caller=caller, + with futures.ThreadPoolExecutor( + thread_name_prefix=f"NormalizerScheduler-TPE-{self.scheduler_id}-raw_data" + ) as executor: + for normalizer_task in normalizer_tasks: + executor.submit( + self.push_normalizer_task, normalizer_task, latest_raw_data.organization, self.create_schedule ) - return - except Exception: - self.logger.warning( - "Could not check if task is running: %s", + + @exception_handler + @tracer.start_as_current_span("push_normalizer_task") + def push_normalizer_task( + self, normalizer_task: models.NormalizerTask, organisation_id: str, create_schedule: bool, caller: str = "" + ) -> None: + if self.has_normalizer_task_started_running(normalizer_task): + self.logger.debug( + "Task is still running: %s", normalizer_task.id, task_id=normalizer_task.id, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, caller=caller, - exc_info=True, ) return @@ -236,37 +148,23 @@ def push_normalizer_task(self, normalizer_task: models.NormalizerTask, caller: s "Task is already on queue: %s", normalizer_task.id, task_id=normalizer_task.id, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, caller=caller, ) return - score = self.ranker.rank(SimpleNamespace(raw_data=normalizer_task.raw_data, task=normalizer_task)) - - task = Task( + task = models.Task( id=normalizer_task.id, scheduler_id=self.scheduler_id, - type=self.ITEM_TYPE.type, - priority=score, + organisation=organisation_id, + type=normalizer_task.type, hash=normalizer_task.hash, data=normalizer_task.model_dump(), ) - try: - self.push_item_to_queue_with_timeout(item=task, max_tries=self.max_tries) - except QueueFullError: - self.logger.warning( - "Could not add task to queue, queue was full: %s", - task.id, - task_id=task.id, - queue_qsize=self.queue.qsize(), - queue_maxsize=self.queue.maxsize, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - caller=caller, - ) - return + task.priority = self.ranker.rank(SimpleNamespace(raw_data=normalizer_task.raw_data, task=normalizer_task)) + + self.push_item_to_queue_with_timeout(task, self.max_tries, create_schedule=create_schedule) self.logger.info( "Created normalizer task", @@ -274,15 +172,14 @@ def push_normalizer_task(self, normalizer_task: models.NormalizerTask, caller: s task_hash=task.hash, normalizer_id=normalizer_task.normalizer.id, raw_data_id=normalizer_task.raw_data.id, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, caller=caller, ) - def push_item_to_queue(self, item: Task, create_schedule: bool = True) -> Task: + def push_item_to_queue(self, item: models.Task, create_schedule: bool = True) -> models.Task: """Some normalizer scheduler specific logic before pushing the item to the queue.""" - normalizer_task = NormalizerTask.model_validate(item.data) + normalizer_task = models.NormalizerTask.model_validate(item.data) # Check if id's are unique and correctly set. Same id's are necessary # for the task runner. @@ -294,8 +191,7 @@ def push_item_to_queue(self, item: Task, create_schedule: bool = True) -> Task: return super().push_item_to_queue(item=item, create_schedule=create_schedule) - @tracer.start_as_current_span("normalizer_has_normalizer_permission_to_run") - def has_normalizer_permission_to_run(self, normalizer: Plugin) -> bool: + def has_normalizer_permission_to_run(self, normalizer: models.Plugin) -> bool: """Check if the task is allowed to run. Args: @@ -306,18 +202,13 @@ def has_normalizer_permission_to_run(self, normalizer: Plugin) -> bool: """ if not normalizer.enabled: self.logger.debug( - "Normalizer: %s is disabled", - normalizer.id, - normalizer_id=normalizer.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, + "Normalizer: %s is disabled", normalizer.id, normalizer_id=normalizer.id, scheduler_id=self.scheduler_id ) return False return True - @tracer.start_as_current_span("normalizer_has_normalizer_task_started_running") - def has_normalizer_task_started_running(self, task: NormalizerTask) -> bool: + def has_normalizer_task_started_running(self, task: models.NormalizerTask) -> bool: """Check if the same task is already running. Args: @@ -328,33 +219,32 @@ def has_normalizer_task_started_running(self, task: NormalizerTask) -> bool: """ # Get the last tasks that have run or are running for the hash # of this particular NormalizerTask. - try: - task_db = self.ctx.datastores.task_store.get_latest_task_by_hash(task.hash) - except Exception as exc_db: - self.logger.error( - "Could not get latest task by hash: %s", - task.hash, - task_id=task.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - exc_info=exc_db, - ) - raise exc_db + task_db = self.ctx.datastores.task_store.get_latest_task_by_hash(task.hash) # Is task still running according to the datastore? - if task_db is not None and task_db.status not in [TaskStatus.COMPLETED, TaskStatus.FAILED]: + if task_db is not None and task_db.status not in [models.TaskStatus.COMPLETED, models.TaskStatus.FAILED]: self.logger.debug( "Task is still running, according to the datastore", task_id=task_db.id, task_hash=task.hash, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, ) return True return False - def get_normalizers_for_mime_type(self, mime_type: str) -> list[Plugin]: + def has_raw_data_errors(self, raw_data: models.RawData) -> bool: + """Check if the raw data contains errors. + + Args: + raw_data: The raw data to check. + + Returns: + True if the raw data contains errors, False otherwise. + """ + return any(mime_type.get("value", "").startswith("error/") for mime_type in raw_data.mime_types) + + def get_normalizers_for_mime_type(self, mime_type: str, organisation: str) -> list[models.Plugin]: """Get available normalizers for a given mime type. Args: @@ -364,37 +254,17 @@ def get_normalizers_for_mime_type(self, mime_type: str) -> list[Plugin]: A list of Plugins of type normalizer for the given mime type. """ try: - normalizers = self.ctx.services.katalogus.get_normalizers_by_org_id_and_type( - self.organisation.id, mime_type - ) + normalizers = self.ctx.services.katalogus.get_normalizers_by_org_id_and_type(organisation, mime_type) except ExternalServiceError: - self.logger.warning( - "Could not get normalizers for mime_type: %s [mime_type=%s, organisation_id=%s, scheduler_id=%s]", - mime_type, - mime_type, - self.organisation.id, - self.scheduler_id, - ) - return [] - - if normalizers is None: - self.logger.debug( - "No normalizer found for mime_type: %s", + self.logger.error( + "Failed to get normalizers for mime type %s", mime_type, mime_type=mime_type, - organisation_id=self.organisation.id, scheduler_id=self.scheduler_id, ) return [] - self.logger.debug( - "Found %d normalizers for mime_type: %s", - len(normalizers), - mime_type, - mime_type=mime_type, - normalizers=[normalizer.id for normalizer in normalizers], - organisation_=self.organisation.id, - scheduler_id=self.scheduler_id, - ) + if normalizers is None: + return [] return normalizers diff --git a/mula/scheduler/schedulers/schedulers/report.py b/mula/scheduler/schedulers/schedulers/report.py index 4fa4189c58c..14abaf6c563 100644 --- a/mula/scheduler/schedulers/schedulers/report.py +++ b/mula/scheduler/schedulers/schedulers/report.py @@ -1,210 +1,125 @@ -from collections.abc import Callable from concurrent import futures from datetime import datetime, timezone -from typing import Any +from typing import Any, Literal -import structlog from opentelemetry import trace -from scheduler import context, storage -from scheduler.models import Organisation, ReportTask, Task, TaskStatus +from scheduler import context, models from scheduler.schedulers import Scheduler -from scheduler.schedulers.queue import PriorityQueue, QueueFullError +from scheduler.schedulers.errors import exception_handler from scheduler.storage import filters tracer = trace.get_tracer(__name__) class ReportScheduler(Scheduler): - ITEM_TYPE: Any = ReportTask - - def __init__( - self, - ctx: context.AppContext, - scheduler_id: str, - organisation: Organisation, - queue: PriorityQueue | None = None, - callback: Callable[..., None] | None = None, - ): - self.logger: structlog.BoundLogger = structlog.get_logger(__name__) - self.organisation = organisation - self.queue = queue or PriorityQueue( - pq_id=scheduler_id, - maxsize=ctx.config.pq_maxsize, - item_type=self.ITEM_TYPE, - allow_priority_updates=True, - pq_store=ctx.datastores.pq_store, - ) + """Scheduler implementation for the creation of ReportTask models.""" - super().__init__( - ctx=ctx, - queue=self.queue, - scheduler_id=scheduler_id, - callback=callback, - create_schedule=True, - auto_calculate_deadline=False, - ) + ID: Literal["report"] = "report" + TYPE: models.SchedulerType = models.SchedulerType.REPORT + ITEM_TYPE: Any = models.ReportTask + + def __init__(self, ctx: context.AppContext): + """Initializes the NormalizerScheduler. + + Args: + ctx (context.AppContext): Application context of shared data (e.g. + configuration, external services connections). + """ + super().__init__(ctx=ctx, scheduler_id=self.ID, create_schedule=True, auto_calculate_deadline=False) def run(self) -> None: + """The run method is called when the schedulers is started. It will + start the rescheduling process for the ReportTask models that are + scheduled. + """ # Rescheduling self.run_in_thread( - name=f"scheduler-{self.scheduler_id}-reschedule", target=self.push_tasks_for_rescheduling, interval=60.0 + name=f"scheduler-{self.scheduler_id}-reschedule", target=self.process_rescheduling, interval=60.0 + ) + self.logger.info( + "Report scheduler started", scheduler_id=self.scheduler_id, item_type=self.queue.item_type.__name__ ) - @tracer.start_as_current_span(name="report_push_tasks_for_rescheduling") - def push_tasks_for_rescheduling(self): - if self.queue.full(): - self.logger.warning( - "Report queue is full, not populating with new tasks", - queue_qsize=self.queue.qsize(), - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, + @tracer.start_as_current_span(name="process_rescheduling") + def process_rescheduling(self): + schedules, _ = self.ctx.datastores.schedule_store.get_schedules( + filters=filters.FilterRequest( + filters=[ + filters.Filter(column="scheduler_id", operator="eq", value=self.scheduler_id), + filters.Filter(column="deadline_at", operator="lt", value=datetime.now(timezone.utc)), + filters.Filter(column="enabled", operator="eq", value=True), + ] ) - return + ) - try: - schedules, _ = self.ctx.datastores.schedule_store.get_schedules( - filters=filters.FilterRequest( - filters=[ - filters.Filter(column="scheduler_id", operator="eq", value=self.scheduler_id), - filters.Filter(column="deadline_at", operator="lt", value=datetime.now(timezone.utc)), - filters.Filter(column="enabled", operator="eq", value=True), - ] - ) - ) - except storage.errors.StorageError as exc_db: - self.logger.error( - "Could not get schedules for rescheduling %s", - self.scheduler_id, - scheduler_id=self.scheduler_id, - organisation_id=self.organisation.id, - exc_info=exc_db, - ) - raise exc_db + # Create report tasks for the schedules + report_tasks = [] + for schedule in schedules: + if not self.has_schedule_permission_to_run(schedule): + continue + + report_task = models.ReportTask.model_validate(schedule.data) + report_tasks.append(report_task) with futures.ThreadPoolExecutor( thread_name_prefix=f"ReportScheduler-TPE-{self.scheduler_id}-rescheduling" ) as executor: - for schedule in schedules: - report_task = ReportTask.model_validate(schedule.data) - - # When the schedule has no schedule (cron expression), but a - # task is already executed for this schedule we should not run - # the task again - if schedule.schedule is None: - try: - _, count = self.ctx.datastores.task_store.get_tasks( - scheduler_id=self.scheduler_id, - task_type=report_task.type, - filters=filters.FilterRequest( - filters=[ - filters.Filter(column="hash", operator="eq", value=report_task.hash), - filters.Filter(column="schedule_id", operator="eq", value=str(schedule.id)), - ] - ), - ) - if count > 0: - self.logger.debug( - "Schedule has no schedule, but task already executed", - schedule_id=schedule.id, - scheduler_id=self.scheduler_id, - organisation_id=self.organisation.id, - ) - continue - except storage.errors.StorageError as exc_db: - self.logger.error( - "Could not get latest task by hash %s", - report_task.hash, - scheduler_id=self.scheduler_id, - organisation_id=self.organisation.id, - exc_info=exc_db, - ) - continue - - executor.submit(self.push_report_task, report_task, self.push_tasks_for_rescheduling.__name__) - - def push_report_task(self, report_task: ReportTask, caller: str = "") -> None: - self.logger.debug( - "Pushing report task", - task_hash=report_task.hash, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - caller=caller, - ) - - if self.has_report_task_started_running(report_task): - self.logger.debug( - "Report task already running", - task_hash=report_task.hash, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - caller=caller, - ) - return + for report_task in report_tasks: + executor.submit( + self.push_report_task, + report_task, + report_task.organisation_id, + self.create_schedule, + self.process_rescheduling.__name__, + ) + @exception_handler + @tracer.start_as_current_span("push_report_task") + def push_report_task( + self, report_task: models.ReportTask, organisation_id: str, create_schedule: bool, caller: str = "" + ) -> None: if self.is_item_on_queue_by_hash(report_task.hash): - self.logger.debug( - "Report task already on queue", - task_hash=report_task.hash, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - caller=caller, - ) + self.logger.debug("Report task already on queue", scheduler_id=self.scheduler_id, caller=caller) return - task = Task( + task = models.Task( scheduler_id=self.scheduler_id, + organisation=organisation_id, priority=int(datetime.now().timestamp()), type=self.ITEM_TYPE.type, hash=report_task.hash, data=report_task.model_dump(), ) - try: - self.push_item_to_queue_with_timeout(task, self.max_tries) - except QueueFullError: - self.logger.warning( - "Could not add task %s to queue, queue was full", - report_task.hash, - task_hash=report_task.hash, - queue_qsize=self.queue.qsize(), - queue_maxsize=self.queue.maxsize, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - caller=caller, - ) - return + self.push_item_to_queue_with_timeout(task, self.max_tries) self.logger.info( "Report task pushed to queue", task_id=task.id, - task_hash=report_task.hash, - organisation_id=self.organisation.id, + task_hash=task.hash, scheduler_id=self.scheduler_id, caller=caller, ) - def has_report_task_started_running(self, task: ReportTask) -> bool: - task_db = None - try: - task_db = self.ctx.datastores.task_store.get_latest_task_by_hash(task.hash) - except storage.errors.StorageError as exc_db: - self.logger.error( - "Could not get latest task by hash %s", - task.hash, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - exc_info=exc_db, - ) - raise exc_db - - if task_db is not None and task_db.status not in [TaskStatus.FAILED, TaskStatus.COMPLETED]: - self.logger.debug( - "Task is still running, according to the datastore", - task_id=task_db.id, - organisation_id=self.organisation.id, - scheduler_id=self.scheduler_id, - ) - return True + def has_schedule_permission_to_run(self, schedule: models.Schedule) -> bool: + """Check if the schedule has permission to run. - return False + Args: + schedule (models.Schedule): Schedule to check. + + Returns: + bool: True if the schedule has permission to run, False otherwise. + """ + report_task = models.ReportTask.model_validate(schedule.data) + _, count = self.ctx.datastores.task_store.get_tasks( + scheduler_id=self.scheduler_id, + task_type=report_task.type, + filters=filters.FilterRequest( + filters=[ + filters.Filter(column="hash", operator="eq", value=report_task.hash), + filters.Filter(column="schedule_id", operator="eq", value=str(schedule.id)), + ] + ), + ) + return not count diff --git a/mula/scheduler/server/handlers/__init__.py b/mula/scheduler/server/handlers/__init__.py index 302806efaa3..2aea97fa01f 100644 --- a/mula/scheduler/server/handlers/__init__.py +++ b/mula/scheduler/server/handlers/__init__.py @@ -1,6 +1,5 @@ from .health import HealthAPI from .metrics import MetricsAPI -from .queues import QueueAPI from .root import RootAPI from .schedulers import SchedulerAPI from .schedules import ScheduleAPI diff --git a/mula/scheduler/server/handlers/queues.py b/mula/scheduler/server/handlers/queues.py deleted file mode 100644 index 461c897c5e9..00000000000 --- a/mula/scheduler/server/handlers/queues.py +++ /dev/null @@ -1,103 +0,0 @@ -from typing import Any - -import fastapi -import structlog -from fastapi import status - -from scheduler import context, models, schedulers, storage -from scheduler.schedulers.queue import NotAllowedError, QueueEmptyError, QueueFullError -from scheduler.server import serializers -from scheduler.server.errors import BadRequestError, ConflictError, NotFoundError, TooManyRequestsError - - -class QueueAPI: - def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext, s: dict[str, schedulers.Scheduler]) -> None: - self.logger: structlog.BoundLogger = structlog.getLogger(__name__) - self.api: fastapi.FastAPI = api - self.ctx: context.AppContext = ctx - self.schedulers: dict[str, schedulers.Scheduler] = s - - self.api.add_api_route( - path="/queues", - endpoint=self.list, - methods=["GET"], - response_model=list[models.Queue], - response_model_exclude_unset=True, - status_code=status.HTTP_200_OK, - description="List all queues", - ) - - self.api.add_api_route( - path="/queues/{queue_id}", - endpoint=self.get, - methods=["GET"], - response_model=models.Queue, - status_code=status.HTTP_200_OK, - description="Get a queue", - ) - - self.api.add_api_route( - path="/queues/{queue_id}/pop", - endpoint=self.pop, - methods=["POST"], - response_model=models.Task | None, - status_code=status.HTTP_200_OK, - description="Pop an item from a queue", - ) - - self.api.add_api_route( - path="/queues/{queue_id}/push", - endpoint=self.push, - methods=["POST"], - response_model=models.Task | None, - status_code=status.HTTP_201_CREATED, - description="Push an item to a queue", - ) - - def list(self) -> Any: - return [models.Queue(**s.queue.dict(include_pq=False)) for s in self.schedulers.copy().values()] - - def get(self, queue_id: str) -> Any: - s = self.schedulers.get(queue_id) - if s is None: - raise NotFoundError(f"queue not found, by queue_id: {queue_id}") - - return models.Queue(**s.queue.dict()) - - def pop(self, queue_id: str, filters: storage.filters.FilterRequest | None = None) -> Any: - s = self.schedulers.get(queue_id) - if s is None: - raise NotFoundError(f"queue not found, by queue_id: {queue_id}") - - try: - item = s.pop_item_from_queue(filters) - except QueueEmptyError: - return None - - if item is None: - raise NotFoundError("could not pop item from queue, check your filters") - - return models.Task(**item.model_dump()) - - def push(self, queue_id: str, item_in: serializers.Task) -> Any: - s = self.schedulers.get(queue_id) - if s is None: - raise NotFoundError(f"queue not found, by queue_id: {queue_id}") - - # Load default values - new_item = models.Task(**item_in.model_dump(exclude_unset=True)) - - # Set values - if new_item.scheduler_id is None: - new_item.scheduler_id = s.scheduler_id - - try: - pushed_item = s.push_item_to_queue(new_item) - except ValueError: - raise BadRequestError("malformed item") - except QueueFullError: - raise TooManyRequestsError("queue is full") - except NotAllowedError: - raise ConflictError("queue is not allowed to push items") - - return pushed_item diff --git a/mula/scheduler/server/handlers/schedulers.py b/mula/scheduler/server/handlers/schedulers.py index 9358dcec45a..65ca2ac1c9a 100644 --- a/mula/scheduler/server/handlers/schedulers.py +++ b/mula/scheduler/server/handlers/schedulers.py @@ -4,12 +4,14 @@ import structlog from fastapi import status -from scheduler import context, models, schedulers -from scheduler.server.errors import BadRequestError, NotFoundError +from scheduler import context, models, schedulers, storage +from scheduler.schedulers.queue import NotAllowedError, QueueFullError +from scheduler.server import serializers, utils +from scheduler.server.errors import BadRequestError, ConflictError, NotFoundError, TooManyRequestsError class SchedulerAPI: - def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext, s: dict[str, schedulers.Scheduler]) -> None: + def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext, s: dict[str, schedulers.Scheduler]): self.logger: structlog.BoundLogger = structlog.getLogger(__name__) self.api: fastapi.FastAPI = api self.ctx: context.AppContext = ctx @@ -19,7 +21,7 @@ def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext, s: dict[str, s path="/schedulers", endpoint=self.list, methods=["GET"], - response_model=list[models.Scheduler], + response_model=list[serializers.Scheduler], status_code=status.HTTP_200_OK, description="List all schedulers", ) @@ -28,51 +30,80 @@ def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext, s: dict[str, s path="/schedulers/{scheduler_id}", endpoint=self.get, methods=["GET"], - response_model=models.Scheduler, + response_model=serializers.Scheduler, status_code=status.HTTP_200_OK, description="Get a scheduler", ) self.api.add_api_route( - path="/schedulers/{scheduler_id}", - endpoint=self.patch, - methods=["PATCH"], - response_model=models.Scheduler, + path="/schedulers/{scheduler_id}/push", + endpoint=self.push, + methods=["POST"], + response_model=models.Task, + status_code=status.HTTP_201_CREATED, + description="Push a task to a scheduler", + ) + + self.api.add_api_route( + path="/schedulers/{scheduler_id}/pop", + endpoint=self.pop, + methods=["POST"], + response_model=utils.PaginatedResponse, status_code=status.HTTP_200_OK, - description="Update a scheduler", + description="Pop a task from a scheduler", ) - def list(self) -> Any: - return [models.Scheduler(**s.dict()) for s in self.schedulers.values()] + def list(self) -> list[serializers.Scheduler]: + return [serializers.Scheduler(**s.dict()) for s in self.schedulers.values()] def get(self, scheduler_id: str) -> Any: s = self.schedulers.get(scheduler_id) if s is None: raise NotFoundError(f"Scheduler {scheduler_id} not found") - return models.Scheduler(**s.dict()) + return serializers.Scheduler(**s.dict()) + + def pop( + self, + request: fastapi.Request, + scheduler_id: str, + offset: int = 0, + limit: int = 100, + filters: storage.filters.FilterRequest | None = None, + ) -> utils.PaginatedResponse: + results, count = self.ctx.datastores.pq_store.pop( + scheduler_id=scheduler_id, offset=offset, limit=limit, filters=filters + ) + + # Update status for popped items + self.ctx.datastores.pq_store.bulk_update_status( + scheduler_id, [item.id for item in results], models.TaskStatus.DISPATCHED + ) + + return utils.paginate(request, results, count, offset, limit) - def patch(self, scheduler_id: str, item: models.Scheduler) -> Any: + def push(self, scheduler_id: str, item: serializers.TaskPush) -> Any: s = self.schedulers.get(scheduler_id) if s is None: raise NotFoundError(f"Scheduler {scheduler_id} not found") - stored_scheduler_model = models.Scheduler(**s.dict()) - patch_data = item.model_dump(exclude_unset=True) - if len(patch_data) == 0: - raise BadRequestError("no data to patch") + if item.scheduler_id is not None and item.scheduler_id != scheduler_id: + raise BadRequestError("scheduler_id in item does not match the scheduler_id in the path") - updated_scheduler = stored_scheduler_model.model_copy(update=patch_data) + # Set scheduler_id if not set + if item.scheduler_id is None: + item.scheduler_id = scheduler_id - # We update the patched attributes, since the schedulers are kept - # in memory. - for attr, value in patch_data.items(): - setattr(s, attr, value) + # Load default values + new_item = models.Task(**item.model_dump(exclude_unset=True)) - # Enable or disable the scheduler if needed. - if updated_scheduler.enabled: - s.enable() - elif not updated_scheduler.enabled: - s.disable() + try: + pushed_item = s.push_item_to_queue(new_item) + except ValueError: + raise BadRequestError("malformed item") + except QueueFullError: + raise TooManyRequestsError("queue is full") + except NotAllowedError: + raise ConflictError("queue is not allowed to push items") - return updated_scheduler + return pushed_item diff --git a/mula/scheduler/server/handlers/schedules.py b/mula/scheduler/server/handlers/schedules.py index 895a50c9b24..e67fa0f9bc6 100644 --- a/mula/scheduler/server/handlers/schedules.py +++ b/mula/scheduler/server/handlers/schedules.py @@ -12,13 +12,11 @@ class ScheduleAPI: - def __init__( - self, api: fastapi.FastAPI, ctx: context.AppContext, schedulers: dict[str, schedulers.Scheduler] - ) -> None: - self.logger: structlog.BoundLogger = structlog.get_logger(__name__) - self.api = api - self.ctx = ctx - self.schedulers = schedulers + def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext, s: dict[str, schedulers.Scheduler]): + self.logger: structlog.BoundLogger = structlog.getLogger(__name__) + self.api: fastapi.FastAPI = api + self.ctx: context.AppContext = ctx + self.schedulers: dict[str, schedulers.Scheduler] = s self.api.add_api_route( path="/schedules", @@ -113,8 +111,8 @@ def create(self, schedule: serializers.ScheduleCreate) -> Any: try: new_schedule = models.Schedule(**schedule.model_dump()) - except ValueError: - raise ValidationError("validation error") + except ValueError as exc: + raise ValidationError(exc) s = self.schedulers.get(new_schedule.scheduler_id) if s is None: @@ -123,8 +121,8 @@ def create(self, schedule: serializers.ScheduleCreate) -> Any: # Validate data with task type of the scheduler try: instance = s.ITEM_TYPE.model_validate(new_schedule.data) - except ValueError: - raise BadRequestError("validation error") + except ValueError as exc: + raise BadRequestError(exc) # Create hash for schedule with task type new_schedule.hash = instance.hash diff --git a/mula/scheduler/server/serializers/__init__.py b/mula/scheduler/server/serializers/__init__.py index a4d3c0b20c4..ac706a15163 100644 --- a/mula/scheduler/server/serializers/__init__.py +++ b/mula/scheduler/server/serializers/__init__.py @@ -1,2 +1,3 @@ from .schedule import ScheduleCreate, SchedulePatch -from .task import Task, TaskStatus +from .scheduler import Scheduler +from .task import Task, TaskPush, TaskStatus diff --git a/mula/scheduler/server/serializers/schedule.py b/mula/scheduler/server/serializers/schedule.py index 5e3c0a0bbb9..e614b623f50 100644 --- a/mula/scheduler/server/serializers/schedule.py +++ b/mula/scheduler/server/serializers/schedule.py @@ -7,11 +7,9 @@ class ScheduleCreate(BaseModel): model_config = ConfigDict(from_attributes=True) scheduler_id: str - + organisation: str data: dict - schedule: str | None = None - deadline_at: datetime | None = None @@ -20,11 +18,7 @@ class SchedulePatch(BaseModel): model_config = ConfigDict(from_attributes=True) hash: str | None = Field(None, max_length=32) - data: dict | None = None - enabled: bool | None = None - schedule: str | None = None - deadline_at: datetime | None = None diff --git a/mula/scheduler/server/serializers/scheduler.py b/mula/scheduler/server/serializers/scheduler.py new file mode 100644 index 00000000000..f267e98909d --- /dev/null +++ b/mula/scheduler/server/serializers/scheduler.py @@ -0,0 +1,11 @@ +from datetime import datetime + +from pydantic import BaseModel + + +class Scheduler(BaseModel): + id: str + type: str + item_type: str + qsize: int = 0 + last_activity: datetime | None = None diff --git a/mula/scheduler/server/serializers/task.py b/mula/scheduler/server/serializers/task.py index 3a4e6fc3846..cc2aafbfdac 100644 --- a/mula/scheduler/server/serializers/task.py +++ b/mula/scheduler/server/serializers/task.py @@ -34,21 +34,20 @@ class Task(BaseModel): model_config = ConfigDict(from_attributes=True, use_enum_values=True) id: uuid.UUID | None = None - scheduler_id: str | None = None - schedule_id: uuid.UUID | None = None - + organisation: str | None = None priority: int | None = None - status: TaskStatus | None = None - type: str | None = None - hash: str | None = None - data: dict | None = None - created_at: datetime | None = None - modified_at: datetime | None = None + + +class TaskPush(BaseModel): + scheduler_id: str | None = None + organisation: str + priority: int | None = None + data: dict diff --git a/mula/scheduler/server/server.py b/mula/scheduler/server/server.py index b39cf1fca5c..2c08ebcc156 100644 --- a/mula/scheduler/server/server.py +++ b/mula/scheduler/server/server.py @@ -19,7 +19,7 @@ class Server: api: A fastapi.FastAPI object used for exposing API endpoints. """ - def __init__(self, ctx: context.AppContext, s: dict[str, schedulers.Scheduler]): + def __init__(self, ctx: context.AppContext, s: dict[str, schedulers.Scheduler]) -> None: """Initializer of the Server class. Args: @@ -45,7 +45,6 @@ def __init__(self, ctx: context.AppContext, s: dict[str, schedulers.Scheduler]): # Set up API endpoints handlers.SchedulerAPI(self.api, self.ctx, s) - handlers.QueueAPI(self.api, self.ctx, s) handlers.ScheduleAPI(self.api, self.ctx, s) handlers.TaskAPI(self.api, self.ctx) handlers.MetricsAPI(self.api, self.ctx) diff --git a/mula/scheduler/storage/connection.py b/mula/scheduler/storage/connection.py index dc381191528..4787afe44bf 100644 --- a/mula/scheduler/storage/connection.py +++ b/mula/scheduler/storage/connection.py @@ -10,7 +10,7 @@ class DBConn: def __init__(self, dsn: str, pool_size: int = 25): - self.logger: structlog.BoundLogger = structlog.get_logger(__name__) + self.logger: structlog.BoundLogger = structlog.getLogger(__name__) self.dsn = dsn self.pool_size = pool_size diff --git a/mula/scheduler/storage/migrations/versions/0009_add_organisation.py b/mula/scheduler/storage/migrations/versions/0009_add_organisation.py new file mode 100644 index 00000000000..3aff6279a7f --- /dev/null +++ b/mula/scheduler/storage/migrations/versions/0009_add_organisation.py @@ -0,0 +1,30 @@ +"""add_organisation + +Revision ID: 0009 +Revises: 0008 +Create Date: 2024-12-10 15:21:27.445743 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "0009" +down_revision = "0008" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("schedules", sa.Column("organisation", sa.String(), nullable=False)) + op.add_column("tasks", sa.Column("organisation", sa.String(), nullable=False)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("tasks", "organisation") + op.drop_column("schedules", "organisation") + # ### end Alembic commands ### diff --git a/mula/scheduler/storage/storage.py b/mula/scheduler/storage/storage.py deleted file mode 100644 index 7fe2f8d1438..00000000000 --- a/mula/scheduler/storage/storage.py +++ /dev/null @@ -1,52 +0,0 @@ -import json -from functools import partial - -import sqlalchemy -import structlog - -from scheduler.config import settings - -from .errors import StorageError - - -class DBConn: - def __init__(self, dsn: str, pool_size: int = 25): - self.logger: structlog.BoundLogger = structlog.getLogger(__name__) - - self.dsn = dsn - self.pool_size = pool_size - - def connect(self) -> None: - db_uri_redacted = sqlalchemy.engine.make_url(name_or_url=self.dsn).render_as_string(hide_password=True) - - pool_size = settings.Settings().db_connection_pool_size - - self.logger.debug( - "Connecting to database %s with pool size %s...", - self.dsn, - pool_size, - dsn=db_uri_redacted, - pool_size=pool_size, - ) - - try: - serializer = partial(json.dumps, default=str) - self.engine = sqlalchemy.create_engine( - self.dsn, - pool_pre_ping=True, - pool_size=pool_size, - pool_recycle=300, - json_serializer=serializer, - connect_args={"options": "-c timezone=utc"}, - ) - except sqlalchemy.exc.SQLAlchemyError as e: - self.logger.error("Failed to connect to database %s: %s", self.dsn, e, dsn=db_uri_redacted) - raise StorageError("Failed to connect to database.") - - self.logger.debug("Connected to database %s.", db_uri_redacted, dsn=db_uri_redacted) - - try: - self.session = sqlalchemy.orm.sessionmaker(bind=self.engine) - except sqlalchemy.exc.SQLAlchemyError as e: - self.logger.error("Failed to create session: %s", e) - raise StorageError("Failed to create session.") diff --git a/mula/scheduler/storage/stores/pq.py b/mula/scheduler/storage/stores/pq.py index feb62bd01c7..b7c8951225c 100644 --- a/mula/scheduler/storage/stores/pq.py +++ b/mula/scheduler/storage/stores/pq.py @@ -1,8 +1,10 @@ from uuid import UUID +from sqlalchemy import exc + from scheduler import models from scheduler.storage import DBConn -from scheduler.storage.errors import exception_handler +from scheduler.storage.errors import StorageError, exception_handler from scheduler.storage.filters import FilterRequest, apply_filter from scheduler.storage.utils import retry @@ -15,25 +17,33 @@ def __init__(self, dbconn: DBConn) -> None: @retry() @exception_handler - def pop(self, scheduler_id: str, filters: FilterRequest | None = None) -> models.Task | None: + def pop( + self, scheduler_id: str | None = None, offset: int = 0, limit: int = 100, filters: FilterRequest | None = None + ) -> tuple[list[models.Task], int]: with self.dbconn.session.begin() as session: - query = ( - session.query(models.TaskDB) - .filter(models.TaskDB.status == models.TaskStatus.QUEUED) - .order_by(models.TaskDB.priority.asc()) - .order_by(models.TaskDB.created_at.asc()) - .filter(models.TaskDB.scheduler_id == scheduler_id) - ) + query = session.query(models.TaskDB).filter(models.TaskDB.status == models.TaskStatus.QUEUED) + + if scheduler_id is not None: + query = query.filter(models.TaskDB.scheduler_id == scheduler_id) if filters is not None: query = apply_filter(models.TaskDB, query, filters) - item_orm = query.first() + try: + count = query.count() + item_orm = ( + query.order_by(models.TaskDB.priority.asc()) + .order_by(models.TaskDB.created_at.asc()) + .offset(offset) + .limit(limit) + .all() + ) + except exc.ProgrammingError as e: + raise StorageError(f"Invalid filter: {e}") from e - if item_orm is None: - return None + items = [models.Task.model_validate(item_orm) for item_orm in item_orm] - return models.Task.model_validate(item_orm) + return items, count @retry() @exception_handler @@ -188,3 +198,14 @@ def clear(self, scheduler_id: str) -> None: .filter(models.TaskDB.scheduler_id == scheduler_id) .delete(), ) + + @retry() + @exception_handler + def bulk_update_status(self, scheduler_id: str, item_ids: list[UUID], status: models.TaskStatus) -> None: + with self.dbconn.session.begin() as session: + ( + session.query(models.TaskDB) + .filter(models.TaskDB.scheduler_id == scheduler_id) + .filter(models.TaskDB.id.in_([str(item_id) for item_id in item_ids])) + .update({"status": status.name}, synchronize_session=False), + ) diff --git a/mula/scheduler/storage/stores/schedule.py b/mula/scheduler/storage/stores/schedule.py index a91b680d03b..40ff8bbff50 100644 --- a/mula/scheduler/storage/stores/schedule.py +++ b/mula/scheduler/storage/stores/schedule.py @@ -21,7 +21,7 @@ def get_schedules( self, scheduler_id: str | None = None, schedule_hash: str | None = None, - enabled: bool | None = True, # FIXME: None? + enabled: bool | None = True, min_deadline_at: datetime | None = None, max_deadline_at: datetime | None = None, min_created_at: datetime | None = None, diff --git a/mula/tests/integration/test_api.py b/mula/tests/integration/test_api.py index 6eaa82086c2..e234facd265 100644 --- a/mula/tests/integration/test_api.py +++ b/mula/tests/integration/test_api.py @@ -64,7 +64,7 @@ def tearDown(self): self.dbconn.engine.dispose() -class APITestCase(APITemplateTestCase): +class APISchedulerEndpointTestCase(APITemplateTestCase): def test_get_schedulers(self): response = self.client.get("/schedulers") self.assertEqual(response.status_code, 200) @@ -78,78 +78,12 @@ def test_get_scheduler_malformed_id(self): response = self.client.get("/schedulers/123.123") self.assertEqual(response.status_code, 404) - def test_patch_scheduler(self): - self.assertTrue(self.scheduler.is_enabled()) - response = self.client.patch(f"/schedulers/{self.scheduler.scheduler_id}", json={"enabled": False}) - self.assertEqual(200, response.status_code) - self.assertFalse(response.json().get("enabled")) - self.assertFalse(self.scheduler.is_enabled()) - - def test_patch_scheduler_attr_not_found(self): - response = self.client.patch(f"/schedulers/{self.scheduler.scheduler_id}", json={"not_found": "not found"}) - self.assertEqual(response.status_code, 400) - self.assertEqual(response.json(), {"detail": "Bad request error occurred: no data to patch"}) - - def test_patch_scheduler_not_found(self): - mock_id = uuid.uuid4() - response = self.client.patch(f"/schedulers/{mock_id}", json={"enabled": False}) - self.assertEqual(response.status_code, 404) - self.assertEqual(response.json(), {"detail": f"Resource not found: Scheduler {mock_id} not found"}) - - def test_patch_scheduler_disable(self): - self.assertTrue(self.scheduler.is_enabled()) - response = self.client.patch(f"/schedulers/{self.scheduler.scheduler_id}", json={"enabled": False}) - self.assertEqual(200, response.status_code) - self.assertFalse(response.json().get("enabled")) - self.assertFalse(self.scheduler.is_enabled()) - - # Try to push to queue - item = create_task_in(0) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=item) - self.assertNotEqual(response.status_code, 201) - self.assertEqual(0, self.scheduler.queue.qsize()) - - def test_patch_scheduler_enable(self): - # Disable queue first - self.assertTrue(self.scheduler.is_enabled()) - response = self.client.patch(f"/schedulers/{self.scheduler.scheduler_id}", json={"enabled": False}) - self.assertEqual(200, response.status_code) - self.assertFalse(response.json().get("enabled")) - self.assertFalse(self.scheduler.is_enabled()) - - # Enable again - response = self.client.patch(f"/schedulers/{self.scheduler.scheduler_id}", json={"enabled": True}) - self.assertEqual(200, response.status_code) - self.assertTrue(response.json().get("enabled")) - self.assertTrue(self.scheduler.is_enabled()) - - # Try to push to queue - self.assertEqual(0, self.scheduler.queue.qsize()) - item = create_task_in(1) - - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=item) - self.assertEqual(response.status_code, 201) - self.assertEqual(1, self.scheduler.queue.qsize()) - - def test_get_queues(self): - response = self.client.get("/queues") - self.assertEqual(response.status_code, 200) - - def test_get_queue(self): - response = self.client.get(f"/queues/{self.scheduler.scheduler_id}") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json().get("id"), self.scheduler.scheduler_id) - - def test_get_queue_malformed_id(self): - response = self.client.get("/queues/123.123") - self.assertEqual(response.status_code, 404) - def test_push_queue(self): self.assertEqual(0, self.scheduler.queue.qsize()) - item = create_task_in(1) + item = create_task_in(1, self.organisation.id) - response_post = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=item) + response_post = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=item) self.assertEqual(201, response_post.status_code) self.assertEqual(1, self.scheduler.queue.qsize()) self.assertIsNotNone(response_post.json().get("id")) @@ -166,23 +100,24 @@ def test_push_queue(self): def test_push_incorrect_item_type(self): response = self.client.post( - f"/queues/{self.scheduler.scheduler_id}/push", json={"priority": 0, "item": "not a task"} + f"/schedulers/{self.scheduler.scheduler_id}/push", json={"organisation": self.organisation.id, "data": {}} ) self.assertEqual(response.status_code, 400) + self.assertEqual(response.json(), {"detail": "Bad request error occurred: malformed item"}) def test_push_queue_full(self): # Set maxsize of the queue to 1 self.scheduler.queue.maxsize = 1 # Add one task to the queue - first_item = create_task_in(1) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=first_item) + first_item = create_task_in(1, self.organisation.id) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=first_item) self.assertEqual(response.status_code, 201) self.assertEqual(1, self.scheduler.queue.qsize()) # Try to add another task to the queue through the api - second_item = create_task_in(2) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=second_item) + second_item = create_task_in(2, self.organisation.id) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=second_item) self.assertEqual(response.status_code, 429) self.assertEqual(1, self.scheduler.queue.qsize()) @@ -191,14 +126,14 @@ def test_push_queue_full_high_priority(self): self.scheduler.queue.maxsize = 1 # Add one task to the queue - first_item = create_task_in(1) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=first_item) + first_item = create_task_in(1, self.organisation.id) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=first_item) self.assertEqual(response.status_code, 201) self.assertEqual(1, self.scheduler.queue.qsize()) # Try to add another task to the queue through the api - second_item = create_task_in(1) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=second_item) + second_item = create_task_in(1, self.organisation.id) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=second_item) self.assertEqual(response.status_code, 201) self.assertEqual(2, self.scheduler.queue.qsize()) @@ -212,13 +147,13 @@ def test_push_replace_not_allowed(self): self.scheduler.queue.allow_priority_updates = False # Add one task to the queue - initial_item = create_task_in(1) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=initial_item) + initial_item = create_task_in(1, self.organisation.id) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=initial_item) self.assertEqual(response.status_code, 201) self.assertEqual(1, self.scheduler.queue.qsize()) # Add the same item again through the api - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=initial_item) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=initial_item) # The queue should still have one item self.assertEqual(response.status_code, 409) @@ -230,13 +165,13 @@ def test_push_replace_allowed(self): self.scheduler.queue.allow_replace = True # Add one task to the queue - initial_item = create_task_in(1) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=initial_item) + initial_item = create_task_in(1, self.organisation.id) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=initial_item) self.assertEqual(response.status_code, 201) self.assertEqual(1, self.scheduler.queue.qsize()) # Add the same item again through the api - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", json=response.json()) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", json=response.json()) # The queue should have one item self.assertEqual(response.status_code, 201) @@ -252,8 +187,8 @@ def test_push_updates_not_allowed(self): self.scheduler.queue.allow_priority_updates = False # Add one task to the queue - initial_item = create_task_in(1) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=initial_item) + initial_item = create_task_in(1, self.organisation.id) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=initial_item) self.assertEqual(response.status_code, 201) self.assertEqual(1, self.scheduler.queue.qsize()) @@ -262,7 +197,9 @@ def test_push_updates_not_allowed(self): updated_item.data["name"] = "updated-name" # Try to update the item through the api - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=updated_item.model_dump_json()) + response = self.client.post( + f"/schedulers/{self.scheduler.scheduler_id}/push", data=updated_item.model_dump_json() + ) # The queue should still have one item self.assertEqual(response.status_code, 409) @@ -274,8 +211,8 @@ def test_push_updates_allowed(self): self.scheduler.queue.allow_updates = True # Add one task to the queue - initial_item = create_task_in(1) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=initial_item) + initial_item = create_task_in(1, self.organisation.id) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=initial_item) self.assertEqual(response.status_code, 201) self.assertEqual(1, self.scheduler.queue.qsize()) @@ -284,7 +221,9 @@ def test_push_updates_allowed(self): updated_item.data["name"] = "updated-name" # Try to update the item through the api - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=updated_item.model_dump_json()) + response = self.client.post( + f"/schedulers/{self.scheduler.scheduler_id}/push", data=updated_item.model_dump_json() + ) self.assertEqual(response.status_code, 201) # The queue should have one item @@ -301,8 +240,8 @@ def test_push_priority_updates_not_allowed(self): self.scheduler.queue.allow_priority_updates = False # Add one task to the queue - initial_item = create_task_in(1) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=initial_item) + initial_item = create_task_in(1, self.organisation.id) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=initial_item) self.assertEqual(response.status_code, 201) self.assertEqual(1, self.scheduler.queue.qsize()) @@ -311,7 +250,9 @@ def test_push_priority_updates_not_allowed(self): updated_item.priority = 2 # Try to update the item through the api - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=updated_item.model_dump_json()) + response = self.client.post( + f"/schedulers/{self.scheduler.scheduler_id}/push", data=updated_item.model_dump_json() + ) # The queue should still have one item self.assertEqual(response.status_code, 409) @@ -328,8 +269,8 @@ def test_update_priority_higher(self): self.scheduler.queue.allow_priority_updates = True # Add one task to the queue - initial_item = create_task_in(2) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=initial_item) + initial_item = create_task_in(2, self.organisation.id) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=initial_item) self.assertEqual(response.status_code, 201) # Update priority of the item @@ -337,7 +278,9 @@ def test_update_priority_higher(self): updated_item.priority = 1 # Try to update the item through the api - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=updated_item.model_dump_json()) + response = self.client.post( + f"/schedulers/{self.scheduler.scheduler_id}/push", data=updated_item.model_dump_json() + ) self.assertEqual(response.status_code, 201) # The queue should have one item @@ -356,8 +299,8 @@ def test_update_priority_lower(self): self.scheduler.queue.allow_priority_updates = True # Add one task to the queue - initial_item = create_task_in(1) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=initial_item) + initial_item = create_task_in(1, self.organisation.id) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=initial_item) self.assertEqual(response.status_code, 201) # Update priority of the item @@ -365,7 +308,9 @@ def test_update_priority_lower(self): updated_item.priority = 2 # Try to update the item through the api - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=updated_item.model_dump_json()) + response = self.client.post( + f"/schedulers/{self.scheduler.scheduler_id}/push", data=updated_item.model_dump_json() + ) self.assertEqual(response.status_code, 201) # The queue should have one item @@ -376,135 +321,168 @@ def test_update_priority_lower(self): def test_pop_queue(self): # Add one task to the queue - initial_item = create_task_in(1) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=initial_item) + initial_item = create_task_in(1, self.organisation.id) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=initial_item) initial_item_id = response.json().get("id") self.assertEqual(response.status_code, 201) self.assertEqual(1, self.scheduler.queue.qsize()) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/pop") + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/pop") self.assertEqual(200, response.status_code) - self.assertEqual(initial_item_id, response.json().get("id")) + self.assertEqual(1, response.json().get("count")) + self.assertEqual(initial_item_id, response.json().get("results")[0].get("id")) self.assertEqual(0, self.scheduler.queue.qsize()) + # Status of the item should be DISPATCHED + get_item = self.client.get(f"/tasks/{initial_item_id}") + self.assertEqual(get_item.json().get("status"), models.TaskStatus.DISPATCHED.name.lower()) + def test_pop_queue_not_found(self): mock_id = uuid.uuid4() - response = self.client.post(f"/queues/{mock_id}/pop") - self.assertEqual(404, response.status_code) - self.assertEqual({"detail": f"Resource not found: queue not found, by queue_id: {mock_id}"}, response.json()) + response = self.client.post(f"/schedulers/{mock_id}/pop") + self.assertEqual(200, response.status_code) + self.assertEqual(0, response.json().get("count")) - def test_pop_queue_filters(self): + def test_pop_queue_filters_two_items(self): # Add one task to the queue - first_item = create_task_in(1, data=functions.TestModel(id="123", name="test")) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=first_item) + first_item = create_task_in(1, self.organisation.id, data=functions.TestModel(id="123", name="test")) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=first_item) first_item_id = response.json().get("id") self.assertEqual(response.status_code, 201) self.assertEqual(1, self.scheduler.queue.qsize()) # Add second item to the queue - second_item = create_task_in(2, data=functions.TestModel(id="456", name="test")) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=second_item) + second_item = create_task_in(2, self.organisation.id, data=functions.TestModel(id="456", name="test")) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=second_item) second_item_id = response.json().get("id") self.assertEqual(response.status_code, 201) self.assertEqual(2, self.scheduler.queue.qsize()) - # Should get the first item + # Should get two items, and queue should be empty response = self.client.post( - f"/queues/{self.scheduler.scheduler_id}/pop", + f"/schedulers/{self.scheduler.scheduler_id}/pop", json={"filters": [{"column": "data", "field": "name", "operator": "eq", "value": "test"}]}, ) self.assertEqual(200, response.status_code) - self.assertEqual(first_item_id, response.json().get("id")) + self.assertEqual(2, response.json().get("count")) + self.assertEqual(first_item_id, response.json().get("results")[0].get("id")) + self.assertEqual(second_item_id, response.json().get("results")[1].get("id")) + self.assertEqual(0, self.scheduler.queue.qsize()) + + # Status of the items should be DISPATCHED + get_first_item = self.client.get(f"/tasks/{first_item_id}") + get_second_item = self.client.get(f"/tasks/{second_item_id}") + self.assertEqual(get_first_item.json().get("status"), models.TaskStatus.DISPATCHED.name.lower()) + self.assertEqual(get_second_item.json().get("status"), models.TaskStatus.DISPATCHED.name.lower()) + + def test_pop_queue_filters_one_item(self): + # Add one task to the queue + first_item = create_task_in(1, self.organisation.id, data=functions.TestModel(id="123", name="test")) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=first_item) + first_item_id = response.json().get("id") + self.assertEqual(response.status_code, 201) self.assertEqual(1, self.scheduler.queue.qsize()) - # Should not return any items + # Add second item to the queue + second_item = create_task_in(2, self.organisation.id, data=functions.TestModel(id="456", name="test")) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=second_item) + second_item_id = response.json().get("id") + self.assertEqual(response.status_code, 201) + self.assertEqual(2, self.scheduler.queue.qsize()) + + # Should get the first item, and should still be an item on the queue response = self.client.post( - f"/queues/{self.scheduler.scheduler_id}/pop", + f"/schedulers/{self.scheduler.scheduler_id}/pop", json={"filters": [{"column": "data", "field": "id", "operator": "eq", "value": "123"}]}, ) - self.assertEqual(404, response.status_code) - self.assertEqual( - response.json(), {"detail": "Resource not found: could not pop item from queue, check your filters"} - ) + self.assertEqual(200, response.status_code) + self.assertEqual(1, response.json().get("count")) + self.assertEqual(first_item_id, response.json().get("results")[0].get("id")) self.assertEqual(1, self.scheduler.queue.qsize()) - # Should get the second item + # Should get the second item, and should be no items on the queue response = self.client.post( - f"/queues/{self.scheduler.scheduler_id}/pop", - json={"filters": [{"column": "data", "field": "name", "operator": "eq", "value": "test"}]}, + f"/schedulers/{self.scheduler.scheduler_id}/pop", + json={"filters": [{"column": "data", "field": "id", "operator": "eq", "value": "456"}]}, ) self.assertEqual(200, response.status_code) - self.assertEqual(second_item_id, response.json().get("id")) + self.assertEqual(1, response.json().get("count")) + self.assertEqual(second_item_id, response.json().get("results")[0].get("id")) self.assertEqual(0, self.scheduler.queue.qsize()) def test_pop_queue_filters_nested(self): # Add one task to the queue - first_item = create_task_in(1, data=functions.TestModel(id="123", name="test", categories=["foo", "bar"])) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=first_item) + first_item = create_task_in( + 1, self.organisation.id, data=functions.TestModel(id="123", name="test", categories=["foo", "bar"]) + ) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=first_item) first_item_id = response.json().get("id") self.assertEqual(response.status_code, 201) self.assertEqual(1, self.scheduler.queue.qsize()) # Add second item to the queue - second_item = create_task_in(2, data=functions.TestModel(id="456", name="test", categories=["baz", "bat"])) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=second_item) + second_item = create_task_in( + 2, self.organisation.id, data=functions.TestModel(id="456", name="test", categories=["baz", "bat"]) + ) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=second_item) second_item_id = response.json().get("id") self.assertEqual(response.status_code, 201) self.assertEqual(2, self.scheduler.queue.qsize()) # Should get the first item response = self.client.post( - f"/queues/{self.scheduler.scheduler_id}/pop", + f"/schedulers/{self.scheduler.scheduler_id}/pop", json={ "filters": [{"column": "data", "operator": "@>", "value": json.dumps({"categories": ["foo", "bar"]})}] }, ) self.assertEqual(200, response.status_code) - self.assertEqual(first_item_id, response.json().get("id")) + self.assertEqual(first_item_id, response.json().get("results")[0].get("id")) self.assertEqual(1, self.scheduler.queue.qsize()) # Should not return any items response = self.client.post( - f"/queues/{self.scheduler.scheduler_id}/pop", + f"/schedulers/{self.scheduler.scheduler_id}/pop", json={ "filters": [{"column": "data", "operator": "@>", "value": json.dumps({"categories": ["foo", "bar"]})}] }, ) - - self.assertEqual(404, response.status_code) - self.assertEqual( - response.json(), {"detail": "Resource not found: could not pop item from queue, check your filters"} - ) + self.assertEqual(200, response.status_code) + self.assertEqual(0, response.json().get("count")) self.assertEqual(1, self.scheduler.queue.qsize()) # Should get the second item response = self.client.post( - f"/queues/{self.scheduler.scheduler_id}/pop", + f"/schedulers/{self.scheduler.scheduler_id}/pop", json={ "filters": [{"column": "data", "operator": "@>", "value": json.dumps({"categories": ["baz", "bat"]})}] }, ) self.assertEqual(200, response.status_code) - self.assertEqual(second_item_id, response.json().get("id")) + self.assertEqual(second_item_id, response.json().get("results")[0].get("id")) self.assertEqual(0, self.scheduler.queue.qsize()) def test_pop_queue_filters_nested_contained_by(self): # Add one task to the queue - first_item = create_task_in(1, data=functions.TestModel(id="123", name="test", categories=["foo", "bar"])) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=first_item) + first_item = create_task_in( + 1, self.organisation.id, data=functions.TestModel(id="123", name="test", categories=["foo", "bar"]) + ) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=first_item) self.assertEqual(response.status_code, 201) self.assertEqual(1, self.scheduler.queue.qsize()) # Add second item to the queue - second_item = create_task_in(2, data=functions.TestModel(id="456", name="test", categories=["baz", "bat"])) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=second_item) + second_item = create_task_in( + 2, self.organisation.id, data=functions.TestModel(id="456", name="test", categories=["baz", "bat"]) + ) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=second_item) second_item_id = response.json().get("id") self.assertEqual(response.status_code, 201) self.assertEqual(2, self.scheduler.queue.qsize()) # Test contained by response = self.client.post( - f"/queues/{self.scheduler.scheduler_id}/pop", + f"/schedulers/{self.scheduler.scheduler_id}/pop", json={ "filters": [ {"column": "data", "operator": "<@", "field": "categories", "value": json.dumps(["baz", "bat"])} @@ -513,13 +491,14 @@ def test_pop_queue_filters_nested_contained_by(self): ) self.assertEqual(200, response.status_code) - self.assertEqual(second_item_id, response.json().get("id")) + self.assertEqual(second_item_id, response.json().get("results")[0].get("id")) self.assertEqual(1, self.scheduler.queue.qsize()) def test_pop_empty(self): """When queue is empty it should return an empty response""" - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/pop") + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/pop") self.assertEqual(200, response.status_code) + self.assertEqual(0, response.json().get("count")) class APITasksEndpointTestCase(APITemplateTestCase): @@ -529,9 +508,10 @@ def setUp(self): # Add one task to the queue first_item = create_task_in( 1, + self.organisation.id, data=functions.TestModel(id="123", name="test", child=functions.TestModel(id="123.123", name="test.child")), ) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=first_item) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=first_item) initial_item_id = response.json().get("id") self.assertEqual(response.status_code, 201) self.assertEqual(1, self.scheduler.queue.qsize()) @@ -539,8 +519,8 @@ def setUp(self): self.first_item_api = self.client.get(f"/tasks/{initial_item_id}").json() # Add second item to the queue - second_item = create_task_in(1, data=functions.TestModel(id="456", name="test")) - response = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=second_item) + second_item = create_task_in(1, self.organisation.id, data=functions.TestModel(id="456", name="test")) + response = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=second_item) second_item_id = response.json().get("id") self.assertEqual(response.status_code, 201) self.assertEqual(2, self.scheduler.queue.qsize()) @@ -548,8 +528,8 @@ def setUp(self): self.second_item_api = self.client.get(f"/tasks/{second_item_id}").json() def test_create_task(self): - item = create_task_in(1) - response_post = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=item) + item = create_task_in(1, self.organisation.id) + response_post = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=item) self.assertEqual(201, response_post.status_code) initial_item_id = response_post.json().get("id") @@ -574,9 +554,9 @@ def test_get_tasks(self): def test_get_task(self): # First add a task - item = create_task_in(1) + item = create_task_in(1, self.organisation.id) - response_post = self.client.post(f"/queues/{self.scheduler.scheduler_id}/push", data=item) + response_post = self.client.post(f"/schedulers/{self.scheduler.scheduler_id}/push", data=item) self.assertEqual(201, response_post.status_code) initial_item_id = response_post.json().get("id") @@ -740,20 +720,22 @@ class APIScheduleEndpointTestCase(APITemplateTestCase): def setUp(self): super().setUp() - self.first_item = functions.create_item(self.scheduler.scheduler_id, 1) + self.first_item = functions.create_task(self.scheduler.scheduler_id, self.organisation.id) self.first_schedule = self.mock_ctx.datastores.schedule_store.create_schedule( models.Schedule( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, hash=self.first_item.hash, data=self.first_item.data, deadline_at=datetime.now(timezone.utc) + timedelta(days=1), ) ) - self.second_item = functions.create_item(self.scheduler.scheduler_id, 1) + self.second_item = functions.create_task(self.scheduler.scheduler_id, self.organisation.id) self.second_schedule = self.mock_ctx.datastores.schedule_store.create_schedule( models.Schedule( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, hash=self.second_item.hash, data=self.second_item.data, deadline_at=datetime.now(timezone.utc) + timedelta(days=2), @@ -886,9 +868,15 @@ def test_list_schedules_min_and_max_created_at(self): self.assertEqual(str(self.first_schedule.id), response.json()["results"][0]["id"]) def test_post_schedule(self): - item = functions.create_item(self.scheduler.scheduler_id, 1) + item = functions.create_task(self.scheduler.scheduler_id, self.organisation.id) response = self.client.post( - "/schedules", json={"scheduler_id": item.scheduler_id, "schedule": "*/5 * * * *", "data": item.data} + "/schedules", + json={ + "scheduler_id": item.scheduler_id, + "organisation": self.organisation.id, + "schedule": "*/5 * * * *", + "data": item.data, + }, ) self.assertEqual(201, response.status_code) self.assertEqual(item.hash, response.json().get("hash")) @@ -904,10 +892,16 @@ def test_post_schedule(self): def test_post_schedule_explicit_deadline_at(self): """When a schedule is created, the deadline_at should be set if it is provided.""" - item = functions.create_item(self.scheduler.scheduler_id, 1) + item = functions.create_task(self.scheduler.scheduler_id, self.organisation.id) now = datetime.now(timezone.utc) response = self.client.post( - "/schedules", json={"scheduler_id": item.scheduler_id, "data": item.data, "deadline_at": now.isoformat()} + "/schedules", + json={ + "scheduler_id": item.scheduler_id, + "organisation": self.organisation.id, + "data": item.data, + "deadline_at": now.isoformat(), + }, ) self.assertEqual(201, response.status_code) self.assertIsNone(response.json().get("schedule")) @@ -920,54 +914,92 @@ def test_post_schedule_explicit_deadline_at(self): def test_post_schedule_schedule_and_deadline_at_none(self): """When a schedule is created, both schedule and deadline_at should not be None.""" - item = functions.create_item(self.scheduler.scheduler_id, 1) - response = self.client.post("/schedules", json={"scheduler_id": item.scheduler_id, "data": item.data}) + item = functions.create_task(self.scheduler.scheduler_id, self.organisation.id) + response = self.client.post( + "/schedules", + json={"scheduler_id": item.scheduler_id, "organisation": self.organisation.id, "data": item.data}, + ) self.assertEqual(400, response.status_code) self.assertEqual( {"detail": "Bad request error occurred: Either deadline_at or schedule must be provided"}, response.json() ) def test_post_schedule_invalid_schedule(self): - item = functions.create_item(self.scheduler.scheduler_id, 1) + item = functions.create_task(self.scheduler.scheduler_id, self.organisation.id) response = self.client.post( - "/schedules", json={"scheduler_id": item.scheduler_id, "schedule": "invalid", "data": item.data} + "/schedules", + json={ + "scheduler_id": item.scheduler_id, + "organisation": self.organisation.id, + "schedule": "invalid", + "data": item.data, + }, ) self.assertEqual(400, response.status_code) self.assertIn("validation error", response.json().get("detail")) def test_post_schedule_invalid_scheduler_id(self): - item = functions.create_item(self.scheduler.scheduler_id, 1) + item = functions.create_task(self.scheduler.scheduler_id, self.organisation.id) response = self.client.post( - "/schedules", json={"scheduler_id": "invalid", "schedule": "*/5 * * * *", "data": item.data} + "/schedules", + json={ + "scheduler_id": "invalid", + "organisation": self.organisation.id, + "schedule": "*/5 * * * *", + "data": item.data, + }, ) self.assertEqual(400, response.status_code) self.assertEqual({"detail": "Bad request error occurred: Scheduler invalid not found"}, response.json()) def test_post_schedule_invalid_data(self): - item = functions.create_item(self.scheduler.scheduler_id, 1) + item = functions.create_task(self.scheduler.scheduler_id, self.organisation.id) response = self.client.post( - "/schedules", json={"scheduler_id": item.scheduler_id, "schedule": "*/5 * * * *", "data": "invalid"} + "/schedules", + json={ + "scheduler_id": item.scheduler_id, + "organisation": self.organisation.id, + "schedule": "*/5 * * * *", + "data": "invalid", + }, ) self.assertEqual(422, response.status_code) def test_post_schedule_invalid_data_type(self): - item = functions.create_item(self.scheduler.scheduler_id, 1) + item = functions.create_task(self.scheduler.scheduler_id, self.organisation.id) response = self.client.post( "/schedules", - json={"scheduler_id": item.scheduler_id, "schedule": "*/5 * * * *", "data": {"invalid": "invalid"}}, + json={ + "scheduler_id": item.scheduler_id, + "organisation": self.organisation.id, + "schedule": "*/5 * * * *", + "data": {"invalid": "invalid"}, + }, ) self.assertEqual(400, response.status_code) self.assertIn("validation error", response.json().get("detail")) def test_post_schedule_hash_already_exists(self): - item = functions.create_item(self.scheduler.scheduler_id, 1) + item = functions.create_task(self.scheduler.scheduler_id, self.organisation.id) response = self.client.post( - "/schedules", json={"scheduler_id": item.scheduler_id, "schedule": "*/5 * * * *", "data": item.data} + "/schedules", + json={ + "scheduler_id": item.scheduler_id, + "organisation": self.organisation.id, + "schedule": "*/5 * * * *", + "data": item.data, + }, ) self.assertEqual(201, response.status_code) response = self.client.post( - "/schedules", json={"scheduler_id": item.scheduler_id, "schedule": "*/5 * * * *", "data": item.data} + "/schedules", + json={ + "scheduler_id": item.scheduler_id, + "organisation": self.organisation.id, + "schedule": "*/5 * * * *", + "data": item.data, + }, ) self.assertEqual(409, response.status_code) self.assertIn("schedule with the same hash already exists", response.json().get("detail")) diff --git a/mula/tests/integration/test_app.py b/mula/tests/integration/test_app.py index b75f0576883..aa8add3bb5a 100644 --- a/mula/tests/integration/test_app.py +++ b/mula/tests/integration/test_app.py @@ -40,105 +40,15 @@ def tearDown(self): models.Base.metadata.drop_all(self.dbconn.engine) self.dbconn.engine.dispose() - def test_monitor_orgs_add(self): - """Test that when a new organisation is added, a new scheduler is created""" - # Arrange - self.mock_ctx.services.katalogus.organisations = { - "org-1": OrganisationFactory(id="org-1"), - "org-2": OrganisationFactory(id="org-2"), - } - - # Act - self.app.monitor_organisations() - - # Assert: six schedulers should have been created for two organisations - self.assertEqual(6, len(self.app.schedulers.keys())) - self.assertEqual(6, len(self.app.server.schedulers.keys())) - - scheduler_org_ids = {s.organisation.id for s in self.app.schedulers.values()} - self.assertEqual({"org-1", "org-2"}, scheduler_org_ids) - - def test_monitor_orgs_remove(self): - """Test that when an organisation is removed, the scheduler is removed""" - # Arrange - self.mock_ctx.services.katalogus.organisations = { - "org-1": OrganisationFactory(id="org-1"), - "org-2": OrganisationFactory(id="org-2"), - } - - # Act - self.app.monitor_organisations() - - # Assert: six schedulers should have been created for two organisations - self.assertEqual(6, len(self.app.schedulers.keys())) - self.assertEqual(6, len(self.app.server.schedulers.keys())) - - scheduler_org_ids = {s.organisation.id for s in self.app.schedulers.values()} - self.assertEqual({"org-1", "org-2"}, scheduler_org_ids) - - # Arrange - self.mock_ctx.services.katalogus.organisations = {} - - # Act - self.app.monitor_organisations() - - # Assert - self.assertEqual(0, len(self.app.schedulers.keys())) - self.assertEqual(0, len(self.app.server.schedulers.keys())) - - scheduler_org_ids = {s.organisation.id for s in self.app.schedulers.values()} - self.assertEqual(set(), scheduler_org_ids) - - def test_monitor_orgs_add_and_remove(self): - """Test that when an organisation is added and removed, the scheduler - is removed""" - # Arrange - self.mock_ctx.services.katalogus.organisations = { - "org-1": OrganisationFactory(id="org-1"), - "org-2": OrganisationFactory(id="org-2"), - } - - # Act - self.app.monitor_organisations() - - # Assert: six schedulers should have been created for two organisations - self.assertEqual(6, len(self.app.schedulers.keys())) - self.assertEqual(6, len(self.app.server.schedulers.keys())) - - scheduler_org_ids = {s.organisation.id for s in self.app.schedulers.values()} - self.assertEqual({"org-1", "org-2"}, scheduler_org_ids) - - # Arrange - self.mock_ctx.services.katalogus.organisations = { - "org-1": OrganisationFactory(id="org-1"), - "org-3": OrganisationFactory(id="org-3"), - } - - # Act - self.app.monitor_organisations() - - # Assert - self.assertEqual(6, len(self.app.schedulers.keys())) - self.assertEqual(6, len(self.app.server.schedulers.keys())) - - scheduler_org_ids = {s.organisation.id for s in self.app.schedulers.values()} - self.assertEqual({"org-1", "org-3"}, scheduler_org_ids) - def test_shutdown(self): """Test that the app shuts down gracefully""" # Arrange self.mock_ctx.services.katalogus.organisations = {"org-1": OrganisationFactory(id="org-1")} - self.app.start_schedulers() - self.app.start_monitors() # Shutdown the app self.app.shutdown() - # Assert that the schedulers have been stopped - for s in self.app.schedulers.copy().values(): - self.assertFalse(s.is_alive()) - # Assert that all threads have been stopped # for thread in self.app.threads: for t in threading.enumerate(): diff --git a/mula/tests/integration/test_boefje_scheduler.py b/mula/tests/integration/test_boefje_scheduler.py index 616a25da600..c2a24787936 100644 --- a/mula/tests/integration/test_boefje_scheduler.py +++ b/mula/tests/integration/test_boefje_scheduler.py @@ -55,10 +55,10 @@ def setUp(self): ) # Scheduler + self.scheduler = schedulers.BoefjeScheduler(self.mock_ctx) + + # Organisation self.organisation = OrganisationFactory() - self.scheduler = schedulers.BoefjeScheduler( - ctx=self.mock_ctx, scheduler_id=self.organisation.id, organisation=self.organisation - ) def tearDown(self): self.scheduler.stop() @@ -87,6 +87,21 @@ def setUp(self): def tearDown(self): mock.patch.stopall() + def test_run(self): + """When the scheduler is started, the run method should be called. + And the scheduler should start the threads. + """ + # Act + self.scheduler.run() + + # Assert: threads started + thread_ids = ["BoefjeScheduler-mutations", "BoefjeScheduler-new_boefjes", "BoefjeScheduler-rescheduling"] + for thread in self.scheduler.threads: + self.assertIn(thread.name, thread_ids) + self.assertTrue(thread.is_alive()) + + self.scheduler.stop() + def test_is_allowed_to_run(self): # Arrange scan_profile = ScanProfileFactory(level=0) @@ -155,7 +170,9 @@ def test_has_boefje_task_started_running_datastore_running(self): boefje = BoefjeFactory() boefje_task = models.BoefjeTask(boefje=boefje, input_ooi=ooi.primary_key, organization=self.organisation.id) - task = functions.create_task(scheduler_id=self.scheduler.scheduler_id, data=boefje_task) + task = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, data=boefje_task, organisation=self.organisation.id + ) # Mock self.mock_get_latest_task_by_hash.return_value = task @@ -179,6 +196,7 @@ def test_has_boefje_task_started_running_datastore_not_running(self): task_db_first = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, status=models.TaskStatus.COMPLETED, type=models.BoefjeTask.type, @@ -190,6 +208,7 @@ def test_has_boefje_task_started_running_datastore_not_running(self): task_db_second = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, type=models.BoefjeTask.type, hash=boefje_task.hash, @@ -293,11 +312,12 @@ def test_has_boefje_task_started_running_stalled_before_grace_period(self): task_db = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, + status=models.TaskStatus.DISPATCHED, type=models.BoefjeTask.type, hash=boefje_task.hash, data=boefje_task.model_dump(), - status=models.TaskStatus.DISPATCHED, created_at=datetime.now(timezone.utc), modified_at=datetime.now(timezone.utc), ) @@ -320,6 +340,7 @@ def test_has_boefje_task_started_running_stalled_after_grace_period(self): task_db = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, status=models.TaskStatus.DISPATCHED, type=models.BoefjeTask.type, @@ -350,6 +371,7 @@ def test_has_boefje_task_started_running_mismatch_before_grace_period(self): task_db = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, status=models.TaskStatus.COMPLETED, type=models.BoefjeTask.type, @@ -383,6 +405,7 @@ def test_has_boefje_task_started_running_mismatch_after_grace_period(self): task_db = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, status=models.TaskStatus.COMPLETED, type=models.BoefjeTask.type, @@ -411,6 +434,7 @@ def test_has_boefje_task_grace_period_passed_datastore_passed(self): task_db = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, status=models.TaskStatus.COMPLETED, type=models.BoefjeTask.type, @@ -442,6 +466,7 @@ def test_has_boefje_task_grace_period_passed_datastore_not_passed(self): task_db = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, status=models.TaskStatus.COMPLETED, type=models.BoefjeTask.type, @@ -471,6 +496,7 @@ def test_has_boefje_task_grace_period_passed_bytes_passed(self): task_db = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, status=models.TaskStatus.COMPLETED, type=models.BoefjeTask.type, @@ -506,6 +532,7 @@ def test_has_boefje_task_grace_period_passed_bytes_not_passed(self): task_db = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, status=models.TaskStatus.COMPLETED, type=models.BoefjeTask.type, @@ -530,7 +557,7 @@ def test_has_boefje_task_grace_period_passed_bytes_not_passed(self): # Assert self.assertFalse(has_passed) - def test_push_task(self): + def test_push_boefje_task(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) @@ -548,12 +575,12 @@ def test_push_task(self): self.mock_get_plugin.return_value = PluginFactory(scan_level=0, consumes=[ooi.object_type]) # Act - self.scheduler.push_boefje_task(boefje_task) + self.scheduler.push_boefje_task(boefje_task, self.organisation.id) # Assert self.assertEqual(1, self.scheduler.queue.qsize()) - def test_push_task_no_ooi(self): + def test_push_boefje_task_no_ooi(self): # Arrange boefje = BoefjeFactory() @@ -567,7 +594,7 @@ def test_push_task_no_ooi(self): self.mock_get_plugin.return_value = PluginFactory(scan_level=0) # Act - self.scheduler.push_boefje_task(boefje_task) + self.scheduler.push_boefje_task(boefje_task, self.organisation.id) # Assert self.assertEqual(1, self.scheduler.queue.qsize()) @@ -577,7 +604,7 @@ def test_push_task_no_ooi(self): @mock.patch("scheduler.schedulers.BoefjeScheduler.has_boefje_task_grace_period_passed") @mock.patch("scheduler.schedulers.BoefjeScheduler.is_item_on_queue_by_hash") @mock.patch("scheduler.context.AppContext.datastores.task_store.get_latest_task_by_hash") - def test_push_task_queue_full( + def test_push_boefje_task_queue_full( self, mock_get_latest_task_by_hash, mock_is_item_on_queue_by_hash, @@ -609,15 +636,15 @@ def test_push_task_queue_full( self.mock_get_plugin.return_value = PluginFactory(scan_level=0, consumes=[ooi.object_type]) # Act - self.scheduler.push_boefje_task(boefje_task) + self.scheduler.push_boefje_task(boefje_task, self.organisation.id) # Assert self.assertEqual(1, self.scheduler.queue.qsize()) with capture_logs() as cm: - self.scheduler.push_boefje_task(boefje_task) + self.scheduler.push_boefje_task(boefje_task, self.organisation.id) - self.assertIn("Could not add task to queue, queue was full", cm[-1].get("event")) + self.assertIn("Queue is full", cm[-1].get("event")) self.assertEqual(1, self.scheduler.queue.qsize()) @mock.patch("scheduler.schedulers.BoefjeScheduler.has_boefje_task_stalled") @@ -626,7 +653,7 @@ def test_push_task_queue_full( @mock.patch("scheduler.schedulers.BoefjeScheduler.has_boefje_task_grace_period_passed") @mock.patch("scheduler.schedulers.BoefjeScheduler.is_item_on_queue_by_hash") @mock.patch("scheduler.context.AppContext.datastores.task_store.get_tasks_by_hash") - def test_push_task_stalled( + def test_push_boefje_task_stalled( self, mock_get_tasks_by_hash, mock_is_item_on_queue_by_hash, @@ -645,6 +672,7 @@ def test_push_task_stalled( task = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, type=models.BoefjeTask.type, hash=boefje_task.hash, @@ -653,13 +681,11 @@ def test_push_task_stalled( modified_at=datetime.now(timezone.utc), ) - item = functions.create_item(scheduler_id=self.organisation.id, priority=1, task=task) - # Mocks self.mock_get_plugin.return_value = PluginFactory(scan_level=0, consumes=[ooi.object_type]) # Act - self.scheduler.push_item_to_queue(item) + self.scheduler.push_item_to_queue(task) # Assert: task should be on priority queue task_pq = models.BoefjeTask(**self.scheduler.queue.peek(0).data) @@ -668,16 +694,16 @@ def test_push_task_stalled( self.assertEqual(boefje_task.boefje.id, task_pq.boefje.id) # Assert: task should be in datastore, and queued - task_db = self.mock_ctx.datastores.task_store.get_task(item.id) - self.assertEqual(task_db.id, item.id) + task_db = self.mock_ctx.datastores.task_store.get_task(task.id) + self.assertEqual(task_db.id, task.id) self.assertEqual(task_db.status, models.TaskStatus.QUEUED) # Act self.scheduler.pop_item_from_queue() # Assert: task should be in datastore, and dispatched - task_db = self.mock_ctx.datastores.task_store.get_task(item.id) - self.assertEqual(task_db.id, item.id) + task_db = self.mock_ctx.datastores.task_store.get_task(task.id) + self.assertEqual(task_db.id, task.id) self.assertEqual(task_db.status, models.TaskStatus.DISPATCHED) # Mocks @@ -690,11 +716,11 @@ def test_push_task_stalled( mock_get_tasks_by_hash.return_value = None # Act - self.scheduler.push_boefje_task(boefje_task) + self.scheduler.push_boefje_task(boefje_task, self.organisation.id) # Assert: task should be in datastore, and failed - task_db = self.mock_ctx.datastores.task_store.get_task(item.id) - self.assertEqual(task_db.id, item.id) + task_db = self.mock_ctx.datastores.task_store.get_task(task.id) + self.assertEqual(task_db.id, task.id) self.assertEqual(task_db.status, models.TaskStatus.FAILED) # Assert: new task should be queued @@ -714,6 +740,7 @@ def test_post_push(self): task = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, type=models.BoefjeTask.type, hash=boefje_task.hash, @@ -722,12 +749,10 @@ def test_post_push(self): modified_at=datetime.now(timezone.utc), ) - item = functions.create_item(scheduler_id=self.organisation.id, priority=1, task=task) - self.mock_get_plugin.return_value = PluginFactory(scan_level=0, consumes=[ooi.object_type]) # Act - self.scheduler.push_item_to_queue(item) + self.scheduler.push_item_to_queue(task) # Task should be on priority queue task_pq = models.BoefjeTask(**self.scheduler.queue.peek(0).data) @@ -736,8 +761,8 @@ def test_post_push(self): self.assertEqual(boefje_task.boefje.id, task_pq.boefje.id) # Task should be in datastore, and queued - task_db = self.mock_ctx.datastores.task_store.get_task(item.id) - self.assertEqual(task_db.id, item.id) + task_db = self.mock_ctx.datastores.task_store.get_task(task.id) + self.assertEqual(task_db.id, task.id) self.assertEqual(task_db.status, models.TaskStatus.QUEUED) # Schedule should be in datastore @@ -763,6 +788,7 @@ def test_post_push_boefje_cron(self): task = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, type=models.BoefjeTask.type, hash=boefje_task.hash, @@ -771,12 +797,10 @@ def test_post_push_boefje_cron(self): modified_at=datetime.now(timezone.utc), ) - item = functions.create_item(scheduler_id=self.organisation.id, priority=1, task=task) - self.mock_get_plugin.return_value = PluginFactory(scan_level=0, consumes=[ooi.object_type], cron=cron) # Act - self.scheduler.push_item_to_queue(item) + self.scheduler.push_item_to_queue(task) # Task should be on priority queue task_pq = models.BoefjeTask(**self.scheduler.queue.peek(0).data) @@ -785,8 +809,8 @@ def test_post_push_boefje_cron(self): self.assertEqual(boefje_task.boefje.id, task_pq.boefje.id) # Task should be in datastore, and queued - task_db = self.mock_ctx.datastores.task_store.get_task(item.id) - self.assertEqual(task_db.id, item.id) + task_db = self.mock_ctx.datastores.task_store.get_task(task.id) + self.assertEqual(task_db.id, task.id) self.assertEqual(task_db.status, models.TaskStatus.QUEUED) # Schedule should be in datastore @@ -818,6 +842,7 @@ def test_post_push_boefje_interval(self): task = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, type=models.BoefjeTask.type, hash=boefje_task.hash, @@ -826,12 +851,10 @@ def test_post_push_boefje_interval(self): modified_at=datetime.now(timezone.utc), ) - item = functions.create_item(scheduler_id=self.organisation.id, priority=1, task=task) - self.mock_get_plugin.return_value = PluginFactory(scan_level=0, consumes=[ooi.object_type], interval=1500) # Act - self.scheduler.push_item_to_queue(item) + self.scheduler.push_item_to_queue(task) # Task should be on priority queue task_pq = models.BoefjeTask(**self.scheduler.queue.peek(0).data) @@ -840,8 +863,8 @@ def test_post_push_boefje_interval(self): self.assertEqual(boefje_task.boefje.id, task_pq.boefje.id) # Task should be in datastore, and queued - task_db = self.mock_ctx.datastores.task_store.get_task(item.id) - self.assertEqual(task_db.id, item.id) + task_db = self.mock_ctx.datastores.task_store.get_task(task.id) + self.assertEqual(task_db.id, task.id) self.assertEqual(task_db.status, models.TaskStatus.QUEUED) # Schedule should be in datastore @@ -870,6 +893,7 @@ def test_post_pop(self): task = models.Task( scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, priority=1, type=models.BoefjeTask.type, hash=boefje_task.hash, @@ -878,13 +902,11 @@ def test_post_pop(self): modified_at=datetime.now(timezone.utc), ) - item = functions.create_item(scheduler_id=self.organisation.id, priority=1, task=task) - # Mocks self.mock_get_plugin.return_value = PluginFactory(scan_level=0, consumes=[ooi.object_type]) # Act - self.scheduler.push_item_to_queue(item) + self.scheduler.push_item_to_queue(task) # Assert: task should be on priority queue task_pq = models.BoefjeTask(**self.scheduler.queue.peek(0).data) @@ -893,109 +915,18 @@ def test_post_pop(self): self.assertEqual(boefje_task.boefje.id, task_pq.boefje.id) # Assert: task should be in datastore, and queued - task_db = self.mock_ctx.datastores.task_store.get_task(item.id) - self.assertEqual(task_db.id, item.id) + task_db = self.mock_ctx.datastores.task_store.get_task(task.id) + self.assertEqual(task_db.id, task.id) self.assertEqual(task_db.status, models.TaskStatus.QUEUED) # Act self.scheduler.pop_item_from_queue() # Assert: task should be in datastore, and queued - task_db = self.mock_ctx.datastores.task_store.get_task(item.id) - self.assertEqual(task_db.id, item.id) + task_db = self.mock_ctx.datastores.task_store.get_task(task.id) + self.assertEqual(task_db.id, task.id) self.assertEqual(task_db.status, models.TaskStatus.DISPATCHED) - def test_disable_scheduler(self): - # Arrange: start scheduler - self.scheduler.run() - - # Arrange: add tasks - scan_profile = ScanProfileFactory(level=0) - ooi = OOIFactory(scan_profile=scan_profile) - boefje_task = models.BoefjeTask( - boefje=BoefjeFactory(), input_ooi=ooi.primary_key, organization=self.organisation.id - ) - - # Mocks - self.mock_get_plugin.return_value = PluginFactory(scan_level=0, consumes=[ooi.object_type]) - - # Act - task = functions.create_task(scheduler_id=self.scheduler.scheduler_id, data=boefje_task) - - item = functions.create_item(scheduler_id=self.organisation.id, priority=1, task=task) - self.scheduler.push_item_to_queue(item) - - # Assert: task should be on priority queue - pq_item = self.scheduler.queue.peek(0) - self.assertEqual(1, self.scheduler.queue.qsize()) - self.assertEqual(pq_item.id, item.id) - - # Assert: task should be in datastore, and queued - task_db = self.mock_ctx.datastores.task_store.get_task(item.id) - self.assertEqual(task_db.id, item.id) - self.assertEqual(task_db.status, models.TaskStatus.QUEUED) - - # Assert: listeners should be running - self.assertGreater(len(self.scheduler.listeners), 0) - - # Assert: threads should be running - self.assertGreater(len(self.scheduler.threads), 0) - - # Act - self.scheduler.disable() - - # Listeners should be stopped - self.assertEqual(0, len(self.scheduler.listeners)) - - # Threads should be stopped - self.assertEqual(0, len(self.scheduler.threads)) - - # Queue should be empty - self.assertEqual(0, self.scheduler.queue.qsize()) - - # All tasks on queue should be set to CANCELLED - tasks, _ = self.mock_ctx.datastores.task_store.get_tasks(self.scheduler.scheduler_id) - for task in tasks: - self.assertEqual(task.status, models.TaskStatus.CANCELLED) - - # Scheduler should be disabled - self.assertFalse(self.scheduler.is_enabled()) - - self.scheduler.stop() - - def test_enable_scheduler(self): - self.scheduler.run() - - # Assert: listeners should be running - self.assertGreater(len(self.scheduler.listeners), 0) - - # Assert: threads should be running - self.assertGreater(len(self.scheduler.threads), 0) - - # Disable scheduler first - self.scheduler.disable() - - # Listeners should be stopped - self.assertEqual(0, len(self.scheduler.listeners)) - - # Threads should be stopped - self.assertEqual(0, len(self.scheduler.threads)) - - # Queue should be empty - self.assertEqual(0, self.scheduler.queue.qsize()) - - # Re-enable scheduler - self.scheduler.enable() - - # Threads should be started - self.assertGreater(len(self.scheduler.threads), 0) - - # Scheduler should be enabled - self.assertTrue(self.scheduler.is_enabled()) - - # Stop the scheduler - self.scheduler.stop() - def test_has_boefje_permission_to_run(self): # Arrange scan_profile = ScanProfileFactory(level=0) @@ -1082,21 +1013,20 @@ def setUp(self): def tearDown(self): mock.patch.stopall() - def test_push_tasks_for_scan_profile_mutations(self): + def test_process_mutations(self): """Scan level change""" # Arrange - scan_profile = ScanProfileFactory(level=0) - ooi = OOIFactory(scan_profile=scan_profile) + ooi = OOIFactory(scan_profile=ScanProfileFactory(level=0)) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type]) mutation = models.ScanProfileMutation( - operation="create", primary_key=ooi.primary_key, value=ooi + operation="create", primary_key=ooi.primary_key, value=ooi, client_id=self.organisation.id ).model_dump_json() # Mocks self.mock_get_boefjes_for_ooi.return_value = [boefje] # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation) + self.scheduler.process_mutations(mutation) # Task should be on priority queue item = self.scheduler.queue.peek(0) @@ -1110,43 +1040,45 @@ def test_push_tasks_for_scan_profile_mutations(self): self.assertEqual(task_db.id, item.id) self.assertEqual(task_db.status, models.TaskStatus.QUEUED) - def test_push_tasks_for_scan_profile_mutations_value_empty(self): + def test_process_mutations_value_empty(self): """When the value of a mutation is empty it should not push any tasks""" # Arrange - mutation = models.ScanProfileMutation(operation="create", primary_key="123", value=None).model_dump_json() + mutation = models.ScanProfileMutation( + operation="create", primary_key="123", value=None, client_id=self.organisation.id + ).model_dump_json() # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation) + self.scheduler.process_mutations(mutation) # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_scan_profile_mutations_no_boefjes_found(self): + def test_process_mutations_no_boefjes_found(self): """When no plugins are found for boefjes, it should return no boefje tasks""" # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) mutation = models.ScanProfileMutation( - operation="create", primary_key=ooi.primary_key, value=ooi + operation="create", primary_key=ooi.primary_key, value=ooi, client_id=self.organisation.id ).model_dump_json() # Mocks self.mock_get_boefjes_for_ooi.return_value = [] # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation) + self.scheduler.process_mutations(mutation) # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_scan_profile_mutations_not_allowed_to_run(self): + def test_process_mutations_not_allowed_to_run(self): """When a boefje is not allowed to run, it should not be added to the queue""" # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type]) mutation = models.ScanProfileMutation( - operation="create", primary_key=ooi.primary_key, value=ooi + operation="create", primary_key=ooi.primary_key, value=ooi, client_id=self.organisation.id ).model_dump_json() # Mocks @@ -1154,19 +1086,19 @@ def test_push_tasks_for_scan_profile_mutations_not_allowed_to_run(self): self.mock_has_boefje_permission_to_run.return_value = False # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation) + self.scheduler.process_mutations(mutation) # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_scan_profile_mutations_still_running(self): + def test_process_mutations_still_running(self): """When a boefje is still running, it should not be added to the queue""" # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type]) mutation = models.ScanProfileMutation( - operation="create", primary_key=ooi.primary_key, value=ooi + operation="create", primary_key=ooi.primary_key, value=ooi, client_id=self.organisation.id ).model_dump_json() # Mocks @@ -1174,30 +1106,31 @@ def test_push_tasks_for_scan_profile_mutations_still_running(self): self.mock_has_boefje_task_started_running.return_value = True # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation) + self.scheduler.process_mutations(mutation) # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_scan_profile_mutations_item_on_queue(self): + def test_process_mutations_item_on_queue(self): """When a boefje is already on the queue, it should not be added to the queue""" # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type]) + mutation1 = models.ScanProfileMutation( - operation="create", primary_key=ooi.primary_key, value=ooi + operation="create", primary_key=ooi.primary_key, value=ooi, client_id=self.organisation.id ).model_dump_json() mutation2 = models.ScanProfileMutation( - operation="create", primary_key=ooi.primary_key, value=ooi + operation="create", primary_key=ooi.primary_key, value=ooi, client_id=self.organisation.id ).model_dump_json() # Mocks self.mock_get_boefjes_for_ooi.return_value = [boefje] # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation1) - self.scheduler.push_tasks_for_scan_profile_mutations(mutation2) + self.scheduler.process_mutations(mutation1) + self.scheduler.process_mutations(mutation2) # Task should be on priority queue (only one) task_pq = self.scheduler.queue.peek(0) @@ -1210,7 +1143,7 @@ def test_push_tasks_for_scan_profile_mutations_item_on_queue(self): task_db = self.mock_ctx.datastores.task_store.get_task(task_pq.id) self.assertEqual(task_db.status, models.TaskStatus.QUEUED) - def test_push_tasks_for_scan_profile_mutations_delete(self): + def test_process_mutations_delete(self): """When an OOI is deleted it should not create tasks""" # Arrange scan_profile = ScanProfileFactory(level=0) @@ -1218,19 +1151,22 @@ def test_push_tasks_for_scan_profile_mutations_delete(self): boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type]) mutation1 = models.ScanProfileMutation( - operation=models.MutationOperationType.DELETE, primary_key=ooi.primary_key, value=ooi + operation=models.MutationOperationType.DELETE, + primary_key=ooi.primary_key, + value=ooi, + client_id=self.organisation.id, ).model_dump_json() # Mocks self.mock_get_boefjes_for_ooi.return_value = [boefje] # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation1) + self.scheduler.process_mutations(mutation1) # Assert self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_scan_profile_mutations_delete_on_queue(self): + def test_process_mutations_delete_on_queue(self): """When an OOI is deleted, and tasks associated with that ooi should be removed from the queue """ @@ -1240,14 +1176,17 @@ def test_push_tasks_for_scan_profile_mutations_delete_on_queue(self): boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type]) mutation1 = models.ScanProfileMutation( - operation=models.MutationOperationType.CREATE, primary_key=ooi.primary_key, value=ooi + operation=models.MutationOperationType.CREATE, + primary_key=ooi.primary_key, + value=ooi, + client_id=self.organisation.id, ).model_dump_json() # Mocks self.mock_get_boefjes_for_ooi.return_value = [boefje] # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation1) + self.scheduler.process_mutations(mutation1) # Assert: task should be on priority queue item = self.scheduler.queue.peek(0) @@ -1258,11 +1197,14 @@ def test_push_tasks_for_scan_profile_mutations_delete_on_queue(self): # Arrange mutation2 = models.ScanProfileMutation( - operation=models.MutationOperationType.DELETE, primary_key=ooi.primary_key, value=ooi + operation=models.MutationOperationType.DELETE, + primary_key=ooi.primary_key, + value=ooi, + client_id=self.organisation.id, ).model_dump_json() # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation2) + self.scheduler.process_mutations(mutation2) # Assert self.assertIsNone(self.scheduler.queue.peek(0)) @@ -1273,7 +1215,7 @@ def test_push_tasks_for_scan_profile_mutations_delete_on_queue(self): task_db = self.mock_ctx.datastores.task_store.get_task(item.id) self.assertEqual(task_db.status, models.TaskStatus.CANCELLED) - def test_push_tasks_for_scan_profile_mutations_op_create_run_on_create(self): + def test_process_mutations_op_create_run_on_create(self): """When a boefje has the run_on contains the setting create, and we receive a create mutation, it should: @@ -1285,14 +1227,17 @@ def test_push_tasks_for_scan_profile_mutations_op_create_run_on_create(self): ooi = OOIFactory(scan_profile=scan_profile) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type], run_on=["create"]) mutation = models.ScanProfileMutation( - operation=models.MutationOperationType.CREATE, primary_key=ooi.primary_key, value=ooi + operation=models.MutationOperationType.CREATE, + primary_key=ooi.primary_key, + value=ooi, + client_id=self.organisation.id, ).model_dump_json() # Mocks self.mock_get_boefjes_for_ooi.return_value = [boefje] # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation) + self.scheduler.process_mutations(mutation) # Assert: task should be on priority queue item = self.scheduler.queue.peek(0) @@ -1310,7 +1255,7 @@ def test_push_tasks_for_scan_profile_mutations_op_create_run_on_create(self): schedule_db = self.mock_ctx.datastores.schedule_store.get_schedule_by_hash(task_db.hash) self.assertIsNone(schedule_db) - def test_push_tasks_for_scan_profile_mutations_op_create_run_on_create_update(self): + def test_process_mutations_op_create_run_on_create_update(self): """When a boefje has the run_on contains the setting create,update, and we receive a create mutation, it should: @@ -1322,14 +1267,17 @@ def test_push_tasks_for_scan_profile_mutations_op_create_run_on_create_update(se ooi = OOIFactory(scan_profile=scan_profile) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type], run_on=["create", "update"]) mutation = models.ScanProfileMutation( - operation=models.MutationOperationType.CREATE, primary_key=ooi.primary_key, value=ooi + operation=models.MutationOperationType.CREATE, + primary_key=ooi.primary_key, + value=ooi, + client_id=self.organisation.id, ).model_dump_json() # Mocks self.mock_get_boefjes_for_ooi.return_value = [boefje] # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation) + self.scheduler.process_mutations(mutation) # Assert: task should be on priority queue item = self.scheduler.queue.peek(0) @@ -1347,7 +1295,7 @@ def test_push_tasks_for_scan_profile_mutations_op_create_run_on_create_update(se schedule_db = self.mock_ctx.datastores.schedule_store.get_schedule_by_hash(task_db.hash) self.assertIsNone(schedule_db) - def test_push_tasks_for_scan_profile_mutations_op_create_run_on_update(self): + def test_process_mutations_op_create_run_on_update(self): """When a boefje has the run_on contains the setting update, and we receive a create mutation, it should: @@ -1359,19 +1307,22 @@ def test_push_tasks_for_scan_profile_mutations_op_create_run_on_update(self): ooi = OOIFactory(scan_profile=scan_profile) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type], run_on=["update"]) mutation = models.ScanProfileMutation( - operation=models.MutationOperationType.CREATE, primary_key=ooi.primary_key, value=ooi + operation=models.MutationOperationType.CREATE, + primary_key=ooi.primary_key, + value=ooi, + client_id=self.organisation.id, ).model_dump_json() # Mocks self.mock_get_boefjes_for_ooi.return_value = [boefje] # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation) + self.scheduler.process_mutations(mutation) # Assert: task should NOT be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_scan_profile_mutations_op_create_run_on_none(self): + def test_process_mutations_op_create_run_on_none(self): """When a boefje has the run_on is empty, and we receive a create mutation, it should: @@ -1383,7 +1334,10 @@ def test_push_tasks_for_scan_profile_mutations_op_create_run_on_none(self): ooi = OOIFactory(scan_profile=scan_profile) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type], run_on=[]) mutation = models.ScanProfileMutation( - operation=models.MutationOperationType.CREATE, primary_key=ooi.primary_key, value=ooi + operation=models.MutationOperationType.CREATE, + primary_key=ooi.primary_key, + value=ooi, + client_id=self.organisation.id, ).model_dump_json() # Mocks @@ -1391,7 +1345,7 @@ def test_push_tasks_for_scan_profile_mutations_op_create_run_on_none(self): self.mock_set_cron.return_value = "0 0 * * *" # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation) + self.scheduler.process_mutations(mutation) # Assert: task should be on priority queue item = self.scheduler.queue.peek(0) @@ -1409,7 +1363,7 @@ def test_push_tasks_for_scan_profile_mutations_op_create_run_on_none(self): schedule_db = self.mock_ctx.datastores.schedule_store.get_schedule(task_db.schedule_id) self.assertIsNotNone(schedule_db) - def test_push_tasks_for_scan_profile_mutations_op_update_run_on_create(self): + def test_process_mutations_op_update_run_on_create(self): """When a boefje has the run_on contains the setting create, and we receive an update mutation, it should: @@ -1421,19 +1375,22 @@ def test_push_tasks_for_scan_profile_mutations_op_update_run_on_create(self): ooi = OOIFactory(scan_profile=scan_profile) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type], run_on=["create"]) mutation = models.ScanProfileMutation( - operation=models.MutationOperationType.UPDATE, primary_key=ooi.primary_key, value=ooi + operation=models.MutationOperationType.UPDATE, + primary_key=ooi.primary_key, + value=ooi, + client_id=self.organisation.id, ).model_dump_json() # Mocks self.mock_get_boefjes_for_ooi.return_value = [boefje] # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation) + self.scheduler.process_mutations(mutation) # Assert: task should NOT be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_scan_profile_mutations_op_update_run_on_create_update(self): + def test_process_mutations_op_update_run_on_create_update(self): """When a boefje has the run_on contains the setting create,update, and we receive an update mutation, it should: @@ -1445,14 +1402,17 @@ def test_push_tasks_scan_profile_mutations_op_update_run_on_create_update(self): ooi = OOIFactory(scan_profile=scan_profile) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type], run_on=["create", "update"]) mutation = models.ScanProfileMutation( - operation=models.MutationOperationType.UPDATE, primary_key=ooi.primary_key, value=ooi + operation=models.MutationOperationType.UPDATE, + primary_key=ooi.primary_key, + value=ooi, + client_id=self.organisation.id, ).model_dump_json() # Mocks self.mock_get_boefjes_for_ooi.return_value = [boefje] # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation) + self.scheduler.process_mutations(mutation) # Assert: task should be on priority queue item = self.scheduler.queue.peek(0) @@ -1470,7 +1430,7 @@ def test_push_tasks_scan_profile_mutations_op_update_run_on_create_update(self): schedule_db = self.mock_ctx.datastores.schedule_store.get_schedule_by_hash(task_db.hash) self.assertIsNone(schedule_db) - def test_push_tasks_scan_profile_mutations_op_update_run_on_update(self): + def test_process_mutations_op_update_run_on_update(self): """When a boefje has the run_on contains the setting update, and we receive an update mutation, it should: @@ -1482,14 +1442,17 @@ def test_push_tasks_scan_profile_mutations_op_update_run_on_update(self): ooi = OOIFactory(scan_profile=scan_profile) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type], run_on=["update"]) mutation = models.ScanProfileMutation( - operation=models.MutationOperationType.UPDATE, primary_key=ooi.primary_key, value=ooi + operation=models.MutationOperationType.UPDATE, + primary_key=ooi.primary_key, + value=ooi, + client_id=self.organisation.id, ).model_dump_json() # Mocks self.mock_get_boefjes_for_ooi.return_value = [boefje] # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation) + self.scheduler.process_mutations(mutation) # Assert: task should be on priority queue item = self.scheduler.queue.peek(0) @@ -1507,7 +1470,7 @@ def test_push_tasks_scan_profile_mutations_op_update_run_on_update(self): schedule_db = self.mock_ctx.datastores.schedule_store.get_schedule_by_hash(task_db.hash) self.assertIsNone(schedule_db) - def test_push_tasks_scan_profile_mutations_op_update_run_on_none(self): + def test_process_mutations_op_update_run_on_none(self): """When a boefje has the run_on is empty, and we receive an update mutation, it should: @@ -1519,7 +1482,10 @@ def test_push_tasks_scan_profile_mutations_op_update_run_on_none(self): ooi = OOIFactory(scan_profile=scan_profile) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type], run_on=[]) mutation = models.ScanProfileMutation( - operation=models.MutationOperationType.UPDATE, primary_key=ooi.primary_key, value=ooi + operation=models.MutationOperationType.UPDATE, + primary_key=ooi.primary_key, + value=ooi, + client_id=self.organisation.id, ).model_dump_json() # Mocks @@ -1527,7 +1493,7 @@ def test_push_tasks_scan_profile_mutations_op_update_run_on_none(self): self.mock_set_cron.return_value = "0 0 * * *" # Act - self.scheduler.push_tasks_for_scan_profile_mutations(mutation) + self.scheduler.process_mutations(mutation) # Assert: task should be on priority queue item = self.scheduler.queue.peek(0) @@ -1570,21 +1536,26 @@ def setUp(self): "scheduler.context.AppContext.services.octopoes.get_objects_by_object_types" ).start() + self.mock_get_organisations = mock.patch( + "scheduler.context.AppContext.services.katalogus.get_organisations" + ).start() + def tearDown(self): mock.patch.stopall() - def test_push_tasks_for_new_boefjes(self): + def test_process_new_boefjes(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type]) # Mocks + self.mock_get_organisations.return_value = [self.organisation] self.mock_get_objects_by_object_types.return_value = [ooi] self.mock_get_new_boefjes_by_org_id.return_value = [boefje] # Act - self.scheduler.push_tasks_for_new_boefjes() + self.scheduler.process_new_boefjes() # Task should be on priority queue task_pq = self.scheduler.queue.peek(0) @@ -1598,7 +1569,7 @@ def test_push_tasks_for_new_boefjes(self): self.assertEqual(task_db.id, task_pq.id) self.assertEqual(task_db.status, models.TaskStatus.QUEUED) - def test_push_tasks_for_new_boefjes_request_exception(self): + def test_process_new_boefjes_request_exception(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) @@ -1612,13 +1583,13 @@ def test_push_tasks_for_new_boefjes_request_exception(self): self.mock_get_new_boefjes_by_org_id.return_value = [boefje] # Act - self.scheduler.push_tasks_for_new_boefjes() - self.scheduler.push_tasks_for_new_boefjes() + self.scheduler.process_new_boefjes() + self.scheduler.process_new_boefjes() # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_new_boefjes_no_new_boefjes(self): + def test_process_new_boefjes_no_new_boefjes(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) @@ -1628,12 +1599,12 @@ def test_push_tasks_for_new_boefjes_no_new_boefjes(self): self.mock_get_new_boefjes_by_org_id.return_value = [] # Act - self.scheduler.push_tasks_for_new_boefjes() + self.scheduler.process_new_boefjes() # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_new_boefjes_empty_consumes(self): + def test_process_new_boefjes_empty_consumes(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) @@ -1644,12 +1615,12 @@ def test_push_tasks_for_new_boefjes_empty_consumes(self): self.mock_get_new_boefjes_by_org_id.return_value = [boefje] # Act - self.scheduler.push_tasks_for_new_boefjes() + self.scheduler.process_new_boefjes() # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_new_boefjes_empty_consumes_no_ooi(self): + def test_process_new_boefjes_empty_consumes_no_ooi(self): # Arrange boefje = PluginFactory(scan_level=0, consumes=[]) @@ -1658,12 +1629,12 @@ def test_push_tasks_for_new_boefjes_empty_consumes_no_ooi(self): self.mock_get_new_boefjes_by_org_id.return_value = [boefje] # Act - self.scheduler.push_tasks_for_new_boefjes() + self.scheduler.process_new_boefjes() # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_new_boefjes_no_oois_found(self): + def test_process_new_boefjes_no_oois_found(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) @@ -1674,12 +1645,12 @@ def test_push_tasks_for_new_boefjes_no_oois_found(self): self.mock_get_new_boefjes_by_org_id.return_value = [boefje] # Act - self.scheduler.push_tasks_for_new_boefjes() + self.scheduler.process_new_boefjes() # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_new_boefjes_get_objects_request_exception(self): + def test_process_new_boefjes_get_objects_request_exception(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) @@ -1693,13 +1664,13 @@ def test_push_tasks_for_new_boefjes_get_objects_request_exception(self): self.mock_get_new_boefjes_by_org_id.return_value = [boefje] # Act - self.scheduler.push_tasks_for_new_boefjes() - self.scheduler.push_tasks_for_new_boefjes() + self.scheduler.process_new_boefjes() + self.scheduler.process_new_boefjes() # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_new_boefjes_not_allowed_to_run(self): + def test_process_new_boefjes_not_allowed_to_run(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) @@ -1711,12 +1682,12 @@ def test_push_tasks_for_new_boefjes_not_allowed_to_run(self): self.mock_has_boefje_permission_to_run.return_value = False # Act - self.scheduler.push_tasks_for_new_boefjes() + self.scheduler.process_new_boefjes() # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_new_boefjes_still_running(self): + def test_process_new_boefjes_still_running(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) @@ -1728,23 +1699,24 @@ def test_push_tasks_for_new_boefjes_still_running(self): self.mock_has_boefje_task_started_running.return_value = True # Act - self.scheduler.push_tasks_for_new_boefjes() + self.scheduler.process_new_boefjes() # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_new_boefjes_item_on_queue(self): + def test_process_new_boefjes_item_on_queue(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) boefje = PluginFactory(scan_level=0, consumes=[ooi.object_type]) # Mocks + self.mock_get_organisations.return_value = [self.organisation] self.mock_get_objects_by_object_types.return_value = [ooi] self.mock_get_new_boefjes_by_org_id.return_value = [boefje] # Act - self.scheduler.push_tasks_for_new_boefjes() + self.scheduler.process_new_boefjes() # Task should be on priority queue task_pq = self.scheduler.queue.peek(0) @@ -1759,7 +1731,7 @@ def test_push_tasks_for_new_boefjes_item_on_queue(self): self.assertEqual(task_db.status, models.TaskStatus.QUEUED) # Act - self.scheduler.push_tasks_for_new_boefjes() + self.scheduler.process_new_boefjes() # Should only be one task on queue task_pq = models.BoefjeTask(**self.scheduler.queue.peek(0).data) @@ -1791,10 +1763,10 @@ def setUp(self): def tearDown(self): mock.patch.stopall() - def test_push_tasks_for_rescheduling_scheduler_id(self): + def test_process_rescheduling_scheduler_id(self): pass - def test_push_tasks_for_rescheduling(self): + def test_process_rescheduling(self): """When the deadline of schedules have passed, the resulting task should be added to the queue""" # Arrange scan_profile = ScanProfileFactory(level=0) @@ -1808,7 +1780,10 @@ def test_push_tasks_for_rescheduling(self): ) schedule = models.Schedule( - scheduler_id=self.scheduler.scheduler_id, hash=boefje_task.hash, data=boefje_task.model_dump() + scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, + hash=boefje_task.hash, + data=boefje_task.model_dump(), ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) @@ -1819,7 +1794,7 @@ def test_push_tasks_for_rescheduling(self): self.mock_get_plugin.return_value = plugin # Act - self.scheduler.push_tasks_for_rescheduling() + self.scheduler.process_rescheduling() # Assert: new item should be on queue self.assertEqual(1, self.scheduler.queue.qsize()) @@ -1833,7 +1808,7 @@ def test_push_tasks_for_rescheduling(self): self.assertIsNotNone(task_db) self.assertEqual(peek.id, task_db.id) - def test_push_tasks_for_rescheduling_no_ooi(self): + def test_process_rescheduling_no_ooi(self): """When the deadline has passed, and when the resulting tasks doesn't have an OOI, it should create a task. """ @@ -1849,7 +1824,10 @@ def test_push_tasks_for_rescheduling_no_ooi(self): ) schedule = models.Schedule( - scheduler_id=self.scheduler.scheduler_id, hash=boefje_task.hash, data=boefje_task.model_dump() + scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, + hash=boefje_task.hash, + data=boefje_task.model_dump(), ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) @@ -1860,7 +1838,7 @@ def test_push_tasks_for_rescheduling_no_ooi(self): self.mock_get_plugin.return_value = plugin # Act - self.scheduler.push_tasks_for_rescheduling() + self.scheduler.process_rescheduling() # Assert: new item should be on queue self.assertEqual(1, self.scheduler.queue.qsize()) @@ -1874,7 +1852,7 @@ def test_push_tasks_for_rescheduling_no_ooi(self): self.assertIsNotNone(task_db) self.assertEqual(peek.id, task_db.id) - def test_push_tasks_for_rescheduling_ooi_not_found(self): + def test_process_rescheduling_ooi_not_found(self): """When ooi isn't found anymore for the schedule, we disable the schedule""" # Arrange scan_profile = ScanProfileFactory(level=0) @@ -1888,7 +1866,10 @@ def test_push_tasks_for_rescheduling_ooi_not_found(self): ) schedule = models.Schedule( - scheduler_id=self.scheduler.scheduler_id, hash=boefje_task.hash, data=boefje_task.model_dump() + scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, + hash=boefje_task.hash, + data=boefje_task.model_dump(), ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) @@ -1899,7 +1880,7 @@ def test_push_tasks_for_rescheduling_ooi_not_found(self): self.mock_get_plugin.return_value = plugin # Act - self.scheduler.push_tasks_for_rescheduling() + self.scheduler.process_rescheduling() # Assert: item should not be on queue self.assertEqual(0, self.scheduler.queue.qsize()) @@ -1908,7 +1889,7 @@ def test_push_tasks_for_rescheduling_ooi_not_found(self): schedule_db_disabled = self.mock_ctx.datastores.schedule_store.get_schedule(schedule.id) self.assertFalse(schedule_db_disabled.enabled) - def test_push_tasks_for_rescheduling_boefje_not_found(self): + def test_process_rescheduling_boefje_not_found(self): """When boefje isn't found anymore for the schedule, we disable the schedule""" # Arrange scan_profile = ScanProfileFactory(level=0) @@ -1922,7 +1903,10 @@ def test_push_tasks_for_rescheduling_boefje_not_found(self): ) schedule = models.Schedule( - scheduler_id=self.scheduler.scheduler_id, hash=boefje_task.hash, data=boefje_task.model_dump() + scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, + hash=boefje_task.hash, + data=boefje_task.model_dump(), ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) @@ -1933,7 +1917,7 @@ def test_push_tasks_for_rescheduling_boefje_not_found(self): self.mock_get_plugin.return_value = None # Act - self.scheduler.push_tasks_for_rescheduling() + self.scheduler.process_rescheduling() # Assert: item should not be on queue self.assertEqual(0, self.scheduler.queue.qsize()) @@ -1942,7 +1926,7 @@ def test_push_tasks_for_rescheduling_boefje_not_found(self): schedule_db_disabled = self.mock_ctx.datastores.schedule_store.get_schedule(schedule.id) self.assertFalse(schedule_db_disabled.enabled) - def test_push_tasks_for_rescheduling_boefje_disabled(self): + def test_process_rescheduling_boefje_disabled(self): """When boefje disabled for the schedule, we disable the schedule""" # Arrange scan_profile = ScanProfileFactory(level=0) @@ -1956,7 +1940,10 @@ def test_push_tasks_for_rescheduling_boefje_disabled(self): ) schedule = models.Schedule( - scheduler_id=self.scheduler.scheduler_id, hash=boefje_task.hash, data=boefje_task.model_dump() + scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, + hash=boefje_task.hash, + data=boefje_task.model_dump(), ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) @@ -1967,7 +1954,7 @@ def test_push_tasks_for_rescheduling_boefje_disabled(self): self.mock_get_plugin.return_value = plugin # Act - self.scheduler.push_tasks_for_rescheduling() + self.scheduler.process_rescheduling() # Assert: item should not be on queue self.assertEqual(0, self.scheduler.queue.qsize()) @@ -1976,7 +1963,7 @@ def test_push_tasks_for_rescheduling_boefje_disabled(self): schedule_db_disabled = self.mock_ctx.datastores.schedule_store.get_schedule(schedule.id) self.assertFalse(schedule_db_disabled.enabled) - def test_push_tasks_for_rescheduling_boefje_doesnt_consume_ooi(self): + def test_process_rescheduling_boefje_doesnt_consume_ooi(self): """When boefje doesn't consume the ooi, we disable the schedule""" # Arrange scan_profile = ScanProfileFactory(level=0) @@ -1990,7 +1977,10 @@ def test_push_tasks_for_rescheduling_boefje_doesnt_consume_ooi(self): ) schedule = models.Schedule( - scheduler_id=self.scheduler.scheduler_id, hash=boefje_task.hash, data=boefje_task.model_dump() + scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, + hash=boefje_task.hash, + data=boefje_task.model_dump(), ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) @@ -2001,7 +1991,7 @@ def test_push_tasks_for_rescheduling_boefje_doesnt_consume_ooi(self): self.mock_get_plugin.return_value = plugin # Act - self.scheduler.push_tasks_for_rescheduling() + self.scheduler.process_rescheduling() # Assert: item should not be on queue self.assertEqual(0, self.scheduler.queue.qsize()) @@ -2010,7 +2000,7 @@ def test_push_tasks_for_rescheduling_boefje_doesnt_consume_ooi(self): schedule_db_disabled = self.mock_ctx.datastores.schedule_store.get_schedule(schedule.id) self.assertFalse(schedule_db_disabled.enabled) - def test_push_tasks_for_rescheduling_boefje_cannot_scan_ooi(self): + def test_process_rescheduling_boefje_cannot_scan_ooi(self): """When boefje cannot scan the ooi, we disable the schedule""" # Arrange scan_profile = ScanProfileFactory(level=0) @@ -2024,7 +2014,10 @@ def test_push_tasks_for_rescheduling_boefje_cannot_scan_ooi(self): ) schedule = models.Schedule( - scheduler_id=self.scheduler.scheduler_id, hash=boefje_task.hash, data=boefje_task.model_dump() + scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, + hash=boefje_task.hash, + data=boefje_task.model_dump(), ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) @@ -2035,7 +2028,7 @@ def test_push_tasks_for_rescheduling_boefje_cannot_scan_ooi(self): self.mock_get_plugin.return_value = plugin # Act - self.scheduler.push_tasks_for_rescheduling() + self.scheduler.process_rescheduling() # Assert: item should not be on queue self.assertEqual(0, self.scheduler.queue.qsize()) diff --git a/mula/tests/integration/test_normalizer_scheduler.py b/mula/tests/integration/test_normalizer_scheduler.py index 493b4bd3f54..ad3d72e9ea8 100644 --- a/mula/tests/integration/test_normalizer_scheduler.py +++ b/mula/tests/integration/test_normalizer_scheduler.py @@ -41,10 +41,10 @@ def setUp(self): ) # Scheduler + self.scheduler = schedulers.NormalizerScheduler(self.mock_ctx) + + # Organisation self.organisation = OrganisationFactory() - self.scheduler = schedulers.NormalizerScheduler( - ctx=self.mock_ctx, scheduler_id=self.organisation.id, organisation=self.organisation - ) def tearDown(self): self.scheduler.stop() @@ -64,57 +64,6 @@ def setUp(self): "scheduler.context.AppContext.services.katalogus.get_plugin_by_id_and_org_id" ).start() - def test_disable_scheduler(self): - # Act - self.scheduler.disable() - - # Listeners should be stopped - self.assertEqual(0, len(self.scheduler.listeners)) - - # Threads should be stopped - self.assertEqual(0, len(self.scheduler.threads)) - - # Queue should be empty - self.assertEqual(0, self.scheduler.queue.qsize()) - - # All tasks on queue should be set to CANCELLED - tasks, _ = self.mock_ctx.datastores.task_store.get_tasks(self.scheduler.scheduler_id) - for task in tasks: - self.assertEqual(task.status, models.TaskStatus.CANCELLED) - - # Scheduler should be disabled - self.assertFalse(self.scheduler.is_enabled()) - - def test_enable_scheduler(self): - # Disable scheduler first - self.scheduler.disable() - - # Listeners should be stopped - self.assertEqual(0, len(self.scheduler.listeners)) - - # Threads should be stopped - self.assertEqual(0, len(self.scheduler.threads)) - - # Queue should be empty - self.assertEqual(0, self.scheduler.queue.qsize()) - - # All tasks on queue should be set to CANCELLED - tasks, _ = self.mock_ctx.datastores.task_store.get_tasks(self.scheduler.scheduler_id) - for task in tasks: - self.assertEqual(task.status, models.TaskStatus.CANCELLED) - - # Re-enable scheduler - self.scheduler.enable() - - # Threads should be started - self.assertGreater(len(self.scheduler.threads), 0) - - # Scheduler should be enabled - self.assertTrue(self.scheduler.is_enabled()) - - # Stop the scheduler - self.scheduler.stop() - def test_is_allowed_to_run(self): # Arrange plugin = PluginFactory(type="normalizer", consumes=["text/plain"]) @@ -151,7 +100,7 @@ def test_get_normalizers_for_mime_type(self, mock_get_normalizers_by_org_id_and_ mock_get_normalizers_by_org_id_and_type.return_value = [normalizer] # Act - result = self.scheduler.get_normalizers_for_mime_type("text/plain") + result = self.scheduler.get_normalizers_for_mime_type("text/plain", self.organisation.id) # Assert self.assertEqual(len(result), 1) @@ -166,7 +115,7 @@ def test_get_normalizers_for_mime_type_request_exception(self, mock_get_normaliz ] # Act - result = self.scheduler.get_normalizers_for_mime_type("text/plain") + result = self.scheduler.get_normalizers_for_mime_type("text/plain", self.organisation.id) # Assert self.assertEqual(len(result), 0) @@ -177,7 +126,7 @@ def test_get_normalizers_for_mime_type_response_is_none(self, mock_get_normalize mock_get_normalizers_by_org_id_and_type.return_value = None # Act - result = self.scheduler.get_normalizers_for_mime_type("text/plain") + result = self.scheduler.get_normalizers_for_mime_type("text/plain", self.organisation.id) # Assert self.assertEqual(len(result), 0) @@ -199,7 +148,11 @@ def setUp(self): "scheduler.schedulers.NormalizerScheduler.get_normalizers_for_mime_type" ).start() - def test_push_tasks_for_received_raw_file(self): + self.mock_get_plugin = mock.patch( + "scheduler.context.AppContext.services.katalogus.get_plugin_by_id_and_org_id" + ).start() + + def test_process_raw_data(self): # Arrange ooi = OOIFactory(scan_profile=ScanProfileFactory(level=0)) boefje = BoefjeFactory() @@ -208,7 +161,7 @@ def test_push_tasks_for_received_raw_file(self): # Arrange: create the RawDataReceivedEvent raw_data_event = models.RawDataReceivedEvent( raw_data=RawDataFactory(boefje_meta=boefje_meta, mime_types=[{"value": "text/plain"}]), - organization=self.organisation.name, + organization=self.organisation.id, created_at=datetime.datetime.now(), ).model_dump_json() @@ -217,7 +170,7 @@ def test_push_tasks_for_received_raw_file(self): self.mock_get_normalizers_for_mime_type.return_value = [plugin] # Act - self.scheduler.push_tasks_for_received_raw_data(raw_data_event) + self.scheduler.process_raw_data(raw_data_event) # Task should be on priority queue task_pq = self.scheduler.queue.peek(0) @@ -228,7 +181,7 @@ def test_push_tasks_for_received_raw_file(self): self.assertEqual(task_db.id, task_pq.id) self.assertEqual(task_db.status, models.TaskStatus.QUEUED) - def test_push_tasks_for_received_raw_file_no_normalizers_found(self): + def test_process_raw_data_no_normalizers_found(self): # Arrange ooi = OOIFactory(scan_profile=ScanProfileFactory(level=0)) boefje = BoefjeFactory() @@ -236,7 +189,7 @@ def test_push_tasks_for_received_raw_file_no_normalizers_found(self): raw_data_event = models.RawDataReceivedEvent( raw_data=RawDataFactory(boefje_meta=boefje_meta, mime_types=[{"value": "text/plain"}]), - organization=self.organisation.name, + organization=self.organisation.id, created_at=datetime.datetime.now(), ).model_dump_json() @@ -244,19 +197,21 @@ def test_push_tasks_for_received_raw_file_no_normalizers_found(self): self.mock_get_normalizers_for_mime_type.return_value = [] # Act - self.scheduler.push_tasks_for_received_raw_data(raw_data_event) + self.scheduler.process_raw_data(raw_data_event) # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_received_raw_file_not_allowed_to_run(self): + def test_process_raw_data_not_allowed_to_run(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) boefje = BoefjeFactory() boefje_task = models.BoefjeTask(boefje=boefje, input_ooi=ooi.primary_key, organization=self.organisation.id) - task = functions.create_task(scheduler_id=self.scheduler.scheduler_id, data=boefje_task) + task = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, data=boefje_task, organisation=self.organisation.id + ) self.mock_ctx.datastores.task_store.create_task(task) boefje_meta = BoefjeMetaFactory(boefje=boefje, input_ooi=ooi.primary_key) @@ -264,7 +219,7 @@ def test_push_tasks_for_received_raw_file_not_allowed_to_run(self): # Mocks raw_data_event = models.RawDataReceivedEvent( raw_data=RawDataFactory(boefje_meta=boefje_meta, mime_types=[{"value": "text/plain"}]), - organization=self.organisation.name, + organization=self.organisation.id, created_at=datetime.datetime.now(), ).model_dump_json() @@ -272,19 +227,21 @@ def test_push_tasks_for_received_raw_file_not_allowed_to_run(self): self.mock_has_normalizer_permission_to_run.return_value = False # Act - self.scheduler.push_tasks_for_received_raw_data(raw_data_event) + self.scheduler.process_raw_data(raw_data_event) # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_received_raw_file_still_running(self): + def test_process_raw_data_still_running(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) boefje = BoefjeFactory() boefje_task = models.BoefjeTask(boefje=boefje, input_ooi=ooi.primary_key, organization=self.organisation.id) - task = functions.create_task(scheduler_id=self.scheduler.scheduler_id, data=boefje_task) + task = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, data=boefje_task, organisation=self.organisation.id + ) self.mock_ctx.datastores.task_store.create_task(task) boefje_meta = BoefjeMetaFactory(boefje=boefje, input_ooi=ooi.primary_key) @@ -292,7 +249,7 @@ def test_push_tasks_for_received_raw_file_still_running(self): # Mocks raw_data_event = models.RawDataReceivedEvent( raw_data=RawDataFactory(boefje_meta=boefje_meta, mime_types=[{"value": "text/plain"}]), - organization=self.organisation.name, + organization=self.organisation.id, created_at=datetime.datetime.now(), ).model_dump_json() @@ -301,19 +258,21 @@ def test_push_tasks_for_received_raw_file_still_running(self): self.mock_has_normalizer_task_started_running.return_value = True # Act - self.scheduler.push_tasks_for_received_raw_data(raw_data_event) + self.scheduler.process_raw_data(raw_data_event) # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_received_raw_file_still_running_exception(self): + def test_process_raw_data_still_running_exception(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) boefje = BoefjeFactory() boefje_task = models.BoefjeTask(boefje=boefje, input_ooi=ooi.primary_key, organization=self.organisation.id) - task = functions.create_task(scheduler_id=self.scheduler.scheduler_id, data=boefje_task) + task = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, data=boefje_task, organisation=self.organisation.id + ) self.mock_ctx.datastores.task_store.create_task(task) boefje_meta = BoefjeMetaFactory(boefje=boefje, input_ooi=ooi.primary_key) @@ -321,7 +280,7 @@ def test_push_tasks_for_received_raw_file_still_running_exception(self): # Mocks raw_data_event = models.RawDataReceivedEvent( raw_data=RawDataFactory(boefje_meta=boefje_meta, mime_types=[{"value": "text/plain"}]), - organization=self.organisation.name, + organization=self.organisation.id, created_at=datetime.datetime.now(), ).model_dump_json() @@ -330,12 +289,12 @@ def test_push_tasks_for_received_raw_file_still_running_exception(self): self.mock_has_normalizer_task_started_running.side_effect = Exception("Something went wrong") # Act - self.scheduler.push_tasks_for_received_raw_data(raw_data_event) + self.scheduler.process_raw_data(raw_data_event) # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_received_raw_file_item_on_queue(self): + def test_process_raw_data_item_on_queue(self): # Arrange ooi = OOIFactory(scan_profile=ScanProfileFactory(level=0)) boefje = BoefjeFactory() @@ -343,13 +302,13 @@ def test_push_tasks_for_received_raw_file_item_on_queue(self): raw_data_event1 = models.RawDataReceivedEvent( raw_data=RawDataFactory(boefje_meta=boefje_meta, mime_types=[{"value": "text/plain"}]), - organization=self.organisation.name, + organization=self.organisation.id, created_at=datetime.datetime.now(), ).model_dump_json() raw_data_event2 = models.RawDataReceivedEvent( raw_data=RawDataFactory(boefje_meta=boefje_meta, mime_types=[{"value": "text/plain"}]), - organization=self.organisation.name, + organization=self.organisation.id, created_at=datetime.datetime.now(), ).model_dump_json() @@ -357,8 +316,8 @@ def test_push_tasks_for_received_raw_file_item_on_queue(self): self.mock_get_normalizers_for_mime_type.return_value = [NormalizerFactory()] # Act - self.scheduler.push_tasks_for_received_raw_data(raw_data_event1) - self.scheduler.push_tasks_for_received_raw_data(raw_data_event2) + self.scheduler.process_raw_data(raw_data_event1) + self.scheduler.process_raw_data(raw_data_event2) # Task should be on priority queue (only one) task_pq = self.scheduler.queue.peek(0) @@ -369,31 +328,33 @@ def test_push_tasks_for_received_raw_file_item_on_queue(self): self.assertEqual(task_db.id, task_pq.id) self.assertEqual(task_db.status, models.TaskStatus.QUEUED) - def test_push_tasks_for_received_raw_file_error_mimetype(self): + def test_process_raw_data_error_mimetype(self): # Arrange scan_profile = ScanProfileFactory(level=0) ooi = OOIFactory(scan_profile=scan_profile) boefje = BoefjeFactory() boefje_task = models.BoefjeTask(boefje=boefje, input_ooi=ooi.primary_key, organization=self.organisation.id) - task = functions.create_task(scheduler_id=self.scheduler.scheduler_id, data=boefje_task) + task = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, data=boefje_task, organisation=self.organisation.id + ) self.mock_ctx.datastores.task_store.create_task(task) boefje_meta = BoefjeMetaFactory(boefje=boefje, input_ooi=ooi.primary_key) raw_data_event = models.RawDataReceivedEvent( raw_data=RawDataFactory(boefje_meta=boefje_meta, mime_types=[{"value": "error/unknown"}]), - organization=self.organisation.name, + organization=self.organisation.id, created_at=datetime.datetime.now(), ).model_dump_json() # Act - self.scheduler.push_tasks_for_received_raw_data(raw_data_event) + self.scheduler.process_raw_data(raw_data_event) # Task should not be on priority queue self.assertEqual(0, self.scheduler.queue.qsize()) - def test_push_tasks_for_received_raw_file_queue_full(self): + def test_process_raw_data_queue_full(self): events = [] for _ in range(0, 2): # Arrange @@ -401,14 +362,16 @@ def test_push_tasks_for_received_raw_file_queue_full(self): ooi = OOIFactory(scan_profile=scan_profile) boefje = BoefjeFactory() boefje_task = models.BoefjeTask(boefje=boefje, input_ooi=ooi.primary_key, organization=self.organisation.id) - task = functions.create_task(scheduler_id=self.scheduler.scheduler_id, data=boefje_task) + task = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, data=boefje_task, organisation=self.organisation.id + ) self.mock_ctx.datastores.task_store.create_task(task) boefje_meta = BoefjeMetaFactory(boefje=boefje, input_ooi=ooi.primary_key) raw_data_event = models.RawDataReceivedEvent( raw_data=RawDataFactory(boefje_meta=boefje_meta, mime_types=[{"value": "text/plain"}]), - organization=self.organisation.name, + organization=self.organisation.id, created_at=datetime.datetime.now(), ).model_dump_json() @@ -421,13 +384,13 @@ def test_push_tasks_for_received_raw_file_queue_full(self): self.mock_get_normalizers_for_mime_type.return_value = [NormalizerFactory()] # Act - self.scheduler.push_tasks_for_received_raw_data(events[0]) + self.scheduler.process_raw_data(events[0]) # Assert self.assertEqual(1, self.scheduler.queue.qsize()) with capture_logs() as cm: - self.scheduler.push_tasks_for_received_raw_data(events[1]) + self.scheduler.process_raw_data(events[1]) - self.assertIn("Could not add task to queue, queue was full", cm[-1].get("event")) + self.assertIn("Queue is full", cm[-1].get("event")) self.assertEqual(1, self.scheduler.queue.qsize()) diff --git a/mula/tests/integration/test_pq_store.py b/mula/tests/integration/test_pq_store.py index 0ace0867758..8fd9b6e6d91 100644 --- a/mula/tests/integration/test_pq_store.py +++ b/mula/tests/integration/test_pq_store.py @@ -38,7 +38,7 @@ def tearDown(self): def test_push(self): # Arrange - item = functions.create_item(scheduler_id=uuid.uuid4().hex, priority=1) + item = functions.create_task(scheduler_id=uuid.uuid4().hex, organisation=self.organisation.id, priority=1) item.status = models.TaskStatus.QUEUED created_item = self.mock_ctx.datastores.pq_store.push(item) @@ -50,7 +50,7 @@ def test_push(self): self.assertEqual(item_db.id, created_item.id) def test_push_status_not_queued(self): - item = functions.create_item(scheduler_id=uuid.uuid4().hex, priority=1) + item = functions.create_task(scheduler_id=uuid.uuid4().hex, organisation=self.organisation.id, priority=1) item.status = models.TaskStatus.PENDING created_item = self.mock_ctx.datastores.pq_store.push(item) @@ -62,24 +62,26 @@ def test_push_status_not_queued(self): def test_pop(self): # Arrange - item = functions.create_item(scheduler_id=uuid.uuid4().hex, priority=1) + item = functions.create_task(scheduler_id=uuid.uuid4().hex, organisation=self.organisation.id, priority=1) item.status = models.TaskStatus.QUEUED created_item = self.mock_ctx.datastores.pq_store.push(item) - popped_item = self.mock_ctx.datastores.pq_store.pop(item.scheduler_id) + popped_items, count = self.mock_ctx.datastores.pq_store.pop(item.scheduler_id) # Assert - self.assertIsNotNone(popped_item) - self.assertEqual(popped_item.id, created_item.id) + self.assertIsNotNone(popped_items) + self.assertEqual(count, 1) + self.assertEqual(popped_items[0].id, created_item.id) def test_pop_status_not_queued(self): # Arrange - item = functions.create_item(scheduler_id=uuid.uuid4().hex, priority=1) + item = functions.create_task(scheduler_id=uuid.uuid4().hex, organisation=self.organisation.id, priority=1) item.status = models.TaskStatus.PENDING created_item = self.mock_ctx.datastores.pq_store.push(item) - popped_item = self.mock_ctx.datastores.pq_store.pop(item.scheduler_id) + popped_items, count = self.mock_ctx.datastores.pq_store.pop(item.scheduler_id) # Assert self.assertIsNotNone(created_item) - self.assertIsNone(popped_item) + self.assertEqual(count, 0) + self.assertEqual(len(popped_items), 0) diff --git a/mula/tests/integration/test_report_scheduler.py b/mula/tests/integration/test_report_scheduler.py index ee35f7ab25a..269d0dd9759 100644 --- a/mula/tests/integration/test_report_scheduler.py +++ b/mula/tests/integration/test_report_scheduler.py @@ -29,10 +29,10 @@ def setUp(self): ) # Scheduler + self.scheduler = schedulers.ReportScheduler(ctx=self.mock_ctx) + + # Organisation self.organisation = OrganisationFactory() - self.scheduler = schedulers.ReportScheduler( - ctx=self.mock_ctx, scheduler_id=self.organisation.id, organisation=self.organisation - ) def tearDown(self): self.scheduler.stop() @@ -51,48 +51,16 @@ def setUp(self): def tearDown(self): mock.patch.stopall() - def test_enable_scheduler(self): - # Disable scheduler first - self.scheduler.disable() - - # Threads should be stopped - self.assertEqual(0, len(self.scheduler.threads)) - - # Queue should be empty - self.assertEqual(0, self.scheduler.queue.qsize()) - - # Re-enable scheduler - self.scheduler.enable() - - # Threads should be started - self.assertGreater(len(self.scheduler.threads), 0) - - # Scheduler should be enabled - self.assertTrue(self.scheduler.is_enabled()) - - # Stop the scheduler - self.scheduler.stop() - - def test_disable_scheduler(self): - # Disable scheduler - self.scheduler.disable() - - # Threads should be stopped - self.assertEqual(0, len(self.scheduler.threads)) - - # Queue should be empty - self.assertEqual(0, self.scheduler.queue.qsize()) - - # Scheduler should be disabled - self.assertFalse(self.scheduler.is_enabled()) - - def test_push_tasks_for_rescheduling(self): + def test_process_rescheduling(self): """When the deadline of schedules have passed, the resulting task should be added to the queue""" # Arrange report_task = models.ReportTask(organisation_id=self.organisation.id, report_recipe_id="123") schedule = models.Schedule( - scheduler_id=self.scheduler.scheduler_id, hash=report_task.hash, data=report_task.model_dump() + scheduler_id=self.scheduler.scheduler_id, + hash=report_task.hash, + data=report_task.model_dump(), + organisation=self.organisation.id, ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) @@ -101,7 +69,7 @@ def test_push_tasks_for_rescheduling(self): self.mock_get_schedules.return_value = ([schedule_db], 1) # Act - self.scheduler.push_tasks_for_rescheduling() + self.scheduler.process_rescheduling() # Assert: new item should be on queue self.assertEqual(1, self.scheduler.queue.qsize()) @@ -115,13 +83,16 @@ def test_push_tasks_for_rescheduling(self): self.assertIsNotNone(task_db) self.assertEqual(peek.id, task_db.id) - def test_push_tasks_for_rescheduling_item_on_queue(self): + def test_process_rescheduling_item_on_queue(self): """When the deadline of schedules have passed, the resulting task should be added to the queue""" # Arrange report_task = models.ReportTask(organisation_id=self.organisation.id, report_recipe_id="123") schedule = models.Schedule( - scheduler_id=self.scheduler.scheduler_id, hash=report_task.hash, data=report_task.model_dump() + scheduler_id=self.scheduler.scheduler_id, + hash=report_task.hash, + data=report_task.model_dump(), + organisation=self.organisation.id, ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) @@ -130,7 +101,7 @@ def test_push_tasks_for_rescheduling_item_on_queue(self): self.mock_get_schedules.return_value = ([schedule_db], 1) # Act - self.scheduler.push_tasks_for_rescheduling() + self.scheduler.process_rescheduling() # Assert: new item should be on queue self.assertEqual(1, self.scheduler.queue.qsize()) @@ -145,7 +116,7 @@ def test_push_tasks_for_rescheduling_item_on_queue(self): self.assertEqual(peek.id, task_db.id) # Act: push again - self.scheduler.push_tasks_for_rescheduling() + self.scheduler.process_rescheduling() # Should only be one task on queue self.assertEqual(1, self.scheduler.queue.qsize()) diff --git a/mula/tests/integration/test_schedule_store.py b/mula/tests/integration/test_schedule_store.py index df957b82171..f2f51344e61 100644 --- a/mula/tests/integration/test_schedule_store.py +++ b/mula/tests/integration/test_schedule_store.py @@ -6,6 +6,7 @@ from scheduler import config, models, storage from scheduler.storage import filters, stores +from tests.factories.organisation import OrganisationFactory from tests.utils import functions @@ -28,27 +29,40 @@ def setUp(self): } ) + # Organisation + self.organisation = OrganisationFactory() + def tearDown(self): models.Base.metadata.drop_all(self.dbconn.engine) self.dbconn.engine.dispose() def test_create_schedule_calculate_deadline_at(self): """When a schedule is created, the deadline_at should be calculated.""" - schedule = models.Schedule(scheduler_id="test_scheduler_id", schedule="* * * * *", data={}) + schedule = models.Schedule( + scheduler_id="test_scheduler_id", organisation=self.organisation.id, schedule="* * * * *", data={} + ) self.assertIsNotNone(schedule.deadline_at) def test_create_schedule_explicit_deadline_at(self): """When a schedule is created, the deadline_at should be set if it is provided.""" now = datetime.now(timezone.utc) - schedule = models.Schedule(scheduler_id="test_scheduler_id", data={}, deadline_at=now) + schedule = models.Schedule( + scheduler_id="test_scheduler_id", organisation=self.organisation.id, data={}, deadline_at=now + ) self.assertEqual(schedule.deadline_at, now) def test_create_schedule_deadline_at_takes_precedence(self): """When a schedule is created, the deadline_at should be set if it is provided.""" now = datetime.now(timezone.utc) - schedule = models.Schedule(scheduler_id="test_scheduler_id", schedule="* * * * *", data={}, deadline_at=now) + schedule = models.Schedule( + scheduler_id="test_scheduler_id", + schedule="* * * * *", + organisation=self.organisation.id, + data={}, + deadline_at=now, + ) self.assertEqual(schedule.deadline_at, now) @@ -56,8 +70,10 @@ def test_create_schedule(self): # Arrange scheduler_id = "test_scheduler_id" - task = functions.create_item(scheduler_id, 1) - schedule = models.Schedule(scheduler_id=scheduler_id, hash=task.hash, data=task.model_dump()) + task = functions.create_task(scheduler_id=scheduler_id, organisation=self.organisation.id, priority=1) + schedule = models.Schedule( + scheduler_id=scheduler_id, organisation=self.organisation.id, hash=task.hash, data=task.model_dump() + ) # Act schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) @@ -69,14 +85,18 @@ def test_get_schedules(self): # Arrange scheduler_one = "test_scheduler_one" for i in range(5): - task = functions.create_item(scheduler_one, 1) - schedule = models.Schedule(scheduler_id=scheduler_one, hash=task.hash, data=task.model_dump()) + task = functions.create_task(scheduler_id=scheduler_one, organisation=self.organisation.id, priority=1) + schedule = models.Schedule( + scheduler_id=scheduler_one, organisation=self.organisation.id, hash=task.hash, data=task.model_dump() + ) self.mock_ctx.datastores.schedule_store.create_schedule(schedule) scheduler_two = "test_scheduler_two" for i in range(5): - task = functions.create_item(scheduler_two, 1) - schedule = models.Schedule(scheduler_id=scheduler_two, hash=task.hash, data=task.model_dump()) + task = functions.create_task(scheduler_id=scheduler_two, organisation=self.organisation.id, priority=1) + schedule = models.Schedule( + scheduler_id=scheduler_two, organisation=self.organisation.id, hash=task.hash, data=task.model_dump() + ) self.mock_ctx.datastores.schedule_store.create_schedule(schedule) schedules_scheduler_one, schedules_scheduler_one_count = self.mock_ctx.datastores.schedule_store.get_schedules( @@ -99,8 +119,10 @@ def test_get_schedules(self): def test_get_schedule(self): # Arrange scheduler_id = "test_scheduler_id" - task = functions.create_item(scheduler_id, 1) - schedule = models.Schedule(scheduler_id=scheduler_id, hash=task.hash, data=task.model_dump()) + task = functions.create_task(scheduler_id=scheduler_id, organisation=self.organisation.id, priority=1) + schedule = models.Schedule( + scheduler_id=scheduler_id, organisation=self.organisation.id, hash=task.hash, data=task.model_dump() + ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) # Act @@ -113,7 +135,9 @@ def test_get_schedule_by_hash(self): # Arrange scheduler_id = "test_scheduler_id" data = functions.create_test_model() - schedule = models.Schedule(scheduler_id=scheduler_id, hash=data.hash, data=data.model_dump()) + schedule = models.Schedule( + scheduler_id=scheduler_id, organisation=self.organisation.id, hash=data.hash, data=data.model_dump() + ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) # Act @@ -127,8 +151,10 @@ def test_get_schedule_by_hash(self): def test_update_schedule(self): # Arrange scheduler_id = "test_scheduler_id" - task = functions.create_item(scheduler_id, 1) - schedule = models.Schedule(scheduler_id=scheduler_id, hash=task.hash, data=task.model_dump()) + task = functions.create_task(scheduler_id=scheduler_id, organisation=self.organisation.id, priority=1) + schedule = models.Schedule( + scheduler_id=scheduler_id, organisation=self.organisation.id, hash=task.hash, data=task.model_dump() + ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) # Assert @@ -145,8 +171,10 @@ def test_update_schedule(self): def test_delete_schedule(self): # Arrange scheduler_id = "test_scheduler_id" - task = functions.create_item(scheduler_id, 1) - schedule = models.Schedule(scheduler_id=scheduler_id, hash=task.hash, data=task.model_dump()) + task = functions.create_task(scheduler_id=scheduler_id, organisation=self.organisation.id, priority=1) + schedule = models.Schedule( + scheduler_id=scheduler_id, organisation=self.organisation.id, hash=task.hash, data=task.model_dump() + ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) # Act @@ -160,8 +188,10 @@ def test_delete_schedule_ondelete(self): """When a schedule is deleted, its tasks should NOT be deleted.""" # Arrange scheduler_id = "test_scheduler_id" - task = functions.create_item(scheduler_id, 1) - schedule = models.Schedule(scheduler_id=scheduler_id, hash=task.hash, data=task.model_dump()) + task = functions.create_task(scheduler_id=scheduler_id, organisation=self.organisation.id, priority=1) + schedule = models.Schedule( + scheduler_id=scheduler_id, organisation=self.organisation.id, hash=task.hash, data=task.model_dump() + ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) task.schedule_id = schedule_db.id @@ -181,8 +211,10 @@ def test_delete_schedule_ondelete(self): def test_relationship_schedule_tasks(self): # Arrange scheduler_id = "test_scheduler_id" - task = functions.create_task(scheduler_id) - schedule = models.Schedule(scheduler_id=scheduler_id, hash=task.hash, data=task.model_dump()) + task = functions.create_task(scheduler_id=scheduler_id, organisation=self.organisation.id) + schedule = models.Schedule( + scheduler_id=scheduler_id, organisation=self.organisation.id, hash=task.hash, data=task.model_dump() + ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) task.schedule_id = schedule_db.id @@ -198,8 +230,10 @@ def test_relationship_schedule_tasks(self): def test_get_tasks_filter_related(self): # Arrange scheduler_id = "test_scheduler_id" - task = functions.create_task(scheduler_id) - schedule = models.Schedule(scheduler_id=scheduler_id, hash=task.hash, data=task.model_dump()) + task = functions.create_task(scheduler_id=scheduler_id, organisation=self.organisation.id) + schedule = models.Schedule( + scheduler_id=scheduler_id, organisation=self.organisation.id, hash=task.hash, data=task.model_dump() + ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) task.schedule_id = schedule_db.id diff --git a/mula/tests/integration/test_scheduler.py b/mula/tests/integration/test_scheduler.py index 7483f5b9880..1c685c55354 100644 --- a/mula/tests/integration/test_scheduler.py +++ b/mula/tests/integration/test_scheduler.py @@ -5,10 +5,10 @@ from unittest import mock from scheduler import config, models, storage -from scheduler.schedulers.queue import InvalidItemError, NotAllowedError, QueueEmptyError, QueueFullError +from scheduler.schedulers.queue import InvalidItemError, QueueEmptyError, QueueFullError from scheduler.storage import stores -from structlog.testing import capture_logs +from tests.factories import OrganisationFactory from tests.mocks import item as mock_item from tests.mocks import queue as mock_queue from tests.mocks import scheduler as mock_scheduler @@ -49,6 +49,9 @@ def setUp(self): ctx=self.mock_ctx, scheduler_id=identifier, queue=queue, create_schedule=True ) + # Organisation + self.organisation = OrganisationFactory() + def tearDown(self): self.scheduler.stop() models.Base.metadata.drop_all(self.dbconn.engine) @@ -58,7 +61,9 @@ def test_push_items_to_queue(self): # Arrange items = [] for i in range(10): - item = functions.create_item(scheduler_id=self.scheduler.scheduler_id, priority=i + 1) + item = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, organisation=self.organisation.id, priority=i + 1 + ) items.append(item) # Act @@ -84,7 +89,9 @@ def test_push_items_to_queue(self): def test_push_item_to_queue(self): # Arrange - item = functions.create_item(scheduler_id=self.scheduler.scheduler_id, priority=1) + item = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, organisation=self.organisation.id, priority=1 + ) # Act self.scheduler.push_item_to_queue(item) @@ -108,7 +115,9 @@ def test_push_item_to_queue_create_schedule_false(self): # Arrange self.scheduler.create_schedule = False - item = functions.create_item(scheduler_id=self.scheduler.scheduler_id, priority=1) + item = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, organisation=self.organisation.id, priority=1 + ) # Act self.scheduler.push_item_to_queue(item) @@ -130,7 +139,9 @@ def test_push_item_to_queue_create_schedule_false(self): def test_push_item_to_queue_full(self): # Arrange - item = functions.create_item(scheduler_id=self.scheduler.scheduler_id, priority=1) + item = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, organisation=self.organisation.id, priority=1 + ) self.scheduler.queue.maxsize = 1 @@ -147,7 +158,9 @@ def test_push_item_to_queue_full(self): def test_push_item_to_queue_invalid(self): # Arrange - item = functions.create_item(scheduler_id=self.scheduler.scheduler_id, priority=1) + item = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, organisation=self.organisation.id, priority=1 + ) item.data = {"invalid": "data"} # Assert @@ -156,16 +169,24 @@ def test_push_item_to_queue_invalid(self): def test_pop_item_from_queue(self): # Arrange - item = functions.create_item(scheduler_id=self.scheduler.scheduler_id, priority=1) + item = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, organisation=self.organisation.id, priority=1 + ) self.scheduler.push_item_to_queue(item) # Act - popped_item = self.scheduler.pop_item_from_queue() + popped_items, count = self.scheduler.pop_item_from_queue() # Assert self.assertEqual(0, self.scheduler.queue.qsize()) - self.assertEqual(item.id, popped_item.id) + self.assertEqual(1, count) + self.assertEqual(1, len(popped_items)) + self.assertEqual(popped_items[0].id, item.id) + + # Status should be dispatched + task_db = self.mock_ctx.datastores.task_store.get_task(str(item.id)) + self.assertEqual(task_db.status, models.TaskStatus.DISPATCHED) def test_pop_item_from_queue_empty(self): self.assertEqual(0, self.scheduler.queue.qsize()) @@ -175,7 +196,9 @@ def test_pop_item_from_queue_empty(self): def test_post_push(self): """When a task is added to the queue, it should be added to the database""" # Arrange - item = functions.create_item(scheduler_id=self.scheduler.scheduler_id, priority=1) + item = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, organisation=self.organisation.id, priority=1 + ) # Act self.scheduler.push_item_to_queue(item) @@ -207,7 +230,9 @@ def test_post_push(self): def test_post_push_schedule_enabled(self): # Arrange - item = functions.create_item(scheduler_id=self.scheduler.scheduler_id, priority=1) + item = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, organisation=self.organisation.id, priority=1 + ) # Act self.scheduler.push_item_to_queue(item) @@ -237,35 +262,11 @@ def test_post_push_schedule_enabled(self): # grace period self.assertGreater(schedule_db.deadline_at, datetime.now(timezone.utc)) - def test_post_push_schedule_disabled(self): - # Arrange - first_item = functions.create_item(scheduler_id=self.scheduler.scheduler_id, priority=1) - - # Act - first_item_db = self.scheduler.push_item_to_queue(first_item) - - initial_schedule_db = self.mock_ctx.datastores.schedule_store.get_schedule(first_item_db.schedule_id) - - # Pop - self.scheduler.pop_item_from_queue() - - # Disable this schedule - initial_schedule_db.enabled = False - self.mock_ctx.datastores.schedule_store.update_schedule(initial_schedule_db) - - # Act - second_item = first_item_db.model_copy() - second_item.id = uuid.uuid4() - second_item_db = self.scheduler.push_item_to_queue(second_item) - - with capture_logs() as cm: - self.scheduler.post_push(second_item_db) - - self.assertIn("is disabled, not updating deadline", cm[-1].get("event")) - def test_post_push_schedule_update_schedule(self): # Arrange - first_item = functions.create_item(scheduler_id=self.scheduler.scheduler_id, priority=1) + first_item = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, organisation=self.organisation.id, priority=1 + ) # Act first_item_db = self.scheduler.push_item_to_queue(first_item) @@ -294,10 +295,16 @@ def test_post_push_schedule_update_schedule(self): def test_post_push_schedule_is_not_none(self): """When a schedule is provided, it should be used to set the deadline""" # Arrange - first_item = functions.create_item(scheduler_id=self.scheduler.scheduler_id, priority=1) + first_item = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, organisation=self.organisation.id, priority=1 + ) schedule = models.Schedule( - scheduler_id=self.scheduler.scheduler_id, schedule="0 0 * * *", hash=first_item.hash, data=first_item.data + scheduler_id=self.scheduler.scheduler_id, + organisation=self.organisation.id, + schedule="0 0 * * *", + hash=first_item.hash, + data=first_item.data, ) schedule_db = self.mock_ctx.datastores.schedule_store.create_schedule(schedule) @@ -317,10 +324,8 @@ def test_post_push_schedule_is_not_none(self): def test_post_pop(self): """When a task is popped from the queue, it should be removed from the database""" # Arrange - item = functions.create_item( - scheduler_id=self.scheduler.scheduler_id, - priority=1, - task=functions.create_task(self.scheduler.scheduler_id), + item = functions.create_task( + scheduler_id=self.scheduler.scheduler_id, organisation=self.organisation.id, priority=1 ) # Act @@ -344,101 +349,3 @@ def test_post_pop(self): task_db = self.mock_ctx.datastores.task_store.get_task(str(item.id)) self.assertEqual(task_db.id, item.id) self.assertEqual(task_db.status, models.TaskStatus.DISPATCHED) - - def test_disable_scheduler(self): - # Arrange: start scheduler - self.scheduler.run() - - # Arrange: add tasks - item = functions.create_item(scheduler_id=self.scheduler.scheduler_id, priority=1) - self.scheduler.push_item_to_queue(item) - - # Assert: task should be on priority queue - pq_item = self.scheduler.queue.peek(0) - self.assertEqual(1, self.scheduler.queue.qsize()) - self.assertEqual(pq_item.id, item.id) - - # Assert: task should be in datastore, and queued - task_db = self.mock_ctx.datastores.task_store.get_task(str(item.id)) - self.assertEqual(task_db.id, item.id) - self.assertEqual(task_db.status, models.TaskStatus.QUEUED) - - # Assert: listeners should be running - self.assertGreater(len(self.scheduler.listeners), 0) - - # Assert: threads should be running - self.assertGreater(len(self.scheduler.threads), 0) - - # Act - self.scheduler.disable() - - # Listeners should be stopped - self.assertEqual(0, len(self.scheduler.listeners)) - - # Threads should be stopped - self.assertEqual(0, len(self.scheduler.threads)) - - # Queue should be empty - self.assertEqual(0, self.scheduler.queue.qsize()) - - # All tasks on queue should be set to CANCELLED - tasks, _ = self.mock_ctx.datastores.task_store.get_tasks(self.scheduler.scheduler_id) - for task in tasks: - self.assertEqual(task.status, models.TaskStatus.CANCELLED) - - # Scheduler should be disabled - self.assertFalse(self.scheduler.is_enabled()) - - with self.assertRaises(NotAllowedError): - self.scheduler.push_item_to_queue(item) - - def test_enable_scheduler(self): - # Arrange: start scheduler - self.scheduler.run() - - # Arrange: add tasks - item = functions.create_item(scheduler_id=self.scheduler.scheduler_id, priority=1) - self.scheduler.push_item_to_queue(item) - - # Assert: listeners should be running - self.assertGreater(len(self.scheduler.listeners), 0) - - # Assert: threads should be running - self.assertGreater(len(self.scheduler.threads), 0) - - # Disable scheduler first - self.scheduler.disable() - - # Listeners should be stopped - self.assertEqual(0, len(self.scheduler.listeners)) - - # Threads should be stopped - self.assertEqual(0, len(self.scheduler.threads)) - - # Queue should be empty - self.assertEqual(0, self.scheduler.queue.qsize()) - - # All tasks on queue should be set to CANCELLED - tasks, _ = self.mock_ctx.datastores.task_store.get_tasks(self.scheduler.scheduler_id) - for task in tasks: - self.assertEqual(task.status, models.TaskStatus.CANCELLED) - - # Re-enable scheduler - self.scheduler.enable() - - # Threads should be started - self.assertGreater(len(self.scheduler.threads), 0) - - # Scheduler should be enabled - self.assertTrue(self.scheduler.is_enabled()) - - # Push item to the queue - self.scheduler.push_item_to_queue(item) - - # Assert: task should be on priority queue - pq_item = self.scheduler.queue.peek(0) - self.assertEqual(1, self.scheduler.queue.qsize()) - self.assertEqual(pq_item.id, item.id) - - # Stop the scheduler - self.scheduler.stop() diff --git a/mula/tests/integration/test_task_store.py b/mula/tests/integration/test_task_store.py index c672fc78557..30d7cc81857 100644 --- a/mula/tests/integration/test_task_store.py +++ b/mula/tests/integration/test_task_store.py @@ -37,14 +37,14 @@ def tearDown(self): self.dbconn.engine.dispose() def test_create_task(self): - task = functions.create_task(scheduler_id=self.organisation.id) + task = functions.create_task(scheduler_id=self.organisation.id, organisation=self.organisation.id) created_task = self.mock_ctx.datastores.task_store.create_task(task) self.assertIsNotNone(created_task) def test_get_tasks(self): # Arrange for i in range(5): - task = functions.create_task(scheduler_id=self.organisation.id) + task = functions.create_task(scheduler_id=self.organisation.id, organisation=self.organisation.id) self.mock_ctx.datastores.task_store.create_task(task) # Act @@ -57,7 +57,7 @@ def test_get_tasks(self): def get_tasks_by_type(self): # Arrange for i in range(5): - task = functions.create_task(scheduler_id=self.organisation.id) + task = functions.create_task(scheduler_id=self.organisation.id, organisation=self.organisation.id) self.mock_ctx.datastores.task_store.create_task(task) # Act @@ -74,7 +74,9 @@ def test_get_tasks_by_hash(self): hashes = [] data = functions.create_test_model() for i in range(5): - task = functions.create_task(scheduler_id=self.organisation.id, data=data) + task = functions.create_task( + scheduler_id=self.organisation.id, organisation=self.organisation.id, data=data + ) self.mock_ctx.datastores.task_store.create_task(task) hashes.append(task.hash) @@ -89,7 +91,7 @@ def test_get_tasks_by_hash(self): def test_get_task(self): # Arrange - task = functions.create_task(scheduler_id=self.organisation.id) + task = functions.create_task(scheduler_id=self.organisation.id, organisation=self.organisation.id) created_task = self.mock_ctx.datastores.task_store.create_task(task) # Act @@ -103,7 +105,9 @@ def test_get_latest_task_by_hash(self): hashes = [] data = functions.create_test_model() for i in range(5): - task = functions.create_task(scheduler_id=self.organisation.id, data=data) + task = functions.create_task( + scheduler_id=self.organisation.id, organisation=self.organisation.id, data=data + ) self.mock_ctx.datastores.task_store.create_task(task) hashes.append(task.hash) @@ -118,7 +122,7 @@ def test_get_latest_task_by_hash(self): def test_update_task(self): # Arrange - task = functions.create_task(scheduler_id=self.organisation.id) + task = functions.create_task(scheduler_id=self.organisation.id, organisation=self.organisation.id) created_task = self.mock_ctx.datastores.task_store.create_task(task) # Act @@ -131,7 +135,7 @@ def test_update_task(self): def test_cancel_task(self): # Arrange - task = functions.create_task(scheduler_id=self.organisation.id) + task = functions.create_task(scheduler_id=self.organisation.id, organisation=self.organisation.id) created_task = self.mock_ctx.datastores.task_store.create_task(task) # Act @@ -163,6 +167,7 @@ def test_get_status_counts(self): data = functions.create_test_model() task = models.Task( scheduler_id=self.organisation.id, + organisation=self.organisation.id, priority=1, status=status, type=functions.TestModel.type, @@ -203,6 +208,7 @@ def test_get_status_count_per_hour(self): data = functions.create_test_model() task = models.Task( scheduler_id=self.organisation.id, + organisation=self.organisation.id, priority=1, status=status, type=functions.TestModel.type, diff --git a/mula/tests/unit/test_queue.py b/mula/tests/unit/test_queue.py index 2861d442257..c55e6a7947b 100644 --- a/mula/tests/unit/test_queue.py +++ b/mula/tests/unit/test_queue.py @@ -43,7 +43,7 @@ def _check_queue_empty(self): def test_push(self): """When adding an item to the priority queue, the item should be added""" - item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(item) item_db = self.pq_store.get(self.pq.pq_id, item.id) @@ -57,7 +57,7 @@ def test_push_item_not_found_in_db(self, mock_push): """When adding an item to the priority queue, but the item is not found in the database, the item shouldn't be added. """ - item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) mock_push.return_value = None @@ -84,7 +84,7 @@ def test_push_invalid_item(self): """When pushing an item that can not be validated, the item shouldn't be pushed. """ - item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) item.data = {"invalid": "data"} with self.assertRaises(InvalidItemError): @@ -100,7 +100,7 @@ def test_push_replace_not_allowed(self): self.pq.allow_replace = False # Add an item to the queue - initial_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + initial_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(initial_item) self.assertEqual(1, self.pq.qsize()) @@ -119,7 +119,7 @@ def test_push_replace_allowed(self): self.pq.allow_replace = True # Add an item to the queue - initial_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + initial_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(initial_item) self.assertEqual(1, self.pq.qsize()) @@ -139,7 +139,7 @@ def test_push_updates_not_allowed(self): self.pq.allow_updates = False # Add an item to the queue - initial_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + initial_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(initial_item) self.assertEqual(1, self.pq.qsize()) @@ -164,7 +164,7 @@ def test_push_updates_allowed(self): self.pq.allow_updates = True # Add an item to the queue - initial_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + initial_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(initial_item) self.assertEqual(1, self.pq.qsize()) @@ -189,7 +189,7 @@ def test_push_priority_updates_not_allowed(self): self.pq.allow_priority_updates = False # Add an item to the queue - initial_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + initial_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(initial_item) self.assertEqual(1, self.pq.qsize()) @@ -215,7 +215,7 @@ def test_push_priority_updates_allowed(self): self.pq.allow_priority_updates = True # Add an item to the queue - initial_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + initial_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(initial_item) self.assertEqual(1, self.pq.qsize()) @@ -237,7 +237,7 @@ def test_remove_item(self): removed, and the item should be removed from the entry_finder. """ # Add an item to the queue - item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(item) self.assertEqual(1, self.pq.qsize()) @@ -255,11 +255,11 @@ def test_push_maxsize_not_allowed(self): self.pq.maxsize = 1 # Add an item to the queue - first_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + first_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(first_item) # Add another item to the queue - second_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=2) + second_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=2) with self.assertRaises(_queue.Full): self.pq.push(second_item) @@ -280,11 +280,11 @@ def test_push_maxsize_allowed(self): self.pq.maxsize = 0 # Add an item to the queue - first_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + first_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(first_item) # Add another item to the queue - second_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=2) + second_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=2) self.pq.push(second_item) # The queue should now have 2 items @@ -310,11 +310,11 @@ def test_push_maxsize_allowed_high_priority(self): self.pq.maxsize = 1 # Add an item to the queue - first_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + first_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(first_item) # Add another item to the queue - second_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + second_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(second_item) # The queue should now have 2 items @@ -340,11 +340,11 @@ def test_push_maxsize_not_allowed_low_priority(self): self.pq.maxsize = 1 # Add an item to the queue - first_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + first_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(first_item) # Add another item to the queue - second_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=2) + second_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=2) with self.assertRaises(_queue.Full): self.pq.push(second_item) @@ -362,15 +362,15 @@ def test_pop(self): it from the queue. """ # Add an item to the queue - first_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + first_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(first_item) # The queue should now have 1 item self.assertEqual(1, self.pq.qsize()) # Pop the item - popped_item = self.pq.pop() - self.assertEqual(first_item.data, popped_item.data) + popped_items, _ = self.pq.pop() + self.assertEqual(first_item.data, popped_items[0].data) # The queue should now be empty self.assertEqual(0, self.pq.qsize()) @@ -380,8 +380,8 @@ def test_pop_with_lock(self): thread to pop an item. """ # Arrange - first_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) - second_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + first_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) + second_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(first_item) self.pq.push(second_item) @@ -392,21 +392,21 @@ def test_pop_with_lock(self): # it will set a timeout so we can test the lock. def first_pop(event): with self.pq.lock: - item = self.pq_store.pop(self.pq.pq_id, None) + items, _ = self.pq_store.pop(self.pq.pq_id, None) event.set() time.sleep(5) - self.pq_store.remove(self.pq.pq_id, item.id) + self.pq_store.remove(self.pq.pq_id, items[0].id) - queue.put(item) + queue.put(items[0]) def second_pop(event): # Wait for thread 1 to set the event before continuing event.wait() - item = self.pq.pop() - queue.put(item) + items, _ = self.pq.pop() + queue.put(items[0]) # Act; with thread 1 we will create a lock on the queue, and then with # thread 2 we try to pop an item while the lock is active. @@ -430,8 +430,8 @@ def test_pop_without_lock(self): NOTE: Here we test the procedure when a lock isn't set. """ # Arrange - first_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) - second_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + first_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) + second_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(first_item) self.pq.push(second_item) @@ -441,21 +441,21 @@ def test_pop_without_lock(self): # This function is similar to the pop() function of the queue, but # it will set a timeout. We have omitted the lock here. def first_pop(event): - item = self.pq_store.pop(self.pq.pq_id, None) + items, _ = self.pq_store.pop(self.pq.pq_id, None) event.set() time.sleep(5) - self.pq_store.remove(self.pq.pq_id, item.id) + self.pq_store.remove(self.pq.pq_id, items[0].id) - queue.put(item) + queue.put(items[0]) def second_pop(event): # Wait for thread 1 to set the event before continuing event.wait() - item = self.pq.pop() - queue.put(item) + items, _ = self.pq.pop() + queue.put(items[0]) # Act; with thread 1 we won't create a lock, and then with thread 2 we # try to pop an item while the timeout is active. @@ -484,26 +484,26 @@ def test_pop_highest_priority(self): priority """ # Add an item to the queue - first_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + first_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(first_item) # Add another item to the queue - second_item = functions.create_item(scheduler_id=self.pq.pq_id, priority=2) + second_item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=2) self.pq.push(second_item) # The queue should now have 2 items self.assertEqual(2, self.pq.qsize()) # Pop the item - popped_item = self.pq.pop() - self.assertEqual(first_item.priority, popped_item.priority) + popped_items, _ = self.pq.pop() + self.assertEqual(first_item.priority, popped_items[0].priority) def test_is_item_on_queue(self): """When checking if an item is on the queue, it should return True if the item is on the queue, and False if it isn't. """ # Add an item to the queue - item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) self.pq.push(item) # Check if the item is on the queue @@ -514,7 +514,7 @@ def test_is_item_not_on_queue(self): the item is on the queue, and False if it isn't. """ # Add an item to the queue - item = functions.create_item(scheduler_id=self.pq.pq_id, priority=1) + item = functions.create_task(scheduler_id=self.pq.pq_id, organisation=self.pq.pq_id, priority=1) # Check if the item is on the queue self.assertFalse(self.pq.is_item_on_queue(item)) diff --git a/mula/tests/utils/functions.py b/mula/tests/utils/functions.py index 8eeeb875d2d..506cbf0bf52 100644 --- a/mula/tests/utils/functions.py +++ b/mula/tests/utils/functions.py @@ -34,23 +34,11 @@ def create_test_model() -> TestModel: return TestModel(id=uuid.uuid4().hex, name=uuid.uuid4().hex) -def create_task_in(priority: int, data: TestModel | None = None) -> str: +def create_task_in(priority: int, organisation: str, data: TestModel | None = None) -> str: if data is None: data = TestModel(id=uuid.uuid4().hex, name=uuid.uuid4().hex) - return json.dumps({"priority": priority, "data": data.model_dump()}) - - -def create_item(scheduler_id: str, priority: int, task: models.Task | None = None) -> models.Task: - if task is None: - task = create_task(scheduler_id) - - item = models.Task(**task.model_dump()) - - if priority is not None: - item.priority = priority - - return item + return json.dumps({"priority": priority, "organisation": organisation, "data": data.model_dump()}) def create_schedule(scheduler_id: str, data: Any | None = None) -> models.Schedule: @@ -58,11 +46,18 @@ def create_schedule(scheduler_id: str, data: Any | None = None) -> models.Schedu return models.Schedule(scheduler_id=scheduler_id, hash=item.hash, data=item.model_dump()) -def create_task(scheduler_id: str, data: Any | None = None) -> models.Task: +def create_task(scheduler_id: str, organisation: str, priority: int = 0, data: Any | None = None) -> models.Task: if data is None: data = TestModel(id=uuid.uuid4().hex, name=uuid.uuid4().hex) - return models.Task(scheduler_id=scheduler_id, type=TestModel.type, hash=data.hash, data=data.model_dump()) + return models.Task( + scheduler_id=scheduler_id, + organisation=organisation, + priority=priority, + type=TestModel.type, + hash=data.hash, + data=data.model_dump(), + ) def create_boefje() -> models.Boefje: diff --git a/octopoes/octopoes/events/manager.py b/octopoes/octopoes/events/manager.py index dc935b4c8a4..ba1918ce4ec 100644 --- a/octopoes/octopoes/events/manager.py +++ b/octopoes/octopoes/events/manager.py @@ -26,6 +26,7 @@ class ScanProfileMutation(BaseModel): operation: OperationType primary_key: str value: AbstractOOI | None = None + client_id: str thread_local = threading.local() @@ -126,7 +127,9 @@ def _publish(self, event: DBEvent) -> None: ) # publish mutations - mutation = ScanProfileMutation(operation=event.operation_type, primary_key=event.primary_key) + mutation = ScanProfileMutation( + operation=event.operation_type, primary_key=event.primary_key, client_id=event.client + ) if event.operation_type != OperationType.DELETE: mutation.value = AbstractOOI( @@ -137,7 +140,7 @@ def _publish(self, event: DBEvent) -> None: self.channel.basic_publish( "", - f"{event.client}__scan_profile_mutations", + "scan_profile_mutations", mutation.model_dump_json().encode(), properties=pika.BasicProperties(delivery_mode=pika.DeliveryMode.Persistent), ) @@ -165,4 +168,4 @@ def _try_connect(self): def _connect(self) -> None: self.channel = self.channel_factory(self.queue_uri) self.channel.queue_declare(queue=f"{self.client}__scan_profile_increments", durable=True) - self.channel.queue_declare(queue=f"{self.client}__scan_profile_mutations", durable=True) + self.channel.queue_declare(queue=f"scan_profile_mutations", durable=True)