Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Documentation for lightly/models/modules #1700

Merged
23 changes: 20 additions & 3 deletions lightly/models/modules/center.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def __init__(
mode: str = "mean",
momentum: float = 0.9,
) -> None:
"""Initializes the Center module with the specified parameters.

Raises:
ValueError: If an unknown mode is provided.
"""
super().__init__()

center_fn = CENTER_MODE_TO_FUNCTION.get(mode)
Expand All @@ -49,8 +54,10 @@ def __init__(

@property
def value(self) -> Tensor:
"""The current value of the center. Use this property to do any operations based
on the center."""
"""The current value of the center.

Use this property to do any operations based on the center.
"""
return self.center

@torch.no_grad()
Expand All @@ -75,7 +82,17 @@ def _center_mean(self, x: Tensor) -> Tensor:

@torch.no_grad()
def center_mean(x: Tensor, dim: Tuple[int, ...]) -> Tensor:
"""Returns the center of the input tensor by calculating the mean."""
"""Returns the center of the input tensor by calculating the mean.

Args:
x:
Input tensor.
dim:
Dimensions along which the mean is calculated.

Returns:
The center of the input tensor.
"""
batch_center = torch.mean(x, dim=dim, keepdim=True)
if dist.is_available() and dist.is_initialized():
dist.all_reduce(batch_center)
Expand Down
Loading
Loading