Skip to content

Commit

Permalink
fix: troubleshooting duplicate classifications
Browse files Browse the repository at this point in the history
  • Loading branch information
mihow committed Sep 26, 2024
1 parent c587a93 commit c843686
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 18 deletions.
17 changes: 15 additions & 2 deletions trapdata/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,24 @@ def _get_source_image(source_images, source_image_id):


@app.post("/pipeline/process")
@app.post("/pipeline/process/")
async def process(data: PipelineRequest) -> PipelineResponse:
# Ensure that the source images are unique, filter out duplicates
source_images_index = {
source_image.id: source_image for source_image in data.source_images
}
incoming_source_images = list(source_images_index.values())
if len(incoming_source_images) != len(data.source_images):
logger.warning(
f"Removed {len(data.source_images) - len(incoming_source_images)} duplicate source images"
)

source_image_results = [
SourceImageResponse(**image.model_dump()) for image in data.source_images
SourceImageResponse(**image.model_dump()) for image in incoming_source_images
]
source_images = [
SourceImage(**image.model_dump()) for image in incoming_source_images
]
source_images = [SourceImage(**image.model_dump()) for image in data.source_images]

start_time = time.time()
detector = MothDetector(
Expand Down
38 changes: 26 additions & 12 deletions trapdata/api/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import torch
from rich import print

from trapdata.common.logs import logger
from trapdata.ml.models.classification import (
Expand Down Expand Up @@ -37,6 +36,9 @@ def __init__(
self.detections = list(detections)
self.results: list[Detection] = []
super().__init__(*args, **kwargs)
logger.info(
f"Initialized {self.__class__.__name__} with {len(self.detections)} detections"
)

def get_dataset(self):
return ClassificationImageDataset(
Expand Down Expand Up @@ -89,19 +91,32 @@ def save_results(
timestamp=datetime.datetime.now(),
)
self.update_classification(detection, classification)
print(detection)
# print(detection)
self.results.extend(self.detections)
logger.info(f"Saving {len(self.results)} detections with classifications")
return self.results

def update_classification(self, detection: Detection, new_classification: Classification) -> None:
def update_classification(
self, detection: Detection, new_classification: Classification
) -> None:
# Remove all existing classifications from this algorithm
detection.classifications = [c for c in detection.classifications if c.algorithm != self.name]
detection.classifications = [
c for c in detection.classifications if c.algorithm != self.name
]
# Add the new classification for this algorithm
detection.classifications.append(new_classification)
logger.debug(
f"Updated classification for detection {detection.bbox}. Total classifications: {len(detection.classifications)}"
)

def run(self) -> list[Detection]:
logger.info(
f"Starting {self.__class__.__name__} run with {len(self.results)} detections"
)
super().run()
logger.info(
f"Finished {self.__class__.__name__} run. Processed {len(self.results)} detections"
)
return self.results


Expand Down Expand Up @@ -134,8 +149,11 @@ def save_results(
# Specific to binary classification / the filter model
terminal=False,
)
print(detection)
if not self.filter_results or classification.classification == self.positive_binary_label:
# print(detection)
if (
not self.filter_results
or classification.classification == self.positive_binary_label
):
self.update_classification(detection, classification)

self.results.extend(self.detections)
Expand All @@ -149,15 +167,11 @@ class MothClassifierPanama(
pass


class MothClassifierPanama2024(
MothClassifier, PanamaMothSpeciesClassifier2024
):
class MothClassifierPanama2024(MothClassifier, PanamaMothSpeciesClassifier2024):
pass


class MothClassifierUKDenmark(
MothClassifier, UKDenmarkMothSpeciesClassifier2024
):
class MothClassifierUKDenmark(MothClassifier, UKDenmarkMothSpeciesClassifier2024):
pass


Expand Down
4 changes: 1 addition & 3 deletions trapdata/api/models/localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import datetime
import typing

from rich import print

from trapdata.common.logs import logger
from trapdata.ml.models.localization import (
MothObjectDetector_FasterRCNN_2023,
Expand Down Expand Up @@ -98,7 +96,7 @@ def save_detection(image_id, coords):
timestamp=datetime.datetime.now(),
crop_image_url=crop_url,
)
print(detection)
# print(detection)
return detection

with concurrent.futures.ThreadPoolExecutor() as executor:
Expand Down
2 changes: 1 addition & 1 deletion trapdata/cli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def run_api():
"""
import uvicorn

uvicorn.run("trapdata.api.api:app", host="0.0.0.0", port=2001, reload=True)
uvicorn.run("trapdata.api.api:app", host="0.0.0.0", port=2000, reload=True)


if __name__ == "__main__":
Expand Down

0 comments on commit c843686

Please sign in to comment.