diff --git a/run_local.sh b/run_local.sh index 5fe2f2e56..f921bba08 100755 --- a/run_local.sh +++ b/run_local.sh @@ -17,6 +17,8 @@ export DIRACX_CONFIG_BACKEND_URL="git+file://${tmp_dir}/cs_store/initialRepo" export DIRACX_DB_URL_AUTHDB="sqlite+aiosqlite:///:memory:" export DIRACX_DB_URL_JOBDB="sqlite+aiosqlite:///:memory:" export DIRACX_DB_URL_JOBLOGGINGDB="sqlite+aiosqlite:///:memory:" +export DIRACX_DB_URL_SANDBOXMETADATADB="sqlite+aiosqlite:///:memory:" +export DIRACX_DB_URL_TASKQUEUEDB="sqlite+aiosqlite:///:memory:" export DIRACX_SERVICE_AUTH_TOKEN_KEY="file://${tmp_dir}/signing-key/rs256.key" export DIRACX_SERVICE_AUTH_ALLOWED_REDIRECTS='["http://'$(hostname| tr -s '[:upper:]' '[:lower:]')':8000/docs/oauth2-redirect"]' diff --git a/setup.cfg b/setup.cfg index abd38583c..28fa6129e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,7 +79,8 @@ diracx.db.sql = JobDB = diracx.db.sql:JobDB JobLoggingDB = diracx.db.sql:JobLoggingDB SandboxMetadataDB = diracx.db.sql:SandboxMetadataDB - #DummyDB = diracx.db:DummyDB + TaskQueueDB = diracx.db.sql:TaskQueueDB + #DummyDB = diracx.db.sql:DummyDB diracx.db.os = JobParametersDB = diracx.db.os:JobParametersDB diracx.services = diff --git a/src/diracx/cli/internal.py b/src/diracx/cli/internal.py index 75b9ab0c8..855a3b7dd 100644 --- a/src/diracx/cli/internal.py +++ b/src/diracx/cli/internal.py @@ -48,9 +48,7 @@ def generate_cs( DefaultGroup=user_group, Users={}, Groups={ - user_group: GroupConfig( - JobShare=None, Properties=["NormalUser"], Quota=None, Users=[] - ) + user_group: GroupConfig(Properties=["NormalUser"], Quota=None, Users=[]) }, ) config = Config( diff --git a/src/diracx/core/config/schema.py b/src/diracx/core/config/schema.py index 8f3470e05..1ee61f6d5 100644 --- a/src/diracx/core/config/schema.py +++ b/src/diracx/core/config/schema.py @@ -49,7 +49,7 @@ class GroupConfig(BaseModel): AutoAddVOMS: bool = False AutoUploadPilotProxy: bool = False AutoUploadProxy: bool = False - JobShare: Optional[int] + JobShare: int = 1000 Properties: list[SecurityProperty] Quota: Optional[int] Users: list[str] @@ -86,9 +86,14 @@ class JobMonitoringConfig(BaseModel): useESForJobParametersFlag: bool = False +class JobSchedulingConfig(BaseModel): + EnableSharesCorrection: bool = False + + class ServicesConfig(BaseModel): Catalogs: dict[str, Any] | None JobMonitoring: JobMonitoringConfig = JobMonitoringConfig() + JobScheduling: JobSchedulingConfig = JobSchedulingConfig() class OperationsConfig(BaseModel): diff --git a/src/diracx/core/exceptions.py b/src/diracx/core/exceptions.py index 4ca6eaa96..efb00f4c6 100644 --- a/src/diracx/core/exceptions.py +++ b/src/diracx/core/exceptions.py @@ -36,3 +36,9 @@ class BadConfigurationVersion(ConfigurationError): class InvalidQueryError(DiracError): """It was not possible to build a valid database query from the given input""" + + +class JobNotFound(Exception): + def __init__(self, job_id: int): + self.job_id: int = job_id + super().__init__(f"Job {job_id} not found") diff --git a/src/diracx/core/models.py b/src/diracx/core/models.py index 9171dc4f2..23e130a73 100644 --- a/src/diracx/core/models.py +++ b/src/diracx/core/models.py @@ -29,13 +29,13 @@ class SortSpec(TypedDict): class ScalarSearchSpec(TypedDict): parameter: str operator: ScalarSearchOperator - value: str + value: str | int class VectorSearchSpec(TypedDict): parameter: str operator: VectorSearchOperator - values: list[str] + values: list[str] | list[int] SearchSpec = ScalarSearchSpec | VectorSearchSpec diff --git a/src/diracx/db/__main__.py b/src/diracx/db/__main__.py index b79e0281b..da36eace1 100644 --- a/src/diracx/db/__main__.py +++ b/src/diracx/db/__main__.py @@ -35,6 +35,9 @@ async def init_sql(): db = BaseSQLDB.available_implementations(db_name)[0](db_url) async with db.engine_context(): async with db.engine.begin() as conn: + # set PRAGMA foreign_keys=ON if sqlite + if db._db_url.startswith("sqlite"): + await conn.exec_driver_sql("PRAGMA foreign_keys=ON") await conn.run_sync(db.metadata.create_all) diff --git a/src/diracx/db/sql/__init__.py b/src/diracx/db/sql/__init__.py index 17b542a57..582509b13 100644 --- a/src/diracx/db/sql/__init__.py +++ b/src/diracx/db/sql/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations -__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB") +__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB") from .auth.db import AuthDB -from .jobs.db import JobDB, JobLoggingDB +from .jobs.db import JobDB, JobLoggingDB, TaskQueueDB from .sandbox_metadata.db import SandboxMetadataDB diff --git a/src/diracx/db/sql/jobs/db.py b/src/diracx/db/sql/jobs/db.py index 273fe81cb..6a8bf986e 100644 --- a/src/diracx/db/sql/jobs/db.py +++ b/src/diracx/db/sql/jobs/db.py @@ -5,19 +5,30 @@ from typing import Any from sqlalchemy import delete, func, insert, select, update -from sqlalchemy.exc import NoResultFound +from sqlalchemy.exc import IntegrityError, NoResultFound -from diracx.core.exceptions import InvalidQueryError +from diracx.core.exceptions import InvalidQueryError, JobNotFound from diracx.core.models import JobStatus, JobStatusReturn, LimitedJobStatusReturn +from diracx.core.properties import JOB_SHARING, SecurityProperty from ..utils import BaseSQLDB, apply_search_filters from .schema import ( + BannedSitesQueue, + GridCEsQueue, InputData, + JobCommands, JobDBBase, JobJDLs, JobLoggingDBBase, Jobs, + JobsQueue, + JobTypesQueue, LoggingInfo, + PlatformsQueue, + SitesQueue, + TagsQueue, + TaskQueueDBBase, + TaskQueues, ) @@ -251,12 +262,35 @@ async def insert( } async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn: - stmt = select(Jobs.Status, Jobs.MinorStatus, Jobs.ApplicationStatus).where( - Jobs.JobID == job_id - ) - return LimitedJobStatusReturn( - **dict((await self.conn.execute(stmt)).one()._mapping) - ) + try: + stmt = select(Jobs.Status, Jobs.MinorStatus, Jobs.ApplicationStatus).where( + Jobs.JobID == job_id + ) + return LimitedJobStatusReturn( + **dict((await self.conn.execute(stmt)).one()._mapping) + ) + except NoResultFound as e: + raise JobNotFound(job_id) from e + + async def set_job_command(self, job_id: int, command: str, arguments: str = ""): + """Store a command to be passed to the job together with the next heart beat""" + try: + stmt = insert(JobCommands).values( + JobID=job_id, + Command=command, + Arguments=arguments, + ReceptionTime=datetime.now(tz=timezone.utc), + ) + await self.conn.execute(stmt) + except IntegrityError as e: + raise JobNotFound(job_id) from e + + async def delete_jobs(self, job_ids: list[int]): + """ + Delete jobs from the database + """ + stmt = delete(JobJDLs).where(JobJDLs.JobID.in_(job_ids)) + await self.conn.execute(stmt) MAGIC_EPOC_NUMBER = 1270000000 @@ -378,10 +412,9 @@ async def get_records(self, job_id: int) -> list[JobStatusReturn]: return res - async def delete_records(self, job_id: int): + async def delete_records(self, job_ids: list[int]): """Delete logging records for given jobs""" - - stmt = delete(LoggingInfo).where(LoggingInfo.JobID == job_id) + stmt = delete(LoggingInfo).where(LoggingInfo.JobID.in_(job_ids)) await self.conn.execute(stmt) async def get_wms_time_stamps(self, job_id): @@ -396,9 +429,316 @@ async def get_wms_time_stamps(self, job_id): ).where(LoggingInfo.JobID == job_id) rows = await self.conn.execute(stmt) if not rows.rowcount: - raise NoResultFound(f"No Logging Info for job {job_id}") + raise JobNotFound(job_id) from None for event, etime in rows: result[event] = str(etime + MAGIC_EPOC_NUMBER) return result + + +class TaskQueueDB(BaseSQLDB): + metadata = TaskQueueDBBase.metadata + + async def get_tq_infos_for_jobs( + self, job_ids: list[int] + ) -> set[tuple[int, str, str, str]]: + """ + Get the task queue info for given jobs + """ + stmt = select( + TaskQueues.TQId, TaskQueues.Owner, TaskQueues.OwnerGroup, TaskQueues.VO + ).where(JobsQueue.JobId.in_(job_ids)) + return set( + (int(row[0]), str(row[1]), str(row[2]), str(row[3])) + for row in (await self.conn.execute(stmt)).all() + ) + + async def get_owner_for_task_queue(self, tq_id: int) -> dict[str, str]: + """ + Get the owner and owner group for a task queue + """ + stmt = select(TaskQueues.Owner, TaskQueues.OwnerGroup, TaskQueues.VO).where( + TaskQueues.TQId == tq_id + ) + return dict((await self.conn.execute(stmt)).one()._mapping) + + async def remove_job(self, job_id: int): + """ + Remove a job from the task queues + """ + stmt = delete(JobsQueue).where(JobsQueue.JobId == job_id) + await self.conn.execute(stmt) + + async def remove_jobs(self, job_ids: list[int]): + """ + Remove jobs from the task queues + """ + stmt = delete(JobsQueue).where(JobsQueue.JobId.in_(job_ids)) + await self.conn.execute(stmt) + + async def delete_task_queue_if_empty( + self, + tq_id: int, + tq_owner: str, + tq_group: str, + job_share: int, + group_properties: list[SecurityProperty], + enable_shares_correction: bool, + allow_background_tqs: bool, + ): + """ + Try to delete a task queue if it's empty + """ + # Check if the task queue is empty + stmt = ( + select(TaskQueues.TQId) + .where(TaskQueues.Enabled >= 1) + .where(TaskQueues.TQId == tq_id) + .where(~TaskQueues.TQId.in_(select(JobsQueue.TQId))) + ) + rows = await self.conn.execute(stmt) + if not rows.rowcount: + return + + # Deleting the task queue (the other tables will be deleted in cascade) + stmt = delete(TaskQueues).where(TaskQueues.TQId == tq_id) + await self.conn.execute(stmt) + + await self.recalculate_tq_shares_for_entity( + tq_owner, + tq_group, + job_share, + group_properties, + enable_shares_correction, + allow_background_tqs, + ) + + async def recalculate_tq_shares_for_entity( + self, + owner: str, + group: str, + job_share: int, + group_properties: list[SecurityProperty], + enable_shares_correction: bool, + allow_background_tqs: bool, + ): + """ + Recalculate the shares for a user/userGroup combo + """ + if JOB_SHARING in group_properties: + # If group has JobSharing just set prio for that entry, user is irrelevant + return await self.__set_priorities_for_entity( + owner, group, job_share, group_properties, allow_background_tqs + ) + + stmt = ( + select(TaskQueues.Owner, func.count(TaskQueues.Owner)) + .where(TaskQueues.OwnerGroup == group) + .group_by(TaskQueues.Owner) + ) + rows = await self.conn.execute(stmt) + # make the rows a list of tuples + # Get owners in this group and the amount of times they appear + # TODO: I guess the rows are already a list of tupes + # maybe refactor + data = [(r[0], r[1]) for r in rows if r] + numOwners = len(data) + # If there are no owners do now + if numOwners == 0: + return + # Split the share amongst the number of owners + entities_shares = {row[0]: job_share / numOwners for row in data} + + # TODO: implement the following + # If corrector is enabled let it work it's magic + # if enable_shares_correction: + # entities_shares = await self.__shares_corrector.correct_shares( + # entitiesShares, group=group + # ) + + # Keep updating + owners = dict(data) + # IF the user is already known and has more than 1 tq, the rest of the users don't need to be modified + # (The number of owners didn't change) + if owner in owners and owners[owner] > 1: + await self.__set_priorities_for_entity( + owner, + group, + entities_shares[owner], + group_properties, + allow_background_tqs, + ) + return + # Oops the number of owners may have changed so we recalculate the prio for all owners in the group + for owner in owners: + await self.__set_priorities_for_entity( + owner, + group, + entities_shares[owner], + group_properties, + allow_background_tqs, + ) + + async def __set_priorities_for_entity( + self, + owner: str, + group: str, + share, + properties: list[SecurityProperty], + allow_background_tqs: bool, + ): + """ + Set the priority for a user/userGroup combo given a splitted share + """ + from DIRAC.WorkloadManagementSystem.DB.TaskQueueDB import ( + TQ_MIN_SHARE, + priorityIgnoredFields, + ) + + stmt = ( + select( + TaskQueues.TQId, + func.sum(JobsQueue.RealPriority) / func.count(JobsQueue.RealPriority), + ) + # TODO: uncomment me and understand why mypy is unhappy with join here and not elsewhere + # .select_from(TaskQueues.join(JobsQueue, TaskQueues.TQId == JobsQueue.TQId)) + .where(TaskQueues.OwnerGroup == group).group_by(TaskQueues.TQId) + ) + if JOB_SHARING not in properties: + stmt = stmt.where(TaskQueues.Owner == owner) + rows = await self.conn.execute(stmt) + tqDict: dict[int, float] = {tq_id: priority for tq_id, priority in rows} + + if not tqDict: + return + + allowBgTQs = allow_background_tqs + + # TODO: one of the only place the logic could actually be encapsulated + # so refactor + + # Calculate Sum of priorities + totalPrio = 0.0 + for k in tqDict: + if tqDict[k] > 0.1 or not allowBgTQs: + totalPrio += tqDict[k] + # Update prio for each TQ + for tqId in tqDict: + if tqDict[tqId] > 0.1 or not allowBgTQs: + prio = (share / totalPrio) * tqDict[tqId] + else: + prio = TQ_MIN_SHARE + prio = max(prio, TQ_MIN_SHARE) + tqDict[tqId] = prio + + # Generate groups of TQs that will have the same prio=sum(prios) maomenos + rows = await self.retrieve_task_queues(list(tqDict)) + # TODO: check the following asumption is correct + allTQsData = rows + tqGroups: dict[str, list] = {} + for tqid in allTQsData: + tqData = allTQsData[tqid] + for field in ("Jobs", "Priority") + priorityIgnoredFields: + if field in tqData: + tqData.pop(field) + tqHash = [] + for f in sorted(tqData): + tqHash.append(f"{f}:{tqData[f]}") + tqHash = "|".join(tqHash) + if tqHash not in tqGroups: + tqGroups[tqHash] = [] + tqGroups[tqHash].append(tqid) + groups = [tqGroups[td] for td in tqGroups] + + # Do the grouping + for tqGroup in groups: + totalPrio = 0 + if len(tqGroup) < 2: + continue + for tqid in tqGroup: + totalPrio += tqDict[tqid] + for tqid in tqGroup: + tqDict[tqid] = totalPrio + + # Group by priorities + prioDict: dict[int, list] = {} + for tqId in tqDict: + prio = tqDict[tqId] + if prio not in prioDict: + prioDict[prio] = [] + prioDict[prio].append(tqId) + + # Execute updates + for prio, tqs in prioDict.items(): + update_stmt = ( + update(TaskQueues).where(TaskQueues.TQId.in_(tqs)).values(Priority=prio) + ) + await self.conn.execute(update_stmt) + + async def retrieve_task_queues(self, tq_id_list=None): + """ + Get all the task queues + """ + if tq_id_list is not None and not tq_id_list: + # Empty list => Fast-track no matches + return {} + + stmt = ( + select( + TaskQueues.TQId, + TaskQueues.Priority, + func.count(JobsQueue.TQId).label("Jobs"), + TaskQueues.Owner, + TaskQueues.OwnerGroup, + TaskQueues.VO, + TaskQueues.CPUTime, + ) + .select_from(TaskQueues.join(JobsQueue, TaskQueues.TQId == JobsQueue.TQId)) + .select_from( + TaskQueues.join(SitesQueue, TaskQueues.TQId == SitesQueue.TQId) + ) + .select_from( + TaskQueues.join(GridCEsQueue, TaskQueues.TQId == GridCEsQueue.TQId) + ) + .group_by( + TaskQueues.TQId, + TaskQueues.Priority, + TaskQueues.Owner, + TaskQueues.OwnerGroup, + TaskQueues.VO, + TaskQueues.CPUTime, + ) + ) + if tq_id_list is not None: + stmt = stmt.where(TaskQueues.TQId.in_(tq_id_list)) + + tq_data = dict(dict(row._mapping) for row in await self.conn.execute(stmt)) + # TODO: the line above should be equivalent to the following commented code, check this is the case + # for record in rows: + # tqId = record[0] + # tqData[tqId] = { + # "Priority": record[1], + # "Jobs": record[2], + # "Owner": record[3], + # "OwnerGroup": record[4], + # "VO": record[5], + # "CPUTime": record[6], + # } + + for tq_id in tq_data: + # TODO: maybe factorize this handy tuple list + for table, field in { + (SitesQueue, "Sites"), + (GridCEsQueue, "GridCEs"), + (BannedSitesQueue, "BannedSites"), + (PlatformsQueue, "Platforms"), + (JobTypesQueue, "JobTypes"), + (TagsQueue, "Tags"), + }: + stmt = select(table.Value).where(table.TQId == tq_id) + tq_data[tq_id][field] = list( + row[0] for row in await self.conn.execute(stmt) + ) + + return tq_data diff --git a/src/diracx/db/sql/jobs/schema.py b/src/diracx/db/sql/jobs/schema.py index fe76eedd2..e1b6625bc 100644 --- a/src/diracx/db/sql/jobs/schema.py +++ b/src/diracx/db/sql/jobs/schema.py @@ -1,9 +1,11 @@ import sqlalchemy.types as types from sqlalchemy import ( + BigInteger, + Boolean, DateTime, Enum, + Float, ForeignKey, - ForeignKeyConstraint, Index, Integer, Numeric, @@ -17,6 +19,7 @@ JobDBBase = declarative_base() JobLoggingDBBase = declarative_base() +TaskQueueDBBase = declarative_base() class EnumBackedBool(types.TypeDecorator): @@ -45,19 +48,16 @@ def process_result_value(self, value, dialect) -> bool: raise NotImplementedError(f"Unknown {value=}") -class JobJDLs(JobDBBase): - __tablename__ = "JobJDLs" - JobID = Column(Integer, autoincrement=True) - JDL = Column(Text) - JobRequirements = Column(Text) - OriginalJDL = Column(Text) - __table_args__ = (PrimaryKeyConstraint("JobID"),) - - class Jobs(JobDBBase): __tablename__ = "Jobs" - JobID = Column("JobID", Integer, primary_key=True, default=0) + JobID = Column( + "JobID", + Integer, + ForeignKey("JobJDLs.JobID", ondelete="CASCADE"), + primary_key=True, + default=0, + ) JobType = Column("JobType", String(32), default="user") DIRACSetup = Column("DIRACSetup", String(32), default="test") JobGroup = Column("JobGroup", String(32), default="00000000") @@ -96,7 +96,6 @@ class Jobs(JobDBBase): ) __table_args__ = ( - ForeignKeyConstraint(["JobID"], ["JobJDLs.JobID"]), Index("JobType", "JobType"), Index("JobGroup", "JobGroup"), Index("JobSplitType", "JobSplitType"), @@ -111,33 +110,46 @@ class Jobs(JobDBBase): ) +class JobJDLs(JobDBBase): + __tablename__ = "JobJDLs" + JobID = Column(Integer, autoincrement=True, primary_key=True) + JDL = Column(Text) + JobRequirements = Column(Text) + OriginalJDL = Column(Text) + + class InputData(JobDBBase): __tablename__ = "InputData" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) LFN = Column(String(255), default="", primary_key=True) Status = Column(String(32), default="AprioriGood") - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) class JobParameters(JobDBBase): __tablename__ = "JobParameters" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Name = Column(String(100), primary_key=True) Value = Column(Text) - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) class OptimizerParameters(JobDBBase): __tablename__ = "OptimizerParameters" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Name = Column(String(100), primary_key=True) Value = Column(Text) - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) class AtticJobParameters(JobDBBase): __tablename__ = "AtticJobParameters" - JobID = Column(Integer, ForeignKey("Jobs.JobID"), primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Name = Column(String(100), primary_key=True) Value = Column(Text) RescheduleCycle = Column(Integer) @@ -163,25 +175,25 @@ class SiteMaskLogging(JobDBBase): class HeartBeatLoggingInfo(JobDBBase): __tablename__ = "HeartBeatLoggingInfo" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Name = Column(String(100), primary_key=True) Value = Column(Text) HeartBeatTime = Column(DateTime, primary_key=True) - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) - class JobCommands(JobDBBase): __tablename__ = "JobCommands" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Command = Column(String(100)) Arguments = Column(String(100)) Status = Column(String(64), default="Received") ReceptionTime = Column(DateTime, primary_key=True) ExecutionTime = NullColumn(DateTime) - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) - class LoggingInfo(JobLoggingDBBase): __tablename__ = "LoggingInfo" @@ -195,3 +207,99 @@ class LoggingInfo(JobLoggingDBBase): StatusTimeOrder = Column(Numeric(precision=12, scale=3), default=0) StatusSource = Column(String(32), default="Unknown") __table_args__ = (PrimaryKeyConstraint("JobID", "SeqNum"),) + + +class TaskQueues(TaskQueueDBBase): + __tablename__ = "tq_TaskQueues" + TQId = Column(Integer, primary_key=True) + Owner = Column(String(255), nullable=False) + OwnerDN = Column(String(255)) + OwnerGroup = Column(String(32), nullable=False) + VO = Column(String(32), nullable=False) + CPUTime = Column(BigInteger, nullable=False) + Priority = Column(Float, nullable=False) + Enabled = Column(Boolean, nullable=False, default=0) + __table_args__ = (Index("TQOwner", "Owner", "OwnerGroup", "CPUTime"),) + + +class JobsQueue(TaskQueueDBBase): + __tablename__ = "tq_Jobs" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + JobId = Column(Integer, primary_key=True) + Priority = Column(Integer, nullable=False) + RealPriority = Column(Float, nullable=False) + __table_args__ = (Index("TaskIndex", "TQId"),) + + +class SitesQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToSites" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("SitesTaskIndex", "TQId"), + Index("SitesIndex", "Value"), + ) + + +class GridCEsQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToGridCEs" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("GridCEsTaskIndex", "TQId"), + Index("GridCEsValueIndex", "Value"), + ) + + +class BannedSitesQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToBannedSites" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("BannedSitesTaskIndex", "TQId"), + Index("BannedSitesValueIndex", "Value"), + ) + + +class PlatformsQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToPlatforms" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("PlatformsTaskIndex", "TQId"), + Index("PlatformsValueIndex", "Value"), + ) + + +class JobTypesQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToJobTypes" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("JobTypesTaskIndex", "TQId"), + Index("JobTypesValueIndex", "Value"), + ) + + +class TagsQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToTags" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("TagsTaskIndex", "TQId"), + Index("TagsValueIndex", "Value"), + ) diff --git a/src/diracx/db/sql/jobs/status_utility.py b/src/diracx/db/sql/jobs/status_utility.py index f7f20997c..f2ba43f46 100644 --- a/src/diracx/db/sql/jobs/status_utility.py +++ b/src/diracx/db/sql/jobs/status_utility.py @@ -1,15 +1,19 @@ +import asyncio from datetime import datetime, timezone from unittest.mock import MagicMock -from sqlalchemy.exc import NoResultFound +from fastapi import BackgroundTasks +from diracx.core.config.schema import Config +from diracx.core.exceptions import JobNotFound from diracx.core.models import ( JobStatus, JobStatusUpdate, ScalarSearchOperator, SetJobStatusReturn, ) -from diracx.db.sql.jobs.db import JobDB, JobLoggingDB +from diracx.db.sql.jobs.db import JobDB, JobLoggingDB, TaskQueueDB +from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB async def set_job_status( @@ -21,8 +25,10 @@ async def set_job_status( ) -> SetJobStatusReturn: """Set various status fields for job specified by its jobId. Set only the last status in the JobDB, updating all the status - logging information in the JobLoggingDB. The statusDict has datetime + logging information in the JobLoggingDB. The status dict has datetime as a key and status information dictionary as values + + :raises: JobNotFound if the job is not found in one of the DBs """ from DIRAC.Core.Utilities import TimeUtilities @@ -49,7 +55,7 @@ async def set_job_status( sorts=[], ) if not res: - raise NoResultFound(f"Job {job_id} not found") + raise JobNotFound(job_id) from None currentStatus = res[0]["Status"] startTime = res[0]["StartExecTime"] @@ -60,10 +66,7 @@ async def set_job_status( currentStatus = JobStatus.RUNNING # Get the latest time stamps of major status updates - try: - result = await job_logging_db.get_wms_time_stamps(job_id) - except NoResultFound as e: - raise e + result = await job_logging_db.get_wms_time_stamps(job_id) ##################################################################################################### @@ -146,3 +149,167 @@ async def set_job_status( ) return SetJobStatusReturn(**job_data) + + +class ForgivingTaskGroup(asyncio.TaskGroup): + # Hacky way, check https://stackoverflow.com/questions/75250788/how-to-prevent-python3-11-taskgroup-from-canceling-all-the-tasks + # Basically e're using this because we want to wait for all tasks to finish, even if one of them raises an exception + def _abort(self): + return None + + +async def delete_jobs( + job_ids: list[int], + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + "Delete" jobs by removing them from the task queues, set kill as a job command setting the job status to DELETED. + :raises: BaseExceptionGroup[JobNotFound] for every job that was not found + """ + + await _remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task) + # TODO: implement StorageManagerClient + # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID(job_ids)) + + async with ForgivingTaskGroup() as task_group: + for job_id in job_ids: + task_group.create_task(job_db.set_job_command(job_id, "Kill")) + + task_group.create_task( + set_job_status( + job_id, + { + datetime.now(timezone.utc): JobStatusUpdate( + Status=JobStatus.DELETED, + MinorStatus="Checking accounting", + StatusSource="job_manager", + ) + }, + job_db, + job_logging_db, + force=True, + ) + ) + + +async def kill_jobs( + job_ids: list[int], + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Kill jobs by removing them from the task queues, set kill as a job command and setting the job status to KILLED. + :raises: BaseExceptionGroup[JobNotFound] for every job that was not found + """ + await _remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task) + # TODO: implement StorageManagerClient + # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID(job_ids)) + + async with ForgivingTaskGroup() as task_group: + for job_id in job_ids: + task_group.create_task(job_db.set_job_command(job_id, "Kill")) + task_group.create_task( + set_job_status( + job_id, + { + datetime.now(timezone.utc): JobStatusUpdate( + Status=JobStatus.KILLED, + MinorStatus="Marked for termination", + StatusSource="job_manager", + ) + }, + job_db, + job_logging_db, + force=True, + ) + ) + + # TODO: Consider using the code below instead, probably more stable but less performant + # errors = [] + # for job_id in job_ids: + # try: + # await job_db.set_job_command(job_id, "Kill") + # await set_job_status( + # job_id, + # { + # datetime.now(timezone.utc): JobStatusUpdate( + # Status=JobStatus.KILLED, + # MinorStatus="Marked for termination", + # StatusSource="job_manager", + # ) + # }, + # job_db, + # job_logging_db, + # force=True, + # ) + # except JobNotFound as e: + # errors.append(e) + + # if errors: + # raise BaseExceptionGroup("Some job ids were not found", errors) + + +async def remove_jobs( + job_ids: list[int], + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + sandbox_metadata_db: SandboxMetadataDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Fully remove a job from the WMS databases. + :raises: nothing + """ + + # Remove the staging task from the StorageManager + # TODO: this was not done in the JobManagerHandler, but it was done in the kill method + # I think it should be done here too + # TODO: implement StorageManagerClient + # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID([job_id])) + + # TODO: this was also not done in the JobManagerHandler, but it was done in the JobCleaningAgent + # I think it should be done here as well + await sandbox_metadata_db.unassign_sandbox_from_jobs(job_ids) + + # Remove the job from TaskQueueDB + await _remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task) + + # Remove the job from JobLoggingDB + await job_logging_db.delete_records(job_ids) + + # Remove the job from JobDB + await job_db.delete_jobs(job_ids) + + +async def _remove_jobs_from_task_queue( + job_ids: list[int], + config: Config, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Remove the job from TaskQueueDB + """ + tq_infos = await task_queue_db.get_tq_infos_for_jobs(job_ids) + await task_queue_db.remove_jobs(job_ids) + for tq_id, owner, owner_group, vo in tq_infos: + # TODO: move to Celery + background_task.add_task( + task_queue_db.delete_task_queue_if_empty, + tq_id, + owner, + owner_group, + config.Registry[vo].Groups[owner_group].JobShare, + config.Registry[vo].Groups[owner_group].Properties, + config.Operations[vo].Services.JobScheduling.EnableSharesCorrection, + config.Registry[vo].Groups[owner_group].AllowBackgroundTQs, + ) diff --git a/src/diracx/db/sql/sandbox_metadata/db.py b/src/diracx/db/sql/sandbox_metadata/db.py index 6900f58a7..87e9a24fa 100644 --- a/src/diracx/db/sql/sandbox_metadata/db.py +++ b/src/diracx/db/sql/sandbox_metadata/db.py @@ -6,11 +6,12 @@ import datetime import sqlalchemy +from sqlalchemy import delete from diracx.db.sql.utils import BaseSQLDB from .schema import Base as SandboxMetadataDBBase -from .schema import sb_Owners, sb_SandBoxes +from .schema import sb_EntityMapping, sb_Owners, sb_SandBoxes class SandboxMetadataDB(BaseSQLDB): @@ -36,7 +37,7 @@ async def _get_put_owner(self, owner: str, owner_group: str) -> int: async def insert( self, owner: str, owner_group: str, sb_SE: str, se_PFN: str, size: int = 0 - ) -> tuple[int, bool]: + ) -> int: """inserts a new sandbox in SandboxMetadataDB this is "equivalent" of DIRAC registerAndGetSandbox @@ -78,3 +79,12 @@ async def delete(self, sandbox_ids: list[int]) -> bool: await self.conn.execute(stmt) return True + + async def unassign_sandbox_from_jobs(self, job_ids: list[int]): + """ + Unassign sandbox from jobs + """ + stmt = delete(sb_EntityMapping).where( + sb_EntityMapping.EntityId.in_(f"Job:{job_id}" for job_id in job_ids) + ) + await self.conn.execute(stmt) diff --git a/src/diracx/routers/dependencies.py b/src/diracx/routers/dependencies.py index 3ab40361f..7a67b94f4 100644 --- a/src/diracx/routers/dependencies.py +++ b/src/diracx/routers/dependencies.py @@ -5,6 +5,8 @@ "AuthDB", "JobDB", "JobLoggingDB", + "SandboxMetadataDB", + "TaskQueueDB", "add_settings_annotation", "AvailableSecurityProperties", ) @@ -19,6 +21,8 @@ from diracx.db.sql import AuthDB as _AuthDB from diracx.db.sql import JobDB as _JobDB from diracx.db.sql import JobLoggingDB as _JobLoggingDB +from diracx.db.sql import SandboxMetadataDB as _SandboxMetadataDB +from diracx.db.sql import TaskQueueDB as _TaskQueueDB T = TypeVar("T") @@ -32,6 +36,10 @@ def add_settings_annotation(cls: T) -> T: AuthDB = Annotated[_AuthDB, Depends(_AuthDB.transaction)] JobDB = Annotated[_JobDB, Depends(_JobDB.transaction)] JobLoggingDB = Annotated[_JobLoggingDB, Depends(_JobLoggingDB.transaction)] +SandboxMetadataDB = Annotated[ + _SandboxMetadataDB, Depends(_SandboxMetadataDB.transaction) +] +TaskQueueDB = Annotated[_TaskQueueDB, Depends(_TaskQueueDB.transaction)] # Miscellaneous Config = Annotated[_Config, Depends(ConfigSource.create)] diff --git a/src/diracx/routers/job_manager/__init__.py b/src/diracx/routers/job_manager/__init__.py index 374c224fc..c2894c4c9 100644 --- a/src/diracx/routers/job_manager/__init__.py +++ b/src/diracx/routers/job_manager/__init__.py @@ -6,11 +6,11 @@ from http import HTTPStatus from typing import Annotated, Any, TypedDict -from fastapi import Body, Depends, HTTPException, Query +from fastapi import BackgroundTasks, Body, Depends, HTTPException, Query from pydantic import BaseModel, root_validator -from sqlalchemy.exc import NoResultFound from diracx.core.config import Config, ConfigSource +from diracx.core.exceptions import JobNotFound from diracx.core.models import ( JobStatus, JobStatusReturn, @@ -23,11 +23,14 @@ ) from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER from diracx.db.sql.jobs.status_utility import ( + delete_jobs, + kill_jobs, + remove_jobs, set_job_status, ) from ..auth import UserInfo, has_properties, verify_dirac_access_token -from ..dependencies import JobDB, JobLoggingDB +from ..dependencies import JobDB, JobLoggingDB, SandboxMetadataDB, TaskQueueDB from ..fastapi_classes import DiracxRouter MAX_PARAMETRIC_JOBS = 20 @@ -233,14 +236,108 @@ def __init__(self, user_info: UserInfo, allInfo: bool = True): @router.delete("/") -async def delete_bulk_jobs(job_ids: Annotated[list[int], Query()]): +async def delete_bulk_jobs( + job_ids: Annotated[list[int], Query()], + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + # TODO: implement job policy + + try: + await delete_jobs( + job_ids, + config, + job_db, + job_logging_db, + task_queue_db, + background_task, + ) + except* JobNotFound as group_exc: + failed_job_ids: list[int] = list({e.job_id for e in group_exc.exceptions}) # type: ignore + + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail={ + "message": f"Failed to delete {len(failed_job_ids)} jobs out of {len(job_ids)}", + "valid_job_ids": list(set(job_ids) - set(failed_job_ids)), + "failed_job_ids": failed_job_ids, + }, + ) from group_exc + return job_ids @router.post("/kill") async def kill_bulk_jobs( job_ids: Annotated[list[int], Query()], + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + # TODO: implement job policy + try: + await kill_jobs( + job_ids, + config, + job_db, + job_logging_db, + task_queue_db, + background_task, + ) + except* JobNotFound as group_exc: + failed_job_ids: list[int] = list({e.job_id for e in group_exc.exceptions}) # type: ignore + + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail={ + "message": f"Failed to kill {len(failed_job_ids)} jobs out of {len(job_ids)}", + "valid_job_ids": list(set(job_ids) - set(failed_job_ids)), + "failed_job_ids": failed_job_ids, + }, + ) from group_exc + + return job_ids + + +@router.post("/remove") +async def remove_bulk_jobs( + job_ids: Annotated[list[int], Query()], + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + sandbox_metadata_db: SandboxMetadataDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, ): + """ + Fully remove a list of jobs from the WMS databases. + + + WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS + and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should + be removed, and the delete endpoint should be used instead for any other purpose. + """ + # TODO: Remove once legacy DIRAC no longer needs this + + # TODO: implement job policy + # Some tests have already been written in the test_job_manager, + # but they need to be uncommented and are not complete + + await remove_jobs( + job_ids, + config, + job_db, + job_logging_db, + sandbox_metadata_db, + task_queue_db, + background_task, + ) + return job_ids @@ -253,7 +350,7 @@ async def get_job_status_bulk( *(job_db.get_job_status(job_id) for job_id in job_ids) ) return {job_id: status for job_id, status in zip(job_ids, result)} - except NoResultFound as e: + except JobNotFound as e: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e @@ -406,13 +503,105 @@ async def get_single_job(job_id: int): return f"This job {job_id}" +@router.delete("/{job_id}") +async def delete_single_job( + job_id: int, + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Delete a job by killing and setting the job status to DELETED. + """ + + # TODO: implement job policy + try: + await delete_jobs( + [job_id], + config, + job_db, + job_logging_db, + task_queue_db, + background_task, + ) + except* JobNotFound as e: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND.value, detail=str(e.exceptions[0]) + ) from e + + return f"Job {job_id} has been successfully deleted" + + +@router.post("/{job_id}/kill") +async def kill_single_job( + job_id: int, + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Kill a job. + """ + + # TODO: implement job policy + + try: + await kill_jobs( + [job_id], config, job_db, job_logging_db, task_queue_db, background_task + ) + except* JobNotFound as e: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, detail=str(e.exceptions[0]) + ) from e + + return f"Job {job_id} has been successfully killed" + + +@router.post("/{job_id}/remove") +async def remove_single_job( + job_id: int, + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + sandbox_metadata_db: SandboxMetadataDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Fully remove a job from the WMS databases. + + WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS + and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should + be removed, and the delete endpoint should be used instead. + """ + # TODO: Remove once legacy DIRAC no longer needs this + + # TODO: implement job policy + + await remove_jobs( + [job_id], + config, + job_db, + job_logging_db, + sandbox_metadata_db, + task_queue_db, + background_task, + ) + + return f"Job {job_id} has been successfully removed" + + @router.get("/{job_id}/status") async def get_single_job_status( job_id: int, job_db: JobDB ) -> dict[int, LimitedJobStatusReturn]: try: status = await job_db.get_job_status(job_id) - except NoResultFound as e: + except JobNotFound as e: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail=f"Job {job_id} not found" ) from e @@ -439,7 +628,7 @@ async def set_single_job_status( latest_status = await set_job_status( job_id, status, job_db, job_logging_db, force ) - except NoResultFound as e: + except JobNotFound as e: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e return {job_id: latest_status} @@ -451,7 +640,7 @@ async def get_single_job_status_history( ) -> dict[int, list[JobStatusReturn]]: try: status = await job_logging_db.get_records(job_id) - except NoResultFound as e: + except JobNotFound as e: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="Job not found" ) from e diff --git a/tests/conftest.py b/tests/conftest.py index b3a743187..dd35ebc8f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,6 +49,7 @@ def pytest_collection_modifyitems(config, items): # --regenerate-client given in cli: allow client re-generation return skip_regen = pytest.mark.skip(reason="need --regenerate-client option to run") + found = False for item in items: if item.name == "test_regenerate_client": item.add_marker(skip_regen) @@ -86,6 +87,7 @@ def with_app(test_auth_settings, with_config_repo): database_urls={ "JobDB": "sqlite+aiosqlite:///:memory:", "JobLoggingDB": "sqlite+aiosqlite:///:memory:", + "TaskQueueDB": "sqlite+aiosqlite:///:memory:", "AuthDB": "sqlite+aiosqlite:///:memory:", "SandboxMetadataDB": "sqlite+aiosqlite:///:memory:", }, @@ -111,6 +113,9 @@ async def create_db_schemas(app=app): assert isinstance(db, BaseSQLDB), (k, db) # Fill the DB schema async with db.engine.begin() as conn: + # set PRAGMA foreign_keys=ON if sqlite + if db._db_url.startswith("sqlite"): + await conn.exec_driver_sql("PRAGMA foreign_keys=ON") await conn.run_sync(db.metadata.create_all) yield diff --git a/tests/db/jobs/test_jobDB.py b/tests/db/jobs/test_jobDB.py index 24eec16e0..5e46352b7 100644 --- a/tests/db/jobs/test_jobDB.py +++ b/tests/db/jobs/test_jobDB.py @@ -4,6 +4,7 @@ import pytest +from diracx.core.exceptions import JobNotFound from diracx.db.sql.jobs.db import JobDB @@ -12,6 +13,9 @@ async def job_db(tmp_path): job_db = JobDB("sqlite+aiosqlite:///:memory:") async with job_db.engine_context(): async with job_db.engine.begin() as conn: + # set PRAGMA foreign_keys=ON if sqlite + if job_db._db_url.startswith("sqlite"): + await conn.exec_driver_sql("PRAGMA foreign_keys=ON") await conn.run_sync(job_db.metadata.create_all) yield job_db @@ -38,3 +42,9 @@ async def test_some_asyncio_code(job_db): async with job_db as job_db: result = await job_db.search(["JobID"], [], []) assert result + + +async def test_set_job_command_invalid_job_id(job_db: JobDB): + async with job_db as job_db: + with pytest.raises(JobNotFound): + await job_db.set_job_command(123456, "test_command") diff --git a/tests/db/jobs/test_jobLoggingDB.py b/tests/db/jobs/test_jobLoggingDB.py index 1949cde15..2a089d356 100644 --- a/tests/db/jobs/test_jobLoggingDB.py +++ b/tests/db/jobs/test_jobLoggingDB.py @@ -23,7 +23,7 @@ async def test_insert_record(job_logging_db: JobLoggingDB): # Act await job_logging_db.insert_record( 1, - status=JobStatus.RECEIVED.value, + status=JobStatus.RECEIVED, minor_status="minor_status", application_status="application_status", date=date, diff --git a/tests/routers/test_job_manager.py b/tests/routers/test_job_manager.py index 4bbdd2b16..9a1f8b882 100644 --- a/tests/routers/test_job_manager.py +++ b/tests/routers/test_job_manager.py @@ -229,81 +229,89 @@ def test_user_without_the_normal_user_property_cannot_submit_job(admin_user_clie assert res.status_code == HTTPStatus.FORBIDDEN, res.json() -def test_get_job_status(normal_user_client: TestClient): - """Test that the job status is returned correctly.""" - # Arrange +@pytest.fixture +def valid_job_id(normal_user_client: TestClient): job_definitions = [TEST_JDL] r = normal_user_client.post("/jobs/", json=job_definitions) assert r.status_code == 200, r.json() - assert len(r.json()) == 1 # Parameters.JOB_ID is 3 - job_id = r.json()[0]["JobID"] + assert len(r.json()) == 1 + return r.json()[0]["JobID"] + + +@pytest.fixture +def valid_job_ids(normal_user_client: TestClient): + job_definitions = [TEST_PARAMETRIC_JDL] + r = normal_user_client.post("/jobs/", json=job_definitions) + assert r.status_code == 200, r.json() + assert len(r.json()) == 3 + return sorted([job_dict["JobID"] for job_dict in r.json()]) + + +@pytest.fixture +def invalid_job_id(): + return 999999996 + +@pytest.fixture +def invalid_job_ids(): + return [999999997, 999999998, 999999999] + + +def test_get_job_status(normal_user_client: TestClient, valid_job_id: int): + """Test that the job status is returned correctly.""" # Act - r = normal_user_client.get(f"/jobs/{job_id}/status") + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") # Assert assert r.status_code == 200, r.json() # TODO: should we return camel case here (and everywhere else) ? - assert r.json()[str(job_id)]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" -def test_get_status_of_nonexistent_job(normal_user_client: TestClient): +def test_get_status_of_nonexistent_job( + normal_user_client: TestClient, invalid_job_id: int +): """Test that the job status is returned correctly.""" # Act - r = normal_user_client.get("/jobs/1/status") + r = normal_user_client.get(f"/jobs/{invalid_job_id}/status") # Assert assert r.status_code == 404, r.json() - assert r.json() == {"detail": "Job 1 not found"} + assert r.json() == {"detail": f"Job {invalid_job_id} not found"} -def test_get_job_status_in_bulk(normal_user_client: TestClient): +def test_get_job_status_in_bulk(normal_user_client: TestClient, valid_job_ids: list): """Test that we can get the status of multiple jobs in one request""" - # Arrange - job_definitions = [TEST_PARAMETRIC_JDL] - r = normal_user_client.post("/jobs/", json=job_definitions) - assert r.status_code == 200, r.json() - assert len(r.json()) == 3 # Parameters.JOB_ID is 3 - submitted_job_ids = sorted([job_dict["JobID"] for job_dict in r.json()]) - assert isinstance(submitted_job_ids, list) - assert (isinstance(submitted_job_id, int) for submitted_job_id in submitted_job_ids) - # Act - r = normal_user_client.get("/jobs/status", params={"job_ids": submitted_job_ids}) + r = normal_user_client.get("/jobs/status", params={"job_ids": valid_job_ids}) # Assert - print(r.json()) assert r.status_code == 200, r.json() assert len(r.json()) == 3 # Parameters.JOB_ID is 3 - for job_id in submitted_job_ids: + for job_id in valid_job_ids: assert str(job_id) in r.json() assert r.json()[str(job_id)]["Status"] == JobStatus.SUBMITTING.value assert r.json()[str(job_id)]["MinorStatus"] == "Bulk transaction confirmation" assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" -async def test_get_job_status_history(normal_user_client: TestClient): +async def test_get_job_status_history( + normal_user_client: TestClient, valid_job_id: int +): # Arrange - job_definitions = [TEST_JDL] - before = datetime.now(timezone.utc) - r = normal_user_client.post("/jobs/", json=job_definitions) - after = datetime.now(timezone.utc) + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - job_id = r.json()[0]["JobID"] - r = normal_user_client.get(f"/jobs/{job_id}/status") - assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" NEW_STATUS = JobStatus.CHECKING.value NEW_MINOR_STATUS = "JobPath" - beforebis = datetime.now(timezone.utc) + before = datetime.now(timezone.utc) r = normal_user_client.put( - f"/jobs/{job_id}/status", + f"/jobs/{valid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { "Status": NEW_STATUS, @@ -311,83 +319,74 @@ async def test_get_job_status_history(normal_user_client: TestClient): } }, ) - afterbis = datetime.now(timezone.utc) + after = datetime.now(timezone.utc) assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS # Act r = normal_user_client.get( - f"/jobs/{job_id}/status/history", + f"/jobs/{valid_job_id}/status/history", ) # Assert assert r.status_code == 200, r.json() assert len(r.json()) == 1 - assert len(r.json()[str(job_id)]) == 2 - assert r.json()[str(job_id)][0]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)][0]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)][0]["ApplicationStatus"] == "Unknown" + assert len(r.json()[str(valid_job_id)]) == 2 + assert r.json()[str(valid_job_id)][0]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)][0]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)][0]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)][0]["StatusSource"] == "JobManager" + + assert r.json()[str(valid_job_id)][1]["Status"] == JobStatus.CHECKING.value + assert r.json()[str(valid_job_id)][1]["MinorStatus"] == "JobPath" + assert r.json()[str(valid_job_id)][1]["ApplicationStatus"] == "Unknown" assert ( - before < datetime.fromisoformat(r.json()[str(job_id)][0]["StatusTime"]) < after + before + < datetime.fromisoformat(r.json()[str(valid_job_id)][1]["StatusTime"]) + < after ) - assert r.json()[str(job_id)][0]["StatusSource"] == "JobManager" - - assert r.json()[str(job_id)][1]["Status"] == JobStatus.CHECKING.value - assert r.json()[str(job_id)][1]["MinorStatus"] == "JobPath" - assert r.json()[str(job_id)][1]["ApplicationStatus"] == "Unknown" - assert ( - beforebis - < datetime.fromisoformat(r.json()[str(job_id)][1]["StatusTime"]) - < afterbis - ) - assert r.json()[str(job_id)][1]["StatusSource"] == "Unknown" + assert r.json()[str(valid_job_id)][1]["StatusSource"] == "Unknown" -def test_get_job_status_history_in_bulk(normal_user_client: TestClient): +def test_get_job_status_history_in_bulk( + normal_user_client: TestClient, valid_job_id: int +): # Arrange - job_definitions = [TEST_JDL] - r = normal_user_client.post("/jobs/", json=job_definitions) + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - job_id = r.json()[0]["JobID"] - r = normal_user_client.get(f"/jobs/{job_id}/status") - assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" # Act - r = normal_user_client.get("/jobs/status/history", params={"job_ids": [job_id]}) + r = normal_user_client.get( + "/jobs/status/history", params={"job_ids": [valid_job_id]} + ) # Assert assert r.status_code == 200, r.json() assert len(r.json()) == 1 - assert r.json()[str(job_id)][0]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)][0]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)][0]["ApplicationStatus"] == "Unknown" - assert datetime.fromisoformat(r.json()[str(job_id)][0]["StatusTime"]) - assert r.json()[str(job_id)][0]["StatusSource"] == "JobManager" + assert r.json()[str(valid_job_id)][0]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)][0]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)][0]["ApplicationStatus"] == "Unknown" + assert datetime.fromisoformat(r.json()[str(valid_job_id)][0]["StatusTime"]) + assert r.json()[str(valid_job_id)][0]["StatusSource"] == "JobManager" -def test_set_job_status(normal_user_client: TestClient): +def test_set_job_status(normal_user_client: TestClient, valid_job_id: int): # Arrange - job_definitions = [TEST_JDL] - r = normal_user_client.post("/jobs/", json=job_definitions) - assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - job_id = r.json()[0]["JobID"] - r = normal_user_client.get(f"/jobs/{job_id}/status") + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" # Act NEW_STATUS = JobStatus.CHECKING.value NEW_MINOR_STATUS = "JobPath" r = normal_user_client.put( - f"/jobs/{job_id}/status", + f"/jobs/{valid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { "Status": NEW_STATUS, @@ -398,20 +397,22 @@ def test_set_job_status(normal_user_client: TestClient): # Assert assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS - r = normal_user_client.get(f"/jobs/{job_id}/status") + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" -def test_set_job_status_invalid_job(normal_user_client: TestClient): +def test_set_job_status_invalid_job( + normal_user_client: TestClient, invalid_job_id: int +): # Act r = normal_user_client.put( - "/jobs/1/status", + f"/jobs/{invalid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { "Status": JobStatus.CHECKING.value, @@ -422,23 +423,17 @@ def test_set_job_status_invalid_job(normal_user_client: TestClient): # Assert assert r.status_code == 404, r.json() - assert r.json() == {"detail": "Job 1 not found"} + assert r.json() == {"detail": f"Job {invalid_job_id} not found"} def test_set_job_status_offset_naive_datetime_return_bad_request( normal_user_client: TestClient, + valid_job_id: int, ): - # Arrange - job_definitions = [TEST_JDL] - r = normal_user_client.post("/jobs/", json=job_definitions) - assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - job_id = r.json()[0]["JobID"] - # Act date = datetime.utcnow().isoformat(sep=" ") r = normal_user_client.put( - f"/jobs/{job_id}/status", + f"/jobs/{valid_job_id}/status", json={ date: { "Status": JobStatus.CHECKING.value, @@ -453,25 +448,20 @@ def test_set_job_status_offset_naive_datetime_return_bad_request( def test_set_job_status_cannot_make_impossible_transitions( - normal_user_client: TestClient, + normal_user_client: TestClient, valid_job_id: int ): # Arrange - job_definitions = [TEST_JDL] - r = normal_user_client.post("/jobs/", json=job_definitions) - assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - job_id = r.json()[0]["JobID"] - r = normal_user_client.get(f"/jobs/{job_id}/status") + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" # Act NEW_STATUS = JobStatus.RUNNING.value NEW_MINOR_STATUS = "JobPath" r = normal_user_client.put( - f"/jobs/{job_id}/status", + f"/jobs/{valid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { "Status": NEW_STATUS, @@ -482,34 +472,29 @@ def test_set_job_status_cannot_make_impossible_transitions( # Assert assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] != NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] != NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS - r = normal_user_client.get(f"/jobs/{job_id}/status") + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] != NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] != NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" -def test_set_job_status_force(normal_user_client: TestClient): +def test_set_job_status_force(normal_user_client: TestClient, valid_job_id: int): # Arrange - job_definitions = [TEST_JDL] - r = normal_user_client.post("/jobs/", json=job_definitions) + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - job_id = r.json()[0]["JobID"] - r = normal_user_client.get(f"/jobs/{job_id}/status") - assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" # Act NEW_STATUS = JobStatus.RUNNING.value NEW_MINOR_STATUS = "JobPath" r = normal_user_client.put( - f"/jobs/{job_id}/status", + f"/jobs/{valid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { "Status": NEW_STATUS, @@ -521,25 +506,19 @@ def test_set_job_status_force(normal_user_client: TestClient): # Assert assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS - r = normal_user_client.get(f"/jobs/{job_id}/status") + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" -def test_set_job_status_bulk(normal_user_client: TestClient): +def test_set_job_status_bulk(normal_user_client: TestClient, valid_job_ids): # Arrange - job_definitions = [TEST_PARAMETRIC_JDL] - r = normal_user_client.post("/jobs/", json=job_definitions) - assert r.status_code == 200, r.json() - assert len(r.json()) == 3 - job_ids = sorted([job_dict["JobID"] for job_dict in r.json()]) - - for job_id in job_ids: + for job_id in valid_job_ids: r = normal_user_client.get(f"/jobs/{job_id}/status") assert r.status_code == 200, r.json() assert r.json()[str(job_id)]["Status"] == JobStatus.SUBMITTING.value @@ -557,13 +536,13 @@ def test_set_job_status_bulk(normal_user_client: TestClient): "MinorStatus": NEW_MINOR_STATUS, } } - for job_id in job_ids + for job_id in valid_job_ids }, ) # Assert assert r.status_code == 200, r.json() - for job_id in job_ids: + for job_id in valid_job_ids: assert r.json()[str(job_id)]["Status"] == NEW_STATUS assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS @@ -574,10 +553,12 @@ def test_set_job_status_bulk(normal_user_client: TestClient): assert r_get.json()[str(job_id)]["ApplicationStatus"] == "Unknown" -def test_set_job_status_with_invalid_job_id(normal_user_client: TestClient): +def test_set_job_status_with_invalid_job_id( + normal_user_client: TestClient, invalid_job_id: int +): # Act r = normal_user_client.put( - "/jobs/999999999/status", + f"/jobs/{invalid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { "Status": JobStatus.CHECKING.value, @@ -588,4 +569,248 @@ def test_set_job_status_with_invalid_job_id(normal_user_client: TestClient): # Assert assert r.status_code == 404, r.json() - assert r.json() == {"detail": "Job 999999999 not found"} + assert r.json() == {"detail": f"Job {invalid_job_id} not found"} + + +# Test delete job + + +def test_delete_job_valid_job_id(normal_user_client: TestClient, valid_job_id: int): + # Act + r = normal_user_client.delete(f"/jobs/{valid_job_id}") + + # Assert + assert r.status_code == 200, r.json() + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") + assert r.status_code == 200, r.json() + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.DELETED + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Checking accounting" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" + + +def test_delete_job_invalid_job_id(normal_user_client: TestClient, invalid_job_id: int): + # Act + r = normal_user_client.delete(f"/jobs/{invalid_job_id}") + + # Assert + assert r.status_code == 404, r.json() + assert r.json() == {"detail": f"Job {invalid_job_id} not found"} + + +def test_delete_bulk_jobs_valid_job_ids( + normal_user_client: TestClient, valid_job_ids: list[int] +): + # Act + r = normal_user_client.delete("/jobs/", params={"job_ids": valid_job_ids}) + + # Assert + assert r.status_code == 200, r.json() + for valid_job_id in valid_job_ids: + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") + assert r.status_code == 200, r.json() + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.DELETED + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Checking accounting" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" + + +def test_delete_bulk_jobs_invalid_job_ids( + normal_user_client: TestClient, invalid_job_ids: list[int] +): + # Act + r = normal_user_client.delete("/jobs/", params={"job_ids": invalid_job_ids}) + + # Assert + assert r.status_code == 404, r.json() + assert r.json() == { + "detail": { + "message": f"Failed to delete {len(invalid_job_ids)} jobs out of {len(invalid_job_ids)}", + "valid_job_ids": [], + "failed_job_ids": invalid_job_ids, + } + } + + +def test_delete_bulk_jobs_mix_of_valid_and_invalid_job_ids( + normal_user_client: TestClient, valid_job_ids: list[int], invalid_job_ids: list[int] +): + # Arrange + job_ids = valid_job_ids + invalid_job_ids + + # Act + r = normal_user_client.delete("/jobs/", params={"job_ids": job_ids}) + + # Assert + assert r.status_code == 404, r.json() + assert r.json() == { + "detail": { + "message": f"Failed to delete {len(invalid_job_ids)} jobs out of {len(job_ids)}", + "valid_job_ids": valid_job_ids, + "failed_job_ids": invalid_job_ids, + } + } + for job_id in valid_job_ids: + r = normal_user_client.get(f"/jobs/{job_id}/status") + assert r.status_code == 200, r.json() + assert r.json()[str(job_id)]["Status"] != JobStatus.DELETED + + +# Test kill job + + +def test_kill_job_valid_job_id(normal_user_client: TestClient, valid_job_id: int): + # Act + r = normal_user_client.post(f"/jobs/{valid_job_id}/kill") + + # Assert + assert r.status_code == 200, r.json() + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") + assert r.status_code == 200, r.json() + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.KILLED + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Marked for termination" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" + + +def test_kill_job_invalid_job_id(normal_user_client: TestClient, invalid_job_id: int): + # Act + r = normal_user_client.post(f"/jobs/{invalid_job_id}/kill") + + # Assert + assert r.status_code == 404, r.json() + assert r.json() == {"detail": f"Job {invalid_job_id} not found"} + + +def test_kill_bulk_jobs_valid_job_ids( + normal_user_client: TestClient, valid_job_ids: list[int] +): + # Act + r = normal_user_client.post("/jobs/kill", params={"job_ids": valid_job_ids}) + + # Assert + assert r.status_code == 200, r.json() + assert r.json() == valid_job_ids + for valid_job_id in valid_job_ids: + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") + assert r.status_code == 200, r.json() + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.KILLED + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Marked for termination" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" + + +def test_kill_bulk_jobs_invalid_job_ids( + normal_user_client: TestClient, invalid_job_ids: list[int] +): + # Act + r = normal_user_client.post("/jobs/kill", params={"job_ids": invalid_job_ids}) + + # Assert + assert r.status_code == 404, r.json() + assert r.json() == { + "detail": { + "message": f"Failed to kill {len(invalid_job_ids)} jobs out of {len(invalid_job_ids)}", + "valid_job_ids": [], + "failed_job_ids": invalid_job_ids, + } + } + + +def test_kill_bulk_jobs_mix_of_valid_and_invalid_job_ids( + normal_user_client: TestClient, valid_job_ids: list[int], invalid_job_ids: list[int] +): + # Arrange + job_ids = valid_job_ids + invalid_job_ids + + # Act + r = normal_user_client.post("/jobs/kill", params={"job_ids": job_ids}) + + # Assert + assert r.status_code == 404, r.json() + assert r.json() == { + "detail": { + "message": f"Failed to kill {len(invalid_job_ids)} jobs out of {len(job_ids)}", + "valid_job_ids": valid_job_ids, + "failed_job_ids": invalid_job_ids, + } + } + for valid_job_id in valid_job_ids: + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") + assert r.status_code == 200, r.json() + # assert the job is not killed + assert r.json()[str(valid_job_id)]["Status"] != JobStatus.KILLED + + +# Test remove job + + +def test_remove_job_valid_job_id(normal_user_client: TestClient, valid_job_id: int): + # Act + r = normal_user_client.post(f"/jobs/{valid_job_id}/remove") + + # Assert + assert r.status_code == 200, r.json() + r = normal_user_client.get(f"/jobs/{valid_job_id}/status") + assert r.status_code == 404, r.json() + + +def test_remove_job_invalid_job_id(normal_user_client: TestClient, invalid_job_id: int): + # Act + r = normal_user_client.post(f"/jobs/{invalid_job_id}/remove") + + # Assert + assert r.status_code == 200, r.json() + + +def test_remove_bulk_jobs_valid_job_ids( + normal_user_client: TestClient, valid_job_ids: list[int] +): + # Act + r = normal_user_client.post("/jobs/remove", params={"job_ids": valid_job_ids}) + + # Assert + assert r.status_code == 200, r.json() + for job_id in valid_job_ids: + r = normal_user_client.get(f"/jobs/{job_id}/status") + assert r.status_code == 404, r.json() + + +# def test_remove_bulk_jobs_invalid_job_ids( +# normal_user_client: TestClient, invalid_job_ids: list[int] +# ): +# # Act +# r = normal_user_client.post("/jobs/remove", params={"job_ids": invalid_job_ids}) + +# # Assert +# assert r.status_code == 404, r.json() +# assert r.json() == { +# "detail": { +# "message": f"Failed to remove {len(invalid_job_ids)} jobs out of {len(invalid_job_ids)}", +# "failed_ids": { +# str(invalid_job_id): f"Job {invalid_job_id} not found" +# for invalid_job_id in invalid_job_ids +# }, +# } +# } + + +# def test_remove_bulk_jobs_mix_of_valid_and_invalid_job_ids( +# normal_user_client: TestClient, valid_job_ids: list[int], invalid_job_ids: list[int] +# ): +# # Arrange +# job_ids = valid_job_ids + invalid_job_ids + +# # Act +# r = normal_user_client.post("/jobs/remove", params={"job_ids": job_ids}) + +# # Assert +# assert r.status_code == 404, r.json() +# assert r.json() == { +# "detail": { +# "message": f"Failed to remove {len(invalid_job_ids)} jobs out of {len(job_ids)}", +# "failed_ids": { +# str(invalid_job_id): f"Job {invalid_job_id} not found" +# for invalid_job_id in invalid_job_ids +# }, +# } +# } +# for job_id in valid_job_ids: +# r = normal_user_client.get(f"/jobs/{job_id}/status") +# assert r.status_code == 404, r.json()