diff --git a/biotrack/tracker.py b/biotrack/tracker.py index 3cfdada..ae92342 100644 --- a/biotrack/tracker.py +++ b/biotrack/tracker.py @@ -17,9 +17,6 @@ from biotrack.track import Track from biotrack.logger import create_logger_file, info, debug -DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - - class BioTracker: def __init__(self, image_width: int, image_height: int, device_id: int = 0, **kwargs): self.logger = create_logger_file() @@ -36,13 +33,13 @@ def __init__(self, image_width: int, image_height: int, device_id: int = 0, **kw self.track_model = CoTrackerPredictor(checkpoint=None) self.track_model.model = model.model self.track_model.step = model.model.window_len // 2 - - self.device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu") + self.device_id = kwargs.get("gpu_id", 0) + self.device = torch.device(f"cuda:{self.device_id}" if torch.cuda.is_available() else "cpu") self.track_model.model.to(self.device) # Initialize the model for computing crop embeddings model_name = kwargs.get("vits_model", ViTWrapper.DEFAULT_MODEL_NAME) - self.vit_wrapper = ViTWrapper(DEFAULT_DEVICE, device_id=device_id, model_name=model_name) + self.vit_wrapper = ViTWrapper(device_id=device_id, model_name=model_name) def update_trackers_queries(self, frame_num: int, keypoints: np.array, labels: List[str], scores: np.array, coverages: np.array, boxes: np.array, d_emb: np.array, **kwargs): """