Skip to content

Commit

Permalink
Fix errors in mps backend
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomm committed Aug 30, 2024
1 parent ef2223a commit da55956
Showing 1 changed file with 1 addition and 7 deletions.
8 changes: 1 addition & 7 deletions dmf/models/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ def free(*objects: Any) -> None:

# Handle CUDA and MPS cache clearing
if torch.cuda.is_available():
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
if torch.backends.mps.is_available():
torch.mps.empty_cache()

# Explicitly run garbage collection to free up memory
gc.collect()
Expand Down Expand Up @@ -81,11 +80,6 @@ def get_memory_stats(device: Optional[Union[str, torch.device]] = None, format_s
memory_stats["free"] = torch.cuda.memory_free(device)
memory_stats["occupied"] = torch.cuda.memory_allocated(device)
memory_stats["reserved"] = torch.cuda.memory_reserved(device)
elif device.type == "mps":
# MPS memory stats (only supported in PyTorch 1.13+)
memory_stats["free"] = torch.mps.current_reserved_memory() - torch.mps.current_allocated_memory()
memory_stats["occupied"] = torch.mps.current_allocated_memory()
memory_stats["reserved"] = torch.mps.current_reserved_memory()
else:
# CPU memory stats using psutil
import psutil
Expand Down

0 comments on commit da55956

Please sign in to comment.