Skip to content

Commit

Permalink
refactor: change to mostly normalized coordinates, except where needed
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Nov 15, 2024
1 parent 6254c72 commit 02bc124
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 58 deletions.
20 changes: 12 additions & 8 deletions biotrack/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,18 @@ def last_update_frame(self):
return self.last_updated_frame

def rescale(self, pt: np.array, box: np.array) -> (np.array, np.array):
pt[0] *= self.x_scale
pt[1] *= self.y_scale
if len(box) > 0:
box[0] *= self.x_scale
box[1] *= self.y_scale
box[2] *= self.x_scale
box[3] *= self.y_scale
return pt, box
pt_rescale = pt.copy()
pt_rescale[0] = pt[0] * self.x_scale
pt_rescale[1] = pt[1] * self.y_scale
if box is not None:
box_rescale = box.copy()
box_rescale[0] = box[0] * self.x_scale
box_rescale[1] = box[1] * self.y_scale
box_rescale[2] = box[2] * self.x_scale
box_rescale[3] = box[3] * self.y_scale
else:
box_rescale = box
return pt_rescale, box_rescale

def get_best(self, rescale=True) -> (int, np.array, str, np.array, float):
# Get the best box which is a few frames behind the last_updated_frame
Expand Down
81 changes: 39 additions & 42 deletions biotrack/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def __init__(self, image_width: int, image_height: int):
self.image_height = image_height
self.model_width = 640
self.model_height = 360
self.x_scale = image_width / self.model_width
self.y_scale = image_height / self.model_height
self.image_width = image_width
self.image_height = image_height
self.open_trackers: List[Track] = []
self.closed_trackers: List[Track] = []
self.next_track_id = 0 # Unique ID for new tracks
Expand All @@ -43,10 +43,10 @@ def __init__(self, image_width: int, image_height: int):

def update_trackers(self, frame_num: int, points: List[np.array], embeddings: List[np.array], **kwargs):
"""
Update the tracker with new detections and crops
Update the tracker with new det_query and crops
:param frame_num: the starting frame number of the batch
:param points: points in the format [[x1,y1],[x2,y2],[[x1,y1],[x2,y2],[x3,y3...]
a collection of detections detected in each last_updated_frame.
:param points: points in the format [[x1,y1],[x2,y2],[[x1,y1],[x2,y2],[x3,y3...] in the normalized scale 0-1
a collection of det_query detected in each last_updated_frame.
:param embeddings: a numpy array of embeddings in the format [[emb1],[emb2],[emb3]...]
:param kwargs:
:return:
Expand All @@ -61,35 +61,26 @@ def update_trackers(self, frame_num: int, points: List[np.array], embeddings: Li
t_emb = np.zeros((len(self.open_trackers), ViTWrapper.VECTOR_DIMENSIONS))
d_emb = embeddings

# If there are no detections, return
# If there are no det_query, return
if len(points) == 0:
return

open_tracks = [t for t in self.open_trackers if not t.is_closed(frame_num)]
info(f"Updating {len(open_tracks)} open tracks")

# Get predicted detections and embeddings from existing trackers
# Get predicted det_query and embeddings from existing trackers
for i, t in enumerate(open_tracks):
t_pts[i] = t.predict()
# Normalize the detections 0-1 based on the model size
t_pts[i][0] = t_pts[i][0] / self.model_width
t_pts[i][1] = t_pts[i][1] / self.model_height
t_emb[i] = t.embedding

# Normalize the detections 0-1 based on the model size, and convert to numpy array
points = np.array(points)
points[:, 0] = points[:, 0] / self.model_width
points[:, 1] = points[:, 1] / self.model_height
d_emb = np.array(d_emb)

# Associate the new detections with the existing tracks
# Associate the new det_query with the existing tracks
costs = associate(detection_pts=points, detection_emb=d_emb, tracker_pts=t_pts, tracker_emb=t_emb)

for cost, point, emb, label, score, box in zip(costs, points, d_emb, labels, scores, boxes):
match = np.argmin(cost, axis=0)
# Denormalize the detections
point[0] = point[0] * self.model_width
point[1] = point[1] * self.model_height
best_cost = cost[match]
if len(open_tracks) > 0:
info(f"Match {match} all costs {cost} for point {point} num trackers {len(open_tracks)}")
Expand All @@ -100,10 +91,10 @@ def update_trackers(self, frame_num: int, points: List[np.array], embeddings: Li
open_tracks[match].update(label, point, emb, frame_num, box=box, score=score)
else:
info(f"Match too high {best_cost} > {max_cost}; creating new track {self.next_track_id} for point {point}")
self.open_trackers.append(Track(label, point, emb, frame_num, self.x_scale, self.y_scale, box=box, id=self.next_track_id, score=score))
self.open_trackers.append(Track(label, point, emb, frame_num, self.image_width, self.image_height, box=box, id=self.next_track_id, score=score))
self.next_track_id += 1
else:
self.open_trackers.append(Track(label, point, emb, frame_num, self.x_scale, self.y_scale, box=box, id=self.next_track_id, score=score))
self.open_trackers.append(Track(label, point, emb, frame_num, self.image_width, self.image_height, box=box, id=self.next_track_id, score=score))
self.next_track_id += 1

# Close any tracks ready to close
Expand All @@ -117,10 +108,10 @@ def update_trackers(self, frame_num: int, points: List[np.array], embeddings: Li

def update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, detections: Dict, **kwargs):
"""
Update the tracker with new frames and detections
Update the tracker with new frames and det_query
:param frame_range: a tuple of the starting and ending frame numbers
:param frames: numpy array of frames in the format [frame1, frame2, frame3...]
:param detections: dictionary of detections in the format {["x": x, "y": y, "xx": x, "xy": xy, "crop_path": crop_path, "frame": frame, "class_name": class_name, "score": score]}
:param detections: dictionary of det_query in the format {["x": x, "y": y, "xx": x, "xy": xy, "crop_path": crop_path, "frame": frame, "class_name": class_name, "score": score]}
:param kwargs:
:return:
"""
Expand Down Expand Up @@ -166,14 +157,13 @@ def update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, detecti
y = y * scale
x += bbox[0]
y += bbox[1]
# Adjust the x, y to the model scale
x /= self.x_scale
y /= self.y_scale
# Adjust to 0-1 scale
x /= self.image_width
y /= self.image_height
frame_num = d["frame"]
label = d["class_name"]
score = d["score"]
# Adjust the bbox to the model scale
bbox = [bbox[0] / self.x_scale, bbox[1] / self.y_scale, bbox[2] / self.x_scale, bbox[3] / self.y_scale]
bbox = [bbox[0] / self.image_width, bbox[1] / self.image_height, bbox[2] / self.image_width, bbox[3] / self.image_height]

# Add the best keypoint to the query
info(f"Adding query for {x}, {y} {label} in frame {frame_num}")
Expand All @@ -185,33 +175,33 @@ def update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, detecti

return self._update_batch(frame_range, frames, det_query, crop_query, **kwargs)

def _update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, detections: Dict, crop_paths: Dict, **kwargs):
def _update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, det_query: Dict, crop_query: Dict, **kwargs):
"""
Update the tracker with the new frames, detections and crops of the detections
:param frame_range: a tuple of the starting and ending frame numbers
:param frames: numpy array of frames in the format [frame1, frame2, frame3...]
:param detections: a dictionary of detections in the format {frame_num: [[x1,y1,label,score,bbox],[x2,y2,label,score,bbox],[[x1,y1,label,score,bbox]]...]}
:param crop_paths: a dictionary of crop paths in the format {frame_num: [crop_path1, crop_path2, crop_path3...]}
:param det_query: a dictionary of detections in the format {frame_num: [[x1,y1,label,score,bbox],[x2,y2,label,score,bbox],[[x1,y1,label,score,bbox]]...]} in the normalized scale 0-1
:param crop_query: a dictionary of crop paths in the format {frame_num: [crop_path1, crop_path2, crop_path3...]}
:param kwargs:
:return:
"""
if len(detections) == 0 or len(crop_paths) == 0:
if len(det_query) == 0 or len(crop_query) == 0:
info("No data for frame")
return []

if len(detections) != len(crop_paths):
info(f"Number of detections {len(detections)} and crop paths {len(crop_paths)} do not match")
if len(det_query) != len(crop_query):
info(f"Number of det_query {len(det_query)} and crop paths {len(crop_query)} do not match")
return []

# Compute the embeddings for the new query detection crops
# Format the queries for the model, each query is [frame_number, x, y]
q_emb = {}
queries = []
for f, d in detections.items():
for f, d in det_query.items():
# TODO: replace with parallel processing
info(f"Computing embeddings for frame {f} {crop_paths[f]}")
info(f"Computing embeddings for frame {f} {crop_query[f]}")
labels = [det[2] for det in d]
q_emb[f] = compute_embedding_vits(self.vit_wrapper, crop_paths[f], labels)
q_emb[f] = compute_embedding_vits(self.vit_wrapper, crop_query[f], labels)
for det in d:
queries.append([f - frame_range[0], det[0], det[1]])

Expand All @@ -220,6 +210,10 @@ def _update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, detect

# Put the queries and frames into tensors and run the model with the backward tracking option which is
# more accurate than the forward/online only tracking
# Convert the queries to the model scale
queries = np.array(queries)
queries[:, 1] *= self.model_width
queries[:, 2] *= self.model_height
queries_t = torch.tensor(queries, dtype=torch.float32)
frames = torch.tensor(frames, dtype=torch.float32).permute(0, 3, 1, 2)
if DEFAULT_DEVICE == "cuda":
Expand All @@ -230,6 +224,9 @@ def _update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, detect
info(f"Running co-tracker model with {len(queries)} queries frames {frame_range}")
pred_pts, pred_visibilities = self.offline_model(video_chunk, queries=queries_t[None], backward_tracking=True)
pred_pts, pred_visibilities = pred_pts.cpu().numpy(), pred_visibilities.cpu().numpy()
# Convert the queries back to the normalized scale
pred_pts[:, :, :, 0] /= self.model_width
pred_pts[:, :, :, 1] /= self.model_height

# Update with predictions
for f in range(frame_range[0], frame_range[1], 1):
Expand All @@ -239,22 +236,22 @@ def _update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, detect
if len(pts) == 0:
continue

# Filter out the detections that are not visible
# Filter out the det_query that are not visible
pred_visibilities_in_frame = pred_visibilities[:, f - frame_range[0], :]
filtered_pts = pts[pred_visibilities_in_frame]

# Create empty embeddings for the predicted detections since this is just pt tracking
# Create empty embeddings for the predicted det_query since this is just pt tracking
empty_emb = np.zeros((len(filtered_pts), ViTWrapper.VECTOR_DIMENSIONS))
self.update_trackers(f, filtered_pts, empty_emb, **kwargs)

# Update with the queries - these seed new tracks and update existing tracks
for f in range(frame_range[0], frame_range[1]):
if f not in detections:
if f not in det_query:
continue
labels_in_frame = [d[2] for d in detections[f]]
scores_in_frame = [d[3] for d in detections[f]]
boxes_in_frame = [d[4] for d in detections[f]]
queries_in_frame = [d[:2] for d in detections[f]]
labels_in_frame = [d[2] for d in det_query[f]]
scores_in_frame = [d[3] for d in det_query[f]]
boxes_in_frame = [d[4] for d in det_query[f]]
queries_in_frame = [d[:2] for d in det_query[f]]
debug(f"Updating with queries {queries_in_frame} in frame {f}")
self.update_trackers(f, queries_in_frame, q_emb[f], labels=labels_in_frame, scores=scores_in_frame, boxes=boxes_in_frame,**kwargs)

Expand Down
11 changes: 3 additions & 8 deletions examples/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,21 @@
frames = frame_stack[i : i + window_len]
frame_full = frame_stack_full[i : i + window_len]

# Get all the detections in the window
# Get all the detections in the window to pass to the tracker - these are called the queries
detections = []
for j in range(len(frames)):
frame_num = i + j
# Load the detections for the last_updated_frame
# Load the det_query for the last_updated_frame
detections_file = detections_path / f"{frame_num}.json"
if not detections_file.exists():
print(f"No detections for frame {frame_num}")
print(f"No det_query for frame {frame_num}")
continue

data = json.loads(detections_file.read_text())
for loc in data:
if j >= num_frames:
continue

# Convert the x, y to the image coordinates and adjust crop path to the full path
loc["x"] = loc["x"] * loc["image_width"]
loc["y"] = loc["y"] * loc["image_height"]
loc["xx"] = loc["xx"] * loc["image_width"]
loc["xy"] = loc["xy"] * loc["image_height"]
loc["crop_path"] = (crops_path / loc["crop_path"]).as_posix()
loc["frame"] = frame_num
loc["score"] = loc["confidence"]
Expand Down

0 comments on commit 02bc124

Please sign in to comment.