diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index fa69d6fb..00000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,34 +0,0 @@ -version: 2.1 - -orbs: - python: circleci/python@1.2 - -jobs: - build-and-test-python38: - docker: - - image: cimg/python:3.8 - steps: - - checkout - - python/install-packages: - pkg-manager: pip - app-dir: ~/project/requirements/ - - run: - name: Run tests python 3.8 - command: pytest - build-and-test-python39: - docker: - - image: cimg/python:3.9 - steps: - - checkout - - python/install-packages: - pkg-manager: pip - app-dir: ~/project/requirements/ - - run: - name: Run tests python 3.9 - command: pytest - -workflows: - sample: - jobs: - - build-and-test-python38 - - build-and-test-python39 diff --git a/CITATION.cff b/CITATION.cff index cf3e5baf..81e44fad 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -4,10 +4,13 @@ authors: - family-names: Karkalousos given-names: Dimitrios orcid: https://orcid.org/0000-0001-5983-0322 + - family-names: Zhang + given-names: Chaoping + orcid: https://orcid.org/0000-0002-6004-983X - family-names: Caan given-names: Matthan orcid: https://orcid.org/0000-0002-5162-8880 title: "MRI Data Consistency" url: "https://github.com/wdika/mridc" -version: 0.1.1 -date-released: 2022-25-05 +version: 0.2.0 +date-released: 2022-12-09 diff --git a/Dockerfile b/Dockerfile index b715f2f8..2450f50e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,7 +29,7 @@ COPY . . # start building the final container FROM mridc-deps as mridc -ARG MRIDC_VERSION=0.1.1 +ARG MRIDC_VERSION=0.2.0 # Check that MRIDC_VERSION is set. Build will fail without this. Expose MRIDC and base container # version information as runtime environment variable for introspection purposes diff --git a/README.md b/README.md index 7719aa39..f13b3d18 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # Data Consistency for Magnetic Resonance Imaging [![CodeQL](https://github.com/wdika/mridc/actions/workflows/codeql-analysis.yml/badge.svg)](https://github.com/wdika/mridc/actions/workflows/codeql-analysis.yml) -[![CircleCI](https://circleci.com/gh/wdika/mridc/tree/main.svg?style=svg)](https://circleci.com/gh/wdika/mridc/tree/main) [![codecov](https://codecov.io/gh/wdika/mridc/branch/main/graph/badge.svg?token=KPPQ33DOTF)](https://codecov.io/gh/wdika/mridc) +[![Tox](https://github.com/wdika/mridc/actions/workflows/tox.yml/badge.svg)](https://github.com/wdika/mridc/actions/workflows/tox.yml) Code style: black --- @@ -36,7 +36,7 @@ The following models are implemented for quantitative imaging: 1.[quantitative Cascades of Independently Recurrent Inference Machines (qCIRIM)](https://iopscience.iop.org/article/10.1088/1361-6560/ac6cc2), 2.[quantitative End-to-End Variational Network (qE2EVN)](https://link.springer.com/chapter/10.1007/978-3-030-59713-9_7), 3.[quantitative Independently Recurrent Inference Machines (qIRIM)](http://arxiv.org/abs/2012.07819), -4.[quantitative Recurrent Inference Machines (qRIM)](https://www.sciencedirect.com/science/article/abs/pii/S1361841518306078?via%3Dihub), +4.[quantitative Recurrent Inference Machines (qRIM)](https://www.sciencedirect.com/science/article/abs/pii/S1361841518306078?via%3Dihub). _Note: Currently only the above models are implemented. More models can be added by extending the reconstruction models for quantitative imaging. If you wish to extend the toolbox, please open an issue._ diff --git a/codecov.yml b/codecov.yml index 9dd41480..44bf43af 100644 --- a/codecov.yml +++ b/codecov.yml @@ -13,6 +13,12 @@ coverage: # -------------- # which folders/files to ignore ignore: + - mridc/launch.py + - mridc/core/utils/* + - mridc/utils/arguments.py + - mridc/utils/distributed.py + - mridc/utils/export_utils.py + - mridc/utils/decorators/* - projects/* - setup.py diff --git a/docs/source/mridc.collections.common.parts.rst b/docs/source/mridc.collections.common.parts.rst index b06319e8..86c3f15f 100644 --- a/docs/source/mridc.collections.common.parts.rst +++ b/docs/source/mridc.collections.common.parts.rst @@ -36,6 +36,14 @@ mridc.collections.common.parts.rnn\_utils module :undoc-members: :show-inheritance: +mridc.collections.common.parts.training\_utils module +----------------------------------------------------- + +.. automodule:: mridc.collections.common.parts.training_utils + :members: + :undoc-members: + :show-inheritance: + mridc.collections.common.parts.utils module ------------------------------------------- diff --git a/docs/source/mridc.utils.rst b/docs/source/mridc.utils.rst index dea4ead2..d0b4107c 100644 --- a/docs/source/mridc.utils.rst +++ b/docs/source/mridc.utils.rst @@ -45,6 +45,14 @@ mridc.utils.config\_utils module :undoc-members: :show-inheritance: +mridc.utils.debug\_hook module +------------------------------ + +.. automodule:: mridc.utils.debug_hook + :members: + :undoc-members: + :show-inheritance: + mridc.utils.distributed module ------------------------------ diff --git a/mridc/collections/common/data/dataset.py b/mridc/collections/common/data/dataset.py index 928916be..c04a09f0 100644 --- a/mridc/collections/common/data/dataset.py +++ b/mridc/collections/common/data/dataset.py @@ -4,12 +4,13 @@ # Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/data/dataset.py from abc import ABC -from typing import Any, List +from typing import Any, Dict, List import numpy as np import torch.utils.data as pt_data +from torch.utils.data import Dataset, IterableDataset -__all__ = ["ConcatDataset"] +__all__ = ["ConcatDataset", "ConcatMapDataset"] class ConcatDataset(pt_data.IterableDataset, ABC): @@ -145,3 +146,95 @@ def random_generator(datasets, **kwargs): while True: yield np.random.choice(np.arange(num), p=p) + + +class ConcatMapDataset(Dataset): + """ + A dataset that accepts as argument multiple datasets and then samples from them based on the specified + sampling technique. + + Parameters + ---------- + datasets: A list of datasets to sample from. + shuffle: Whether to shuffle individual datasets. Only works with non-iterable datasets. Defaults to True. + sampling_technique: Sampling technique to choose which dataset to draw a sample from. Defaults to 'random'. + Currently supports 'random' and 'round-robin'. + sampling_probabilities: Probability values for sampling. Only used when sampling_technique = 'random'. + global_rank: Worker rank, used for partitioning map style datasets. Defaults to 0. + world_size: Total number of processes, used for partitioning map style datasets. Defaults to 1. + """ + + def __init__( + self, + datasets: List[Any], + sampling_technique: str = "temperature", + sampling_temperature: int = 5, + sampling_probabilities: List[float] = None, + consumed_samples: int = 0, + ): + super().__init__() + self.datasets = datasets + self.sampling_kwargs: Dict = {} + self.size = 0 + self.sampling_technique = sampling_technique + self.sampling_temperature = sampling_temperature + self.sampling_probabilities = sampling_probabilities + self.consumed_samples = consumed_samples + self.np_rng = np.random.RandomState(consumed_samples) + for dataset in datasets: + self.size += len(dataset) + self.dataset_index = np.zeros(len(self.datasets), dtype=np.uint8) + self.permuted_dataset_indices = [] + for dataset in self.datasets: + permuted_indices = np.arange(len(dataset)) + self.np_rng.shuffle(permuted_indices) + self.permuted_dataset_indices.append(permuted_indices) + if self.sampling_technique == "temperature": + lengths = [len(dataset) for dataset in datasets] + p = np.array(lengths) / np.sum(lengths) + p = np.power(p, 1 / self.sampling_temperature) + p = p / np.sum(p) + self.p = p + elif self.sampling_technique == "random": + if not self.sampling_probabilities: + raise ValueError( + "Random generator expects a 'sampling_probabilities' - a list of probability values corresponding " + "to each dataset." + ) + if len(self.sampling_probabilities) != len(self.datasets): + raise ValueError( + "Length of probabilities list must be equal to the number of datasets. " # type: ignore + f"Found {len(sampling_probabilities)} probs and {len(self.datasets)} datasets." # type: ignore + ) + p = np.array(self.sampling_probabilities) + self.p = p / np.sum(p) + + def __len__(self): + return self.size + + def _get_dataset_index(self, idx): + """Returns the index of the dataset to sample from.""" + if self.sampling_technique in ["temperature", "random"]: + return self.np_rng.choice(np.arange(len(self.datasets)), p=self.p) + elif self.sampling_technique == "round-robin": + return idx % len(self.datasets) + + def __getitem__(self, idx): + # Get the dataset we want to sample from + dataset_index = self._get_dataset_index(idx) + + # Get the index of the sample we want to fetch from the dataset + sample_idx = self.dataset_index[dataset_index] + + # If the sample idx > dataset size, reset to 0. + if sample_idx > len(self.datasets[dataset_index]): + sample_idx = 0 + self.dataset_index[dataset_index] = 0 + + # Sample index -> shuffled sample index + shuffled_sample_idx = self.permuted_dataset_indices[dataset_index][sample_idx] + + sample = self.datasets[dataset_index][shuffled_sample_idx] + self.dataset_index[dataset_index] += 1 + + return sample diff --git a/mridc/collections/common/parts/patch_utils.py b/mridc/collections/common/parts/patch_utils.py index 398b0c4a..1c093e1d 100644 --- a/mridc/collections/common/parts/patch_utils.py +++ b/mridc/collections/common/parts/patch_utils.py @@ -7,4 +7,4 @@ # Library version globals TORCH_VERSION = None -TORCH_VERSION_MIN = version.Version("1.9.0") +TORCH_VERSION_MIN = version.Version("1.8.0") diff --git a/mridc/collections/common/parts/training_utils.py b/mridc/collections/common/parts/training_utils.py new file mode 100644 index 00000000..fde1822c --- /dev/null +++ b/mridc/collections/common/parts/training_utils.py @@ -0,0 +1,28 @@ +# encoding: utf-8 +__author__ = "Dimitrios Karkalousos" + +# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/parts/training_utils.py + +from contextlib import nullcontext + +import torch + +__all__ = ["avoid_bfloat16_autocast_context", "avoid_float16_autocast_context"] + + +def avoid_bfloat16_autocast_context(): + """If the current autocast context is bfloat16, cast it to float32.""" + if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16: + return torch.cuda.amp.autocast(dtype=torch.float32) + else: + return nullcontext() + + +def avoid_float16_autocast_context(): + """If the current autocast context is float16, cast it to bfloat16 if available or float32.""" + if not torch.is_autocast_enabled() or torch.get_autocast_gpu_dtype() != torch.float16: + return nullcontext() + if torch.cuda.is_bf16_supported(): + return torch.cuda.amp.autocast(dtype=torch.bfloat16) + else: + return torch.cuda.amp.autocast(dtype=torch.float32) diff --git a/mridc/core/classes/export.py b/mridc/core/classes/export.py index fc823f89..8ffd72f1 100644 --- a/mridc/core/classes/export.py +++ b/mridc/core/classes/export.py @@ -5,6 +5,7 @@ from abc import ABC from os.path import exists +from typing import List, Union import torch from torch.onnx import TrainingMode @@ -47,9 +48,10 @@ def export( do_constant_folding=True, onnx_opset_version=None, training=TrainingMode.EVAL, - check_trace: bool = False, + check_trace: Union[bool, List[torch.Tensor]] = False, dynamic_axes=None, check_tolerance=0.01, + export_modules_as_functions: bool = False, ): """ Export the module to a file. @@ -65,6 +67,7 @@ def export( check_trace: If True, check the trace of the exported model. dynamic_axes: A dictionary of input names and dynamic axes. check_tolerance: The tolerance for the check_trace. + export_modules_as_functions: If True, export modules as functions. """ all_out = [] all_descr = [] @@ -81,6 +84,7 @@ def export( check_trace=check_trace, dynamic_axes=dynamic_axes, check_tolerance=check_tolerance, + export_modules_as_functions=export_modules_as_functions, ) # Propagate input example (default scenario, may need to be overriden) if input_example is not None: @@ -101,6 +105,7 @@ def _export( check_trace: bool = False, dynamic_axes=None, check_tolerance=0.01, + export_modules_as_functions: bool = False, ): """ Helper to export the module to a file. @@ -116,15 +121,13 @@ def _export( check_trace: If True, check the trace of the exported model. dynamic_axes: A dictionary of input names and dynamic axes. check_tolerance: The tolerance for the check_trace. + export_modules_as_functions: If True, export modules as functions. """ my_args = locals().copy() my_args.pop("self") - exportables = [] - for m in self.modules(): # type: ignore - if isinstance(m, Exportable): - exportables.append(m) - qual_name = self.__module__ + "." + self.__class__.__qualname__ + exportables = [m for m in self.modules() if isinstance(m, Exportable)] # type: ignore + qual_name = f"{self.__module__}.{self.__class__.__qualname__}" format = get_export_format(output) output_descr = f"{qual_name} exported to {format}" @@ -191,10 +194,12 @@ def _export( do_constant_folding=do_constant_folding, dynamic_axes=dynamic_axes, opset_version=onnx_opset_version, + export_modules_as_functions=export_modules_as_functions, ) if check_trace: - verify_runtime(output, input_list, input_dict, input_names, output_names, output_example) + check_trace_input = [input_example] if isinstance(check_trace, bool) else check_trace + verify_runtime(self, output, check_trace_input, input_names) else: raise ValueError(f"Encountered unknown export format {format}.") diff --git a/mridc/core/classes/modelPT.py b/mridc/core/classes/modelPT.py index 45127980..a5518d8a 100644 --- a/mridc/core/classes/modelPT.py +++ b/mridc/core/classes/modelPT.py @@ -539,8 +539,8 @@ def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = N for i in range(len(scheduler_config["name"])) ] - self._scheduler = _schedulers - self._optimizer = [self._optimizer] * len(scheduler_config["name"]) + self._scheduler = _schedulers # type: ignore + self._optimizer = [self._optimizer] * len(scheduler_config["name"]) # type: ignore else: # Try to instantiate scheduler for optimizer self._scheduler = mridc.core.optim.lr_scheduler.prepare_lr_scheduler( # type: ignore diff --git a/mridc/core/conf/schedulers.py b/mridc/core/conf/schedulers.py index 3601f2d9..872a1ee7 100644 --- a/mridc/core/conf/schedulers.py +++ b/mridc/core/conf/schedulers.py @@ -81,6 +81,16 @@ class NoamAnnealingParams(WarmupSchedulerParams): min_lr: float = 0.0 +@dataclass +class NoamHoldAnnealingParams(WarmupHoldSchedulerParams): + """ + Polynomial Hold Decay Annealing parameter config. + It is not derived from Config as it is not a MRIDC object (and in particular it doesn't need a name). + """ + + decay_rate: float = 0.5 + + @dataclass class WarmupAnnealingParams(WarmupSchedulerParams): """Warmup Annealing parameter config""" @@ -205,6 +215,7 @@ def get_scheduler_config(name: str, **kwargs: Optional[Dict[str, Any]]) -> parti "SquareRootConstantSchedulerParams": SquareRootConstantSchedulerParams, "CosineAnnealingParams": CosineAnnealingParams, "NoamAnnealingParams": NoamAnnealingParams, + "NoamHoldAnnealingParams": NoamHoldAnnealingParams, "WarmupAnnealingParams": WarmupAnnealingParams, "PolynomialDecayAnnealingParams": PolynomialDecayAnnealingParams, "PolynomialHoldDecayAnnealingParams": PolynomialHoldDecayAnnealingParams, diff --git a/mridc/core/conf/trainer.py b/mridc/core/conf/trainer.py index d9e9ef01..a4861583 100644 --- a/mridc/core/conf/trainer.py +++ b/mridc/core/conf/trainer.py @@ -18,17 +18,13 @@ class TrainerConfig: """TrainerConfig is a dataclass that holds all the hyperparameters for the training process.""" logger: Any = True - checkpoint_callback: Any = True callbacks: Optional[Any] = None default_root_dir: Optional[str] = None gradient_clip_val: float = 0 - process_position: int = 0 num_nodes: int = 1 gpus: Optional[Any] = None auto_select_gpus: bool = False tpu_cores: Optional[Any] = None - log_gpu_memory: Optional[str] = None - progress_bar_refresh_rate: int = 1 enable_progress_bar: bool = True overfit_batches: Any = 0.0 track_grad_norm: Any = -1 @@ -37,18 +33,16 @@ class TrainerConfig: accumulate_grad_batches: Any = 1 max_epochs: int = 1000 min_epochs: int = 1 - max_steps: Optional[int] = None + max_steps: Optional[int] = -1 min_steps: Optional[int] = None limit_train_batches: Any = 1.0 limit_val_batches: Any = 1.0 limit_test_batches: Any = 1.0 val_check_interval: Any = 1.0 - flush_logs_every_n_steps: int = 100 log_every_n_steps: int = 50 accelerator: Optional[str] = None sync_batchnorm: bool = False precision: Any = 32 - weights_summary: Optional[str] = "full" # ModelSummary.MODE_DEFAULT weights_save_path: Optional[str] = None num_sanity_val_steps: int = 2 resume_from_checkpoint: Optional[str] = None @@ -58,23 +52,20 @@ class TrainerConfig: auto_lr_find: Any = False replace_sampler_ddp: bool = True detect_anomaly: bool = False - terminate_on_nan: bool = False auto_scale_batch_size: Any = False - prepare_data_per_node: bool = True amp_backend: str = "native" amp_level: Optional[str] = None plugins: Optional[Any] = None # Optional[Union[str, list]] move_metrics_to_cpu: bool = False multiple_trainloader_mode: str = "max_size_cycle" limit_predict_batches: float = 1.0 - stochastic_weight_avg: bool = False gradient_clip_algorithm: str = "norm" max_time: Optional[Any] = None # can be one of Union[str, timedelta, Dict[str, int], None] reload_dataloaders_every_n_epochs: int = 0 ipus: Optional[int] = None devices: Any = None strategy: Any = None - enable_checkpointing: bool = True + enable_checkpointing: bool = False enable_model_summary: bool = True diff --git a/mridc/core/optim/lr_scheduler.py b/mridc/core/optim/lr_scheduler.py index efc70d1e..3ae54003 100644 --- a/mridc/core/optim/lr_scheduler.py +++ b/mridc/core/optim/lr_scheduler.py @@ -422,6 +422,15 @@ def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle): return lr +def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps, decay_rate, min_lr): + """Anneal learning rate by noam hold.""" + # hold_steps = total number of steps to hold the LR, not the warmup + hold steps. + T_warmup_decay = max(1, warmup_steps**decay_rate) + T_hold_decay = max(1, (step - hold_steps) ** decay_rate) + lr = (initial_lr * T_warmup_decay) / T_hold_decay + return max(lr, min_lr) + + class SquareAnnealing(WarmupPolicy): """Anneal learning rate by square.""" @@ -561,13 +570,82 @@ def get_lr(self): def _noam_annealing(self, initial_lr, step): """Noam learning rate annealing.""" - mult = self._normalize * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5))) + mult = ( + self._normalize * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5))) + if self.warmup_steps > 0 + else self._normalize * step ** (-0.5) + ) out_lr = initial_lr * mult if step > self.warmup_steps: out_lr = max(out_lr, self.min_lr) return out_lr +class NoamHoldAnnealing(WarmupHoldPolicy): + def __init__(self, optimizer, *, max_steps, decay_rate=0.5, min_lr=0.0, last_epoch=-1, **kwargs): + """ + Implementation of the Noam Hold Annealing policy from the SqueezeFormer paper. + + Unlike NoamAnnealing, the peak learning rate can be explicitly set for this scheduler. + The schedule first performs linear warmup, then holds the peak LR, then decays with some schedule for + the remainder of the steps. Therefore the min-lr is still dependent on the hyper parameters selected. + + It's schedule is determined by three factors- + + Warmup Steps: Initial stage, where linear warmup occurs uptil the peak LR is reached. Unlike NoamAnnealing, + the peak LR is explicitly stated here instead of a scaling factor. + + Hold Steps: Intermediate stage, where the peak LR is maintained for some number of steps. In this region, + the high peak LR allows the model to converge faster if training is stable. However the high LR + may also cause instability during training. Should usually be a significant fraction of training + steps (around 30-40% of the entire training steps). + + Decay Steps: Final stage, where the LR rapidly decays with some scaling rate (set by decay rate). + To attain Noam decay, use 0.5, for Squeezeformer recommended decay, use 1.0. The fast decay after + prolonged high LR during hold phase allows for rapid convergence. + + References: + - [Squeezeformer: An Efficient Transformer for Automatic Speech Recognition](https://arxiv.org/abs/2206.00888) + + Parameters + ---------- + optimizer : torch.optim.Optimizer + Optimizer to use for the scheduler. + max_steps : int + Total number of training steps. + decay_rate : float + Decay rate for the final stage of the schedule. Should be between 0 and 1. + min_lr : float + Minimum learning rate to use for the schedule. Should be between 0 and 1. + last_epoch : int + Last epoch to start the schedule from. Should be between 0 and max_steps. + """ + self.decay_rate = decay_rate + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + """Get the learning rate for the given step.""" + if self.warmup_steps is None or self.warmup_steps == 0: + raise ValueError("Noam scheduler cannot be used without warmup steps") + + if self.hold_steps > 0: + hold_steps = self.hold_steps - self.warmup_steps + else: + hold_steps = 0 + + return [ + _noam_hold_annealing( + initial_lr, + step=step, + warmup_steps=self.warmup_steps, + hold_steps=hold_steps, + decay_rate=self.decay_rate, + min_lr=self.min_lr, + ) + for initial_lr in self.base_lrs + ] + + class WarmupAnnealing(WarmupPolicy): """Warmup learning rate annealing.""" @@ -922,6 +1000,9 @@ def prepare_lr_scheduler( if add_max_args_flag and scheduler_config.get("name", "") != "ExponentialLR": scheduler_args["max_steps"] = max_steps + if scheduler_config.get("name", "") == "CyclicLR": + del scheduler_args["max_steps"] + # Get the scheduler class from the config scheduler_cls = get_scheduler(scheduler_name, **scheduler_args) @@ -960,6 +1041,7 @@ def compute_max_steps( logging.warning( "Please note that drop_last is broken in pytorch 1.6.0. We will fix when pytorch 1.7.0 is released" ) + # TODO: Master version, not in pytorch 1.6.0 steps_per_epoch = _round(sampler_num_samples / batch_size) if isinstance(limit_train_batches, int) or limit_train_batches == 0.0: @@ -977,6 +1059,7 @@ def compute_max_steps( "SquareAnnealing": SquareAnnealing, "CosineAnnealing": CosineAnnealing, "NoamAnnealing": NoamAnnealing, + "NoamHoldAnnealing": NoamHoldAnnealing, "WarmupAnnealing": WarmupAnnealing, "InverseSquareRootAnnealing": InverseSquareRootAnnealing, "T5InverseSquareRootAnnealing": T5InverseSquareRootAnnealing, diff --git a/mridc/core/optim/optimizer_with_master_params.py b/mridc/core/optim/optimizer_with_master_params.py index 4cd3e026..776953b0 100644 --- a/mridc/core/optim/optimizer_with_master_params.py +++ b/mridc/core/optim/optimizer_with_master_params.py @@ -172,7 +172,7 @@ def __init__( self._fp32_grad_accum = fp32_grad_accum self._contiguous_grad_bucket = contiguous_grad_bucket - self._async_grad_allreduce = async_grad_allreduce + self._async_grad_allreduce = async_grad_allreduce and get_data_parallel_world_size() > 1 self._grad_divisor = 1 / get_data_parallel_world_size() if self._async_grad_allreduce: @@ -279,7 +279,7 @@ def _make_param_hook(self, param, main_param, i, grad_chunk_info): def param_hook(*unused): """Gradient accumulation and all-reduce.""" - if param.grad.data is None: + if param.grad is None: return if main_param.grad is None: main_param.grad = param.grad.float() @@ -400,6 +400,11 @@ def reload_model_params(self): @torch.no_grad() def step(self, **kwargs): """Step the optimizer.""" + # While async grad allreduce is enabled, bprop will keep moving forward without waiting for the finish of + # async grad AR works. Hence, to guarantee the correctness of grads reduction, we cannot start weight update + # until all async grad AR works are done. + if self._async_grad_allreduce: + torch.cuda.synchronize() self.optimizer.step(closure=None, **kwargs) # Update params from main params. with torch.no_grad(): @@ -462,7 +467,7 @@ def get_parameters(self): def _get_state(self): """Promote state, so it can be retrieved or set via "optimizer_instance.state.""" - return self.optimizer.state + return self.optimizer.state if hasattr(self, "optimizer") else [] def _set_state(self, value): """Promote state, so it can be retrieved or set via "optimizer_instance.state.""" @@ -475,10 +480,20 @@ def _get_param_groups(self): Promote param_groups, so it can be retrieved or set via "optimizer_instance.param_groups. (for example, to adjust the learning rate) """ - return self.optimizer.param_groups + return self.optimizer.param_groups if hasattr(self, "optimizer") else [] def _set_param_groups(self, value): """Set param_groups.""" self.optimizer.param_groups = value param_groups = property(_get_param_groups, _set_param_groups) + + def _get_defaults(self): + """Promote defaults, so it can be retrieved or set via 'optimizer_instance.default'.""" + return self.optimizer.defaults if hasattr(self, "optimizer") else [] + + def _set_defaults(self, value): + """Set defaults.""" + self.optimizer.defaults = value + + defaults = property(_get_defaults, _set_defaults) diff --git a/mridc/package_info.py b/mridc/package_info.py index 83b01498..b817e22e 100644 --- a/mridc/package_info.py +++ b/mridc/package_info.py @@ -1,8 +1,8 @@ # encoding: utf-8 MAJOR = 0 -MINOR = 1 -PATCH = 1 +MINOR = 2 +PATCH = 0 PRE_RELEASE = "" # Use the following formatting: (major, minor, patch, pre-release) diff --git a/mridc/utils/debug_hook.py b/mridc/utils/debug_hook.py new file mode 100644 index 00000000..007f4b5d --- /dev/null +++ b/mridc/utils/debug_hook.py @@ -0,0 +1,222 @@ +# encoding: utf-8 +__author__ = "Dimitrios Karkalousos" + +# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/debug_hook.py + +import os + +import torch + + +def get_forward_hook(name, trainer, rank, logger, dump_to_file=False): + """ + A forward hook to dump all the module input and output norms. It is called at every time after forward() has + computed an output. Only float type input/output tensor norms are computed. + + For more details about the forward hook, check: + https://pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html + + Parameters + ---------- + name : str + tensor name + trainer : PTL trainer + PTL trainer + rank : int + worker rank + logger : PTL log function + PTL log function + dump_to_file : bool, optional + wether dump the csv file to the disk, by default False + + Returns + ------- + forward_hook + """ + if dump_to_file: + os.makedirs("debug_info", exist_ok=True) + fp = open(f"debug_info/forward_{name}_rank{rank}.txt", "w") + header = False + + def forward_hook(module, inputs, outputs): + """Forward hook to dump all of the module input and output norms. It is called at every time after forward() + has computed an output. Only float type input/output tensor norms are computed.""" + nonlocal header + nonlocal fp + if trainer.training: + values = [] + headers = [] + for n, i in enumerate(inputs): + if isinstance(i, torch.Tensor) and i.dtype in [torch.float, torch.half, torch.bfloat16]: + if not header: + headers.append("input") + input_norm = i.data.norm() + values.append(f"{input_norm}") + logger(f"debug_info_forward/{name}_rank{rank}_input{n}", input_norm) + if isinstance(outputs, tuple): + for n, i in enumerate(outputs): + if isinstance(i, torch.Tensor) and i.dtype in [torch.float, torch.half, torch.bfloat16]: + if not header: + headers.append("output") + output_norm = i.data.norm() + values.append(f"{output_norm}") + logger(f"debug_info_forward/{name}_rank{rank}_output{n}", output_norm) + else: + headers.append("output") + values.append(f"{outputs.data.norm()}") + values.append(f"{trainer.global_step}") + if not header: + headers.append("step") + fp.write(",".join(headers) + "\n") + header = True + fp.write(",".join(values) + "\n") + fp.flush() + + return forward_hook + + +def get_backward_hook(name, trainer, rank, logger, dump_to_file=False): + """ + A backward hook to dump all the module input and output grad norms. The hook will be called every time the \ + gradients with respect to module inputs are computed. Only float type input/output grad tensor norms are computed. + + For more details about the backward hook, check: + https://pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_full_backward_hook.html + + Parameters + ---------- + name : str + tensor name + trainer : PTL trainer + PTL trainer + rank : int + worker rank + logger : PTL log function + PTL log function + dump_to_file : bool, optional + wether dump the csv file to the disk, by default False + + Returns + ------- + backward_hook + """ + if dump_to_file: + os.makedirs("debug_info", exist_ok=True) + fp = open(f"debug_info/backward_{name}_rank{rank}.txt", "w") + header = False + + def backward_hook(module, inputs, outputs): + """Backward hook to dump all the module input and output grad norms. The hook will be called every time the \ + has computed an output. Only float type input/output tensor norms are computed.""" + nonlocal header + nonlocal fp + if trainer.training: + values = [] + headers = [] + for n, i in enumerate(inputs): + if isinstance(i, torch.Tensor) and i.dtype in [torch.float, torch.half, torch.bfloat16]: + if not header: + headers.append("input") + input_norm = i.data.norm() + values.append(f"{input_norm}") + logger(f"debug_info_backward/{name}_rank{rank}_input{n}", input_norm) + if isinstance(outputs, tuple): + for n, i in enumerate(outputs): + if isinstance(i, torch.Tensor) and i.dtype in [torch.float, torch.half, torch.bfloat16]: + if not header: + headers.append("output") + output_norm = i.data.norm() + values.append(f"{output_norm}") + logger(f"debug_info_backward/{name}_rank{rank}_output{n}", output_norm) + else: + headers.append("output") + values.append(f"{outputs.data.norm()}") + values.append(f"{trainer.global_step}") + if not header: + headers.append("step") + fp.write(",".join(headers) + "\n") + header = True + fp.write(",".join(values) + "\n") + fp.flush() + + return backward_hook + + +def get_tensor_hook(module, name, trainer, rank, logger, dump_to_file=False): + """ + A tensor hook to dump all of the tensor weight norms and grad norms at the end of each of the backward steps. + + For more details about the tensor hook, check: + https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html + + Parameters + ---------- + module : torch.nn.Module + module to register the hook + name : str + tensor name + trainer : PTL trainer + PTL trainer + rank : int + worker rank + logger : PTL log function + PTL log function + dump_to_file : bool, optional + wether dump the csv file to the disk, by default False + + Returns + ------- + tensor_hook + """ + if dump_to_file: + os.makedirs("debug_info", exist_ok=True) + fp = open(f"debug_info/tensor_{name}_rank{rank}.csv", "w") + header = False + + def tensor_hook(grad): + """Tensor hook to dump all the tensor weight norms and grad norms at the end of each of the backward steps.""" + nonlocal header + nonlocal fp + values = [] + headers = [] + + weight = module.get_parameter(name) + weight_norm = weight.data.norm() + grad_norm = grad.data.norm() + logger(f"debug_info_tensors/{name}_rank{rank}_grad_norm", grad_norm) + logger(f"debug_info_tensors/{name}_rank{rank}_weight_norm", weight_norm) + values.append(f"{weight_norm}") + values.append(f"{grad_norm}") + values.append(f"{trainer.global_step}") + if dump_to_file: + if not header: + headers.append("weight") + headers.append("grad") + headers.append("step") + fp.write(",".join(headers) + "\n") + header = True + fp.write(",".join(values) + "\n") + fp.flush() + return grad + + return tensor_hook + + +def register_debug_hooks(module, trainer, logger, dump_to_file=False): + """ + Register debug hooks. It can + 1. track the module forward step input/output norm + 2. track the module backward step input/output grad norm + 3. track the parameter weight norm and grad norm. + """ + # default rank 0 + rank = 0 + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + for name, tensor in module.named_parameters(): + if name != "": + tensor.register_hook(get_tensor_hook(module, name, trainer, rank, logger, dump_to_file)) + for name, layer in module.named_modules(): + if name != "": + layer.register_forward_hook(get_forward_hook(name, trainer, rank, logger, dump_to_file)) + layer.register_full_backward_hook(get_backward_hook(name, trainer, rank, logger, dump_to_file)) diff --git a/mridc/utils/exp_manager.py b/mridc/utils/exp_manager.py index 8b8922ec..ef744736 100644 --- a/mridc/utils/exp_manager.py +++ b/mridc/utils/exp_manager.py @@ -844,19 +844,31 @@ def on_train_end(self, trainer, pl_module): if trainer.fast_dev_run: return None + # check if we need to save a last checkpoint manually as validation isn't always run based on the interval + if self.save_last and trainer.val_check_interval != 0: + should_save_last_checkpoint = False + if isinstance(trainer.val_check_interval, float) and trainer.val_check_interval % trainer.global_step != 0: + should_save_last_checkpoint = True + if isinstance(trainer.val_check_interval, int) and trainer.global_step % trainer.val_check_interval != 0: + should_save_last_checkpoint = True + if should_save_last_checkpoint: + monitor_candidates = self._monitor_candidates(trainer) + super()._save_last_checkpoint(trainer, monitor_candidates) + # Call parent on_train_end() to save the -last checkpoint super().on_train_end(trainer, pl_module) # Load the best model and then re-save it if self.save_best_model: # wait for all processes to finish - trainer.training_type_plugin.barrier("SaveBestCheckpointConnector.resume_end") + trainer.strategy.barrier("SaveBestCheckpointConnector.resume_end") if self.best_model_path == "": logging.warning( f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints " "were found. Saving latest model instead." ) else: + self.best_model_path = trainer.strategy.broadcast(self.best_model_path) trainer._checkpoint_connector.restore(self.best_model_path) if self.save_mridc_on_train_end: diff --git a/mridc/utils/export_utils.py b/mridc/utils/export_utils.py index bc469e1a..e7d1333f 100644 --- a/mridc/utils/export_utils.py +++ b/mridc/utils/export_utils.py @@ -30,6 +30,21 @@ class ExportFormat(Enum): _EXT_DICT = {".pt": ExportFormat.TORCHSCRIPT, ".ts": ExportFormat.TORCHSCRIPT, ".onnx": ExportFormat.ONNX} +def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): + """Cast tensor from from_dtype to to_dtype""" + return x.to(dtype=to_dtype) if x.dtype == from_dtype else x + + +def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): + """Cast all tensors in x from from_dtype to to_dtype""" + if isinstance(x, torch.Tensor): + return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) + if isinstance(x, dict): + return {k: cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) for k in x.keys()} + if isinstance(x, tuple): + return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) + + class CastToFloat(nn.Module): """Cast input to float""" @@ -39,7 +54,9 @@ def __init__(self, mod): def forward(self, x): """Forward pass""" - return self.mod.forward(x.to(torch.float).to(x.dtype)) if torch.is_autocast_enabled() else self.mod.forward(x) + return ( + self.mod.forward(x.to(torch.float32).to(x.dtype)) if torch.is_autocast_enabled() else self.mod.forward(x) + ) def get_export_format(filename: str): @@ -91,43 +108,30 @@ def parse_input_example(input_example): return input_list, input_dict -def to_onnxrt_input(input_names, input_dict, input_list): - """Transforms input to onnxrt input format""" - return { - k: input_dict[k].cpu().numpy() if k in input_dict else input_list.pop().cpu().numpy() - for k in reversed(input_names) - } +def to_onnxrt_input(ort_input_names, input_names, input_dict, input_list): + """Convert input to onnxrt input""" + odict = {} + for k in reversed(input_names): + if k in input_dict: + val = input_dict[k].cpu().numpy() + else: + val = input_list.pop().cpu().numpy() + if k in ort_input_names: + odict[k] = val + return odict def verify_runtime( + model, output, - input_list, - input_dict, + input_examples, input_names, - output_names, - output_example, check_tolerance=0.01, ): - """ - Verify runtime output with onnxrt. - - Parameters - ---------- - output: The output of the module. - input_list: The input list of the module. - input_dict: The input dict of the module. - input_names: The input names of the module. - output_names: The output names of the module. - output_example: The output example of the module. - check_tolerance: The tolerance for the check. - - Returns - ------- - The runtime output. - """ - # Verify the model can be read, and is valid + """Verify runtime output with onnxrt.""" onnx_model = onnx.load(output) - input_names = [node.name for node in onnx_model.graph.input] + ort_input_names = [node.name for node in onnx_model.graph.input] + # skipcq: PYL-W0622 global ort_available if not ort_available: @@ -137,22 +141,33 @@ def verify_runtime( onnx_session_opt = onnxruntime.SessionOptions() onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - sess = onnxruntime.InferenceSession( onnx_model.SerializeToString(), sess_options=onnx_session_opt, providers=["CUDAExecutionProvider"] ) - ort_out = sess.run(output_names, to_onnxrt_input(input_names, input_dict, input_list)) all_good = True - for i, out in enumerate(ort_out[0]): + for input_example in input_examples: + input_list, input_dict = parse_input_example(input_example) + output_example = model.forward(*input_list, **input_dict) + ort_input = to_onnxrt_input(ort_input_names, input_names, input_dict, input_list) + all_good = all_good and run_ort_and_compare(sess, ort_input, output_example, check_tolerance) + status = "SUCCESS" if all_good else "FAIL" + logging.info(f"ONNX generated at {output} verified with onnxruntime : {status}") + return all_good + + +def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): + """Run onnxrt and compare with output example""" + ort_out = sess.run(None, ort_input) + all_good = True + for i, out in enumerate(ort_out): expected = output_example[i] if torch.is_tensor(expected): tout = torch.from_numpy(out) + logging.info(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): all_good = False logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") - status = "SUCCESS" if all_good else "FAIL" - logging.info(f"ONNX generated at {output} verified with onnxruntime : {status}") return all_good diff --git a/mridc/utils/get_rank.py b/mridc/utils/get_rank.py index 2a66032d..36d3e10a 100644 --- a/mridc/utils/get_rank.py +++ b/mridc/utils/get_rank.py @@ -2,6 +2,9 @@ __author__ = "Dimitrios Karkalousos" # Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/get_rank.py + +import torch + from mridc.utils.env_var_parsing import get_envint @@ -22,3 +25,8 @@ def is_global_rank_zero(): node_rank = get_envint("NODE_RANK", get_envint("GROUP_RANK", 0)) local_rank = get_envint("LOCAL_RANK", 0) return node_rank == 0 and local_rank == 0 + + +def get_rank(): + """Helper function that returns torch.distributed.get_rank() if DDP has been initialized otherwise returns 0.""" + return 0 if is_global_rank_zero() else torch.distributed.get_rank() diff --git a/tests/core/test_optimizers_schedulers.py b/tests/core/test_optimizers_schedulers.py index a3296938..2fdf2047 100644 --- a/tests/core/test_optimizers_schedulers.py +++ b/tests/core/test_optimizers_schedulers.py @@ -162,16 +162,29 @@ class TestOptimizersSchedulers: INITIAL_LR = 0.1 MIN_LR = 1e-3 MAX_STEPS = 10 + D_MODEL = 16 # fused_adam is looking for CUDA and this test is being run on CPU only tests @pytest.mark.unit def test_get_optimizer(self): """Test that the optimizer is correctly created""" model = TempModel() + if torch.cuda.is_available(): + model.cuda() for opt_name in AVAILABLE_OPTIMIZERS: if opt_name == "fused_adam" and not torch.cuda.is_available(): continue + if opt_name == "distributed_fused_adam": + if not torch.cuda.is_available() or not torch.distributed.is_nccl_available(): + continue + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + "nccl", + world_size=1, + rank=0, + store=torch.distributed.HashStore(), + ) opt_cls = get_optimizer(opt_name) if opt_name == "adafactor": # Adafactor's default mode uses relative_step without any lr. @@ -713,6 +726,57 @@ def test_CosineAnnealing(self): if final_lr != self.MIN_LR: raise AssertionError + # Noam scheduler should decay past MAX_STEPS - run two schedulers in parallel to test it + @pytest.mark.unit + def test_NoamAnnealing(self): + model = TempModel() + opt_cls = get_optimizer("novograd") + opt1 = opt_cls(model.parameters(), lr=self.INITIAL_LR) + opt2 = opt_cls(model.parameters(), lr=self.INITIAL_LR) + + # No warmup case + policy1 = optim.lr_scheduler.NoamAnnealing( + opt1, d_model=self.D_MODEL, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR + ) + policy2 = optim.lr_scheduler.NoamAnnealing( + opt2, d_model=self.D_MODEL, max_steps=self.MAX_STEPS * 2, min_lr=self.MIN_LR + ) + initial_lr = policy1.get_last_lr()[0] + + assert initial_lr == self.D_MODEL ** (-0.5) * self.INITIAL_LR + + for _ in range(self.MAX_STEPS * 2): + if policy1.get_last_lr()[0] > self.INITIAL_LR: + raise AssertionError + assert policy1.get_last_lr()[0] <= policy2.get_last_lr()[0] + opt1.step() + opt2.step() + policy1.step() + policy2.step() + + # Warmup steps available + policy1 = optim.lr_scheduler.NoamAnnealing( + opt1, d_model=self.D_MODEL, warmup_steps=5, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR + ) + policy2 = optim.lr_scheduler.NoamAnnealing( + opt2, d_model=self.D_MODEL, warmup_steps=5, max_steps=self.MAX_STEPS * 2, min_lr=self.MIN_LR + ) + initial_lr = policy1.get_last_lr()[0] + + assert initial_lr < self.INITIAL_LR + + for i in range(self.MAX_STEPS * 2): + if i <= 5: + assert policy1.get_last_lr()[0] <= self.INITIAL_LR + else: + assert self.MIN_LR <= policy1.get_last_lr()[0] <= self.INITIAL_LR + assert policy1.get_last_lr()[0] <= policy2.get_last_lr()[0] + + opt1.step() + opt2.step() + policy1.step() + policy2.step() + @pytest.mark.unit def test_PolynomialDecayAnnealing(self): """Test PolynomialDecayAnnealing"""