diff --git a/biotrack/track.py b/biotrack/track.py index c7c3060..542bd9c 100644 --- a/biotrack/track.py +++ b/biotrack/track.py @@ -1,8 +1,7 @@ # biotrack, Apache-2.0 license # Filename: biotrack/track.py # Description: Basic track object to contain and update tracks -from collections import Counter - +from collections import defaultdict from biotrack.logger import info, debug import numpy as np @@ -90,6 +89,10 @@ def rescale(self, pt: np.array, box: np.array) -> (np.array, np.array): box_rescale = box return pt_rescale, box_rescale + @property + def num_frames(self): + return self.last_updated_frame - self.start_frame + 1 + def get_best(self, rescale=True) -> (int, np.array, str, np.array, float): # Get the best box which is a few frames behind the last_updated_frame # This is pretty arbitrary, but sometimes the last box is too blurry or not visible @@ -159,11 +162,16 @@ def update(self, label: str, pt: np.array, emb: np.array, frame_num: int, box:np scores = scores[-10:] labels = labels[-10:] - # Update the best_label with that of the highest scoring recent label - max_score = max(scores) - max_frame = np.argmax(scores) - self.best_label = labels[max_frame] - self.best_score = max_score + # Calculate the weighted score for each label + label_scores = defaultdict(float) + for score, label in zip(scores, labels): + label_scores[label] += float(score) + + # Update the best_label with a weighted probability + label = max(label_scores, key=label_scores.get) + num_best_label = labels.count(label) + self.best_label = label + self.best_score = label_scores[label] / num_best_label pts_pretty = [f"{pt[0]:.2f},{pt[1]:.2f},{label},{score}" for pt, label, score in zip(self.pt.values(), self.label.values(), self.score.values())] total_frames = len(self.pt)