From bd52dc7de7f1fe10b1b1096140c19633c23f1d75 Mon Sep 17 00:00:00 2001 From: Simon Mazenoux Date: Mon, 11 Sep 2023 10:38:21 +0200 Subject: [PATCH] feat: implement kill and delete endpoints --- run_local.sh | 1 + setup.cfg | 1 + src/diracx/cli/internal.py | 4 +- src/diracx/core/config/schema.py | 7 +- src/diracx/core/models.py | 2 +- src/diracx/db/__init__.py | 4 +- src/diracx/db/jobs/db.py | 328 +++++++++++++++++++++ src/diracx/db/jobs/schema.py | 99 +++++++ src/diracx/routers/dependencies.py | 3 + src/diracx/routers/job_manager/__init__.py | 179 ++++++++++- tests/conftest.py | 1 + 11 files changed, 620 insertions(+), 9 deletions(-) diff --git a/run_local.sh b/run_local.sh index 5fe2f2e56..74bee6daf 100755 --- a/run_local.sh +++ b/run_local.sh @@ -17,6 +17,7 @@ export DIRACX_CONFIG_BACKEND_URL="git+file://${tmp_dir}/cs_store/initialRepo" export DIRACX_DB_URL_AUTHDB="sqlite+aiosqlite:///:memory:" export DIRACX_DB_URL_JOBDB="sqlite+aiosqlite:///:memory:" export DIRACX_DB_URL_JOBLOGGINGDB="sqlite+aiosqlite:///:memory:" +export DIRACX_DB_URL_TASKQUEUEDB="sqlite+aiosqlite:///:memory:" export DIRACX_SERVICE_AUTH_TOKEN_KEY="file://${tmp_dir}/signing-key/rs256.key" export DIRACX_SERVICE_AUTH_ALLOWED_REDIRECTS='["http://'$(hostname| tr -s '[:upper:]' '[:lower:]')':8000/docs/oauth2-redirect"]' diff --git a/setup.cfg b/setup.cfg index f70bd5d1d..45e432ba2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -78,6 +78,7 @@ diracx.dbs = JobDB = diracx.db:JobDB JobLoggingDB = diracx.db:JobLoggingDB SandboxMetadataDB = diracx.db:SandboxMetadataDB + TaskQueueDB = diracx.db:TaskQueueDB #DummyDB = diracx.db:DummyDB diracx.services = jobs = diracx.routers.job_manager:router diff --git a/src/diracx/cli/internal.py b/src/diracx/cli/internal.py index 75b9ab0c8..855a3b7dd 100644 --- a/src/diracx/cli/internal.py +++ b/src/diracx/cli/internal.py @@ -48,9 +48,7 @@ def generate_cs( DefaultGroup=user_group, Users={}, Groups={ - user_group: GroupConfig( - JobShare=None, Properties=["NormalUser"], Quota=None, Users=[] - ) + user_group: GroupConfig(Properties=["NormalUser"], Quota=None, Users=[]) }, ) config = Config( diff --git a/src/diracx/core/config/schema.py b/src/diracx/core/config/schema.py index 8f3470e05..1ee61f6d5 100644 --- a/src/diracx/core/config/schema.py +++ b/src/diracx/core/config/schema.py @@ -49,7 +49,7 @@ class GroupConfig(BaseModel): AutoAddVOMS: bool = False AutoUploadPilotProxy: bool = False AutoUploadProxy: bool = False - JobShare: Optional[int] + JobShare: int = 1000 Properties: list[SecurityProperty] Quota: Optional[int] Users: list[str] @@ -86,9 +86,14 @@ class JobMonitoringConfig(BaseModel): useESForJobParametersFlag: bool = False +class JobSchedulingConfig(BaseModel): + EnableSharesCorrection: bool = False + + class ServicesConfig(BaseModel): Catalogs: dict[str, Any] | None JobMonitoring: JobMonitoringConfig = JobMonitoringConfig() + JobScheduling: JobSchedulingConfig = JobSchedulingConfig() class OperationsConfig(BaseModel): diff --git a/src/diracx/core/models.py b/src/diracx/core/models.py index fa6048011..81753f34e 100644 --- a/src/diracx/core/models.py +++ b/src/diracx/core/models.py @@ -31,7 +31,7 @@ class SortSpec(TypedDict): class ScalarSearchSpec(TypedDict): parameter: str operator: ScalarSearchOperator - value: str + value: str | int class VectorSearchSpec(TypedDict): diff --git a/src/diracx/db/__init__.py b/src/diracx/db/__init__.py index 3dd13c3c9..e80409d69 100644 --- a/src/diracx/db/__init__.py +++ b/src/diracx/db/__init__.py @@ -1,7 +1,7 @@ -__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB") +__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB") from .auth.db import AuthDB -from .jobs.db import JobDB, JobLoggingDB +from .jobs.db import JobDB, JobLoggingDB, TaskQueueDB from .sandbox_metadata.db import SandboxMetadataDB # from .dummy.db import DummyDB diff --git a/src/diracx/db/jobs/db.py b/src/diracx/db/jobs/db.py index 388ec6be2..9e2467339 100644 --- a/src/diracx/db/jobs/db.py +++ b/src/diracx/db/jobs/db.py @@ -9,16 +9,27 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.models import JobStatusReturn, LimitedJobStatusReturn +from diracx.core.properties import JOB_SHARING, SecurityProperty from diracx.core.utils import JobStatus from ..utils import BaseDB, apply_search_filters from .schema import ( + BannedSitesQueue, + GridCEsQueue, InputData, + JobCommands, JobDBBase, JobJDLs, JobLoggingDBBase, Jobs, + JobsQueue, + JobTypesQueue, LoggingInfo, + PlatformsQueue, + SitesQueue, + TagsQueue, + TaskQueueDBBase, + TaskQueues, ) @@ -260,6 +271,27 @@ async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn: **dict((await self.conn.execute(stmt)).one()._mapping) ) + async def set_job_command(self, job_id: int, command: str, arguments: str = ""): + """Store a command to be passed to the job together with the next heart beat""" + stmt = insert(JobCommands).values( + JobID=job_id, + Command=command, + Arguments=arguments, + ReceptionTime=datetime.now(tz=timezone.utc), + ) + await self.conn.execute(stmt) + + async def get_vo(self, job_id: int) -> str: + """ + Get the VO of the owner of the job + """ + # TODO: Consider having a VO column in the Jobs table + from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd + + stmt = select(JobJDLs.JDL).where(JobJDLs.JobID == job_id) + jdl = (await self.conn.execute(stmt)).scalar_one() + return ClassAd(jdl).getAttributeString("VirtualOrganisation") + MAGIC_EPOC_NUMBER = 1270000000 @@ -405,3 +437,299 @@ async def get_wms_time_stamps(self, job_id): result[event] = str(etime + MAGIC_EPOC_NUMBER) return result + + +class TaskQueueDB(BaseDB): + metadata = TaskQueueDBBase.metadata + + async def get_tq_id_for_job(self, job_id: int) -> int: + """ + Get the task queue info for a given job + """ + stmt = select(TaskQueues.TQId).where(JobsQueue.JobId == job_id) + return (await self.conn.execute(stmt)).scalar_one() + + async def get_owner_for_task_queue(self, tq_id: int) -> dict[str, str]: + """ + Get the owner and owner group for a task queue + """ + stmt = select(TaskQueues.Owner, TaskQueues.OwnerGroup, TaskQueues.VO).where( + TaskQueues.TQId == tq_id + ) + return dict((await self.conn.execute(stmt)).one()._mapping) + + async def delete_job(self, job_id: int): + """ + Delete a job from the task queues + Raises NoResultFound if the job_id is not found + """ + stmt = delete(JobsQueue).where(JobsQueue.JobId == job_id) + await self.conn.execute(stmt) + + async def delete_task_queue_if_empty( + self, + tq_id: int, + tq_owner: str, + tq_group: str, + job_share: int, + group_properties: list[SecurityProperty], + enable_shares_correction: bool, + allow_background_tqs: bool, + ): + """ + Try to delete a task queue if it's empty + """ + # Check if the task queue is empty + stmt = ( + select(TaskQueues.TQId) + .where(TaskQueues.Enabled >= 1) + .where(TaskQueues.TQId == tq_id) + .where(~TaskQueues.TQId.in_(select(JobsQueue.TQId))) + ) + rows = await self.conn.execute(stmt) + if not rows.rowcount: + return + + # Deleting the task queue (the other tables will be deleted in cascade) + stmt = delete(TaskQueues).where(TaskQueues.TQId == tq_id) + await self.conn.execute(stmt) + + await self.recalculate_tq_shares_for_entity( + tq_owner, + tq_group, + job_share, + group_properties, + enable_shares_correction, + allow_background_tqs, + ) + + async def recalculate_tq_shares_for_entity( + self, + owner: str, + group: str, + job_share: int, + group_properties: list[SecurityProperty], + enable_shares_correction: bool, + allow_background_tqs: bool, + ): + """ + Recalculate the shares for a user/userGroup combo + """ + if JOB_SHARING in group_properties: + # If group has JobSharing just set prio for that entry, user is irrelevant + return await self.__set_priorities_for_entity( + owner, group, job_share, group_properties, allow_background_tqs + ) + + stmt = ( + select(TaskQueues.Owner, func.count(TaskQueues.Owner)) + .where(TaskQueues.OwnerGroup == group) + .group_by(TaskQueues.Owner) + ) + rows = await self.conn.execute(stmt) + # make the rows a list of tuples + # Get owners in this group and the amount of times they appear + # TODO: I guess the rows are already a list of tupes + # maybe refactor + data = [(r[0], r[1]) for r in rows if r] + numOwners = len(data) + # If there are no owners do now + if numOwners == 0: + return + # Split the share amongst the number of owners + entities_shares = {row[0]: job_share / numOwners for row in data} + + # TODO: implement the following + # If corrector is enabled let it work it's magic + # if enable_shares_correction: + # entities_shares = await self.__shares_corrector.correct_shares( + # entitiesShares, group=group + # ) + + # Keep updating + owners = dict(data) + # IF the user is already known and has more than 1 tq, the rest of the users don't need to be modified + # (The number of owners didn't change) + if owner in owners and owners[owner] > 1: + await self.__set_priorities_for_entity( + owner, + group, + entities_shares[owner], + group_properties, + allow_background_tqs, + ) + return + # Oops the number of owners may have changed so we recalculate the prio for all owners in the group + for owner in owners: + await self.__set_priorities_for_entity( + owner, + group, + entities_shares[owner], + group_properties, + allow_background_tqs, + ) + pass + + async def __set_priorities_for_entity( + self, + owner: str, + group: str, + share, + properties: list[SecurityProperty], + allow_background_tqs: bool, + ): + """ + Set the priority for a user/userGroup combo given a splitted share + """ + from DIRAC.WorkloadManagementSystem.DB.TaskQueueDB import ( + TQ_MIN_SHARE, + priorityIgnoredFields, + ) + + stmt = ( + select( + TaskQueues.TQId, + func.sum(JobsQueue.RealPriority) / func.count(JobsQueue.RealPriority), + ) + # TODO: uncomment me and understand why mypy is unhappy with join here and not elsewhere + # .select_from(TaskQueues.join(JobsQueue, TaskQueues.TQId == JobsQueue.TQId)) + .where(TaskQueues.OwnerGroup == group).group_by(TaskQueues.TQId) + ) + if JOB_SHARING not in properties: + stmt = stmt.where(TaskQueues.Owner == owner) + rows = await self.conn.execute(stmt) + # Make a dict of TQId:priority + tqDict: dict[int, float] = {row[0]: row[1] for row in rows} + + if not tqDict: + return + + allowBgTQs = allow_background_tqs + + # TODO: one of the only place the logic could actually be encapsulated + # so refactor + + # Calculate Sum of priorities + totalPrio = 0.0 + for k in tqDict: + if tqDict[k] > 0.1 or not allowBgTQs: + totalPrio += tqDict[k] + # Update prio for each TQ + for tqId in tqDict: + if tqDict[tqId] > 0.1 or not allowBgTQs: + prio = (share / totalPrio) * tqDict[tqId] + else: + prio = TQ_MIN_SHARE + prio = max(prio, TQ_MIN_SHARE) + tqDict[tqId] = prio + + # Generate groups of TQs that will have the same prio=sum(prios) maomenos + rows = await self.retrieve_task_queues(list(tqDict)) + # TODO: check the following asumption is correct + allTQsData = rows + tqGroups: dict[str, list] = {} + for tqid in allTQsData: + tqData = allTQsData[tqid] + for field in ("Jobs", "Priority") + priorityIgnoredFields: + if field in tqData: + tqData.pop(field) + tqHash = [] + for f in sorted(tqData): + tqHash.append(f"{f}:{tqData[f]}") + tqHash = "|".join(tqHash) + if tqHash not in tqGroups: + tqGroups[tqHash] = [] + tqGroups[tqHash].append(tqid) + tqGroups = [tqGroups[td] for td in tqGroups] + + # Do the grouping + for tqGroup in tqGroups: + totalPrio = 0 + if len(tqGroup) < 2: + continue + for tqid in tqGroup: + totalPrio += tqDict[tqid] + for tqid in tqGroup: + tqDict[tqid] = totalPrio + + # Group by priorities + prioDict: dict[int, list] = {} + for tqId in tqDict: + prio = tqDict[tqId] + if prio not in prioDict: + prioDict[prio] = [] + prioDict[prio].append(tqId) + + # Execute updates + for prio, tqs in prioDict.items(): + update_stmt = ( + update(TaskQueues).where(TaskQueues.TQId.in_(tqs)).values(Priority=prio) + ) + await self.conn.execute(update_stmt) + + async def retrieve_task_queues(self, tqIdList=None): + """ + Get all the task queues + """ + if tqIdList is not None and not tqIdList: + # Empty list => Fast-track no matches + return {} + + stmt = ( + select( + TaskQueues.TQId, + TaskQueues.Priority, + func.count(JobsQueue.TQId).label("Jobs"), + TaskQueues.Owner, + TaskQueues.OwnerGroup, + TaskQueues.VO, + TaskQueues.CPUTime, + ) + .select_from(TaskQueues.join(JobsQueue, TaskQueues.TQId == JobsQueue.TQId)) + .select_from( + TaskQueues.join(SitesQueue, TaskQueues.TQId == SitesQueue.TQId) + ) + .select_from( + TaskQueues.join(GridCEsQueue, TaskQueues.TQId == GridCEsQueue.TQId) + ) + .group_by( + TaskQueues.TQId, + TaskQueues.Priority, + TaskQueues.Owner, + TaskQueues.OwnerGroup, + TaskQueues.VO, + TaskQueues.CPUTime, + ) + ) + if tqIdList is not None: + stmt = stmt.where(TaskQueues.TQId.in_(tqIdList)) + + tqData = dict(row._mapping for row in await self.conn.execute(stmt)) + # TODO: the line above should be equivalent to the following commented code, check this is the case + # for record in rows: + # tqId = record[0] + # tqData[tqId] = { + # "Priority": record[1], + # "Jobs": record[2], + # "Owner": record[3], + # "OwnerGroup": record[4], + # "VO": record[5], + # "CPUTime": record[6], + # } + + for tqId in tqData: + # TODO: maybe factorize this handy tuple list + for table, field in { + (SitesQueue, "Sites"), + (GridCEsQueue, "GridCEs"), + (BannedSitesQueue, "BannedSites"), + (PlatformsQueue, "Platforms"), + (JobTypesQueue, "JobTypes"), + (TagsQueue, "Tags"), + }: + stmt = select(table.Value).where(table.TQId == tqId) + tqData[tqId][field] = list( + row[0] for row in await self.conn.execute(stmt) + ) + + return tqData diff --git a/src/diracx/db/jobs/schema.py b/src/diracx/db/jobs/schema.py index d13cf5009..5f66e534d 100644 --- a/src/diracx/db/jobs/schema.py +++ b/src/diracx/db/jobs/schema.py @@ -1,7 +1,9 @@ import sqlalchemy.types as types from sqlalchemy import ( + Boolean, DateTime, Enum, + Float, ForeignKey, ForeignKeyConstraint, Index, @@ -17,6 +19,7 @@ JobDBBase = declarative_base() JobLoggingDBBase = declarative_base() +TaskQueueDBBase = declarative_base() class EnumBackedBool(types.TypeDecorator): @@ -198,3 +201,99 @@ class LoggingInfo(JobLoggingDBBase): StatusTimeOrder = Column(Numeric(precision=12, scale=3), default=0) StatusSource = Column(String(32), default="Unknown") __table_args__ = (PrimaryKeyConstraint("JobID", "SeqNum"),) + + +class TaskQueues(TaskQueueDBBase): + __tablename__ = "tq_TaskQueues" + TQId = Column(Integer, primary_key=True) + Owner = Column(String(255), nullable=False) + OwnerDN = Column(String(255)) + OwnerGroup = Column(String(32), nullable=False) + VO = Column(String(32), nullable=False) + CPUTime = Column(Integer, nullable=False) + Priority = Column(Float, nullable=False) + Enabled = Column(Boolean, nullable=False, default=0) + __table_args__ = (Index("TQOwner", "Owner", "OwnerGroup", "CPUTime"),) + + +class JobsQueue(TaskQueueDBBase): + __tablename__ = "tq_Jobs" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + JobId = Column(Integer, primary_key=True) + Priority = Column(Integer, nullable=False) + RealPriority = Column(Float, nullable=False) + __table_args__ = (Index("TaskIndex", "TQId"),) + + +class SitesQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToSites" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("SitesTaskIndex", "TQId"), + Index("SitesIndex", "Value"), + ) + + +class GridCEsQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToGridCEs" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("GridCEsTaskIndex", "TQId"), + Index("GridCEsValueIndex", "Value"), + ) + + +class BannedSitesQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToBannedSites" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("BannedSitesTaskIndex", "TQId"), + Index("BannedSitesValueIndex", "Value"), + ) + + +class PlatformsQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToPlatforms" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("PlatformsTaskIndex", "TQId"), + Index("PlatformsValueIndex", "Value"), + ) + + +class JobTypesQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToJobTypes" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("JobTypesTaskIndex", "TQId"), + Index("JobTypesValueIndex", "Value"), + ) + + +class TagsQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToTags" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("TagsTaskIndex", "TQId"), + Index("TagsValueIndex", "Value"), + ) diff --git a/src/diracx/routers/dependencies.py b/src/diracx/routers/dependencies.py index 17d4cf830..8f4211f2a 100644 --- a/src/diracx/routers/dependencies.py +++ b/src/diracx/routers/dependencies.py @@ -5,6 +5,7 @@ "AuthDB", "JobDB", "JobLoggingDB", + "TaskQueueDB", "add_settings_annotation", "AvailableSecurityProperties", ) @@ -19,6 +20,7 @@ from diracx.db import AuthDB as _AuthDB from diracx.db import JobDB as _JobDB from diracx.db import JobLoggingDB as _JobLoggingDB +from diracx.db import TaskQueueDB as _TaskQueueDB T = TypeVar("T") @@ -32,6 +34,7 @@ def add_settings_annotation(cls: T) -> T: AuthDB = Annotated[_AuthDB, Depends(_AuthDB.transaction)] JobDB = Annotated[_JobDB, Depends(_JobDB.transaction)] JobLoggingDB = Annotated[_JobLoggingDB, Depends(_JobLoggingDB.transaction)] +TaskQueueDB = Annotated[_TaskQueueDB, Depends(_TaskQueueDB.transaction)] # Miscellaneous Config = Annotated[_Config, Depends(ConfigSource.create)] diff --git a/src/diracx/routers/job_manager/__init__.py b/src/diracx/routers/job_manager/__init__.py index 2f06a4756..bf4636cb6 100644 --- a/src/diracx/routers/job_manager/__init__.py +++ b/src/diracx/routers/job_manager/__init__.py @@ -6,7 +6,7 @@ from http import HTTPStatus from typing import Annotated, Any, TypedDict -from fastapi import Body, Depends, HTTPException, Query +from fastapi import BackgroundTasks, Body, Depends, HTTPException, Query from pydantic import BaseModel, root_validator from sqlalchemy.exc import NoResultFound @@ -16,6 +16,7 @@ JobStatusUpdate, LimitedJobStatusReturn, ScalarSearchOperator, + ScalarSearchSpec, SearchSpec, SetJobStatusReturn, SortSpec, @@ -27,7 +28,7 @@ ) from ..auth import UserInfo, has_properties, verify_dirac_access_token -from ..dependencies import JobDB, JobLoggingDB +from ..dependencies import JobDB, JobLoggingDB, TaskQueueDB from ..fastapi_classes import DiracxRouter MAX_PARAMETRIC_JOBS = 20 @@ -240,6 +241,10 @@ async def delete_bulk_jobs(job_ids: Annotated[list[int], Query()]): @router.post("/kill") async def kill_bulk_jobs( job_ids: Annotated[list[int], Query()], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + user_info: Annotated[UserInfo, Depends(verify_dirac_access_token)], ): return job_ids @@ -406,6 +411,100 @@ async def get_single_job(job_id: int): return f"This job {job_id}" +@router.delete("/{job_id}") +async def delete_single_job( + job_id: int, + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise + from DIRAC.StorageManagementSystem.Client.StorageManagerClient import ( + StorageManagerClient, + ) + + res = await job_db.search( + parameters=["Status", "Owner", "OwnerGroup"], + search=[ + ScalarSearchSpec( + parameter="JobID", + operator=ScalarSearchOperator.EQUAL, + value=job_id, + ), + ], + sorts=[], + ) + if not res: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, detail=f"Job {job_id} not found" + ) + + status = res[0]["Status"] + owner = res[0]["Owner"] + owner_group = res[0]["OwnerGroup"] + vo = await job_db.get_vo(job_id) + + # TODO: implement JobPolicy + # validJobList, invalidJobList, nonauthJobList, ownerJobList = self.jobPolicy.evaluateJobRights(jobList, right) + + if status == JobStatus.STAGING: + returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID([job_id])) + + if status in (JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED): + await job_db.set_job_command(job_id, "Kill") + + await set_job_status( + job_id, + { + datetime.now(timezone.utc): JobStatusUpdate( + Status=JobStatus.DELETED, + MinorStatus="Checking accounting", + StatusSource="job_manager", + ) + }, + job_db, + job_logging_db, + force=True, # TODO: consider if force should be True or False + ) + + # Delete the job from the task queue + tq_id = await task_queue_db.get_tq_id_for_job(job_id) + await task_queue_db.delete_job(job_id) + background_task.add_task( + task_queue_db.delete_task_queue_if_empty, + tq_id, + owner, + owner_group, + config.Registry[vo].Groups[owner_group].JobShare, + config.Registry[vo].Groups[owner_group].Properties, + config.Operations[vo].Services.JobScheduling.EnableSharesCorrection, + config.Registry[vo].Groups[owner_group].AllowBackgroundTQs, + ) + + # TODO: implement the following + # if it was the last job for the pilot + # result = self.pilotAgentsDB.getPilotsForJobID(jobID) + # if not result["OK"]: + # self.log.error("Failed to get Pilots for JobID", result["Message"]) + # return result + # for pilot in result["Value"]: + # res = self.pilotAgentsDB.getJobsForPilot(pilot) + # if not res["OK"]: + # self.log.error("Failed to get jobs for pilot", res["Message"]) + # return res + # if not res["Value"]: # if list of jobs for pilot is empty, delete pilot + # result = self.pilotAgentsDB.getPilotInfo(pilotID=pilot) + # if not result["OK"]: + # self.log.error("Failed to get pilot info", result["Message"]) + # return result + # ret = self.pilotAgentsDB.deletePilot(result["Value"]["PilotJobReference"]) + # if not ret["OK"]: + # self.log.error("Failed to delete pilot from PilotAgentsDB", ret["Message"]) + # return ret + + @router.get("/{job_id}/status") async def get_single_job_status( job_id: int, job_db: JobDB @@ -456,3 +555,79 @@ async def get_single_job_status_history( status_code=HTTPStatus.NOT_FOUND, detail="Job not found" ) from e return {job_id: status} + + +@router.post("/{job_id}/kill") +async def kill_single_job( + job_id: int, + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Kill a job. + """ + from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise + from DIRAC.StorageManagementSystem.Client.StorageManagerClient import ( + StorageManagerClient, + ) + + res = await job_db.search( + parameters=["Status", "Owner", "OwnerGroup"], + search=[ + ScalarSearchSpec( + parameter="JobID", + operator=ScalarSearchOperator.EQUAL, + value=job_id, + ), + ], + sorts=[], + ) + if not res: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, detail=f"Job {job_id} not found" + ) + + status = res[0]["Status"] + owner = res[0]["Owner"] + owner_group = res[0]["OwnerGroup"] + vo = await job_db.get_vo(job_id) + + # TODO: implement JobPolicy + # validJobList, invalidJobList, nonauthJobList, ownerJobList = self.jobPolicy.evaluateJobRights(jobList, right) + + if status == JobStatus.STAGING: + returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID([job_id])) + + if status in (JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED): + await job_db.set_job_command(job_id, "Kill") + + await set_job_status( + job_id, + { + datetime.now(timezone.utc): JobStatusUpdate( + Status=JobStatus.KILLED, + MinorStatus="Marked for termination", + StatusSource="job_manager", + ) + }, + job_db, + job_logging_db, + force=True, # TODO: consider if force should be True or False + ) + + # Delete the job from the task queue + tq_id = await task_queue_db.get_tq_id_for_job(job_id) + await task_queue_db.delete_job(job_id) + background_task.add_task( + task_queue_db.delete_task_queue_if_empty, + tq_id, + owner, + owner_group, + config.Registry[vo].Groups[owner_group].JobShare, + config.Registry[vo].Groups[owner_group].Properties, + config.Operations[vo].Services.JobScheduling.EnableSharesCorrection, + config.Registry[vo].Groups[owner_group].AllowBackgroundTQs, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 24eacb6b9..4a359424c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -74,6 +74,7 @@ def with_app(test_auth_settings, with_config_repo): database_urls={ "JobDB": "sqlite+aiosqlite:///:memory:", "JobLoggingDB": "sqlite+aiosqlite:///:memory:", + "TaskQueueDB": "sqlite+aiosqlite:///:memory:", "AuthDB": "sqlite+aiosqlite:///:memory:", "SandboxMetadataDB": "sqlite+aiosqlite:///:memory:", },