diff --git a/biotrack/track.py b/biotrack/track.py index 690c73b..269427e 100644 --- a/biotrack/track.py +++ b/biotrack/track.py @@ -1,6 +1,8 @@ # biotrack, Apache-2.0 license # Filename: biotrack/tracker/track.py # Description: Basic track object to contain and update tracks +from collections import Counter + from biotrack.logger import info, debug import numpy as np @@ -19,6 +21,7 @@ def __init__(self, track_id: int, label: str, pt: np.array, emb: np.array, frame self.box = {frame: box} self.emb = emb self.best_label = label + self.best_score = score self.start_frame = frame self.last_updated_frame = frame self.x_scale = x_scale @@ -67,10 +70,9 @@ def get_best(self, rescale=True) -> (int, np.array, str, np.array, float): frame_num = self.last_updated_frame box = self.box[frame_num] pt = self.pt[frame_num] - max_score = max(self.score.values()) if rescale: pt, box = self.rescale(pt, box) - return frame_num, pt, self.best_label, box, max_score + return frame_num, pt, self.best_label, box, self.best_score def get(self, frame_num: int, rescale=True) -> (np.array, str, np.array, float): if frame_num not in self.pt.keys(): @@ -118,20 +120,41 @@ def update(self, label: str, pt: np.array, emb: np.array, frame_num: int, box:np self.emb = emb self.last_updated_frame = frame_num - # Update the best_label with that of the highest scoring label - # Only choose scores that are greater than 0 - valid_scores = {k: v for k, v in self.score.items() if v > 0} - if len(valid_scores) == 0: - self.best_label = "marine organism" + # Update the best_label with that which occurs the most that has a score > 0. This is a simple majority vote + data = [(pred, score) for pred, score in zip(self.label.values(), self.score.values()) if float(score) > 0.] + + if len(data) > 0: + p, s = zip(*data) + model_predictions = list(p) + model_scores = list(s) + + # Count occurrences of each prediction in the top lists + counter = Counter(model_predictions) + + majority_count = (len(data) // 2) + 1 + + majority_predictions = [pred for pred, count in counter.items() if count >= majority_count] + + # If there are no majority predictions + if len(majority_predictions) == 0: + # Pick the prediction with the highest score + # best_pred, max_score = max_score_p(model_predictions, model_scores) + self.best_label = "marine organism" + self.best_score = 0.0 + else: + self.best_label = majority_predictions[0] + best_score = 0.0 + num_majority = 0 + # Sum all the scores for the majority predictions + for pred, score in data: + if pred in majority_predictions: + best_score += float(score) + num_majority += 1 + self.best_score /= num_majority else: - max_score = max(valid_scores.values()) - for key, value in valid_scores.items(): - if value == max_score: - max_frame = key - break - self.best_label = self.label[max_frame] - - pts_pretty = [f"{pt[0]:.2f},{pt[1]:.2f},{label}" for pt, label in zip(self.pt.values(), self.label.values())] - best_label = max(set(self.label.values()), key=list(self.label.values()).count) + self.best_label = "marine organism" + self.best_score = 0.0 + + pts_pretty = [f"{pt[0]:.2f},{pt[1]:.2f},{label},{score}" for pt, label, score in zip(self.pt.values(), self.label.values(), self.score.values())] total_frames = len(self.pt) - info(f"Updating tracker {self.id} total_frames {total_frames} updated start {self.start_frame} to {self.last_updated_frame} {pts_pretty} with label {best_label}") + info(f"Updating tracker {self.id} total_frames {total_frames} updated start {self.start_frame} to {self.last_updated_frame} {pts_pretty} with label {self.best_label}, score {self.best_score}")