diff --git a/src/mainframe/endpoints/package.py b/src/mainframe/endpoints/package.py index 67a78893..5228a546 100644 --- a/src/mainframe/endpoints/package.py +++ b/src/mainframe/endpoints/package.py @@ -4,7 +4,7 @@ import structlog from fastapi import APIRouter, Depends, HTTPException -from letsbuilda.pypi import Package, PyPIServices # type: ignore +from letsbuilda.pypi import Package as PyPIPackage, PyPIServices # type: ignore from letsbuilda.pypi.exceptions import PackageNotFoundError from sqlalchemy import select, tuple_ from sqlalchemy.exc import IntegrityError @@ -17,6 +17,7 @@ from mainframe.models.orm import DownloadURL, Rule, Scan, Status from mainframe.models.schemas import ( Error, + Package, PackageScanResult, PackageScanResultFail, PackageSpecifier, @@ -168,8 +169,10 @@ def lookup_package_info( with session, session.begin(): data = session.scalars(query).unique().all() + packages = [Package.from_db(result) for result in data] + log.info("Package information queried") - return data + return packages def _deduplicate_packages(packages: list[PackageSpecifier], session: Session) -> set[tuple[str, str]]: @@ -178,11 +181,11 @@ def _deduplicate_packages(packages: list[PackageSpecifier], session: Session) -> return name_ver - {(scan.name, scan.version) for scan in scalars.all()} -def _get_packages_metadata(pypi_client: PyPIServices, packages_to_check: set[tuple[str, str]]) -> Iterable[Package]: +def _get_packages_metadata(pypi_client: PyPIServices, packages_to_check: set[tuple[str, str]]) -> Iterable[PyPIPackage]: if not packages_to_check: return - def _get_package_metadata(package: tuple[str, str]) -> Optional[Package]: + def _get_package_metadata(package: tuple[str, str]) -> Optional[PyPIPackage]: try: return pypi_client.get_package_metadata(*package) except PackageNotFoundError: diff --git a/src/mainframe/models/schemas.py b/src/mainframe/models/schemas.py index 3399192b..3624e34b 100644 --- a/src/mainframe/models/schemas.py +++ b/src/mainframe/models/schemas.py @@ -1,7 +1,10 @@ -from enum import Enum +from datetime import datetime from typing import Any, Optional +from enum import Enum -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_serializer + +from mainframe.models.orm import Scan class ServerMetadata(BaseModel): @@ -17,6 +20,60 @@ class Error(BaseModel): detail: str +class Package(BaseModel): + """Model representing a package queried from the database.""" + + scan_id: str + name: str + version: Optional[str] + status: Optional[str] + score: Optional[int] + inspector_url: Optional[str] + rules: list[str] = [] + download_urls: list[str] = [] + queued_at: Optional[datetime] + queued_by: Optional[str] + reported_at: Optional[datetime] + reported_by: Optional[str] + pending_at: Optional[datetime] + pending_by: Optional[str] + finished_at: Optional[datetime] + finished_by: Optional[str] + commit_hash: Optional[str] + + @classmethod + def from_db(cls, scan: Scan): + return cls( + scan_id=str(scan.scan_id), + name=scan.name, + version=scan.version, + status=str(scan.status), + score=scan.score, # pyright: ignore + inspector_url=scan.inspector_url, + rules=[rule.name for rule in scan.rules], + download_urls=[url.url for url in scan.download_urls], + reported_at=scan.reported_at, + reported_by=scan.reported_by, + queued_at=scan.queued_at, + queued_by=scan.queued_by, + pending_at=scan.pending_at, + pending_by=scan.pending_by, + finished_at=scan.finished_at, + finished_by=scan.finished_by, + commit_hash=scan.commit_hash, + ) + + @field_serializer( + "queued_at", + "pending_at", + "finished_at", + "reported_at", + ) + def serialize_dt(self, dt: Optional[datetime], _info): # pyright: ignore + if dt: + return int(dt.timestamp()) + + class PackageSpecifier(BaseModel): """ Model used to specify a package by name and version diff --git a/tests/test_package.py b/tests/test_package.py index 1ccba134..6759d8c3 100644 --- a/tests/test_package.py +++ b/tests/test_package.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Optional import pytest @@ -19,6 +20,7 @@ from mainframe.json_web_token import AuthenticationData from mainframe.models.orm import Scan, Status from mainframe.models.schemas import ( + Package as ResponsePackage, PackageScanResult, PackageScanResultFail, PackageSpecifier, @@ -261,3 +263,47 @@ def test_submit_duplicate_package( else: assert all(scan.status != Status.QUEUED for scan in test_data) + + +def test_package_from_db(): + """Test the from_db method of Package.""" + + scan = Scan( + name="pyfoo", + version="3.12.2", + score=14, + queued_by="Ryan", + reported_by="Ryan", + queued_at=datetime(2024, 3, 5, 12, 30, 0), + ) + + pkg = ResponsePackage.from_db(scan) + + assert pkg.name == "pyfoo" + assert pkg.version == "3.12.2" + assert pkg.score == 14 + assert pkg.queued_by == "Ryan" + assert pkg.reported_by == "Ryan" + assert pkg.queued_at == datetime(2024, 3, 5, 12, 30, 0) + + +def test_datetime_serialization(): + """Test that the datetime fields are serialized correctly.""" + + scan = Scan( + name="Pyfoo", + version="3.13.0", + queued_at=datetime(2023, 10, 12, 13, 45, 30), + pending_at=datetime(2023, 10, 12, 13, 45, 30), + finished_at=datetime(2023, 10, 12, 13, 45, 30), + reported_at=datetime(2023, 10, 12, 13, 45, 30), + queued_by="Tina", + ) + + pkg = ResponsePackage.from_db(scan).model_dump() + dt = int(datetime(2023, 10, 12, 13, 45, 30).timestamp()) + + assert pkg.get("queued_at") == dt + assert pkg.get("pending_at") == dt + assert pkg.get("finished_at") == dt + assert pkg.get("reported_at") == dt