From 0a7283d58ef932ceecb31d7124cd0c7facb127cf Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Wed, 30 Oct 2024 18:18:12 -0500 Subject: [PATCH] Combine get_reported_version into validate_package --- src/mainframe/endpoints/report.py | 32 ++++++------------------ tests/test_report.py | 41 +------------------------------ 2 files changed, 8 insertions(+), 65 deletions(-) diff --git a/src/mainframe/endpoints/report.py b/src/mainframe/endpoints/report.py index fa3b7a5..40eee63 100644 --- a/src/mainframe/endpoints/report.py +++ b/src/mainframe/endpoints/report.py @@ -27,22 +27,6 @@ router = APIRouter(tags=["report"]) -def get_reported_version(scans: Sequence[Scan]) -> Optional[Scan]: - """ - Get the version of this scan that was reported. - - Returns: - `Scan`: The scan record that was reported - `None`: No versions of this package were reported - """ - - for scan in scans: - if scan.reported_at is not None: - return scan - - return None - - def validate_package(name: str, version: str, scans: Sequence[Scan]) -> Scan: """ Checks if the package is valid according to our database. @@ -62,17 +46,15 @@ def validate_package(name: str, version: str, scans: Sequence[Scan]) -> Scan: PackageAlreadyReported: The package was already reported """ - if not scans: - raise PackageNotFound(name=name, version=version) - - if scan := get_reported_version(scans): - raise PackageAlreadyReported(name=scan.name, reported_version=scan.version) + for scan in scans: + if scan.reported_at is not None: + raise PackageAlreadyReported(name=scan.name, reported_version=scan.version) - scan = next((s for s in scans if (s.name, s.version) == (name, version)), None) - if scan is None: - raise PackageNotFound(name=name, version=version) + for scan in scans: + if (scan.name, scan.version) == (name, version): + return scan - return scan + raise PackageNotFound(name=name, version=version) def _validate_inspector_url(name: str, version: str, body_url: Optional[str], scan_url: Optional[str]) -> str: diff --git a/tests/test_report.py b/tests/test_report.py index 1a4f6ec..20bc923 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -8,10 +8,7 @@ from fastapi.encoders import jsonable_encoder from mainframe.custom_exceptions import PackageAlreadyReported, PackageNotFound -from mainframe.endpoints.report import ( - validate_package, - get_reported_version, -) +from mainframe.endpoints.report import validate_package from mainframe.endpoints.report import ( _validate_inspector_url, # pyright: ignore [reportPrivateUsage] ) @@ -29,42 +26,6 @@ from tests.conftest import MockDatabase -def test_get_reported_version(): - scan1 = Scan( - name="package1", - version="1.0.0", - reported_at=datetime.now(), - ) - - scan2 = Scan( - name="package1", - version="1.0.1", - reported_at=None, - ) - - scans = [scan1, scan2] - - assert get_reported_version(scans) == scan1 - - -def test_get_no_reported_version(): - scan1 = Scan( - name="package1", - version="1.0.0", - reported_at=None, - ) - - scan2 = Scan( - name="package1", - version="1.0.1", - reported_at=None, - ) - - scans = [scan1, scan2] - - assert get_reported_version(scans) is None - - def test_validate_package(): scan1 = Scan( name="package1",