Skip to content

Commit

Permalink
build: better build for cuda deploy
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Dec 12, 2024
1 parent f66145e commit 47f93ca
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 78 deletions.
65 changes: 46 additions & 19 deletions biotrack/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Description: Main tracker class for tracking objects with points, label and embeddings using the CoTracker model and BioClip ViT embeddings
from dbm.dumb import error

import cv2
from PIL import Image

import torch
Expand Down Expand Up @@ -49,7 +50,7 @@ def update_trackers_queries(self, frame_num: int, keypoints: np.array, labels: L
:param boxes: the bounding boxes of the queries in the format [[x1,y1,x2,y2],[x1,y1,x2,y2]...] in the normalized scale 0-1
:param scores: the scores of the queries in the format [score1, score2, score3...] in the normalized scale 0-1
:param coverages: the coverage of the gradcam queries in the format [coverage1, coverage2, coverage3...] in the normalized scale 0-1
:param labels: the labels of the queries in the format [label1, label2, label3...]
:param labels: the labels of the queries in the format [[top_label11, top_label12], [top_label21, top_label22], [top_label31, top_label32],...]
:param frame_num: the starting frame number of the batch
:param keypoints: points in the format [[[x1,y1],[x2,y2],[[x1,y1],[x2,y2],[x3,y3...], [x1,y1],[x2,y2],[[x1,y1],[x2,y2],[x3,y3...], one per track in the normalized scale 0-1
:param kwargs:
Expand All @@ -70,9 +71,9 @@ def update_trackers_queries(self, frame_num: int, keypoints: np.array, labels: L

# Get the newly created tracks and initialize
for j, data in enumerate(zip(new_tracks, keypoints, d_emb, boxes, labels, scores, coverages)):
new_tracks[j], points, emb, box, label, score, coverage = data
new_tracks[j], points, emb, box, labels, scores, coverage = data
for i, pt in enumerate(points[0]):
new_tracks[j].init(i, label, pt, emb, frame_num, box=box, score=score, coverage=coverage)
new_tracks[j].init(i, labels, pt, emb, frame_num, box=box, scores=scores, coverage=coverage)

self.open_trackers.extend(new_tracks)
return
Expand Down Expand Up @@ -101,16 +102,20 @@ def update_trackers_queries(self, frame_num: int, keypoints: np.array, labels: L
error(f"Unassigned points {len(unassigned)} is not a multiple of {Track.NUM_KP}")
return

# Confine the unassigned points to the number of labels - cannot assign the keypoints to more than one label
unassigned = unassigned[:len(labels)*Track.NUM_KP]

for i in range(0, len(unassigned), Track.NUM_KP):
info(f"Creating new track {self.next_track_id} at {frame_num}")
track = Track(self.next_track_id, self.image_width, self.image_height, **kwargs)
unassigned_idx = unassigned[i:i + Track.NUM_KP]
labels = labels[i // Track.NUM_KP]
scores = scores[i // Track.NUM_KP]
box = boxes[i // Track.NUM_KP]
coverage = coverages[i // Track.NUM_KP]
debug(f"Creating new track with {len(unassigned_idx)} points {unassigned_idx} labels {labels} scores {scores} box {box} coverage {coverage}")
for j, d_idx in enumerate(unassigned_idx):
box = boxes[d_idx // Track.NUM_KP]
label = labels[d_idx // Track.NUM_KP]
score = scores[d_idx // Track.NUM_KP]
coverage = coverages[d_idx // Track.NUM_KP]
track.init(j, label, keypoints[d_idx], d_emb[d_idx // Track.NUM_KP], frame_num, box=box, score=score, coverage=coverage)
track.init(j, labels, keypoints[d_idx], d_emb[d_idx // Track.NUM_KP], frame_num, box=box, scores=scores, coverage=coverage)
self.open_trackers.append(track)
self.next_track_id += 1

Expand Down Expand Up @@ -185,6 +190,18 @@ def check(self, frame_num: int):
self.closed_trackers.append(t)
self.open_trackers.pop(i)

# Remove any tracks that have high intersection over union with each other
for i, t1 in enumerate(self.open_trackers):
for j, t2 in enumerate(self.open_trackers):
if i == j:
continue
if t1.is_closed() or t2.is_closed():
continue
if t1.intersection_over_union(t2) > 0.5:
info(f"======>Removing track {t1.track_id} with iou {t1.intersection_over_union(t2):.2f} with track {t2.track_id}")
self.open_trackers.pop(i)
break

def update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, detections: Dict, **kwargs):
"""
Update the tracker with new frames and det_query
Expand All @@ -194,6 +211,7 @@ def update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, detecti
:param kwargs:
:return:
"""
imshow = kwargs.get("imshow", False)
def correct_keypoints(top_kps, crop_paths):
correct_kpts = top_kps
for i, data in enumerate(zip(top_kps, crop_paths)):
Expand Down Expand Up @@ -233,29 +251,38 @@ def correct_keypoints(top_kps, crop_paths):
images = [d['crop_path'] for d in detections if d["frame"] == i]
embeddings, predicted_classes, predicted_scores, keypoints, coverages = self.vit_wrapper.process_images(images)
# Remove any data that has no keypoints
# Get the index of the keypoints that are empty
empty_idx = [i for i, kpts in enumerate(keypoints) if len(kpts) == 0]
if len(empty_idx) > 0:
info(f"Removing empty keypoints {empty_idx}")
keypoints = [kpts for i, kpts in enumerate(keypoints) if i not in empty_idx]
embeddings = [emb for i, emb in enumerate(embeddings) if i not in empty_idx]
predicted_classes = [p for i, p in enumerate(predicted_classes) if i not in empty_idx]
predicted_scores = [p for i, p in enumerate(predicted_scores) if i not in empty_idx]
if len(embeddings) == 0: # No data found
keypoints = [kpts for i, kpts in enumerate(keypoints) if i not in empty_idx]
embeddings = [emb for i, emb in enumerate(embeddings) if i not in empty_idx]
predicted_classes = [p for i, p in enumerate(predicted_classes) if i not in empty_idx]
predicted_scores = [p for i, p in enumerate(predicted_scores) if i not in empty_idx]
coverages = [c for i, c in enumerate(coverages) if i not in empty_idx]
images = [img for i, img in enumerate(images) if i not in empty_idx]
if len(keypoints) == 0: # No data found
info(f"No valid keypoints found for frame {i}")
det_query.pop(i)
image_query.pop(i)
continue
correct_kpts = correct_keypoints(keypoints, images)
predicted_classes = [p[0] for p in predicted_classes]
predicted_scores = [p[0] for p in predicted_scores]
info(f"Adding query for {correct_kpts} in frame idx {i}")

if imshow:
# Display the keypoints on the image
for k in correct_kpts:
for kp in k[0]:
cv2.circle(frames[i], (int(kp[0]*self.model_width), int(kp[1]*self.model_height)), 5, (0, 255, 0), -1)
cv2.imshow("Keypoints", frames[i])
cv2.waitKey(-1)
predicted_classes = [[p[0], p[1]] for p in predicted_classes]
predicted_scores = [[p[0], p[1]] for p in predicted_scores]
info(f"Adding query for {correct_kpts} in frame idx {i} predicted_classes {predicted_classes} predicted_scores {predicted_scores} coverages {coverages}")
det_query[i].append([correct_kpts, predicted_classes, predicted_scores, coverages, boxes, embeddings])
image_query[i].append(images)

return self._update_batch(frame_range, frames, det_query, image_query, **kwargs)

def _update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, det_query: Dict, crop_query: Dict, save:bool = False, **kwargs):
def _update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, det_query: Dict, crop_query: Dict, save:bool = True, **kwargs):
"""
Update the tracker with the new frames, detections and crops of the detections
:param frame_range: a tuple of the starting and ending frame numbers
Expand Down Expand Up @@ -396,4 +423,4 @@ def get_tracks(self):
Get the open and closed tracks
:return: a list of open and closed tracks
"""
return self.open_trackers + self.closed_trackers
return self.open_trackers + self.closed_trackers
Loading

0 comments on commit 47f93ca

Please sign in to comment.