Skip to content

Commit

Permalink
perf: majority vote for best label and average score
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Nov 16, 2024
1 parent 6b3f621 commit 4f3d5f3
Showing 1 changed file with 40 additions and 17 deletions.
57 changes: 40 additions & 17 deletions biotrack/track.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# biotrack, Apache-2.0 license
# Filename: biotrack/tracker/track.py
# Description: Basic track object to contain and update tracks
from collections import Counter

from biotrack.logger import info, debug
import numpy as np

Expand All @@ -19,6 +21,7 @@ def __init__(self, track_id: int, label: str, pt: np.array, emb: np.array, frame
self.box = {frame: box}
self.emb = emb
self.best_label = label
self.best_score = score
self.start_frame = frame
self.last_updated_frame = frame
self.x_scale = x_scale
Expand Down Expand Up @@ -67,10 +70,9 @@ def get_best(self, rescale=True) -> (int, np.array, str, np.array, float):
frame_num = self.last_updated_frame
box = self.box[frame_num]
pt = self.pt[frame_num]
max_score = max(self.score.values())
if rescale:
pt, box = self.rescale(pt, box)
return frame_num, pt, self.best_label, box, max_score
return frame_num, pt, self.best_label, box, self.best_score

def get(self, frame_num: int, rescale=True) -> (np.array, str, np.array, float):
if frame_num not in self.pt.keys():
Expand Down Expand Up @@ -118,20 +120,41 @@ def update(self, label: str, pt: np.array, emb: np.array, frame_num: int, box:np
self.emb = emb
self.last_updated_frame = frame_num

# Update the best_label with that of the highest scoring label
# Only choose scores that are greater than 0
valid_scores = {k: v for k, v in self.score.items() if v > 0}
if len(valid_scores) == 0:
self.best_label = "marine organism"
# Update the best_label with that which occurs the most that has a score > 0. This is a simple majority vote
data = [(pred, score) for pred, score in zip(self.label.values(), self.score.values()) if float(score) > 0.]

if len(data) > 0:
p, s = zip(*data)
model_predictions = list(p)
model_scores = list(s)

# Count occurrences of each prediction in the top lists
counter = Counter(model_predictions)

majority_count = (len(data) // 2) + 1

majority_predictions = [pred for pred, count in counter.items() if count >= majority_count]

# If there are no majority predictions
if len(majority_predictions) == 0:
# Pick the prediction with the highest score
# best_pred, max_score = max_score_p(model_predictions, model_scores)
self.best_label = "marine organism"
self.best_score = 0.0
else:
self.best_label = majority_predictions[0]
best_score = 0.0
num_majority = 0
# Sum all the scores for the majority predictions
for pred, score in data:
if pred in majority_predictions:
best_score += float(score)
num_majority += 1
self.best_score /= num_majority
else:
max_score = max(valid_scores.values())
for key, value in valid_scores.items():
if value == max_score:
max_frame = key
break
self.best_label = self.label[max_frame]

pts_pretty = [f"{pt[0]:.2f},{pt[1]:.2f},{label}" for pt, label in zip(self.pt.values(), self.label.values())]
best_label = max(set(self.label.values()), key=list(self.label.values()).count)
self.best_label = "marine organism"
self.best_score = 0.0

pts_pretty = [f"{pt[0]:.2f},{pt[1]:.2f},{label},{score}" for pt, label, score in zip(self.pt.values(), self.label.values(), self.score.values())]
total_frames = len(self.pt)
info(f"Updating tracker {self.id} total_frames {total_frames} updated start {self.start_frame} to {self.last_updated_frame} {pts_pretty} with label {best_label}")
info(f"Updating tracker {self.id} total_frames {total_frames} updated start {self.start_frame} to {self.last_updated_frame} {pts_pretty} with label {self.best_label}, score {self.best_score}")

0 comments on commit 4f3d5f3

Please sign in to comment.