Skip to content

Commit

Permalink
Merge pull request #238 from vipyrsec/jobs-cte
Browse files Browse the repository at this point in the history
Use a CTE for the `/POST jobs` endpoint
  • Loading branch information
jonathan-d-zhang authored Jul 14, 2024
2 parents bf7ecd2 + 8f9c6f3 commit ccda8a1
Showing 1 changed file with 28 additions and 23 deletions.
51 changes: 28 additions & 23 deletions src/mainframe/endpoints/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import structlog
from fastapi import APIRouter, Depends
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import and_, or_, select, update
from sqlalchemy.orm import Session, joinedload, aliased

from mainframe.constants import mainframe_settings
from mainframe.database import get_db
Expand Down Expand Up @@ -41,33 +41,38 @@ def get_jobs(
"""

with session, session.begin():
scans = (
session.scalars(
select(Scan)
.where(
or_(
Scan.status == Status.QUEUED,
and_(
Scan.pending_at
< datetime.now(timezone.utc) - timedelta(seconds=mainframe_settings.job_timeout),
Scan.status == Status.PENDING,
),
)
# Use a CTE to limit the number of rows we fetch
cte = (
select(Scan)
.where(
or_(
Scan.status == Status.QUEUED,
and_(
Scan.pending_at
< datetime.now(timezone.utc) - timedelta(seconds=mainframe_settings.job_timeout),
Scan.status == Status.PENDING,
),
)
.order_by(Scan.pending_at.nulls_first(), Scan.queued_at)
.limit(batch)
.options(joinedload(Scan.download_urls))
)
.unique()
.all()
.order_by(Scan.pending_at.nulls_first(), Scan.queued_at)
.limit(batch)
.options(joinedload(Scan.download_urls))
.with_for_update(skip_locked=True)
.cte()
)

scan_cte = aliased(Scan, cte)

# Uses a Postgres `UPDATE .. FROM`. https://docs.sqlalchemy.org/en/20/tutorial/data_update.html#update-from
scans = session.scalars(
update(Scan)
.where(Scan.scan_id == scan_cte.scan_id)
.values(status=Status.PENDING, pending_at=datetime.now(timezone.utc), pending_by=auth.subject)
.returning(Scan)
)

response_body: list[JobResult] = []
for scan in scans:
scan.status = Status.PENDING
scan.pending_at = datetime.now(timezone.utc)
scan.pending_by = auth.subject

logger.info(
"Job given and status set to pending in database",
package={
Expand Down

0 comments on commit ccda8a1

Please sign in to comment.