diff --git a/dwi_ml/data/processing/streamlines/data_augmentation.py b/dwi_ml/data/processing/streamlines/data_augmentation.py index cd46a185..3900eda1 100644 --- a/dwi_ml/data/processing/streamlines/data_augmentation.py +++ b/dwi_ml/data/processing/streamlines/data_augmentation.py @@ -153,3 +153,28 @@ def reverse_streamlines(sft: StatefulTractogram, data_per_streamline=sft.data_per_streamline) return new_sft + + +def normalize_streamlines(sft: StatefulTractogram): + """ Normalize streamlines so that their points are in [0, 1]. + + Parameters + ---------- + sft: StatefulTractogram + Dipy object containing your streamlines + + Returns + ------- + new_sft: StatefulTractogram + Dipy object with reversed streamlines and data_per_point. + """ + + dims = sft.dimensions + + new_streamlines = [s / dims for s in sft.streamlines] + + new_sft = StatefulTractogram.from_sft( + new_streamlines, sft, data_per_point=sft.data_per_point, + data_per_streamline=sft.data_per_streamline) + + return new_sft diff --git a/dwi_ml/models/projects/ae_next_models.py b/dwi_ml/models/projects/ae_next_models.py new file mode 100644 index 00000000..3ca42ee9 --- /dev/null +++ b/dwi_ml/models/projects/ae_next_models.py @@ -0,0 +1,313 @@ +# -*- coding: utf-8 -*- +import logging +from typing import List + +import torch + +from dwi_ml.models.main_models import MainModelAbstract + + +class Permute(torch.nn.Module): + """This module returns a view of the tensor input with its dimensions permuted. + From https://github.com/pytorch/vision/blob/main/torchvision/ops/misc.py#L308 # noqa E501 + + Args: + dims (List[int]): The desired ordering of dimensions + """ + + def __init__(self, dims: List[int]): + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.permute(x, self.dims) + + +class LayerNorm1d(torch.nn.LayerNorm): + """ Layer normalization module. Uses the same normalization as torch.nn.LayerNorm + but with the input tensor permuted before and after the normalization. This is + necessary because torch.nn.LayerNorm normalizes the last dimension of the input + tensor, but we want to normalize the middle dimension to fit with convolutional layers. + + From https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py # noqa E501 + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 1) + x = torch.nn.functional.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 2, 1) + return x + + +class ResBlock1d(torch.nn.Module): + """ ConvNext residual block. + + Uses a 1D convolution with kernel size 7 and groups=channels to ensure that + each channel is processed independently. The block is composed of a layer + normalization, a GELU activation function, a linear layer with 4 times the + number of channels, a GELU activation function, and a linear layer with the + same number of channels as the input. The output is added to the input + tensor. + """ + + def __init__(self, channels, stride=1, norm=LayerNorm1d): + """ Constructor + + Parameters + ---------- + channels : int + Number of channels in the input tensor. + stride : int + Stride of the convolution. + norm : torch.nn.Module + Normalization layer to use. + """ + + super(ResBlock1d, self).__init__() + + self.block = torch.nn.Sequential( + torch.nn.Conv1d(channels, channels, kernel_size=7, groups=channels, + stride=stride, padding=3, padding_mode='reflect'), + norm(channels), + Permute((0, 2, 1)), + torch.nn.Linear( + in_features=channels, out_features=4 * channels, bias=True), + torch.nn.GELU(), + torch.nn.Linear( + in_features=channels * 4, out_features=channels, bias=True), + Permute((0, 2, 1))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward pass. + """ + + identity = x + x = self.block(x) + + return x + identity + + +class ModelConvNextAE(MainModelAbstract): + """ ConvNext autoencoder model, modified to work with 1D data. + + The model is composed of an encoder and a decoder. The encoder is composed + of a series of 1D convolutions and residual blocks. The decoder is composed + of residual blocks, upsampling layers, and 1D convolutions. The decoder is + much smaller than the encoder to reduce the number of parameters and + enforce a well defined latent space. + + References: + [1]: Liu, Z., Mao, H., Wu, C. Y., Feichtenhofer, C., Darrell, T., & Xie, S. + (2022). A convnet for the 2020s. In Proceedings of the IEEE/CVF conference + on computer vision and pattern recognition (pp. 11976-11986). + """ + + def __init__( + self, + kernel_size, + latent_space_dims, + experiment_name: str, + nb_points: int = None, + step_size: float = None, + compress_lines: float = False, + # Other + log_level=logging.root.level + ): + super().__init__( + experiment_name, + step_size=step_size, + nb_points=nb_points, + compress_lines=compress_lines, + log_level=log_level + ) + + self.kernel_size = kernel_size + self.latent_space_dims = latent_space_dims + self.reconstruction_loss = torch.nn.MSELoss(reduction="sum") + + """ + Encode convolutions. Follows the same structure as the original + ConvNext-Tiny architecture, but with 1D convolutions. + """ + self.encoder = torch.nn.Sequential( + torch.nn.Conv1d(3, 32, self.kernel_size+1, stride=1, padding=1, + padding_mode='reflect'), + ResBlock1d(32), + ResBlock1d(32), + ResBlock1d(32), + torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0, + padding_mode='reflect'), + ResBlock1d(64), + ResBlock1d(64), + ResBlock1d(64), + torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=0, + padding_mode='reflect'), + ResBlock1d(128), + ResBlock1d(128), + ResBlock1d(128), + torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=0, + padding_mode='reflect'), + ResBlock1d(256), + ResBlock1d(256), + ResBlock1d(256), + torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=0, + padding_mode='reflect'), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + torch.nn.Conv1d(512, 1024, self.kernel_size, stride=2, padding=0, + padding_mode='reflect'), + ResBlock1d(1024), + ResBlock1d(1024), + ResBlock1d(1024), + ) + """ + Latent space + """ + + self.fc1 = torch.nn.Linear(8192, + self.latent_space_dims) # 8192 = 1024*8 + self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) + + """ + Decode convolutions. Uses upsampling and 1D convolutions instead of + transposed convolutions to avoid checkerboard artifacts. + """ + self.decoder = torch.nn.Sequential( + ResBlock1d(1024), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 1024, 512, self.kernel_size+1, stride=1, padding=1, + padding_mode='reflect'), + ResBlock1d(512), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 512, 256, self.kernel_size+1, stride=1, padding=1, + padding_mode='reflect'), + ResBlock1d(256), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 256, 128, self.kernel_size+1, stride=1, padding=1, + padding_mode='reflect'), + ResBlock1d(128), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 128, 64, self.kernel_size+1, stride=1, padding=1, + padding_mode='reflect'), + ResBlock1d(64), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 64, 32, self.kernel_size+1, stride=1, padding=1, + padding_mode='reflect'), + ResBlock1d(32), + torch.nn.Conv1d( + 32, 3, self.kernel_size+1, stride=1, padding=1, + padding_mode='reflect'), + ) + + def forward( + self, + input_streamlines: List[torch.tensor], + ): + """Run the model on a batch of sequences. + + Parameters + ---------- + input_streamlines: List[torch.tensor], + Batch of streamlines. Only used if previous directions are added to + the model. Used to compute directions; its last point will not be + used. + + Returns + ------- + model_outputs : List[Tensor] + Output data, ready to be passed to either `compute_loss()` or + `get_tracking_directions()`. + """ + + x = self.decode(self.encode(input_streamlines)) + return x + + def encode(self, x: List[torch.Tensor]) -> torch.Tensor: + """ Encode the input data. + + Parameters + ---------- + x : list of tensors + List of input tensors. + + Returns + ------- + z : torch.Tensor + Input data encoded to the latent space. + """ + + # x: list of tensors + if isinstance(x, list): + x = torch.stack(x) + x = torch.swapaxes(x, 1, 2) + + x = self.encoder(x) + self.encoder_out_size = (x.shape[1], x.shape[2]) + + # Flatten + h7 = x.reshape(-1, self.encoder_out_size[0] * self.encoder_out_size[1]) + + z = self.fc1(h7) + return z + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """ Decode the input data. + + Parameters + ---------- + z : torch.Tensor + Input data in the latent space. + + Returns + ------- + x_hat : torch.Tensor + Decoded data. + """ + + fc = self.fc2(z) + fc_reshape = fc.view( + -1, self.encoder_out_size[0], self.encoder_out_size[1] + ) + x_hat = self.decoder(fc_reshape) + return x_hat + + def compute_loss(self, model_outputs, targets, average_results=True): + """Compute the loss of the model. + + Parameters + ---------- + model_outputs : List[Tensor] + Model outputs. + targets : List[Tensor] + Target data. + average_results : bool + If True, the loss will be averaged across the batch. + + Returns + ------- + loss : Tensor + Loss value. + """ + + targets = torch.stack(targets) + targets = torch.swapaxes(targets, 1, 2) + mse = self.reconstruction_loss(model_outputs, targets) + + return mse, 1 diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 623371a9..6794c9ea 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -50,6 +50,7 @@ from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset from dwi_ml.data.processing.streamlines.data_augmentation import ( + normalize_streamlines, reverse_streamlines, split_streamlines, resample_or_compress) from dwi_ml.data.processing.utils import add_noise_to_tensor from dwi_ml.models.main_models import MainModelOneInput, \ @@ -64,7 +65,9 @@ def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract, split_ratio: float = 0., noise_gaussian_size_forward: float = 0., noise_gaussian_size_loss: float = 0., - reverse_ratio: float = 0., log_level=logging.root.level): + reverse_ratio: float = 0., + normalize: bool = False, + log_level=logging.root.level): """ Parameters ---------- @@ -131,6 +134,7 @@ def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract, self.noise_gaussian_size_forward = noise_gaussian_size_forward self.noise_gaussian_size_loss = noise_gaussian_size_loss self.split_ratio = split_ratio + self.normalize = normalize self.reverse_ratio = reverse_ratio if self.split_ratio and not 0 <= self.split_ratio <= 1: raise ValueError('Split ratio must be a float between 0 and 1.') @@ -157,6 +161,7 @@ def params_for_checkpoint(self): 'noise_gaussian_size_forward': self.noise_gaussian_size_forward, 'noise_gaussian_size_loss': self.noise_gaussian_size_loss, 'reverse_ratio': self.reverse_ratio, + 'normalize': self.normalize, 'split_ratio': self.split_ratio, } return params @@ -229,6 +234,9 @@ def _data_augmentation_sft(self, sft): reverse_ids = ids[:int(len(ids) * self.reverse_ratio)] sft = reverse_streamlines(sft, reverse_ids) + if self.normalize: + sft = normalize_streamlines(sft) + return sft def add_noise_streamlines_forward(self, batch_streamlines, device): diff --git a/requirements.txt b/requirements.txt index 7739110b..c7ac2ff2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,8 @@ requests==2.28.* dipy==1.9.* scilpy==2.0.2 +packaging==23.2.* +platformdirs<4,>=3.1.1 # ------- # Other important dependencies diff --git a/scripts_python/ae_autoencode_streamlines.py b/scripts_python/ae_autoencode_streamlines.py new file mode 100755 index 00000000..7db2912c --- /dev/null +++ b/scripts_python/ae_autoencode_streamlines.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import argparse +import logging + +import nibabel as nib +import numpy as np +import torch + +from scilpy.io.utils import (add_overwrite_arg, + assert_outputs_exist, + add_reference_arg, + add_verbose_arg) +from scilpy.io.streamlines import load_tractogram_with_reference +from scilpy.tracking.utils import save_tractogram +from dipy.tracking.streamline import set_number_of_points +from dwi_ml.io_utils import (add_arg_existing_experiment_path, + add_memory_args) +from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.models.projects.ae_next_models import ModelConvNextAE +from nibabel.streamlines import detect_format + + +def _build_arg_parser(): + p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, + description=__doc__) + # Mandatory + # Should only be False for debugging tests. + add_arg_existing_experiment_path(p) + # Add_args_testing_subj_hdf5(p) + + p.add_argument('in_tractogram', type=str, + help="Tractogram to autoencode.") + + p.add_argument('out_tractogram', type=str, + help="Autoencoded tractogram.") + + # Additional arg for projects + p.add_argument('--model', type=str, choices=['finta', 'convnext'], + default='finta', + help='Type of model to use.') + + # Options + p.add_argument('--normalize', action='store_true') + p.add_argument('--batch_size', type=int, default=5000) + add_memory_args(p) + + add_reference_arg(p) + add_overwrite_arg(p) + add_verbose_arg(p) + return p + + +def autoencode_streamlines( + model, sft, device, batch_size=5000, normalize=False +): + """ Autoencode a tractogram. Streamlines are loaded in batches to make + it possible to autoencode large tractogram. This function returns a + generator instead of an actual tractogram for the same reason. + + Parameters + ---------- + model: MainModelAbstract + Autoencoder + sft: StatefulTractogram + Tractogram to autoencode + device: torch.device + GPU or CPU. + batch_size: int + Number of streamlines to autoencode at once. + normalize: bool + Whether to normalize the streamline coordinates before + inputting them to the model. + + Returns + ------- + _autoencode_streamlines: generator + Generator function which will autoencode. + """ + + sft.to_vox() + sft.to_corner() + + bundle = sft.streamlines + + logging.info("Running model to compute loss") + batch_size = batch_size + batches = range(0, len(sft.streamlines), batch_size) + + def _autoencode_streamlines(): + for i, b in enumerate(batches): + with torch.no_grad(): + s = np.asarray( + set_number_of_points( + bundle[i * batch_size:(i+1) * batch_size], + 256)) + if normalize: + s /= sft.dimensions + + streamlines = torch.as_tensor( + s, dtype=torch.float32, device=device) + tmp_outputs = model(streamlines).cpu().numpy() + # latent = model.encode(streamlines) + scaling = sft.dimensions if normalize else 1.0 + streamlines_output = tmp_outputs.transpose((0, 2, 1)) * scaling + for strml in streamlines_output: + yield strml, strml[0] + + return _autoencode_streamlines + + +def main(): + p = _build_arg_parser() + args = p.parse_args() + + tracts_format = detect_format(args.out_tractogram) + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) + + # Verify output names + # Check experiment_path exists and best_model folder exists + # Assert_inputs_exist(p, args.hdf5_file) + assert_outputs_exist(p, args, args.out_tractogram) + + # Device + device = (torch.device('cuda') if torch.cuda.is_available() and + args.use_gpu else None) + + # 1. Load model + logging.debug("Loading model.") + if args.model == 'finta': + architecture = ModelAE + else: + architecture = ModelConvNextAE + + model = architecture.load_model_from_params_and_state( + args.experiment_path + '/best_model', log_level=sub_loggers_level + ).to(device) + + sft = load_tractogram_with_reference(p, args, args.in_tractogram) + + _autoencode_streamlines = autoencode_streamlines( + model, sft, device, args.batch_size, args.normalize) + + # Need a nifti image to lazy-save a tractogram + fake_ref = nib.Nifti1Image(np.zeros(sft.dimensions), sft.affine) + + save_tractogram(_autoencode_streamlines(), tracts_format, fake_ref, + len(sft.streamlines), args.out_tractogram, + 0, 999, False, False, args.verbose) + + +if __name__ == '__main__': + main() diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index e5846235..f377b4fa 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -21,6 +21,7 @@ from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer from dwi_ml.io_utils import add_memory_args +from dwi_ml.models.projects.ae_next_models import ModelConvNextAE from dwi_ml.models.projects.ae_models import ModelAE from dwi_ml.training.trainers import DWIMLAbstractTrainer from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, @@ -45,6 +46,10 @@ def prepare_arg_parser(): add_memory_args(p, add_lazy_options=True, add_rng=True) add_verbose_arg(p) + # Additional arg for projects + p.add_argument('--model', type=str, choices=['finta', 'convnext'], + help='Type of model to train') + return p @@ -61,9 +66,19 @@ def init_from_args(args, sub_loggers_level): # Final model with Timer("\n\nPreparing model", newline=True, color='yellow'): # INPUTS: verifying args - model = ModelAE( - experiment_name=args.experiment_name, - log_level=sub_loggers_level) + + if args.model == 'finta': + model = ModelAE( + experiment_name=args.experiment_name, + step_size=None, compress_lines=None, + kernel_size=3, latent_space_dims=32, + log_level=sub_loggers_level) + else: + model = ModelConvNextAE( + experiment_name=args.experiment_name, + step_size=None, compress_lines=None, + kernel_size=2, latent_space_dims=32, + log_level=sub_loggers_level) logging.info("AEmodel final parameters:" + format_dict_to_str(model.params_for_checkpoint)) @@ -79,6 +94,7 @@ def init_from_args(args, sub_loggers_level): dataset=dataset, model=model, streamline_group_name=args.streamline_group_name, # OTHER + normalize=True, rng=args.rng, log_level=sub_loggers_level) logging.info("Loader user-defined parameters: " +