Skip to content

Commit

Permalink
perf: weighted score for best label/score
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Nov 22, 2024
1 parent 692f9fa commit 4875e18
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions biotrack/track.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# biotrack, Apache-2.0 license
# Filename: biotrack/track.py
# Description: Basic track object to contain and update tracks
from collections import Counter

from collections import defaultdict
from biotrack.logger import info, debug
import numpy as np

Expand Down Expand Up @@ -90,6 +89,10 @@ def rescale(self, pt: np.array, box: np.array) -> (np.array, np.array):
box_rescale = box
return pt_rescale, box_rescale

@property
def num_frames(self):
return self.last_updated_frame - self.start_frame + 1

def get_best(self, rescale=True) -> (int, np.array, str, np.array, float):
# Get the best box which is a few frames behind the last_updated_frame
# This is pretty arbitrary, but sometimes the last box is too blurry or not visible
Expand Down Expand Up @@ -159,11 +162,16 @@ def update(self, label: str, pt: np.array, emb: np.array, frame_num: int, box:np
scores = scores[-10:]
labels = labels[-10:]

# Update the best_label with that of the highest scoring recent label
max_score = max(scores)
max_frame = np.argmax(scores)
self.best_label = labels[max_frame]
self.best_score = max_score
# Calculate the weighted score for each label
label_scores = defaultdict(float)
for score, label in zip(scores, labels):
label_scores[label] += float(score)

# Update the best_label with a weighted probability
label = max(label_scores, key=label_scores.get)
num_best_label = labels.count(label)
self.best_label = label
self.best_score = label_scores[label] / num_best_label

pts_pretty = [f"{pt[0]:.2f},{pt[1]:.2f},{label},{score}" for pt, label, score in zip(self.pt.values(), self.label.values(), self.score.values())]
total_frames = len(self.pt)
Expand Down

0 comments on commit 4875e18

Please sign in to comment.