-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #70 from fernandoGubiMarques/lfr-base
Implement a base model for LFR
- Loading branch information
Showing
6 changed files
with
351 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from torch import nn | ||
from typing import Sequence | ||
|
||
|
||
class MLP(nn.Sequential): | ||
""" | ||
A multilayer perceptron (MLP) implemented as a subclass of nn.Sequential. | ||
This MLP is composed of a sequence of linear layers interleaved with ReLU activation | ||
functions, except for the final layer which remains purely linear. | ||
Example | ||
------- | ||
>>> mlp = MLP(10, 20, 30, 40) | ||
>>> print(mlp) | ||
MLP( | ||
(0): Linear(in_features=10, out_features=20, bias=True) | ||
(1): ReLU() | ||
(2): Linear(in_features=20, out_features=30, bias=True) | ||
(3): ReLU() | ||
(4): Linear(in_features=30, out_features=40, bias=True) | ||
) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
layer_sizes: Sequence[int], | ||
activation_cls: type = nn.ReLU, | ||
*args, | ||
**kwargs | ||
): | ||
""" | ||
Initializes the MLP with specified layer sizes. | ||
Parameters | ||
---------- | ||
layer_sizes : Sequence[int] | ||
A sequence of positive integers indicating the size of each layer. | ||
At least two integers are required, representing the input and output layers. | ||
activation_cls : type | ||
The class of the activation function to use between layers. Default is nn.ReLU. | ||
*args | ||
Additional arguments passed to the activation function. | ||
**kwargs | ||
Additional keyword arguments passed to the activation function. | ||
Raises | ||
------ | ||
AssertionError | ||
If fewer than two layer sizes are provided or if any layer size is not a positive integer. | ||
AssertionError | ||
If activation_cls does not inherit from torch.nn.Module. | ||
""" | ||
|
||
assert ( | ||
len(layer_sizes) >= 2 | ||
), "Multilayer perceptron must have at least 2 layers" | ||
assert all( | ||
ls > 0 and isinstance(ls, int) for ls in layer_sizes | ||
), "All layer sizes must be positive integers" | ||
|
||
assert issubclass( | ||
activation_cls, nn.Module | ||
), "activation_cls must inherit from torch.nn.Module" | ||
|
||
layers = [] | ||
for i in range(len(layer_sizes) - 2): | ||
layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1])) | ||
layers.append(activation_cls(*args, **kwargs)) | ||
layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1])) | ||
|
||
super().__init__(*layers) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .lfr import LearnFromRandomnessModel, RepeatedModuleList | ||
|
||
__all__ = [ | ||
"LearnFromRandomnessModel", | ||
"RepeatedModuleList" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
import lightning as L | ||
import torch | ||
|
||
|
||
class RepeatedModuleList(torch.nn.ModuleList): | ||
""" | ||
A module list with the same module `cls`, instantiated `size` times. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
size: int, | ||
cls: type, | ||
*args, | ||
**kwargs | ||
): | ||
""" | ||
Initializes the RepeatedModuleList with multiple instances of a given module class. | ||
Parameters | ||
---------- | ||
size: int | ||
The number of instances to create. | ||
cls: type | ||
The module class to instantiate. Must be a subclass of `torch.nn.Module`. | ||
*args: | ||
Positional arguments to pass to the module class constructor. | ||
**kwargs: | ||
Keyword arguments to pass to the module class constructor. | ||
Raises | ||
------ | ||
AssertionError: | ||
If `cls` is not a subclass of `torch.nn.Module`. | ||
Example | ||
------- | ||
>>> class SimpleModule(torch.nn.Module): | ||
>>> def __init__(self, in_features, out_features): | ||
>>> super().__init__() | ||
>>> self.linear = torch.nn.Linear(in_features, out_features) | ||
>>> | ||
>>> repeated_modules = RepeatedModuleList(3, SimpleModule, 10, 5) | ||
>>> print(repeated_modules) | ||
RepeatedModuleList( | ||
(0): SimpleModule( | ||
(linear): Linear(in_features=10, out_features=5, bias=True) | ||
) | ||
(1): SimpleModule( | ||
(linear): Linear(in_features=10, out_features=5, bias=True) | ||
) | ||
(2): SimpleModule( | ||
(linear): Linear(in_features=10, out_features=5, bias=True) | ||
) | ||
) | ||
""" | ||
|
||
assert issubclass( | ||
cls, torch.nn.Module | ||
), f"{cls} does not derive from torch.nn.Module" | ||
|
||
super().__init__([cls(*args, **kwargs) for _ in range(size)]) | ||
|
||
|
||
class LearnFromRandomnessModel(L.LightningModule): | ||
""" | ||
A PyTorch Lightning model for pretraining with the technique | ||
'Learning From Random Projectors'. | ||
References | ||
---------- | ||
Yi Sui, Tongzi Wu, Jesse C. Cresswell, Ga Wu, George Stein, Xiao Shi Huang, Xiaochen Zhang, Maksims Volkovs. | ||
"Self-supervised Representation Learning From Random Data Projectors", 2024 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
backbone: torch.nn.Module, | ||
projectors: torch.nn.ModuleList, | ||
predictors: torch.nn.ModuleList, | ||
loss_fn: torch.nn.Module, | ||
learning_rate: float = 1e-3, | ||
flatten: bool = True, | ||
): | ||
""" | ||
Initialize the LFR_Model. | ||
Parameters | ||
---------- | ||
backbone: torch.nn.Module | ||
The backbone neural network for feature extraction. | ||
projectors: torch.nn.ModuleList | ||
A list of projector networks. | ||
predictors: torch.nn.ModuleList | ||
A list of predictor networks. | ||
loss_fn: torch.nn.Module | ||
The loss function to optimize. | ||
learning_rate: float | ||
The learning rate for the optimizer, by default 1e-3. | ||
flatten: bool | ||
Whether to flatten the input tensor or not, by default True. | ||
""" | ||
super().__init__() | ||
self.backbone = backbone | ||
self.projectors = projectors | ||
self.predictors = predictors | ||
self.loss_fn = loss_fn | ||
self.learning_rate = learning_rate | ||
self.flatten = flatten | ||
|
||
for param in self.projectors.parameters(): | ||
param.requires_grad = False | ||
|
||
for proj in self.projectors: | ||
proj.eval() | ||
|
||
def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Calculate the loss between the output and the input data. | ||
Parameters | ||
---------- | ||
y_hat : torch.Tensor | ||
The output data from the forward pass. | ||
y : torch.Tensor | ||
The input data/label. | ||
Returns | ||
------- | ||
torch.Tensor | ||
The loss value. | ||
""" | ||
loss = self.loss_fn(y_hat, y) | ||
return loss | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Forward pass through the network. | ||
Parameters | ||
---------- | ||
x : torch.Tensor | ||
The input data. | ||
Returns | ||
------- | ||
torch.Tensor | ||
The predicted output and projected input. | ||
""" | ||
z: torch.Tensor = self.backbone(x) | ||
|
||
if self.flatten: | ||
z = z.view(z.size(0), -1) | ||
x = x.view(x.size(0), -1) | ||
|
||
y_pred = torch.stack([predictor(z) for predictor in self.predictors], 1) | ||
y_proj = torch.stack([projector(x) for projector in self.projectors], 1) | ||
|
||
return y_pred, y_proj | ||
|
||
def _single_step( | ||
self, batch: torch.Tensor, batch_idx: int, step_name: str | ||
) -> torch.Tensor: | ||
""" | ||
Perform a single training/validation/test step. | ||
Parameters | ||
---------- | ||
batch : torch.Tensor | ||
The input batch of data. | ||
batch_idx : int | ||
The index of the batch. | ||
step_name : str | ||
The name of the step (train, val, test). | ||
Returns | ||
------- | ||
torch.Tensor | ||
The loss value for the batch. | ||
""" | ||
x = batch | ||
y_pred, y_proj = self.forward(x) | ||
loss = self._loss_func(y_pred, y_proj) | ||
self.log( | ||
f"{step_name}_loss", | ||
loss, | ||
on_step=False, | ||
on_epoch=True, | ||
prog_bar=True, | ||
logger=True, | ||
) | ||
|
||
return loss | ||
|
||
def training_step(self, batch: torch.Tensor, batch_idx: int): | ||
return self._single_step(batch, batch_idx, step_name="train") | ||
|
||
def validation_step(self, batch: torch.Tensor, batch_idx: int): | ||
return self._single_step(batch, batch_idx, step_name="val") | ||
|
||
def test_step(self, batch: torch.Tensor, batch_idx: int): | ||
return self._single_step(batch, batch_idx, step_name="test") | ||
|
||
def configure_optimizers(self): | ||
return torch.optim.Adam( | ||
self.parameters(), | ||
lr=self.learning_rate, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
from torch.nn import Sequential, Conv2d, CrossEntropyLoss | ||
from torchvision.transforms import Resize | ||
|
||
from minerva.models.ssl.lfr import RepeatedModuleList, LearnFromRandomnessModel | ||
from minerva.models.nets.image.deeplabv3 import DeepLabV3Backbone | ||
|
||
|
||
def test_lfr(): | ||
|
||
## Example class for projector | ||
class Projector(Sequential): | ||
def __init__(self): | ||
super().__init__( | ||
Conv2d(3, 16, 5, 2), | ||
Conv2d(16, 64, 5, 2), | ||
Conv2d(64, 16, 5, 2), | ||
Resize((100, 50)), | ||
) | ||
|
||
## Example class for predictor | ||
class Predictor(Sequential): | ||
def __init__(self): | ||
super().__init__(Conv2d(2048, 16, 1), Resize((100, 50))) | ||
|
||
# Declare model | ||
model = LearnFromRandomnessModel( | ||
DeepLabV3Backbone(), | ||
RepeatedModuleList(5, Projector), | ||
RepeatedModuleList(5, Predictor), | ||
CrossEntropyLoss(), | ||
flatten=False | ||
) | ||
|
||
# Test the class instantiation | ||
assert model is not None | ||
|
||
# # Test the forward method | ||
input_shape = (2, 3, 701, 255) | ||
expected_output_size = torch.Size([2, 5, 16, 100, 50]) | ||
x = torch.rand(*input_shape) | ||
|
||
y_pred, y_proj = model(x) | ||
assert ( | ||
y_pred.shape == expected_output_size | ||
), f"Expected output shape {expected_output_size}, but got {y_pred.shape}" | ||
|
||
assert ( | ||
y_proj.shape == expected_output_size | ||
), f"Expected output shape {expected_output_size}, but got {y_proj.shape}" | ||
|
||
# Test the _loss_func method | ||
loss = model._loss_func(y_pred, y_proj) | ||
assert loss is not None | ||
# TODO: assert the loss result | ||
|
||
# Test the configure_optimizers method | ||
optimizer = model.configure_optimizers() | ||
assert optimizer is not None |