From 55d493807afb52f6e50e88bc0eba092a7c48b7f0 Mon Sep 17 00:00:00 2001 From: fernandoGubiMarques Date: Fri, 5 Jul 2024 14:43:55 +0000 Subject: [PATCH 1/6] added MLP module --- minerva/models/nets/mlp.py | 52 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 minerva/models/nets/mlp.py diff --git a/minerva/models/nets/mlp.py b/minerva/models/nets/mlp.py new file mode 100644 index 0000000..b6aae28 --- /dev/null +++ b/minerva/models/nets/mlp.py @@ -0,0 +1,52 @@ +from torch import nn + +class MLP(nn.Sequential): + """ + A multilayer perceptron (MLP) implemented as a subclass of nn.Sequential. + + The MLP consists of a series of linear layers interleaved with ReLU activation functions, + except for the last layer which is purely linear. + + Example + ------- + + >>> mlp = MLP(10, 20, 30, 40) + >>> print(mlp) + MLP( + (0): Linear(in_features=10, out_features=20, bias=True) + (1): ReLU() + (2): Linear(in_features=20, out_features=30, bias=True) + (3): ReLU() + (4): Linear(in_features=30, out_features=40, bias=True) + ) + """ + + def __init__(self, *layer_sizes): + """ + Initializes the MLP with the given layer sizes. + + Parameters + ---------- + *layer_sizes: int + A variable number of positive integers specifying the size of each layer. + There must be at least two integers, representing the input and output layers. + + Raises + ------ + AssertionError: If less than two layer sizes are provided. + + AssertionError: If any layer size is not a positive integer. + """ + assert ( + len(layer_sizes) >= 2 + ), "Multilayer perceptron must have at least 2 layers" + assert all( + ls > 0 and isinstance(ls, int) for ls in layer_sizes + ), "All layer sizes must be a positive integer" + + layers = [] + for i in range(len(layer_sizes) - 2): + layers += [nn.Linear(layer_sizes[i], layer_sizes[i + 1]), nn.ReLU()] + layers += [nn.Linear(layer_sizes[-2], layer_sizes[-1])] + + super().__init__(*layers) From 0d6f94f9b67a4209704413140bc7d27761b89e3f Mon Sep 17 00:00:00 2001 From: fernandoGubiMarques Date: Fri, 5 Jul 2024 14:47:21 +0000 Subject: [PATCH 2/6] Added LFR module --- minerva/models/nets/__init__.py | 5 + minerva/models/nets/lfr.py | 202 ++++++++++++++++++++++++++++++++ 2 files changed, 207 insertions(+) create mode 100644 minerva/models/nets/lfr.py diff --git a/minerva/models/nets/__init__.py b/minerva/models/nets/__init__.py index c179923..c6c35d5 100644 --- a/minerva/models/nets/__init__.py +++ b/minerva/models/nets/__init__.py @@ -3,6 +3,8 @@ from .image.setr import SETR_PUP from .image.unet import UNet from .image.wisenet import WiseNet +from .mlp import MLP +from .lfr import LearnFromRandomnessModel, RepeatedModuleList __all__ = [ "SimpleSupervisedModel", @@ -10,4 +12,7 @@ "SETR_PUP", "UNet", "WiseNet", + "MLP", + "LearnFromRandomnessModel", + "RepeatedModuleList" ] diff --git a/minerva/models/nets/lfr.py b/minerva/models/nets/lfr.py new file mode 100644 index 0000000..a5301a2 --- /dev/null +++ b/minerva/models/nets/lfr.py @@ -0,0 +1,202 @@ +import lightning as L +import torch + + +class RepeatedModuleList(torch.nn.ModuleList): + """ + A module list with the same module `cls`, instantiated `size` times. + """ + + def __init__(self, size, cls, *args, **kwargs): + """ + Initializes the RepeatedModuleList with multiple instances of a given module class. + + Parameters + ---------- + size: int + The number of instances to create. + cls: type + The module class to instantiate. Must be a subclass of `torch.nn.Module`. + *args: + Positional arguments to pass to the module class constructor. + **kwargs: + Keyword arguments to pass to the module class constructor. + + Raises + ------ + AssertionError: + If `cls` is not a subclass of `torch.nn.Module`. + + Example + ------- + >>> class SimpleModule(torch.nn.Module): + >>> def __init__(self, in_features, out_features): + >>> super().__init__() + >>> self.linear = torch.nn.Linear(in_features, out_features) + >>> + >>> repeated_modules = RepeatedModuleList(3, SimpleModule, 10, 5) + >>> print(repeated_modules) + RepeatedModuleList( + (0): SimpleModule( + (linear): Linear(in_features=10, out_features=5, bias=True) + ) + (1): SimpleModule( + (linear): Linear(in_features=10, out_features=5, bias=True) + ) + (2): SimpleModule( + (linear): Linear(in_features=10, out_features=5, bias=True) + ) + ) + """ + + assert issubclass( + cls, torch.nn.Module + ), f"{cls} does not derive from torch.nn.Module" + + super().__init__([cls(*args, **kwargs) for _ in range(size)]) + + +class LearnFromRandomnessModel(L.LightningModule): + """ + A PyTorch Lightning model for pretraining with the technique + 'Learning From Random Projectors'. + + References + ---------- + Yi Sui, Tongzi Wu, Jesse C. Cresswell, Ga Wu, George Stein, Xiao Shi Huang, Xiaochen Zhang, Maksims Volkovs. + "Self-supervised Representation Learning From Random Data Projectors", 2024 + """ + + def __init__( + self, + backbone: torch.nn.Module, + projectors: torch.nn.ModuleList, + predictors: torch.nn.ModuleList, + loss_fn: torch.nn.Module, + learning_rate: float = 1e-3, + flatten: bool = True, + ): + """ + Initialize the LFR_Model. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone neural network for feature extraction. + projectors : torch.nn.ModuleList + A list of projector networks. + predictors : torch.nn.ModuleList + A list of predictor networks. + loss_fn : torch.nn.Module + The loss function to optimize. + learning_rate : Optional[float] + The learning rate for the optimizer, by default 1e-3. + flatten : Optional[bool] + Whether to flatten the input tensor or not, by default True. + """ + super().__init__() + self.backbone = backbone + self.projectors = projectors + self.predictors = predictors + self.loss_fn = loss_fn + self.learning_rate = learning_rate + self.flatten = flatten + + for param in self.projectors.parameters(): + param.requires_grad = False + + for proj in self.projectors: + proj.eval() + + def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Calculate the loss between the output and the input data. + + Parameters + ---------- + y_hat : torch.Tensor + The output data from the forward pass. + y : torch.Tensor + The input data/label. + + Returns + ------- + torch.Tensor + The loss value. + """ + loss = self.loss_fn(y_hat, y) + return loss + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the network. + + Parameters + ---------- + x : torch.Tensor + The input data. + + Returns + ------- + torch.Tensor + The predicted output and projected input. + """ + z: torch.Tensor = self.backbone(x) + + if self.flatten: + z = z.view(z.size(0), -1) + x = x.view(x.size(0), -1) + + y_pred = torch.stack([predictor(z) for predictor in self.predictors], 1) + y_proj = torch.stack([projector(x) for projector in self.projectors], 1) + + return y_pred, y_proj + + def _single_step( + self, batch: torch.Tensor, batch_idx: int, step_name: str + ) -> torch.Tensor: + """ + Perform a single training/validation/test step. + + Parameters + ---------- + batch : torch.Tensor + The input batch of data. + batch_idx : int + The index of the batch. + step_name : str + The name of the step (train, val, test). + + Returns + ------- + torch.Tensor + The loss value for the batch. + """ + x = batch + y_pred, y_proj = self.forward(x) + loss = self._loss_func(y_pred, y_proj) + self.log( + f"{step_name}_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return loss + + def training_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, step_name="train") + + def validation_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, step_name="val") + + def test_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, step_name="test") + + def configure_optimizers(self): + return torch.optim.Adam( + self.parameters(), + lr=self.learning_rate, + ) From 8b359a927912966cf8d862d9df5d21922d4d3e69 Mon Sep 17 00:00:00 2001 From: fernandoGubiMarques Date: Fri, 5 Jul 2024 15:32:30 +0000 Subject: [PATCH 3/6] implemented tests for LFR --- tests/models/nets/test_lfr.py | 60 +++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tests/models/nets/test_lfr.py diff --git a/tests/models/nets/test_lfr.py b/tests/models/nets/test_lfr.py new file mode 100644 index 0000000..666e050 --- /dev/null +++ b/tests/models/nets/test_lfr.py @@ -0,0 +1,60 @@ +import torch +from torch.nn import Sequential, Conv2d, CrossEntropyLoss +from torchvision.transforms import Resize + +from minerva.models.nets.lfr import RepeatedModuleList, LearnFromRandomnessModel +from minerva.models.nets.image.deeplabv3 import DeepLabV3Backbone + + +def test_lfr(): + + ## Example class for projector + class Projector(Sequential): + def __init__(self): + super().__init__( + Conv2d(3, 16, 5, 2), + Conv2d(16, 64, 5, 2), + Conv2d(64, 16, 5, 2), + Resize((100, 50)), + ) + + ## Example class for predictor + class Predictor(Sequential): + def __init__(self): + super().__init__(Conv2d(2048, 16, 1), Resize((100, 50))) + + # Declare model + model = LearnFromRandomnessModel( + DeepLabV3Backbone(), + RepeatedModuleList(5, Projector), + RepeatedModuleList(5, Predictor), + CrossEntropyLoss(), + flatten=False + ) + + # Test the class instantiation + assert model is not None + + # # Test the forward method + input_shape = (2, 3, 701, 255) + expected_output_size = torch.Size([2, 5, 16, 100, 50]) + x = torch.rand(*input_shape) + + y_pred, y_proj = model(x) + assert ( + y_pred.shape == expected_output_size + ), f"Expected output shape {expected_output_size}, but got {y_pred.shape}" + + assert ( + y_proj.shape == expected_output_size + ), f"Expected output shape {expected_output_size}, but got {y_proj.shape}" + + # Test the _loss_func method + loss = model._loss_func(y_pred, y_proj) + assert loss is not None + # TODO: assert the loss result + + # Test the configure_optimizers method + optimizer = model.configure_optimizers() + assert optimizer is not None + From 056b75c6f71499a57684e041426cfea8363bdd7f Mon Sep 17 00:00:00 2001 From: fernandoGubiMarques Date: Fri, 5 Jul 2024 17:39:10 +0000 Subject: [PATCH 4/6] fixed type annotations in docs and implementation --- minerva/models/nets/lfr.py | 20 +++++++++++++------- minerva/models/nets/mlp.py | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/minerva/models/nets/lfr.py b/minerva/models/nets/lfr.py index a5301a2..29ce07f 100644 --- a/minerva/models/nets/lfr.py +++ b/minerva/models/nets/lfr.py @@ -7,7 +7,13 @@ class RepeatedModuleList(torch.nn.ModuleList): A module list with the same module `cls`, instantiated `size` times. """ - def __init__(self, size, cls, *args, **kwargs): + def __init__( + self, + size: int, + cls: type, + *args, + **kwargs + ): """ Initializes the RepeatedModuleList with multiple instances of a given module class. @@ -81,17 +87,17 @@ def __init__( Parameters ---------- - backbone : torch.nn.Module + backbone: torch.nn.Module The backbone neural network for feature extraction. - projectors : torch.nn.ModuleList + projectors: torch.nn.ModuleList A list of projector networks. - predictors : torch.nn.ModuleList + predictors: torch.nn.ModuleList A list of predictor networks. - loss_fn : torch.nn.Module + loss_fn: torch.nn.Module The loss function to optimize. - learning_rate : Optional[float] + learning_rate: float The learning rate for the optimizer, by default 1e-3. - flatten : Optional[bool] + flatten: bool Whether to flatten the input tensor or not, by default True. """ super().__init__() diff --git a/minerva/models/nets/mlp.py b/minerva/models/nets/mlp.py index b6aae28..8452749 100644 --- a/minerva/models/nets/mlp.py +++ b/minerva/models/nets/mlp.py @@ -21,7 +21,7 @@ class MLP(nn.Sequential): ) """ - def __init__(self, *layer_sizes): + def __init__(self, *layer_sizes: int): """ Initializes the MLP with the given layer sizes. From fd2226557742e03917442524153dbb0ae274a551 Mon Sep 17 00:00:00 2001 From: fernandoGubiMarques Date: Fri, 5 Jul 2024 17:52:33 +0000 Subject: [PATCH 5/6] updated readme --- minerva/models/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/minerva/models/README.md b/minerva/models/README.md index 35752cb..fd21816 100644 --- a/minerva/models/README.md +++ b/minerva/models/README.md @@ -21,4 +21,6 @@ # SSL Models -... \ No newline at end of file +| **Model** | **Authors** | **Task** | **Type** | **Input Shape** | **Python Class** | **Observations** | +|-----------------------------------------|---------------|----------|----------|:---------------:|:--------------------------------------------:|-------------------| +| [LFR](https://arxiv.org/abs/2310.07756) | Yi Sui et al. | Any | Any | Any | minerva.models.nets.LearnFromRandomnessModel | | From 0d0e5f835ef704926ca4c467b7708d202fcba865 Mon Sep 17 00:00:00 2001 From: fernandoGubiMarques Date: Fri, 12 Jul 2024 14:16:39 +0000 Subject: [PATCH 6/6] moved lfr to ssl folder, added parameters to mlp --- minerva/models/nets/__init__.py | 5 +-- minerva/models/nets/mlp.py | 53 ++++++++++++++++++-------- minerva/models/ssl/__init__.py | 6 +++ minerva/models/{nets => ssl}/lfr.py | 0 tests/models/{nets => ssl}/test_lfr.py | 3 +- 5 files changed, 45 insertions(+), 22 deletions(-) create mode 100644 minerva/models/ssl/__init__.py rename minerva/models/{nets => ssl}/lfr.py (100%) rename tests/models/{nets => ssl}/test_lfr.py (95%) diff --git a/minerva/models/nets/__init__.py b/minerva/models/nets/__init__.py index c6c35d5..bbfdb31 100644 --- a/minerva/models/nets/__init__.py +++ b/minerva/models/nets/__init__.py @@ -4,7 +4,6 @@ from .image.unet import UNet from .image.wisenet import WiseNet from .mlp import MLP -from .lfr import LearnFromRandomnessModel, RepeatedModuleList __all__ = [ "SimpleSupervisedModel", @@ -12,7 +11,5 @@ "SETR_PUP", "UNet", "WiseNet", - "MLP", - "LearnFromRandomnessModel", - "RepeatedModuleList" + "MLP" ] diff --git a/minerva/models/nets/mlp.py b/minerva/models/nets/mlp.py index 8452749..b59e2c0 100644 --- a/minerva/models/nets/mlp.py +++ b/minerva/models/nets/mlp.py @@ -1,12 +1,14 @@ from torch import nn +from typing import Sequence + class MLP(nn.Sequential): """ A multilayer perceptron (MLP) implemented as a subclass of nn.Sequential. - - The MLP consists of a series of linear layers interleaved with ReLU activation functions, - except for the last layer which is purely linear. - + + This MLP is composed of a sequence of linear layers interleaved with ReLU activation + functions, except for the final layer which remains purely linear. + Example ------- @@ -21,32 +23,51 @@ class MLP(nn.Sequential): ) """ - def __init__(self, *layer_sizes: int): + def __init__( + self, + layer_sizes: Sequence[int], + activation_cls: type = nn.ReLU, + *args, + **kwargs + ): """ - Initializes the MLP with the given layer sizes. + Initializes the MLP with specified layer sizes. Parameters ---------- - *layer_sizes: int - A variable number of positive integers specifying the size of each layer. - There must be at least two integers, representing the input and output layers. - + layer_sizes : Sequence[int] + A sequence of positive integers indicating the size of each layer. + At least two integers are required, representing the input and output layers. + activation_cls : type + The class of the activation function to use between layers. Default is nn.ReLU. + *args + Additional arguments passed to the activation function. + **kwargs + Additional keyword arguments passed to the activation function. + Raises ------ - AssertionError: If less than two layer sizes are provided. - - AssertionError: If any layer size is not a positive integer. + AssertionError + If fewer than two layer sizes are provided or if any layer size is not a positive integer. + AssertionError + If activation_cls does not inherit from torch.nn.Module. """ + assert ( len(layer_sizes) >= 2 ), "Multilayer perceptron must have at least 2 layers" assert all( ls > 0 and isinstance(ls, int) for ls in layer_sizes - ), "All layer sizes must be a positive integer" + ), "All layer sizes must be positive integers" + + assert issubclass( + activation_cls, nn.Module + ), "activation_cls must inherit from torch.nn.Module" layers = [] for i in range(len(layer_sizes) - 2): - layers += [nn.Linear(layer_sizes[i], layer_sizes[i + 1]), nn.ReLU()] - layers += [nn.Linear(layer_sizes[-2], layer_sizes[-1])] + layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1])) + layers.append(activation_cls(*args, **kwargs)) + layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1])) super().__init__(*layers) diff --git a/minerva/models/ssl/__init__.py b/minerva/models/ssl/__init__.py new file mode 100644 index 0000000..c14ad54 --- /dev/null +++ b/minerva/models/ssl/__init__.py @@ -0,0 +1,6 @@ +from .lfr import LearnFromRandomnessModel, RepeatedModuleList + +__all__ = [ + "LearnFromRandomnessModel", + "RepeatedModuleList" +] \ No newline at end of file diff --git a/minerva/models/nets/lfr.py b/minerva/models/ssl/lfr.py similarity index 100% rename from minerva/models/nets/lfr.py rename to minerva/models/ssl/lfr.py diff --git a/tests/models/nets/test_lfr.py b/tests/models/ssl/test_lfr.py similarity index 95% rename from tests/models/nets/test_lfr.py rename to tests/models/ssl/test_lfr.py index 666e050..c93cfec 100644 --- a/tests/models/nets/test_lfr.py +++ b/tests/models/ssl/test_lfr.py @@ -2,7 +2,7 @@ from torch.nn import Sequential, Conv2d, CrossEntropyLoss from torchvision.transforms import Resize -from minerva.models.nets.lfr import RepeatedModuleList, LearnFromRandomnessModel +from minerva.models.ssl.lfr import RepeatedModuleList, LearnFromRandomnessModel from minerva.models.nets.image.deeplabv3 import DeepLabV3Backbone @@ -57,4 +57,3 @@ def __init__(self): # Test the configure_optimizers method optimizer = model.configure_optimizers() assert optimizer is not None -