Skip to content

Commit

Permalink
fix: default to keypoint only (no embedding) cost and fix query visib…
Browse files Browse the repository at this point in the history
…ility matrix
  • Loading branch information
danellecline committed Dec 11, 2024
1 parent c47d157 commit 5eab5fc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
6 changes: 2 additions & 4 deletions biotrack/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def process_images(self, image_paths: list, min_coverage: float = 10.0) -> tuple
best_score = score
best_coverage = coverage
best_keypoints = kp
debug(f"Found best keypoints: {best_keypoints} for {category_name} with coverage {best_coverage} for {image_path}")
if best_coverage > 20:
break
debug(f"Found best keypoints: {best_keypoints} for {category_name} with coverage {best_coverage} score {score} for {image_path}")

keypoints.append(best_keypoints)
coverages.append(best_coverage)
Expand Down Expand Up @@ -145,7 +143,7 @@ def get_gcam_keypoints(model: torch.nn.Module,
input_tensor = np.transpose(input_tensor, (1, 2, 0))
img_color = np.uint8(input_tensor * 255)
img_color = cv2.cvtColor(img_color, cv2.COLOR_BGR2GRAY)
_, img_thres = cv2.threshold(img_color, 180, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
_, img_thres = cv2.threshold(img_color, 150, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
contours_raw, _ = cv2.findContours(img_thres, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

if display:
Expand Down
10 changes: 7 additions & 3 deletions biotrack/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def update_trackers_queries(self, frame_num: int, keypoints: np.array, labels: L

# Flatten the keypoints and associate the new points with the existing track traces
keypoints = np.array([item for sublist in keypoints for item in sublist[0]])
# assignment, costs = associate_trace_pts(detection_pts=keypoints, trace_pts=t_pts)
assignment, costs = associate_track_pts_emb(detection_pts=keypoints, detection_emb=d_emb, trace_pts=t_pts, tracker_emb=t_emb)
assignment, costs = associate_trace_pts(detection_pts=keypoints, trace_pts=t_pts)
# assignment, costs = associate_track_pts_emb(detection_pts=keypoints, detection_emb=d_emb, trace_pts=t_pts, tracker_emb=t_emb)
if len(assignment) == 0:
return

Expand Down Expand Up @@ -237,6 +237,7 @@ def correct_keypoints(top_kps, crop_paths):
empty_idx = [i for i, kpts in enumerate(keypoints) if len(kpts) == 0]
if len(empty_idx) > 0:
info(f"Removing empty keypoints {empty_idx}")
keypoints = [kpts for i, kpts in enumerate(keypoints) if i not in empty_idx]
embeddings = [emb for i, emb in enumerate(embeddings) if i not in empty_idx]
predicted_classes = [p for i, p in enumerate(predicted_classes) if i not in empty_idx]
predicted_scores = [p for i, p in enumerate(predicted_scores) if i not in empty_idx]
Expand Down Expand Up @@ -311,9 +312,12 @@ def _update_batch(self, frame_range: Tuple[int, int], frames: np.ndarray, det_qu
if ii < f:
continue
for j in range(0, len(labels_in_frame[0]*Track.NUM_KP)):
query_vis[ii][vis_idx + j] = True
idx = min(vis_idx + j, len(queries)-1)
query_vis[ii - frame_range[0]][idx] = True
vis_idx += len(labels_in_frame[0]*Track.NUM_KP)

return self.open_trackers + self.closed_trackers

# Put the queries and frames into tensors and run the model with the backward tracking option which is
# more accurate than the forward/online only tracking
queries_ = np.array(queries)
Expand Down

0 comments on commit 5eab5fc

Please sign in to comment.