-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
328 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from typing import TYPE_CHECKING | ||
|
||
import lazy_loader as lazy | ||
|
||
submod_attrs = { | ||
"memory": ["free", "get_memory_stats"], | ||
"seed": ["set_seed"], | ||
"device": ["get_device"], | ||
} | ||
|
||
__getattr__, __dir__, __all__ = lazy.attach(__name__, submod_attrs=submod_attrs) | ||
|
||
if TYPE_CHECKING: | ||
from .memory import free, get_memory_stats | ||
from .seed import set_seed | ||
from .device import get_device | ||
|
||
__all__ = ["free", "get_memory_stats", "set_seed", "get_device"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
|
||
from typing import Optional, Union | ||
try: | ||
import torch | ||
except ImportError: | ||
raise ImportError("PyTorch is not installed. Please, " | ||
"install a suitable version of PyTorch.") | ||
|
||
def get_device(device: Optional[Union[str, torch.device]] = None) -> torch.device: | ||
"""Return the specified device. | ||
Parameters | ||
---------- | ||
device : str or torch.device, optional | ||
The device to use. If None, the default device is selected. | ||
Returns | ||
------- | ||
torch.device | ||
The selected device. | ||
""" | ||
|
||
if device is None: | ||
if torch.cuda.is_available(): | ||
device = torch.device("cuda") | ||
elif torch.backends.mps.is_available(): | ||
device = torch.device("mps") | ||
else: | ||
device = torch.device("cpu") | ||
else: | ||
device = torch.device(device) | ||
|
||
return device |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import gc | ||
from typing import Dict, Optional, Union, Any | ||
|
||
try: | ||
import torch | ||
except ImportError: | ||
raise ImportError("PyTorch is not installed. Please, " | ||
"install a suitable version of PyTorch.") | ||
|
||
from .device import get_device | ||
from ..utils.format_bytes import bytes_to_human_readable | ||
|
||
def free(*objects: Any) -> None: | ||
""" | ||
Free the memory associated with the given objects, including PyTorch models, tensors, and other related objects. | ||
Parameters | ||
---------- | ||
*objects : Any | ||
The objects to free. Typically these are PyTorch models or tensors. | ||
Notes | ||
----- | ||
This function handles CPU, CUDA, and MPS tensors/models by setting gradients to None, deleting the objects, | ||
clearing the CUDA or MPS cache if necessary, and calling garbage collection. | ||
""" | ||
for obj in objects: | ||
if isinstance(obj, torch.nn.Module): | ||
# Free the model parameters' gradients | ||
for param in obj.parameters(): | ||
if param.grad is not None: | ||
param.grad = None | ||
# Move the model to CPU before deleting (optional, depending on use case) | ||
obj.to('cpu') | ||
|
||
elif isinstance(obj, torch.Tensor): | ||
# Free the tensor memory | ||
if obj.grad is not None: | ||
obj.grad = None | ||
# Move tensor to CPU before deletion (optional) | ||
obj = obj.cpu() | ||
|
||
# Delete the object reference | ||
del obj | ||
|
||
# Handle CUDA and MPS cache clearing | ||
if torch.cuda.is_available(): | ||
torch.cuda.empty_cache() | ||
if torch.backends.mps.is_available(): | ||
torch.mps.empty_cache() | ||
|
||
# Explicitly run garbage collection to free up memory | ||
gc.collect() | ||
|
||
|
||
def get_memory_stats(device: Optional[Union[str, torch.device]] = None, format_size: bool=False) -> Dict[str, Any]: | ||
""" | ||
Get memory statistics for the specified device. | ||
Parameters | ||
---------- | ||
device : str or torch.device, optional | ||
The device to get memory statistics for. If None, automatically detects | ||
the available device (CUDA, MPS, or CPU). | ||
format_size : bool, optional | ||
Whether to format the memory sizes in human-readable format (KB, MB, GB, TB). Default is False. | ||
Returns | ||
------- | ||
dict | ||
A dictionary containing memory statistics: free, occupied, reserved, and device. | ||
""" | ||
# Determine the device if not provided | ||
device = get_device(device) | ||
|
||
memory_stats = {"device": str(device)} | ||
|
||
try: | ||
if device.type == "cuda": | ||
# CUDA memory stats | ||
memory_stats["free"] = torch.cuda.memory_free(device) | ||
memory_stats["occupied"] = torch.cuda.memory_allocated(device) | ||
memory_stats["reserved"] = torch.cuda.memory_reserved(device) | ||
elif device.type == "mps": | ||
# MPS memory stats (only supported in PyTorch 1.13+) | ||
memory_stats["free"] = torch.mps.current_reserved_memory() - torch.mps.current_allocated_memory() | ||
memory_stats["occupied"] = torch.mps.current_allocated_memory() | ||
memory_stats["reserved"] = torch.mps.current_reserved_memory() | ||
else: | ||
# CPU memory stats using psutil | ||
import psutil | ||
virtual_mem = psutil.virtual_memory() | ||
memory_stats["free"] = virtual_mem.available | ||
memory_stats["occupied"] = virtual_mem.total - virtual_mem.available | ||
memory_stats["reserved"] = virtual_mem.total | ||
|
||
except Exception: | ||
memory_stats["free"] = None | ||
memory_stats["occupied"] = None | ||
memory_stats["reserved"] = None | ||
|
||
if format_size: | ||
for key in ["free", "occupied", "reserved"]: | ||
memory_stats[key] = bytes_to_human_readable(memory_stats[key]) | ||
|
||
return memory_stats | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import random | ||
import numpy as np | ||
import torch | ||
|
||
def set_seed(seed: int) -> "torch.Generator": | ||
""" | ||
Set the seed for random number generation in Python, NumPy, and PyTorch to ensure reproducibility. | ||
Parameters | ||
---------- | ||
seed : int | ||
The seed value to set for random number generation. | ||
Returns | ||
------- | ||
torch.Generator | ||
The random number generator for PyTorch. | ||
""" | ||
# Set the seed for Python's built-in random module | ||
random.seed(seed) | ||
|
||
# Set the seed for NumPy | ||
np.random.seed(seed) | ||
|
||
# Set the seed for PyTorch based on the available device | ||
if torch.cuda.is_available(): | ||
# Set seed for CUDA devices | ||
torch.cuda.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) # For multi-GPU setups | ||
rg = torch.Generator(torch.cuda.current_device()) | ||
else: | ||
# Set seed for CPU-only and MPS, since it's a CPU-based backend | ||
torch.manual_seed(seed) | ||
rg = torch.Generator() | ||
|
||
# Ensure deterministic behavior in cuDNN (if applicable) | ||
if torch.backends.cudnn.is_available(): | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
|
||
rg.manual_seed(seed) | ||
return rg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
|
||
|
||
from typing import Optional | ||
|
||
def bytes_to_human_readable(num_bytes: Optional[int], decimal_places: int = 2, units = ["Bytes", "KB", "MB", "GB", "TB"]) -> str: | ||
""" | ||
Convert a number in bytes into a human-readable string with appropriate units (KB, MB, GB, TB). | ||
If num_bytes is None, return None. | ||
Parameters | ||
---------- | ||
num_bytes : int | ||
The number of bytes to convert. Can be a positive or negative integer. | ||
decimal_places : int, optional | ||
The number of decimal places to display for KB, MB, GB, and TB. Default is 2. | ||
Returns | ||
------- | ||
str or None | ||
The human-readable string representing the size in appropriate units, or None if input is None. | ||
""" | ||
if num_bytes is None: | ||
return None | ||
|
||
# Define the units and thresholds (limited to TB) | ||
|
||
factor = 1024.0 | ||
size = abs(num_bytes) | ||
unit_index = 0 | ||
|
||
while size >= factor and unit_index < len(units) - 1: | ||
size /= factor | ||
unit_index += 1 | ||
|
||
# Format the size based on the unit | ||
if units[unit_index] == "Bytes": | ||
size_str = f"{int(size)} {units[unit_index]}" | ||
else: | ||
size_str = f"{size:.{decimal_places}f} {units[unit_index]}" | ||
|
||
# Add a minus sign for negative byte values | ||
if num_bytes < 0: | ||
size_str = f"-{size_str}" | ||
|
||
return size_str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,4 +11,5 @@ Below are the different modules available within the package: | |
alerts | ||
env | ||
io | ||
models | ||
video |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
Models | ||
====== | ||
|
||
The `dmf.models` module provides utilities for tasks related to PyTorch models. | ||
|
||
This package is included in the `dmf-utils` core package and can be installed using the following command: | ||
|
||
.. code-block:: bash | ||
pip install dmf-utils | ||
Howeve, you need to have installed PyTorch to use the functionalities in this module. | ||
The installation instructions for PyTorch are provided below. | ||
|
||
Content | ||
--------- | ||
|
||
The `dmf.models` module includes the following functions: | ||
|
||
.. autosummary:: | ||
:toctree: autosummary | ||
|
||
dmf.models.free | ||
dmf.models.get_memory_stats | ||
dmf.models.get_device | ||
dmf.models.set_seed | ||
|
||
|
||
Pytorch Installation | ||
-------------------- | ||
|
||
For detailed information check the official `PyTorch installation guide <https://pytorch.org/get-started/locally/>`_. | ||
|
||
Linux + CUDA | ||
~~~~~~~~~~~~ | ||
|
||
1. First, check your cuda version by running: | ||
|
||
.. code-block:: bash | ||
nvcc --version | ||
If you dont have CUDA installed, you can install it by following the instructions on the `NVIDIA CUDA Toolkit <https://developer.nvidia.com/cuda-toolkit>`_ page. | ||
|
||
To install the latest version of PyTorch for your CUDA support, run: | ||
|
||
.. code-block:: bash | ||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 | ||
Make sure to replace `cu118` with the appropriate version matching your CUDA toolkit version. | ||
|
||
macOS + MPS | ||
~~~~~~~~~~~ | ||
|
||
On macOS systems, especially those with Apple Silicon (M1, M2 chips), | ||
you can use the Metal Performance Shaders (MPS) backend by installing PyTorch with: | ||
|
||
.. code-block:: bash | ||
pip install torch torchvision torchaudio | ||
The MPS backend is automatically enabled when using PyTorch on compatible macOS devices. | ||
|
||
CPU-based | ||
~~~~~~~~~ | ||
|
||
For environments without GPU support or when running on systems without CUDA or MPS capabilities, | ||
you can install the CPU-only version of PyTorch: | ||
|
||
.. code-block:: bash | ||
pip install torch torchvision torchaudio |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters