From c8436866fc333af53ffdb17db91826d816452758 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 25 Sep 2024 23:13:16 -0700 Subject: [PATCH] fix: troubleshooting duplicate classifications --- trapdata/api/api.py | 17 ++++++++++-- trapdata/api/models/classification.py | 38 ++++++++++++++++++--------- trapdata/api/models/localization.py | 4 +-- trapdata/cli/base.py | 2 +- 4 files changed, 43 insertions(+), 18 deletions(-) diff --git a/trapdata/api/api.py b/trapdata/api/api.py index b9be78c7..4969b792 100644 --- a/trapdata/api/api.py +++ b/trapdata/api/api.py @@ -81,11 +81,24 @@ def _get_source_image(source_images, source_image_id): @app.post("/pipeline/process") +@app.post("/pipeline/process/") async def process(data: PipelineRequest) -> PipelineResponse: + # Ensure that the source images are unique, filter out duplicates + source_images_index = { + source_image.id: source_image for source_image in data.source_images + } + incoming_source_images = list(source_images_index.values()) + if len(incoming_source_images) != len(data.source_images): + logger.warning( + f"Removed {len(data.source_images) - len(incoming_source_images)} duplicate source images" + ) + source_image_results = [ - SourceImageResponse(**image.model_dump()) for image in data.source_images + SourceImageResponse(**image.model_dump()) for image in incoming_source_images + ] + source_images = [ + SourceImage(**image.model_dump()) for image in incoming_source_images ] - source_images = [SourceImage(**image.model_dump()) for image in data.source_images] start_time = time.time() detector = MothDetector( diff --git a/trapdata/api/models/classification.py b/trapdata/api/models/classification.py index 7f9b1bc2..dc2266ac 100644 --- a/trapdata/api/models/classification.py +++ b/trapdata/api/models/classification.py @@ -3,7 +3,6 @@ import numpy as np import torch -from rich import print from trapdata.common.logs import logger from trapdata.ml.models.classification import ( @@ -37,6 +36,9 @@ def __init__( self.detections = list(detections) self.results: list[Detection] = [] super().__init__(*args, **kwargs) + logger.info( + f"Initialized {self.__class__.__name__} with {len(self.detections)} detections" + ) def get_dataset(self): return ClassificationImageDataset( @@ -89,19 +91,32 @@ def save_results( timestamp=datetime.datetime.now(), ) self.update_classification(detection, classification) - print(detection) + # print(detection) self.results.extend(self.detections) logger.info(f"Saving {len(self.results)} detections with classifications") return self.results - def update_classification(self, detection: Detection, new_classification: Classification) -> None: + def update_classification( + self, detection: Detection, new_classification: Classification + ) -> None: # Remove all existing classifications from this algorithm - detection.classifications = [c for c in detection.classifications if c.algorithm != self.name] + detection.classifications = [ + c for c in detection.classifications if c.algorithm != self.name + ] # Add the new classification for this algorithm detection.classifications.append(new_classification) + logger.debug( + f"Updated classification for detection {detection.bbox}. Total classifications: {len(detection.classifications)}" + ) def run(self) -> list[Detection]: + logger.info( + f"Starting {self.__class__.__name__} run with {len(self.results)} detections" + ) super().run() + logger.info( + f"Finished {self.__class__.__name__} run. Processed {len(self.results)} detections" + ) return self.results @@ -134,8 +149,11 @@ def save_results( # Specific to binary classification / the filter model terminal=False, ) - print(detection) - if not self.filter_results or classification.classification == self.positive_binary_label: + # print(detection) + if ( + not self.filter_results + or classification.classification == self.positive_binary_label + ): self.update_classification(detection, classification) self.results.extend(self.detections) @@ -149,15 +167,11 @@ class MothClassifierPanama( pass -class MothClassifierPanama2024( - MothClassifier, PanamaMothSpeciesClassifier2024 -): +class MothClassifierPanama2024(MothClassifier, PanamaMothSpeciesClassifier2024): pass -class MothClassifierUKDenmark( - MothClassifier, UKDenmarkMothSpeciesClassifier2024 -): +class MothClassifierUKDenmark(MothClassifier, UKDenmarkMothSpeciesClassifier2024): pass diff --git a/trapdata/api/models/localization.py b/trapdata/api/models/localization.py index f4198ed0..91497f20 100644 --- a/trapdata/api/models/localization.py +++ b/trapdata/api/models/localization.py @@ -2,8 +2,6 @@ import datetime import typing -from rich import print - from trapdata.common.logs import logger from trapdata.ml.models.localization import ( MothObjectDetector_FasterRCNN_2023, @@ -98,7 +96,7 @@ def save_detection(image_id, coords): timestamp=datetime.datetime.now(), crop_image_url=crop_url, ) - print(detection) + # print(detection) return detection with concurrent.futures.ThreadPoolExecutor() as executor: diff --git a/trapdata/cli/base.py b/trapdata/cli/base.py index eecf487f..616c645d 100644 --- a/trapdata/cli/base.py +++ b/trapdata/cli/base.py @@ -93,7 +93,7 @@ def run_api(): """ import uvicorn - uvicorn.run("trapdata.api.api:app", host="0.0.0.0", port=2001, reload=True) + uvicorn.run("trapdata.api.api:app", host="0.0.0.0", port=2000, reload=True) if __name__ == "__main__":