Skip to content

Commit

Permalink
Merge pull request #70 from fernandoGubiMarques/lfr-base
Browse files Browse the repository at this point in the history
Implement a base model for LFR
  • Loading branch information
otavioon authored Jul 16, 2024
2 parents e26a2f1 + 0d0e5f8 commit abea98f
Show file tree
Hide file tree
Showing 6 changed files with 351 additions and 1 deletion.
4 changes: 3 additions & 1 deletion minerva/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@

# SSL Models

...
| **Model** | **Authors** | **Task** | **Type** | **Input Shape** | **Python Class** | **Observations** |
|-----------------------------------------|---------------|----------|----------|:---------------:|:--------------------------------------------:|-------------------|
| [LFR](https://arxiv.org/abs/2310.07756) | Yi Sui et al. | Any | Any | Any | minerva.models.nets.LearnFromRandomnessModel | |
2 changes: 2 additions & 0 deletions minerva/models/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from .image.setr import SETR_PUP
from .image.unet import UNet
from .image.wisenet import WiseNet
from .mlp import MLP

__all__ = [
"SimpleSupervisedModel",
"DeepLabV3",
"SETR_PUP",
"UNet",
"WiseNet",
"MLP"
]
73 changes: 73 additions & 0 deletions minerva/models/nets/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from torch import nn
from typing import Sequence


class MLP(nn.Sequential):
"""
A multilayer perceptron (MLP) implemented as a subclass of nn.Sequential.
This MLP is composed of a sequence of linear layers interleaved with ReLU activation
functions, except for the final layer which remains purely linear.
Example
-------
>>> mlp = MLP(10, 20, 30, 40)
>>> print(mlp)
MLP(
(0): Linear(in_features=10, out_features=20, bias=True)
(1): ReLU()
(2): Linear(in_features=20, out_features=30, bias=True)
(3): ReLU()
(4): Linear(in_features=30, out_features=40, bias=True)
)
"""

def __init__(
self,
layer_sizes: Sequence[int],
activation_cls: type = nn.ReLU,
*args,
**kwargs
):
"""
Initializes the MLP with specified layer sizes.
Parameters
----------
layer_sizes : Sequence[int]
A sequence of positive integers indicating the size of each layer.
At least two integers are required, representing the input and output layers.
activation_cls : type
The class of the activation function to use between layers. Default is nn.ReLU.
*args
Additional arguments passed to the activation function.
**kwargs
Additional keyword arguments passed to the activation function.
Raises
------
AssertionError
If fewer than two layer sizes are provided or if any layer size is not a positive integer.
AssertionError
If activation_cls does not inherit from torch.nn.Module.
"""

assert (
len(layer_sizes) >= 2
), "Multilayer perceptron must have at least 2 layers"
assert all(
ls > 0 and isinstance(ls, int) for ls in layer_sizes
), "All layer sizes must be positive integers"

assert issubclass(
activation_cls, nn.Module
), "activation_cls must inherit from torch.nn.Module"

layers = []
for i in range(len(layer_sizes) - 2):
layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
layers.append(activation_cls(*args, **kwargs))
layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1]))

super().__init__(*layers)
6 changes: 6 additions & 0 deletions minerva/models/ssl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .lfr import LearnFromRandomnessModel, RepeatedModuleList

__all__ = [
"LearnFromRandomnessModel",
"RepeatedModuleList"
]
208 changes: 208 additions & 0 deletions minerva/models/ssl/lfr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import lightning as L
import torch


class RepeatedModuleList(torch.nn.ModuleList):
"""
A module list with the same module `cls`, instantiated `size` times.
"""

def __init__(
self,
size: int,
cls: type,
*args,
**kwargs
):
"""
Initializes the RepeatedModuleList with multiple instances of a given module class.
Parameters
----------
size: int
The number of instances to create.
cls: type
The module class to instantiate. Must be a subclass of `torch.nn.Module`.
*args:
Positional arguments to pass to the module class constructor.
**kwargs:
Keyword arguments to pass to the module class constructor.
Raises
------
AssertionError:
If `cls` is not a subclass of `torch.nn.Module`.
Example
-------
>>> class SimpleModule(torch.nn.Module):
>>> def __init__(self, in_features, out_features):
>>> super().__init__()
>>> self.linear = torch.nn.Linear(in_features, out_features)
>>>
>>> repeated_modules = RepeatedModuleList(3, SimpleModule, 10, 5)
>>> print(repeated_modules)
RepeatedModuleList(
(0): SimpleModule(
(linear): Linear(in_features=10, out_features=5, bias=True)
)
(1): SimpleModule(
(linear): Linear(in_features=10, out_features=5, bias=True)
)
(2): SimpleModule(
(linear): Linear(in_features=10, out_features=5, bias=True)
)
)
"""

assert issubclass(
cls, torch.nn.Module
), f"{cls} does not derive from torch.nn.Module"

super().__init__([cls(*args, **kwargs) for _ in range(size)])


class LearnFromRandomnessModel(L.LightningModule):
"""
A PyTorch Lightning model for pretraining with the technique
'Learning From Random Projectors'.
References
----------
Yi Sui, Tongzi Wu, Jesse C. Cresswell, Ga Wu, George Stein, Xiao Shi Huang, Xiaochen Zhang, Maksims Volkovs.
"Self-supervised Representation Learning From Random Data Projectors", 2024
"""

def __init__(
self,
backbone: torch.nn.Module,
projectors: torch.nn.ModuleList,
predictors: torch.nn.ModuleList,
loss_fn: torch.nn.Module,
learning_rate: float = 1e-3,
flatten: bool = True,
):
"""
Initialize the LFR_Model.
Parameters
----------
backbone: torch.nn.Module
The backbone neural network for feature extraction.
projectors: torch.nn.ModuleList
A list of projector networks.
predictors: torch.nn.ModuleList
A list of predictor networks.
loss_fn: torch.nn.Module
The loss function to optimize.
learning_rate: float
The learning rate for the optimizer, by default 1e-3.
flatten: bool
Whether to flatten the input tensor or not, by default True.
"""
super().__init__()
self.backbone = backbone
self.projectors = projectors
self.predictors = predictors
self.loss_fn = loss_fn
self.learning_rate = learning_rate
self.flatten = flatten

for param in self.projectors.parameters():
param.requires_grad = False

for proj in self.projectors:
proj.eval()

def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Calculate the loss between the output and the input data.
Parameters
----------
y_hat : torch.Tensor
The output data from the forward pass.
y : torch.Tensor
The input data/label.
Returns
-------
torch.Tensor
The loss value.
"""
loss = self.loss_fn(y_hat, y)
return loss

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the network.
Parameters
----------
x : torch.Tensor
The input data.
Returns
-------
torch.Tensor
The predicted output and projected input.
"""
z: torch.Tensor = self.backbone(x)

if self.flatten:
z = z.view(z.size(0), -1)
x = x.view(x.size(0), -1)

y_pred = torch.stack([predictor(z) for predictor in self.predictors], 1)
y_proj = torch.stack([projector(x) for projector in self.projectors], 1)

return y_pred, y_proj

def _single_step(
self, batch: torch.Tensor, batch_idx: int, step_name: str
) -> torch.Tensor:
"""
Perform a single training/validation/test step.
Parameters
----------
batch : torch.Tensor
The input batch of data.
batch_idx : int
The index of the batch.
step_name : str
The name of the step (train, val, test).
Returns
-------
torch.Tensor
The loss value for the batch.
"""
x = batch
y_pred, y_proj = self.forward(x)
loss = self._loss_func(y_pred, y_proj)
self.log(
f"{step_name}_loss",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)

return loss

def training_step(self, batch: torch.Tensor, batch_idx: int):
return self._single_step(batch, batch_idx, step_name="train")

def validation_step(self, batch: torch.Tensor, batch_idx: int):
return self._single_step(batch, batch_idx, step_name="val")

def test_step(self, batch: torch.Tensor, batch_idx: int):
return self._single_step(batch, batch_idx, step_name="test")

def configure_optimizers(self):
return torch.optim.Adam(
self.parameters(),
lr=self.learning_rate,
)
59 changes: 59 additions & 0 deletions tests/models/ssl/test_lfr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from torch.nn import Sequential, Conv2d, CrossEntropyLoss
from torchvision.transforms import Resize

from minerva.models.ssl.lfr import RepeatedModuleList, LearnFromRandomnessModel
from minerva.models.nets.image.deeplabv3 import DeepLabV3Backbone


def test_lfr():

## Example class for projector
class Projector(Sequential):
def __init__(self):
super().__init__(
Conv2d(3, 16, 5, 2),
Conv2d(16, 64, 5, 2),
Conv2d(64, 16, 5, 2),
Resize((100, 50)),
)

## Example class for predictor
class Predictor(Sequential):
def __init__(self):
super().__init__(Conv2d(2048, 16, 1), Resize((100, 50)))

# Declare model
model = LearnFromRandomnessModel(
DeepLabV3Backbone(),
RepeatedModuleList(5, Projector),
RepeatedModuleList(5, Predictor),
CrossEntropyLoss(),
flatten=False
)

# Test the class instantiation
assert model is not None

# # Test the forward method
input_shape = (2, 3, 701, 255)
expected_output_size = torch.Size([2, 5, 16, 100, 50])
x = torch.rand(*input_shape)

y_pred, y_proj = model(x)
assert (
y_pred.shape == expected_output_size
), f"Expected output shape {expected_output_size}, but got {y_pred.shape}"

assert (
y_proj.shape == expected_output_size
), f"Expected output shape {expected_output_size}, but got {y_proj.shape}"

# Test the _loss_func method
loss = model._loss_func(y_pred, y_proj)
assert loss is not None
# TODO: assert the loss result

# Test the configure_optimizers method
optimizer = model.configure_optimizers()
assert optimizer is not None

0 comments on commit abea98f

Please sign in to comment.