Skip to content

Commit

Permalink
feat: add pass through of max_frames and max_empty_frames for track
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Nov 15, 2024
1 parent 4c1c62d commit 321335d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
14 changes: 8 additions & 6 deletions biotrack/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@


class Track:
def __init__(self, label: str, pt: np.array, emb: np.array, frame: int, x_scale: float, y_scale: float, box: np.array = None, score: float = 0., max_empty_frame: int = 30, max_frames: int = 300, id: int = 0):
info(f"Creating tracker {id} at frame {frame} with point {pt} score {score} and emb {emb.shape}. Max empty frame {max_empty_frame} Max frames {max_frames}")
self.max_empty_frame = max_empty_frame
def __init__(self, track_id: int, label: str, pt: np.array, emb: np.array, frame: int, x_scale: float, y_scale: float, box: np.array = None, score: float = 0., **kwargs):
max_empty_frames = kwargs.get("max_empty_frames", 30)
max_frames = kwargs.get("max_frames", 300)
info(f"Creating tracker {track_id} at {frame}:{pt},{score}. Max empty frame {max_empty_frames} Max frames {max_frames}")
self.max_empty_frames = max_empty_frames
self.max_frames = max_frames
self.id = id
self.id = track_id
self.pt = {frame: pt}
self.label = {frame: label}
self.score = {frame: score}
Expand All @@ -31,8 +33,8 @@ def embedding(self):
return self.emb

def is_closed(self, frame_num: int) -> bool:
is_closed = (frame_num - self.last_updated_frame + 1) >= self.max_empty_frame or len(self.pt) >= self.max_frames
info(f"Tracker {self.id} is_closed {is_closed} frame_num {frame_num} last_updated_frame {self.last_updated_frame} max_empty_frame {self.max_empty_frame} max_frames {self.max_frames}")
is_closed = (frame_num - self.last_updated_frame + 1) >= self.max_empty_frames or len(self.pt) >= self.max_frames
info(f"Tracker {self.id} is_closed {is_closed} frame_num {frame_num} last_updated_frame {self.last_updated_frame} max_empty_frame {self.max_empty_frames} max_frames {self.max_frames}")
return is_closed

@property
Expand Down
4 changes: 2 additions & 2 deletions biotrack/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def update_trackers(self, frame_num: int, points: List[np.array], embeddings: Li
open_tracks[match].update(label, point, emb, frame_num, box=box, score=score)
else:
info(f"Match too high {best_cost} > {max_cost}; creating new track {self.next_track_id} for point {point}")
self.open_trackers.append(Track(label, point, emb, frame_num, self.image_width, self.image_height, box=box, id=self.next_track_id, score=score))
self.open_trackers.append(Track(self.next_track_id, label, point, emb, frame_num, self.image_width, self.image_height, box=box, score=score, **kwargs))
self.next_track_id += 1
else:
self.open_trackers.append(Track(label, point, emb, frame_num, self.image_width, self.image_height, box=box, id=self.next_track_id, score=score))
self.open_trackers.append(Track(self.next_track_id, label, point, emb, frame_num, self.image_width, self.image_height, box=box, score=score, **kwargs))
self.next_track_id += 1

def update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, detections: Dict, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions examples/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

i_e = min(i + window_len - 1, num_frames) # handle the end of the video
print(f"Tracking frames {i} to {i_e}")
tracks = tracker.update_batch((i, i_e), frames, detections=detections)
tracks = tracker.update_batch((i, i_e), frames, detections=detections, max_frames=60, max_empty_frames=5)

# Display the tracks for the window
for j in range(len(frames)):
Expand All @@ -69,7 +69,7 @@
color = (255, 255, 255)
thickness = 1
frame = cv2.circle(frame, center, radius, color, thickness)
# Draw the track id with the label, e.g. 1:Unknown
# Draw the track track_id with the label, e.g. 1:Unknown
font = cv2.FONT_HERSHEY_SIMPLEX
fontScale = 1
frame = cv2.putText(frame, f"{track.id}:{label}{score:.2f}", center, font, fontScale, color, thickness, cv2.LINE_AA)
Expand Down

0 comments on commit 321335d

Please sign in to comment.