diff --git a/dmf/models/memory.py b/dmf/models/memory.py index d07c063..7aa6aec 100644 --- a/dmf/models/memory.py +++ b/dmf/models/memory.py @@ -45,9 +45,8 @@ def free(*objects: Any) -> None: # Handle CUDA and MPS cache clearing if torch.cuda.is_available(): + torch.cuda.ipc_collect() torch.cuda.empty_cache() - if torch.backends.mps.is_available(): - torch.mps.empty_cache() # Explicitly run garbage collection to free up memory gc.collect() @@ -81,11 +80,6 @@ def get_memory_stats(device: Optional[Union[str, torch.device]] = None, format_s memory_stats["free"] = torch.cuda.memory_free(device) memory_stats["occupied"] = torch.cuda.memory_allocated(device) memory_stats["reserved"] = torch.cuda.memory_reserved(device) - elif device.type == "mps": - # MPS memory stats (only supported in PyTorch 1.13+) - memory_stats["free"] = torch.mps.current_reserved_memory() - torch.mps.current_allocated_memory() - memory_stats["occupied"] = torch.mps.current_allocated_memory() - memory_stats["reserved"] = torch.mps.current_reserved_memory() else: # CPU memory stats using psutil import psutil