Skip to content

Commit

Permalink
chore: change device id to gpu id kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Dec 17, 2024
1 parent 0aa0c4b commit f743694
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions biotrack/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
from biotrack.track import Track
from biotrack.logger import create_logger_file, info, debug

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"


class BioTracker:
def __init__(self, image_width: int, image_height: int, device_id: int = 0, **kwargs):
self.logger = create_logger_file()
Expand All @@ -36,13 +33,13 @@ def __init__(self, image_width: int, image_height: int, device_id: int = 0, **kw
self.track_model = CoTrackerPredictor(checkpoint=None)
self.track_model.model = model.model
self.track_model.step = model.model.window_len // 2

self.device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu")
self.device_id = kwargs.get("gpu_id", 0)
self.device = torch.device(f"cuda:{self.device_id}" if torch.cuda.is_available() else "cpu")
self.track_model.model.to(self.device)

# Initialize the model for computing crop embeddings
model_name = kwargs.get("vits_model", ViTWrapper.DEFAULT_MODEL_NAME)
self.vit_wrapper = ViTWrapper(DEFAULT_DEVICE, device_id=device_id, model_name=model_name)
self.vit_wrapper = ViTWrapper(device_id=device_id, model_name=model_name)

def update_trackers_queries(self, frame_num: int, keypoints: np.array, labels: List[str], scores: np.array, coverages: np.array, boxes: np.array, d_emb: np.array, **kwargs):
"""
Expand Down

0 comments on commit f743694

Please sign in to comment.