From 74494df24bd2c03cd59bdbc08b7f78f192641f22 Mon Sep 17 00:00:00 2001 From: Antoine Theberge Date: Thu, 19 Sep 2024 08:43:40 -0400 Subject: [PATCH 01/14] WIP: transformer ae --- command_ae.sh | 22 ++++++ dwi_ml/models/projects/ae_models.py | 101 ++++++++++++++++++++++------ 2 files changed, 101 insertions(+), 22 deletions(-) create mode 100644 command_ae.sh diff --git a/command_ae.sh b/command_ae.sh new file mode 100644 index 00000000..17f3e3de --- /dev/null +++ b/command_ae.sh @@ -0,0 +1,22 @@ +experiments=experiments +experiment_name=fibercup_september24 + +rm -rf $experiments/$experiment_name + +ae_train_model.py $experiments \ + $experiment_name \ + fibercup.hdf5 \ + target \ + -v INFO \ + --batch_size_training 80 \ + --batch_size_units nb_streamlines \ + --nb_subjects_per_batch 1 \ + --learning_rate 0.001 \ + --weight_decay 0.05 \ + --optimizer Adam \ + --max_epochs 1000 \ + --max_batches_per_epoch_training 20 \ + --comet_workspace dwi_ml \ + --comet_project ae-fibercup \ + --patience 100 \ + --use_gpu diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 0818b6ad..9dc9f37d 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -1,13 +1,48 @@ # -*- coding: utf-8 -*- import logging +import math + from typing import List import torch +from torch import nn +from torch import Tensor from torch.nn import functional as F from dwi_ml.models.main_models import MainModelAbstract +class PositionalEncoding(nn.Module): + """ Modified from + https://pytorch.org/tutorials/beginner/transformer_tutorial.htm://pytorch.org/tutorials/beginner/transformer_tutorial.html # noqa E504 + """ + + def __init__( + self, d_model: int, dropout: float = 0.1, max_len: int = 5000 + ): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) + * (-math.log(10000.0) / d_model)) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x: Tensor) -> Tensor: + """ + Arguments: + x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` + """ + x = x.permute(1, 0, 2) + x = x + self.pe[:x.size(0)] + x = self.dropout(x) + x = x.permute(1, 0, 2) + return x + + class ModelAE(MainModelAbstract): """ Recurrent tracking model. @@ -27,23 +62,45 @@ def __init__(self, kernel_size, latent_space_dims, log_level=logging.root.level): super().__init__(experiment_name, step_size, compress_lines, log_level) - self.kernel_size = kernel_size - self.latent_space_dims = latent_space_dims + # Embedding size, could be defined by the user ? + self.embedding_size = 32 + # Embedding layer + self.embedding = nn.Sequential( + *(nn.Linear(3, self.embedding_size), + nn.ReLU())) - self.pad = torch.nn.ReflectionPad1d(1) + # Positional encoding layer + self.pos_encoding = PositionalEncoding( + self.embedding_size, max_len=(256)) + # Transformer encoder layer + layer = nn.TransformerEncoderLayer( + self.embedding_size, 4, batch_first=True) - def pre_pad(m): - return torch.nn.Sequential(self.pad, m) + # Transformer encoder + self.encoder = nn.TransformerEncoder(layer, 4) + self.decoder = nn.TransformerEncoder(layer, 4) + + self.reconstruction_loss = torch.nn.MSELoss() + + self.pad = torch.nn.ReflectionPad1d(1) + self.kernel_size = kernel_size + self.latent_space_dims = latent_space_dims self.fc1 = torch.nn.Linear(8192, self.latent_space_dims) # 8192 = 1024*8 self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) + self.fc3 = torch.nn.Linear(self.embedding_size, 3) + + def pre_pad(m): + return torch.nn.Sequential(self.pad, m) + """ Encode convolutions """ self.encod_conv1 = pre_pad( - torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0) + torch.nn.Conv1d(self.embedding_size, 32, + self.kernel_size, stride=2, padding=0) ) self.encod_conv2 = pre_pad( torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0) @@ -95,7 +152,8 @@ def pre_pad(m): scale_factor=2, mode="linear", align_corners=False ) self.decod_conv6 = pre_pad( - torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0) + torch.nn.Conv1d(32, 32, + self.kernel_size, stride=1, padding=0) ) @property @@ -143,6 +201,11 @@ def forward(self, def encode(self, x): # x: list of tensors x = torch.stack(x) + + x = self.embedding(x) * math.sqrt(self.embedding_size) + x = self.pos_encoding(x) + x = self.encoder(x) + x = torch.swapaxes(x, 1, 2) h1 = F.relu(self.encod_conv1(x)) @@ -162,6 +225,7 @@ def encode(self, x): return fc1 def decode(self, z): + fc = self.fc2(z) fc_reshape = fc.view( -1, self.encoder_out_size[0], self.encoder_out_size[1] @@ -178,24 +242,17 @@ def decode(self, z): h10 = self.upsampl5(h9) h11 = self.decod_conv6(h10) - return h11 + h11 = h11.permute(0, 2, 1) + + h12 = self.decoder(h11) + + x = self.fc3(h12) + + return x.permute(0, 2, 1) def compute_loss(self, model_outputs, targets, average_results=True): - print("COMPARISON\n") targets = torch.stack(targets) targets = torch.swapaxes(targets, 1, 2) - print(targets[0, :, 0:5]) - print(model_outputs[0, :, 0:5]) - reconstruction_loss = torch.nn.MSELoss(reduction="sum") - mse = reconstruction_loss(model_outputs, targets) - - # loss_function_vae - # See Appendix B from VAE paper: - # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - # https://arxiv.org/abs/1312.6114 - # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) - # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) - # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) - # kld = torch.sum(kld_element).__mul__(-0.5) + mse = self.reconstruction_loss(model_outputs, targets) return mse, 1 From dd16e895237845de0d6a656b98fad4666403ece7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antoine=20Th=C3=A9berge?= Date: Sun, 22 Sep 2024 19:55:27 -0400 Subject: [PATCH 02/14] WIP: ConvNeXt inspired arch --- command_ae.sh | 31 ++-- dwi_ml/models/projects/ae_models.py | 257 ++++++++++++---------------- 2 files changed, 129 insertions(+), 159 deletions(-) diff --git a/command_ae.sh b/command_ae.sh index 17f3e3de..df418a65 100644 --- a/command_ae.sh +++ b/command_ae.sh @@ -1,22 +1,21 @@ experiments=experiments -experiment_name=fibercup_september24 +experiment_name=mouse_september24 rm -rf $experiments/$experiment_name ae_train_model.py $experiments \ $experiment_name \ - fibercup.hdf5 \ - target \ - -v INFO \ - --batch_size_training 80 \ - --batch_size_units nb_streamlines \ - --nb_subjects_per_batch 1 \ - --learning_rate 0.001 \ - --weight_decay 0.05 \ - --optimizer Adam \ - --max_epochs 1000 \ - --max_batches_per_epoch_training 20 \ - --comet_workspace dwi_ml \ - --comet_project ae-fibercup \ - --patience 100 \ - --use_gpu + mouse.hdf5 \ + target \ + -v INFO \ + --batch_size_training 2500 \ + --batch_size_units nb_streamlines \ + --nb_subjects_per_batch 5 \ + --learning_rate 0.0005*200 0.0003*200 0.0001*200 0.00007*200 0.00005 \ + --weight_decay 0.2 \ + --optimizer Adam \ + --max_epochs 2000 \ + --max_batches_per_epoch_training 9999 \ + --comet_workspace dwi-ml \ + --comet_project ae-fibercup \ + --patience 100 --use_gpu diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 9dc9f37d..4a5f86fa 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -1,46 +1,35 @@ # -*- coding: utf-8 -*- import logging -import math - from typing import List import torch -from torch import nn -from torch import Tensor -from torch.nn import functional as F from dwi_ml.models.main_models import MainModelAbstract -class PositionalEncoding(nn.Module): - """ Modified from - https://pytorch.org/tutorials/beginner/transformer_tutorial.htm://pytorch.org/tutorials/beginner/transformer_tutorial.html # noqa E504 - """ +class ResBlock1d(torch.nn.Module): - def __init__( - self, d_model: int, dropout: float = 0.1, max_len: int = 5000 - ): - super().__init__() - self.dropout = nn.Dropout(p=dropout) + def __init__(self, channels, stride=1): + super(ResBlock1d, self).__init__() - position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2) - * (-math.log(10000.0) / d_model)) - pe = torch.zeros(max_len, 1, d_model) - pe[:, 0, 0::2] = torch.sin(position * div_term) - pe[:, 0, 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) + self.block = torch.nn.Sequential( + torch.nn.Conv1d(channels, channels, kernel_size=1, + stride=stride, padding=0), + torch.nn.BatchNorm1d(channels), + torch.nn.GELU(), + torch.nn.Conv1d(channels, channels, kernel_size=3, + stride=stride, padding=1), + torch.nn.BatchNorm1d(channels), + torch.nn.GELU(), + torch.nn.Conv1d(channels, channels, 1, + 1, 0), + torch.nn.BatchNorm1d(channels)) - def forward(self, x: Tensor) -> Tensor: - """ - Arguments: - x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` - """ - x = x.permute(1, 0, 2) - x = x + self.pe[:x.size(0)] - x = self.dropout(x) - x = x.permute(1, 0, 2) - return x + def forward(self, x): + identity = x + xp = self.block(x) + + return xp + identity class ModelAE(MainModelAbstract): @@ -53,6 +42,7 @@ class ModelAE(MainModelAbstract): deterministic (3D vectors) or probabilistic (based on probability distribution parameters). """ + def __init__(self, kernel_size, latent_space_dims, experiment_name: str, # Target preprocessing params for the batch loader + tracker @@ -62,98 +52,100 @@ def __init__(self, kernel_size, latent_space_dims, log_level=logging.root.level): super().__init__(experiment_name, step_size, compress_lines, log_level) - # Embedding size, could be defined by the user ? - self.embedding_size = 32 - # Embedding layer - self.embedding = nn.Sequential( - *(nn.Linear(3, self.embedding_size), - nn.ReLU())) - - # Positional encoding layer - self.pos_encoding = PositionalEncoding( - self.embedding_size, max_len=(256)) - # Transformer encoder layer - layer = nn.TransformerEncoderLayer( - self.embedding_size, 4, batch_first=True) - - # Transformer encoder - self.encoder = nn.TransformerEncoder(layer, 4) - self.decoder = nn.TransformerEncoder(layer, 4) - - self.reconstruction_loss = torch.nn.MSELoss() - - self.pad = torch.nn.ReflectionPad1d(1) self.kernel_size = kernel_size self.latent_space_dims = latent_space_dims + self.reconstruction_loss = torch.nn.MSELoss(reduction="sum") self.fc1 = torch.nn.Linear(8192, self.latent_space_dims) # 8192 = 1024*8 self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) - self.fc3 = torch.nn.Linear(self.embedding_size, 3) - - def pre_pad(m): - return torch.nn.Sequential(self.pad, m) - """ Encode convolutions """ - self.encod_conv1 = pre_pad( - torch.nn.Conv1d(self.embedding_size, 32, - self.kernel_size, stride=2, padding=0) - ) - self.encod_conv2 = pre_pad( - torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0) - ) - self.encod_conv3 = pre_pad( - torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=0) - ) - self.encod_conv4 = pre_pad( - torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=0) - ) - self.encod_conv5 = pre_pad( - torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=0) - ) - self.encod_conv6 = pre_pad( - torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=0) + self.encoder = torch.nn.Sequential( + torch.nn.Conv1d(3, 64, self.kernel_size, stride=2, padding=1), + torch.nn.GELU(), + ResBlock1d(64), + ResBlock1d(64), + ResBlock1d(64), + torch.nn.Conv1d(64, 64, self.kernel_size, stride=2, padding=1), + torch.nn.GELU(), + ResBlock1d(64), + ResBlock1d(64), + ResBlock1d(64), + torch.nn.Conv1d(64, 256, self.kernel_size, stride=2, padding=1), + torch.nn.GELU(), + ResBlock1d(256), + ResBlock1d(256), + ResBlock1d(256), + torch.nn.Conv1d(256, 256, self.kernel_size, stride=2, padding=1), + torch.nn.GELU(), + ResBlock1d(256), + ResBlock1d(256), + ResBlock1d(256), + torch.nn.Conv1d(256, 1024, self.kernel_size, stride=2, padding=1), + torch.nn.GELU(), + ResBlock1d(1024), + ResBlock1d(1024), + ResBlock1d(1024), + torch.nn.Conv1d(1024, 1024, self.kernel_size, stride=1, padding=1), + torch.nn.GELU(), + ResBlock1d(1024), + ResBlock1d(1024), + ResBlock1d(1024), ) """ Decode convolutions """ - self.decod_conv1 = pre_pad( - torch.nn.Conv1d(1024, 512, self.kernel_size, stride=1, padding=0) - ) - self.upsampl1 = torch.nn.Upsample( - scale_factor=2, mode="linear", align_corners=False - ) - self.decod_conv2 = pre_pad( - torch.nn.Conv1d(512, 256, self.kernel_size, stride=1, padding=0) - ) - self.upsampl2 = torch.nn.Upsample( - scale_factor=2, mode="linear", align_corners=False - ) - self.decod_conv3 = pre_pad( - torch.nn.Conv1d(256, 128, self.kernel_size, stride=1, padding=0) - ) - self.upsampl3 = torch.nn.Upsample( - scale_factor=2, mode="linear", align_corners=False - ) - self.decod_conv4 = pre_pad( - torch.nn.Conv1d(128, 64, self.kernel_size, stride=1, padding=0) - ) - self.upsampl4 = torch.nn.Upsample( - scale_factor=2, mode="linear", align_corners=False - ) - self.decod_conv5 = pre_pad( - torch.nn.Conv1d(64, 32, self.kernel_size, stride=1, padding=0) - ) - self.upsampl5 = torch.nn.Upsample( - scale_factor=2, mode="linear", align_corners=False - ) - self.decod_conv6 = pre_pad( - torch.nn.Conv1d(32, 32, - self.kernel_size, stride=1, padding=0) + self.decoder = torch.nn.Sequential( + ResBlock1d(1024), + ResBlock1d(1024), + ResBlock1d(1024), + torch.nn.GELU(), + torch.nn.Conv1d( + 1024, 1024, self.kernel_size, stride=1, padding=1), + torch.nn.GELU(), + ResBlock1d(1024), + ResBlock1d(1024), + ResBlock1d(1024), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 1024, 256, self.kernel_size, stride=1, padding=1), + torch.nn.GELU(), + ResBlock1d(256), + ResBlock1d(256), + ResBlock1d(256), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 256, 256, self.kernel_size, stride=1, padding=1), + torch.nn.GELU(), + ResBlock1d(256), + ResBlock1d(256), + ResBlock1d(256), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 256, 64, self.kernel_size, stride=1, padding=1), + torch.nn.GELU(), + ResBlock1d(64), + ResBlock1d(64), + ResBlock1d(64), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 64, 64, self.kernel_size, stride=1, padding=1), + torch.nn.GELU(), + ResBlock1d(64), + ResBlock1d(64), + ResBlock1d(64), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 64, 3, self.kernel_size, stride=1, padding=1), ) @property @@ -201,58 +193,37 @@ def forward(self, def encode(self, x): # x: list of tensors x = torch.stack(x) - - x = self.embedding(x) * math.sqrt(self.embedding_size) - x = self.pos_encoding(x) - x = self.encoder(x) - x = torch.swapaxes(x, 1, 2) - h1 = F.relu(self.encod_conv1(x)) - h2 = F.relu(self.encod_conv2(h1)) - h3 = F.relu(self.encod_conv3(h2)) - h4 = F.relu(self.encod_conv4(h3)) - h5 = F.relu(self.encod_conv5(h4)) - h6 = self.encod_conv6(h5) - - self.encoder_out_size = (h6.shape[1], h6.shape[2]) + x = self.encoder(x) + self.encoder_out_size = (x.shape[1], x.shape[2]) # Flatten - h7 = h6.view(-1, self.encoder_out_size[0] * self.encoder_out_size[1]) + h7 = x.view(-1, self.encoder_out_size[0] * self.encoder_out_size[1]) fc1 = self.fc1(h7) - return fc1 def decode(self, z): - fc = self.fc2(z) fc_reshape = fc.view( -1, self.encoder_out_size[0], self.encoder_out_size[1] ) - h1 = F.relu(self.decod_conv1(fc_reshape)) - h2 = self.upsampl1(h1) - h3 = F.relu(self.decod_conv2(h2)) - h4 = self.upsampl2(h3) - h5 = F.relu(self.decod_conv3(h4)) - h6 = self.upsampl3(h5) - h7 = F.relu(self.decod_conv4(h6)) - h8 = self.upsampl4(h7) - h9 = F.relu(self.decod_conv5(h8)) - h10 = self.upsampl5(h9) - h11 = self.decod_conv6(h10) - - h11 = h11.permute(0, 2, 1) - - h12 = self.decoder(h11) - - x = self.fc3(h12) - - return x.permute(0, 2, 1) + z = self.decoder(fc_reshape) + return z def compute_loss(self, model_outputs, targets, average_results=True): targets = torch.stack(targets) targets = torch.swapaxes(targets, 1, 2) mse = self.reconstruction_loss(model_outputs, targets) + # loss_function_vae + # See Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) + # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) + # kld = torch.sum(kld_element).__mul__(-0.5) + return mse, 1 From 7b12e6054002106a1e7810902cb3339266c1be11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antoine=20Th=C3=A9berge?= Date: Wed, 25 Sep 2024 11:23:28 -0400 Subject: [PATCH 03/14] ENH: ConvNext-Tiny + latent --- command_ae.sh | 10 +- .../streamlines/data_augmentation.py | 25 ++ dwi_ml/models/projects/ae_models.py | 75 ++--- dwi_ml/models/projects/ae_next_models.py | 258 ++++++++++++++++++ dwi_ml/training/batch_loaders.py | 10 +- dwi_ml/viz/latent_streamlines.py | 183 +++++++++++++ requirements.txt | 2 + scripts_python/ae_train_model.py | 5 +- scripts_python/ae_visualize_bundles.py | 107 ++++++++ scripts_python/ae_visualize_streamlines.py | 50 ++-- 10 files changed, 667 insertions(+), 58 deletions(-) create mode 100644 dwi_ml/models/projects/ae_next_models.py create mode 100644 dwi_ml/viz/latent_streamlines.py create mode 100644 scripts_python/ae_visualize_bundles.py diff --git a/command_ae.sh b/command_ae.sh index df418a65..841e7ddf 100644 --- a/command_ae.sh +++ b/command_ae.sh @@ -1,20 +1,20 @@ experiments=experiments -experiment_name=mouse_september24 +experiment_name=fibercup_september24 rm -rf $experiments/$experiment_name ae_train_model.py $experiments \ $experiment_name \ - mouse.hdf5 \ + fibercup_tracking.hdf5 \ target \ -v INFO \ - --batch_size_training 2500 \ + --batch_size_training 1800 \ --batch_size_units nb_streamlines \ --nb_subjects_per_batch 5 \ - --learning_rate 0.0005*200 0.0003*200 0.0001*200 0.00007*200 0.00005 \ + --learning_rate 0.00001*300 0.000005 \ --weight_decay 0.2 \ --optimizer Adam \ - --max_epochs 2000 \ + --max_epochs 5000 \ --max_batches_per_epoch_training 9999 \ --comet_workspace dwi-ml \ --comet_project ae-fibercup \ diff --git a/dwi_ml/data/processing/streamlines/data_augmentation.py b/dwi_ml/data/processing/streamlines/data_augmentation.py index 48683def..9ac47505 100644 --- a/dwi_ml/data/processing/streamlines/data_augmentation.py +++ b/dwi_ml/data/processing/streamlines/data_augmentation.py @@ -138,3 +138,28 @@ def reverse_streamlines(sft: StatefulTractogram, data_per_streamline=sft.data_per_streamline) return new_sft + + +def normalize_streamlines(sft: StatefulTractogram): + """ Normalize streamlines so that their points are in [0, 1]. + + Parameters + ---------- + sft: StatefulTractogram + Dipy object containing your streamlines + + Returns + ------- + new_sft: StatefulTractogram + Dipy object with reversed streamlines and data_per_point. + """ + + dims = sft.dimensions + + new_streamlines = [s / dims for s in sft.streamlines] + + new_sft = StatefulTractogram.from_sft( + new_streamlines, sft, data_per_point=sft.data_per_point, + data_per_streamline=sft.data_per_streamline) + + return new_sft diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 4a5f86fa..2e8b7034 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -18,7 +18,7 @@ def __init__(self, channels, stride=1): torch.nn.BatchNorm1d(channels), torch.nn.GELU(), torch.nn.Conv1d(channels, channels, kernel_size=3, - stride=stride, padding=1), + stride=stride, padding=1, padding_mode='reflect'), torch.nn.BatchNorm1d(channels), torch.nn.GELU(), torch.nn.Conv1d(channels, channels, 1, @@ -64,32 +64,38 @@ def __init__(self, kernel_size, latent_space_dims, Encode convolutions """ self.encoder = torch.nn.Sequential( - torch.nn.Conv1d(3, 64, self.kernel_size, stride=2, padding=1), + torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=1, + padding_mode='reflect'), torch.nn.GELU(), - ResBlock1d(64), - ResBlock1d(64), - ResBlock1d(64), - torch.nn.Conv1d(64, 64, self.kernel_size, stride=2, padding=1), + ResBlock1d(32), + ResBlock1d(32), + ResBlock1d(32), + torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=1, + padding_mode='reflect'), torch.nn.GELU(), ResBlock1d(64), ResBlock1d(64), ResBlock1d(64), - torch.nn.Conv1d(64, 256, self.kernel_size, stride=2, padding=1), + torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=1, + padding_mode='reflect'), torch.nn.GELU(), - ResBlock1d(256), - ResBlock1d(256), - ResBlock1d(256), - torch.nn.Conv1d(256, 256, self.kernel_size, stride=2, padding=1), + ResBlock1d(128), + ResBlock1d(128), + ResBlock1d(128), + torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=1, + padding_mode='reflect'), torch.nn.GELU(), ResBlock1d(256), ResBlock1d(256), ResBlock1d(256), - torch.nn.Conv1d(256, 1024, self.kernel_size, stride=2, padding=1), + torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=1, + padding_mode='reflect'), torch.nn.GELU(), - ResBlock1d(1024), - ResBlock1d(1024), - ResBlock1d(1024), - torch.nn.Conv1d(1024, 1024, self.kernel_size, stride=1, padding=1), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), torch.nn.GELU(), ResBlock1d(1024), ResBlock1d(1024), @@ -105,15 +111,17 @@ def __init__(self, kernel_size, latent_space_dims, ResBlock1d(1024), torch.nn.GELU(), torch.nn.Conv1d( - 1024, 1024, self.kernel_size, stride=1, padding=1), + 1024, 512, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), torch.nn.GELU(), - ResBlock1d(1024), - ResBlock1d(1024), - ResBlock1d(1024), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), torch.nn.Conv1d( - 1024, 256, self.kernel_size, stride=1, padding=1), + 512, 256, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), torch.nn.GELU(), ResBlock1d(256), ResBlock1d(256), @@ -121,15 +129,17 @@ def __init__(self, kernel_size, latent_space_dims, torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), torch.nn.Conv1d( - 256, 256, self.kernel_size, stride=1, padding=1), + 256, 128, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), torch.nn.GELU(), - ResBlock1d(256), - ResBlock1d(256), - ResBlock1d(256), + ResBlock1d(128), + ResBlock1d(128), + ResBlock1d(128), torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), torch.nn.Conv1d( - 256, 64, self.kernel_size, stride=1, padding=1), + 128, 64, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), torch.nn.GELU(), ResBlock1d(64), ResBlock1d(64), @@ -137,15 +147,18 @@ def __init__(self, kernel_size, latent_space_dims, torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), torch.nn.Conv1d( - 64, 64, self.kernel_size, stride=1, padding=1), + 64, 32, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), torch.nn.GELU(), - ResBlock1d(64), - ResBlock1d(64), - ResBlock1d(64), + ResBlock1d(32), + ResBlock1d(32), + ResBlock1d(32), torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), torch.nn.Conv1d( - 64, 3, self.kernel_size, stride=1, padding=1), + 32, 3, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), + torch.nn.GELU(), ) @property diff --git a/dwi_ml/models/projects/ae_next_models.py b/dwi_ml/models/projects/ae_next_models.py new file mode 100644 index 00000000..f68def5d --- /dev/null +++ b/dwi_ml/models/projects/ae_next_models.py @@ -0,0 +1,258 @@ +# -*- coding: utf-8 -*- +import logging +from typing import List + +import torch + +from dwi_ml.models.main_models import MainModelAbstract + +class Permute(torch.nn.Module): + """This module returns a view of the tensor input with its dimensions permuted. + From https://github.com/pytorch/vision/blob/main/torchvision/ops/misc.py#L308 # noqa E501 + + Args: + dims (List[int]): The desired ordering of dimensions + """ + + def __init__(self, dims: List[int]): + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.permute(x, self.dims) + + +class LayerNorm1d(torch.nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 1) + x = torch.nn.functional.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 2, 1) + return x + + +class ResBlock1d(torch.nn.Module): + + def __init__(self, channels, stride=1, norm=LayerNorm1d): + super(ResBlock1d, self).__init__() + + self.block = torch.nn.Sequential( + torch.nn.Conv1d(channels, channels, kernel_size=7, groups=channels, + stride=stride, padding=3, padding_mode='reflect'), + norm(channels), + Permute((0, 2, 1)), + torch.nn.Linear( + in_features=channels, out_features=4 * channels, bias=True), + torch.nn.GELU(), + torch.nn.Linear( + in_features=channels * 4, out_features=channels, bias=True), + Permute((0, 2, 1))) + + def forward(self, x): + identity = x + x = self.block(x) + + return x + identity + + +class ModelConvNextAE(MainModelAbstract): + """ + """ + + def __init__(self, kernel_size, latent_space_dims, + experiment_name: str, + # Target preprocessing params for the batch loader + tracker + step_size: float = None, + compress_lines: float = False, + # Other + log_level=logging.root.level): + super().__init__(experiment_name, step_size, compress_lines, log_level) + + self.kernel_size = kernel_size + self.latent_space_dims = latent_space_dims + self.reconstruction_loss = torch.nn.MSELoss(reduction="sum") + + self.fc1 = torch.nn.Linear(8192, + self.latent_space_dims) # 8192 = 1024*8 + self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) + + """ + Encode convolutions + """ + self.encoder = torch.nn.Sequential( + torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=1, + padding_mode='reflect'), + ResBlock1d(32), + ResBlock1d(32), + ResBlock1d(32), + torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=1, + padding_mode='reflect'), + ResBlock1d(64), + ResBlock1d(64), + ResBlock1d(64), + torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=1, + padding_mode='reflect'), + ResBlock1d(128), + ResBlock1d(128), + ResBlock1d(128), + torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=1, + padding_mode='reflect'), + ResBlock1d(256), + ResBlock1d(256), + ResBlock1d(256), + torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=1, + padding_mode='reflect'), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), + ResBlock1d(1024), + ResBlock1d(1024), + ResBlock1d(1024), + ) + + """ + Decode convolutions + """ + self.decoder = torch.nn.Sequential( + ResBlock1d(1024), + ResBlock1d(1024), + ResBlock1d(1024), + torch.nn.Conv1d( + 1024, 512, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + ResBlock1d(512), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 512, 256, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), + ResBlock1d(256), + ResBlock1d(256), + ResBlock1d(256), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 256, 128, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), + ResBlock1d(128), + ResBlock1d(128), + ResBlock1d(128), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 128, 64, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), + ResBlock1d(64), + ResBlock1d(64), + ResBlock1d(64), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 64, 32, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), + ResBlock1d(32), + ResBlock1d(32), + ResBlock1d(32), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), + torch.nn.Conv1d( + 32, 3, self.kernel_size, stride=1, padding=1, + padding_mode='reflect'), + ) + + @property + def params_for_checkpoint(self): + """All parameters necessary to create again the same model. Will be + used in the trainer, when saving the checkpoint state. Params here + will be used to re-create the model when starting an experiment from + checkpoint. You should be able to re-create an instance of your + model with those params.""" + # p = super().params_for_checkpoint() + p = {'kernel_size': self.kernel_size, + 'latent_space_dims': self.latent_space_dims, + 'experiment_name': self.experiment_name} + return p + + @classmethod + def _load_params(cls, model_dir): + p = super()._load_params(model_dir) + p['kernel_size'] = 3 + p['latent_space_dims'] = 32 + return p + + def forward(self, + input_streamlines: List[torch.tensor], + ): + """Run the model on a batch of sequences. + + Parameters + ---------- + input_streamlines: List[torch.tensor], + Batch of streamlines. Only used if previous directions are added to + the model. Used to compute directions; its last point will not be + used. + + Returns + ------- + model_outputs : List[Tensor] + Output data, ready to be passed to either `compute_loss()` or + `get_tracking_directions()`. + """ + + x = self.decode(self.encode(input_streamlines)) + return x + + def encode(self, x): + # x: list of tensors + if isinstance(x, list): + x = torch.stack(x) + x = torch.swapaxes(x, 1, 2) + + x = self.encoder(x) + self.encoder_out_size = (x.shape[1], x.shape[2]) + + # Flatten + h7 = x.reshape(-1, self.encoder_out_size[0] * self.encoder_out_size[1]) + + fc1 = self.fc1(h7) + return fc1 + + def decode(self, z): + fc = self.fc2(z) + fc_reshape = fc.view( + -1, self.encoder_out_size[0], self.encoder_out_size[1] + ) + z = self.decoder(fc_reshape) + return z + + def compute_loss(self, model_outputs, targets, average_results=True): + targets = torch.stack(targets) + targets = torch.swapaxes(targets, 1, 2) + mse = self.reconstruction_loss(model_outputs, targets) + + # loss_function_vae + # See Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) + # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) + # kld = torch.sum(kld_element).__mul__(-0.5) + + return mse, 1 diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 61f13640..a98a0d1d 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -50,6 +50,7 @@ from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset from dwi_ml.data.processing.streamlines.data_augmentation import ( + normalize_streamlines, reverse_streamlines, split_streamlines, resample_or_compress) from dwi_ml.data.processing.utils import add_noise_to_tensor from dwi_ml.models.main_models import MainModelOneInput, \ @@ -64,7 +65,9 @@ def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract, split_ratio: float = 0., noise_gaussian_size_forward: float = 0., noise_gaussian_size_loss: float = 0., - reverse_ratio: float = 0., log_level=logging.root.level): + reverse_ratio: float = 0., + normalize: bool = False, + log_level=logging.root.level): """ Parameters ---------- @@ -131,6 +134,7 @@ def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract, self.noise_gaussian_size_forward = noise_gaussian_size_forward self.noise_gaussian_size_loss = noise_gaussian_size_loss self.split_ratio = split_ratio + self.normalize = normalize self.reverse_ratio = reverse_ratio if self.split_ratio and not 0 <= self.split_ratio <= 1: raise ValueError('Split ratio must be a float between 0 and 1.') @@ -157,6 +161,7 @@ def params_for_checkpoint(self): 'noise_gaussian_size_forward': self.noise_gaussian_size_forward, 'noise_gaussian_size_loss': self.noise_gaussian_size_loss, 'reverse_ratio': self.reverse_ratio, + 'normalize': self.normalize, 'split_ratio': self.split_ratio, } return params @@ -225,6 +230,9 @@ def _data_augmentation_sft(self, sft): reverse_ids = ids[:int(len(ids) * self.reverse_ratio)] sft = reverse_streamlines(sft, reverse_ids) + if self.normalize: + sft = normalize_streamlines(sft) + return sft def add_noise_streamlines_forward(self, batch_streamlines, device): diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py new file mode 100644 index 00000000..297bced6 --- /dev/null +++ b/dwi_ml/viz/latent_streamlines.py @@ -0,0 +1,183 @@ +import logging + +from typing import Union, List, Tuple +from sklearn.manifold import TSNE +import numpy as np +import torch + +import matplotlib.pyplot as plt + +def plot_latent_streamlines( + encoded_streamlines: Union[np.ndarray, torch.Tensor], + save_path: str = None, + fig_size: Union[List, Tuple] = None, + random_state: int = 42, + max_subset_size: int = None + ): + """ + Projects and plots the latent space representation + of the streamlines using t-SNE dimensionality reduction. + + Parameters + ---------- + encoded_streamlines: Union[np.ndarray, torch.Tensor] + Latent space streamlines to plot of shape (N, latent_space_dim). + save_path: str + Path to save the figure. If not specified, the figure will be shown. + fig_size: List[int] or Tuple[int] + 2-valued figure size (x, y) + random_state: int + Random state for t-SNE. + max_subset_size: int: + In case of performance issues, you can limit the number of streamlines to plot. + """ + + if isinstance(encoded_streamlines, torch.Tensor): + latent_space_streamlines = encoded_streamlines.cpu().numpy() + else: + latent_space_streamlines = encoded_streamlines + + if max_subset_size is not None: + if not (max_subset_size > 0): + raise ValueError("A max_subset_size of an integer value greater than 0 is required.") + + # Only sample if we need to reduce the number of latent streamlines + # to show on the plot. + if (len(latent_space_streamlines) > max_subset_size): + sample_indices = np.random.choice(len(latent_space_streamlines), max_subset_size, replace=False) + latent_space_streamlines = latent_space_streamlines[sample_indices] # (max_subset_size, 2) + + # Project the data into 2 dimensions. + tsne = TSNE(n_components=2, random_state=random_state) + X_tsne = tsne.fit_transform(latent_space_streamlines) # Output (N, 2) + + + logging.info("New figure for t-SNE visualisation.") + fig, ax = plt.subplots() + if fig_size is not None: + fig.set_figheight(fig_size[0]) + fig.set_figwidth(fig_size[1]) + + ax.scatter(X_tsne[:, 0], X_tsne[:, 1], alpha=0.9, edgecolors='black', linewidths=0.5) + + if save_path is not None: + fig.savefig(save_path) + else: + plt.show() + + +class BundlesLatentSpaceVisualizer(object): + """ + Utility class that wraps a t-SNE projection of the latent space for multiple bundles. + The usage of this class is intented as follows: + 1. Create an instance of this class, + 2. Add the latent space streamlines for each bundle using "add_data_to_plot" + with its corresponding label. + 3. Fit and plot the t-SNE projection using the "plot" method. + + t-SNE projection can only leverage the fit_transform() with all the data that needs to + be projected at the same time since it aims to preserve the local structure of the data. + """ + def __init__(self, + save_path: str = None, + fig_size: Union[List, Tuple] = None, + random_state: int = 42, + max_subset_size: int = None + ): + """ + Parameters + ---------- + save_path: str + Path to save the figure. If not specified, the figure will be shown. + fig_size: List[int] or Tuple[int] + 2-valued figure size (x, y) + random_state: List + Random state for t-SNE. + max_subset_size: + In case of performance issues, you can limit the number of streamlines to plot + for each bundle. + """ + self.save_path = save_path + self.fig_size = fig_size + self.random_state = random_state + self.max_subset_size = max_subset_size + + self.tsne = TSNE(n_components=2, random_state=self.random_state) + self.bundles = {} + + + def add_data_to_plot(self, data: np.ndarray, label: str = '_'): + """ + Add unprojected data (no t-SNE, no PCA, etc.). + This should be directly the output of the model as a numpy array. + + Parameters + ---------- + data: str + Unprojected latent space streamlines (N, latent_space_dim). + label: str + Name of the bundle. Used for the legend. + """ + if isinstance(data, torch.Tensor): + latent_space_streamlines = data.cpu().numpy() + else: + latent_space_streamlines = data + + if self.max_subset_size is not None: + if not (self.max_subset_size > 0): + raise ValueError("A max_subset_size of an integer value greater than 0 is required.") + + # Only sample if we need to reduce the number of latent streamlines + # to show on the plot. + if (len(latent_space_streamlines) > self.max_subset_size): + sample_indices = np.random.choice(len(latent_space_streamlines), self.max_subset_size, replace=False) + latent_space_streamlines = latent_space_streamlines[sample_indices] # (max_subset_size, 2) + + self.bundles[label] = latent_space_streamlines + + def plot(self): + """ + Fit and plot the t-SNE projection of the latent space streamlines. + This should be called once after adding all the data to plot using "add_data_to_plot". + """ + nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) + logging.info("Plotting a total of {} streamlines".format(nb_streamlines)) + + bundles_indices = {} + current_start = 0 + for (bname, bdata) in self.bundles.items(): + bundles_indices[bname] = np.arange(current_start, current_start + bdata.shape[0]) + current_start += bdata.shape[0] + + assert current_start == nb_streamlines + + all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) + + logging.info("Fitting TSNE projection.") + all_projected_streamlines = self.tsne.fit_transform(all_streamlines) + + logging.info("New figure for t-SNE visualisation.") + fig, ax = plt.subplots() + if self.fig_size is not None: + fig.set_figheight(self.fig_size[0]) + fig.set_figwidth(self.fig_size[1]) + + for (bname, bdata) in self.bundles.items(): + bindices = bundles_indices[bname] + proj_data = all_projected_streamlines[bindices] + ax.scatter( + proj_data[:, 0], + proj_data[:, 1], + label=bname, + alpha=0.9, + edgecolors='black', + linewidths=0.5, + ) + + ax.legend() + + if self.save_path is not None: + fig.savefig(self.save_path) + else: + plt.show() + diff --git a/requirements.txt b/requirements.txt index 7739110b..c7ac2ff2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,8 @@ requests==2.28.* dipy==1.9.* scilpy==2.0.2 +packaging==23.2.* +platformdirs<4,>=3.1.1 # ------- # Other important dependencies diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 20d416b1..80685d0a 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -20,7 +20,7 @@ from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer from dwi_ml.io_utils import add_memory_args -from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.models.projects.ae_next_models import ModelConvNextAE from dwi_ml.training.trainers import DWIMLAbstractTrainer from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, prepare_batch_sampler) @@ -68,7 +68,7 @@ def init_from_args(args, sub_loggers_level): # Final model with Timer("\n\nPreparing model", newline=True, color='yellow'): # INPUTS: verifying args - model = ModelAE( + model = ModelConvNextAE( experiment_name=args.experiment_name, step_size=None, compress_lines=None, kernel_size=3, latent_space_dims=32, @@ -88,6 +88,7 @@ def init_from_args(args, sub_loggers_level): dataset=dataset, model=model, streamline_group_name=args.streamline_group_name, # OTHER + normalize=True, rng=args.rng, log_level=sub_loggers_level) logging.info("Loader user-defined parameters: " + diff --git a/scripts_python/ae_visualize_bundles.py b/scripts_python/ae_visualize_bundles.py new file mode 100644 index 00000000..fc225c17 --- /dev/null +++ b/scripts_python/ae_visualize_bundles.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import argparse +import logging +import pathlib +import torch +import numpy as np +from glob import glob +from os.path import expanduser +from dipy.tracking.streamline import set_number_of_points + +from scilpy.io.utils import (add_overwrite_arg, + assert_outputs_exist, + add_reference_arg, + add_verbose_arg) +from scilpy.io.streamlines import load_tractogram_with_reference +from dwi_ml.io_utils import (add_arg_existing_experiment_path, + add_memory_args) +from dwi_ml.models.projects.ae_next_models import ModelConvNextAE +from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer + + +def _build_arg_parser(): + p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, + description=__doc__) + # Mandatory + # Should only be False for debugging tests. + add_arg_existing_experiment_path(p) + # Add_args_testing_subj_hdf5(p) + + p.add_argument('in_bundles', + help="The 'glob' path to several bundles identified by their file name." + "e.g. FiberCupGroundTruth_filtered_bundle_0.tck") + + p.add_argument('out_path') + + # Options + p.add_argument('--batch_size', type=int) + add_memory_args(p) + + p.add_argument('--pick_at_random', action='store_true') + add_reference_arg(p) + add_overwrite_arg(p) + add_verbose_arg(p) + return p + +def load_bundles(p, args, files_list: list): + bundles = [] + for bundle_file in files_list: + bundle_sft = load_tractogram_with_reference(p, args, bundle_file) + bundle_sft.to_vox() + bundle_sft.to_corner() + bundles.append(bundle_sft) + return bundles + +def main(): + p = _build_arg_parser() + args = p.parse_args() + + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) + + # Verify output names + # Check experiment_path exists and best_model folder exists + # Assert_inputs_exist(p, args.hdf5_file) + assert_outputs_exist(p, args, []) + + # Device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # 1. Load model + logging.debug("Loading model.") + model = ModelConvNextAE.load_model_from_params_and_state( + args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) + + expanded = expanduser(args.in_bundles) + bundles_files = glob(expanded) + if isinstance(bundles_files, str): + bundles_files = [bundles_files] + + bundles_label = [pathlib.Path(l).stem for l in bundles_files] + bundles_sft = load_bundles(p, args, bundles_files) + + logging.info("Running model to compute loss") + + ls_viz = BundlesLatentSpaceVisualizer( + save_path=args.out_path) + + with torch.no_grad(): + for i, bundle_sft in enumerate(bundles_sft): + + # Resample + streamlines = torch.as_tensor(np.asarray(set_number_of_points(bundle_sft.streamlines, 256)), + dtype=torch.float32, device=device) + + latent_streamlines = model.encode(streamlines).cpu().numpy() # output of (N, 32) + ls_viz.add_data_to_plot(latent_streamlines, label=bundles_label[i]) + + ls_viz.plot() + + +if __name__ == '__main__': + main() diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index 31526453..b3618ab8 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -5,6 +5,8 @@ import torch +from tqdm import tqdm + from scilpy.io.utils import (add_overwrite_arg, assert_outputs_exist, add_reference_arg, @@ -12,8 +14,8 @@ from scilpy.io.streamlines import load_tractogram_with_reference from dipy.io.streamline import save_tractogram from dwi_ml.io_utils import (add_arg_existing_experiment_path, - add_memory_args) -from dwi_ml.models.projects.ae_models import ModelAE + add_memory_args) +from dwi_ml.models.projects.ae_next_models import ModelConvNextAE def _build_arg_parser(): @@ -47,6 +49,8 @@ def main(): p = _build_arg_parser() args = p.parse_args() + normalize = True + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, # but we will set trainer to user-defined level. sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' @@ -65,8 +69,9 @@ def main(): # 1. Load model logging.debug("Loading model.") - model = ModelAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_loggers_level) + model = ModelConvNextAE.load_model_from_params_and_state( + args.experiment_path + '/best_model', log_level=sub_loggers_level).to( + device) # model.set_context('training') # 2. Compute loss # tester = TesterOneInput(args.experiment_path, @@ -81,25 +86,32 @@ def main(): sft = load_tractogram_with_reference(p, args, args.in_tractogram) sft.to_vox() sft.to_corner() - bundle = sft.streamlines[0:5000] - - logging.info("Running model to compute loss") - - new_sft = sft.from_sft(bundle, sft) - save_tractogram(new_sft, 'orig_5000.trk') + bundle = sft.streamlines - with torch.no_grad(): - streamlines = [ - torch.as_tensor(s, dtype=torch.float32, device=device) - for s in bundle] - tmp_outputs = model(streamlines) - # latent = model.encode(streamlines) + if normalize: + sft.streamlines /= sft.dimensions - streamlines_output = [tmp_outputs[i, :, :].transpose(0, 1).cpu().numpy() for i in range(len(bundle))] + logging.info("Running model to compute loss") + batch_size = 5000 + batches = range(0, len(sft.streamlines), batch_size) + all_streamlines = [] + for i, b in enumerate(tqdm(batches)): + print(i, b) + with torch.no_grad(): + streamlines = [ + torch.as_tensor(s, dtype=torch.float32, device=device) + for s in bundle[i * batch_size:(i+1) * batch_size]] + tmp_outputs = model(streamlines) + # latent = model.encode(streamlines) + scaling = sft.dimensions if normalize else 1.0 + streamlines_output = [tmp_outputs[j, :, :].transpose( + 0, 1).cpu().numpy() * scaling + for j in range(tmp_outputs.shape[0])] + all_streamlines.extend(streamlines_output) # print(streamlines_output[0].shape) - new_sft = sft.from_sft(streamlines_output, sft) - save_tractogram(new_sft, args.out_tractogram) + new_sft = sft.from_sft(all_streamlines, sft) + save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False) # latent_output = [s.cpu().numpy() for s in latent] From 6853567d5f05c21651705f51e9fcb4d5785b238a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antoine=20Th=C3=A9berge?= Date: Wed, 25 Sep 2024 15:45:42 -0400 Subject: [PATCH 04/14] ENH: reduce padding --- command_ae.sh | 2 +- dwi_ml/models/projects/ae_next_models.py | 30 ++++++++++++------------ scripts_python/ae_train_model.py | 2 +- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/command_ae.sh b/command_ae.sh index 841e7ddf..4e925a56 100644 --- a/command_ae.sh +++ b/command_ae.sh @@ -8,7 +8,7 @@ ae_train_model.py $experiments \ fibercup_tracking.hdf5 \ target \ -v INFO \ - --batch_size_training 1800 \ + --batch_size_training 1100 \ --batch_size_units nb_streamlines \ --nb_subjects_per_batch 5 \ --learning_rate 0.00001*300 0.000005 \ diff --git a/dwi_ml/models/projects/ae_next_models.py b/dwi_ml/models/projects/ae_next_models.py index f68def5d..cc0108cd 100644 --- a/dwi_ml/models/projects/ae_next_models.py +++ b/dwi_ml/models/projects/ae_next_models.py @@ -80,27 +80,27 @@ def __init__(self, kernel_size, latent_space_dims, Encode convolutions """ self.encoder = torch.nn.Sequential( - torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=1, + torch.nn.Conv1d(3, 32, self.kernel_size+1, stride=1, padding=1, padding_mode='reflect'), ResBlock1d(32), ResBlock1d(32), ResBlock1d(32), - torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=1, + torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0, padding_mode='reflect'), ResBlock1d(64), ResBlock1d(64), ResBlock1d(64), - torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=1, + torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=0, padding_mode='reflect'), ResBlock1d(128), ResBlock1d(128), ResBlock1d(128), - torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=1, + torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=0, padding_mode='reflect'), ResBlock1d(256), ResBlock1d(256), ResBlock1d(256), - torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=1, + torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=0, padding_mode='reflect'), ResBlock1d(512), ResBlock1d(512), @@ -111,7 +111,7 @@ def __init__(self, kernel_size, latent_space_dims, ResBlock1d(512), ResBlock1d(512), ResBlock1d(512), - torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=1, + torch.nn.Conv1d(512, 1024, self.kernel_size, stride=2, padding=0, padding_mode='reflect'), ResBlock1d(1024), ResBlock1d(1024), @@ -125,8 +125,10 @@ def __init__(self, kernel_size, latent_space_dims, ResBlock1d(1024), ResBlock1d(1024), ResBlock1d(1024), + torch.nn.Upsample(scale_factor=2, mode="linear", + align_corners=False), torch.nn.Conv1d( - 1024, 512, self.kernel_size, stride=1, padding=1, + 1024, 512, self.kernel_size+1, stride=1, padding=1, padding_mode='reflect'), ResBlock1d(512), ResBlock1d(512), @@ -140,7 +142,7 @@ def __init__(self, kernel_size, latent_space_dims, torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), torch.nn.Conv1d( - 512, 256, self.kernel_size, stride=1, padding=1, + 512, 256, self.kernel_size+1, stride=1, padding=1, padding_mode='reflect'), ResBlock1d(256), ResBlock1d(256), @@ -148,7 +150,7 @@ def __init__(self, kernel_size, latent_space_dims, torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), torch.nn.Conv1d( - 256, 128, self.kernel_size, stride=1, padding=1, + 256, 128, self.kernel_size+1, stride=1, padding=1, padding_mode='reflect'), ResBlock1d(128), ResBlock1d(128), @@ -156,7 +158,7 @@ def __init__(self, kernel_size, latent_space_dims, torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), torch.nn.Conv1d( - 128, 64, self.kernel_size, stride=1, padding=1, + 128, 64, self.kernel_size+1, stride=1, padding=1, padding_mode='reflect'), ResBlock1d(64), ResBlock1d(64), @@ -164,15 +166,13 @@ def __init__(self, kernel_size, latent_space_dims, torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), torch.nn.Conv1d( - 64, 32, self.kernel_size, stride=1, padding=1, + 64, 32, self.kernel_size+1, stride=1, padding=1, padding_mode='reflect'), ResBlock1d(32), ResBlock1d(32), ResBlock1d(32), - torch.nn.Upsample(scale_factor=2, mode="linear", - align_corners=False), torch.nn.Conv1d( - 32, 3, self.kernel_size, stride=1, padding=1, + 32, 3, self.kernel_size+1, stride=1, padding=1, padding_mode='reflect'), ) @@ -192,7 +192,7 @@ def params_for_checkpoint(self): @classmethod def _load_params(cls, model_dir): p = super()._load_params(model_dir) - p['kernel_size'] = 3 + p['kernel_size'] = 2 p['latent_space_dims'] = 32 return p diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 80685d0a..175ed989 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -71,7 +71,7 @@ def init_from_args(args, sub_loggers_level): model = ModelConvNextAE( experiment_name=args.experiment_name, step_size=None, compress_lines=None, - kernel_size=3, latent_space_dims=32, + kernel_size=2, latent_space_dims=32, log_level=sub_loggers_level) logging.info("AEmodel final parameters:" + From 35fa5ab0d9052fe628b229c23eca487330ac01e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antoine=20Th=C3=A9berge?= Date: Wed, 25 Sep 2024 16:46:16 -0400 Subject: [PATCH 05/14] ENH: asym decoder --- dwi_ml/models/projects/ae_next_models.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/dwi_ml/models/projects/ae_next_models.py b/dwi_ml/models/projects/ae_next_models.py index cc0108cd..8a800191 100644 --- a/dwi_ml/models/projects/ae_next_models.py +++ b/dwi_ml/models/projects/ae_next_models.py @@ -122,8 +122,6 @@ def __init__(self, kernel_size, latent_space_dims, Decode convolutions """ self.decoder = torch.nn.Sequential( - ResBlock1d(1024), - ResBlock1d(1024), ResBlock1d(1024), torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), @@ -131,46 +129,30 @@ def __init__(self, kernel_size, latent_space_dims, 1024, 512, self.kernel_size+1, stride=1, padding=1, padding_mode='reflect'), ResBlock1d(512), - ResBlock1d(512), - ResBlock1d(512), - ResBlock1d(512), - ResBlock1d(512), - ResBlock1d(512), - ResBlock1d(512), - ResBlock1d(512), - ResBlock1d(512), torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), torch.nn.Conv1d( 512, 256, self.kernel_size+1, stride=1, padding=1, padding_mode='reflect'), ResBlock1d(256), - ResBlock1d(256), - ResBlock1d(256), torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), torch.nn.Conv1d( 256, 128, self.kernel_size+1, stride=1, padding=1, padding_mode='reflect'), ResBlock1d(128), - ResBlock1d(128), - ResBlock1d(128), torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), torch.nn.Conv1d( 128, 64, self.kernel_size+1, stride=1, padding=1, padding_mode='reflect'), ResBlock1d(64), - ResBlock1d(64), - ResBlock1d(64), torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=False), torch.nn.Conv1d( 64, 32, self.kernel_size+1, stride=1, padding=1, padding_mode='reflect'), ResBlock1d(32), - ResBlock1d(32), - ResBlock1d(32), torch.nn.Conv1d( 32, 3, self.kernel_size+1, stride=1, padding=1, padding_mode='reflect'), @@ -192,7 +174,6 @@ def params_for_checkpoint(self): @classmethod def _load_params(cls, model_dir): p = super()._load_params(model_dir) - p['kernel_size'] = 2 p['latent_space_dims'] = 32 return p From 72457d74650b9378f24286c6b1d6f74913f029fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antoine=20Th=C3=A9berge?= Date: Fri, 27 Sep 2024 13:40:29 -0400 Subject: [PATCH 06/14] ENH: convnext --- dwi_ml/models/projects/ae_models.py | 215 +++++++++++----------------- scripts_python/ae_train_model.py | 31 ++-- 2 files changed, 99 insertions(+), 147 deletions(-) diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 2e8b7034..d836fd0f 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -3,35 +3,11 @@ from typing import List import torch +from torch.nn import functional as F from dwi_ml.models.main_models import MainModelAbstract -class ResBlock1d(torch.nn.Module): - - def __init__(self, channels, stride=1): - super(ResBlock1d, self).__init__() - - self.block = torch.nn.Sequential( - torch.nn.Conv1d(channels, channels, kernel_size=1, - stride=stride, padding=0), - torch.nn.BatchNorm1d(channels), - torch.nn.GELU(), - torch.nn.Conv1d(channels, channels, kernel_size=3, - stride=stride, padding=1, padding_mode='reflect'), - torch.nn.BatchNorm1d(channels), - torch.nn.GELU(), - torch.nn.Conv1d(channels, channels, 1, - 1, 0), - torch.nn.BatchNorm1d(channels)) - - def forward(self, x): - identity = x - xp = self.block(x) - - return xp + identity - - class ModelAE(MainModelAbstract): """ Recurrent tracking model. @@ -42,7 +18,6 @@ class ModelAE(MainModelAbstract): deterministic (3D vectors) or probabilistic (based on probability distribution parameters). """ - def __init__(self, kernel_size, latent_space_dims, experiment_name: str, # Target preprocessing params for the batch loader + tracker @@ -54,7 +29,11 @@ def __init__(self, kernel_size, latent_space_dims, self.kernel_size = kernel_size self.latent_space_dims = latent_space_dims - self.reconstruction_loss = torch.nn.MSELoss(reduction="sum") + + self.pad = torch.nn.ReflectionPad1d(1) + + def pre_pad(m): + return torch.nn.Sequential(self.pad, m) self.fc1 = torch.nn.Linear(8192, self.latent_space_dims) # 8192 = 1024*8 @@ -63,102 +42,60 @@ def __init__(self, kernel_size, latent_space_dims, """ Encode convolutions """ - self.encoder = torch.nn.Sequential( - torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=1, - padding_mode='reflect'), - torch.nn.GELU(), - ResBlock1d(32), - ResBlock1d(32), - ResBlock1d(32), - torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=1, - padding_mode='reflect'), - torch.nn.GELU(), - ResBlock1d(64), - ResBlock1d(64), - ResBlock1d(64), - torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=1, - padding_mode='reflect'), - torch.nn.GELU(), - ResBlock1d(128), - ResBlock1d(128), - ResBlock1d(128), - torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=1, - padding_mode='reflect'), - torch.nn.GELU(), - ResBlock1d(256), - ResBlock1d(256), - ResBlock1d(256), - torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=1, - padding_mode='reflect'), - torch.nn.GELU(), - ResBlock1d(512), - ResBlock1d(512), - ResBlock1d(512), - torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=1, - padding_mode='reflect'), - torch.nn.GELU(), - ResBlock1d(1024), - ResBlock1d(1024), - ResBlock1d(1024), + self.encod_conv1 = pre_pad( + torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv2 = pre_pad( + torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv3 = pre_pad( + torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv4 = pre_pad( + torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv5 = pre_pad( + torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv6 = pre_pad( + torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=0) ) """ Decode convolutions """ - self.decoder = torch.nn.Sequential( - ResBlock1d(1024), - ResBlock1d(1024), - ResBlock1d(1024), - torch.nn.GELU(), - torch.nn.Conv1d( - 1024, 512, self.kernel_size, stride=1, padding=1, - padding_mode='reflect'), - torch.nn.GELU(), - ResBlock1d(512), - ResBlock1d(512), - ResBlock1d(512), - torch.nn.Upsample(scale_factor=2, mode="linear", - align_corners=False), - torch.nn.Conv1d( - 512, 256, self.kernel_size, stride=1, padding=1, - padding_mode='reflect'), - torch.nn.GELU(), - ResBlock1d(256), - ResBlock1d(256), - ResBlock1d(256), - torch.nn.Upsample(scale_factor=2, mode="linear", - align_corners=False), - torch.nn.Conv1d( - 256, 128, self.kernel_size, stride=1, padding=1, - padding_mode='reflect'), - torch.nn.GELU(), - ResBlock1d(128), - ResBlock1d(128), - ResBlock1d(128), - torch.nn.Upsample(scale_factor=2, mode="linear", - align_corners=False), - torch.nn.Conv1d( - 128, 64, self.kernel_size, stride=1, padding=1, - padding_mode='reflect'), - torch.nn.GELU(), - ResBlock1d(64), - ResBlock1d(64), - ResBlock1d(64), - torch.nn.Upsample(scale_factor=2, mode="linear", - align_corners=False), - torch.nn.Conv1d( - 64, 32, self.kernel_size, stride=1, padding=1, - padding_mode='reflect'), - torch.nn.GELU(), - ResBlock1d(32), - ResBlock1d(32), - ResBlock1d(32), - torch.nn.Upsample(scale_factor=2, mode="linear", - align_corners=False), - torch.nn.Conv1d( - 32, 3, self.kernel_size, stride=1, padding=1, - padding_mode='reflect'), - torch.nn.GELU(), + self.decod_conv1 = pre_pad( + torch.nn.Conv1d(1024, 512, self.kernel_size, stride=1, padding=0) + ) + self.upsampl1 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv2 = pre_pad( + torch.nn.Conv1d(512, 256, self.kernel_size, stride=1, padding=0) + ) + self.upsampl2 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv3 = pre_pad( + torch.nn.Conv1d(256, 128, self.kernel_size, stride=1, padding=0) + ) + self.upsampl3 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv4 = pre_pad( + torch.nn.Conv1d(128, 64, self.kernel_size, stride=1, padding=0) + ) + self.upsampl4 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv5 = pre_pad( + torch.nn.Conv1d(64, 32, self.kernel_size, stride=1, padding=0) + ) + self.upsampl5 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv6 = pre_pad( + torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0) ) @property @@ -208,13 +145,20 @@ def encode(self, x): x = torch.stack(x) x = torch.swapaxes(x, 1, 2) - x = self.encoder(x) - self.encoder_out_size = (x.shape[1], x.shape[2]) + h1 = F.relu(self.encod_conv1(x)) + h2 = F.relu(self.encod_conv2(h1)) + h3 = F.relu(self.encod_conv3(h2)) + h4 = F.relu(self.encod_conv4(h3)) + h5 = F.relu(self.encod_conv5(h4)) + h6 = self.encod_conv6(h5) + + self.encoder_out_size = (h6.shape[1], h6.shape[2]) # Flatten - h7 = x.view(-1, self.encoder_out_size[0] * self.encoder_out_size[1]) + h7 = h6.view(-1, self.encoder_out_size[0] * self.encoder_out_size[1]) fc1 = self.fc1(h7) + return fc1 def decode(self, z): @@ -222,21 +166,24 @@ def decode(self, z): fc_reshape = fc.view( -1, self.encoder_out_size[0], self.encoder_out_size[1] ) - z = self.decoder(fc_reshape) - return z + h1 = F.relu(self.decod_conv1(fc_reshape)) + h2 = self.upsampl1(h1) + h3 = F.relu(self.decod_conv2(h2)) + h4 = self.upsampl2(h3) + h5 = F.relu(self.decod_conv3(h4)) + h6 = self.upsampl3(h5) + h7 = F.relu(self.decod_conv4(h6)) + h8 = self.upsampl4(h7) + h9 = F.relu(self.decod_conv5(h8)) + h10 = self.upsampl5(h9) + h11 = self.decod_conv6(h10) + + return h11 def compute_loss(self, model_outputs, targets, average_results=True): + targets = torch.stack(targets) targets = torch.swapaxes(targets, 1, 2) - mse = self.reconstruction_loss(model_outputs, targets) - - # loss_function_vae - # See Appendix B from VAE paper: - # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - # https://arxiv.org/abs/1312.6114 - # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) - # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) - # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) - # kld = torch.sum(kld_element).__mul__(-0.5) - + reconstruction_loss = torch.nn.MSELoss(reduction="sum") + mse = reconstruction_loss(model_outputs, targets) return mse, 1 diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 175ed989..55d2f678 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -11,21 +11,22 @@ # comet_ml not used, but comet_ml requires to be imported before torch. # See bug report here https://github.com/Lightning-AI/lightning/issues/5829 # Importing now to solve issues later. -import comet_ml +import comet_ml # noqa F401 import torch -from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, add_verbose_arg +from scilpy.io.utils import ( + assert_inputs_exist, assert_outputs_exist, add_verbose_arg) from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer from dwi_ml.io_utils import add_memory_args from dwi_ml.models.projects.ae_next_models import ModelConvNextAE +from dwi_ml.models.projects.ae_models import ModelAE from dwi_ml.training.trainers import DWIMLAbstractTrainer from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, prepare_batch_sampler) from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) -from dwi_ml.training.utils.trainer import add_training_args from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader from dwi_ml.training.utils.experiment import ( add_mandatory_args_experiment_and_hdf5_path) @@ -39,7 +40,6 @@ def prepare_arg_parser(): add_mandatory_args_experiment_and_hdf5_path(p) add_args_batch_sampler(p) add_args_batch_loader(p) - #training_group = add_training_args(p) add_training_args(p) p.add_argument('streamline_group_name', help="Name of the group in hdf5") @@ -47,10 +47,8 @@ def prepare_arg_parser(): add_verbose_arg(p) # Additional arg for projects - #training_group.add_argument( - # '--clip_grad', type=float, default=None, - # help="Value to which the gradient norms to avoid exploding gradients." - # "\nDefault = None (not clipping).") + p.add_argument('--model', type=str, choices=['finta', 'convnext'], + help='Type of model to train') return p @@ -68,11 +66,18 @@ def init_from_args(args, sub_loggers_level): # Final model with Timer("\n\nPreparing model", newline=True, color='yellow'): # INPUTS: verifying args - model = ModelConvNextAE( - experiment_name=args.experiment_name, - step_size=None, compress_lines=None, - kernel_size=2, latent_space_dims=32, - log_level=sub_loggers_level) + if args.model == 'finta': + model = ModelAE( + experiment_name=args.experiment_name, + step_size=None, compress_lines=None, + kernel_size=3, latent_space_dims=32, + log_level=sub_loggers_level) + else: + model = ModelConvNextAE( + experiment_name=args.experiment_name, + step_size=None, compress_lines=None, + kernel_size=2, latent_space_dims=32, + log_level=sub_loggers_level) logging.info("AEmodel final parameters:" + format_dict_to_str(model.params_for_checkpoint)) From fc9510f29489cc9fdb227363495b5a64d76fa5e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antoine=20Th=C3=A9berge?= Date: Fri, 27 Sep 2024 13:53:50 -0400 Subject: [PATCH 07/14] ENH: remove viz bundles from branch --- scripts_python/ae_visualize_bundles.py | 107 ------------------------- 1 file changed, 107 deletions(-) delete mode 100644 scripts_python/ae_visualize_bundles.py diff --git a/scripts_python/ae_visualize_bundles.py b/scripts_python/ae_visualize_bundles.py deleted file mode 100644 index fc225c17..00000000 --- a/scripts_python/ae_visualize_bundles.py +++ /dev/null @@ -1,107 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import argparse -import logging -import pathlib -import torch -import numpy as np -from glob import glob -from os.path import expanduser -from dipy.tracking.streamline import set_number_of_points - -from scilpy.io.utils import (add_overwrite_arg, - assert_outputs_exist, - add_reference_arg, - add_verbose_arg) -from scilpy.io.streamlines import load_tractogram_with_reference -from dwi_ml.io_utils import (add_arg_existing_experiment_path, - add_memory_args) -from dwi_ml.models.projects.ae_next_models import ModelConvNextAE -from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer - - -def _build_arg_parser(): - p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, - description=__doc__) - # Mandatory - # Should only be False for debugging tests. - add_arg_existing_experiment_path(p) - # Add_args_testing_subj_hdf5(p) - - p.add_argument('in_bundles', - help="The 'glob' path to several bundles identified by their file name." - "e.g. FiberCupGroundTruth_filtered_bundle_0.tck") - - p.add_argument('out_path') - - # Options - p.add_argument('--batch_size', type=int) - add_memory_args(p) - - p.add_argument('--pick_at_random', action='store_true') - add_reference_arg(p) - add_overwrite_arg(p) - add_verbose_arg(p) - return p - -def load_bundles(p, args, files_list: list): - bundles = [] - for bundle_file in files_list: - bundle_sft = load_tractogram_with_reference(p, args, bundle_file) - bundle_sft.to_vox() - bundle_sft.to_corner() - bundles.append(bundle_sft) - return bundles - -def main(): - p = _build_arg_parser() - args = p.parse_args() - - # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, - # but we will set trainer to user-defined level. - sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' - - # General logging (ex, scilpy: Warning) - logging.getLogger().setLevel(level=logging.WARNING) - - # Verify output names - # Check experiment_path exists and best_model folder exists - # Assert_inputs_exist(p, args.hdf5_file) - assert_outputs_exist(p, args, []) - - # Device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - # 1. Load model - logging.debug("Loading model.") - model = ModelConvNextAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) - - expanded = expanduser(args.in_bundles) - bundles_files = glob(expanded) - if isinstance(bundles_files, str): - bundles_files = [bundles_files] - - bundles_label = [pathlib.Path(l).stem for l in bundles_files] - bundles_sft = load_bundles(p, args, bundles_files) - - logging.info("Running model to compute loss") - - ls_viz = BundlesLatentSpaceVisualizer( - save_path=args.out_path) - - with torch.no_grad(): - for i, bundle_sft in enumerate(bundles_sft): - - # Resample - streamlines = torch.as_tensor(np.asarray(set_number_of_points(bundle_sft.streamlines, 256)), - dtype=torch.float32, device=device) - - latent_streamlines = model.encode(streamlines).cpu().numpy() # output of (N, 32) - ls_viz.add_data_to_plot(latent_streamlines, label=bundles_label[i]) - - ls_viz.plot() - - -if __name__ == '__main__': - main() From 1b16e27bd9e9034c77d7cd8731ebdf77bc8b7a0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antoine=20Th=C3=A9berge?= Date: Fri, 27 Sep 2024 13:56:33 -0400 Subject: [PATCH 08/14] ENH: lazy-save autoencoded streamlines --- scripts_python/ae_visualize_streamlines.py | 58 ++++++++++++---------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index b3618ab8..825fe401 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -3,19 +3,21 @@ import argparse import logging +import nibabel as nib +import numpy as np import torch -from tqdm import tqdm - from scilpy.io.utils import (add_overwrite_arg, assert_outputs_exist, add_reference_arg, add_verbose_arg) from scilpy.io.streamlines import load_tractogram_with_reference -from dipy.io.streamline import save_tractogram +from scilpy.tracking.utils import save_tractogram +from dipy.tracking.streamline import set_number_of_points from dwi_ml.io_utils import (add_arg_existing_experiment_path, add_memory_args) from dwi_ml.models.projects.ae_next_models import ModelConvNextAE +from nibabel.streamlines import detect_format def _build_arg_parser(): @@ -50,7 +52,7 @@ def main(): args = p.parse_args() normalize = True - + tracts_format = detect_format(args.out_tractogram) # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, # but we will set trainer to user-defined level. sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' @@ -86,32 +88,38 @@ def main(): sft = load_tractogram_with_reference(p, args, args.in_tractogram) sft.to_vox() sft.to_corner() - bundle = sft.streamlines - if normalize: - sft.streamlines /= sft.dimensions + bundle = sft.streamlines logging.info("Running model to compute loss") batch_size = 5000 batches = range(0, len(sft.streamlines), batch_size) - all_streamlines = [] - for i, b in enumerate(tqdm(batches)): - print(i, b) - with torch.no_grad(): - streamlines = [ - torch.as_tensor(s, dtype=torch.float32, device=device) - for s in bundle[i * batch_size:(i+1) * batch_size]] - tmp_outputs = model(streamlines) - # latent = model.encode(streamlines) - scaling = sft.dimensions if normalize else 1.0 - streamlines_output = [tmp_outputs[j, :, :].transpose( - 0, 1).cpu().numpy() * scaling - for j in range(tmp_outputs.shape[0])] - all_streamlines.extend(streamlines_output) - - # print(streamlines_output[0].shape) - new_sft = sft.from_sft(all_streamlines, sft) - save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False) + + def _autoencode_streamlines(): + for i, b in enumerate(batches): + with torch.no_grad(): + s = np.asarray( + set_number_of_points( + bundle[i * batch_size:(i+1) * batch_size], + 256)) + if normalize: + s /= sft.dimensions + + streamlines = torch.as_tensor( + s, dtype=torch.float32, device=device) + tmp_outputs = model(streamlines).cpu().numpy() + # latent = model.encode(streamlines) + scaling = sft.dimensions if normalize else 1.0 + streamlines_output = tmp_outputs.transpose((0, 2, 1)) * scaling + for strml in streamlines_output: + yield strml, strml[0] + + # Need a nifti image to lazy-save a tractogram + fake_ref = nib.Nifti1Image(np.zeros(sft.dimensions), sft.affine) + + save_tractogram(_autoencode_streamlines(), tracts_format, fake_ref, + len(bundle), args.out_tractogram, 0, 999, False, False, + args.verbose) # latent_output = [s.cpu().numpy() for s in latent] From 6b6c21a5ed91ca30eb0fc1033942702c0618edc7 Mon Sep 17 00:00:00 2001 From: AntoineTheb Date: Mon, 30 Sep 2024 11:46:17 -0400 Subject: [PATCH 09/14] ENH: docstring --- dwi_ml/models/projects/ae_next_models.py | 120 +++++++++++++++++---- dwi_ml/viz/latent_streamlines.py | 3 +- scripts_python/ae_visualize_bundles.py | 110 +++++++++++++++++++ scripts_python/ae_visualize_streamlines.py | 8 +- 4 files changed, 219 insertions(+), 22 deletions(-) create mode 100644 scripts_python/ae_visualize_bundles.py diff --git a/dwi_ml/models/projects/ae_next_models.py b/dwi_ml/models/projects/ae_next_models.py index 8a800191..1bbe7b0f 100644 --- a/dwi_ml/models/projects/ae_next_models.py +++ b/dwi_ml/models/projects/ae_next_models.py @@ -6,6 +6,7 @@ from dwi_ml.models.main_models import MainModelAbstract + class Permute(torch.nn.Module): """This module returns a view of the tensor input with its dimensions permuted. From https://github.com/pytorch/vision/blob/main/torchvision/ops/misc.py#L308 # noqa E501 @@ -23,6 +24,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LayerNorm1d(torch.nn.LayerNorm): + """ Layer normalization module. Uses the same normalization as torch.nn.LayerNorm + but with the input tensor permuted before and after the normalization. This is + necessary because torch.nn.LayerNorm normalizes the last dimension of the input + tensor, but we want to normalize the middle dimension to fit with convolutional layers. + + From https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py # noqa E501 + """ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 1) x = torch.nn.functional.layer_norm( @@ -32,8 +40,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ResBlock1d(torch.nn.Module): + """ ConvNext residual block. + + Uses a 1D convolution with kernel size 7 and groups=channels to ensure that + each channel is processed independently. The block is composed of a layer + normalization, a GELU activation function, a linear layer with 4 times the + number of channels, a GELU activation function, and a linear layer with the + same number of channels as the input. The output is added to the input + tensor. + """ def __init__(self, channels, stride=1, norm=LayerNorm1d): + """ Constructor + + Parameters + ---------- + channels : int + Number of channels in the input tensor. + stride : int + Stride of the convolution. + norm : torch.nn.Module + Normalization layer to use. + """ + super(ResBlock1d, self).__init__() self.block = torch.nn.Sequential( @@ -48,7 +77,10 @@ def __init__(self, channels, stride=1, norm=LayerNorm1d): in_features=channels * 4, out_features=channels, bias=True), Permute((0, 2, 1))) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward pass. + """ + identity = x x = self.block(x) @@ -56,7 +88,18 @@ def forward(self, x): class ModelConvNextAE(MainModelAbstract): - """ + """ ConvNext autoencoder model, modified to work with 1D data. + + The model is composed of an encoder and a decoder. The encoder is composed + of a series of 1D convolutions and residual blocks. The decoder is composed + of residual blocks, upsampling layers, and 1D convolutions. The decoder is + much smaller than the encoder to reduce the number of parameters and + enforce a well defined latent space. + + References: + [1]: Liu, Z., Mao, H., Wu, C. Y., Feichtenhofer, C., Darrell, T., & Xie, S. + (2022). A convnet for the 2020s. In Proceedings of the IEEE/CVF conference + on computer vision and pattern recognition (pp. 11976-11986). """ def __init__(self, kernel_size, latent_space_dims, @@ -72,12 +115,9 @@ def __init__(self, kernel_size, latent_space_dims, self.latent_space_dims = latent_space_dims self.reconstruction_loss = torch.nn.MSELoss(reduction="sum") - self.fc1 = torch.nn.Linear(8192, - self.latent_space_dims) # 8192 = 1024*8 - self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) - """ - Encode convolutions + Encode convolutions. Follows the same structure as the original + ConvNext-Tiny architecture, but with 1D convolutions. """ self.encoder = torch.nn.Sequential( torch.nn.Conv1d(3, 32, self.kernel_size+1, stride=1, padding=1, @@ -117,9 +157,17 @@ def __init__(self, kernel_size, latent_space_dims, ResBlock1d(1024), ResBlock1d(1024), ) + """ + Latent space + """ + + self.fc1 = torch.nn.Linear(8192, + self.latent_space_dims) # 8192 = 1024*8 + self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) """ - Decode convolutions + Decode convolutions. Uses upsampling and 1D convolutions instead of + transposed convolutions to avoid checkerboard artifacts. """ self.decoder = torch.nn.Sequential( ResBlock1d(1024), @@ -199,7 +247,20 @@ def forward(self, x = self.decode(self.encode(input_streamlines)) return x - def encode(self, x): + def encode(self, x: List[torch.Tensor]) -> torch.Tensor: + """ Encode the input data. + + Parameters + ---------- + x : list of tensors + List of input tensors. + + Returns + ------- + torch.Tensor + Input data encoded to the latent space. + """ + # x: list of tensors if isinstance(x, list): x = torch.stack(x) @@ -214,7 +275,20 @@ def encode(self, x): fc1 = self.fc1(h7) return fc1 - def decode(self, z): + def decode(self, z: torch.Tensor) -> torch.Tensor: + """ Decode the input data. + + Parameters + ---------- + z : torch.Tensor + Input data in the latent space. + + Returns + ------- + torch.Tensor + Decoded data. + """ + fc = self.fc2(z) fc_reshape = fc.view( -1, self.encoder_out_size[0], self.encoder_out_size[1] @@ -223,17 +297,25 @@ def decode(self, z): return z def compute_loss(self, model_outputs, targets, average_results=True): + """Compute the loss of the model. + + Parameters + ---------- + model_outputs : List[Tensor] + Model outputs. + targets : List[Tensor] + Target data. + average_results : bool + If True, the loss will be averaged across the batch. + + Returns + ------- + loss : Tensor + Loss value. + """ + targets = torch.stack(targets) targets = torch.swapaxes(targets, 1, 2) mse = self.reconstruction_loss(model_outputs, targets) - # loss_function_vae - # See Appendix B from VAE paper: - # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - # https://arxiv.org/abs/1312.6114 - # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) - # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) - # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) - # kld = torch.sum(kld_element).__mul__(-0.5) - return mse, 1 diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index 297bced6..ffafc9c9 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -152,7 +152,8 @@ def plot(self): assert current_start == nb_streamlines all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) - + # Set NaNs to 0 + all_streamlines[np.isnan(all_streamlines)] = np.zeros(32) logging.info("Fitting TSNE projection.") all_projected_streamlines = self.tsne.fit_transform(all_streamlines) diff --git a/scripts_python/ae_visualize_bundles.py b/scripts_python/ae_visualize_bundles.py new file mode 100644 index 00000000..8f3260ed --- /dev/null +++ b/scripts_python/ae_visualize_bundles.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import argparse +import logging +import pathlib +import torch +import numpy as np +from glob import glob +from os.path import expanduser +from dipy.tracking.streamline import set_number_of_points + +from scilpy.io.utils import (add_overwrite_arg, + assert_outputs_exist, + add_reference_arg, + add_verbose_arg) +from scilpy.io.streamlines import load_tractogram_with_reference +from dwi_ml.io_utils import (add_arg_existing_experiment_path, + add_memory_args) +from dwi_ml.models.projects.ae_next_models import ModelConvNextAE +from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer + + +def _build_arg_parser(): + p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, + description=__doc__) + # Mandatory + # Should only be False for debugging tests. + add_arg_existing_experiment_path(p) + # Add_args_testing_subj_hdf5(p) + + p.add_argument('in_bundles', + help="The 'glob' path to several bundles identified by their file name." + "e.g. FiberCupGroundTruth_filtered_bundle_0.tck") + + p.add_argument('out_path') + + # Options + p.add_argument('--batch_size', type=int) + add_memory_args(p) + + p.add_argument('--pick_at_random', action='store_true') + add_reference_arg(p) + add_overwrite_arg(p) + add_verbose_arg(p) + return p + +def load_bundles(p, args, files_list: list): + bundles = [] + for bundle_file in files_list: + bundle_sft = load_tractogram_with_reference(p, args, bundle_file) + bundle_sft.to_vox() + bundle_sft.to_corner() + bundles.append(bundle_sft) + return bundles + +def main(): + p = _build_arg_parser() + args = p.parse_args() + + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) + + # Verify output names + # Check experiment_path exists and best_model folder exists + # Assert_inputs_exist(p, args.hdf5_file) + assert_outputs_exist(p, args, []) + + # Device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # 1. Load model + logging.debug("Loading model.") + model = ModelConvNextAE.load_model_from_params_and_state( + args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) + + expanded = expanduser(args.in_bundles) + bundles_files = glob(expanded) + if isinstance(bundles_files, str): + bundles_files = [bundles_files] + + bundles_label = [pathlib.Path(l).stem for l in bundles_files] + bundles_sft = load_bundles(p, args, bundles_files) + + logging.info("Running model to compute loss") + + ls_viz = BundlesLatentSpaceVisualizer( + save_path=args.out_path) + + with torch.no_grad(): + for i, bundle_sft in enumerate(bundles_sft): + + # Resample + s_resampled = np.asarray( + set_number_of_points(bundle_sft.streamlines, 256)) + streamlines = torch.as_tensor(s_resampled, + dtype=torch.float32, device=device) + + latent_streamlines = model.encode(streamlines).cpu().numpy() # output of (N, 32) + print(latent_streamlines.shape) + ls_viz.add_data_to_plot(latent_streamlines, label=bundles_label[i]) + + ls_viz.plot() + + +if __name__ == '__main__': + main() diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index 825fe401..59bd9110 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -30,11 +30,11 @@ def _build_arg_parser(): p.add_argument('in_tractogram', help="If set, saves the tractogram with the loss per point " - "as a data per point (color)") + "as a data per point (color)") p.add_argument('out_tractogram', help="If set, saves the tractogram with the loss per point " - "as a data per point (color)") + "as a data per point (color)") # Options p.add_argument('--batch_size', type=int) @@ -91,6 +91,10 @@ def main(): bundle = sft.streamlines + bundle = sft.streamlines + + bundle = set_number_of_points(bundle, 256) + logging.info("Running model to compute loss") batch_size = 5000 batches = range(0, len(sft.streamlines), batch_size) From fd1f77e161b8738c45ac7789d7e3d94a11872784 Mon Sep 17 00:00:00 2001 From: AntoineTheb Date: Tue, 1 Oct 2024 12:12:23 -0400 Subject: [PATCH 10/14] ENH: remove latent space viz stuff --- dwi_ml/viz/latent_streamlines.py | 184 ------------------------- scripts_python/ae_visualize_bundles.py | 110 --------------- 2 files changed, 294 deletions(-) delete mode 100644 dwi_ml/viz/latent_streamlines.py delete mode 100644 scripts_python/ae_visualize_bundles.py diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py deleted file mode 100644 index ffafc9c9..00000000 --- a/dwi_ml/viz/latent_streamlines.py +++ /dev/null @@ -1,184 +0,0 @@ -import logging - -from typing import Union, List, Tuple -from sklearn.manifold import TSNE -import numpy as np -import torch - -import matplotlib.pyplot as plt - -def plot_latent_streamlines( - encoded_streamlines: Union[np.ndarray, torch.Tensor], - save_path: str = None, - fig_size: Union[List, Tuple] = None, - random_state: int = 42, - max_subset_size: int = None - ): - """ - Projects and plots the latent space representation - of the streamlines using t-SNE dimensionality reduction. - - Parameters - ---------- - encoded_streamlines: Union[np.ndarray, torch.Tensor] - Latent space streamlines to plot of shape (N, latent_space_dim). - save_path: str - Path to save the figure. If not specified, the figure will be shown. - fig_size: List[int] or Tuple[int] - 2-valued figure size (x, y) - random_state: int - Random state for t-SNE. - max_subset_size: int: - In case of performance issues, you can limit the number of streamlines to plot. - """ - - if isinstance(encoded_streamlines, torch.Tensor): - latent_space_streamlines = encoded_streamlines.cpu().numpy() - else: - latent_space_streamlines = encoded_streamlines - - if max_subset_size is not None: - if not (max_subset_size > 0): - raise ValueError("A max_subset_size of an integer value greater than 0 is required.") - - # Only sample if we need to reduce the number of latent streamlines - # to show on the plot. - if (len(latent_space_streamlines) > max_subset_size): - sample_indices = np.random.choice(len(latent_space_streamlines), max_subset_size, replace=False) - latent_space_streamlines = latent_space_streamlines[sample_indices] # (max_subset_size, 2) - - # Project the data into 2 dimensions. - tsne = TSNE(n_components=2, random_state=random_state) - X_tsne = tsne.fit_transform(latent_space_streamlines) # Output (N, 2) - - - logging.info("New figure for t-SNE visualisation.") - fig, ax = plt.subplots() - if fig_size is not None: - fig.set_figheight(fig_size[0]) - fig.set_figwidth(fig_size[1]) - - ax.scatter(X_tsne[:, 0], X_tsne[:, 1], alpha=0.9, edgecolors='black', linewidths=0.5) - - if save_path is not None: - fig.savefig(save_path) - else: - plt.show() - - -class BundlesLatentSpaceVisualizer(object): - """ - Utility class that wraps a t-SNE projection of the latent space for multiple bundles. - The usage of this class is intented as follows: - 1. Create an instance of this class, - 2. Add the latent space streamlines for each bundle using "add_data_to_plot" - with its corresponding label. - 3. Fit and plot the t-SNE projection using the "plot" method. - - t-SNE projection can only leverage the fit_transform() with all the data that needs to - be projected at the same time since it aims to preserve the local structure of the data. - """ - def __init__(self, - save_path: str = None, - fig_size: Union[List, Tuple] = None, - random_state: int = 42, - max_subset_size: int = None - ): - """ - Parameters - ---------- - save_path: str - Path to save the figure. If not specified, the figure will be shown. - fig_size: List[int] or Tuple[int] - 2-valued figure size (x, y) - random_state: List - Random state for t-SNE. - max_subset_size: - In case of performance issues, you can limit the number of streamlines to plot - for each bundle. - """ - self.save_path = save_path - self.fig_size = fig_size - self.random_state = random_state - self.max_subset_size = max_subset_size - - self.tsne = TSNE(n_components=2, random_state=self.random_state) - self.bundles = {} - - - def add_data_to_plot(self, data: np.ndarray, label: str = '_'): - """ - Add unprojected data (no t-SNE, no PCA, etc.). - This should be directly the output of the model as a numpy array. - - Parameters - ---------- - data: str - Unprojected latent space streamlines (N, latent_space_dim). - label: str - Name of the bundle. Used for the legend. - """ - if isinstance(data, torch.Tensor): - latent_space_streamlines = data.cpu().numpy() - else: - latent_space_streamlines = data - - if self.max_subset_size is not None: - if not (self.max_subset_size > 0): - raise ValueError("A max_subset_size of an integer value greater than 0 is required.") - - # Only sample if we need to reduce the number of latent streamlines - # to show on the plot. - if (len(latent_space_streamlines) > self.max_subset_size): - sample_indices = np.random.choice(len(latent_space_streamlines), self.max_subset_size, replace=False) - latent_space_streamlines = latent_space_streamlines[sample_indices] # (max_subset_size, 2) - - self.bundles[label] = latent_space_streamlines - - def plot(self): - """ - Fit and plot the t-SNE projection of the latent space streamlines. - This should be called once after adding all the data to plot using "add_data_to_plot". - """ - nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) - logging.info("Plotting a total of {} streamlines".format(nb_streamlines)) - - bundles_indices = {} - current_start = 0 - for (bname, bdata) in self.bundles.items(): - bundles_indices[bname] = np.arange(current_start, current_start + bdata.shape[0]) - current_start += bdata.shape[0] - - assert current_start == nb_streamlines - - all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) - # Set NaNs to 0 - all_streamlines[np.isnan(all_streamlines)] = np.zeros(32) - logging.info("Fitting TSNE projection.") - all_projected_streamlines = self.tsne.fit_transform(all_streamlines) - - logging.info("New figure for t-SNE visualisation.") - fig, ax = plt.subplots() - if self.fig_size is not None: - fig.set_figheight(self.fig_size[0]) - fig.set_figwidth(self.fig_size[1]) - - for (bname, bdata) in self.bundles.items(): - bindices = bundles_indices[bname] - proj_data = all_projected_streamlines[bindices] - ax.scatter( - proj_data[:, 0], - proj_data[:, 1], - label=bname, - alpha=0.9, - edgecolors='black', - linewidths=0.5, - ) - - ax.legend() - - if self.save_path is not None: - fig.savefig(self.save_path) - else: - plt.show() - diff --git a/scripts_python/ae_visualize_bundles.py b/scripts_python/ae_visualize_bundles.py deleted file mode 100644 index 8f3260ed..00000000 --- a/scripts_python/ae_visualize_bundles.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import argparse -import logging -import pathlib -import torch -import numpy as np -from glob import glob -from os.path import expanduser -from dipy.tracking.streamline import set_number_of_points - -from scilpy.io.utils import (add_overwrite_arg, - assert_outputs_exist, - add_reference_arg, - add_verbose_arg) -from scilpy.io.streamlines import load_tractogram_with_reference -from dwi_ml.io_utils import (add_arg_existing_experiment_path, - add_memory_args) -from dwi_ml.models.projects.ae_next_models import ModelConvNextAE -from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer - - -def _build_arg_parser(): - p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, - description=__doc__) - # Mandatory - # Should only be False for debugging tests. - add_arg_existing_experiment_path(p) - # Add_args_testing_subj_hdf5(p) - - p.add_argument('in_bundles', - help="The 'glob' path to several bundles identified by their file name." - "e.g. FiberCupGroundTruth_filtered_bundle_0.tck") - - p.add_argument('out_path') - - # Options - p.add_argument('--batch_size', type=int) - add_memory_args(p) - - p.add_argument('--pick_at_random', action='store_true') - add_reference_arg(p) - add_overwrite_arg(p) - add_verbose_arg(p) - return p - -def load_bundles(p, args, files_list: list): - bundles = [] - for bundle_file in files_list: - bundle_sft = load_tractogram_with_reference(p, args, bundle_file) - bundle_sft.to_vox() - bundle_sft.to_corner() - bundles.append(bundle_sft) - return bundles - -def main(): - p = _build_arg_parser() - args = p.parse_args() - - # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, - # but we will set trainer to user-defined level. - sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' - - # General logging (ex, scilpy: Warning) - logging.getLogger().setLevel(level=logging.WARNING) - - # Verify output names - # Check experiment_path exists and best_model folder exists - # Assert_inputs_exist(p, args.hdf5_file) - assert_outputs_exist(p, args, []) - - # Device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - # 1. Load model - logging.debug("Loading model.") - model = ModelConvNextAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) - - expanded = expanduser(args.in_bundles) - bundles_files = glob(expanded) - if isinstance(bundles_files, str): - bundles_files = [bundles_files] - - bundles_label = [pathlib.Path(l).stem for l in bundles_files] - bundles_sft = load_bundles(p, args, bundles_files) - - logging.info("Running model to compute loss") - - ls_viz = BundlesLatentSpaceVisualizer( - save_path=args.out_path) - - with torch.no_grad(): - for i, bundle_sft in enumerate(bundles_sft): - - # Resample - s_resampled = np.asarray( - set_number_of_points(bundle_sft.streamlines, 256)) - streamlines = torch.as_tensor(s_resampled, - dtype=torch.float32, device=device) - - latent_streamlines = model.encode(streamlines).cpu().numpy() # output of (N, 32) - print(latent_streamlines.shape) - ls_viz.add_data_to_plot(latent_streamlines, label=bundles_label[i]) - - ls_viz.plot() - - -if __name__ == '__main__': - main() From fbae16dc82a43ef2439b8b9d8a5fc4295c09d8c3 Mon Sep 17 00:00:00 2001 From: AntoineTheb Date: Tue, 1 Oct 2024 12:14:34 -0400 Subject: [PATCH 11/14] ENH: remove .sh --- command_ae.sh | 21 --------------------- 1 file changed, 21 deletions(-) delete mode 100644 command_ae.sh diff --git a/command_ae.sh b/command_ae.sh deleted file mode 100644 index 4e925a56..00000000 --- a/command_ae.sh +++ /dev/null @@ -1,21 +0,0 @@ -experiments=experiments -experiment_name=fibercup_september24 - -rm -rf $experiments/$experiment_name - -ae_train_model.py $experiments \ - $experiment_name \ - fibercup_tracking.hdf5 \ - target \ - -v INFO \ - --batch_size_training 1100 \ - --batch_size_units nb_streamlines \ - --nb_subjects_per_batch 5 \ - --learning_rate 0.00001*300 0.000005 \ - --weight_decay 0.2 \ - --optimizer Adam \ - --max_epochs 5000 \ - --max_batches_per_epoch_training 9999 \ - --comet_workspace dwi-ml \ - --comet_project ae-fibercup \ - --patience 100 --use_gpu From 9cd5436956d0dd5acb7b3a4c100c9b4e2b96a2fa Mon Sep 17 00:00:00 2001 From: AntoineTheb Date: Wed, 2 Oct 2024 16:32:14 -0400 Subject: [PATCH 12/14] ENH: cleanup --- dwi_ml/models/projects/ae_next_models.py | 51 ++++----- scripts_python/ae_autoencode_streamlines.py | 109 ++++++++++---------- 2 files changed, 82 insertions(+), 78 deletions(-) mode change 100644 => 100755 scripts_python/ae_autoencode_streamlines.py diff --git a/dwi_ml/models/projects/ae_next_models.py b/dwi_ml/models/projects/ae_next_models.py index 1bbe7b0f..b8218796 100644 --- a/dwi_ml/models/projects/ae_next_models.py +++ b/dwi_ml/models/projects/ae_next_models.py @@ -102,14 +102,24 @@ class ModelConvNextAE(MainModelAbstract): on computer vision and pattern recognition (pp. 11976-11986). """ - def __init__(self, kernel_size, latent_space_dims, - experiment_name: str, - # Target preprocessing params for the batch loader + tracker - step_size: float = None, - compress_lines: float = False, - # Other - log_level=logging.root.level): - super().__init__(experiment_name, step_size, compress_lines, log_level) + def __init__( + self, + kernel_size, + latent_space_dims, + experiment_name: str, + nb_points: int = None, + step_size: float = None, + compress_lines: float = False, + # Other + log_level=logging.root.level + ): + super().__init__( + experiment_name, + step_size=step_size, + nb_points=nb_points, + compress_lines=compress_lines, + log_level=log_level + ) self.kernel_size = kernel_size self.latent_space_dims = latent_space_dims @@ -219,15 +229,10 @@ def params_for_checkpoint(self): 'experiment_name': self.experiment_name} return p - @classmethod - def _load_params(cls, model_dir): - p = super()._load_params(model_dir) - p['latent_space_dims'] = 32 - return p - - def forward(self, - input_streamlines: List[torch.tensor], - ): + def forward( + self, + input_streamlines: List[torch.tensor], + ): """Run the model on a batch of sequences. Parameters @@ -257,7 +262,7 @@ def encode(self, x: List[torch.Tensor]) -> torch.Tensor: Returns ------- - torch.Tensor + z : torch.Tensor Input data encoded to the latent space. """ @@ -272,8 +277,8 @@ def encode(self, x: List[torch.Tensor]) -> torch.Tensor: # Flatten h7 = x.reshape(-1, self.encoder_out_size[0] * self.encoder_out_size[1]) - fc1 = self.fc1(h7) - return fc1 + z = self.fc1(h7) + return z def decode(self, z: torch.Tensor) -> torch.Tensor: """ Decode the input data. @@ -285,7 +290,7 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: Returns ------- - torch.Tensor + x_hat : torch.Tensor Decoded data. """ @@ -293,8 +298,8 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: fc_reshape = fc.view( -1, self.encoder_out_size[0], self.encoder_out_size[1] ) - z = self.decoder(fc_reshape) - return z + x_hat = self.decoder(fc_reshape) + return x_hat def compute_loss(self, model_outputs, targets, average_results=True): """Compute the loss of the model. diff --git a/scripts_python/ae_autoencode_streamlines.py b/scripts_python/ae_autoencode_streamlines.py old mode 100644 new mode 100755 index 59bd9110..03fa4862 --- a/scripts_python/ae_autoencode_streamlines.py +++ b/scripts_python/ae_autoencode_streamlines.py @@ -16,6 +16,7 @@ from dipy.tracking.streamline import set_number_of_points from dwi_ml.io_utils import (add_arg_existing_experiment_path, add_memory_args) +from dwi_ml.models.projects.ae_models import ModelAE from dwi_ml.models.projects.ae_next_models import ModelConvNextAE from nibabel.streamlines import detect_format @@ -36,22 +37,60 @@ def _build_arg_parser(): help="If set, saves the tractogram with the loss per point " "as a data per point (color)") + # Additional arg for projects + p.add_argument('--model', type=str, choices=['finta', 'convnext'], + default='finta', + help='Type of model to train') + # Options - p.add_argument('--batch_size', type=int) + p.add_argument('--normalize', action='store_true') + p.add_argument('--batch_size', type=int, default=5000) add_memory_args(p) - p.add_argument('--pick_at_random', action='store_true') add_reference_arg(p) add_overwrite_arg(p) add_verbose_arg(p) return p +def autoencode_streamlines( + model, sft, device, batch_size=5000, normalize=False +): + sft.to_vox() + sft.to_corner() + + bundle = set_number_of_points(sft.streamlines, 256) + + logging.info("Running model to compute loss") + batch_size = batch_size + batches = range(0, len(sft.streamlines), batch_size) + + def _autoencode_streamlines(): + for i, b in enumerate(batches): + with torch.no_grad(): + s = np.asarray( + set_number_of_points( + bundle[i * batch_size:(i+1) * batch_size], + 256)) + if normalize: + s /= sft.dimensions + + streamlines = torch.as_tensor( + s, dtype=torch.float32, device=device) + tmp_outputs = model(streamlines).cpu().numpy() + # latent = model.encode(streamlines) + scaling = sft.dimensions if normalize else 1.0 + streamlines_output = tmp_outputs.transpose((0, 2, 1)) * scaling + for strml in streamlines_output: + yield strml, strml[0] + + return _autoencode_streamlines + + def main(): p = _build_arg_parser() args = p.parse_args() - normalize = True tracts_format = detect_format(args.out_tractogram) # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, # but we will set trainer to user-defined level. @@ -71,66 +110,26 @@ def main(): # 1. Load model logging.debug("Loading model.") - model = ModelConvNextAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_loggers_level).to( - device) - # model.set_context('training') - # 2. Compute loss - # tester = TesterOneInput(args.experiment_path, - # model, - # args.batch_size, - # device) - # tester = Tester(args.experiment_path, model, args.batch_size, device) - # sft = tester.load_and_format_data(args.subj_id, - # args.hdf5_file, - # args.subset) + if args.model == 'finta': + architecture = ModelAE + else: + architecture = ModelConvNextAE - sft = load_tractogram_with_reference(p, args, args.in_tractogram) - sft.to_vox() - sft.to_corner() - - bundle = sft.streamlines + model = architecture.load_model_from_params_and_state( + args.experiment_path + '/best_model', log_level=sub_loggers_level + ).to(device) - bundle = sft.streamlines - - bundle = set_number_of_points(bundle, 256) - - logging.info("Running model to compute loss") - batch_size = 5000 - batches = range(0, len(sft.streamlines), batch_size) - - def _autoencode_streamlines(): - for i, b in enumerate(batches): - with torch.no_grad(): - s = np.asarray( - set_number_of_points( - bundle[i * batch_size:(i+1) * batch_size], - 256)) - if normalize: - s /= sft.dimensions + sft = load_tractogram_with_reference(p, args, args.in_tractogram) - streamlines = torch.as_tensor( - s, dtype=torch.float32, device=device) - tmp_outputs = model(streamlines).cpu().numpy() - # latent = model.encode(streamlines) - scaling = sft.dimensions if normalize else 1.0 - streamlines_output = tmp_outputs.transpose((0, 2, 1)) * scaling - for strml in streamlines_output: - yield strml, strml[0] + _autoencode_streamlines = autoencode_streamlines( + model, sft, device, args.batch_size, args.normalize) # Need a nifti image to lazy-save a tractogram fake_ref = nib.Nifti1Image(np.zeros(sft.dimensions), sft.affine) save_tractogram(_autoencode_streamlines(), tracts_format, fake_ref, - len(bundle), args.out_tractogram, 0, 999, False, False, - args.verbose) - - # latent_output = [s.cpu().numpy() for s in latent] - - # outputs, losses = tester.run_model_on_sft( - # sft, uncompress_loss=args.uncompress_loss, - # force_compress_loss=args.force_compress_loss, - # weight_with_angle=args.weight_with_angle) + len(sft.streamlines), args.out_tractogram, + 0, 999, False, False, args.verbose) if __name__ == '__main__': From 25fe9209cd75e1031126132a6a874e23ef5927cb Mon Sep 17 00:00:00 2001 From: AntoineTheb Date: Thu, 3 Oct 2024 11:11:34 -0400 Subject: [PATCH 13/14] ENH: remove unused function --- dwi_ml/models/projects/ae_next_models.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/dwi_ml/models/projects/ae_next_models.py b/dwi_ml/models/projects/ae_next_models.py index b8218796..3ca42ee9 100644 --- a/dwi_ml/models/projects/ae_next_models.py +++ b/dwi_ml/models/projects/ae_next_models.py @@ -216,19 +216,6 @@ def __init__( padding_mode='reflect'), ) - @property - def params_for_checkpoint(self): - """All parameters necessary to create again the same model. Will be - used in the trainer, when saving the checkpoint state. Params here - will be used to re-create the model when starting an experiment from - checkpoint. You should be able to re-create an instance of your - model with those params.""" - # p = super().params_for_checkpoint() - p = {'kernel_size': self.kernel_size, - 'latent_space_dims': self.latent_space_dims, - 'experiment_name': self.experiment_name} - return p - def forward( self, input_streamlines: List[torch.tensor], From 34701e006540b9ac92f0625bd7663638b19ba762 Mon Sep 17 00:00:00 2001 From: AntoineTheb Date: Thu, 3 Oct 2024 17:03:36 -0400 Subject: [PATCH 14/14] FIX: documentation --- scripts_python/ae_autoencode_streamlines.py | 38 ++++++++++++++++----- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/scripts_python/ae_autoencode_streamlines.py b/scripts_python/ae_autoencode_streamlines.py index 03fa4862..7db2912c 100755 --- a/scripts_python/ae_autoencode_streamlines.py +++ b/scripts_python/ae_autoencode_streamlines.py @@ -29,18 +29,16 @@ def _build_arg_parser(): add_arg_existing_experiment_path(p) # Add_args_testing_subj_hdf5(p) - p.add_argument('in_tractogram', - help="If set, saves the tractogram with the loss per point " - "as a data per point (color)") + p.add_argument('in_tractogram', type=str, + help="Tractogram to autoencode.") - p.add_argument('out_tractogram', - help="If set, saves the tractogram with the loss per point " - "as a data per point (color)") + p.add_argument('out_tractogram', type=str, + help="Autoencoded tractogram.") # Additional arg for projects p.add_argument('--model', type=str, choices=['finta', 'convnext'], default='finta', - help='Type of model to train') + help='Type of model to use.') # Options p.add_argument('--normalize', action='store_true') @@ -56,10 +54,34 @@ def _build_arg_parser(): def autoencode_streamlines( model, sft, device, batch_size=5000, normalize=False ): + """ Autoencode a tractogram. Streamlines are loaded in batches to make + it possible to autoencode large tractogram. This function returns a + generator instead of an actual tractogram for the same reason. + + Parameters + ---------- + model: MainModelAbstract + Autoencoder + sft: StatefulTractogram + Tractogram to autoencode + device: torch.device + GPU or CPU. + batch_size: int + Number of streamlines to autoencode at once. + normalize: bool + Whether to normalize the streamline coordinates before + inputting them to the model. + + Returns + ------- + _autoencode_streamlines: generator + Generator function which will autoencode. + """ + sft.to_vox() sft.to_corner() - bundle = set_number_of_points(sft.streamlines, 256) + bundle = sft.streamlines logging.info("Running model to compute loss") batch_size = batch_size