diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index cbc700bc..42a18ace 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -30,3 +30,4 @@ jobs: flags: pytests # optional name: codecov-umbrella # optional verbose: true # optional (default = false) + version: "v0.1.15" diff --git a/.readthedocs.yaml b/.readthedocs.yaml index f7339cdf..bfb350f4 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -10,10 +10,6 @@ build: os: ubuntu-20.04 tools: python: "3.9" - # You can also specify other tool versions: - # nodejs: "16" - # rust: "1.55" - # golang: "1.17" # Build documentation in the docs/ directory with Sphinx sphinx: diff --git a/CITATION.cff b/CITATION.cff index aff1fa40..9ba70fcd 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -9,5 +9,5 @@ authors: orcid: https://orcid.org/0000-0002-5162-8880 title: "MRI Data Consistency" url: "https://github.com/wdika/mridc" -version: 0.0.1 -date-released: 2021-29-11 +version: 0.1.0 +date-released: 2022-25-05 diff --git a/README.md b/README.md index 9d098bd4..5fbfc3e4 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,6 @@ [![CodeQL](https://github.com/wdika/mridc/actions/workflows/codeql-analysis.yml/badge.svg)](https://github.com/wdika/mridc/actions/workflows/codeql-analysis.yml) [![CircleCI](https://circleci.com/gh/wdika/mridc/tree/main.svg?style=svg)](https://circleci.com/gh/wdika/mridc/tree/main) [![codecov](https://codecov.io/gh/wdika/mridc/branch/main/graph/badge.svg?token=KPPQ33DOTF)](https://codecov.io/gh/wdika/mridc) -[![DeepSource](https://deepsource.io/gh/wdika/mridc.svg/?label=active+issues&show_trend=true&token=txj87v43GA6vhpbSwPEUTQtX)](https://deepsource.io/gh/wdika/mridc/?ref=repository-badge) -[![DeepSource](https://deepsource.io/gh/wdika/mridc.svg/?label=resolved+issues&show_trend=true&token=txj87v43GA6vhpbSwPEUTQtX)](https://deepsource.io/gh/wdika/mridc/?ref=repository-badge) -[![Total alerts](https://img.shields.io/lgtm/alerts/g/wdika/mridc.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/wdika/mridc/alerts/) Code style: black --- @@ -76,11 +73,11 @@ Recommended public datasets to use with this repo: - [fastMRI](http://arxiv.org/abs/1811.08839), - [Fully Sampled Knees](http://old.mridata.org/fullysampled/knees/). -## Documentation +## API Documentation [![Documentation Status](https://readthedocs.org/projects/mridc/badge/?version=latest)](https://mridc.readthedocs.io/en/latest/?badge=latest) -Read the docs [here](https://mridc.readthedocs.io/en/latest/index.html) +Access the API Documentation [here](https://mridc.readthedocs.io/en/latest/modules.html) ## License diff --git a/docs/build/doctrees/environment.pickle b/docs/build/doctrees/environment.pickle new file mode 100644 index 00000000..9a592398 Binary files /dev/null and b/docs/build/doctrees/environment.pickle differ diff --git a/docs/build/doctrees/index.doctree b/docs/build/doctrees/index.doctree new file mode 100644 index 00000000..d682a1d1 Binary files /dev/null and b/docs/build/doctrees/index.doctree differ diff --git a/docs/build/doctrees/modules.doctree b/docs/build/doctrees/modules.doctree new file mode 100644 index 00000000..a1c70a65 Binary files /dev/null and b/docs/build/doctrees/modules.doctree differ diff --git a/docs/build/doctrees/mridc.collections.common.callbacks.doctree b/docs/build/doctrees/mridc.collections.common.callbacks.doctree new file mode 100644 index 00000000..c288415c Binary files /dev/null and b/docs/build/doctrees/mridc.collections.common.callbacks.doctree differ diff --git a/docs/build/doctrees/mridc.collections.common.data.doctree b/docs/build/doctrees/mridc.collections.common.data.doctree new file mode 100644 index 00000000..90813c1d Binary files /dev/null and b/docs/build/doctrees/mridc.collections.common.data.doctree differ diff --git a/docs/build/doctrees/mridc.collections.common.doctree b/docs/build/doctrees/mridc.collections.common.doctree new file mode 100644 index 00000000..71931779 Binary files /dev/null and b/docs/build/doctrees/mridc.collections.common.doctree differ diff --git a/docs/build/doctrees/mridc.collections.common.losses.doctree b/docs/build/doctrees/mridc.collections.common.losses.doctree new file mode 100644 index 00000000..47ce6a0e Binary files /dev/null and b/docs/build/doctrees/mridc.collections.common.losses.doctree differ diff --git a/docs/build/doctrees/mridc.collections.common.metrics.doctree b/docs/build/doctrees/mridc.collections.common.metrics.doctree new file mode 100644 index 00000000..802d887b Binary files /dev/null and b/docs/build/doctrees/mridc.collections.common.metrics.doctree differ diff --git a/docs/build/doctrees/mridc.collections.common.parts.doctree b/docs/build/doctrees/mridc.collections.common.parts.doctree new file mode 100644 index 00000000..11d6ba77 Binary files /dev/null and b/docs/build/doctrees/mridc.collections.common.parts.doctree differ diff --git a/docs/build/doctrees/mridc.collections.doctree b/docs/build/doctrees/mridc.collections.doctree new file mode 100644 index 00000000..2aa36776 Binary files /dev/null and b/docs/build/doctrees/mridc.collections.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.data.doctree b/docs/build/doctrees/mridc.collections.reconstruction.data.doctree new file mode 100644 index 00000000..8a9e92cf Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.data.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.doctree b/docs/build/doctrees/mridc.collections.reconstruction.doctree new file mode 100644 index 00000000..d3090266 Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.metrics.doctree b/docs/build/doctrees/mridc.collections.reconstruction.metrics.doctree new file mode 100644 index 00000000..0349eb4b Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.metrics.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.cascadenet.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.cascadenet.doctree new file mode 100644 index 00000000..64002d5a Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.cascadenet.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.conv.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.conv.doctree new file mode 100644 index 00000000..c8f05499 Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.conv.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.convrecnet.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.convrecnet.doctree new file mode 100644 index 00000000..59b8a4ca Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.convrecnet.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.crossdomain.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.crossdomain.doctree new file mode 100644 index 00000000..ec91de1f Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.crossdomain.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.didn.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.didn.doctree new file mode 100644 index 00000000..3794d3a6 Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.didn.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.doctree new file mode 100644 index 00000000..e00bdd05 Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.multidomain.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.multidomain.doctree new file mode 100644 index 00000000..f86b607f Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.multidomain.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.mwcnn.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.mwcnn.doctree new file mode 100644 index 00000000..f7a925f2 Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.mwcnn.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.primaldual.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.primaldual.doctree new file mode 100644 index 00000000..de81977d Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.primaldual.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.recurrentvarnet.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.recurrentvarnet.doctree new file mode 100644 index 00000000..dbcc72f9 Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.recurrentvarnet.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.rim.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.rim.doctree new file mode 100644 index 00000000..5ba86792 Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.rim.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.sigmanet.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.sigmanet.doctree new file mode 100644 index 00000000..c624edc2 Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.sigmanet.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.unet_base.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.unet_base.doctree new file mode 100644 index 00000000..f1a0d82d Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.unet_base.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.variablesplittingnet.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.variablesplittingnet.doctree new file mode 100644 index 00000000..6ca20e3c Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.variablesplittingnet.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.models.varnet.doctree b/docs/build/doctrees/mridc.collections.reconstruction.models.varnet.doctree new file mode 100644 index 00000000..d1abaca7 Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.models.varnet.doctree differ diff --git a/docs/build/doctrees/mridc.collections.reconstruction.parts.doctree b/docs/build/doctrees/mridc.collections.reconstruction.parts.doctree new file mode 100644 index 00000000..f091b7f1 Binary files /dev/null and b/docs/build/doctrees/mridc.collections.reconstruction.parts.doctree differ diff --git a/docs/build/doctrees/mridc.core.classes.doctree b/docs/build/doctrees/mridc.core.classes.doctree new file mode 100644 index 00000000..500e9cf7 Binary files /dev/null and b/docs/build/doctrees/mridc.core.classes.doctree differ diff --git a/docs/build/doctrees/mridc.core.conf.doctree b/docs/build/doctrees/mridc.core.conf.doctree new file mode 100644 index 00000000..91cc95d7 Binary files /dev/null and b/docs/build/doctrees/mridc.core.conf.doctree differ diff --git a/docs/build/doctrees/mridc.core.connectors.doctree b/docs/build/doctrees/mridc.core.connectors.doctree new file mode 100644 index 00000000..b96753db Binary files /dev/null and b/docs/build/doctrees/mridc.core.connectors.doctree differ diff --git a/docs/build/doctrees/mridc.core.doctree b/docs/build/doctrees/mridc.core.doctree new file mode 100644 index 00000000..6c400217 Binary files /dev/null and b/docs/build/doctrees/mridc.core.doctree differ diff --git a/docs/build/doctrees/mridc.core.neural_types.doctree b/docs/build/doctrees/mridc.core.neural_types.doctree new file mode 100644 index 00000000..3ab6ff74 Binary files /dev/null and b/docs/build/doctrees/mridc.core.neural_types.doctree differ diff --git a/docs/build/doctrees/mridc.core.optim.doctree b/docs/build/doctrees/mridc.core.optim.doctree new file mode 100644 index 00000000..7ca21af3 Binary files /dev/null and b/docs/build/doctrees/mridc.core.optim.doctree differ diff --git a/docs/build/doctrees/mridc.core.utils.doctree b/docs/build/doctrees/mridc.core.utils.doctree new file mode 100644 index 00000000..26dda57e Binary files /dev/null and b/docs/build/doctrees/mridc.core.utils.doctree differ diff --git a/docs/build/doctrees/mridc.doctree b/docs/build/doctrees/mridc.doctree new file mode 100644 index 00000000..fe5f6997 Binary files /dev/null and b/docs/build/doctrees/mridc.doctree differ diff --git a/docs/build/doctrees/mridc.utils.decorators.doctree b/docs/build/doctrees/mridc.utils.decorators.doctree new file mode 100644 index 00000000..1ab55a48 Binary files /dev/null and b/docs/build/doctrees/mridc.utils.decorators.doctree differ diff --git a/docs/build/doctrees/mridc.utils.doctree b/docs/build/doctrees/mridc.utils.doctree new file mode 100644 index 00000000..f8224464 Binary files /dev/null and b/docs/build/doctrees/mridc.utils.doctree differ diff --git a/docs/build/doctrees/mridc.utils.formaters.doctree b/docs/build/doctrees/mridc.utils.formaters.doctree new file mode 100644 index 00000000..fd565627 Binary files /dev/null and b/docs/build/doctrees/mridc.utils.formaters.doctree differ diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo new file mode 100644 index 00000000..792495fa --- /dev/null +++ b/docs/build/html/.buildinfo @@ -0,0 +1,4 @@ +# Sphinx build info version 1 +# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. +config: 46554dcb68b067ce402ab476ee469479 +tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/build/html/_modules/index.html b/docs/build/html/_modules/index.html new file mode 100644 index 00000000..6fdbe419 --- /dev/null +++ b/docs/build/html/_modules/index.html @@ -0,0 +1,193 @@ + + + + + + Overview: module code — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Overview: module code
  • +
  • +
  • +
+
+
+
+
+ +

All modules for which code is available

+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/common/callbacks/callbacks.html b/docs/build/html/_modules/mridc/collections/common/callbacks/callbacks.html new file mode 100644 index 00000000..765dde4d --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/common/callbacks/callbacks.html @@ -0,0 +1,126 @@ + + + + + + mridc.collections.common.callbacks.callbacks — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.common.callbacks.callbacks
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.common.callbacks.callbacks

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/callbacks.py
+
+import time
+
+from pytorch_lightning.callbacks.base import Callback
+from pytorch_lightning.utilities import rank_zero_only
+
+
+
[docs]class LogEpochTimeCallback(Callback): + """Simple callback that logs how long each epoch takes, in seconds, to a pytorch lightning log""" + + def __init__(self): + """Initialize the callback.""" + super().__init__() + self.epoch_start = time.time() + +
[docs] @rank_zero_only + def on_train_epoch_start(self, trainer, pl_module): + """Called at the start of each epoch.""" + self.epoch_start = time.time()
+ +
[docs] @rank_zero_only + def on_train_epoch_end(self, trainer, pl_module): + """Called at the end of each epoch.""" + curr_time = time.time() + duration = curr_time - self.epoch_start + trainer.logger.log_metrics({"epoch_time": duration}, step=trainer.global_step)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/common/data/dataset.html b/docs/build/html/_modules/mridc/collections/common/data/dataset.html new file mode 100644 index 00000000..b7113b11 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/common/data/dataset.html @@ -0,0 +1,243 @@ + + + + + + mridc.collections.common.data.dataset — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.common.data.dataset
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.common.data.dataset

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/data/dataset.py
+
+from abc import ABC
+from typing import Any, List
+
+import numpy as np
+import torch.utils.data as pt_data
+
+__all__ = ["ConcatDataset"]
+
+
+
[docs]class ConcatDataset(pt_data.IterableDataset, ABC): + """ + A dataset that accepts as argument multiple datasets and then samples from them based on the specified + sampling technique. + + Parameters + ---------- + datasets: A list of datasets to sample from. + shuffle: Whether to shuffle individual datasets. Only works with non-iterable datasets. Defaults to True. + sampling_technique: Sampling technique to choose which dataset to draw a sample from. Defaults to 'random'. + Currently supports 'random' and 'round-robin'. + sampling_probabilities: Probability values for sampling. Only used when sampling_technique = 'random'. + global_rank: Worker rank, used for partitioning map style datasets. Defaults to 0. + world_size: Total number of processes, used for partitioning map style datasets. Defaults to 1. + """ + + def __init__( + self, + datasets: List[Any], + shuffle: bool = True, + sampling_technique: str = "random", + sampling_probabilities: List[float] = None, + global_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + + self.datasets = datasets + self.iterables = [None] * len(datasets) + self.shuffle = shuffle + self.global_rank = global_rank + self.world_size = world_size + self.sampling_kwargs = {} + if sampling_technique == "random": + self.index_generator = ConcatDataset.random_generator + self.sampling_kwargs["p"] = sampling_probabilities # type: ignore + elif sampling_technique == "round-robin": + self.index_generator = ConcatDataset.round_robin_generator + else: + supported_sampling_techniques = ["random", "round-robin"] + raise ValueError(f"Currently we only support sampling techniques in {supported_sampling_techniques}.") + self.length = 0 + + if isinstance(datasets[0], pt_data.IterableDataset): + self.kind = "iterable" + else: + self.kind = "map" + + for dataset in datasets: + isiterable = isinstance(dataset, pt_data.IterableDataset) + if isiterable and self.kind != "iterable" or (not isiterable and self.kind == "iterable"): + raise ValueError("All datasets in ConcatDataset must be of the same kind (Iterable or Map).") + + if self.kind == "map": + self.length += np.floor_divide(len(dataset), world_size) + else: + self.length += len(dataset) + +
[docs] def get_iterable(self, dataset): + """Returns an iterable dataset.""" + if isinstance(dataset, pt_data.IterableDataset): + return dataset.__iter__() + indices = np.arange(len(dataset)) + if self.shuffle: + np.random.shuffle(indices) + return iter(indices)
+ +
[docs] def __iter__(self): + """Returns an iterator over the dataset.""" + worker_info = pt_data.get_worker_info() + if worker_info is None: + max_elements = self.length + wid = 0 + wnum = 1 + else: + wid = worker_info.id + wnum = worker_info.num_workers + max_elements = len(range(wid, self.length, wnum)) + + if self.kind == "map": + for idx in range(len(self.datasets)): + start_idx = np.floor_divide(len(self.datasets[idx]), self.world_size) * self.global_rank + end_idx = start_idx + np.floor_divide(len(self.datasets[idx]), self.world_size) + if self.global_rank == self.world_size - 1: + end_idx = len(self.datasets[idx]) + indices = range(start_idx + wid, end_idx, wnum) + self.datasets[idx] = pt_data.Subset(self.datasets[idx], indices) + + for idx, dataset in enumerate(self.datasets): + iterable = self.get_iterable(dataset) + self.iterables[idx] = iterable # type: ignore + + n = 0 + ind_gen = self.index_generator(self.datasets, **self.sampling_kwargs) + while n < max_elements: + n += 1 + try: + ind = next(ind_gen) + except StopIteration: + return + try: + val = next(self.iterables[ind]) # type: ignore + if self.kind == "map": + val = self.datasets[ind][val] + yield val + except StopIteration: + self.iterables[ind] = self.get_iterable(self.datasets[ind]) # type: ignore + n -= 1
+ +
[docs] def __len__(self): + """Returns the number of elements in the dataset.""" + return self.length
+ +
[docs] @staticmethod + def round_robin_generator(datasets, **kwargs): + """Generates indices in a round-robin fashion.""" + num = len(datasets) + while True: + yield from range(num)
+ +
[docs] @staticmethod + def random_generator(datasets, **kwargs): + """Generates random indices.""" + p = kwargs.get("p") + if not p: + raise ValueError("Random generator expects a 'p' keyowrd argument for sampling probabilities.") + + num = len(datasets) + if len(p) != num: + raise ValueError("Length of probabilities list must be equal to the number of datasets.") + + while True: + yield np.random.choice(np.arange(num), p=p)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/common/losses/aggregator.html b/docs/build/html/_modules/mridc/collections/common/losses/aggregator.html new file mode 100644 index 00000000..42b3df90 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/common/losses/aggregator.html @@ -0,0 +1,152 @@ + + + + + + mridc.collections.common.losses.aggregator — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.common.losses.aggregator
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.common.losses.aggregator

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/aggregator.py
+
+from typing import List
+
+import torch
+
+__all__ = ["AggregatorLoss"]
+
+from mridc.core.classes.common import typecheck
+from mridc.core.classes.loss import Loss
+from mridc.core.neural_types.elements import LossType
+from mridc.core.neural_types.neural_type import NeuralType
+
+
+
[docs]class AggregatorLoss(Loss): + """ + Sums several losses into one. + + Parameters + ---------- + num_inputs: number of input losses + weights: a list of coefficient for merging losses + """ + + @property + def input_types(self): + """Returns definitions of module input ports.""" + return {f"loss_{str(i + 1)}": NeuralType(elements_type=LossType()) for i in range(self._num_losses)} + + @property + def output_types(self): + """Returns definitions of module output ports.""" + return {"loss": NeuralType(elements_type=LossType())} + + def __init__(self, num_inputs: int = 2, weights: List[float] = None): + super().__init__() + self._num_losses = num_inputs + if weights is not None and len(weights) != num_inputs: + raise ValueError("Length of weights should be equal to the number of inputs (num_inputs)") + + self._weights = weights + +
[docs] @typecheck() + def forward(self, **kwargs): + """Computes the sum of the losses.""" + values = [kwargs[x] for x in sorted(kwargs.keys())] + loss = torch.zeros_like(values[0]) + for loss_idx, loss_value in enumerate(values): + if self._weights is not None: + loss = loss.add(loss_value, alpha=self._weights[loss_idx]) + else: + loss = loss.add(loss_value) + return loss
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/common/losses/ssim.html b/docs/build/html/_modules/mridc/collections/common/losses/ssim.html new file mode 100644 index 00000000..a64daf75 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/common/losses/ssim.html @@ -0,0 +1,155 @@ + + + + + + mridc.collections.common.losses.ssim — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.common.losses.ssim
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.common.losses.ssim

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+
[docs]class SSIMLoss(nn.Module): + """SSIM loss module.""" + + def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03): + """ + Args: + win_size: Window size for SSIM calculation. + k1: k1 parameter for SSIM calculation. + k2: k2 parameter for SSIM calculation. + """ + super().__init__() + self.win_size = win_size + self.k1, self.k2 = k1, k2 + self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size**2) + NP = win_size**2 + self.cov_norm = NP / (NP - 1) + +
[docs] def forward(self, X: torch.Tensor, Y: torch.Tensor, data_range: torch.Tensor): + """ + + Parameters + ---------- + X: First input tensor. + Y: Second input tensor. + data_range: Data range of the input tensors. + + Returns + ------- + SSIM loss. + """ + if not isinstance(self.w, torch.Tensor): + raise AssertionError + + data_range = data_range[:, None, None, None] + C1 = (self.k1 * data_range) ** 2 + C2 = (self.k2 * data_range) ** 2 + ux = F.conv2d(X, self.w) # typing: ignore + uy = F.conv2d(Y, self.w) # + uxx = F.conv2d(X * X, self.w) + uyy = F.conv2d(Y * Y, self.w) + uxy = F.conv2d(X * Y, self.w) + vx = self.cov_norm * (uxx - ux * ux) + vy = self.cov_norm * (uyy - uy * uy) + vxy = self.cov_norm * (uxy - ux * uy) + A1, A2, B1, B2 = (2 * ux * uy + C1, 2 * vxy + C2, ux**2 + uy**2 + C1, vx + vy + C2) + D = B1 * B2 + S = (A1 * A2) / D + + return 1 - S.mean()
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/common/metrics/global_average_loss_metric.html b/docs/build/html/_modules/mridc/collections/common/metrics/global_average_loss_metric.html new file mode 100644 index 00000000..f5980358 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/common/metrics/global_average_loss_metric.html @@ -0,0 +1,161 @@ + + + + + + mridc.collections.common.metrics.global_average_loss_metric — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.common.metrics.global_average_loss_metric
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.common.metrics.global_average_loss_metric

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from:
+# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/global_average_loss_metric.py
+
+import torch
+from torchmetrics import Metric
+
+__all__ = ["GlobalAverageLossMetric"]
+
+
+
[docs]class GlobalAverageLossMetric(Metric): + """ + This class is for averaging loss across multiple processes if a distributed backend is used. True average is \ + computed not running average. It does not accumulate gradients so the averaged loss cannot be used for \ + optimization. + + .. note:: + If ``take_avg_loss`` is ``True``, the :meth:`update` method ``loss`` argument has to be a mean loss. If \ + ``take_avg_loss`` is ``False`` then the :meth:`update` method ``loss`` argument has to be a sum of losses. \ + See PyTorch Lightning Metrics for the metric usage instruction. + + Parameters + ---------- + compute_on_step: The method :meth:`forward` only calls ``update()`` and returns ``None`` if this is set to \ + ``False``. Default: ``True`` + dist_sync_on_step: Synchronize metric state across processes at each method :meth:`forward` call before \ + returning the value at the step + process_group: Specify the process group on which synchronization is called. default: ``None`` (which selects \ + the entire world) + take_avg_loss: If ``True`` values of :meth:`update` method ``loss`` argument has to be a mean loss. If ``False`` \ + values of :meth:`update` method ``loss`` argument has to be a sum of losses. default: ``True`` + """ + + def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, take_avg_loss=True): + super().__init__( + compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group + ) + self.add_state("loss_sum", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum") + self.add_state("num_measurements", torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") + self.take_avg_loss = take_avg_loss + +
[docs] def update(self, loss, num_measurements): + """ + Updates :attr:`loss_sum` and :attr:`num_measurements`. + + Parameters + ---------- + loss: A float zero dimensional ``torch.Tensor`` which is either sum or average of losses for processed \ + examples. See ``take_avg_loss`` parameter of :meth:`__init__`. + num_measurements: An integer zero dimensional ``torch.Tensor`` which contains a number of loss measurements. \ + The sum or mean of the results of these measurements are in the ``loss`` parameter. + """ + if self.take_avg_loss: + self.loss_sum += loss.detach() * num_measurements + else: + self.loss_sum += loss.detach() + self.num_measurements += num_measurements
+ +
[docs] def compute(self): + """Returns mean loss.""" + if self.num_measurements.eq(0): + return torch.tensor(float("nan")) + return self.loss_sum / self.num_measurements
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/common/parts/fft.html b/docs/build/html/_modules/mridc/collections/common/parts/fft.html new file mode 100644 index 00000000..9e61c7e0 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/common/parts/fft.html @@ -0,0 +1,278 @@ + + + + + + mridc.collections.common.parts.fft — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.common.parts.fft
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.common.parts.fft

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI
+
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+__all__ = ["fft2c", "ifft2c"]
+
+
+
[docs]def fft2c( + data: torch.Tensor, + fft_type: str = "orthogonal", + fft_normalization: str = "ortho", + fft_dim: Union[Optional[int], List[int], None] = None, +) -> torch.Tensor: + """ + Apply centered 2 dimensional Fast Fourier Transform. + + Parameters + ---------- + data: Complex valued input data containing at least 3 dimensions: dimensions -2 & -1 are spatial dimensions. All + other dimensions are assumed to be batch dimensions. + fft_type: Specify fft type. This is important if an orthogonal transformation is needed or not. + fft_normalization: "ortho" is the default normalization used by PyTorch. Can be changed to "ortho" or None. + fft_dim: dimensions to apply the FFT + + Returns + ------- + The FFT of the input. + """ + if fft_dim is None: + fft_dim = [-2, -1] + + if fft_type == "orthogonal": + data = ifftshift(data, dim=[-3, -2]) + + data = torch.view_as_real(torch.fft.fft2(torch.view_as_complex(data), dim=fft_dim, norm=fft_normalization)) + + if fft_type == "orthogonal": + data = fftshift(data, dim=[-3, -2]) + + return data
+ + +
[docs]def ifft2c( + data: torch.Tensor, + fft_type: str = "orthogonal", + fft_normalization: str = "ortho", + fft_dim: Union[Optional[int], List[int], None] = None, +) -> torch.Tensor: + """ + Apply centered 2 dimensional Inverse Fast Fourier Transform. + + Parameters + ---------- + data: Complex valued input data containing at least 3 dimensions: dimensions -2 & -1 are spatial dimensions. All + other dimensions are assumed to be batch dimensions. + fft_type: Specify fft type. This is important if an orthogonal transformation is needed or not. + fft_normalization: "ortho" is the default normalization used by PyTorch. Can be changed to "ortho" or None. + fft_dim: dimensions to apply the FFT + + Returns + ------- + The IFFT of the input. + """ + if fft_dim is None: + fft_dim = [-2, -1] + + if fft_type == "orthogonal": + data = ifftshift(data, dim=[-3, -2]) + + data = torch.view_as_real(torch.fft.ifft2(torch.view_as_complex(data), dim=fft_dim, norm=fft_normalization)) + + if fft_type == "orthogonal": + data = fftshift(data, dim=[-3, -2]) + + return data
+ + +def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor: + """ + Similar to roll but for only one dim. + + Parameters + ---------- + x: A PyTorch tensor. + shift: Amount to roll. + dim: Which dimension to roll. + + Returns + ------- + Rolled version of x. + """ + shift %= x.size(dim) + if shift == 0: + return x + + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + + return torch.cat((right, left), dim=dim) + + +def roll(x: torch.Tensor, shift: List[int], dim: List[int]) -> torch.Tensor: + """ + Similar to np.roll but applies to PyTorch Tensors. + + Parameters + ---------- + x: A PyTorch tensor. + shift: Amount to roll. + dim: Which dimension to roll. + + Returns + ------- + Rolled version of x. + """ + if len(shift) != len(dim): + raise ValueError("len(shift) must match len(dim)") + + for (s, d) in zip(shift, dim): + x = roll_one_dim(x, s, d) + + return x + + +def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Parameters + ---------- + x: A PyTorch tensor. + dim: Which dimension to fftshift. + + Returns + ------- + fftshifted version of x. + """ + if dim is None: + # this weird code is necessary for torch.jit.script typing + dim = [0] * (x.dim()) + for i in range(1, x.dim()): + dim[i] = i + + # Also necessary for torch.jit.script + shift = [0] * len(dim) + for i, dim_num in enumerate(dim): + shift[i] = np.floor_divide(x.shape[dim_num], 2) + + return roll(x, shift, dim) + + +def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Parameters + ---------- + x: A PyTorch tensor. + dim: Which dimension to ifftshift. + + Returns + ------- + ifftshifted version of x. + """ + if dim is None: + # this weird code is necessary for torch.jit.script typing + dim = [0] * (x.dim()) + for i in range(1, x.dim()): + dim[i] = i + + # Also necessary for torch.jit.script + shift = [0] * len(dim) + for i, dim_num in enumerate(dim): + shift[i] = np.floor_divide(x.shape[dim_num] + 1, 2) + + return roll(x, shift, dim) +
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/common/parts/ptl_overrides.html b/docs/build/html/_modules/mridc/collections/common/parts/ptl_overrides.html new file mode 100644 index 00000000..8e928980 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/common/parts/ptl_overrides.html @@ -0,0 +1,111 @@ + + + + + + mridc.collections.common.parts.ptl_overrides — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.common.parts.ptl_overrides
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.common.parts.ptl_overrides

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/parts/ptl_overrides.py
+
+import torch
+from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
+
+
+
[docs]class MRIDCNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): + """Native Mixed Precision Plugin for MRIDC.""" + + def __init__(self, init_scale: float = 2**32, growth_interval: int = 1000) -> None: + super().__init__(precision=16, device=self.device) + self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, growth_interval=growth_interval)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/common/parts/rnn_utils.html b/docs/build/html/_modules/mridc/collections/common/parts/rnn_utils.html new file mode 100644 index 00000000..36638b9c --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/common/parts/rnn_utils.html @@ -0,0 +1,128 @@ + + + + + + mridc.collections.common.parts.rnn_utils — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.common.parts.rnn_utils
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.common.parts.rnn_utils

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+import torch.nn as nn
+
+__all__ = ["rnn_weights_init"]
+
+
+
[docs]def rnn_weights_init(module, std_init_range=0.02, xavier=True): + """ + # TODO: check if this is the correct way to initialize RNN weights + Initialize different weights in Transformer model. + + Parameters + ---------- + module: torch.nn.Module to be initialized + std_init_range: standard deviation of normal initializer + xavier: if True, xavier initializer will be used in Linear layers as was proposed in AIAYN paper, otherwise normal + initializer will be used (like in BERT paper) + """ + if isinstance(module, nn.Linear): + if xavier: + nn.init.xavier_uniform_(module.weight) + else: + nn.init.normal_(module.weight, mean=0.0, std=std_init_range) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=std_init_range) + elif isinstance(module, nn.LayerNorm): + nn.init.constant_(module.weight, 1.0) + nn.init.constant_(module.bias, 0.0)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/common/parts/utils.html b/docs/build/html/_modules/mridc/collections/common/parts/utils.html new file mode 100644 index 00000000..61e59818 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/common/parts/utils.html @@ -0,0 +1,357 @@ + + + + + + mridc.collections.common.parts.utils — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.common.parts.utils
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.common.parts.utils

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI
+
+from pathlib import Path
+from typing import Dict
+
+import h5py
+import numpy as np
+import torch
+
+__all__ = [
+    "to_tensor",
+    "tensor_to_complex_np",
+    "complex_mul",
+    "complex_conj",
+    "complex_abs",
+    "complex_abs_sq",
+    "rss",
+    "rss_complex",
+    "sense",
+    "coil_combination",
+    "save_reconstructions",
+    "check_stacked_complex",
+]
+
+
+
[docs]def to_tensor(data: np.ndarray) -> torch.Tensor: + """ + Converts a numpy array to a torch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Parameters + ---------- + data: Input numpy array to be converted to torch. + + Returns + ------- + Torch tensor version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data)
+ + +
[docs]def tensor_to_complex_np(data: torch.Tensor) -> np.ndarray: + """ + Converts a torch tensor to a numpy array. + + Parameters + ---------- + data: Input torch tensor to be converted to numpy. + + Returns + ------- + Complex Numpy array version of data. + """ + data = data.numpy() + + return data[..., 0] + 1j * data[..., 1]
+ + +
[docs]def complex_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Complex multiplication. + + This multiplies two complex tensors assuming that they are both stored as + real arrays with the last dimension being the complex dimension. + + Parameters + ---------- + x: A PyTorch tensor with the last dimension of size 2. + y: A PyTorch tensor with the last dimension of size 2. + + Returns + ------- + A PyTorch tensor with the last dimension of size 2. + """ + if not x.shape[-1] == y.shape[-1] == 2: + raise ValueError("Tensors do not have separate complex dim.") + + re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] + im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] + + return torch.stack((re, im), dim=-1)
+ + +
[docs]def complex_conj(x: torch.Tensor) -> torch.Tensor: + """ + Complex conjugate. + + This applies the complex conjugate assuming that the input array has the + last dimension as the complex dimension. + + Parameters + ---------- + x: A PyTorch tensor with the last dimension of size 2. + + Returns + ------- + A PyTorch tensor with the last dimension of size 2. + """ + if x.shape[-1] != 2: + raise ValueError("Tensor does not have separate complex dim.") + + return torch.stack((x[..., 0], -x[..., 1]), dim=-1)
+ + +
[docs]def complex_abs(data: torch.Tensor) -> torch.Tensor: + """ + Compute the absolute value of a complex valued input tensor. + + Parameters + ---------- + data: A complex valued tensor, where the size of the final dimension should be 2. + + Returns + ------- + Absolute value of data. + """ + if data.shape[-1] != 2: + raise ValueError("Tensor does not have separate complex dim.") + + return (data**2).sum(dim=-1).sqrt()
+ + +
[docs]def complex_abs_sq(data: torch.Tensor) -> torch.Tensor: + """ + Compute the squared absolute value of a complex tensor. + + Parameters + ---------- + data: A complex valued tensor, where the size of the final dimension should be 2. + + Returns + ------- + Squared absolute value of data. + """ + if data.shape[-1] != 2: + raise ValueError("Tensor does not have separate complex dim.") + + return (data**2).sum(dim=-1)
+ + +
[docs]def check_stacked_complex(data: torch.Tensor) -> torch.Tensor: + """ + Check if tensor is stacked complex (real & imag parts stacked along last dim) and convert it to a combined complex + tensor. + + Parameters + ---------- + data: A complex valued tensor, where the size of the final dimension might be 2. + + Returns + ------- + A complex valued tensor. + """ + return torch.view_as_complex(data) if data.shape[-1] == 2 else data
+ + +
[docs]def rss(data: torch.Tensor, dim: int = 0) -> torch.Tensor: + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Parameters + ---------- + data: The input tensor + dim: The dimensions along which to apply the RSS transform + + Returns + ------- + The RSS value. + """ + return torch.sqrt((data**2).sum(dim))
+ + +
[docs]def rss_complex(data: torch.Tensor, dim: int = 0) -> torch.Tensor: + """ + Compute the Root Sum of Squares (RSS) for complex inputs. + + RSS is computed assuming that dim is the coil dimension. + + Parameters + ---------- + data: The input tensor + dim: The dimensions along which to apply the RSS transform + + Returns + ------- + The RSS value. + """ + return torch.sqrt(complex_abs_sq(data).sum(dim))
+ + +
[docs]def sense(data: torch.Tensor, sensitivity_maps: torch.Tensor, dim: int = 0) -> torch.Tensor: + """ + The SENSitivity Encoding (SENSE) transform [1]_. + + References + ---------- + .. [1] Pruessmann KP, Weiger M, Scheidegger MB, Boesiger P. SENSE: Sensitivity encoding for fast MRI. Magn Reson Med 1999; 42:952-962. + + Parameters + ---------- + data: The input tensor + sensitivity_maps: The sensitivity maps + dim: The coil dimension + + Returns + ------- + A coil-combined image. + """ + return complex_mul(data, complex_conj(sensitivity_maps)).sum(dim)
+ + +
[docs]def coil_combination( + data: torch.Tensor, sensitivity_maps: torch.Tensor, method: str = "SENSE", dim: int = 0 +) -> torch.Tensor: + """ + Coil combination. + + Parameters + ---------- + data: The input tensor. + sensitivity_maps: The sensitivity maps. + method: The coil combination method. + dim: The dimensions along which to apply the coil combination transform. + + Returns + ------- + Coil combined data. + """ + if method == "SENSE": + return sense(data, sensitivity_maps, dim) + if method == "RSS": + return rss(data, dim) + raise ValueError("Output type not supported.")
+ + +
[docs]def save_reconstructions(reconstructions: Dict[str, np.ndarray], out_dir: Path): + """ + Save reconstruction images. + + This function writes to h5 files that are appropriate for submission to the + leaderboard. + + Parameters + ---------- + reconstructions: A dictionary mapping input filenames to corresponding reconstructions. + out_dir: Path to the output directory where the reconstructions should be saved. + """ + out_dir.mkdir(exist_ok=True, parents=True) + for fname, recons in reconstructions.items(): + with h5py.File(out_dir / fname, "w") as hf: + hf.create_dataset("reconstruction", data=recons)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/data/mri_data.html b/docs/build/html/_modules/mridc/collections/reconstruction/data/mri_data.html new file mode 100644 index 00000000..6baf338e --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/data/mri_data.html @@ -0,0 +1,442 @@ + + + + + + mridc.collections.reconstruction.data.mri_data — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.data.mri_data
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.data.mri_data

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI
+
+import logging
+import os
+import random
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
+
+import h5py
+import numpy as np
+import torch
+import yaml  # type: ignore
+from defusedxml.ElementTree import fromstring
+from torch.utils.data import Dataset
+
+
+
[docs]def et_query(root: str, qlist: Sequence[str], namespace: str = "https://www.ismrm.org/ISMRMRD") -> str: + """ + Query an XML element for a list of attributes. + + Parameters + ---------- + root: The root element of the XML tree. + qlist: A list of strings, each of which is an attribute name. + namespace: The namespace of the XML tree. + + Returns + ------- + A string containing the value of the last attribute in the list. + """ + s = "." + prefix = "ismrmrd_namespace" + + ns = {prefix: namespace} + + for el in qlist: + s += f"//{prefix}:{el}" + + value = root.find(s, ns) # type: ignore + if value is None: + return "0" + + return str(value.text) # type: ignore
+ + +
[docs]class FastMRICombinedSliceDataset(torch.utils.data.Dataset): + """A dataset that combines multiple datasets.""" + + def __init__( + self, + roots: Sequence[Path], + challenges: Sequence[str], + sense_roots: Optional[Sequence[Path]] = None, + transforms: Optional[Sequence[Optional[Callable]]] = None, + sample_rates: Optional[Sequence[Optional[float]]] = None, + volume_sample_rates: Optional[Sequence[Optional[float]]] = None, + use_dataset_cache: bool = False, + dataset_cache_file: Union[str, Path, os.PathLike] = "dataset_cache.yaml", + num_cols: Optional[Tuple[int]] = None, + ): + """ + Parameters + ---------- + roots: Paths to the datasets. + challenges: "singlecoil" or "multicoil" depending on which challenge to use. + sense_roots: Load pre-computed (stored) sensitivity maps. + transforms: Optional; A sequence of callable objects that preprocesses the raw data into appropriate form. + The transform function should take 'kspace', 'target', 'attributes', 'filename', and 'slice' as inputs. + 'target' may be null for test data. + sample_rates: Optional; A sequence of floats between 0 and 1. This controls what fraction of the slices + should be loaded. When creating subsampled datasets either set sample_rates (sample by slices) or + volume_sample_rates (sample by volumes) but not both. + volume_sample_rates: Optional; A sequence of floats between 0 and 1. This controls what fraction of the + volumes should be loaded. When creating subsampled datasets either set sample_rates (sample by slices) + or volume_sample_rates (sample by volumes) but not both. + use_dataset_cache: Whether to cache dataset metadata. This is very useful for large datasets like the brain + data. + dataset_cache_file: Optional; A file in which to cache dataset information for faster load times. + num_cols: Optional; If provided, only slices with the desired number of columns will be considered. + """ + if sample_rates is not None and volume_sample_rates is not None: + raise ValueError( + "either set sample_rates (sample by slices) or volume_sample_rates (sample by volumes) but not both" + ) + if transforms is None: + transforms = [None] * len(roots) + if sample_rates is None: + sample_rates = [None] * len(roots) + if volume_sample_rates is None: + volume_sample_rates = [None] * len(roots) + if not len(roots) == len(transforms) == len(challenges) == len(sample_rates) == len(volume_sample_rates): + raise ValueError("Lengths of roots, transforms, challenges, sample_rates do not match") + + self.datasets = [] + self.examples: List[Tuple[Path, int, Dict[str, object]]] = [] + for i, _ in enumerate(roots): + self.datasets.append( + FastMRISliceDataset( + root=roots[i], + transform=transforms[i], + sense_root=sense_roots[i] if sense_roots is not None else None, + challenge=challenges[i], + sample_rate=sample_rates[i], + volume_sample_rate=volume_sample_rates[i], + use_dataset_cache=use_dataset_cache, + dataset_cache_file=dataset_cache_file, + num_cols=num_cols, + ) + ) + + self.examples += self.datasets[-1].examples + + def __len__(self): + return sum(len(dataset) for dataset in self.datasets) + + def __getitem__(self, i): + for dataset in self.datasets: + if i < len(dataset): + return dataset[i] + i = i - len(dataset)
+ + +
[docs]class FastMRISliceDataset(Dataset): + """A dataset that loads slices from a single dataset.""" + + def __init__( + self, + root: Union[str, Path, os.PathLike], + challenge: str = "multicoil", + transform: Optional[Callable] = None, + sense_root: Union[str, Path, os.PathLike] = None, + use_dataset_cache: bool = False, + sample_rate: Optional[float] = None, + volume_sample_rate: Optional[float] = None, + dataset_cache_file: Union[str, Path, os.PathLike] = "dataset_cache.yaml", + num_cols: Optional[Tuple[int]] = None, + mask_root: Union[str, Path, os.PathLike] = None, + ): + """ + Parameters + ---------- + root: Path to the dataset. + challenge: "singlecoil" or "multicoil" depending on which challenge to use. + transform: Optional; A sequence of callable objects that preprocesses the raw data into appropriate form. + The transform function should take 'kspace', 'target', 'attributes', 'filename', and 'slice' as inputs. + 'target' may be null for test data. + sense_root: Path to the coil sensitivities maps dataset. + use_dataset_cache: Whether to cache dataset metadata. This is very useful for large datasets like the brain + data. + sample_rate: Optional; A sequence of floats between 0 and 1. This controls what fraction of the slices + should be loaded. When creating subsampled datasets either set sample_rates (sample by slices) or + volume_sample_rates (sample by volumes) but not both. + volume_sample_rate: Optional; A sequence of floats between 0 and 1. This controls what fraction of the + volumes should be loaded. When creating subsampled datasets either set sample_rates (sample by slices) + or volume_sample_rates (sample by volumes) but not both. + dataset_cache_file: Optional; A file in which to cache dataset information for faster load times. + num_cols: Optional; If provided, only slices with the desired number of columns will be considered. + mask_root: Path to stored masks. + """ + if challenge not in ("singlecoil", "multicoil"): + raise ValueError('challenge should be either "singlecoil" or "multicoil"') + + if sample_rate is not None and volume_sample_rate is not None: + raise ValueError( + "either set sample_rate (sample by slices) or volume_sample_rate (sample by volumes) but not both" + ) + + self.sense_root = sense_root + self.mask_root = mask_root + + self.dataset_cache_file = Path(dataset_cache_file) + + self.transform = transform + self.recons_key = "reconstruction_esc" if challenge == "singlecoil" else "reconstruction_rss" + self.examples = [] + + # set default sampling mode if none given + if sample_rate is None: + sample_rate = 1.0 + if volume_sample_rate is None: + volume_sample_rate = 1.0 + + # load dataset cache if we have and user wants to use it + if self.dataset_cache_file.exists() and use_dataset_cache: + with open(self.dataset_cache_file, "rb") as f: + dataset_cache = yaml.safe_load(f) + else: + dataset_cache = {} + + # check if our dataset is in the cache + # if there, use that metadata, if not, then regenerate the metadata + if dataset_cache.get(root) is None or not use_dataset_cache: + files = list(Path(root).iterdir()) + for fname in sorted(files): + metadata, num_slices = self._retrieve_metadata(fname) + self.examples += [(fname, slice_ind, metadata) for slice_ind in range(num_slices)] + + if dataset_cache.get(root) is None and use_dataset_cache: + dataset_cache[root] = self.examples + logging.info(f"Saving dataset cache to {self.dataset_cache_file}.") + with open(self.dataset_cache_file, "wb") as f: # type: ignore + yaml.dump(dataset_cache, f) # type: ignore + else: + logging.info(f"Using dataset cache from {self.dataset_cache_file}.") + self.examples = dataset_cache[root] + + # subsample if desired + if sample_rate < 1.0: # sample by slice + random.shuffle(self.examples) + num_examples = round(len(self.examples) * sample_rate) + self.examples = self.examples[:num_examples] + elif volume_sample_rate < 1.0: # sample by volume + vol_names = sorted(list({f[0].stem for f in self.examples})) + random.shuffle(vol_names) + num_volumes = round(len(vol_names) * volume_sample_rate) + sampled_vols = vol_names[:num_volumes] + self.examples = [example for example in self.examples if example[0].stem in sampled_vols] + + if num_cols: + self.examples = [ex for ex in self.examples if ex[2]["encoding_size"][1] in num_cols] # type: ignore + + @staticmethod + def _retrieve_metadata(fname): + """ + Retrieve metadata from a given file. + + Parameters + ---------- + fname: Path to file. + + Returns + ------- + A dictionary containing the metadata. + """ + with h5py.File(fname, "r") as hf: + if "ismrmrd_header" in hf: + et_root = fromstring(hf["ismrmrd_header"][()]) + + enc = ["encoding", "encodedSpace", "matrixSize"] + enc_size = ( + int(et_query(et_root, enc + ["x"])), + int(et_query(et_root, enc + ["y"])), + int(et_query(et_root, enc + ["z"])), + ) + rec = ["encoding", "reconSpace", "matrixSize"] + recon_size = ( + int(et_query(et_root, rec + ["x"])), + int(et_query(et_root, rec + ["y"])), + int(et_query(et_root, rec + ["z"])), + ) + + params = ["encoding", "encodingLimits", "kspace_encoding_step_1"] + enc_limits_center = int(et_query(et_root, params + ["center"])) + enc_limits_max = int(et_query(et_root, params + ["maximum"])) + 1 + + padding_left = torch.div(enc_size[1], 2, rounding_mode="trunc").item() - enc_limits_center + padding_right = padding_left + enc_limits_max + else: + padding_left = 0 + padding_right = 0 + enc_size = 0 + recon_size = (0, 0) + + num_slices = hf["kspace"].shape[0] + + metadata = { + "padding_left": padding_left, + "padding_right": padding_right, + "encoding_size": enc_size, + "recon_size": recon_size, + } + + return metadata, num_slices + + def __len__(self): + return len(self.examples) + + def __getitem__(self, i: int): + fname, dataslice, metadata = self.examples[i] + with h5py.File(fname, "r") as hf: + kspace = hf["kspace"][dataslice].astype(np.complex64) + + if "sensitivity_map" in hf: + sensitivity_map = hf["sensitivity_map"][dataslice].astype(np.complex64) + elif self.sense_root is not None and self.sense_root != "None": + with h5py.File(Path(self.sense_root) / Path(str(fname).split("/")[-2]) / fname.name, "r") as sf: + sensitivity_map = ( + sf["sensitivity_map"][dataslice] + if "sensitivity_map" in sf or "sensitivity_map" in next(iter(sf.keys())) + else sf["sense"][dataslice] + ) + sensitivity_map = sensitivity_map.squeeze().astype(np.complex64) + else: + sensitivity_map = np.array([]) + + if "mask" in hf: + mask = np.asarray(hf["mask"]) + + if mask.ndim == 3: + mask = mask[dataslice] + + elif self.mask_root is not None and self.mask_root != "None": + mask_path = Path(self.mask_root) / Path(str(fname.name).split(".")[0] + ".npy") + mask = np.load(str(mask_path)) + else: + mask = None + + eta = hf["eta"][dataslice].astype(np.complex64) if "eta" in hf else np.array([]) + + if "reconstruction_sense" in hf: + self.recons_key = "reconstruction_sense" + + target = hf[self.recons_key][dataslice].astype(np.float32) if self.recons_key in hf else None + + attrs = dict(hf.attrs) + attrs |= metadata + + if sensitivity_map.shape != kspace.shape: + sensitivity_map = np.transpose(sensitivity_map, (2, 0, 1)) + + return ( + ( + kspace, + sensitivity_map, + mask, + eta, + target, + attrs, + fname.name, + dataslice, + ) + if self.transform is None + else self.transform( + kspace, + sensitivity_map, + mask, + eta, + target, + attrs, + fname.name, + dataslice, + ) + )
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/data/subsample.html b/docs/build/html/_modules/mridc/collections/reconstruction/data/subsample.html new file mode 100644 index 00000000..21d42d27 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/data/subsample.html @@ -0,0 +1,778 @@ + + + + + + mridc.collections.reconstruction.data.subsample — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.data.subsample
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.data.subsample

+# encoding: utf-8
+# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI
+__author__ = "Dimitrios Karkalousos"
+
+import contextlib
+from typing import Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+
+
+
[docs]@contextlib.contextmanager +def temp_seed(rng: np.random, seed: Optional[Union[int, Tuple[int, ...]]]): + """ + Temporarily sets the seed of the given random number generator. + + Parameters + ---------- + rng: The random number generator. + seed: The seed to set. + + Returns + ------- + A context manager. + """ + if seed is None: + try: + yield + finally: + pass + else: + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state)
+ + +
[docs]class MaskFunc: + """A class that defines a mask function.""" + + def __init__(self, center_fractions: Sequence[float], accelerations: Sequence[int]): + """ + Initialize the mask function. + + Parameters + ---------- + center_fractions: Fraction of low-frequency columns to be retained. If multiple values are provided, then \ + one of these numbers is chosen uniformly each time. For 2D setting this value corresponds to setting the \ + Full-Width-Half-Maximum. + accelerations: Amount of under-sampling. This should have the same length as center_fractions. If multiple \ + values are provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError("Number of center fractions should match number of accelerations") + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random.RandomState() # pylint: disable=no-member + +
[docs] def __call__( + self, + shape: Sequence[int], + seed: Optional[Union[int, Tuple[int, ...]]] = None, + half_scan_percentage: Optional[float] = 0.0, + scale: Optional[float] = 0.02, + ) -> Tuple[torch.Tensor, int]: + """ + + Parameters + ---------- + shape: Shape of the input tensor. + seed: Seed for the random number generator. + half_scan_percentage: Percentage of the low-frequency columns to be retained. + scale: Scale of the mask. + + Returns + ------- + A tuple of the mask and the number of low-frequency columns retained. + """ + raise NotImplementedError
+ +
[docs] def choose_acceleration(self): + """Choose acceleration.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration
+ + +
[docs]class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the k-space data has N columns, the mask \ + picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a probability equal to: \ + prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). This ensures that the expected number of \ + columns selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which case one possible (center_fraction, \ + acceleration) is chosen uniformly at random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there is a 50% probability that \ + 4-fold acceleration with 8% center fraction is selected and a 50% probability that 8-fold acceleration with 4% \ + center fraction is selected. + """ + +
[docs] def __call__( + self, + shape: Sequence[int], + seed: Optional[Union[int, Tuple[int, ...]]] = None, + half_scan_percentage: Optional[float] = 0.0, + scale: Optional[float] = 0.02, + ) -> Tuple[torch.Tensor, int]: + """ + Parameters + ---------- + shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ + along the second last dimension. + seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time \ + for the same shape. The random state is reset afterwards. + half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. + scale: Optional; Defines the scale of the center of the mask. + + Returns + ------- + A tuple of the mask and the number of columns selected. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) + mask = self.rng.uniform(size=num_cols) < prob # type: ignore + pad = torch.div((num_cols - num_low_freqs + 1), 2, rounding_mode="trunc").item() + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask, acceleration
+ + +
[docs]class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the k-space data has N columns, the mask \ + picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to low-frequencies. + 2. The other columns are selected with equal spacing at a proportion that reaches the desired acceleration \ + rate taking into consideration the number of low frequencies. This ensures that the expected number of \ + columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which case one possible (center_fraction, \ + acceleration) is chosen uniformly at random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in \ + https://github.com/facebookresearch/fastMRI/issues/54), which will require modifications to standard GRAPPA \ + approaches. Nonetheless, this aspect of the function has been preserved to match the public multicoil data. + """ + +
[docs] def __call__( + self, + shape: Sequence[int], + seed: Optional[Union[int, Tuple[int, ...]]] = None, + half_scan_percentage: Optional[float] = 0.0, + scale: Optional[float] = 0.02, + ) -> Tuple[torch.Tensor, int]: + """ + Parameters + ---------- + shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ + along the second last dimension. + seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time for \ + the same shape. The random state is reset afterwards. + half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. + scale: Optional; Defines the scale of the center of the mask. + + Returns + ------- + A tuple of the mask and the number of columns selected. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = torch.div((num_cols - num_low_freqs + 1), 2, rounding_mode="trunc").item() + mask[pad : pad + num_low_freqs] = True # type: ignore + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask, acceleration
+ + +
[docs]class Gaussian1DMaskFunc(MaskFunc): + """ + Creates a 1D sub-sampling mask of a given shape. + + For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of \ + which the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled \ + according to a Gaussian distribution. + + The center fractions here act as Full-Width at Half-Maximum (FWHM) values. + """ + +
[docs] def __call__( + self, + shape: Union[Sequence[int], np.ndarray], + seed: Optional[Union[int, Tuple[int, ...]]] = None, + half_scan_percentage: Optional[float] = 0.0, + scale: Optional[float] = 0.02, + ) -> Tuple[torch.Tensor, int]: + """ + Parameters + ---------- + shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ + along the second last dimension. + seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time \ + for the same shape. The random state is reset afterwards. + half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. + scale: For autocalibration purposes, data points near the k-space center will be fully sampled within an \ + ellipse of which the half-axes will set to the set scale % of the fully sampled region + + Returns + ------- + A tuple of the mask and the number of columns selected. + """ + dims = [1 for _ in shape] + self.shape = tuple(shape[-3:-1]) + dims[-2] = self.shape[-1] + + full_width_half_maximum, acceleration = self.choose_acceleration() + if not isinstance(full_width_half_maximum, list): + full_width_half_maximum = [full_width_half_maximum] * 2 + self.full_width_half_maximum = full_width_half_maximum + self.acceleration = acceleration + + self.scale = scale + + mask = self.gaussian_kspace() + mask[tuple(self.gaussian_coordinates())] = 1.0 + + mask = np.fft.ifftshift(np.fft.ifftshift(np.fft.ifftshift(mask, axes=0), axes=0), axes=(0, 1)) + + if half_scan_percentage != 0: + mask[: int(np.round(mask.shape[0] * half_scan_percentage)), :] = 0.0 + + return torch.from_numpy(mask[0].reshape(dims).astype(np.float32)), acceleration
+ +
[docs] def gaussian_kspace(self): + """Creates a Gaussian sampled k-space center.""" + scaled = int(self.shape[0] * self.scale) + center = np.ones((scaled, self.shape[1])) + top_scaled = torch.div((self.shape[0] - scaled), 2, rounding_mode="trunc").item() + bottom_scaled = self.shape[0] - scaled - top_scaled + top = np.zeros((top_scaled, self.shape[1])) + btm = np.zeros((bottom_scaled, self.shape[1])) + return np.concatenate((top, center, btm))
+ +
[docs] def gaussian_coordinates(self): + """Creates a Gaussian sampled k-space coordinates.""" + n_sample = int(self.shape[0] / self.acceleration) + kernel = self.gaussian_kernel() + idxs = np.random.choice(range(self.shape[0]), size=n_sample, replace=False, p=kernel) + xsamples = np.concatenate([np.tile(i, self.shape[1]) for i in idxs]) + ysamples = np.concatenate([range(self.shape[1]) for _ in idxs]) + return xsamples, ysamples
+ +
[docs] def gaussian_kernel(self): + """Creates a Gaussian sampled k-space kernel.""" + kernel = 1 + for fwhm, kern_len in zip(self.full_width_half_maximum, self.shape): + sigma = fwhm / np.sqrt(8 * np.log(2)) + x = np.linspace(-1.0, 1.0, kern_len) + g = np.exp(-(x**2 / (2 * sigma**2))) + kernel = g + break + kernel = kernel / kernel.sum() + return kernel
+ + +
[docs]class Gaussian2DMaskFunc(MaskFunc): + """ + Creates a 2D sub-sampling mask of a given shape. + + For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of \ + which the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled \ + according to a Gaussian distribution. + + The center fractions here act as Full-Width at Half-Maximum (FWHM) values. + """ + +
[docs] def __call__( + self, + shape: Union[Sequence[int], np.ndarray], + seed: Optional[Union[int, Tuple[int, ...]]] = None, + half_scan_percentage: Optional[float] = 0.0, + scale: Optional[float] = 0.02, + ) -> Tuple[torch.Tensor, int]: + """ + Parameters + ---------- + shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ + along the second last dimension. + seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time for \ + the same shape. The random state is reset afterwards. + half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. + scale: For autocalibration purposes, data points near the k-space center will be fully sampled within an \ + ellipse of which the half-axes will set to the set scale % of the fully sampled region + + Returns + ------- + A tuple of the mask and the number of columns selected. + """ + dims = [1 for _ in shape] + self.shape = tuple(shape[-3:-1]) + dims[-3:-1] = self.shape + + full_width_half_maximum, acceleration = self.choose_acceleration() + + if not isinstance(full_width_half_maximum, list): + full_width_half_maximum = [full_width_half_maximum] * 2 + self.full_width_half_maximum = full_width_half_maximum + + self.acceleration = acceleration + self.scale = scale + + mask = self.gaussian_kspace() + mask[tuple(self.gaussian_coordinates())] = 1.0 + + if half_scan_percentage != 0: + mask[: int(np.round(mask.shape[0] * half_scan_percentage)), :] = 0.0 + + return torch.from_numpy(mask.reshape(dims).astype(np.float32)), acceleration
+ +
[docs] def gaussian_kspace(self): + """Creates a Gaussian sampled k-space center.""" + a, b = self.scale * self.shape[0], self.scale * self.shape[1] + afocal, bfocal = self.shape[0] / 2, self.shape[1] / 2 + xx, yy = np.mgrid[: self.shape[0], : self.shape[1]] + ellipse = np.power((xx - afocal) / a, 2) + np.power((yy - bfocal) / b, 2) + return (ellipse < 1).astype(float)
+ +
[docs] def gaussian_coordinates(self): + """Creates a Gaussian sampled k-space coordinates.""" + n_sample = int(self.shape[0] * self.shape[1] / self.acceleration) + cartesian_prod = list(np.ndindex(self.shape)) # type: ignore + kernel = self.gaussian_kernel() + idxs = np.random.choice(range(len(cartesian_prod)), size=n_sample, replace=False, p=kernel.flatten()) + return list(zip(*list(map(cartesian_prod.__getitem__, idxs))))
+ +
[docs] def gaussian_kernel(self): + """Creates a Gaussian kernel.""" + kernels = [] + for fwhm, kern_len in zip(self.full_width_half_maximum, self.shape): + sigma = fwhm / np.sqrt(8 * np.log(2)) + x = np.linspace(-1.0, 1.0, kern_len) + g = np.exp(-(x**2 / (2 * sigma**2))) + kernels.append(g) + kernel = np.sqrt(np.outer(kernels[0], kernels[1])) + kernel = kernel / kernel.sum() + return kernel
+ + +
[docs]class Poisson2DMaskFunc(MaskFunc): + """ + Creates a 2D sub-sampling mask of a given shape. + + For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of \ + which the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled \ + according to a (variable density) Poisson distribution. + + For a given acceleration factor to be accurate, the scale for the fully sampled center should remain at the \ + default 0.02. A predefined list is used to convert the acceleration factor to the appropriate r parameter needed \ + for the variable density calculation. This list has been made to accommodate acceleration factors of 4 up to 21, \ + rounding off to the nearest one available. As such, acceleration factors outside this range cannot be used. + """ + +
[docs] def __call__( + self, + shape: Union[Sequence[int], np.ndarray], + seed: Optional[Union[int, Tuple[int, ...]]] = None, + half_scan_percentage: Optional[float] = 0.0, + scale: Optional[float] = 0.02, + ) -> Tuple[torch.Tensor, int]: + """ + Parameters + ---------- + shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ + along the second last dimension. + seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time \ + for the same shape. The random state is reset afterwards. + half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. + scale: For autocalibration purposes, data points near the k-space center will be fully sampled within an \ + ellipse of which the half-axes will set to the set scale % of the fully sampled region + + Returns + ------- + A tuple of the mask and the number of columns selected. + """ + dims = [1 for _ in shape] + self.shape = tuple(shape[-3:-1]) + dims[-3:-1] = self.shape + + _, acceleration = self.choose_acceleration() + if acceleration > 21.5 or acceleration < 3.5: + raise ValueError(f"Acceleration {acceleration} is not supported for Poisson 2D masking.") + + self.acceleration = acceleration + self.scale = scale + + # TODO: consider moving this to a yaml file + rfactor = [ + 21.22, + 20.32, + 19.06, + 18.22, + 17.41, + 16.56, + 15.86, + 15.12, + 14.42, + 13.88, + 13.17, + 12.76, + 12.21, + 11.72, + 11.09, + 10.68, + 10.35, + 10.02, + 9.61, + 9.22, + 9.03, + 8.66, + 8.28, + 8.1, + 7.74, + 7.62, + 7.32, + 7.04, + 6.94, + 6.61, + 6.5, + 6.27, + 6.15, + 5.96, + 5.83, + 5.59, + 5.46, + 5.38, + 5.15, + 5.05, + 4.9, + 4.86, + 4.67, + 4.56, + 4.52, + 4.41, + 4.31, + 4.21, + 4.11, + 3.99, + ] + self.r = min(range(len(rfactor)), key=lambda i: abs(rfactor[i] - self.acceleration)) + 40 + + pattern1 = self.poisson_disc2d() + pattern2 = self.centered_circle() + mask = np.logical_or(pattern1, pattern2) + + if half_scan_percentage != 0: + mask[: int(np.round(mask.shape[0] * half_scan_percentage)), :] = 0.0 + + return (torch.from_numpy(mask.reshape(dims).astype(np.float32)), acceleration)
+ +
[docs] def poisson_disc2d(self): + """Creates a 2D Poisson disc pattern.""" + # Amount of tries before discarding a reference point for new samples + k = 10 + + # Amount of samples to be drawn + pattern_shape = (self.shape[0] - 1, self.shape[1] - 1) + + # Initialize the pattern + center = np.array([1.0 * pattern_shape[0] / 2, 1.0 * pattern_shape[1] / 2]) + width, height = pattern_shape + + # Cell side length (equal to r_min) + a = 1 + + # Number of cells in the x- and y-directions of the grid + nx, ny = int(width / a), int(height / a) + + # A list of coordinates in the grid of cells + coords_list = [(ix, iy) for ix in range(nx + 1) for iy in range(ny + 1)] + + # Initialize the dictionary of cells: each key is a cell's coordinates, the corresponding value is the index + # of that cell's point's that might cause conflict when adding a new point. + cells = {coords: [] for coords in coords_list} + centernorm = np.linalg.norm(center) + + def calc_r(coords): + """Calculate r for the given coordinates.""" + return ((np.linalg.norm(np.asarray(coords) - center) / centernorm) * 240 + 50) / self.r + + def get_cell_coords(pt): + """Get the coordinates of the cell that pt = (x,y) falls in.""" + return int(np.floor_divide(pt[0], a)), int(np.floor_divide(pt[1], a)) + + def mark_neighbours(idx): + """Add sample index to the cells within r(point) range of the point.""" + coords = samples[idx] + if idx in cells[get_cell_coords(coords)]: + # This point is already marked on the grid, so we can skip + return + + # Mark the point on the grid + rx = calc_r(coords) + xvals = np.arange(coords[0] - rx, coords[0] + rx) + yvals = np.arange(coords[1] - rx, coords[1] + rx) + + # Get the coordinates of the cells that the point falls in + xvals = xvals[(xvals >= 0) & (xvals <= width)] + yvals = yvals[(yvals >= 0) & (yvals <= height)] + + def dist(x, y): + """Calculate the distance between the point and the cell.""" + return np.sqrt((coords[0] - x) ** 2 + (coords[1] - y) ** 2) < rx + + xx, yy = np.meshgrid(xvals, yvals, sparse=False) + + # Mark the points in the grid + pts = np.vstack((xx.ravel(), yy.ravel())).T + pts = pts[dist(pts[:, 0], pts[:, 1])] + + return [cells[get_cell_coords(pt)].append(idx) for pt in pts] + + def point_valid(pt): + """Check if the point is valid.""" + rx = calc_r(pt) + if rx < 1: + if np.linalg.norm(pt - center) < self.scale * width: + return False + rx = 1 + + # Get the coordinates of the cells that the point falls in + neighbour_idxs = cells[get_cell_coords(pt)] + for n in neighbour_idxs: + n_coords = samples[n] + + # Squared distance between or candidate point, pt, and this nearby_pt. + distance = np.sqrt((n_coords[0] - pt[0]) ** 2 + (n_coords[1] - pt[1]) ** 2) + if distance < rx: + # The points are too close, so pt is not a candidate. + return False + + # All points tested: if we're here, pt is + return True + + def get_point(k, refpt): + """ + Try to find a candidate point relative to refpt to emit in the sample. We draw up to k points from the + annulus of inner radius r, outer radius 2r around the reference point, refpt. If none of them are suitable + return False. Otherwise, return the pt. + """ + i = 0 + rx = calc_r(refpt) + while i < k: + rho, theta = np.random.uniform(rx, 2 * rx), np.random.uniform(0, 2 * np.pi) + pt = refpt[0] + rho * np.cos(theta), refpt[1] + rho * np.sin(theta) + if not (0 < pt[0] < width and 0 < pt[1] < height): + # Off the grid, try again. + continue + if point_valid(pt): + return pt + i += 1 + + # We failed to find a suitable point in the vicinity of refpt. + return False + + # Pick a random point to start with. + pt = (np.random.uniform(0, width), np.random.uniform(0, height)) + samples = [pt] + cursample = 0 + mark_neighbours(0) + + # Set active, in the sense that we're going to look for more points in its neighbourhood. + active = [0] + + # As long as there are points in the active list, keep trying to find samples. + while active: + # choose a random "reference" point from the active list. + idx = np.random.choice(active) + refpt = samples[idx] + + # Try to pick a new point relative to the reference point. + pt = get_point(k, refpt) + if pt: + # Point pt is valid: add it to the samples list and mark it as active + samples.append(pt) + cursample += 1 + active.append(cursample) + mark_neighbours(cursample) + else: + # We had to give up looking for valid points near refpt, so remove it from the list of "active" points. + active.remove(idx) + + samples = np.rint(np.array(samples)).astype(int) + samples = np.unique(samples[:, 0] + 1j * samples[:, 1]) + samples = np.column_stack((samples.real, samples.imag)).astype(int) + + poisson_pattern = np.zeros((pattern_shape[0] + 1, pattern_shape[1] + 1), dtype=bool) + poisson_pattern[samples[:, 0], samples[:, 1]] = True + + return poisson_pattern
+ +
[docs] def centered_circle(self): + """Creates a boolean centered circle image using the scale as a radius.""" + center_x = int((self.shape[0] - 1) / 2) + center_y = int((self.shape[1] - 1) / 2) + + X, Y = np.indices(self.shape) + radius = int(self.shape[0] * self.scale) + return ((X - center_x) ** 2 + (Y - center_y) ** 2) < radius**2
+ + +
[docs]def create_mask_for_mask_type( + mask_type_str: str, center_fractions: Sequence[float], accelerations: Sequence[int] +) -> MaskFunc: + """ + Creates a MaskFunc object for the given mask type. + + Parameters + ---------- + mask_type_str: The string representation of the mask type. + center_fractions: The center fractions for the mask. + accelerations: The accelerations for the mask. + + Returns + ------- + A MaskFunc object. + """ + if mask_type_str == "random1d": + return RandomMaskFunc(center_fractions, accelerations) + if mask_type_str == "equispaced1d": + return EquispacedMaskFunc(center_fractions, accelerations) + if mask_type_str == "gaussian1d": + return Gaussian1DMaskFunc(center_fractions, accelerations) + if mask_type_str == "gaussian2d": + return Gaussian2DMaskFunc(center_fractions, accelerations) + if mask_type_str == "poisson2d": + return Poisson2DMaskFunc(center_fractions, accelerations) + raise NotImplementedError(f"{mask_type_str} not supported")
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/metrics/evaluate.html b/docs/build/html/_modules/mridc/collections/reconstruction/metrics/evaluate.html new file mode 100644 index 00000000..e183c16f --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/metrics/evaluate.html @@ -0,0 +1,402 @@ + + + + + + mridc.collections.reconstruction.metrics.evaluate — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.metrics.evaluate
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.metrics.evaluate

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI
+import os
+import pathlib
+from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
+from os.path import exists
+
+import h5py
+import numpy as np
+import pandas as pd
+import torch
+from runstats import Statistics
+from skimage.filters import threshold_otsu
+from skimage.metrics import peak_signal_noise_ratio, structural_similarity
+from skimage.morphology import convex_hull_image
+from tqdm import tqdm
+
+from mridc.collections.common.parts.fft import ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul, tensor_to_complex_np, to_tensor
+from mridc.collections.reconstruction.parts.utils import center_crop
+
+
+
[docs]def mse(gt: np.ndarray, pred: np.ndarray) -> float: + """Compute Mean Squared Error (MSE)""" + return np.mean((gt - pred) ** 2) # type: ignore
+ + +
[docs]def nmse(gt: np.ndarray, pred: np.ndarray) -> float: + """Compute Normalized Mean Squared Error (NMSE)""" + return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2
+ + +
[docs]def psnr(gt: np.ndarray, pred: np.ndarray, maxval: np.ndarray = None) -> float: + """Compute Peak Signal to Noise Ratio metric (PSNR)""" + if maxval is None: + maxval = np.max(gt) + return peak_signal_noise_ratio(gt, pred, data_range=maxval)
+ + +
[docs]def ssim(gt: np.ndarray, pred: np.ndarray, maxval: np.ndarray = None) -> float: + """Compute Structural Similarity Index Metric (SSIM)""" + if gt.ndim != 3: + raise ValueError("Unexpected number of dimensions in ground truth.") + if gt.ndim != pred.ndim: + raise ValueError("Ground truth dimensions does not match pred.") + + maxval = np.max(gt) if maxval is None else maxval + + _ssim = sum( + structural_similarity(gt[slice_num], pred[slice_num], data_range=maxval) for slice_num in range(gt.shape[0]) + ) + + return _ssim / gt.shape[0]
+ + +METRIC_FUNCS = dict(MSE=mse, NMSE=nmse, PSNR=psnr, SSIM=ssim) + + +
[docs]class Metrics: + """Maintains running statistics for a given collection of metrics.""" + + def __init__(self, metric_funcs, output_path, method): + """ + Parameters + ---------- + metric_funcs (dict): A dict where the keys are metric names and the values are Python functions for evaluating + that metric. + output_path: path to the output directory + method: reconstruction method + """ + self.metrics_scores = {metric: Statistics() for metric in metric_funcs} + self.output_path = output_path + self.method = method + +
[docs] def push(self, target, recons): + """ + Pushes a new batch of metrics to the running statistics. + + Parameters + ---------- + target: target image + recons: reconstructed image + + Returns + ------- + dict: A dict where the keys are metric names and the values are + """ + for metric, func in METRIC_FUNCS.items(): + self.metrics_scores[metric].push(func(target, recons))
+ +
[docs] def means(self): + """Mean of the means of each metric.""" + return {metric: stat.mean() for metric, stat in self.metrics_scores.items()}
+ +
[docs] def stddevs(self): + """Standard deviation of the means of each metric.""" + return {metric: stat.stddev() for metric, stat in self.metrics_scores.items()}
+ +
[docs] def __repr__(self): + """Representation of the metrics.""" + means = self.means() + stddevs = self.stddevs() + metric_names = sorted(list(means)) + + res = " ".join(f"{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}" for name in metric_names) + "\n" + + with open(f"{self.output_path}metrics.txt", "a") as output: + output.write(f"{self.method}: {res}") + + return res
+ + +
[docs]def evaluate( + arguments, + reconstruction_key, + mask_background, + output_path, + method, + acc, + no_params, + slice_start, + slice_end, +): + """ + Evaluates the reconstructions. + + Parameters + ---------- + arguments: The CLI arguments. + reconstruction_key: The key of the reconstruction to evaluate. + mask_background: The background mask. + output_path: The output path. + method: The reconstruction method. + acc: The acceleration factor. + no_params: The number of parameters. + slice_start: The start slice. (optional) + slice_end: The end slice. (optional) + + Returns + ------- + dict: A dict where the keys are metric names and the values are the mean of the metric. + """ + _metrics = Metrics(METRIC_FUNCS, output_path, method) if arguments.type == "mean_std" else {} + + for tgt_file in tqdm(arguments.target_path.iterdir()): + if exists(arguments.predictions_path / tgt_file.name): + with h5py.File(tgt_file, "r") as target, h5py.File( + arguments.predictions_path / tgt_file.name, "r" + ) as recons: + kspace = target["kspace"][()] + + if arguments.sense_path is not None: + sense = h5py.File(arguments.sense_path / tgt_file.name, "r")["sensitivity_map"][()] + elif "sensitivity_map" in target: + sense = target["sensitivity_map"][()] + + sense = sense.squeeze().astype(np.complex64) + + if sense.shape != kspace.shape: + sense = np.transpose(sense, (0, 3, 1, 2)) + + target = np.abs( + tensor_to_complex_np( + torch.sum( + complex_mul( + ifft2c( + to_tensor(kspace), + fft_type="orthogonal" + if "fastmri" in str(arguments.sense_path).lower() + else "other", + ), + complex_conj(to_tensor(sense)), + ), + 1, + ) + ) + ) + + recons = recons[reconstruction_key][()] + + if recons.ndim == 4: + recons = recons.squeeze(1) + + if arguments.crop_size is not None: + crop_size = arguments.crop_size + crop_size[0] = min(target.shape[-2], int(crop_size[0])) + crop_size[1] = min(target.shape[-1], int(crop_size[1])) + crop_size[0] = min(recons.shape[-2], int(crop_size[0])) + crop_size[1] = min(recons.shape[-1], int(crop_size[1])) + + target = center_crop(target, crop_size) + recons = center_crop(recons, crop_size) + + if mask_background: + for sl in range(target.shape[0]): + mask = convex_hull_image( + np.where(np.abs(target[sl]) > threshold_otsu(np.abs(target[sl])), 1, 0) # type: ignore + ) + target[sl] = target[sl] * mask + recons[sl] = recons[sl] * mask + + if slice_start is not None: + target = target[slice_start:] + recons = recons[slice_start:] + + if slice_end is not None: + target = target[:slice_end] + recons = recons[:slice_end] + + for sl in range(target.shape[0]): + target[sl] = target[sl] / np.max(np.abs(target[sl])) + recons[sl] = recons[sl] / np.max(np.abs(recons[sl])) + + target = np.abs(target) + recons = np.abs(recons) + + if arguments.type == "mean_std": + _metrics.push(target, recons) + else: + _target = np.expand_dims(target, 1) + _recons = np.expand_dims(recons, 1) + for sl in range(target.shape[0]): + _metrics["FNAME"] = tgt_file.name + _metrics["SLICE"] = sl + _metrics["ACC"] = acc + _metrics["METHOD"] = method + _metrics["MSE"] = [mse(target[sl], recons[sl])] + _metrics["NMSE"] = [nmse(target[sl], recons[sl])] + _metrics["PSNR"] = [psnr(target[sl], recons[sl])] + _metrics["SSIM"] = [ssim(_target[sl], _recons[sl])] + _metrics["PARAMS"] = no_params + + if not exists(arguments.output_path): + pd.DataFrame(columns=_metrics.keys()).to_csv(arguments.output_path, index=False, mode="w") + pd.DataFrame(_metrics).to_csv(arguments.output_path, index=False, header=False, mode="a") + + return _metrics
+ + +if __name__ == "__main__": + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument("target_path", type=pathlib.Path, help="Path to the ground truth data") + parser.add_argument("predictions_path", type=pathlib.Path, help="Path to reconstructions") + parser.add_argument("output_path", type=str, help="Path to save the metrics") + parser.add_argument("--sense_path", type=pathlib.Path, help="Path to the sense data") + parser.add_argument( + "--challenge", + choices=["singlecoil", "multicoil", "multicoil_sense", "multicoil_other"], + default="multicoil_other", + help="Which challenge", + ) + parser.add_argument("--crop_size", nargs="+", default=None, help="Set crop size.") + parser.add_argument("--method", type=str, required=True, help="Model's name to evaluate") + parser.add_argument("--acceleration", type=int, required=True, default=None) + parser.add_argument("--no_params", type=str, required=True, default=None) + parser.add_argument( + "--acquisition", + choices=["CORPD_FBK", "CORPDFS_FBK", "AXT1", "AXT1PRE", "AXT1POST", "AXT2", "AXFLAIR"], + default=None, + help="If set, only volumes of the specified acquisition type are used for " + "evaluation. By default, all volumes are included.", + ) + parser.add_argument( + "--fill_pred_path", action="store_true", help="Find reconstructions folder in predictions path" + ) + parser.add_argument("--mask_background", action="store_true", help="Toggle to mask background") + parser.add_argument("--type", choices=["mean_std", "all_slices"], default="mean_std", help="Output type.") + parser.add_argument("--slice_start", type=int, help="Select to skip first slices") + parser.add_argument("--slice_end", type=int, help="Select to skip last slices") + + args = parser.parse_args() + + if args.fill_pred_path: + dir = "" + for root, dirs, files in os.walk(args.predictions_path, topdown=False): + for name in dirs: + dir = os.path.join(root, name) + args.predictions_path = pathlib.Path(f"{dir}/reconstructions/") + + if args.challenge == "multicoil": + recons_key = "reconstruction_rss" + elif args.challenge == "multicoil_sense": + recons_key = "reconstruction_sense" + elif args.challenge == "singlecoil": + recons_key = "reconstruction_esc" + else: + recons_key = "reconstruction" + + metrics = evaluate( + args, + recons_key, + args.mask_background, + args.output_path, + args.method, + args.acceleration, + args.no_params, + args.slice_start, + args.slice_end, + ) + + if args.type == "mean_std": + print(metrics) + elif args.type == "all_slices": + print("Done, csv file saved!") +
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/base.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/base.html new file mode 100644 index 00000000..32904008 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/base.html @@ -0,0 +1,882 @@ + + + + + + mridc.collections.reconstruction.models.base — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.base
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.base

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+import os
+from abc import ABC
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, Optional, Tuple
+
+import h5py
+import numpy as np
+import torch
+import wandb
+from omegaconf import DictConfig
+from pytorch_lightning import Trainer
+from torch import nn
+from torch.utils.data import DataLoader
+from torchmetrics.metric import Metric
+
+from mridc.collections.common.parts.fft import ifft2c
+from mridc.collections.common.parts.utils import rss_complex
+from mridc.collections.reconstruction.data.mri_data import FastMRISliceDataset
+from mridc.collections.reconstruction.data.subsample import create_mask_for_mask_type
+from mridc.collections.reconstruction.metrics.evaluate import mse, nmse, psnr, ssim
+from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
+from mridc.collections.reconstruction.parts.transforms import MRIDataTransforms
+from mridc.collections.reconstruction.parts.utils import batched_mask_center
+from mridc.core.classes.modelPT import ModelPT
+from mridc.utils.model_utils import convert_model_config_to_dict_config, maybe_update_config_version
+
+__all__ = ["BaseMRIReconstructionModel", "BaseSensitivityModel"]
+
+
+class DistributedMetricSum(Metric):
+    """
+    A metric that sums the values of a metric across all workers.
+    Taken from: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/pl_modules/mri_module.py
+    """
+
+    def __init__(self, dist_sync_on_step=True):
+        super().__init__(dist_sync_on_step=dist_sync_on_step)
+
+        self.add_state("quantity", default=torch.tensor(0.0), dist_reduce_fx="sum")
+
+    def update(self, batch: torch.Tensor):  # type: ignore
+        """Update the metric with a batch of data."""
+        self.quantity += batch
+
+    def compute(self):
+        """Compute the metric value."""
+        return self.quantity
+
+
+
[docs]class BaseMRIReconstructionModel(ModelPT, ABC): + """Base class of all MRIReconstruction models.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + + # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable + self.world_size = 1 + if trainer is not None: + self.world_size = trainer.num_nodes * trainer.num_devices + + cfg = convert_model_config_to_dict_config(cfg) + cfg = maybe_update_config_version(cfg) + + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + self.MSE = DistributedMetricSum() + self.NMSE = DistributedMetricSum() + self.SSIM = DistributedMetricSum() + self.PSNR = DistributedMetricSum() + self.TotExamples = DistributedMetricSum() + + # Set evaluation metrics dictionaries + self.mse_vals: Dict = defaultdict(dict) + self.nmse_vals: Dict = defaultdict(dict) + self.ssim_vals: Dict = defaultdict(dict) + self.psnr_vals: Dict = defaultdict(dict) + + # skipcq: PYL-R0201 +
[docs] def process_loss(self, target, pred, _loss_fn): + """ + Processes the loss. + + Parameters + ---------- + target: Target data. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + pred: Final prediction(s). + list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or + torch.Tensor, shape [batch_size, n_x, n_y, 2] + _loss_fn: Loss function. + torch.nn.Module, default torch.nn.L1Loss() + + Returns + ------- + loss: torch.FloatTensor, shape [1] + If self.accumulate_loss is True, returns an accumulative result of all intermediate losses. + """ + target = torch.abs(target / torch.max(torch.abs(target))) + if "ssim" in str(_loss_fn).lower(): + max_value = np.array(torch.max(torch.abs(target)).item()).astype(np.float32) + + def loss_fn(x, y): + """Calculate the ssim loss.""" + return _loss_fn( + x.unsqueeze(dim=1), + torch.abs(y / torch.max(torch.abs(y))).unsqueeze(dim=1), + data_range=torch.tensor(max_value).unsqueeze(dim=0).to(x.device), + ) + + else: + + def loss_fn(x, y): + """Calculate other loss.""" + return _loss_fn(x, torch.abs(y / torch.max(torch.abs(y)))) + + return loss_fn(target, pred)
+ +
[docs] @staticmethod + def process_inputs(y, mask, init_pred): + """ + Processes the inputs to the method. + + Parameters + ---------- + y: Subsampled k-space data. + list of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + list of torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + list of torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + y: Subsampled k-space data. + randomly selected y + mask: Sampling mask. + randomly selected mask + init_pred: Initial prediction. + randomly selected init_pred + r: Random index. + """ + if isinstance(y, list): + r = np.random.randint(len(y)) + y = y[r] + mask = mask[r] + else: + r = 0 + return y, mask, init_pred, r
+ +
[docs] def training_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]: + """ + Performs a training step. + + Parameters + ---------- + batch: Batch of data. + Dict[str, torch.Tensor], with keys, + + 'y': subsampled kspace, + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + 'sensitivity_maps': sensitivity_maps, + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + 'mask': sampling mask, + torch.Tensor, shape [1, 1, n_x, n_y, 1] + 'init_pred': initial prediction. For example zero-filled or PICS. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + 'target': target data, + torch.Tensor, shape [batch_size, n_x, n_y, 2] + 'fname': filename, + str, shape [batch_size] + 'slice_idx': slice_idx, + torch.Tensor, shape [batch_size] + 'acc': acceleration factor, + torch.Tensor, shape [batch_size] + 'max_value': maximum value of the magnitude image space, + torch.Tensor, shape [batch_size] + 'crop_size': crop size, + torch.Tensor, shape [n_x, n_y] + batch_idx: Batch index. + int + + Returns + ------- + Dict[str, torch.Tensor], with keys, + 'loss': loss, + torch.Tensor, shape [1] + 'log': log, + dict, shape [1] + """ + y, sensitivity_maps, mask, init_pred, target, _, _, acc = batch + y, mask, init_pred, r = self.process_inputs(y, mask, init_pred) + preds = self.forward(y, sensitivity_maps, mask, init_pred, target) + + if self.accumulate_estimates: + try: + preds = next(preds) + except StopIteration: + pass + + train_loss = sum(self.process_loss(target, preds, _loss_fn=self.train_loss_fn)) + else: + train_loss = self.process_loss(target, preds, _loss_fn=self.train_loss_fn) + + acc = r if r != 0 else acc + tensorboard_logs = { + f"train_loss_{acc}x": train_loss.item(), # type: ignore + "lr": self._optimizer.param_groups[0]["lr"], # type: ignore + } + return {"loss": train_loss, "log": tensorboard_logs}
+ +
[docs] def validation_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Dict: + """ + Performs a validation step. + + Parameters + ---------- + batch: Batch of data. Dict[str, torch.Tensor], with keys, + 'y': subsampled kspace, + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + 'sensitivity_maps': sensitivity_maps, + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + 'mask': sampling mask, + torch.Tensor, shape [1, 1, n_x, n_y, 1] + 'init_pred': initial prediction. For example zero-filled or PICS. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + 'target': target data, + torch.Tensor, shape [batch_size, n_x, n_y, 2] + 'fname': filename, + str, shape [batch_size] + 'slice_idx': slice_idx, + torch.Tensor, shape [batch_size] + 'acc': acceleration factor, + torch.Tensor, shape [batch_size] + 'max_value': maximum value of the magnitude image space, + torch.Tensor, shape [batch_size] + 'crop_size': crop size, + torch.Tensor, shape [n_x, n_y] + batch_idx: Batch index. + int + + Returns + ------- + Dict[str, torch.Tensor], with keys, + 'loss': loss, + torch.Tensor, shape [1] + 'log': log, + dict, shape [1] + """ + y, sensitivity_maps, mask, init_pred, target, fname, slice_num, _ = batch + y, mask, init_pred, _ = self.process_inputs(y, mask, init_pred) + preds = self.forward(y, sensitivity_maps, mask, init_pred, target) + + if self.accumulate_estimates: + try: + preds = next(preds) + except StopIteration: + pass + + val_loss = sum(self.process_loss(target, preds, _loss_fn=self.eval_loss_fn)) + else: + val_loss = self.process_loss(target, preds, _loss_fn=self.eval_loss_fn) + + # Cascades + if isinstance(preds, list): + preds = preds[-1] + + # Time-steps + if isinstance(preds, list): + preds = preds[-1] + + key = f"{fname[0]}_images_idx_{int(slice_num)}" # type: ignore + output = torch.abs(preds).detach().cpu() + target = torch.abs(target).detach().cpu() + output = output / output.max() # type: ignore + target = target / target.max() # type: ignore + error = torch.abs(target - output) + self.log_image(f"{key}/target", target) + self.log_image(f"{key}/reconstruction", output) + self.log_image(f"{key}/error", error) + + target = target.numpy() # type: ignore + output = output.numpy() # type: ignore + self.mse_vals[fname][slice_num] = torch.tensor(mse(target, output)).view(1) + self.nmse_vals[fname][slice_num] = torch.tensor(nmse(target, output)).view(1) + self.ssim_vals[fname][slice_num] = torch.tensor(ssim(target, output, maxval=output.max() - output.min())).view( + 1 + ) + self.psnr_vals[fname][slice_num] = torch.tensor(psnr(target, output, maxval=output.max() - output.min())).view( + 1 + ) + + return {"val_loss": val_loss}
+ +
[docs] def test_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Tuple[str, int, torch.Tensor]: + """ + Performs a test step. + + Parameters + ---------- + batch: Batch of data. Dict[str, torch.Tensor], with keys, + 'y': subsampled kspace, + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + 'sensitivity_maps': sensitivity_maps, + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + 'mask': sampling mask, + torch.Tensor, shape [1, 1, n_x, n_y, 1] + 'init_pred': initial prediction. For example zero-filled or PICS. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + 'target': target data, + torch.Tensor, shape [batch_size, n_x, n_y, 2] + 'fname': filename, + str, shape [batch_size] + 'slice_idx': slice_idx, + torch.Tensor, shape [batch_size] + 'acc': acceleration factor, + torch.Tensor, shape [batch_size] + 'max_value': maximum value of the magnitude image space, + torch.Tensor, shape [batch_size] + 'crop_size': crop size, + torch.Tensor, shape [n_x, n_y] + batch_idx: Batch index. + int + + Returns + ------- + name: Name of the volume. + str + slice_num: Slice number. + int + pred: Predicted data. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + """ + y, sensitivity_maps, mask, init_pred, target, fname, slice_num, _ = batch + y, mask, init_pred, _ = self.process_inputs(y, mask, init_pred) + preds = self.forward(y, sensitivity_maps, mask, init_pred, target) + + if self.accumulate_estimates: + try: + preds = next(preds) + except StopIteration: + pass + + # Cascades + if isinstance(preds, list): + preds = preds[-1] + + # Time-steps + if isinstance(preds, list): + preds = preds[-1] + + slice_num = int(slice_num) + name = str(fname[0]) # type: ignore + key = f"{name}_images_idx_{slice_num}" # type: ignore + + output = torch.abs(preds).detach().cpu() + output = output / output.max() # type: ignore + + target = torch.abs(target).detach().cpu() + target = target / target.max() # type: ignore + + error = torch.abs(target - output) + + self.log_image(f"{key}/target", target) + self.log_image(f"{key}/reconstruction", output) + self.log_image(f"{key}/error", error) + + return name, slice_num, preds.detach().cpu().numpy()
+ +
[docs] def log_image(self, name, image): + """ + Logs an image. + + Parameters + ---------- + name: Name of the image. + str + image: Image to log. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + """ + if "wandb" in self.logger.__module__.lower(): + self.logger.experiment.log({name: wandb.Image(image.numpy())}) + else: + self.logger.experiment.add_image(name, image, global_step=self.global_step)
+ +
[docs] def validation_epoch_end(self, outputs): + """ + Called at the end of validation epoch to aggregate outputs. + + Parameters + ---------- + outputs: List of outputs of the validation batches. + list of dicts + + Returns + ------- + metrics: Dictionary of metrics. + dict + """ + self.log("val_loss", torch.stack([x["val_loss"] for x in outputs]).mean()) + + # Log metrics. + # Taken from: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/pl_modules/mri_module.py + mse_vals = defaultdict(dict) + nmse_vals = defaultdict(dict) + ssim_vals = defaultdict(dict) + psnr_vals = defaultdict(dict) + + for k in self.mse_vals.keys(): + mse_vals[k].update(self.mse_vals[k]) + for k in self.nmse_vals.keys(): + nmse_vals[k].update(self.nmse_vals[k]) + for k in self.ssim_vals.keys(): + ssim_vals[k].update(self.ssim_vals[k]) + for k in self.psnr_vals.keys(): + psnr_vals[k].update(self.psnr_vals[k]) + + # apply means across image volumes + metrics = {"MSE": 0, "NMSE": 0, "SSIM": 0, "PSNR": 0} + local_examples = 0 + for fname in mse_vals: + local_examples += 1 + metrics["MSE"] = metrics["MSE"] + torch.mean(torch.cat([v.view(-1) for _, v in mse_vals[fname].items()])) + metrics["NMSE"] = metrics["NMSE"] + torch.mean( + torch.cat([v.view(-1) for _, v in nmse_vals[fname].items()]) + ) + metrics["SSIM"] = metrics["SSIM"] + torch.mean( + torch.cat([v.view(-1) for _, v in ssim_vals[fname].items()]) + ) + metrics["PSNR"] = metrics["PSNR"] + torch.mean( + torch.cat([v.view(-1) for _, v in psnr_vals[fname].items()]) + ) + + # reduce across ddp via sum + metrics["MSE"] = self.MSE(metrics["MSE"]) + metrics["NMSE"] = self.NMSE(metrics["NMSE"]) + metrics["SSIM"] = self.SSIM(metrics["SSIM"]) + metrics["PSNR"] = self.PSNR(metrics["PSNR"]) + + tot_examples = self.TotExamples(torch.tensor(local_examples)) + for metric, value in metrics.items(): + self.log(f"{metric}", value / tot_examples)
+ +
[docs] def test_epoch_end(self, outputs): + """ + Called at the end of test epoch to aggregate outputs. + + Parameters + ---------- + outputs: List of outputs of the test batches. + list of dicts + + Returns + ------- + Saves the reconstructed images to .h5 files. + """ + reconstructions = defaultdict(list) + for fname, slice_num, output in outputs: + reconstructions[fname].append((slice_num, output)) + + for fname in reconstructions: + reconstructions[fname] = np.stack([out for _, out in sorted(reconstructions[fname])]) # type: ignore + + out_dir = Path(os.path.join(self.logger.log_dir, "reconstructions")) + out_dir.mkdir(exist_ok=True, parents=True) + for fname, recons in reconstructions.items(): + with h5py.File(out_dir / fname, "w") as hf: + hf.create_dataset("reconstruction", data=recons)
+ +
[docs] def setup_training_data(self, train_data_config: Optional[DictConfig]): + """ + Setups the training data. + + Parameters + ---------- + train_data_config: Training data configuration. + dict + + Returns + ------- + train_data: Training data. + torch.utils.data.DataLoader + """ + self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config)
+ +
[docs] def setup_validation_data(self, val_data_config: Optional[DictConfig]): + """ + Setups the validation data. + + Parameters + ---------- + val_data_config: Validation data configuration. + dict + + Returns + ------- + val_data: Validation data. + torch.utils.data.DataLoader + """ + self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config)
+ +
[docs] def setup_test_data(self, test_data_config: Optional[DictConfig]): + """ + Setups the test data. + + Parameters + ---------- + test_data_config: Test data configuration. + dict + + Returns + ------- + test_data: Test data. + torch.utils.data.DataLoader + """ + self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)
+ + @staticmethod + def _setup_dataloader_from_config(cfg: DictConfig) -> DataLoader: + """ + Setups the dataloader from the configuration (yaml) file. + + Parameters + ---------- + cfg: Configuration file. + dict + + Returns + ------- + dataloader: DataLoader. + torch.utils.data.DataLoader + """ + if cfg.get("dataset_type") != "FastMRI": + raise ValueError(f"Unknown dataset type: {cfg.get('dataset_type')}") + + mask_args = cfg.get("mask_args") + mask_type = mask_args.get("type") + shift_mask = mask_args.get("shift_mask") + + if mask_type is not None and mask_type != "None": + accelerations = mask_args.get("accelerations") + center_fractions = mask_args.get("center_fractions") + mask_center_scale = mask_args.get("scale") + + mask_func = ( + [ + create_mask_for_mask_type(mask_type, [cf] * 2, [acc] * 2) + for acc, cf in zip(accelerations, center_fractions) + ] + if len(accelerations) > 2 + else [create_mask_for_mask_type(mask_type, center_fractions, accelerations)] + ) + else: + mask_func = None # type: ignore + mask_center_scale = 0.02 + + dataset = FastMRISliceDataset( + root=cfg.get("data_path"), + sense_root=cfg.get("sense_data_path"), + challenge=cfg.get("challenge"), + transform=MRIDataTransforms( + mask_func=mask_func, + shift_mask=shift_mask, + mask_center_scale=mask_center_scale, + normalize_inputs=cfg.get("normalize_inputs"), + crop_size=cfg.get("crop_size"), + crop_before_masking=cfg.get("crop_before_masking"), + kspace_zero_filling_size=cfg.get("kspace_zero_filling_size"), + fft_type=cfg.get("fft_type"), + use_seed=cfg.get("use_seed"), + ), + sample_rate=cfg.get("sample_rate"), + ) + if cfg.shuffle: + sampler = torch.utils.data.RandomSampler(dataset) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=cfg.get("batch_size"), + sampler=sampler, + num_workers=cfg.get("num_workers", 2), + pin_memory=cfg.get("pin_memory", False), + drop_last=cfg.get("drop_last", False), + )
+ + +
[docs]class BaseSensitivityModel(nn.Module, ABC): + """ + Model for learning sensitivity estimation from k-space data. + This model applies an IFFT to multichannel k-space data and then a U-Net to the coil images to estimate coil + sensitivities. + """ + + def __init__( + self, + chans: int = 8, + num_pools: int = 4, + in_chans: int = 2, + out_chans: int = 2, + drop_prob: float = 0.0, + padding_size: int = 15, + mask_type: str = "2D", # TODO: make this generalizable + fft_type: str = "orthogonal", + normalize: bool = True, + mask_center: bool = True, + ): + """ + Initializes the model. + + Parameters + ---------- + chans: Number of channels in the input k-space data. + int + num_pools: Number of U-Net downsampling/upsampling operations. + int + in_chans: Number of channels in the input data. + int + out_chans: Number of channels in the output data. + int + drop_prob: Dropout probability. + float + padding_size: Size of the zero-padding. + int + mask_type: Type of mask to use. + str + fft_type: Type of FFT to use. + str + normalize: Whether to normalize the input data. + bool + mask_center: Whether mask the center of the image. + bool + """ + super().__init__() + + self.mask_type = mask_type + self.fft_type = fft_type + + self.norm_unet = NormUnet( + chans, + num_pools, + in_chans=in_chans, + out_chans=out_chans, + drop_prob=drop_prob, + padding_size=padding_size, + normalize=normalize, + ) + + self.mask_center = mask_center + self.normalize = normalize + +
[docs] @staticmethod + def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]: + """ + Converts the number of channels in a tensor to the batch dimension. + + Parameters + ---------- + x: Tensor to convert. + torch.Tensor + + Returns + ------- + Tuple of the converted tensor and the original last dimension. + Tuple[torch.Tensor, int] + """ + b, c, h, w, comp = x.shape + + return x.view(b * c, 1, h, w, comp), b
+ +
[docs] @staticmethod + def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor: + """ + Converts the number of channels in a tensor to the channel dimension. + + Parameters + ---------- + x: Tensor to convert. + torch.Tensor + batch_size: Original batch size. + int + + Returns + ------- + Converted tensor. + torch.Tensor + """ + bc, _, h, w, comp = x.shape + c = torch.div(bc, batch_size, rounding_mode="trunc") + + return x.view(batch_size, c, h, w, comp)
+ +
[docs] @staticmethod + def divide_root_sum_of_squares(x: torch.Tensor) -> torch.Tensor: + """ + Divide the input by the root of the sum of squares of the magnitude of each complex number. + + Parameters + ---------- + x: Tensor to divide. + torch.Tensor + + Returns + ------- + RSS output tensor. + torch.Tensor + """ + return x / rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1)
+ +
[docs] @staticmethod + def get_pad_and_num_low_freqs( + mask: torch.Tensor, num_low_frequencies: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the padding to apply to the input to make it square and the number of low frequencies to keep. + + Parameters + ---------- + mask: Mask to use. + torch.Tensor + num_low_frequencies: Number of low frequencies to keep. + int + + Returns + ------- + Tuple of the padding and the number of low frequencies to keep. + Tuple[torch.Tensor, torch.Tensor] + """ + if num_low_frequencies is None or num_low_frequencies == 0: + # get low frequency line locations and mask them out + squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8) + cent = torch.div(squeezed_mask.shape[1], 2, rounding_mode="trunc") + # running argmin returns the first non-zero + left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1) + right = torch.argmin(squeezed_mask[:, cent:], dim=1) + num_low_frequencies_tensor = torch.max( + 2 * torch.min(left, right), torch.ones_like(left) + ) # force a symmetric center unless 1 + else: + num_low_frequencies_tensor = num_low_frequencies * torch.ones( + mask.shape[0], dtype=mask.dtype, device=mask.device + ) + + pad = torch.div(mask.shape[-2] - num_low_frequencies_tensor + 1, 2, rounding_mode="trunc") + + return pad, num_low_frequencies_tensor
+ +
[docs] def forward( + self, + masked_kspace: torch.Tensor, + mask: torch.Tensor, + num_low_frequencies: Optional[int] = None, + ) -> torch.Tensor: + """ + Forward pass of the model. + + Parameters + ---------- + masked_kspace: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [batch_size, 1, n_x, n_y, 1] + num_low_frequencies: Number of low frequencies to keep. + int + + Returns + ------- + Normalized UNet output tensor. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + """ + if self.mask_center: + pad, num_low_freqs = self.get_pad_and_num_low_freqs(mask, num_low_frequencies) + masked_kspace = batched_mask_center(masked_kspace, pad, pad + num_low_freqs, mask_type=self.mask_type) + + # convert to image space + images, batches = self.chans_to_batch_dim(ifft2c(masked_kspace)) + + # estimate sensitivities + images = self.batch_chans_to_chan_dim(self.norm_unet(images), batches) + if self.normalize: + images = self.divide_root_sum_of_squares(images) + return images
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/cascadenet/ccnn_block.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/cascadenet/ccnn_block.html new file mode 100644 index 00000000..0e24c692 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/cascadenet/ccnn_block.html @@ -0,0 +1,207 @@ + + + + + + mridc.collections.reconstruction.models.cascadenet.ccnn_block — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.cascadenet.ccnn_block
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.cascadenet.ccnn_block

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+import torch
+
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul
+
+
+
[docs]class CascadeNetBlock(torch.nn.Module): + """ + Model block for CascadeNet & Convolution Recurrent Neural Network. + + This model applies a combination of soft data consistency with the input model as a regularizer. + A series of these blocks can be stacked to form the full variational network. + """ + + def __init__(self, model: torch.nn.Module, fft_type: str = "orthogonal", no_dc: bool = False): + """ + Initializes the model block. + + Parameters + ---------- + model: Model to apply soft data consistency. + torch.nn.Module + fft_type: Type of FFT to use. + str + no_dc: Flag to disable the soft data consistency. + bool + """ + super().__init__() + + self.model = model + self.fft_type = fft_type + self.no_dc = no_dc + self.dc_weight = torch.nn.Parameter(torch.ones(1)) + +
[docs] def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + """ + Expand the sensitivity maps to the same size as the input. + + Parameters + ---------- + x: Input data. + torch.Tensor, shape [batch_size, n_coils, height, width, 2] + sens_maps: Sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, height, width, 2] + + Returns + ------- + SENSE reconstruction expanded to the same size as the input. + torch.Tensor, shape [batch_size, n_coils, height, width, 2] + """ + return fft2c(complex_mul(x, sens_maps), fft_type=self.fft_type)
+ +
[docs] def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + """ + Reduce the sensitivity maps to the same size as the input. + + Parameters + ---------- + x: Input data. + torch.Tensor, shape [batch_size, n_coils, height, width, 2] + sens_maps: Sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, height, width, 2] + + Returns + ------- + SENSE reconstruction. + torch.Tensor, shape [batch_size, height, width, 2] + """ + x = ifft2c(x, fft_type=self.fft_type) + return complex_mul(x, complex_conj(sens_maps)).sum(dim=1, keepdim=True)
+ +
[docs] def forward( + self, + pred: torch.Tensor, + ref_kspace: torch.Tensor, + sens_maps: torch.Tensor, + mask: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the model block. + + Parameters + ---------- + pred: Predicted k-space data. + torch.Tensor, shape [batch_size, n_coils, height, width, 2] + ref_kspace: Reference k-space data. + torch.Tensor, shape [batch_size, n_coils, height, width, 2] + sens_maps: Sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, height, width, 2] + mask: Mask to apply to the data. + torch.Tensor, shape [batch_size, 1, height, width, 1] + + Returns + ------- + Reconstructed image. + torch.Tensor, shape [batch_size, height, width, 2] + """ + zero = torch.zeros(1, 1, 1, 1, 1).to(pred) + soft_dc = torch.where(mask.bool(), pred - ref_kspace, zero) * self.dc_weight + + eta = self.sens_reduce(pred, sens_maps) + eta = self.model(eta.squeeze(1).permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + eta = self.sens_expand(eta, sens_maps) + + if not self.no_dc: + eta = pred - soft_dc - eta + + return eta
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/ccnn.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/ccnn.html new file mode 100644 index 00000000..7753634b --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/ccnn.html @@ -0,0 +1,221 @@ + + + + + + mridc.collections.reconstruction.models.ccnn — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.ccnn
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.ccnn

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from abc import ABC
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from torch.nn import L1Loss
+
+from mridc.collections.common.losses.ssim import SSIMLoss
+from mridc.collections.common.parts.fft import ifft2c
+from mridc.collections.common.parts.utils import coil_combination
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.models.cascadenet.ccnn_block import CascadeNetBlock
+from mridc.collections.reconstruction.models.conv.conv2d import Conv2d
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["CascadeNet"]
+
+
+
[docs]class CascadeNet(BaseMRIReconstructionModel, ABC): + """ + Implementation of the Deep Cascade of Convolutional Neural Networks, as presented in Schlemper, J., \ + Caballero, J., Hajnal, J. V., Price, A., & Rueckert, D. + + References + ---------- + + .. + + Schlemper, J., Caballero, J., Hajnal, J. V., Price, A., & Rueckert, D., A Deep Cascade of Convolutional \ + Neural Networks for MR Image Reconstruction. Information Processing in Medical Imaging (IPMI), 2017. \ + Available at: https://arxiv.org/pdf/1703.00555.pdf + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + self.fft_type = cfg_dict.get("fft_type") + + # Cascades of CascadeCNN blocks + self.cascades = torch.nn.ModuleList( + [ + CascadeNetBlock( + Conv2d( + in_channels=2, + out_channels=2, + hidden_channels=cfg_dict.get("hidden_channels"), + n_convs=cfg_dict.get("n_convs"), + batchnorm=cfg_dict.get("batchnorm"), + ), + fft_type=self.fft_type, + no_dc=cfg_dict.get("no_dc"), + ) + for _ in range(cfg_dict.get("num_cascades")) + ] + ) + + self.output_type = cfg_dict.get("output_type") + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_chans"), + cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + normalize=cfg_dict.get("sens_normalize"), + ) + + # initialize weights if not using pretrained ccnn + # TODO if not cfg_dict.get("pretrained", False) + self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() + self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() + + self.accumulate_estimates = False + self.dc_weight = torch.nn.Parameter(torch.ones(1)) + +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + init_pred: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the network. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] + If self.accumulate_loss is True, returns a list of all intermediate estimates. + If False, returns the final estimate. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + pred = y.clone() + for cascade in self.cascades: + pred = cascade(pred, y, sensitivity_maps, mask) + pred = torch.view_as_complex( + coil_combination(ifft2c(pred, fft_type=self.fft_type), sensitivity_maps, method=self.output_type, dim=1) + ) + _, pred = center_crop_to_smallest(target, pred) + return pred
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/cirim.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/cirim.html new file mode 100644 index 00000000..d3635193 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/cirim.html @@ -0,0 +1,331 @@ + + + + + + mridc.collections.reconstruction.models.cirim — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.cirim
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.cirim

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+import math
+from abc import ABC
+from typing import Generator, Union
+
+import numpy as np
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from torch.nn import L1Loss
+
+from mridc.collections.common.losses.ssim import SSIMLoss
+from mridc.collections.common.parts.fft import ifft2c
+from mridc.collections.common.parts.rnn_utils import rnn_weights_init
+from mridc.collections.common.parts.utils import coil_combination
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.models.rim.rim_block import RIMBlock
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["CIRIM"]
+
+
+
[docs]class CIRIM(BaseMRIReconstructionModel, ABC): + """ + Implementation of the Cascades of Independently Recurrent Inference Machines, as presented in \ + Karkalousos, D. et al. + + References + ---------- + + .. + + Karkalousos, D. et al. (2021) ‘Assessment of Data Consistency through Cascades of Independently Recurrent \ + Inference Machines for fast and robust accelerated MRI reconstruction’. Available at: \ + https://arxiv.org/abs/2111.15498v1 + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + # Cascades of RIM blocks + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + self.recurrent_filters = cfg_dict.get("recurrent_filters") + + # make time-steps size divisible by 8 for fast fp16 training + self.time_steps = 8 * math.ceil(cfg_dict.get("time_steps") / 8) + + self.no_dc = cfg_dict.get("no_dc") + self.fft_type = cfg_dict.get("fft_type") + self.num_cascades = cfg_dict.get("num_cascades") + + self.cirim = torch.nn.ModuleList( + [ + RIMBlock( + recurrent_layer=cfg_dict.get("recurrent_layer"), + conv_filters=cfg_dict.get("conv_filters"), + conv_kernels=cfg_dict.get("conv_kernels"), + conv_dilations=cfg_dict.get("conv_dilations"), + conv_bias=cfg_dict.get("conv_bias"), + recurrent_filters=self.recurrent_filters, + recurrent_kernels=cfg_dict.get("recurrent_kernels"), + recurrent_dilations=cfg_dict.get("recurrent_dilations"), + recurrent_bias=cfg_dict.get("recurrent_bias"), + depth=cfg_dict.get("depth"), + time_steps=self.time_steps, + conv_dim=cfg_dict.get("conv_dim"), + no_dc=self.no_dc, + fft_type=self.fft_type, + ) + for _ in range(self.num_cascades) + ] + ) + + # Keep estimation through the cascades if keep_eta is True or re-estimate it if False. + self.keep_eta = cfg_dict.get("keep_eta") + self.output_type = cfg_dict.get("output_type") + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_chans"), + cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + normalize=cfg_dict.get("sens_normalize"), + mask_center=cfg_dict.get("sens_mask_center"), + ) + + # initialize weights if not using pretrained cirim + if not cfg_dict.get("pretrained", False): + std_init_range = 1 / self.recurrent_filters[0] ** 0.5 + self.cirim.apply(lambda module: rnn_weights_init(module, std_init_range)) + + self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() + self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() + + self.dc_weight = torch.nn.Parameter(torch.ones(1)) + self.accumulate_estimates = True + +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + init_pred: torch.Tensor, + target: torch.Tensor, + ) -> Union[Generator, torch.Tensor]: + """ + Forward pass of the network. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] + If self.accumulate_loss is True, returns a list of all intermediate estimates. + If False, returns the final estimate. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + prediction = y.clone() + init_pred = None if init_pred is None or init_pred.dim() < 4 else init_pred + hx = None + sigma = 1.0 + cascades_etas = [] + for i, cascade in enumerate(self.cirim): + # Forward pass through the cascades + prediction, hx = cascade( + prediction, + y, + sensitivity_maps, + mask, + init_pred, + hx, + sigma, + keep_eta=False if i == 0 else self.keep_eta, + ) + time_steps_etas = [self.process_intermediate_pred(pred, sensitivity_maps, target) for pred in prediction] + cascades_etas.append(time_steps_etas) + yield cascades_etas
+ +
[docs] def process_intermediate_pred(self, pred, sensitivity_maps, target, do_coil_combination=False): + """ + Process the intermediate prediction. + + Parameters + ---------- + pred: Intermediate prediction. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + target: Target data to crop to size. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + do_coil_combination: Whether to do coil combination. + bool, default False + + Returns + ------- + pred: torch.Tensor, shape [batch_size, n_x, n_y, 2] + Processed prediction. + """ + # Take the last time step of the eta + if not self.no_dc or do_coil_combination: + pred = ifft2c(pred, fft_type=self.fft_type) + pred = coil_combination(pred, sensitivity_maps, method=self.output_type, dim=1) + pred = torch.view_as_complex(pred) + _, pred = center_crop_to_smallest(target, pred) + return pred
+ +
[docs] def process_loss(self, target, pred, _loss_fn): + """ + Process the loss. + + Parameters + ---------- + target: Target data. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + pred: Final prediction(s). + list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or + torch.Tensor, shape [batch_size, n_x, n_y, 2] + _loss_fn: Loss function. + torch.nn.Module, default torch.nn.L1Loss() + + Returns + ------- + loss: torch.FloatTensor, shape [1] + If self.accumulate_loss is True, returns an accumulative result of all intermediate losses. + """ + target = torch.abs(target / torch.max(torch.abs(target))) + if "ssim" in str(_loss_fn).lower(): + max_value = np.array(torch.max(torch.abs(target)).item()).astype(np.float32) + + def loss_fn(x, y): + """Calculate the ssim loss.""" + return _loss_fn( + x.unsqueeze(dim=1), + torch.abs(y / torch.max(torch.abs(y))).unsqueeze(dim=1), + data_range=torch.tensor(max_value).unsqueeze(dim=0).to(x.device), + ) + + else: + + def loss_fn(x, y): + """Calculate other loss.""" + return _loss_fn(x, torch.abs(y / torch.max(torch.abs(y)))) + + if self.accumulate_estimates: + cascades_loss = [] + for cascade_pred in pred: + time_steps_loss = [loss_fn(target, time_step_pred) for time_step_pred in cascade_pred] + _loss = [ + x * torch.logspace(-1, 0, steps=self.time_steps).to(time_steps_loss[0]) for x in time_steps_loss + ] + cascades_loss.append(sum(sum(_loss) / self.time_steps)) + yield sum(list(cascades_loss)) / len(self.cirim) + else: + return loss_fn(target, pred)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/conv/conv2d.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/conv/conv2d.html new file mode 100644 index 00000000..eddd672a --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/conv/conv2d.html @@ -0,0 +1,165 @@ + + + + + + mridc.collections.reconstruction.models.conv.conv2d — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.conv.conv2d
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.conv.conv2d

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/conv/conv.py
+# Copyright (c) DIRECT Contributors
+
+import torch.nn as nn
+
+
+
[docs]class Conv2d(nn.Module): + """ + Implementation of a simple cascade of 2D convolutions. + If batchnorm is set to True, batch normalization layer is applied after each convolution. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, n_convs=3, activation=nn.PReLU(), batchnorm=False): + """ + Inits Conv2d. + + Parameters + ---------- + in_channels: Number of input channels. + int + out_channels: Number of output channels. + int + hidden_channels: Number of hidden channels. + int + n_convs: Number of convolutional layers. + int + activation: Activation function. + torch.nn.Module + batchnorm: If True a batch normalization layer is applied after every convolution. + bool + """ + super().__init__() + + self.conv = [] + for idx in range(n_convs): + self.conv.append( + nn.Conv2d( + in_channels if idx == 0 else hidden_channels, + hidden_channels if idx != n_convs - 1 else out_channels, + kernel_size=3, + padding=1, + ) + ) + if batchnorm: + self.conv.append(nn.BatchNorm2d(hidden_channels if idx != n_convs - 1 else out_channels, eps=1e-4)) + if idx != n_convs - 1: + self.conv.append(activation) + self.conv = nn.Sequential(*self.conv) + +
[docs] def forward(self, x): + """ + Performs the forward pass of Conv2d. + + Parameters + ---------- + x: Input tensor. + + Returns + ------- + Convoluted output. + """ + if x.dim() == 5: + x = x.squeeze(1) + if x.shape[-1] == 2: + x = x.permute(0, 3, 1, 2) + return self.conv(x)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/conv/gruconv2d.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/conv/gruconv2d.html new file mode 100644 index 00000000..8a0b7c8a --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/conv/gruconv2d.html @@ -0,0 +1,206 @@ + + + + + + mridc.collections.reconstruction.models.conv.gruconv2d — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.conv.gruconv2d
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.conv.gruconv2d

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from typing import Optional
+
+import torch.nn as nn
+from torch import Tensor
+
+from mridc.collections.reconstruction.models.rim.conv_layers import ConvNonlinear
+from mridc.collections.reconstruction.models.rim.rnn_cells import ConvGRUCell
+
+
+
[docs]class GRUConv2d(nn.Module): + """ + Implementation of a GRU followed by a number of 2D convolutions inspired by [1]_. + + References + ---------- + .. [1] C. Qin, J. Schlemper, J. Caballero, A. N. Price, J. V. Hajnal and D. Rueckert, "Convolutional Recurrent Neural Networks for Dynamic MR Image Reconstruction," in IEEE Transactions on Medical Imaging, vol. 38, no. 1, pp. 280-290, Jan. 2019, doi: 10.1109/TMI.2018.2863670. + """ + + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + n_convs=3, + activation="ReLU", + batchnorm=False, + ): + """ + Inits Conv2d. + + Parameters + ---------- + in_channels: Number of input channels. + int + out_channels: Number of output channels. + int + hidden_channels: Number of hidden channels. + int + n_convs: Number of convolutional layers. + int + activation: Activation function. + torch.nn.Module + batchnorm: If True a batch normalization layer is applied after every convolution. + bool + """ + super().__init__() + + self.layers = nn.ModuleList() + self.layers.append( + ConvGRUCell( + in_channels, + hidden_channels, + conv_dim=2, + kernel_size=3, + dilation=1, + bias=False, + ) + ) + for _ in range(n_convs): + self.layers.append( + ConvNonlinear( + hidden_channels, + hidden_channels, + conv_dim=2, + kernel_size=3, + dilation=1, + bias=False, + nonlinear=activation, + ) + ) + self.layers.append( + nn.Sequential( + ConvNonlinear( + hidden_channels, + out_channels, + conv_dim=2, + kernel_size=3, + dilation=1, + bias=False, + nonlinear=activation, + ) + ) + ) + + self.hidden_channels = hidden_channels + +
[docs] def forward(self, x, hx: Optional[Tensor] = None): + """ + Performs the forward pass of Conv2d. + + Parameters + ---------- + x: Input tensor. + torch.Tensor + hx: Initial hidden state. + torch.Tensor + + Returns + ------- + Convoluted output. + """ + if hx is None: + hx = x.new_zeros((x.size(0), self.hidden_channels, *x.size()[2:])) + + for i, layer in enumerate(self.layers): + x = layer(x, hx) if i == 0 else layer(x) + return x
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/convrecnet/crnn_block.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/convrecnet/crnn_block.html new file mode 100644 index 00000000..8c3dcd04 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/convrecnet/crnn_block.html @@ -0,0 +1,224 @@ + + + + + + mridc.collections.reconstruction.models.convrecnet.crnn_block — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.convrecnet.crnn_block
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.convrecnet.crnn_block

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from typing import Any, List, Union
+
+import torch
+
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul
+
+
+
[docs]class DataConsistencyLayer(torch.nn.Module): + """ + Data consistency layer for the CRNN. + This layer is used to ensure that the output of the CRNN is the same as the input. + """ + + def __init__(self): + """Initializes the data consistency layer.""" + super().__init__() + self.dc_weight = torch.nn.Parameter(torch.ones(1)) + +
[docs] def forward(self, pred_kspace, ref_kspace, mask): + """Forward pass of the data consistency layer.""" + zero = torch.zeros(1, 1, 1, 1, 1).to(pred_kspace) + return torch.where(mask.bool(), pred_kspace - ref_kspace, zero) * self.dc_weight
+ + +
[docs]class RecurrentConvolutionalNetBlock(torch.nn.Module): + """ + Model block for Recurrent Convolution Neural Network inspired by [1]_. + + References + ---------- + .. [1] C. Qin, J. Schlemper, J. Caballero, A. N. Price, J. V. Hajnal and D. Rueckert, "Convolutional Recurrent Neural Networks for Dynamic MR Image Reconstruction," in IEEE Transactions on Medical Imaging, vol. 38, no. 1, pp. 280-290, Jan. 2019, doi: 10.1109/TMI.2018.2863670. + """ + + def __init__( + self, model: torch.nn.Module, num_iterations: int = 10, fft_type: str = "orthogonal", no_dc: bool = False + ): + """ + Initialize the model block. + + Parameters + ---------- + model: Model to apply soft data consistency. + num_iterations: Number of iterations. + fft_type: Type of FFT to use. + no_dc: Whether to remove the DC component. + """ + super().__init__() + + self.model = model + self.num_iterations = num_iterations + self.fft_type = fft_type + self.no_dc = no_dc + + self.dc_weight = torch.nn.Parameter(torch.ones(1)) + +
[docs] def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + """ + Expand the sensitivity maps to the same size as the input. + + Parameters + ---------- + x: Input data. + sens_maps: Sensitivity maps. + + Returns + ------- + SENSE reconstruction expanded to the same size as the input. + """ + return fft2c(complex_mul(x, sens_maps), fft_type=self.fft_type)
+ +
[docs] def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + """ + Reduce the sensitivity maps to the same size as the input. + + Parameters + ---------- + x: Input data. + sens_maps: Sensitivity maps. + + Returns + ------- + SENSE reconstruction reduced to the same size as the input. + """ + x = ifft2c(x, fft_type=self.fft_type) + return complex_mul(x, complex_conj(sens_maps)).sum(1)
+ +
[docs] def forward( + self, + ref_kspace: torch.Tensor, + sens_maps: torch.Tensor, + mask: torch.Tensor, + ) -> List[Union[torch.Tensor, Any]]: + """ + Forward pass of the model. + + Parameters + ---------- + ref_kspace: Reference k-space data. + sens_maps: Sensitivity maps. + mask: Mask to apply to the data. + + Returns + ------- + Reconstructed image. + """ + zero = torch.zeros(1, 1, 1, 1, 1).to(ref_kspace) + pred = ref_kspace.clone() + + preds = [] + for _ in range(self.num_iterations): + soft_dc = torch.where(mask.bool(), pred - ref_kspace, zero) * self.dc_weight + + eta = self.sens_reduce(pred, sens_maps) + eta = self.model(eta.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + eta + eta = self.sens_expand(eta.unsqueeze(1), sens_maps) + + if not self.no_dc: + # TODO: Check if this is correct + eta = pred - soft_dc - eta + pred = eta + + preds.append(eta) + + return preds
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/crnn.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/crnn.html new file mode 100644 index 00000000..9f9db9da --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/crnn.html @@ -0,0 +1,280 @@ + + + + + + mridc.collections.reconstruction.models.crnn — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.crnn
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.crnn

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from abc import ABC
+from typing import Generator, Union
+
+import numpy as np
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from torch.nn import L1Loss
+
+from mridc.collections.common.losses.ssim import SSIMLoss
+from mridc.collections.common.parts.fft import ifft2c
+from mridc.collections.common.parts.utils import coil_combination
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.models.conv.gruconv2d import GRUConv2d
+from mridc.collections.reconstruction.models.convrecnet.crnn_block import RecurrentConvolutionalNetBlock
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["CRNNet"]
+
+
+
[docs]class CRNNet(BaseMRIReconstructionModel, ABC): + """ + Implementation of the Convolutional Recurrent Neural Network, inspired by C. Qin, J. Schlemper, J. Caballero, \ + A. N. Price, J. V. Hajnal and D. Rueckert. + + References + ---------- + + .. + + C. Qin, J. Schlemper, J. Caballero, A. N. Price, J. V. Hajnal and D. Rueckert, "Convolutional Recurrent \ + Neural Networks for Dynamic MR Image Reconstruction," in IEEE Transactions on Medical Imaging, vol. 38, \ + no. 1, pp. 280-290, Jan. 2019, doi: 10.1109/TMI.2018.2863670. + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + self.no_dc = cfg_dict.get("no_dc") + self.fft_type = cfg_dict.get("fft_type") + self.num_iterations = cfg_dict.get("num_iterations") + + self.crnn = RecurrentConvolutionalNetBlock( + GRUConv2d( + in_channels=2, + out_channels=2, + hidden_channels=cfg_dict.get("hidden_channels"), + n_convs=cfg_dict.get("n_convs"), + batchnorm=cfg_dict.get("batchnorm"), + ), + num_iterations=self.num_iterations, + fft_type=self.fft_type, + no_dc=self.no_dc, + ) + + self.output_type = cfg_dict.get("output_type") + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_chans"), + cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + normalize=cfg_dict.get("sens_normalize"), + ) + + # initialize weights if not using pretrained ccnn + # TODO if not ccnn_cfg_dict.get("pretrained", False) + + self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() + self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() + + self.accumulate_estimates = True + +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + init_pred: torch.Tensor, + target: torch.Tensor, + ) -> Union[Generator, torch.Tensor]: + """ + Forward pass of the network. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] + If self.accumulate_loss is True, returns a list of all intermediate estimates. + If False, returns the final estimate. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + pred = self.crnn(y, sensitivity_maps, mask) + yield [self.process_intermediate_pred(x, sensitivity_maps, target) for x in pred]
+ +
[docs] def process_intermediate_pred(self, pred, sensitivity_maps, target): + """ + Process the intermediate prediction. + + Parameters + ---------- + pred: Intermediate prediction. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + target: Target data to crop to size. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: torch.Tensor, shape [batch_size, n_x, n_y, 2] + Processed prediction. + """ + pred = ifft2c(pred, fft_type=self.fft_type) + pred = coil_combination(pred, sensitivity_maps, method=self.output_type, dim=1) + pred = torch.view_as_complex(pred) + _, pred = center_crop_to_smallest(target, pred) + return pred
+ +
[docs] def process_loss(self, target, pred, _loss_fn): + """ + Process the loss. + + Parameters + ---------- + target: Target data. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + pred: Final prediction(s). + list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or + torch.Tensor, shape [batch_size, n_x, n_y, 2] + _loss_fn: Loss function. + torch.nn.Module, default torch.nn.L1Loss() + + Returns + ------- + loss: torch.FloatTensor, shape [1] + If self.accumulate_loss is True, returns an accumulative result of all intermediate losses. + """ + target = torch.abs(target / torch.max(torch.abs(target))) + + if "ssim" in str(_loss_fn).lower(): + max_value = np.array(torch.max(torch.abs(target)).item()).astype(np.float32) + + def loss_fn(x, y): + """Calculate the ssim loss.""" + return _loss_fn( + x.unsqueeze(dim=1), + torch.abs(y / torch.max(torch.abs(y))).unsqueeze(dim=1), + data_range=torch.tensor(max_value).unsqueeze(dim=0).to(x.device), + ) + + else: + + def loss_fn(x, y): + """Calculate other loss.""" + return _loss_fn(x, torch.abs(y / torch.max(torch.abs(y)))) + + iterations_loss = [loss_fn(target, iteration_pred) for iteration_pred in pred] + _loss = [x * torch.logspace(-1, 0, steps=self.num_iterations).to(iterations_loss[0]) for x in iterations_loss] + yield sum(sum(_loss) / self.num_iterations)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/crossdomain/crossdomain.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/crossdomain/crossdomain.html new file mode 100644 index 00000000..77460deb --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/crossdomain/crossdomain.html @@ -0,0 +1,271 @@ + + + + + + mridc.collections.reconstruction.models.crossdomain.crossdomain — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.crossdomain.crossdomain
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.crossdomain.crossdomain

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/crossdomain/crossdomain.py
+# Copyright (c) DIRECT Contributors
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul
+
+
+
[docs]class CrossDomainNetwork(nn.Module): + """This performs optimisation in both, k-space ("K") and image ("I") domains according to domain_sequence.""" + + def __init__( + self, + image_model_list: nn.Module, + kspace_model_list: Optional[Union[nn.Module, None]] = None, + domain_sequence: str = "KIKI", + image_buffer_size: int = 1, + kspace_buffer_size: int = 1, + normalize_image: bool = False, + fft_type: str = "orthogonal", + **kwargs, + ): + """ + Inits CrossDomainNetwork. + + Parameters + ---------- + image_model_list: Image domain model list. + torch.nn.Module + kspace_model_list: K-space domain model list. If set to None, a correction step is applied. + torch.nn.Module, Default: None. + domain_sequence: Domain sequence containing only "K" (k-space domain) and/or "I" (image domain). + str, Default: "KIKI". + image_buffer_size: Image buffer size. + int, Default: 1. + kspace_buffer_size: K-space buffer size. + int, Default: 1. + normalize_image: If True, input is normalized. + bool, Default: False. + fft_type: Type of FFT. + str, Default: "orthogonal". + kwargs:Keyword Arguments. + dict + """ + super().__init__() + + self.fft_type = fft_type + + domain_sequence = list(domain_sequence.strip()) # type: ignore + if not set(domain_sequence).issubset({"K", "I"}): + raise ValueError(f"Invalid domain sequence. Got {domain_sequence}. Should only contain 'K' and 'I'.") + + if kspace_model_list is not None and len(kspace_model_list) != domain_sequence.count("K"): + raise ValueError("K-space domain steps do not match k-space model list length.") + + if len(image_model_list) != domain_sequence.count("I"): + raise ValueError("Image domain steps do not match image model list length.") + + self.domain_sequence = domain_sequence + + self.kspace_model_list = kspace_model_list + self.kspace_buffer_size = kspace_buffer_size + + self.image_model_list = image_model_list + self.image_buffer_size = image_buffer_size + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + +
[docs] def kspace_correction(self, block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map, masked_kspace): + """Performs k-space correction.""" + forward_buffer = [ + self._forward_operator(image.clone(), sampling_mask, sensitivity_map) + for image in torch.split(image_buffer, 2, self._complex_dim) + ] + forward_buffer = torch.cat(forward_buffer, self._complex_dim) + + kspace_buffer = torch.cat([kspace_buffer, forward_buffer, masked_kspace], self._complex_dim) + + if self.kspace_model_list is not None: + kspace_buffer = self.kspace_model_list[block_idx](kspace_buffer.permute(0, 1, 4, 2, 3)).permute( + 0, 1, 3, 4, 2 + ) + else: + kspace_buffer = kspace_buffer[..., :2] - kspace_buffer[..., 2:4] + + return kspace_buffer
+ +
[docs] def image_correction(self, block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map): + """Performs image correction.""" + backward_buffer = [ + self._backward_operator(kspace.clone(), sampling_mask, sensitivity_map) + for kspace in torch.split(kspace_buffer, 2, self._complex_dim) + ] + backward_buffer = torch.cat(backward_buffer, self._complex_dim) + + image_buffer = torch.cat([image_buffer, backward_buffer], self._complex_dim).permute(0, 3, 1, 2) + image_buffer = self.image_model_list[block_idx](image_buffer).permute(0, 2, 3, 1) + + return image_buffer
+ + def _forward_operator(self, image, sampling_mask, sensitivity_map): + """Forward operator.""" + return torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=image.dtype).to(image.device), + fft2c( + complex_mul(image.unsqueeze(1), sensitivity_map), + fft_type=self.fft_type, + ).type(image.type()), + ) + + def _backward_operator(self, kspace, sampling_mask, sensitivity_map): + """Backward operator.""" + kspace = torch.where(sampling_mask == 0, torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), kspace) + return ( + complex_mul( + ifft2c(kspace.float(), fft_type=self.fft_type), + complex_conj(sensitivity_map), + ) + .sum(1) + .type(kspace.type()) + ) + +
[docs] def forward( + self, + masked_kspace: torch.Tensor, + sensitivity_map: torch.Tensor, + sampling_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Computes the forward pass of CrossDomainNetwork. + + Parameters + ---------- + masked_kspace: Subsampled k-space data. + torch.tenor, shape [batch_size, n_coil, height, width, 2] + sensitivity_map: Sensitivity map. + torch.tenor, shape [batch_size, n_coil, height, width, 2] + sampling_mask: Sampling mask. + torch.tenor, shape [batch_size, 1, height, width, 1] + + Returns + ------- + Output image. + torch.tenor, shape [batch_size, height, width, 2] + """ + input_image = self._backward_operator(masked_kspace, sampling_mask, sensitivity_map) + + image_buffer = torch.cat([input_image] * self.image_buffer_size, self._complex_dim).to(masked_kspace.device) + kspace_buffer = torch.cat([masked_kspace] * self.kspace_buffer_size, self._complex_dim).to( + masked_kspace.device + ) + + kspace_block_idx, image_block_idx = 0, 0 + for block_domain in self.domain_sequence: + if block_domain == "K": + kspace_buffer = self.kspace_correction( + kspace_block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map, masked_kspace + ) + kspace_block_idx += 1 + else: + image_buffer = self.image_correction( + image_block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map + ) + image_block_idx += 1 + + return image_buffer[..., :2]
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/crossdomain/multicoil.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/crossdomain/multicoil.html new file mode 100644 index 00000000..3382f3bb --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/crossdomain/multicoil.html @@ -0,0 +1,170 @@ + + + + + + mridc.collections.reconstruction.models.crossdomain.multicoil — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.crossdomain.multicoil
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.crossdomain.multicoil

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/crossdomain/multicoil.py
+# Copyright (c) DIRECT Contributors
+
+import torch
+import torch.nn as nn
+
+
+
[docs]class MultiCoil(nn.Module): + """ + This makes the forward pass of multi-coil data of shape (N, N_coils, H, W, C) to a model. + If coil_to_batch is set to True, coil dimension is moved to the batch dimension. Otherwise, it passes to the model + each coil-data individually. + """ + + def __init__(self, model: nn.Module, coil_dim: int = 1, coil_to_batch: bool = False): + """Inits MultiCoil. + + Parameters + ---------- + model: Any nn.Module that takes as input with 4D data (N, H, W, C). Typically, a convolutional-like model. + torch.nn.Module + coil_dim: Coil dimension. + int, Default: 1. + coil_to_batch: If True batch and coil dimensions are merged when forwarded by the model and unmerged when + outputted. Otherwise, input is forwarded to the model per coil. + bool, Default: False. + """ + super().__init__() + + self.model = model + self.coil_to_batch = coil_to_batch + self._coil_dim = coil_dim + + def _compute_model_per_coil(self, data: torch.Tensor) -> torch.Tensor: + """Computes the model per coil.""" + output = [] + + for idx in range(data.size(self._coil_dim)): + subselected_data = data.select(self._coil_dim, idx) + if subselected_data.shape[-1] == 2 and subselected_data.dim() == 4: + output.append(self.model(subselected_data.permute(0, 3, 1, 2))) + else: + output.append(self.model(subselected_data.unsqueeze(1)).squeeze(1)) + output = torch.stack(output, dim=self._coil_dim) + return output + +
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs the forward pass of MultiCoil. + + Parameters + ---------- + x: Multi-coil input. + torch.Tensor, shape (N, N_coils, H, W, C) + + Returns + ------- + Multi-coil output. + torch.Tensor, shape (N, N_coils, H, W, C) + """ + if self.coil_to_batch: + x = x.clone() + + batch, coil, channels, height, width = x.size() + x = x.reshape(batch * coil, channels, height, width).contiguous() + x = self.model(x).permute(0, 2, 3, 1) + x = x.reshape(batch, coil, height, width, -1).permute(0, 1, 4, 2, 3) + else: + x = self._compute_model_per_coil(x).contiguous() + + return x
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/didn/didn.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/didn/didn.html new file mode 100644 index 00000000..fb0dbc4f --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/didn/didn.html @@ -0,0 +1,464 @@ + + + + + + mridc.collections.reconstruction.models.didn.didn — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.didn.didn
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.didn.didn

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/didn/didn.py
+# Copyright (c) DIRECT Contributors
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+
[docs]class Subpixel(nn.Module): + """ + Subpixel convolution layer for up-scaling of low resolution features at super-resolution as implemented in \ + Yu, Songhyun, et al. + + References + ---------- + + .. + Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer \ + Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, \ + https://doi.org/10.1109/CVPRW.2019.00262. + + """ + + def __init__(self, in_channels, out_channels, upscale_factor, kernel_size, padding=0): + """ + Inits Subpixel. + + Parameters + ---------- + in_channels: Number of input channels. + out_channels: Number of output channels. + upscale_factor: Subpixel upscale factor. + kernel_size: Convolution kernel size. + padding: Padding size. Default: 0. + """ + super().__init__() + self.conv = nn.Conv2d( + in_channels, out_channels * upscale_factor**2, kernel_size=kernel_size, padding=padding + ) + self.pixelshuffle = nn.PixelShuffle(upscale_factor) + +
[docs] def forward(self, x): + """Computes Subpixel convolution on input torch.Tensor ``x``.""" + return self.pixelshuffle(self.conv(x))
+ + +
[docs]class ReconBlock(nn.Module): + """ + Reconstruction Block of DIDN model as implemented in Yu, Songhyun, et al. + + References + ---------- + + .. + Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer \ + Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, \ + https://doi.org/10.1109/CVPRW.2019.00262. + + """ + + def __init__(self, in_channels, num_convs): + """ + Inits ReconBlock. + + Parameters + ---------- + in_channels: Number of input channels. + num_convs: Number of convolution blocks. + """ + super().__init__() + self.convs = nn.ModuleList( + [ + nn.Sequential( + *[ + nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1), + nn.PReLU(), + ] + ) + for _ in range(num_convs - 1) + ] + ) + self.convs.append(nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1)) + self.num_convs = num_convs + +
[docs] def forward(self, input_data): + """ + Computes num_convs convolutions followed by PReLU activation on `input_data`. + + Parameters + ---------- + input_data: Input tensor. + """ + output = input_data.clone() + for idx in range(self.num_convs): + output = self.convs[idx](output) + + return input_data + output
+ + +
[docs]class DUB(nn.Module): + """ + Down-up block (DUB) for DIDN model as implemented in Yu, Songhyun, et al. + + References + ---------- + + .. + Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer \ + Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, \ + https://doi.org/10.1109/CVPRW.2019.00262. + + """ + + def __init__( + self, + in_channels, + out_channels, + ): + """ + Inits DUB. + + Parameters + ---------- + in_channels: Number of input channels. + out_channels: Number of output channels. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + # Scale 1 + self.conv1_1 = nn.Sequential(*[nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), nn.PReLU()] * 2) + self.down1 = nn.Conv2d(in_channels, in_channels * 2, kernel_size=3, stride=2, padding=1) + # Scale 2 + self.conv2_1 = nn.Sequential( + *[nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=3, padding=1), nn.PReLU()] + ) + self.down2 = nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=3, stride=2, padding=1) + # Scale 3 + self.conv3_1 = nn.Sequential( + *[ + nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size=3, padding=1), + nn.PReLU(), + ] + ) + self.up1 = nn.Sequential(*[Subpixel(in_channels * 4, in_channels * 2, 2, 1, 0)]) + # Scale 2 + self.conv_agg_1 = nn.Conv2d(in_channels * 4, in_channels * 2, kernel_size=1) + self.conv2_2 = nn.Sequential( + *[ + nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=3, padding=1), + nn.PReLU(), + ] + ) + self.up2 = nn.Sequential(*[Subpixel(in_channels * 2, in_channels, 2, 1, 0)]) + # Scale 1 + self.conv_agg_2 = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1) + self.conv1_2 = nn.Sequential(*[nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), nn.PReLU()] * 2) + self.conv_out = nn.Sequential(*[nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), nn.PReLU()]) + +
[docs] @staticmethod + def pad(x): + """ + Pads input to height and width dimensions if odd. + + Parameters + ---------- + x: Input to pad. + + Returns + ------- + Padded tensor. + """ + padding = [0, 0, 0, 0] + + if x.shape[-2] % 2 != 0: + padding[3] = 1 # Padding right - width + if x.shape[-1] % 2 != 0: + padding[1] = 1 # Padding bottom - height + if sum(padding) != 0: + x = F.pad(x, padding, "reflect") + return x
+ +
[docs] @staticmethod + def crop_to_shape(x, shape): + """ + Crops ``x`` to specified shape. + + Parameters + ---------- + x: Input tensor with shape (\*, H, W). + shape: Crop shape corresponding to H, W. + + Returns + ------- + Cropped tensor. + """ + h, w = x.shape[-2:] + + if h > shape[0]: + x = x[:, :, : shape[0], :] + if w > shape[1]: + x = x[:, :, :, : shape[1]] + return x
+ +
[docs] def forward(self, x): + """ + Parameters + ---------- + x: Input tensor. + + Returns + ------- + DUB output. + """ + x1 = self.pad(x.clone()) + x1 = x1 + self.conv1_1(x1) + x2 = self.down1(x1) + x2 = x2 + self.conv2_1(x2) + out = self.down2(x2) + out = out + self.conv3_1(out) + out = self.up1(out) + out = torch.cat([x2, self.crop_to_shape(out, x2.shape[-2:])], dim=1) + out = self.conv_agg_1(out) + out = out + self.conv2_2(out) + out = self.up2(out) + out = torch.cat([x1, self.crop_to_shape(out, x1.shape[-2:])], dim=1) + out = self.conv_agg_2(out) + out = out + self.conv1_2(out) + out = x + self.crop_to_shape(self.conv_out(out), x.shape[-2:]) + return out
+ + +
[docs]class DIDN(nn.Module): + """ + Deep Iterative Down-up convolutional Neural network (DIDN) implementation as in Yu, Songhyun, et al. + + References + ---------- + + .. + Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer \ + Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, \ + https://doi.org/10.1109/CVPRW.2019.00262. + + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int = 128, + num_dubs: int = 6, + num_convs_recon: int = 9, + skip_connection: bool = False, + ): + """ + Inits DIDN. + + Parameters + ---------- + in_channels: Number of input channels. + int + out_channels: Number of output channels. + int + hidden_channels: Number of hidden channels. First convolution out_channels. + int, Default: 128. + num_dubs: Number of DUB networks. + int, Default: 6. + num_convs_recon: Number of ReconBlock convolutions. + int, Default: 9. + skip_connection: Use skip connection. + bool, Default: False. + """ + super().__init__() + self.conv_in = nn.Sequential( + *[nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, padding=1), nn.PReLU()] + ) + self.down = nn.Conv2d( + in_channels=hidden_channels, + out_channels=hidden_channels, + kernel_size=3, + stride=2, + padding=1, + ) + self.dubs = nn.ModuleList( + [DUB(in_channels=hidden_channels, out_channels=hidden_channels) for _ in range(num_dubs)] + ) + self.recon_block = ReconBlock(in_channels=hidden_channels, num_convs=num_convs_recon) + self.recon_agg = nn.Conv2d(in_channels=hidden_channels * num_dubs, out_channels=hidden_channels, kernel_size=1) + self.conv = nn.Sequential( + *[ + nn.Conv2d( + in_channels=hidden_channels, + out_channels=hidden_channels, + kernel_size=3, + padding=1, + ), + nn.PReLU(), + ] + ) + self.up2 = Subpixel(hidden_channels, hidden_channels, 2, 1) + self.conv_out = nn.Conv2d( + in_channels=hidden_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ) + self.num_dubs = num_dubs + self.skip_connection = (in_channels == out_channels) and skip_connection + +
[docs] @staticmethod + def crop_to_shape(x, shape): + """ + Crops ``x`` to specified shape. + + Parameters + ---------- + x: Input tensor with shape (\*, H, W). + shape: Crop shape corresponding to H, W. + + Returns + ------- + Cropped tensor. + """ + h, w = x.shape[-2:] + + if h > shape[0]: + x = x[:, :, : shape[0], :] + if w > shape[1]: + x = x[:, :, :, : shape[1]] + return x
+ +
[docs] def forward(self, x, channel_dim=1): + """ + Takes as input a torch.Tensor `x` and computes DIDN(x). + + Parameters + ---------- + x: Input tensor. + channel_dim: Channel dimension. Default: 1. + + Returns + ------- + DIDN output tensor. + """ + out = self.conv_in(x) + out = self.down(out) + + dub_outs = [] + for dub in self.dubs: + out = dub(out) + dub_outs.append(out) + + out = [self.recon_block(dub_out) for dub_out in dub_outs] + out = self.recon_agg(torch.cat(out, dim=channel_dim)) + out = self.conv(out) + out = self.up2(out) + out = self.conv_out(out) + out = self.crop_to_shape(out, x.shape[-2:]) + + if self.skip_connection: + out = x + out + return out
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/dunet.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/dunet.html new file mode 100644 index 00000000..9f4c3b1b --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/dunet.html @@ -0,0 +1,256 @@ + + + + + + mridc.collections.reconstruction.models.dunet — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.dunet
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.dunet

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from abc import ABC
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from torch.nn import L1Loss
+
+from mridc.collections.common.losses.ssim import SSIMLoss
+from mridc.collections.common.parts.fft import ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.models.didn.didn import DIDN
+from mridc.collections.reconstruction.models.sigmanet.dc_layers import (
+    DataGDLayer,
+    DataIDLayer,
+    DataProxCGLayer,
+    DataVSLayer,
+)
+from mridc.collections.reconstruction.models.sigmanet.sensitivity_net import SensitivityNetwork
+from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["DUNet"]
+
+
+
[docs]class DUNet(BaseMRIReconstructionModel, ABC): + """ + Implementation of the Down-Up NET, inspired by Hammernik, K, Schlemper, J, Qin, C, et al. + + References + ---------- + + .. + + Hammernik, K, Schlemper, J, Qin, C, et al. Systematic evaluation of iterative deep neural networks for fast \ + parallel MRI reconstruction with sensitivity-weighted coil combination. Magn Reson Med. 2021; 86: 1859– 1872. \ + https://doi.org/10.1002/mrm.28827 + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + self.fft_type = cfg_dict.get("fft_type") + + reg_model_architecture = cfg_dict.get("reg_model_architecture") + if reg_model_architecture == "DIDN": + reg_model = DIDN( + in_channels=2, + out_channels=2, + hidden_channels=cfg_dict.get("didn_hidden_channels"), + num_dubs=cfg_dict.get("didn_num_dubs"), + num_convs_recon=cfg_dict.get("didn_num_convs_recon"), + ) + elif reg_model_architecture in ["UNET", "NORMUNET"]: + reg_model = NormUnet( + cfg_dict.get("unet_num_filters"), + cfg_dict.get("unet_num_pool_layers"), + in_chans=2, + out_chans=2, + drop_prob=cfg_dict.get("unet_dropout_probability"), + padding_size=cfg_dict.get("unet_padding_size"), + normalize=cfg_dict.get("unet_normalize"), + ) + else: + raise NotImplementedError( + f"DUNET is currently implemented for reg_model_architecture == 'DIDN' or 'UNet'." + f"Got reg_model_architecture == {reg_model_architecture}." + ) + + data_consistency_term = cfg_dict.get("data_consistency_term") + + if data_consistency_term == "GD": + dc_layer = DataGDLayer(lambda_init=cfg_dict.get("data_consistency_lambda_init"), fft_type=self.fft_type) + elif data_consistency_term == "PROX": + dc_layer = DataProxCGLayer( + lambda_init=cfg_dict.get("data_consistency_lambda_init"), fft_type=self.fft_type + ) + elif data_consistency_term == "VS": + dc_layer = DataVSLayer( + alpha_init=cfg_dict.get("data_consistency_alpha_init"), + beta_init=cfg_dict.get("data_consistency_beta_init"), + fft_type=self.fft_type, + ) + else: + dc_layer = DataIDLayer() + + self.model = SensitivityNetwork( + cfg_dict.get("num_iter"), + reg_model, + dc_layer, + shared_params=cfg_dict.get("shared_params"), + save_space=False, + reset_cache=False, + ) + + self._coil_dim = 1 + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_chans"), + cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + normalize=cfg_dict.get("sens_normalize"), + ) + + self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() + self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() + self.output_type = cfg_dict.get("output_type") + + self.dc_weight = torch.nn.Parameter(torch.ones(1)) + self.accumulate_estimates = False + +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + init_pred: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the network. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] + If self.accumulate_loss is True, returns a list of all intermediate estimates. + If False, returns the final estimate. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + init_pred = torch.sum(complex_mul(ifft2c(y, fft_type=self.fft_type), complex_conj(sensitivity_maps)), 1) + image = self.model(init_pred, y, sensitivity_maps, mask) + image = torch.sum(complex_mul(image, complex_conj(sensitivity_maps)), 1) + image = torch.view_as_complex(image) + _, image = center_crop_to_smallest(target, image) + return image
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/jointicnet.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/jointicnet.html new file mode 100644 index 00000000..7f914544 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/jointicnet.html @@ -0,0 +1,341 @@ + + + + + + mridc.collections.reconstruction.models.jointicnet — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.jointicnet
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.jointicnet

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from abc import ABC
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from torch.nn import L1Loss
+
+from mridc.collections.common.losses.ssim import SSIMLoss
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["JointICNet"]
+
+
+
[docs]class JointICNet(BaseMRIReconstructionModel, ABC): + """ + Implementation of the Joint Deep Model-Based MR Image and Coil Sensitivity Reconstruction Network (Joint-ICNet), \ + as presented in Jun, Yohan, et al. + + References + ---------- + + .. + + Jun, Yohan, et al. “Joint Deep Model-Based MR Image and Coil Sensitivity Reconstruction Network (Joint-ICNet) \ + for Fast MRI.” 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), IEEE, 2021, pp. \ + 5266–75. DOI.org (Crossref), https://doi.org/10.1109/CVPR46437.2021.00523. + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + self.num_iter = cfg_dict.get("num_iter") + self.fft_type = cfg_dict.get("fft_type") + + self.kspace_model = NormUnet( + cfg_dict.get("kspace_unet_num_filters"), + cfg_dict.get("kspace_unet_num_pool_layers"), + in_chans=2, + out_chans=2, + drop_prob=cfg_dict.get("kspace_unet_dropout_probability"), + padding_size=cfg_dict.get("kspace_unet_padding_size"), + normalize=cfg_dict.get("kspace_unet_normalize"), + ) + + self.image_model = NormUnet( + cfg_dict.get("imspace_unet_num_filters"), + cfg_dict.get("imspace_unet_num_pool_layers"), + in_chans=2, + out_chans=2, + drop_prob=cfg_dict.get("imspace_unet_dropout_probability"), + padding_size=cfg_dict.get("imspace_unet_padding_size"), + normalize=cfg_dict.get("imspace_unet_normalize"), + ) + + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_unet_num_filters"), + cfg_dict.get("sens_unet_num_pool_layers"), + mask_center=cfg_dict.get("sens_unet_mask_center"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + drop_prob=cfg_dict.get("sens_unet_dropout_probability"), + padding_size=cfg_dict.get("sens_unet_padding_size"), + normalize=cfg_dict.get("sens_unet_normalize"), + ) + + self.conv_out = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=1) + + self.reg_param_I = torch.nn.Parameter(torch.ones(self.num_iter)) + self.reg_param_F = torch.nn.Parameter(torch.ones(self.num_iter)) + self.reg_param_C = torch.nn.Parameter(torch.ones(self.num_iter)) + + self.lr_image = torch.nn.Parameter(torch.ones(self.num_iter)) + self.lr_sens = torch.nn.Parameter(torch.ones(self.num_iter)) + + self._coil_dim = 1 + self._spatial_dims = (2, 3) + + self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() + self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() + self.output_type = cfg_dict.get("output_type") + + self.accumulate_estimates = False + +
[docs] def update_C(self, idx, DC_sens, sensitivity_maps, image, y, mask) -> torch.Tensor: + """ + Update the coil sensitivity maps. + + .. math:: + C = (1 - 2 * \lambda_{k}^{C} * ni_{k}) * C_{k} + + C = 2 * \lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b)) + + A(x_{k}) = M * F * (C * x_{k}) + + C = 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^* + + Parameters + ---------- + idx: int + The current iteration index. + DC_sens: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] + The initial coil sensitivity maps. + sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] + The coil sensitivity maps. + image: torch.Tensor [batch_size, num_coils, num_rows, num_cols] + The predicted image. + y: torch.Tensor [batch_size, num_coils, num_rows, num_cols] + The subsampled k-space data. + mask: torch.Tensor [batch_size, 1, num_rows, num_cols] + The subsampled mask. + + Returns + ------- + sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] + The updated coil sensitivity maps. + """ + # (1 - 2 * lambda_{k}^{C} * ni_{k}) * C_{k} + sense_term_1 = (1 - 2 * self.reg_param_C[idx] * self.lr_sens[idx]) * sensitivity_maps + # 2 * lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b)) + sense_term_2 = 2 * self.reg_param_C[idx] * self.lr_sens[idx] * DC_sens + # A(x_{k}) = M * F * (C * x_{k}) + sense_term_3_A = fft2c(complex_mul(image.unsqueeze(1), sensitivity_maps), fft_type=self.fft_type) + sense_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), sense_term_3_A) + # 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^* + sense_term_3_mask = torch.where( + mask == 1, + torch.tensor([0.0], dtype=y.dtype).to(y.device), + sense_term_3_A - y, + ) + + sense_term_3_backward = ifft2c(sense_term_3_mask, fft_type=self.fft_type) + sense_term_3 = 2 * self.lr_sens[idx] * sense_term_3_backward * complex_conj(image).unsqueeze(1) + sensitivity_maps = sense_term_1 + sense_term_2 - sense_term_3 + return sensitivity_maps
+ +
[docs] def update_X(self, idx, image, sensitivity_maps, y, mask): + """ + Update the image. + + .. math:: + x_{k} = (1 - 2 * \lamdba_{{k}_{I}} * mi_{k} - 2 * \lamdba_{{k}_{F}} * mi_{k}) * x_{k} + + x_{k} = 2 * mi_{k} * (\lambda_{{k}_{I}} * D_I(x_{k}) + \lambda_{{k}_{F}} * F^-1(D_F(f))) + + A(x{k} - b) = M * F * (C * x{k}) - b + + x_{k} = 2 * mi_{k} * A^* * (A(x{k} - b)) + + Parameters + ---------- + idx: int + The current iteration index. + image: torch.Tensor [batch_size, num_coils, num_rows, num_cols] + The predicted image. + sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] + The coil sensitivity maps. + y: torch.Tensor [batch_size, num_coils, num_rows, num_cols] + The subsampled k-space data. + mask: torch.Tensor [batch_size, 1, num_rows, num_cols] + The subsampled mask. + + Returns + ------- + image: torch.Tensor [batch_size, num_coils, num_rows, num_cols] + The updated image. + """ + # (1 - 2 * lamdba_{k}_{I} * mi_{k} - 2 * lamdba_{k}_{F} * mi_{k}) * x_{k} + image_term_1 = ( + 1 - 2 * self.reg_param_I[idx] * self.lr_image[idx] - 2 * self.reg_param_F[idx] * self.lr_image[idx] + ) * image + # D_I(x_{k}) + image_term_2_DI = self.image_model(image.unsqueeze(1)).squeeze(1).contiguous() + image_term_2_DF = ifft2c( + self.kspace_model(fft2c(image, fft_type=self.fft_type).unsqueeze(1)).squeeze(1).contiguous(), + fft_type=self.fft_type, + ) + # 2 * mi_{k} * (lambda_{k}_{I} * D_I(x_{k}) + lambda_{k}_{F} * F^-1(D_F(f))) + image_term_2 = ( + 2 + * self.lr_image[idx] + * (self.reg_param_I[idx] * image_term_2_DI + self.reg_param_F[idx] * image_term_2_DF) + ) + # A(x{k}) - b) = M * F * (C * x{k}) - b + image_term_3_A = fft2c(complex_mul(image.unsqueeze(1), sensitivity_maps), fft_type=self.fft_type) + image_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), image_term_3_A) - y + # 2 * mi_{k} * A^* * (A(x{k}) - b)) + image_term_3_Aconj = complex_mul( + ifft2c(image_term_3_A, fft_type=self.fft_type), complex_conj(sensitivity_maps) + ).sum(1) + image_term_3 = 2 * self.lr_image[idx] * image_term_3_Aconj + image = image_term_1 + image_term_2 - image_term_3 + return image
+ +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + init_pred: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the network. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] + If self.accumulate_loss is True, returns a list of all intermediate estimates. + If False, returns the final estimate. + """ + DC_sens = self.sens_net(y, mask) + sensitivity_maps = DC_sens.clone() + image = complex_mul(ifft2c(y, fft_type=self.fft_type), complex_conj(sensitivity_maps)).sum(self._coil_dim) + for idx in range(self.num_iter): + sensitivity_maps = self.update_C(idx, DC_sens, sensitivity_maps, image, y, mask) + image = self.update_X(idx, image, sensitivity_maps, y, mask) + image = torch.view_as_complex(image) + _, image = center_crop_to_smallest(target, image) + return image
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/kikinet.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/kikinet.html new file mode 100644 index 00000000..7efd9c88 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/kikinet.html @@ -0,0 +1,291 @@ + + + + + + mridc.collections.reconstruction.models.kikinet — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.kikinet
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.kikinet

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from abc import ABC
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from torch.nn import L1Loss
+
+from mridc.collections.common.losses.ssim import SSIMLoss
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.models.conv.conv2d import Conv2d
+from mridc.collections.reconstruction.models.crossdomain.multicoil import MultiCoil
+from mridc.collections.reconstruction.models.didn.didn import DIDN
+from mridc.collections.reconstruction.models.mwcnn.mwcnn import MWCNN
+from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["KIKINet"]
+
+
+
[docs]class KIKINet(BaseMRIReconstructionModel, ABC): + """ + Based on KIKINet implementation [1]. Modified to work with multi-coil k-space data, as presented in Eo, Taejoon, \ + et al. + + References + ---------- + + .. + + Eo, Taejoon, et al. “KIKI-Net: Cross-Domain Convolutional Neural Networks for Reconstructing Undersampled \ + Magnetic Resonance Images.” Magnetic Resonance in Medicine, vol. 80, no. 5, Nov. 2018, pp. 2188–201. PubMed, \ + https://doi.org/10.1002/mrm.27201. + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + self.num_iter = cfg_dict.get("num_iter") + self.no_dc = cfg_dict.get("no_dc") + + kspace_model_architecture = cfg_dict.get("kspace_model_architecture") + + if kspace_model_architecture == "CONV": + kspace_model = Conv2d( + in_channels=2, + out_channels=2, + hidden_channels=cfg_dict.get("kspace_conv_hidden_channels"), + n_convs=cfg_dict.get("kspace_conv_n_convs"), + batchnorm=cfg_dict.get("kspace_conv_batchnorm"), + ) + elif kspace_model_architecture == "DIDN": + kspace_model = DIDN( + in_channels=2, + out_channels=2, + hidden_channels=cfg_dict.get("kspace_didn_hidden_channels"), + num_dubs=cfg_dict.get("kspace_didn_num_dubs"), + num_convs_recon=cfg_dict.get("kspace_didn_num_convs_recon"), + ) + elif kspace_model_architecture in ["UNET", "NORMUNET"]: + kspace_model = NormUnet( + cfg_dict.get("kspace_unet_num_filters"), + cfg_dict.get("kspace_unet_num_pool_layers"), + in_chans=2, + out_chans=2, + drop_prob=cfg_dict.get("kspace_unet_dropout_probability"), + padding_size=cfg_dict.get("kspace_unet_padding_size"), + normalize=cfg_dict.get("kspace_unet_normalize"), + ) + else: + raise NotImplementedError( + f"KIKINet is currently implemented for kspace_model_architecture == 'CONV' or 'DIDN' or 'UNet'." + f"Got kspace_model_architecture == {kspace_model_architecture}." + ) + + image_model_architecture = cfg_dict.get("imspace_model_architecture") + + if image_model_architecture == "MWCNN": + image_model = MWCNN( + input_channels=2, + first_conv_hidden_channels=cfg_dict.get("image_mwcnn_hidden_channels"), + num_scales=cfg_dict.get("image_mwcnn_num_scales"), + bias=cfg_dict.get("image_mwcnn_bias"), + batchnorm=cfg_dict.get("image_mwcnn_batchnorm"), + ) + elif image_model_architecture in ["UNET", "NORMUNET"]: + image_model = NormUnet( + cfg_dict.get("imspace_unet_num_filters"), + cfg_dict.get("imspace_unet_num_pool_layers"), + in_chans=2, + out_chans=2, + drop_prob=cfg_dict.get("imspace_unet_dropout_probability"), + padding_size=cfg_dict.get("imspace_unet_padding_size"), + normalize=cfg_dict.get("imspace_unet_normalize"), + ) + else: + raise NotImplementedError( + f"KIKINet is currently implemented only with image_model_architecture == 'MWCNN' or 'UNet'." + f"Got {image_model_architecture}." + ) + + self.fft_type = cfg_dict.get("fft_type") + self._coil_dim = 1 + + self.image_model_list = torch.nn.ModuleList([image_model] * self.num_iter) + self.kspace_model_list = torch.nn.ModuleList([MultiCoil(kspace_model, self._coil_dim)] * self.num_iter) + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_chans"), + cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + normalize=cfg_dict.get("sens_normalize"), + ) + + self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() + self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() + self.output_type = cfg_dict.get("output_type") + + self.dc_weight = torch.nn.Parameter(torch.ones(1)) + self.accumulate_estimates = False + +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + init_pred: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the network. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] + If self.accumulate_loss is True, returns a list of all intermediate estimates. + If False, returns the final estimate. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + kspace = y.clone() + zero = torch.zeros(1, 1, 1, 1, 1).to(kspace) + + for idx in range(self.num_iter): + soft_dc = torch.where(mask.bool(), kspace - y, zero) * self.dc_weight + + kspace = self.kspace_model_list[idx](kspace) + if kspace.shape[-1] != 2: + kspace = kspace.permute(0, 1, 3, 4, 2).to(target) + kspace = torch.view_as_real(kspace[..., 0] + 1j * kspace[..., 1]) # this is necessary, but why? + + image = complex_mul(ifft2c(kspace, fft_type=self.fft_type), complex_conj(sensitivity_maps)).sum(1) + image = self.image_model_list[idx](image.unsqueeze(1)).squeeze(1) + + if not self.no_dc: + image = fft2c(complex_mul(image.unsqueeze(1), sensitivity_maps), fft_type=self.fft_type).type( + image.type() + ) + image = kspace - soft_dc - image + image = complex_mul(ifft2c(image, fft_type=self.fft_type), complex_conj(sensitivity_maps)).sum(1) + + if idx < self.num_iter - 1: + kspace = fft2c(complex_mul(image.unsqueeze(1), sensitivity_maps), fft_type=self.fft_type).type( + image.type() + ) + + image = torch.view_as_complex(image) + _, image = center_crop_to_smallest(target, image) + return image
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/lpd.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/lpd.html new file mode 100644 index 00000000..1720db32 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/lpd.html @@ -0,0 +1,302 @@ + + + + + + mridc.collections.reconstruction.models.lpd — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.lpd
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.lpd

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from abc import ABC
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from torch.nn import L1Loss
+
+from mridc.collections.common.losses.ssim import SSIMLoss
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.models.conv.conv2d import Conv2d
+from mridc.collections.reconstruction.models.didn.didn import DIDN
+from mridc.collections.reconstruction.models.mwcnn.mwcnn import MWCNN
+from mridc.collections.reconstruction.models.primaldual.pd import DualNet, PrimalNet
+from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["LPDNet"]
+
+
+
[docs]class LPDNet(BaseMRIReconstructionModel, ABC): + """ + Implementation of the Learned Primal Dual network, inspired by Adler, Jonas, and Ozan Öktem. + + References + ---------- + + .. + + Adler, Jonas, and Ozan Öktem. “Learned Primal-Dual Reconstruction.” IEEE Transactions on Medical Imaging, \ + vol. 37, no. 6, June 2018, pp. 1322–32. arXiv.org, https://doi.org/10.1109/TMI.2018.2799231. + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + self.num_iter = cfg_dict.get("num_iter") + self.num_primal = cfg_dict.get("num_primal") + self.num_dual = cfg_dict.get("num_dual") + + primal_model_architecture = cfg_dict.get("primal_model_architecture") + + if primal_model_architecture == "MWCNN": + primal_model = torch.nn.Sequential( + *[ + MWCNN( + input_channels=2 * (self.num_primal + 1), + first_conv_hidden_channels=cfg_dict.get("primal_mwcnn_hidden_channels"), + num_scales=cfg_dict.get("primal_mwcnn_num_scales"), + bias=cfg_dict.get("primal_mwcnn_bias"), + batchnorm=cfg_dict.get("primal_mwcnn_batchnorm"), + ), + torch.nn.Conv2d(2 * (self.num_primal + 1), 2 * self.num_primal, kernel_size=1), + ] + ) + elif primal_model_architecture in ["UNET", "NORMUNET"]: + primal_model = NormUnet( + cfg_dict.get("primal_unet_num_filters"), + cfg_dict.get("primal_unet_num_pool_layers"), + in_chans=2 * (self.num_primal + 1), + out_chans=2 * self.num_primal, + drop_prob=cfg_dict.get("primal_unet_dropout_probability"), + padding_size=cfg_dict.get("primal_unet_padding_size"), + normalize=cfg_dict.get("primal_unet_normalize"), + ) + else: + raise NotImplementedError( + f"LPDNet is currently implemented for primal_model_architecture == 'CONV' or 'UNet'." + f"Got primal_model_architecture == {primal_model_architecture}." + ) + + dual_model_architecture = cfg_dict.get("dual_model_architecture") + + if dual_model_architecture == "CONV": + dual_model = Conv2d( + in_channels=2 * (self.num_dual + 2), + out_channels=2 * self.num_dual, + hidden_channels=cfg_dict.get("kspace_conv_hidden_channels"), + n_convs=cfg_dict.get("kspace_conv_n_convs"), + batchnorm=cfg_dict.get("kspace_conv_batchnorm"), + ) + elif dual_model_architecture == "DIDN": + dual_model = DIDN( + in_channels=2 * (self.num_dual + 2), + out_channels=2 * self.num_dual, + hidden_channels=cfg_dict.get("kspace_didn_hidden_channels"), + num_dubs=cfg_dict.get("kspace_didn_num_dubs"), + num_convs_recon=cfg_dict.get("kspace_didn_num_convs_recon"), + ) + elif dual_model_architecture in ["UNET", "NORMUNET"]: + dual_model = NormUnet( + cfg_dict.get("dual_unet_num_filters"), + cfg_dict.get("dual_unet_num_pool_layers"), + in_chans=2 * (self.num_dual + 2), + out_chans=2 * self.num_dual, + drop_prob=cfg_dict.get("dual_unet_dropout_probability"), + padding_size=cfg_dict.get("dual_unet_padding_size"), + normalize=cfg_dict.get("dual_unet_normalize"), + ) + else: + raise NotImplementedError( + f"LPDNet is currently implemented for dual_model_architecture == 'CONV' or 'DIDN' or 'UNet'." + f"Got dual_model_architecture == {dual_model_architecture}." + ) + + self.primal_net = torch.nn.ModuleList( + [PrimalNet(self.num_primal, primal_architecture=primal_model) for _ in range(self.num_iter)] + ) + self.dual_net = torch.nn.ModuleList( + [DualNet(self.num_dual, dual_architecture=dual_model) for _ in range(self.num_iter)] + ) + + self.fft_type = cfg_dict.get("fft_type") + self._coil_dim = 1 + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_chans"), + cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + normalize=cfg_dict.get("sens_normalize"), + ) + + self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() + self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() + self.output_type = cfg_dict.get("output_type") + + self.accumulate_estimates = False + +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + init_pred: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the network. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] + If self.accumulate_loss is True, returns a list of all intermediate estimates. + If False, returns the final estimate. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + + input_image = complex_mul( + ifft2c(torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), y), fft_type=self.fft_type), + complex_conj(sensitivity_maps), + ).sum(1) + dual_buffer = torch.cat([y] * self.num_dual, -1).to(y.device) + primal_buffer = torch.cat([input_image] * self.num_primal, -1).to(y.device) + + for idx in range(self.num_iter): + # Dual + f_2 = primal_buffer[..., 2:4].clone() + f_2 = torch.where( + mask == 0, + torch.tensor([0.0], dtype=f_2.dtype).to(f_2.device), + fft2c(complex_mul(f_2.unsqueeze(1), sensitivity_maps), fft_type=self.fft_type).type(f_2.type()), + ) + dual_buffer = self.dual_net[idx](dual_buffer, f_2, y) + + # Primal + h_1 = dual_buffer[..., 0:2].clone() + h_1 = complex_mul( + ifft2c( + torch.where(mask == 0, torch.tensor([0.0], dtype=h_1.dtype).to(h_1.device), h_1), + fft_type=self.fft_type, + ), + complex_conj(sensitivity_maps), + ).sum(1) + primal_buffer = self.primal_net[idx](primal_buffer, h_1) + + output = primal_buffer[..., 0:2] + output = (output**2).sum(-1).sqrt() + _, output = center_crop_to_smallest(target, output) + return output
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/multidomain/multidomain.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/multidomain/multidomain.html new file mode 100644 index 00000000..65a6389d --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/multidomain/multidomain.html @@ -0,0 +1,386 @@ + + + + + + mridc.collections.reconstruction.models.multidomain.multidomain — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.multidomain.multidomain
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.multidomain.multidomain

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from:https://github.com/NKI-AI/direct/blob/main/direct/nn/multidomainnet/multidomain.py
+# Copyright (c) DIRECT Contributors
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul
+
+
+
[docs]class MultiDomainConv2d(nn.Module): + """Multi-domain convolution layer.""" + + def __init__( + self, + fft_type, + in_channels, + out_channels, + **kwargs, + ): + super().__init__() + + self.image_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs) + self.kspace_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs) + self.fft_type = fft_type + self._channels_dim = 1 + self._spatial_dims = [1, 2] + +
[docs] def forward(self, image): + """Forward method for the MultiDomainConv2d class.""" + kspace = [ + fft2c(im, fft_type=self.fft_type, fft_dim=self._spatial_dims) + for im in torch.split(image.permute(0, 2, 3, 1).contiguous(), 2, -1) + ] + kspace = torch.cat(kspace, -1).permute(0, 3, 1, 2) + kspace = self.kspace_conv(kspace) + + backward = [ + ifft2c(ks.float(), fft_type=self.fft_type, fft_dim=self._spatial_dims).type(image.type()) + for ks in torch.split(kspace.permute(0, 2, 3, 1).contiguous(), 2, -1) + ] + backward = torch.cat(backward, -1).permute(0, 3, 1, 2) + + image = self.image_conv(image) + image = torch.cat([image, backward], dim=self._channels_dim) + return image
+ + +
[docs]class MultiDomainConvTranspose2d(nn.Module): + """Multi-Domain convolutional transpose layer.""" + + def __init__( + self, + fft_type, + in_channels, + out_channels, + **kwargs, + ): + super().__init__() + + self.image_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs) + self.kspace_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs) + self.fft_type = fft_type + self._channels_dim = 1 + self._spatial_dims = [1, 2] + +
[docs] def forward(self, image): + """Forward method for the MultiDomainConvTranspose2d class.""" + kspace = [ + fft2c(im, fft_type=self.fft_type, fft_dim=self._spatial_dims) + for im in torch.split(image.permute(0, 2, 3, 1).contiguous(), 2, -1) + ] + kspace = torch.cat(kspace, -1).permute(0, 3, 1, 2) + kspace = self.kspace_conv(kspace) + + backward = [ + ifft2c(ks.float(), fft_type=self.fft_type, fft_dim=self._spatial_dims).type(image.type()) + for ks in torch.split(kspace.permute(0, 2, 3, 1).contiguous(), 2, -1) + ] + backward = torch.cat(backward, -1).permute(0, 3, 1, 2) + + image = self.image_conv(image) + return torch.cat([image, backward], dim=self._channels_dim)
+ + +
[docs]class MultiDomainConvBlock(nn.Module): + """ + A multi-domain convolutional block that consists of two multi-domain convolution layers each followed by instance + normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, fft_type, in_channels: int, out_channels: int, dropout_probability: float): + """ + Parameters + ---------- + in_channels: Number of input channels. + out_channels: Number of output channels. + dropout_probability: Dropout probability. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.dropout_probability = dropout_probability + + self.layers = nn.Sequential( + MultiDomainConv2d(fft_type, in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(dropout_probability), + MultiDomainConv2d(fft_type, out_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(dropout_probability), + ) + +
[docs] def forward(self, _input: torch.Tensor): + """Forward method for the MultiDomainConvBlock class.""" + return self.layers(_input)
+ + def __repr__(self): + return ( + f"MultiDomainConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels}, " + f"dropout_probability={self.dropout_probability})" + )
+ + +
[docs]class TransposeMultiDomainConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose layers followed by instance + normalization and LeakyReLU activation. + """ + + def __init__(self, fft_type, in_channels: int, out_channels: int): + """ + Parameters + ---------- + in_channels: Number of input channels. + out_channels: Number of output channels. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.layers = nn.Sequential( + MultiDomainConvTranspose2d(fft_type, in_channels, out_channels, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + +
[docs] def forward(self, input_data: torch.Tensor): + """Forward method for the TransposeMultiDomainConvBlock class.""" + return self.layers(input_data)
+ + def __repr__(self): + return f"MultiDomainConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels})"
+ + +
[docs]class StandardizationLayer(nn.Module): + """ + Multi-channel data standardization method. Inspired by AIRS model submission to the Fast MRI 2020 challenge. + Given individual coil images :math:`\{x_i\}_{i=1}^{N_c}` and sensitivity coil maps :math:`\{S_i\}_{i=1}^{N_c}` \ + it returns + + .. math:: + + [(x_{sense}, {x_{res}}_1), ..., (x_{sense}, {x_{res}}_{N_c})] + + where + + :math:`{x_{res}}_i = xi - S_i X x_{sense}` and + + :math:`x_{sense} = \sum_{i=1}^{N_c} {S_i}^{*} X x_i`. + """ + + def __init__(self, coil_dim=1, channel_dim=-1): + super().__init__() + self.coil_dim = coil_dim + self.channel_dim = channel_dim + +
[docs] def forward(self, coil_images: torch.Tensor, sensitivity_map: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + combined_image = complex_mul(coil_images, complex_conj(sensitivity_map)).sum(self.coil_dim) + residual_image = combined_image.unsqueeze(self.coil_dim) - complex_mul( + combined_image.unsqueeze(self.coil_dim), sensitivity_map + ) + return torch.cat( + [ + torch.cat( + [combined_image, residual_image.select(self.coil_dim, idx)], + self.channel_dim, + ).unsqueeze(self.coil_dim) + for idx in range(coil_images.size(self.coil_dim)) + ], + self.coil_dim, + )
+ + +
[docs]class MultiDomainUnet2d(nn.Module): + """ + Unet modification to be used with Multi-domain network as in AIRS Medical submission to the Fast MRI 2020 + challenge. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_filters: int, + num_pool_layers: int, + dropout_probability: float, + fft_type: str = "orthogonal", + ): + """ + Parameters + ---------- + in_channels: Number of input channels to the u-net. + out_channels: Number of output channels to the u-net. + num_filters: Number of output channels of the first convolutional layer. + num_pool_layers: Number of down-sampling and up-sampling layers (depth). + dropout_probability: Dropout probability. + fft_type: FFT type. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_filters = num_filters + self.num_pool_layers = num_pool_layers + self.dropout_probability = dropout_probability + self.fft_type = fft_type + + self.down_sample_layers = nn.ModuleList( + [MultiDomainConvBlock(fft_type, in_channels, num_filters, dropout_probability)] + ) + ch = num_filters + for _ in range(num_pool_layers - 1): + self.down_sample_layers += [MultiDomainConvBlock(fft_type, ch, ch * 2, dropout_probability)] + ch *= 2 + self.conv = MultiDomainConvBlock(fft_type, ch, ch * 2, dropout_probability) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv += [TransposeMultiDomainConvBlock(fft_type, ch * 2, ch)] + self.up_conv += [MultiDomainConvBlock(fft_type, ch * 2, ch, dropout_probability)] + ch //= 2 + + self.up_transpose_conv += [TransposeMultiDomainConvBlock(fft_type, ch * 2, ch)] + self.up_conv += [ + nn.Sequential( + MultiDomainConvBlock(fft_type, ch * 2, ch, dropout_probability), + nn.Conv2d(ch, self.out_channels, kernel_size=1, stride=1), + ) + ] + +
[docs] def forward(self, input_data: torch.Tensor): + """Forward pass of the u-net.""" + stack = [] + output = input_data + + # Apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output = self.conv(output) + + # Apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # Reflect pad on the right/bottom if needed to handle odd input dimensions. + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # Padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # Padding bottom + if sum(padding) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/multidomainnet.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/multidomainnet.html new file mode 100644 index 00000000..5d4fe258 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/multidomainnet.html @@ -0,0 +1,225 @@ + + + + + + mridc.collections.reconstruction.models.multidomainnet — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.multidomainnet
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.multidomainnet

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from abc import ABC
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from torch.nn import L1Loss
+
+from mridc.collections.common.losses.ssim import SSIMLoss
+from mridc.collections.common.parts.fft import ifft2c
+from mridc.collections.common.parts.utils import coil_combination
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.models.multidomain.multidomain import MultiDomainUnet2d, StandardizationLayer
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["MultiDomainNet"]
+
+
+
[docs]class MultiDomainNet(BaseMRIReconstructionModel, ABC): + """Feature-level multi-domain module. Inspired by AIRS Medical submission to the FastMRI 2020 challenge.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + self._coil_dim = 1 + self._complex_dim = -1 + + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + standardization = cfg_dict["standardization"] + if standardization: + self.standardization = StandardizationLayer(self._coil_dim, self._complex_dim) + + self.fft_type = cfg_dict.get("fft_type") + + self.unet = MultiDomainUnet2d( + in_channels=4 if standardization else 2, # if standardization, in_channels is 4 due to standardized input + out_channels=2, + num_filters=cfg_dict["num_filters"], + num_pool_layers=cfg_dict["num_pool_layers"], + dropout_probability=cfg_dict["dropout_probability"], + fft_type=self.fft_type, + ) + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_chans"), + cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + normalize=cfg_dict.get("sens_normalize"), + ) + + self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() + self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() + self.output_type = cfg_dict.get("output_type") + + self.accumulate_estimates = False + + def _compute_model_per_coil(self, model, data): + """ + Compute the model per coil. + + Parameters + ---------- + model: torch.nn.Module + The model to be computed. + data: torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + The data to be computed. + + Returns + ------- + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + The computed output. + """ + output = [] + for idx in range(data.size(self._coil_dim)): + subselected_data = data.select(self._coil_dim, idx) + output.append(model(subselected_data)) + output = torch.stack(output, dim=self._coil_dim) + return output + +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + init_pred: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the network. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] + If self.accumulate_loss is True, returns a list of all intermediate estimates. + If False, returns the final estimate. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + image = ifft2c(y, fft_type=self.fft_type) + + if hasattr(self, "standardization"): + image = self.standardization(image, sensitivity_maps) + + output_image = self._compute_model_per_coil(self.unet, image.permute(0, 1, 4, 2, 3)).permute(0, 1, 3, 4, 2) + output_image = coil_combination(output_image, sensitivity_maps, method=self.output_type, dim=1) + output_image = torch.view_as_complex(output_image) + _, output_image = center_crop_to_smallest(target, output_image) + return output_image
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/mwcnn/mwcnn.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/mwcnn/mwcnn.html new file mode 100644 index 00000000..382a6d64 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/mwcnn/mwcnn.html @@ -0,0 +1,581 @@ + + + + + + mridc.collections.reconstruction.models.mwcnn.mwcnn — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.mwcnn.mwcnn
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.mwcnn.mwcnn

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/mwcnn/mwcnn.py
+# Copyright (c) DIRECT Contributors
+
+from collections import OrderedDict
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+
[docs]class DWT(nn.Module): + """ + 2D Discrete Wavelet Transform as implemented in Liu, Pengju, et al. + + References + ---------- + + .. + + Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \ + arXiv.org, http://arxiv.org/abs/1805.07071. + + """ + + def __init__(self): + """Inits DWT.""" + super().__init__() + self.requires_grad = False + +
[docs] @staticmethod + def forward(x: torch.Tensor) -> torch.Tensor: + """ + Computes DWT(`x`) given tensor `x`. + + Parameters + ---------- + x: Input tensor. + + Returns + ------- + DWT of `x`. + """ + x01 = x[:, :, 0::2, :] / 2 + x02 = x[:, :, 1::2, :] / 2 + x1 = x01[:, :, :, 0::2] + x2 = x02[:, :, :, 0::2] + x3 = x01[:, :, :, 1::2] + x4 = x02[:, :, :, 1::2] + x_LL = x1 + x2 + x3 + x4 + x_HL = -x1 - x2 + x3 + x4 + x_LH = -x1 + x2 - x3 + x4 + x_HH = x1 - x2 - x3 + x4 + + return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
+ + +
[docs]class IWT(nn.Module): + """ + 2D Inverse Wavelet Transform as implemented in Liu, Pengju, et al. + + References + ---------- + + .. + + Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \ + arXiv.org, http://arxiv.org/abs/1805.07071. + + """ + + def __init__(self): + """Inits IWT.""" + super().__init__() + self.requires_grad = False + self._r = 2 + +
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes IWT(`x`) given tensor `x`. + + Parameters + ---------- + x: Input tensor. + + Returns + ------- + IWT of `x`. + """ + batch, in_channel, in_height, in_width = x.size() + out_channel, out_height, out_width = int(in_channel / (self._r**2)), self._r * in_height, self._r * in_width + + x1 = x[:, 0:out_channel, :, :] / 2 + x2 = x[:, out_channel : out_channel * 2, :, :] / 2 + x3 = x[:, out_channel * 2 : out_channel * 3, :, :] / 2 + x4 = x[:, out_channel * 3 : out_channel * 4, :, :] / 2 + + h = torch.zeros([batch, out_channel, out_height, out_width], dtype=x.dtype).to(x.device) + + h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4 + h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4 + h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4 + h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4 + + return h
+ + +
[docs]class ConvBlock(nn.Module): + """ + Convolution Block for MWCNN as implemented in Liu, Pengju, et al. + + References + ---------- + + .. + + Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \ + arXiv.org, http://arxiv.org/abs/1805.07071. + + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + bias: bool = True, + batchnorm: bool = False, + activation: nn.Module = nn.ReLU(True), + scale: Optional[float] = 1.0, + ): + """ + Inits ConvBlock. + + Parameters + ---------- + in_channels: Number of input channels. + int + out_channels: Number of output channels. + int + kernel_size: Conv kernel size. + int + bias: Use convolution bias. + bool, Default: True. + batchnorm: Use batch normalization. + bool, Default: False. + activation: Activation function. + torch.nn.Module, Default: nn.ReLU(True). + scale: Scale factor for convolution. + float (optional), Default: 1.0. + """ + super().__init__() + + net = [ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + padding=kernel_size // 2, + ) + ] + + if batchnorm: + net.append(nn.BatchNorm2d(num_features=out_channels, eps=1e-4, momentum=0.95)) + net.append(activation) + + self.net = nn.Sequential(*net) + self.scale = scale + +
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass of ConvBlock. + + Parameters + ---------- + x: Input with shape (N, C, H, W). + + Returns + ------- + Output with shape (N, C', H', W'). + """ + return self.net(x) * self.scale
+ + +
[docs]class DilatedConvBlock(nn.Module): + """ + Double dilated Convolution Block fpr MWCNN as implemented in Liu, Pengju, et al. + + References + ---------- + + .. + + Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \ + arXiv.org, http://arxiv.org/abs/1805.07071. + + """ + + def __init__( + self, + in_channels: int, + dilations: Tuple[int, int], + kernel_size: int, + out_channels: Optional[int] = None, + bias: bool = True, + batchnorm: bool = False, + activation: nn.Module = nn.ReLU(True), + scale: Optional[float] = 1.0, + ): + """ + Inits DilatedConvBlock. + + Parameters + ---------- + in_channels: Number of input channels. + int + dilations: Number of dilations. + Tuple[int, int], Default: (1, 1). + kernel_size: Conv kernel size. + int + out_channels: Number of output channels. + int (optional), Default: None. + bias: Use convolution bias. + bool, Default: True. + batchnorm: Use batch normalization. + bool, Default: False. + activation: Activation function. + torch.nn.Module, Default: nn.ReLU(True). + scale: Scale factor for convolution. + float (optional), Default: 1.0. + """ + super().__init__() + net = [ + nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + bias=bias, + dilation=dilations[0], + padding=kernel_size // 2 + dilations[0] - 1, + ) + ] + + if batchnorm: + net.append(nn.BatchNorm2d(num_features=in_channels, eps=1e-4, momentum=0.95)) + net.append(activation) + if out_channels is None: + out_channels = in_channels + net.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + dilation=dilations[1], + padding=kernel_size // 2 + dilations[1] - 1, + ) + ) + if batchnorm: + net.append(nn.BatchNorm2d(num_features=in_channels, eps=1e-4, momentum=0.95)) + net.append(activation) + + self.net = nn.Sequential(*net) + self.scale = scale + +
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass of DilatedConvBlock. + + Parameters + ---------- + x: Input with shape (N, C, H, W). + + Returns + ------- + Output with shape (N, C', H', W'). + """ + return self.net(x) * self.scale
+ + +
[docs]class MWCNN(nn.Module): + """ + Multi-level Wavelet CNN (MWCNN) implementation as implemented in Liu, Pengju, et al. + + References + ---------- + + .. + + Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \ + arXiv.org, http://arxiv.org/abs/1805.07071. + + """ + + def __init__( + self, + input_channels: int, + first_conv_hidden_channels: int, + num_scales: int = 4, + bias: bool = True, + batchnorm: bool = False, + activation: nn.Module = nn.ReLU(True), + ): + """ + Inits MWCNN. + + Parameters + ---------- + input_channels: Input channels dimension. + int + first_conv_hidden_channels: First convolution output channels dimension. + int + num_scales: Number of scales. + int, Default: 4. + bias: Convolution bias. If True, adds a learnable bias to the output. + bool, Default: True. + batchnorm: If True, a batchnorm layer is added after each convolution. + bool, Default: False. + activation: Activation function applied after each convolution. + torch.nn.Module, Default: nn.ReLU(). + """ + super().__init__() + self._kernel_size = 3 + self.DWT = DWT() + self.IWT = IWT() + + self.down = nn.ModuleList() + for idx in range(num_scales): + in_channels = input_channels if idx == 0 else first_conv_hidden_channels * 2 ** (idx + 1) + out_channels = first_conv_hidden_channels * 2**idx + dilations = (2, 1) if idx != num_scales - 1 else (2, 3) + self.down.append( + nn.Sequential( + OrderedDict( + [ + ( + f"convblock{idx}", + ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self._kernel_size, + bias=bias, + batchnorm=batchnorm, + activation=activation, + ), + ), + ( + f"dilconvblock{idx}", + DilatedConvBlock( + in_channels=out_channels, + dilations=dilations, + kernel_size=self._kernel_size, + bias=bias, + batchnorm=batchnorm, + activation=activation, + ), + ), + ] + ) + ) + ) + self.up = nn.ModuleList() + for idx in range(num_scales)[::-1]: + in_channels = first_conv_hidden_channels * 2**idx + out_channels = input_channels if idx == 0 else first_conv_hidden_channels * 2 ** (idx + 1) + dilations = (2, 1) if idx != num_scales - 1 else (3, 2) + self.up.append( + nn.Sequential( + OrderedDict( + [ + ( + f"invdilconvblock{num_scales - 2 - idx}", + DilatedConvBlock( + in_channels=in_channels, + dilations=dilations, + kernel_size=self._kernel_size, + bias=bias, + batchnorm=batchnorm, + activation=activation, + ), + ), + ( + f"invconvblock{num_scales - 2 - idx}", + ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self._kernel_size, + bias=bias, + batchnorm=batchnorm, + activation=activation, + ), + ), + ] + ) + ) + ) + self.num_scales = num_scales + +
[docs] @staticmethod + def pad(x): + """ + Pad the input with zeros. + + Parameters + ---------- + x: Input tensor. + + Returns + ------- + Padded tensor. + """ + padding = [0, 0, 0, 0] + + if x.shape[-2] % 2 != 0: + padding[3] = 1 # Padding right - width + if x.shape[-1] % 2 != 0: + padding[1] = 1 # Padding bottom - height + if sum(padding) != 0: + x = F.pad(x, padding, "reflect") + return x
+ +
[docs] @staticmethod + def crop_to_shape(x, shape): + """ + Crop the input to the given shape. + + Parameters + ---------- + x: Input tensor. + shape: Tuple of (height, width). + + Returns + ------- + Cropped tensor. + """ + h, w = x.shape[-2:] + + if h > shape[0]: + x = x[:, :, : shape[0], :] + if w > shape[1]: + x = x[:, :, :, : shape[1]] + return x
+ +
[docs] def forward(self, input_tensor: torch.Tensor, res: bool = False) -> torch.Tensor: + """ + Computes forward pass of MWCNN. + + Parameters + ---------- + input_tensor: Input tensor. + torch.tensor + res: If True, residual connection is applied to the output. + bool, Default: False. + + Returns + ------- + Output tensor. + """ + res_values = [] + x = self.pad(input_tensor.clone()) + for idx in range(self.num_scales): + if idx == 0: + x = self.pad(self.down[idx](x)) + res_values.append(x) + elif idx == self.num_scales - 1: + x = self.down[idx](self.DWT(x)) + else: + x = self.pad(self.down[idx](self.DWT(x))) + res_values.append(x) + + for idx in range(self.num_scales): + if idx != self.num_scales - 1: + x = ( + self.crop_to_shape(self.IWT(self.up[idx](x)), res_values[self.num_scales - 2 - idx].shape[-2:]) + + res_values[self.num_scales - 2 - idx] + ) + else: + x = self.crop_to_shape(self.up[idx](x), input_tensor.shape[-2:]) + if res: + x += input_tensor + return x
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/pics.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/pics.html new file mode 100644 index 00000000..f21339e8 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/pics.html @@ -0,0 +1,269 @@ + + + + + + mridc.collections.reconstruction.models.pics — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.pics
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.pics

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from abc import ABC
+from typing import Any, Dict, Tuple, Union
+
+# import bart
+import numpy as np
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["PICS"]
+
+
+
[docs]class PICS(BaseMRIReconstructionModel, ABC): + """ + Parallel-Imaging Compressed Sensing (PICS) reconstruction using the BART by Uecker, M. et al. + + References + ---------- + + .. + + Uecker, M. et al. (2015) ‘Berkeley Advanced Reconstruction Toolbox’, Proc. Intl. Soc. Mag. Reson. Med., 23. + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + self.reg_wt = cfg_dict.get("reg_wt") + self.num_iters = cfg_dict.get("num_iters") + self._device = cfg_dict.get("device") + self.fft_type = cfg_dict.get("fft_type") + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_chans"), + cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + normalize=cfg_dict.get("sens_normalize"), + ) + +
[docs] @staticmethod + def process_inputs(y, mask): + """ + Process the inputs to the method. + + Parameters + ---------- + y: Subsampled k-space data. + list of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + list of torch.Tensor, shape [1, 1, n_x, n_y, 1] + + Returns + ------- + y: Subsampled k-space data. + randomly selected y + mask: Sampling mask. + randomly selected mask + r: Random index. + """ + if isinstance(y, list): + r = np.random.randint(len(y)) + y = y[r] + mask = mask[r] + else: + r = 0 + return y, mask, r
+ +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + target: torch.Tensor = None, + ) -> Union[list, Any]: + """ + Forward pass of PICS. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: torch.Tensor, shape [batch_size, n_x, n_y, 2] + Predicted data. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + pred = torch.zeros_like(sensitivity_maps) + # if "cuda" in str(self._device): + # pred = bart.bart(1, f"pics -d0 -g -S -R W:7:0:{self.reg_wt} -i {self.num_iters}", y, sensitivity_maps)[0] + # else: + # pred = bart.bart(1, f"pics -d0 -S -R W:7:0:{self.reg_wt} -i {self.num_iters}", y, sensitivity_maps)[0] + _, pred = center_crop_to_smallest(target, pred) + return pred
+ +
[docs] def test_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Tuple[str, int, torch.Tensor]: + """ + Test step. + + Parameters + ---------- + batch: Batch of data. + Dict of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + batch_idx: Batch index. + int + + Returns + ------- + name: Name of the volume. + str + slice_num: Slice number. + int + pred: Predicted data. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + """ + y, sensitivity_maps, mask, _, target, fname, slice_num, _ = batch + y, mask, _ = self.process_inputs(y, mask) + + y = torch.view_as_complex(y).permute(0, 2, 3, 1).detach().cpu().numpy() + + if sensitivity_maps is None and not self.sens_net: + raise ValueError( + "Sensitivity maps are required for PICS. " + "Please set use_sens_net to True if you precomputed sensitivity maps are not available." + ) + + sensitivity_maps = torch.view_as_complex(sensitivity_maps) + if self.fft_type != "orthogonal": + sensitivity_maps = torch.fft.fftshift(sensitivity_maps, dim=(-2, -1)) + sensitivity_maps = sensitivity_maps.permute(0, 2, 3, 1).detach().cpu().numpy() # type: ignore + + prediction = torch.from_numpy(self.forward(y, sensitivity_maps, mask, target)).unsqueeze(0) + if self.fft_type != "orthogonal": + prediction = torch.fft.fftshift(prediction, dim=(-2, -1)) + + slice_num = int(slice_num) + name = str(fname[0]) # type: ignore + key = f"{name}_images_idx_{slice_num}" # type: ignore + output = torch.abs(prediction).detach().cpu() + target = torch.abs(target).detach().cpu() + output = output / output.max() # type: ignore + target = target / target.max() # type: ignore + error = torch.abs(target - output) + self.log_image(f"{key}/target", target) + self.log_image(f"{key}/reconstruction", output) + self.log_image(f"{key}/error", error) + + return name, slice_num, prediction.detach().cpu().numpy()
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/primaldual/pd.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/primaldual/pd.html new file mode 100644 index 00000000..64834781 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/primaldual/pd.html @@ -0,0 +1,208 @@ + + + + + + mridc.collections.reconstruction.models.primaldual.pd — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.primaldual.pd
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.primaldual.pd

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/lpd/lpd.py
+# Copyright (c) DIRECT Contributors
+
+import torch
+import torch.nn as nn
+
+
+
[docs]class DualNet(nn.Module): + """Dual Network for Learned Primal Dual Network.""" + + def __init__(self, num_dual, **kwargs): + """ + Inits DualNet. + + Parameters + ---------- + num_dual: Number of dual for LPD algorithm. + kwargs: Keyword arguments. + """ + super().__init__() + + if kwargs.get("dual_architecture") is None: + n_hidden = kwargs.get("n_hidden") + if n_hidden is None: + raise ValueError("n_hidden is required for DualNet") + + self.dual_block = nn.Sequential( + *[ + nn.Conv2d(2 * (num_dual + 2), n_hidden, kernel_size=3, padding=1), + nn.PReLU(), + nn.Conv2d(n_hidden, n_hidden, kernel_size=3, padding=1), + nn.PReLU(), + nn.Conv2d(n_hidden, 2 * num_dual, kernel_size=3, padding=1), + ] + ) + else: + self.dual_block = kwargs.get("dual_architecture") + +
[docs] @staticmethod + def compute_model_per_coil(model, data): + """ + Computes model per coil. + + Parameters + ---------- + model: Model to compute. + data: Multi-coil input. + + Returns + ------- + Multi-coil output. + """ + output = [] + for idx in range(data.size(1)): + subselected_data = data.select(1, idx) + output.append(model(subselected_data)) + output = torch.stack(output, dim=1) + return output
+ +
[docs] def forward(self, h, forward_f, g): + """Forward pass.""" + inp = torch.cat([h, forward_f, g], dim=-1).permute(0, 1, 4, 2, 3) + return self.compute_model_per_coil(self.dual_block, inp).permute(0, 1, 3, 4, 2)
+ + +
[docs]class PrimalNet(nn.Module): + """Primal Network for Learned Primal Dual Network.""" + + def __init__(self, num_primal, **kwargs): + """ + Inits PrimalNet. + + Parameters + ---------- + num_primal: Number of primal for LPD algorithm. + """ + super().__init__() + + if kwargs.get("primal_architecture") is None: + n_hidden = kwargs.get("n_hidden") + if n_hidden is None: + raise ValueError("Missing argument n_hidden.") + self.primal_block = nn.Sequential( + *[ + nn.Conv2d(2 * (num_primal + 1), n_hidden, kernel_size=3, padding=1), + nn.PReLU(), + nn.Conv2d(n_hidden, n_hidden, kernel_size=3, padding=1), + nn.PReLU(), + nn.Conv2d(n_hidden, 2 * num_primal, kernel_size=3, padding=1), + ] + ) + else: + self.primal_block = kwargs.get("primal_architecture") + +
[docs] def forward(self, f, backward_h): + """ + Forward pass of primal network. + + Parameters + ---------- + f: Forward function. + backward_h: Backward function. + + Returns + ------- + Primal function. + """ + inp = torch.cat([f, backward_h], dim=-1).permute(0, 3, 1, 2) + return self.primal_block(inp).permute(0, 2, 3, 1)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/recurrentvarnet/conv2gru.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/recurrentvarnet/conv2gru.html new file mode 100644 index 00000000..1bcb68ee --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/recurrentvarnet/conv2gru.html @@ -0,0 +1,259 @@ + + + + + + mridc.collections.reconstruction.models.recurrentvarnet.conv2gru — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.recurrentvarnet.conv2gru
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.recurrentvarnet.conv2gru

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/recurrent/recurrent.py
+# Copyright (c) DIRECT Contributors
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+
[docs]class Conv2dGRU(nn.Module): + """2D Convolutional GRU Network.""" + + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 2, + gru_kernel_size=1, + orthogonal_initialization: bool = True, + instance_norm: bool = False, + dense_connect: int = 0, + replication_padding: bool = True, + ): + """ + Inits Conv2dGRU. + + Parameters + ---------- + in_channels: Number of input channels. + int + hidden_channels: Number of hidden channels. + int + out_channels: Number of output channels. If None, same as in_channels. + int (optional), Default: None. + num_layers: Number of layers. + int, Default: 2. + gru_kernel_size: Size of the GRU kernel. + int, Default: 1. + orthogonal_initialization: Orthogonal initialization is used if set to True. + bool, Default: True. + instance_norm: Instance norm is used if set to True. + bool, Default: False. + dense_connect: Number of dense connections. + replication_padding: If set to true replication padding is applied. + """ + super().__init__() + + if out_channels is None: + out_channels = in_channels + + self.num_layers = num_layers + self.hidden_channels = hidden_channels + self.dense_connect = dense_connect + + self.reset_gates = nn.ModuleList([]) + self.update_gates = nn.ModuleList([]) + self.out_gates = nn.ModuleList([]) + self.conv_blocks = nn.ModuleList([]) + + # Create convolutional blocks + for idx in range(num_layers + 1): + in_ch = in_channels if idx == 0 else (1 + min(idx, dense_connect)) * hidden_channels + out_ch = hidden_channels if idx < num_layers else out_channels + padding = 0 if replication_padding else (2 if idx == 0 else 1) + block = [] + if replication_padding: + if idx == 1: + block.append(nn.ReplicationPad2d(2)) + else: + block.append(nn.ReplicationPad2d(2 if idx == 0 else 1)) + block.append( + nn.Conv2d( + in_channels=in_ch, + out_channels=out_ch, + kernel_size=5 if idx == 0 else 3, + dilation=(2 if idx == 1 else 1), + padding=padding, + ) + ) + self.conv_blocks.append(nn.Sequential(*block)) + + # Create GRU blocks + for _ in range(num_layers): + for gru_part in [self.reset_gates, self.update_gates, self.out_gates]: + block = [] + if instance_norm: + block.append(nn.InstanceNorm2d(2 * hidden_channels)) + block.append( + nn.Conv2d( + in_channels=2 * hidden_channels, + out_channels=hidden_channels, + kernel_size=gru_kernel_size, + padding=gru_kernel_size // 2, + ) + ) + gru_part.append(nn.Sequential(*block)) + + if orthogonal_initialization: + for reset_gate, update_gate, out_gate in zip(self.reset_gates, self.update_gates, self.out_gates): + nn.init.orthogonal_(reset_gate[-1].weight) + nn.init.orthogonal_(update_gate[-1].weight) + nn.init.orthogonal_(out_gate[-1].weight) + nn.init.constant_(reset_gate[-1].bias, -1.0) + nn.init.constant_(update_gate[-1].bias, 0.0) + nn.init.constant_(out_gate[-1].bias, 0.0) + +
[docs] def forward( + self, + cell_input: torch.Tensor, + previous_state: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes Conv2dGRU forward pass given tensors `cell_input` and `previous_state`. + + Parameters + ---------- + cell_input: Reconstruction input + previous_state: Tensor of previous states. + + Returns + ------- + Output and new states. + """ + new_states: List[torch.Tensor] = [] + conv_skip: List[torch.Tensor] = [] + + if previous_state is None: + batch_size, spatial_size = cell_input.size(0), (cell_input.size(2), cell_input.size(3)) + state_size = [batch_size, self.hidden_channels] + list(spatial_size) + [self.num_layers] + previous_state = torch.zeros(*state_size, dtype=cell_input.dtype).to(cell_input.device) + + for idx in range(self.num_layers): + if len(conv_skip) > 0: + cell_input = F.relu( + self.conv_blocks[idx](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1)), + inplace=True, + ) + else: + cell_input = F.relu(self.conv_blocks[idx](cell_input), inplace=True) + if self.dense_connect > 0: + conv_skip.append(cell_input) + + stacked_inputs = torch.cat([cell_input, previous_state[:, :, :, :, idx]], dim=1) + + update = torch.sigmoid(self.update_gates[idx](stacked_inputs)) + reset = torch.sigmoid(self.reset_gates[idx](stacked_inputs)) + delta = torch.tanh( + self.out_gates[idx](torch.cat([cell_input, previous_state[:, :, :, :, idx] * reset], dim=1)) + ) + cell_input = previous_state[:, :, :, :, idx] * (1 - update) + delta * update + new_states.append(cell_input) + cell_input = F.relu(cell_input, inplace=False) + if len(conv_skip) > 0: + out = self.conv_blocks[self.num_layers](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1)) + else: + out = self.conv_blocks[self.num_layers](cell_input) + + return out, torch.stack(new_states, dim=-1)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/recurrentvarnet/recurentvarnet.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/recurrentvarnet/recurentvarnet.html new file mode 100644 index 00000000..83e83841 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/recurrentvarnet/recurentvarnet.html @@ -0,0 +1,317 @@ + + + + + + mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/recurrentvarnet/recurrentvarnet.py
+# Copyright (c) DIRECT Contributors
+
+from typing import Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul
+from mridc.collections.reconstruction.models.recurrentvarnet.conv2gru import Conv2dGRU
+
+
+
[docs]class RecurrentInit(nn.Module): + """ + Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in Yiasemis, George, et al. + The RSI module learns to initialize the recurrent hidden state :math:`h_0`, input of the first + RecurrentVarNetBlock of the RecurrentVarNet. + + References + ---------- + + .. + + Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to \ + the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, \ + http://arxiv.org/abs/2111.09639. + + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + channels: Tuple[int, ...], + dilations: Tuple[int, ...], + depth: int = 2, + multiscale_depth: int = 1, + ): + """ + Inits RecurrentInit. + + Parameters + ---------- + in_channels: Input channels. + int + out_channels: Number of hidden channels of the recurrent unit of RecurrentVarNet Block. + int + channels: Channels :math:`n_d` in the convolutional layers of initializer. + Tuple[int, ...] + dilations: Dilations :math:`p` of the convolutional layers of the initializer. + Tuple[int, ...] + depth: RecurrentVarNet Block number of layers :math:`n_l`. + int + multiscale_depth: Number of feature layers to aggregate for the output, if 1, multi-scale context aggregation + is disabled. + int + """ + super().__init__() + + self.conv_blocks = nn.ModuleList() + self.out_blocks = nn.ModuleList() + self.depth = depth + self.multiscale_depth = multiscale_depth + tch = in_channels + for (curr_channels, curr_dilations) in zip(channels, dilations): + block = [ + nn.ReplicationPad2d(curr_dilations), + nn.Conv2d(tch, curr_channels, 3, padding=0, dilation=curr_dilations), + ] + tch = curr_channels + self.conv_blocks.append(nn.Sequential(*block)) + tch = np.sum(channels[-multiscale_depth:]) + for _ in range(depth): + block = [nn.Conv2d(tch, out_channels, 1, padding=0)] + self.out_blocks.append(nn.Sequential(*block)) + +
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes initialization for recurrent unit given input `x`. + + Parameters + ---------- + x: Initialization for RecurrentInit. + + Returns + ------- + Initial recurrent hidden state from input `x`. + """ + features = [] + for block in self.conv_blocks: + x = F.relu(block(x), inplace=True) + if self.multiscale_depth > 1: + features.append(x) + if self.multiscale_depth > 1: + x = torch.cat(features[-self.multiscale_depth :], dim=1) + output_list = [] + for block in self.out_blocks: + y = F.relu(block(x), inplace=True) + output_list.append(y) + return torch.stack(output_list, dim=-1)
+ + +
[docs]class RecurrentVarNetBlock(nn.Module): + """ + Recurrent Variational Network Block :math:`\mathcal{H}_{\theta_{t}}` as presented in Yiasemis, George, et al. + + + References + ---------- + + .. + + Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to \ + the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, \ + http://arxiv.org/abs/2111.09639. + + """ + + def __init__( + self, + in_channels: int = 2, + hidden_channels: int = 64, + num_layers: int = 4, + fft_type: str = "orthogonal", + ): + """ + Inits RecurrentVarNetBlock. + + Parameters + ---------- + in_channels: Input channel number. + int, Default is 2 for complex data. + hidden_channels: Hidden channels. + int, Default: 64. + num_layers: Number of layers of :math:`n_l` recurrent unit. + int, Default: 4. + fft_type: FFT type. + str, Default: "orthogonal". + """ + super().__init__() + self.fft_type = fft_type + + self.learning_rate = nn.Parameter(torch.tensor([1.0])) # :math:`\alpha_t` + self.regularizer = Conv2dGRU( + in_channels=in_channels, + hidden_channels=hidden_channels, + num_layers=num_layers, + replication_padding=True, + ) # Recurrent Unit of RecurrentVarNet Block :math:`\mathcal{H}_{\theta_t}` + +
[docs] def forward( + self, + current_kspace: torch.Tensor, + masked_kspace: torch.Tensor, + sampling_mask: torch.Tensor, + sensitivity_map: torch.Tensor, + hidden_state: Union[None, torch.Tensor], + coil_dim: int = 1, + complex_dim: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes forward pass of RecurrentVarNetBlock. + + Parameters + ---------- + current_kspace: Current k-space prediction. + torch.Tensor, shape [batch_size, n_coil, height, width, 2] + masked_kspace: Subsampled k-space. + torch.Tensor, shape [batch_size, n_coil, height, width, 2] + sampling_mask: Sampling mask. + torch.Tensor, shape [batch_size, 1, height, width, 1] + sensitivity_map: Coil sensitivities. + torch.Tensor, shape [batch_size, n_coil, height, width, 2] + hidden_state: ConvGRU hidden state. + None or torch.Tensor, shape [batch_size, n_l, height, width, hidden_channels] + coil_dim: Coil dimension. + int, Default: 1. + complex_dim: Complex dimension. + int, Default: -1. + + Returns + ------- + new_kspace: New k-space prediction. + torch.Tensor, shape [batch_size, n_coil, height, width, 2] + hidden_state: Next hidden state. + list of torch.Tensor, shape [batch_size, hidden_channels, height, width, num_layers] + """ + kspace_error = torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), + current_kspace - masked_kspace, + ) + + recurrent_term = torch.cat( + [ + complex_mul(ifft2c(kspace, fft_type=self.fft_type), complex_conj(sensitivity_map)).sum(coil_dim) + for kspace in torch.split(current_kspace, 2, complex_dim) + ], + dim=complex_dim, + ).permute(0, 3, 1, 2) + + recurrent_term, hidden_state = self.regularizer(recurrent_term, hidden_state) # :math:`w_t`, :math:`h_{t+1}` + recurrent_term = recurrent_term.permute(0, 2, 3, 1) + + recurrent_term = torch.cat( + [ + fft2c(complex_mul(image.unsqueeze(coil_dim), sensitivity_map), fft_type=self.fft_type) + for image in torch.split(recurrent_term, 2, complex_dim) + ], + dim=complex_dim, + ) + + new_kspace = current_kspace - self.learning_rate * kspace_error + recurrent_term + + return new_kspace, hidden_state
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/rim/conv_layers.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/rim/conv_layers.html new file mode 100644 index 00000000..de3b0214 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/rim/conv_layers.html @@ -0,0 +1,217 @@ + + + + + + mridc.collections.reconstruction.models.rim.conv_layers — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.rim.conv_layers
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.rim.conv_layers

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+import torch
+import torch.nn as nn
+
+
+
[docs]class ConvRNNStack(nn.Module): + """A stack of convolutional RNNs.""" + + def __init__(self, convs, rnn): + """ + Parameters + ---------- + convs: list of convolutional layers + rnn: list of RNN layers + """ + super(ConvRNNStack, self).__init__() + self.convs = convs + self.rnn = rnn + +
[docs] def forward(self, x, hidden): + """ + Parameters + ---------- + x: [batch_size, seq_len, input_size] + hidden: [num_layers * num_directions, batch_size, hidden_size + + Returns + ------- + output: [batch_size, seq_len, hidden_size] + """ + return self.rnn(self.convs(x), hidden)
+ + +
[docs]class ConvNonlinear(nn.Module): + """A convolutional layer with nonlinearity.""" + + def __init__(self, input_size, features, conv_dim, kernel_size, dilation, bias, nonlinear="relu"): + """ + Initializes the convolutional layer. + + Parameters + ---------- + input_size: number of input channels. + features: number of output channels. + conv_dim: number of dimensions of the convolutional layer. + kernel_size: size of the convolutional kernel. + dilation: dilation of the convolutional kernel. + bias: whether to use bias. + nonlinear: nonlinearity of the convolutional layer. + """ + super(ConvNonlinear, self).__init__() + + self.input_size = input_size + self.features = features + self.kernel_size = kernel_size + self.dilation = dilation + self.bias = bias + self.conv_dim = conv_dim + self.conv_class = self.determine_conv_class(conv_dim) + + if nonlinear is not None and nonlinear.upper() == "RELU": + self.nonlinear = torch.nn.ReLU() + elif nonlinear is None: + self.nonlinear = lambda x: x + else: + raise ValueError("Please specify a proper nonlinearity") + + self.padding = [ + torch.nn.ReplicationPad1d(torch.div(dilation * (kernel_size - 1), 2, rounding_mode="trunc").item()), + torch.nn.ReplicationPad2d(torch.div(dilation * (kernel_size - 1), 2, rounding_mode="trunc").item()), + torch.nn.ReplicationPad3d(torch.div(dilation * (kernel_size - 1), 2, rounding_mode="trunc").item()), + ][conv_dim - 1] + + self.conv_layer = self.conv_class( + in_channels=input_size, + out_channels=features, + kernel_size=kernel_size, + padding=0, + dilation=dilation, + bias=bias, + ) + + self.reset_parameters() + +
[docs] def reset_parameters(self): + """Resets the parameters of the convolutional layer.""" + torch.nn.init.kaiming_normal_(self.conv_layer.weight, nonlinearity="relu") + + if self.conv_layer.bias is not None: + nn.init.zeros_(self.conv_layer.bias)
+ +
[docs] @staticmethod + def determine_conv_class(n_dim): + """Determines the convolutional layer class.""" + if n_dim == 1: + return nn.Conv1d + if n_dim == 2: + return nn.Conv2d + if n_dim == 3: + return nn.Conv3d + raise ValueError(f"Convolution of: {n_dim} dims is not implemented")
+ +
[docs] def extra_repr(self): + """Extra information about the layer.""" + s = "{input_size}, {features}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + if "nonlinear" in self.__dict__ and self.nonlinear != "tanh": + s += ", nonlinearity={nonlinear}" + return s.format(**self.__dict__)
+ +
[docs] def check_forward_input(self, _input): + """Checks input for correct size and shape.""" + if _input.size(1) != self.input_size: + raise RuntimeError(f"input has inconsistent input_size: got {_input.size(1)}, expected {self.input_size}")
+ +
[docs] def forward(self, _input): + """Forward pass of the convolutional layer.""" + return self.nonlinear(self.conv_layer(self.padding(_input)))
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/rim/rim_block.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/rim/rim_block.html new file mode 100644 index 00000000..129d80f5 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/rim/rim_block.html @@ -0,0 +1,290 @@ + + + + + + mridc.collections.reconstruction.models.rim.rim_block — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.rim.rim_block
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.rim.rim_block

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from typing import Any, Tuple, Union
+
+import torch
+
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul
+from mridc.collections.reconstruction.models.rim.conv_layers import ConvNonlinear, ConvRNNStack
+from mridc.collections.reconstruction.models.rim.rnn_cells import ConvGRUCell, ConvMGUCell, IndRNNCell
+from mridc.collections.reconstruction.models.rim.utils import log_likelihood_gradient
+
+
+
[docs]class RIMBlock(torch.nn.Module): + """RIMBlock is a block of Recurrent Inference Machines (RIMs).""" + + def __init__( + self, + recurrent_layer=None, + conv_filters=None, + conv_kernels=None, + conv_dilations=None, + conv_bias=None, + recurrent_filters=None, + recurrent_kernels=None, + recurrent_dilations=None, + recurrent_bias=None, + depth: int = 2, + time_steps: int = 8, + conv_dim: int = 2, + no_dc: bool = False, + fft_type: str = "orthogonal", + ): + """ + Initialize the RIMBlock. + + Parameters + ---------- + recurrent_layer: Type of recurrent layer. + conv_filters: Number of filters in the convolutional layers. + conv_kernels: Kernel size of the convolutional layers. + conv_dilations: Dilation of the convolutional layers. + conv_bias: Bias of the convolutional layers. + recurrent_filters: Number of filters in the recurrent layers. + recurrent_kernels: Kernel size of the recurrent layers. + recurrent_dilations: Dilation of the recurrent layers. + recurrent_bias: Bias of the recurrent layers. + depth: Number of layers in the block. + time_steps: Number of time steps in the block. + conv_dim: Dimension of the convolutional layers. + no_dc: If True, the DC component is removed from the input. + fft_type: Type of FFT. + """ + super(RIMBlock, self).__init__() + + self.input_size = depth * 2 + self.time_steps = time_steps + + self.layers = torch.nn.ModuleList() + for ( + (conv_features, conv_k_size, conv_dilation, l_conv_bias, nonlinear), + (rnn_features, rnn_k_size, rnn_dilation, rnn_bias, rnn_type), + ) in zip( + zip(conv_filters, conv_kernels, conv_dilations, conv_bias, ["relu", "relu", None]), + zip( + recurrent_filters, + recurrent_kernels, + recurrent_dilations, + recurrent_bias, + [recurrent_layer, recurrent_layer, None], + ), + ): + conv_layer = None + + if conv_features != 0: + conv_layer = ConvNonlinear( + self.input_size, + conv_features, + conv_dim=conv_dim, + kernel_size=conv_k_size, + dilation=conv_dilation, + bias=l_conv_bias, + nonlinear=nonlinear, + ) + self.input_size = conv_features + + if rnn_features != 0 and rnn_type is not None: + if rnn_type.upper() == "GRU": + rnn_type = ConvGRUCell + elif rnn_type.upper() == "MGU": + rnn_type = ConvMGUCell + elif rnn_type.upper() == "INDRNN": + rnn_type = IndRNNCell + else: + raise ValueError("Please specify a proper recurrent layer type.") + + rnn_layer = rnn_type( + self.input_size, + rnn_features, + conv_dim=2, + kernel_size=rnn_k_size, + dilation=rnn_dilation, + bias=rnn_bias, + ) + + self.input_size = rnn_features + + self.layers.append(ConvRNNStack(conv_layer, rnn_layer)) + + self.final_layer = torch.nn.Sequential(conv_layer) + + self.recurrent_filters = recurrent_filters + self.fft_type = fft_type + + self.no_dc = no_dc + + if not self.no_dc: + self.dc_weight = torch.nn.Parameter(torch.ones(1)) + self.zero = torch.zeros(1, 1, 1, 1, 1) + +
[docs] def forward( + self, + pred: torch.Tensor, + masked_kspace: torch.Tensor, + sense: torch.Tensor, + mask: torch.Tensor, + eta: torch.Tensor = None, + hx: torch.Tensor = None, + sigma: float = 1.0, + keep_eta: bool = False, + ) -> Tuple[Any, Union[list, torch.Tensor, None]]: + """ + Forward pass of the RIMBlock. + + Parameters + ---------- + pred: Predicted k-space. + masked_kspace: Subsampled k-space. + sense: Coil sensitivity maps. + mask: Sample mask. + eta: Initial guess for the eta. + hx: Initial guess for the hidden state. + sigma: Noise level. + keep_eta: Whether to keep the eta. + + Returns + ------- + Reconstructed image and hidden states. + """ + if hx is None: + hx = [ + masked_kspace.new_zeros((masked_kspace.size(0), f, *masked_kspace.size()[2:-1])) + for f in self.recurrent_filters + if f != 0 + ] + + if isinstance(pred, list): + pred = pred[-1].detach() + + if eta is None or eta.ndim < 3: + eta = ( + pred + if keep_eta + else torch.sum( + complex_mul(ifft2c(pred, fft_type=self.fft_type), complex_conj(sense)), + 1, + ) + ) + + etas = [] + for _ in range(self.time_steps): + grad_eta = log_likelihood_gradient( + eta, masked_kspace, sense, mask, sigma=sigma, fft_type=self.fft_type + ).contiguous() + + for h, convrnn in enumerate(self.layers): + hx[h] = convrnn(grad_eta, hx[h]) + grad_eta = hx[h] + + eta = eta + self.final_layer(grad_eta).permute(0, 2, 3, 1) + etas.append(eta) + + eta = etas + + if self.no_dc: + return eta, None + + soft_dc = torch.where(mask, pred - masked_kspace, self.zero.to(masked_kspace)) * self.dc_weight + current_kspace = [ + masked_kspace - soft_dc - fft2c(complex_mul(e.unsqueeze(1), sense), fft_type=self.fft_type) for e in eta + ] + + return current_kspace, None
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/rim/rnn_cells.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/rim/rnn_cells.html new file mode 100644 index 00000000..ad156782 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/rim/rnn_cells.html @@ -0,0 +1,465 @@ + + + + + + mridc.collections.reconstruction.models.rim.rnn_cells — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.rim.rnn_cells
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.rim.rnn_cells

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+import torch
+import torch.nn as nn
+
+
+
[docs]class ConvGRUCellBase(nn.Module): + """ + Base class for Conv Gated Recurrent Unit (GRU) cells. + # TODO: add paper reference + """ + + def __init__(self, input_size, hidden_size, conv_dim, kernel_size, dilation, bias): + super(ConvGRUCellBase, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.conv_dim = conv_dim + self.conv_class = self.determine_conv_class(conv_dim) + + self.ih = nn.Conv2d( + input_size, + 3 * hidden_size, + kernel_size, + padding=torch.div(dilation * (kernel_size - 1), 2, rounding_mode="trunc").item(), + dilation=dilation, + bias=bias, + ) + self.hh = nn.Conv2d( + hidden_size, + 3 * hidden_size, + kernel_size, + padding=torch.div(dilation * (kernel_size - 1), 2, rounding_mode="trunc").item(), + dilation=dilation, + bias=False, + ) + + self.reset_parameters() + +
[docs] def reset_parameters(self): + """Initialize parameters following the way proposed in the paper.""" + self.ih.weight.data = self.orthotogonalize_weights(self.ih.weight.data) + self.hh.weight.data = self.orthotogonalize_weights(self.hh.weight.data) + + if self.bias is True: + nn.init.zeros_(self.ih.bias)
+ +
[docs] @staticmethod + def orthotogonalize_weights(weights, chunks=1): + """Orthogonalize the weights of a convolutional layer.""" + return torch.cat([nn.init.orthogonal_(w) for w in weights.chunk(chunks, 0)], 0)
+ +
[docs] @staticmethod + def determine_conv_class(n_dim): + """Determine the convolutional class to use.""" + if n_dim == 1: + return nn.Conv1d + if n_dim == 2: + return nn.Conv2d + if n_dim == 3: + return nn.Conv3d + raise NotImplementedError("No convolution of this dimensionality implemented")
+ +
[docs] def extra_repr(self): + """Extra information to be printed when printing the model.""" + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": + s += ", nonlinearity={nonlinearity}" + return s.format(**self.__dict__)
+ +
[docs] def check_forward_input(self, _input): + """Check forward input.""" + if _input.size(1) != self.input_size: + raise RuntimeError(f"input has inconsistent input_size: got {_input.size(1)}, expected {self.input_size}")
+ +
[docs] def check_forward_hidden(self, _input, hx, hidden_label=""): + """Check forward hidden.""" + if _input.size(0) != hx.size(0): + raise RuntimeError( + f"Input batch size {_input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}" + ) + + if hx.size(1) != self.hidden_size: + raise RuntimeError( + f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}" + )
+ + +
[docs]class ConvGRUCell(ConvGRUCellBase): + """A Convolutional GRU cell.""" + + def __init__(self, input_size, hidden_size, conv_dim, kernel_size, dilation=1, bias=True): + """ + Initialize the ConvGRUCell. + + Parameters + ---------- + input_size: The number of channels in the input. + hidden_size: The number of channels in the hidden state. + conv_dim: The number of dimensions of the convolutional layer. + kernel_size: The size of the convolutional kernel. + dilation: The dilation of the convolutional kernel. + bias: Whether to add a bias. + """ + super(ConvGRUCell, self).__init__(input_size, hidden_size, conv_dim, kernel_size, dilation, bias) + +
[docs] def forward(self, _input, hx): + """Forward pass of the ConvGRUCell.""" + ih = self.ih(_input).chunk(3, 1) + hh = self.hh(hx).chunk(3, 1) + + r = torch.sigmoid(ih[0] + hh[0]) + z = torch.sigmoid(ih[1] + hh[1]) + n = torch.tanh(ih[2] + r * hh[2]) + + hx = n * (1 - z) + z * hx + + return hx
+ + +
[docs]class ConvMGUCellBase(nn.Module): + """ + A base class for a Convolutional Minimal Gated Unit cell. + # TODO: add paper reference + """ + + def __init__(self, input_size, hidden_size, conv_dim, kernel_size, dilation, bias): + """ + Initialize the ConvMGUCellBase. + + Parameters + ---------- + input_size: The number of channels in the input. + hidden_size: The number of channels in the hidden state. + conv_dim: The number of dimensions of the convolutional layer. + kernel_size: The size of the convolutional kernel. + dilation: The dilation of the convolutional kernel. + bias: Whether to add a bias. + """ + super(ConvMGUCellBase, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.conv_dim = conv_dim + self.conv_class = self.determine_conv_class(conv_dim) + + self.ih = nn.Conv2d( + input_size, + 2 * hidden_size, + kernel_size, + padding=torch.div(dilation * (kernel_size - 1), 2, rounding_mode="trunc").item(), + dilation=dilation, + bias=bias, + ) + self.hh = nn.Conv2d( + hidden_size, + 2 * hidden_size, + kernel_size, + padding=torch.div(dilation * (kernel_size - 1), 2, rounding_mode="trunc").item(), + dilation=dilation, + bias=False, + ) + + self.reset_parameters() + +
[docs] def reset_parameters(self): + """Reset the parameters.""" + self.ih.weight.data = self.orthotogonalize_weights(self.ih.weight.data) + self.hh.weight.data = self.orthotogonalize_weights(self.hh.weight.data) + + nn.init.xavier_uniform_(self.ih.weight, nn.init.calculate_gain("relu")) + nn.init.xavier_uniform_(self.hh.weight) + + if self.bias is True: + nn.init.zeros_(self.ih.bias)
+ +
[docs] @staticmethod + def orthotogonalize_weights(weights, chunks=1): + """Orthogonalize the weights.""" + return torch.cat([nn.init.orthogonal_(w) for w in weights.chunk(chunks, 0)], 0)
+ +
[docs] @staticmethod + def determine_conv_class(n_dim): + """Determine the convolutional class.""" + if n_dim == 1: + return nn.Conv1d + if n_dim == 2: + return nn.Conv2d + if n_dim == 3: + return nn.Conv3d + raise ValueError(f"Convolution of: {n_dim} dims is not implemented")
+ +
[docs] def extra_repr(self): + """Extra information about the ConvMGUCellBase.""" + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": + s += ", nonlinearity={nonlinearity}" + return s.format(**self.__dict__)
+ +
[docs] def check_forward_input(self, _input): + """Check the forward input.""" + if _input.size(1) != self.input_size: + raise RuntimeError(f"input has inconsistent input_size: got {_input.size(1)}, expected {self.input_size}")
+ +
[docs] def check_forward_hidden(self, _input, hx, hidden_label=""): + """Check the forward hidden.""" + if _input.size(0) != hx.size(0): + raise RuntimeError( + f"Input batch size {_input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}" + ) + + if hx.size(1) != self.hidden_size: + raise RuntimeError( + f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}" + )
+ + +
[docs]class ConvMGUCell(ConvMGUCellBase): + """Convolutional Minimal Gated Unit cell.""" + + def __init__(self, input_size, hidden_size, conv_dim, kernel_size, dilation=1, bias=True): + """ + Initialize the ConvMGUCell. + + Parameters + ---------- + input_size: The input size. + hidden_size: The hidden size. + conv_dim: The convolutional dimension. + kernel_size: The kernel size. + dilation: The dilation. + bias: Whether to use a bias. + """ + super(ConvMGUCell, self).__init__(input_size, hidden_size, conv_dim, kernel_size, dilation, bias) + +
[docs] def forward(self, _input, hx): + """Forward the ConvMGUCell.""" + ih = self.ih(_input).chunk(2, dim=1) + hh = self.hh(hx).chunk(2, dim=1) + + f = torch.sigmoid(ih[0] + hh[0]) + c = torch.tanh(ih[1] + f * hh[1]) + + return c + f * (hx - c)
+ + +
[docs]class IndRNNCellBase(nn.Module): + """ + Base class for Independently RNN cells as presented in [1]_. + + References + ---------- + .. [1] Li, S. et al. (2018) ‘Independently Recurrent Neural Network (IndRNN): Building A Longer and Deeper RNN’, Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, (1), pp. 5457–5466. doi: 10.1109/CVPR.2018.00572. + """ + + def __init__(self, input_size, hidden_size, conv_dim, kernel_size, dilation, bias): + """ + Initialize the IndRNNCellBase. + + Parameters + ---------- + input_size: The input size. + hidden_size: The hidden size. + conv_dim: The convolutional dimension. + kernel_size: The kernel size. + dilation: The dilation. + bias: Whether to use a bias. + """ + super(IndRNNCellBase, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.kernel_size = kernel_size + self.bias = bias + self.conv_dim = conv_dim + self.conv_class = self.determine_conv_class(conv_dim) + + self.ih = nn.Conv2d( + input_size, + hidden_size, + kernel_size, + padding=torch.div(dilation * (kernel_size - 1), 2, rounding_mode="trunc").item(), + dilation=dilation, + bias=bias, + ) + self.hh = nn.Parameter( + nn.init.normal_(torch.empty(1, hidden_size, 1, 1), std=1.0 / (hidden_size * (1 + kernel_size**2))) + ) + + self.reset_parameters() + +
[docs] def reset_parameters(self): + """Reset the parameters.""" + self.ih.weight.data = self.orthotogonalize_weights(self.ih.weight.data) + + nn.init.normal_(self.ih.weight, std=1.0 / (self.hidden_size * (1 + self.kernel_size**2))) + + if self.bias is True: + nn.init.zeros_(self.ih.bias)
+ +
[docs] @staticmethod + def orthotogonalize_weights(weights, chunks=1): + """Orthogonalize the weights.""" + return torch.cat([nn.init.orthogonal_(w) for w in weights.chunk(chunks, 0)], 0)
+ +
[docs] @staticmethod + def determine_conv_class(n_dim): + """Determine the convolutional class.""" + if n_dim == 1: + return nn.Conv1d + if n_dim == 2: + return nn.Conv2d + if n_dim == 3: + return nn.Conv3d + raise NotImplementedError("No convolution of this dimensionality implemented")
+ +
[docs] def extra_repr(self): + """Extra information about the module, used for printing.""" + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": + s += ", nonlinearity={nonlinearity}" + return s.format(**self.__dict__)
+ +
[docs] def check_forward_input(self, _input): + """Check forward input.""" + if _input.size(1) != self.input_size: + raise RuntimeError(f"input has inconsistent input_size: got {_input.size(1)}, expected {self.input_size}")
+ +
[docs] def check_forward_hidden(self, _input, hx, hidden_label=""): + """Check forward hidden.""" + if _input.size(0) != hx.size(0): + raise RuntimeError( + f"Input batch size {_input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}" + ) + + if hx.size(1) != self.hidden_size: + raise RuntimeError( + f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}" + )
+ + +
[docs]class IndRNNCell(IndRNNCellBase): + """Independently Recurrent Neural Network cell.""" + + def __init__(self, input_size, hidden_size, conv_dim, kernel_size, dilation=1, bias=True): + """ + Parameters + ---------- + input_size: The number of expected features in the input. + hidden_size: The number of features in the hidden state. + conv_dim: The dimension of the convolutional layer. + kernel_size: The size of the convolved kernel. + dilation: The spacing between the kernel points. + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + """ + super(IndRNNCell, self).__init__(input_size, hidden_size, conv_dim, kernel_size, dilation, bias) + +
[docs] def forward(self, _input, hx): + """Forward propagate the RNN cell.""" + return nn.ReLU()(self.ih(_input) + self.hh * hx)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/rim/utils.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/rim/utils.html new file mode 100644 index 00000000..c62c5397 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/rim/utils.html @@ -0,0 +1,145 @@ + + + + + + mridc.collections.reconstruction.models.rim.utils — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.rim.utils
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.rim.utils

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+import torch
+
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+
+
+
[docs]def log_likelihood_gradient( + eta: torch.Tensor, + masked_kspace: torch.Tensor, + sense: torch.Tensor, + mask: torch.Tensor, + sigma: float, + fft_type: str = "orthogonal", +) -> torch.Tensor: + """ + Computes the gradient of the log-likelihood function. + + Parameters + ---------- + eta: Initial guess for the reconstruction. + masked_kspace: Subsampled k-space data. + sense: Sensing matrix. + mask: Sampling mask. + sigma: Noise level. + fft_type: Type of FFT to use. + + Returns + ------- + Gradient of the log-likelihood function. + """ + eta_real, eta_imag = map(lambda x: torch.unsqueeze(x, 0), eta.chunk(2, -1)) + sense_real, sense_imag = sense.chunk(2, -1) + + re_se = eta_real * sense_real - eta_imag * sense_imag + im_se = eta_real * sense_imag + eta_imag * sense_real + + pred = ifft2c(mask * (fft2c(torch.cat((re_se, im_se), -1), fft_type=fft_type) - masked_kspace), fft_type=fft_type) + + pred_real, pred_imag = pred.chunk(2, -1) + + re_out = torch.sum(pred_real * sense_real + pred_imag * sense_imag, 1) / (sigma**2.0) + im_out = torch.sum(pred_imag * sense_real - pred_real * sense_imag, 1) / (sigma**2.0) + + eta_real = eta_real.squeeze(0) + eta_imag = eta_imag.squeeze(0) + + return torch.cat((eta_real, eta_imag, re_out, im_out), 0).unsqueeze(0).squeeze(-1)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/rvn.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/rvn.html new file mode 100644 index 00000000..5d3dc585 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/rvn.html @@ -0,0 +1,289 @@ + + + + + + mridc.collections.reconstruction.models.rvn — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.rvn
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.rvn

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+import math
+from abc import ABC
+from typing import Optional
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from torch.nn import L1Loss
+
+from mridc.collections.common.losses.ssim import SSIMLoss
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.rnn_utils import rnn_weights_init
+from mridc.collections.common.parts.utils import coil_combination, complex_conj, complex_mul
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet import RecurrentInit, RecurrentVarNetBlock
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["RecurrentVarNet"]
+
+
+
[docs]class RecurrentVarNet(BaseMRIReconstructionModel, ABC): + """ + Implementation of the Recurrent Variational Network implementation, as presented in Yiasemis, George, et al. + + References + ---------- + + .. + + Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to \ + the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, \ + http://arxiv.org/abs/2111.09639. + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + # Cascades of RIM blocks + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + self.in_channels = cfg_dict.get("in_channels") + self.recurrent_hidden_channels = cfg_dict.get("recurrent_hidden_channels") + self.recurrent_num_layers = cfg_dict.get("recurrent_num_layers") + self.no_parameter_sharing = cfg_dict.get("no_parameter_sharing") + + # make time-steps size divisible by 8 for fast fp16 training + self.num_steps = 8 * math.ceil(cfg_dict.get("num_steps") / 8) + + self.learned_initializer = cfg_dict.get("learned_initializer") + self.initializer_initialization = cfg_dict.get("initializer_initialization") + self.initializer_channels = cfg_dict.get("initializer_channels") + self.initializer_dilations = cfg_dict.get("initializer_dilations") + + if ( + self.learned_initializer + and self.initializer_initialization is not None + and self.initializer_channels is not None + and self.initializer_dilations is not None + ): + if self.initializer_initialization not in [ + "sense", + "input_image", + "zero_filled", + ]: + raise ValueError( + "Unknown initializer_initialization. Expected `sense`, `'input_image` or `zero_filled`." + f"Got {self.initializer_initialization}." + ) + self.initializer = RecurrentInit( + self.in_channels, + self.recurrent_hidden_channels, + channels=self.initializer_channels, + dilations=self.initializer_dilations, + depth=self.recurrent_num_layers, + multiscale_depth=cfg_dict.get("initializer_multiscale"), + ) + else: + self.initializer = None # type: ignore + + self.fft_type = cfg_dict.get("fft_type") + self.output_type = cfg_dict.get("output_type") + + self.block_list: torch.nn.Module = torch.nn.ModuleList() + for _ in range(self.num_steps if self.no_parameter_sharing else 1): + self.block_list.append( + RecurrentVarNetBlock( + in_channels=self.in_channels, + hidden_channels=self.recurrent_hidden_channels, + num_layers=self.recurrent_num_layers, + fft_type=self.fft_type, + ) + ) + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_chans"), + cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + normalize=cfg_dict.get("sens_normalize"), + ) + + std_init_range = 1 / self.recurrent_hidden_channels**0.5 + + # initialize weights if not using pretrained cirim + if not cfg_dict.get("pretrained", False): + self.block_list.apply(lambda module: rnn_weights_init(module, std_init_range)) + + self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() + self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() + + self.accumulate_estimates = False + +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + init_pred: torch.Tensor, + target: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass of the network. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] + If self.accumulate_loss is True, returns a list of all intermediate estimates. + If False, returns the final estimate. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + + previous_state: Optional[torch.Tensor] = None + + if self.initializer is not None: + if self.initializer_initialization == "sense": + initializer_input_image = ( + complex_mul(ifft2c(y, fft_type=self.fft_type), complex_conj(sensitivity_maps)).sum(1).unsqueeze(1) + ) + elif self.initializer_initialization == "input_image": + if "initial_image" not in kwargs: + raise ValueError( + "`'initial_image` is required as input if initializer_initialization " + f"is {self.initializer_initialization}." + ) + initializer_input_image = kwargs["initial_image"].unsqueeze(1) + elif self.initializer_initialization == "zero_filled": + initializer_input_image = ifft2c(y, fft_type=self.fft_type) + + previous_state = self.initializer( + fft2c(initializer_input_image, fft_type=self.fft_type).sum(1).permute(0, 3, 1, 2) + ) + + kspace_prediction = y.clone() + + for step in range(self.num_steps): + block = self.block_list[step] if self.no_parameter_sharing else self.block_list[0] + kspace_prediction, previous_state = block( + kspace_prediction, + y, + mask, + sensitivity_maps, + previous_state, + ) + + eta = ifft2c(kspace_prediction, fft_type=self.fft_type) + eta = coil_combination(eta, sensitivity_maps, method=self.output_type, dim=1) + eta = torch.view_as_complex(eta) + _, eta = center_crop_to_smallest(target, eta) + return eta
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/sigmanet/dc_layers.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/sigmanet/dc_layers.html new file mode 100644 index 00000000..5ec3b94c --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/sigmanet/dc_layers.html @@ -0,0 +1,439 @@ + + + + + + mridc.collections.reconstruction.models.sigmanet.dc_layers — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.sigmanet.dc_layers
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.sigmanet.dc_layers

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from:
+# https://github.com/khammernik/sigmanet/blob/master/reconstruction/common/mytorch/models/datalayer.py
+
+import torch
+
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.utils import complex_abs, complex_conj, complex_mul
+
+
+
[docs]class DataIDLayer(torch.nn.Module): + """Placeholder for the data layer.""" + + def __init__(self, *args, **kwargs): + super().__init__()
+ + +
[docs]class DataGDLayer(torch.nn.Module): + """DataLayer computing the gradient on the L2 dataterm.""" + + def __init__(self, lambda_init, learnable=True, fft_type="orthogonal"): + """ + Parameters + ---------- + lambda_init: Init value of data term weight lambda. + learnable: If True, the data term weight lambda is learnable. + fft_type: Type of FFT to use. + """ + super(DataGDLayer, self).__init__() + self.lambda_init = lambda_init + self.data_weight = torch.nn.Parameter(torch.Tensor(1)) + self.data_weight.data = torch.tensor( + lambda_init, + dtype=self.data_weight.dtype, + ) + self.data_weight.requires_grad = learnable + + self.fft_type = fft_type + +
[docs] def forward(self, x, y, smaps, mask): + """ + + Parameters + ---------- + x: Input image. + y: Subsampled k-space data. + smaps: Coil sensitivity maps. + mask: Sampling mask. + + Returns + ------- + data_loss: Data term loss. + """ + A_x_y = ( + torch.sum( + fft2c(complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps), fft_type=self.fft_type) * mask, + -4, + keepdim=True, + ) + - y + ) + gradD_x = torch.sum(complex_mul(ifft2c(A_x_y * mask), complex_conj(smaps)), dim=(-5)) + return x - self.data_weight * gradD_x
+ + +
[docs]class DataProxCGLayer(torch.nn.Module): + """Solving the prox wrt. dataterm using Conjugate Gradient as proposed by Aggarwal et al.""" + + def __init__(self, lambda_init, tol=1e-6, iter=10, learnable=True, fft_type="orthogonal"): + super(DataProxCGLayer, self).__init__() + + self.lambdaa = torch.nn.Parameter(torch.Tensor(1)) + self.lambdaa.data = torch.tensor(lambda_init) + self.lambdaa_init = lambda_init + self.lambdaa.requires_grad = learnable + + self.tol = tol + self.iter = iter + + self.op = ConjugateGradient + self.fft_type = fft_type + +
[docs] def forward(self, x, f, smaps, mask): + """ + + Parameters + ---------- + x: Input image. + f: Subsampled k-space data. + smaps: Coil sensitivity maps. + mask: Sampling mask. + + Returns + ------- + data_loss: Data term loss. + """ + return self.op.apply( + x, + self.lambdaa, + f, + smaps, + mask, + self.tol, + self.iter, + self.fft_type, + )
+ +
[docs] def set_learnable(self, flag): + self.lambdaa.requires_grad = flag
+ + +
[docs]class ConjugateGradient(torch.autograd.Function): + """Conjugate Gradient solver for the prox of the data term.""" + +
[docs] @staticmethod + def complexDot(data1, data2): + """Complex dot product of two tensors.""" + nBatch = data1.shape[0] + mult = complex_mul(data1, complex_conj(data2)) + re, im = torch.unbind(mult, dim=-1) + return torch.stack([torch.sum(re.view(nBatch, -1), dim=-1), torch.sum(im.view(nBatch, -1), dim=-1)], -1)
+ +
[docs] @staticmethod + def solve(x0, M, tol, max_iter): + """Solve the linear system Mx=b using conjugate gradient.""" + nBatch = x0.shape[0] + x = torch.zeros(x0.shape).to(x0.device) + r = x0.clone() + p = x0.clone() + x0x0 = (x0.pow(2)).view(nBatch, -1).sum(-1) + rr = torch.stack([(r.pow(2)).view(nBatch, -1).sum(-1), torch.zeros(nBatch).to(x0.device)], dim=-1) + + it = 0 + while torch.min(rr[..., 0] / x0x0) > tol and it < max_iter: + it += 1 + q = M(p) + + data1 = rr + data2 = ConjugateGradient.complexDot(p, q) + + re1, im1 = torch.unbind(data1, -1) + re2, im2 = torch.unbind(data2, -1) + alpha = torch.stack([re1 * re2 + im1 * im2, im1 * re2 - re1 * im2], -1) / complex_abs(data2) ** 2 + + x += complex_mul(alpha.reshape(nBatch, 1, 1, 1, -1), p.clone()) + r -= complex_mul(alpha.reshape(nBatch, 1, 1, 1, -1), q.clone()) + rr_new = torch.stack([(r.pow(2)).view(nBatch, -1).sum(-1), torch.zeros(nBatch).to(x0.device)], dim=-1) + beta = torch.stack([rr_new[..., 0] / rr[..., 0], torch.zeros(nBatch).to(x0.device)], dim=-1) + p = r.clone() + complex_mul(beta.reshape(nBatch, 1, 1, 1, -1), p) + rr = rr_new.clone() + return x
+ +
[docs] @staticmethod + def forward(ctx, z, lambdaa, y, smaps, mask, tol, max_iter, fft_type): + """ + Forward pass of the conjugate gradient solver. + + Parameters + ---------- + ctx: Context object. + z: Input image. + lambdaa: Regularization parameter. + y: Subsampled k-space data. + smaps: Coil sensitivity maps. + mask: Sampling mask. + tol: Tolerance for the stopping criterion. + max_iter: Maximum number of iterations. + fft_type: FFT type. + + Returns + ------- + z: Output image. + """ + ctx.tol = tol + ctx.max_iter = max_iter + ctx.fft_type = fft_type + + def A(x): + x = fft2c(complex_mul(x.expand_as(smaps), smaps), fft_type=fft_type) * mask + return torch.sum(x, dim=-4, keepdim=True) + + def AT(x): + return torch.sum(complex_mul(ifft2c(x * mask), complex_conj(smaps)), dim=(-5)) + + def M(p): + return lambdaa * AT(A(p)) + p + + x0 = lambdaa * AT(y) + z + ctx.save_for_backward(AT(y), x0, smaps, mask, lambdaa) + + return ConjugateGradient.solve(x0, M, ctx.tol, ctx.max_iter)
+ +
[docs] @staticmethod + def backward(ctx, grad_x): + """ + Backward pass of the conjugate gradient solver. + + Parameters + ---------- + ctx: Context object. + grad_x: Gradient of the output image. + + Returns + ------- + grad_z: Gradient of the input image. + """ + ATy, rhs, smaps, mask, lambdaa = ctx.saved_tensors + + def A(x): + x = fft2c(complex_mul(x.expand_as(smaps), smaps), fft_type=ctx.fft_type) * mask + return torch.sum(x, dim=-4, keepdim=True) + + def AT(x): + return torch.sum(complex_mul(ifft2c(x * mask), complex_conj(smaps)), dim=(-5)) + + def M(p): + return lambdaa * AT(A(p)) + p + + Qe = ConjugateGradient.solve(grad_x, M, ctx.tol, ctx.max_iter) + QQe = ConjugateGradient.solve(Qe, M, ctx.tol, ctx.max_iter) + + grad_z = Qe + + grad_lambdaa = ( + complex_mul(ifft2c(Qe), complex_conj(ATy)).sum() - complex_mul(ifft2c(QQe), complex_conj(rhs)).sum() + ) + + return grad_z, grad_lambdaa, None, None, None, None, None, None
+ + +
[docs]class DataVSLayer(torch.nn.Module): + """ + DataLayer using variable splitting formulation + """ + + def __init__(self, alpha_init, beta_init, learnable=True, fft_type="orthogonal"): + """ + Parameters + ---------- + alpha_init: Init value of data consistency block (DCB) + beta_init: Init value of weighted averaging block (WAB) + learnable: If True, the parameters of the model are learnable + fft_type: Type of FFT to use. Can be "orthogonal". + """ + super(DataVSLayer, self).__init__() + self.alpha = torch.nn.Parameter(torch.Tensor(1)) + self.alpha.data = torch.tensor(alpha_init, dtype=self.alpha.dtype) + + self.beta = torch.nn.Parameter(torch.Tensor(1)) + self.beta.data = torch.tensor(beta_init, dtype=self.beta.dtype) + + self.learnable = learnable + self.set_learnable(learnable) + + self.fft_type = fft_type + +
[docs] def forward(self, x, y, smaps, mask): + """ + Forward pass of the data-consistency block. + + Parameters + ---------- + x: Input image. + y: Subsampled k-space data. + smaps: Coil sensitivity maps. + mask: Sampling mask. + + Returns + ------- + Output image. + """ + A_x = torch.sum( + fft2c(complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps), fft_type=self.fft_type), -4, keepdim=True + ) + k_dc = (1 - mask) * A_x + mask * (self.alpha * A_x + (1 - self.alpha) * y) + x_dc = torch.sum(complex_mul(ifft2c(k_dc), complex_conj(smaps)), dim=(-5)) + return self.beta * x + (1 - self.beta) * x_dc
+ +
[docs] def set_learnable(self, flag): + """ + Set the learnable flag of the parameters. + + Parameters + ---------- + flag: If True, the parameters of the model are learnable. + """ + self.learnable = flag + self.alpha.requires_grad = self.learnable + self.beta.requires_grad = self.learnable
+ + +
[docs]class DCLayer(torch.nn.Module): + """ + Data Consistency layer from DC-CNN, apply for single coil mainly + """ + + def __init__(self, lambda_init=0.0, learnable=True, fft_type="orthogonal"): + """ + Parameters + ---------- + lambda_init: Init value of data consistency block (DCB) + learnable: If True, the parameters of the model are learnable + fft_type: Type of FFT to use. Can be "orthogonal". + """ + super(DCLayer, self).__init__() + self.lambda_ = torch.nn.Parameter(torch.Tensor(1)) + self.lambda_.data = torch.tensor(lambda_init, dtype=self.lambda_.dtype) + + self.learnable = learnable + self.set_learnable(learnable) + + self.fft_type = fft_type + +
[docs] def forward(self, x, y, mask): + """ + Forward pass of the data-consistency block. + + Parameters + ---------- + x: Input image. + y: Subsampled k-space data. + mask: Sampling mask. + + Returns + ------- + Output image. + """ + A_x = fft2c(x, fft_type=self.fft_type) + k_dc = (1 - mask) * A_x + mask * (self.lambda_ * A_x + (1 - self.lambda_) * y) + return ifft2c(k_dc, fft_type=self.fft_type)
+ +
[docs] def set_learnable(self, flag): + """ + Set the learnable flag of the parameters. + + Parameters + ---------- + flag: If True, the parameters of the model are learnable. + """ + self.learnable = flag + self.lambda_.requires_grad = self.learnable
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/sigmanet/sensitivity_net.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/sigmanet/sensitivity_net.html new file mode 100644 index 00000000..9c299a39 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/sigmanet/sensitivity_net.html @@ -0,0 +1,405 @@ + + + + + + mridc.collections.reconstruction.models.sigmanet.sensitivity_net — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.sigmanet.sensitivity_net
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.sigmanet.sensitivity_net

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from:
+# https://github.com/khammernik/sigmanet/blob/master/reconstruction/common/mytorch/models/sn.py
+import numpy as np
+import torch
+
+
+
[docs]def matrix_invert(xx, xy, yx, yy): + """Invert a 2x2 matrix.""" + det = xx * yy - xy * yx + return yy.div(det), -xy.div(det), -yx.div(det), xx.div(det)
+ + +
[docs]class ComplexInstanceNorm(torch.nn.Module): + """Motivated by 'Deep Complex Networks' (https://arxiv.org/pdf/1705.09792.pdf)""" + + def __init__(self): + super(ComplexInstanceNorm, self).__init__() + self.mean = 0 + self.cov_xx_half = 1 / np.sqrt(2) + self.cov_xy_half = 0 + self.cov_yx_half = 0 + self.cov_yy_half = 1 / np.sqrt(2) + +
[docs] def complex_instance_norm(self, x, eps=1e-5): + """Operates on images x of size [nBatch, nSmaps, nFE, nPE, 2]""" + x_combined = torch.sum(x, dim=1, keepdim=True) + mean = x_combined.mean(dim=(1, 2, 3), keepdim=True) + x_m = x - mean + self.mean = mean + self.complex_pseudocovariance(x_m)
+ +
[docs] def complex_pseudocovariance(self, data): + """Data variable hast to be already mean-free! Operates on images x of size [nBatch, nSmaps, nFE, nPE, 2]""" + if data.size(-1) != 2: + raise AssertionError + shape = data.shape + + # compute number of elements + N = shape[2] * shape[3] + + # separate real/imaginary channel + re, im = torch.unbind(data, dim=-1) + + # dimensions is now length of original shape - 1 (because channels are seperated) + dim = list(range(1, len(shape) - 1)) + + # compute covariance entries. cxy = cyx + cxx = (re * re).sum(dim=dim, keepdim=True) / (N - 1) + cyy = (im * im).sum(dim=dim, keepdim=True) / (N - 1) + cxy = (re * im).sum(dim=dim, keepdim=True) / (N - 1) + + # Eigenvalue decomposition C = V*S*inv(V) + # compute eigenvalues + s1 = (cxx + cyy) / 2 - torch.sqrt((cxx + cyy) ** 2 / 4 - cxx * cyy + cxy**2) + s2 = (cxx + cyy) / 2 + torch.sqrt((cxx + cyy) ** 2 / 4 - cxx * cyy + cxy**2) + + # compute eigenvectors + v1x = s1 - cyy + v1y = cxy + v2x = s2 - cyy + v2y = cxy + + # normalize eigenvectors + norm1 = torch.sqrt(torch.sum(v1x * v1x + v1y * v1y, dim=dim, keepdim=True)) + norm2 = torch.sqrt(torch.sum(v2x * v2x + v2y * v2y, dim=dim, keepdim=True)) + + v1x = v1x.div(norm1) + v1y = v1y.div(norm1) + + v2x = v2x.div(norm2) + v2y = v2y.div(norm2) + + # now we need the sqrt of the covariance matrix. + # C^{-0.5} = V * sqrt(S) * inv(V) + det = v1x * v2y - v2x * v1y + s1 = torch.sqrt(s1).div(det) + s2 = torch.sqrt(s2).div(det) + + self.cov_xx_half = v1x * v2y * s1 - v1y * v2x * s2 + self.cov_yy_half = v1x * v2y * s2 - v1y * v2x * s1 + self.cov_xy_half = v1x * v2x * (s2 - s1) + self.cov_yx_half = v1y * v2y * (s1 - s2)
+ +
[docs] def forward(self, input): + """Operates on images x of size [nBatch, nSmaps, nFE, nPE, 2]""" + return self.normalize(input)
+ +
[docs] def set_normalization(self, input): + """Set the normalization parameters for a given input.""" + mean = torch.tensor([torch.mean(input).item()]).to(input) + self.complex_pseudocovariance(input - mean) + self.mean = mean.unsqueeze(1).unsqueeze(1).unsqueeze(1) + self.cov_xx_half = self.cov_xx_half.view(-1, 1, 1, 1) + self.cov_xy_half = self.cov_xy_half.view(-1, 1, 1, 1) + self.cov_yx_half = self.cov_yx_half.view(-1, 1, 1, 1) + self.cov_yy_half = self.cov_yy_half.view(-1, 1, 1, 1)
+ +
[docs] def normalize(self, x): + """Normalize the input x.""" + x_m = x - self.mean + re, im = torch.unbind(x_m, dim=-1) + + cov_xx_half_inv, cov_xy_half_inv, cov_yx_half_inv, cov_yy_half_inv = matrix_invert( + self.cov_xx_half, self.cov_xy_half, self.cov_yx_half, self.cov_yy_half + ) + x_norm_re = cov_xx_half_inv * re + cov_xy_half_inv * im + x_norm_im = cov_yx_half_inv * re + cov_yy_half_inv * im + img = torch.stack([x_norm_re, x_norm_im], dim=-1) + img = img.clamp(-6, 6) + return img
+ +
[docs] def unnormalize(self, x): + """Unnormalize the input x.""" + re, im = torch.unbind(x, dim=-1) + x_unnorm_re = self.cov_xx_half * re + self.cov_xy_half * im + x_unnorm_im = self.cov_yx_half * re + self.cov_yy_half * im + return torch.stack([x_unnorm_re, x_unnorm_im], dim=-1) + self.mean
+ + +
[docs]class ComplexNormWrapper(torch.nn.Module): + """Wrapper for complex normalization.""" + + def __init__(self, model): + super().__init__() + self.model = model + self.complex_instance_norm = ComplexInstanceNorm() + +
[docs] def forward(self, input): + # compute complex instance norm on sample of size [nBatch, nSmaps, nFE, nPE, 2] + self.complex_instance_norm.set_normalization(input) + output = self.complex_instance_norm.normalize(input) + + # re-shape data from [nBatch, nSmaps, nFE, nPE, 2] to [nBatch*nSmaps, 2, nFE, nPE] + shp = output.shape + output = output.view(shp[0] * shp[1], *shp[2:]).permute(0, 3, 1, 2) + + # apply denoising + output = self.model(output) + + # re-shape data from [nBatch*nSmaps, 2, nFE, nPE] + # to [nBatch, nSmaps, nFE, nPE, 2] + output = output.permute(0, 2, 3, 1).view(*shp) + # unnormalize + output = self.complex_instance_norm.unnormalize(output) + return output
+ + +
[docs]class SensitivityNetwork(torch.nn.Module): + """Sensitivity network with data term based on forward and adjoint containing the sensitivity maps""" + + def __init__( + self, + num_iter, + model, + datalayer, + shared_params=True, + save_space=False, + reset_cache=False, + ): + """ + + Parameters + ---------- + num_iter: Number of iterations. + model: Model to be used for the forward and adjoint. + datalayer: Data layer to be used for the forward and adjoint. + shared_params: If True, the parameters of the model are shared between the forward and adjoint. + save_space: If True, the adjoint is computed in the forward pass. + reset_cache: If True, the adjoint is computed in the forward pass. + """ + super().__init__() + + self.shared_params = shared_params + + self.num_iter = 1 if self.shared_params else num_iter + self.num_iter_total = num_iter + + self.is_trainable = [True] * num_iter + + # setup the modules + self.gradR = torch.nn.ModuleList([ComplexNormWrapper(model) for _ in range(self.num_iter)]) + + self.gradD = torch.nn.ModuleList([datalayer for _ in range(self.num_iter)]) + + self.save_space = save_space + if self.save_space: + self.forward = self.forward_save_space + self.reset_cache = reset_cache + +
[docs] def forward(self, x, y, smaps, mask): + """ + + Parameters + ---------- + x: Input data. + y: Subsampled k-space data. + smaps: Coil sensitivity maps. + mask: Sampling mask. + + Returns + ------- + Output data. + """ + x_all = [x] + x_half_all = [] + if self.shared_params: + num_iter = self.num_iter_total + else: + num_iter = min(np.where(self.is_trainable)[0][-1] + 1, self.num_iter) + + for i in range(num_iter): + x_thalf = x - self.gradR[i % self.num_iter](x) + x = self.gradD[i % self.num_iter](x_thalf, y, smaps, mask) + x_all.append(x) + x_half_all.append(x_thalf) + + return x_all[-1]
+ +
[docs] def forward_save_space(self, x, y, smaps, mask): + """ + + Parameters + ---------- + x: Input data. + y: Subsampled k-space data. + smaps: Coil sensitivity maps. + mask: Sampling mask. + + Returns + ------- + Output data. + """ + if self.shared_params: + num_iter = self.num_iter_total + else: + num_iter = min(np.where(self.is_trainable)[0][-1] + 1, self.num_iter) + + for i in range(num_iter): + x_thalf = x - self.gradR[i % self.num_iter](x) + x = self.gradD[i % self.num_iter](x_thalf, y, smaps, mask) + + # would run out of memory at test time + # if this is False for some cases + if self.reset_cache: + torch.cuda.empty_cache() + torch.backends.cuda.cufft_plan_cache.clear() + + return x
+ +
[docs] def freeze(self, i): + """freeze parameter of cascade i""" + for param in self.gradR[i].parameters(): + param.require_grad_ = False + self.is_trainable[i] = False
+ +
[docs] def unfreeze(self, i): + """freeze parameter of cascade i""" + for param in self.gradR[i].parameters(): + param.require_grad_ = True + self.is_trainable[i] = True
+ +
[docs] def freeze_all(self): + """freeze parameter of cascade i""" + for i in range(self.num_iter): + self.freeze(i)
+ +
[docs] def unfreeze_all(self): + """freeze parameter of cascade i""" + for i in range(self.num_iter): + self.unfreeze(i)
+ +
[docs] def copy_params(self, src_i, trg_j): + """copy i-th cascade net parameters to j-th cascade net parameters""" + src_params = self.gradR[src_i].parameters() + trg_params = self.gradR[trg_j].parameters() + + for trg_param, src_param in zip(trg_params, src_params): + trg_param.data.copy_(src_param.data)
+ +
[docs] def stage_training_init(self): + """set stage training flag to True""" + self.freeze_all() + self.unfreeze(0) + print(self.is_trainable)
+ +
[docs] def stage_training_transition_i(self, copy=False): + """set stage training flag to True""" + if self.shared_params: + return + + # if all unlocked, don't do anything + if not np.all(self.is_trainable): + for i in range(self.num_iter): + + # if last cascade is reached, unlock all + if i == self.num_iter - 1: + self.unfreeze_all() + break + + # freeze current i, unlock next. copy parameter if specified + if self.is_trainable[i]: + self.freeze(i) + self.unfreeze(i + 1) + if copy: + self.copy_params(i, i + 1) + break
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/unet.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/unet.html new file mode 100644 index 00000000..2804abba --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/unet.html @@ -0,0 +1,204 @@ + + + + + + mridc.collections.reconstruction.models.unet — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.unet
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.unet

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from abc import ABC
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from torch.nn import L1Loss
+
+from mridc.collections.common.losses.ssim import SSIMLoss
+from mridc.collections.common.parts.fft import ifft2c
+from mridc.collections.common.parts.utils import coil_combination
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["UNet"]
+
+
+
[docs]class UNet(BaseMRIReconstructionModel, ABC): + """ + Implementation of the UNet, as presented in O. Ronneberger, P. Fischer, and Thomas Brox. + + References + ---------- + .. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation. \ + In International Conference on Medical image computing and computer-assisted intervention, pages 234–241. \ + Springer, 2015. + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + self.fft_type = cfg_dict.get("fft_type") + + self.unet = NormUnet( + chans=cfg_dict.get("channels"), + num_pools=cfg_dict.get("pooling_layers"), + padding_size=cfg_dict.get("padding_size"), + normalize=cfg_dict.get("normalize"), + ) + + self.output_type = cfg_dict.get("output_type") + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_chans"), + cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + normalize=cfg_dict.get("sens_normalize"), + ) + + # initialize weights if not using pretrained unet + # TODO if not cfg_dict.get("pretrained", False): + + self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() + self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() + + self.accumulate_estimates = False + +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + init_pred: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the network. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] + If self.accumulate_loss is True, returns a list of all intermediate estimates. + If False, returns the final estimate. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + eta = torch.view_as_complex( + coil_combination(ifft2c(y, fft_type=self.fft_type), sensitivity_maps, method=self.output_type, dim=1) + ) + _, eta = center_crop_to_smallest(target, eta) + return torch.view_as_complex(self.unet(torch.view_as_real(eta.unsqueeze(1)))).squeeze(1)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/unet_base/unet_block.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/unet_base/unet_block.html new file mode 100644 index 00000000..1f21d6f2 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/unet_base/unet_block.html @@ -0,0 +1,404 @@ + + + + + + mridc.collections.reconstruction.models.unet_base.unet_block — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.unet_base.unet_block
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.unet_base.unet_block

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI
+import math
+from typing import List, Tuple
+
+import torch
+
+
+
[docs]class NormUnet(torch.nn.Module): + """ + Normalized U-Net model. + + This is the same as a regular U-Net, but with normalization applied to the input before the U-Net. + This keeps the values more numerically stable during training. + """ + + def __init__( + self, + chans: int, + num_pools: int, + in_chans: int = 2, + out_chans: int = 2, + drop_prob: float = 0.0, + padding_size: int = 15, + normalize: bool = True, + norm_groups: int = 2, + ): + """ + + Parameters + ---------- + chans : Number of output channels of the first convolution layer. + num_pools : Number of down-sampling and up-sampling layers. + in_chans : Number of channels in the input to the U-Net model. + out_chans : Number of channels in the output to the U-Net model. + drop_prob : Dropout probability. + padding_size: Size of the padding. + normalize: Whether to normalize the input. + norm_groups: Number of groups to use for group normalization. + """ + super().__init__() + + self.unet = Unet( + in_chans=in_chans, out_chans=out_chans, chans=chans, num_pool_layers=num_pools, drop_prob=drop_prob + ) + + self.padding_size = padding_size + self.normalize = normalize + + self.norm_groups = norm_groups + +
[docs] @staticmethod + def complex_to_chan_dim(x: torch.Tensor) -> torch.Tensor: + """Convert the last dimension of the input to complex.""" + b, c, h, w, two = x.shape + if two != 2: + raise AssertionError + return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
+ +
[docs] @staticmethod + def chan_complex_to_last_dim(x: torch.Tensor) -> torch.Tensor: + """Convert the last dimension of the input to complex.""" + b, c2, h, w = x.shape + if c2 % 2 != 0: + raise AssertionError + c = torch.div(c2, 2, rounding_mode="trunc") + return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
+ +
[docs] def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Normalize the input.""" + # group norm + b, c, h, w = x.shape + + x = x.reshape(b, self.norm_groups, -1) + + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + + x = (x - mean) / std + + x = x.reshape(b, c, h, w) + + return x, mean, std
+ +
[docs] def unnorm(self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: + """Unnormalize the input.""" + b, c, h, w = x.shape + input_data = x.reshape(b, self.norm_groups, -1) + return (input_data * std + mean).reshape(b, c, h, w)
+ +
[docs] def pad(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: + """Pad the input with zeros to make it square.""" + _, _, h, w = x.shape + w_mult = ((w - 1) | self.padding_size) + 1 + h_mult = ((h - 1) | self.padding_size) + 1 + w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] + h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] + # TODO: fix this type when PyTorch fixes theirs + # the documentation lies - this actually takes a list + # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457 + # https://github.com/pytorch/pytorch/pull/16949 + x = torch.nn.functional.pad(x, w_pad + h_pad) + + return x, (h_pad, w_pad, h_mult, w_mult)
+ +
[docs] @staticmethod + def unpad(x: torch.Tensor, h_pad: List[int], w_pad: List[int], h_mult: int, w_mult: int) -> torch.Tensor: + """Unpad the input.""" + return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]
+ +
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the network.""" + iscomplex = False + if x.shape[-1] == 2: + x = self.complex_to_chan_dim(x) + iscomplex = True + + mean = 1.0 + std = 1.0 + + if self.normalize: + x, mean, std = self.norm(x) + + x, pad_sizes = self.pad(x) + x = self.unet(x) + x = self.unpad(x, *pad_sizes) + + if self.normalize: + x = self.unnorm(x, mean, std) + + if iscomplex: + x = self.chan_complex_to_last_dim(x) + + return x
+ + +
[docs]class Unet(torch.nn.Module): + """ + PyTorch implementation of a U-Net model, as presented in [1]_. + + References + ---------- + .. [1] O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pages 234–241. Springer, 2015. + """ + + def __init__( + self, in_chans: int, out_chans: int, chans: int = 32, num_pool_layers: int = 4, drop_prob: float = 0.0 + ): + """ + Parameters + ---------- + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = torch.nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = torch.nn.ModuleList() + self.up_transpose_conv = torch.nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + torch.nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), torch.nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1) + ) + ) + +
[docs] def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns + ------- + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = torch.nn.functional.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output = self.conv(output) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/bottom if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = torch.nn.functional.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output
+ + +
[docs]class ConvBlock(torch.nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by instance normalization, LeakyReLU + activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Parameters + ---------- + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = torch.nn.Sequential( + torch.nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + torch.nn.InstanceNorm2d(out_chans), + torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), + torch.nn.Dropout2d(drop_prob), + torch.nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + torch.nn.InstanceNorm2d(out_chans), + torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), + torch.nn.Dropout2d(drop_prob), + ) + +
[docs] def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns + ------- + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image)
+ + +
[docs]class TransposeConvBlock(torch.nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose layers followed by instance + normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Parameters + ---------- + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = torch.nn.Sequential( + torch.nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + torch.nn.InstanceNorm2d(out_chans), + torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + +
[docs] def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns + ------- + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/variablesplittingnet/vsnet_block.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/variablesplittingnet/vsnet_block.html new file mode 100644 index 00000000..9a8b7921 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/variablesplittingnet/vsnet_block.html @@ -0,0 +1,226 @@ + + + + + + mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from typing import Any, List, Union
+
+import torch
+
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul
+
+
+
[docs]class DataConsistencyLayer(torch.nn.Module): + """ + Data consistency layer for the VSNet. + This layer is used to ensure that the output of the VSNet is the same as the input. + """ + + def __init__(self): + """Initializes the data consistency layer.""" + super().__init__() + self.dc_weight = torch.nn.Parameter(torch.ones(1)) + +
[docs] def forward(self, pred_kspace, ref_kspace, mask): + """Forward pass of the data consistency layer.""" + return ((1 - mask) * pred_kspace + mask * ref_kspace) * self.dc_weight
+ + +
[docs]class WeightedAverageTerm(torch.nn.Module): + """Weighted average term for the VSNet.""" + + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.ones(1)) + +
[docs] def forward(self, x, Sx): + return self.param * x + (1 - self.param) * Sx
+ + +
[docs]class VSNetBlock(torch.nn.Module): + """ + Model block for the Variable-Splitting Network inspired by [1]_. + + References + ---------- + .. [1] Duan, J. et al. (2019) ‘Vs-net: Variable splitting network for accelerated parallel MRI reconstruction’, Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture Notes in Bioinformatics), 11767 LNCS, pp. 713–722. doi: 10.1007/978-3-030-32251-9_78. + """ + + def __init__( + self, + denoiser_block: torch.nn.ModuleList, + data_consistency_block: torch.nn.ModuleList, + weighted_average_block: torch.nn.ModuleList, + num_cascades: int = 8, + fft_type: str = "orthogonal", + ): + """ + + Parameters + ---------- + denoiser_block: Model to apply denoising. + data_consistency_block: Model to apply data consistency. + weighted_average_block: Model to apply weighted average. + num_cascades: Number of cascades. + fft_type: Type of FFT to use. + """ + super().__init__() + + self.denoiser_block = denoiser_block + self.data_consistency_block = data_consistency_block + self.weighted_average_block = weighted_average_block + self.num_cascades = num_cascades + self.fft_type = fft_type + +
[docs] def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + """ + Expand the sensitivity maps to the same size as the input. + + Parameters + ---------- + x: Input data. + sens_maps: Coil Sensitivity maps. + + Returns + ------- + SENSE reconstruction expanded to the same size as the input sens_maps. + """ + return fft2c(complex_mul(x, sens_maps), fft_type=self.fft_type)
+ +
[docs] def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + """ + Reduce the sensitivity maps. + + Parameters + ---------- + x: Input data. + sens_maps: Coil Sensitivity maps. + + Returns + ------- + SENSE coil-combined reconstruction. + """ + x = ifft2c(x, fft_type=self.fft_type) + return complex_mul(x, complex_conj(sens_maps)).sum(1)
+ +
[docs] def forward( + self, + kspace: torch.Tensor, + sens_maps: torch.Tensor, + mask: torch.Tensor, + ) -> List[Union[torch.Tensor, Any]]: + """ + + Parameters + ---------- + kspace: Reference k-space data. + sens_maps: Coil sensitivity maps. + mask: Mask to apply to the data. + + Returns + ------- + Reconstructed image. + """ + for idx in range(self.num_cascades): + pred = self.sens_reduce(kspace, sens_maps) + pred = self.denoiser_block[idx](pred.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + pred = self.sens_expand(pred, sens_maps) + sx = self.data_consistency_block[idx](pred, kspace, mask) + sx = self.sens_reduce(sx, sens_maps) + kspace = self.weighted_average_block[idx](kspace + pred, sx) + return kspace
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/varnet/vn_block.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/varnet/vn_block.html new file mode 100644 index 00000000..36ca4fa4 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/varnet/vn_block.html @@ -0,0 +1,191 @@ + + + + + + mridc.collections.reconstruction.models.varnet.vn_block — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.varnet.vn_block
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.varnet.vn_block

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+import torch
+
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul
+
+
+
[docs]class VarNetBlock(torch.nn.Module): + """ + Model block for end-to-end variational network. + + This model applies a combination of soft data consistency with the input model as a regularizer. + A series of these blocks can be stacked to form the full variational network. + """ + + def __init__(self, model: torch.nn.Module, fft_type: str = "orthogonal", no_dc: bool = False): + """ + Initialize the model block. + + Parameters + ---------- + model: Model to apply soft data consistency. + fft_type: Type of FFT to use. + no_dc: Whether to remove the DC component. + """ + super().__init__() + + self.model = model + self.fft_type = fft_type + self.no_dc = no_dc + self.dc_weight = torch.nn.Parameter(torch.ones(1)) + +
[docs] def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + """ + Expand the sensitivity maps to the same size as the input. + + Parameters + ---------- + x: Input data. + sens_maps: Coil Sensitivity maps. + + Returns + ------- + SENSE reconstruction expanded to the same size as the input sens_maps. + """ + return fft2c(complex_mul(x, sens_maps), fft_type=self.fft_type)
+ +
[docs] def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + """ + Reduce the sensitivity maps. + + Parameters + ---------- + x: Input data. + sens_maps: Coil Sensitivity maps. + + Returns + ------- + SENSE coil-combined reconstruction. + """ + x = ifft2c(x, fft_type=self.fft_type) + return complex_mul(x, complex_conj(sens_maps)).sum(dim=1, keepdim=True)
+ +
[docs] def forward( + self, + pred: torch.Tensor, + ref_kspace: torch.Tensor, + sens_maps: torch.Tensor, + mask: torch.Tensor, + ) -> torch.Tensor: + """ + + Parameters + ---------- + kspace: Reference k-space data. + sens_maps: Coil sensitivity maps. + mask: Mask to apply to the data. + + Returns + ------- + Reconstructed image. + """ + zero = torch.zeros(1, 1, 1, 1, 1).to(pred) + soft_dc = torch.where(mask.bool(), pred - ref_kspace, zero) * self.dc_weight + + eta = self.sens_reduce(pred, sens_maps) + eta = self.model(eta) + eta = self.sens_expand(eta, sens_maps) + + if not self.no_dc: + eta = pred - soft_dc - eta + + return eta
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/vn.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/vn.html new file mode 100644 index 00000000..d058c21b --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/vn.html @@ -0,0 +1,224 @@ + + + + + + mridc.collections.reconstruction.models.vn — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.vn
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.vn

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from abc import ABC
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from torch.nn import L1Loss
+
+from mridc.collections.common.losses.ssim import SSIMLoss
+from mridc.collections.common.parts.fft import ifft2c
+from mridc.collections.common.parts.utils import coil_combination
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
+from mridc.collections.reconstruction.models.varnet.vn_block import VarNetBlock
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["VarNet"]
+
+
+
[docs]class VarNet(BaseMRIReconstructionModel, ABC): + """ + Implementation of the End-to-end Variational Network (VN), as presented in Sriram, A. et al. + + References + ---------- + + .. + + Sriram, A. et al. (2020) ‘End-to-End Variational Networks for Accelerated MRI Reconstruction’. Available \ + at: https://github.com/facebookresearch/fastMRI. + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + self.no_dc = cfg_dict.get("no_dc") + self.fft_type = cfg_dict.get("fft_type") + self.num_cascades = cfg_dict.get("num_cascades") + + # Cascades of VN blocks + self.cascades = torch.nn.ModuleList( + [ + VarNetBlock( + NormUnet( + chans=cfg_dict.get("channels"), + num_pools=cfg_dict.get("pooling_layers"), + padding_size=cfg_dict.get("padding_size"), + normalize=cfg_dict.get("normalize"), + ), + fft_type=self.fft_type, + no_dc=self.no_dc, + ) + for _ in range(self.num_cascades) + ] + ) + + self.output_type = cfg_dict.get("output_type") + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_chans"), + cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + normalize=cfg_dict.get("sens_normalize"), + ) + + # initialize weights if not using pretrained vn + # TODO if not cfg_dict.get("pretrained", False) + + self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() + self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() + + self.dc_weight = torch.nn.Parameter(torch.ones(1)) + self.accumulate_estimates = False + +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + init_pred: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the network. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] + If self.accumulate_loss is True, returns a list of all intermediate estimates. + If False, returns the final estimate. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + estimation = y.clone() + + for cascade in self.cascades: + # Forward pass through the cascades + estimation = cascade(estimation, y, sensitivity_maps, mask) + + estimation = ifft2c(estimation, fft_type=self.fft_type) + estimation = coil_combination(estimation, sensitivity_maps, method=self.output_type, dim=1) + estimation = torch.view_as_complex(estimation) + _, estimation = center_crop_to_smallest(target, estimation) + return estimation
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/vsnet.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/vsnet.html new file mode 100644 index 00000000..f542bea5 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/vsnet.html @@ -0,0 +1,250 @@ + + + + + + mridc.collections.reconstruction.models.vsnet — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.vsnet
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.vsnet

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from abc import ABC
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from torch.nn import L1Loss
+
+from mridc.collections.common.losses.ssim import SSIMLoss
+from mridc.collections.common.parts.fft import ifft2c
+from mridc.collections.common.parts.utils import coil_combination
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.models.conv.conv2d import Conv2d
+from mridc.collections.reconstruction.models.mwcnn.mwcnn import MWCNN
+from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
+from mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block import (
+    DataConsistencyLayer,
+    VSNetBlock,
+    WeightedAverageTerm,
+)
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["VSNet"]
+
+
+
[docs]class VSNet(BaseMRIReconstructionModel, ABC): + """ + Implementation of the Variable-Splitting Net, as presented in Duan, J. et al. + + References + ---------- + + .. + + Duan, J. et al. (2019) ‘Vs-net: Variable splitting network for accelerated parallel MRI reconstruction’, \ + Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture \ + Notes in Bioinformatics), 11767 LNCS, pp. 713–722. doi: 10.1007/978-3-030-32251-9_78. + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + num_cascades = cfg_dict.get("num_cascades") + self.fft_type = cfg_dict.get("fft_type") + + image_model_architecture = cfg_dict.get("imspace_model_architecture") + if image_model_architecture == "CONV": + image_model = Conv2d( + in_channels=2, + out_channels=2, + hidden_channels=cfg_dict.get("imspace_conv_hidden_channels"), + n_convs=cfg_dict.get("imspace_conv_n_convs"), + batchnorm=cfg_dict.get("imspace_conv_batchnorm"), + ) + elif image_model_architecture == "MWCNN": + image_model = MWCNN( + input_channels=2, + first_conv_hidden_channels=cfg_dict.get("image_mwcnn_hidden_channels"), + num_scales=cfg_dict.get("image_mwcnn_num_scales"), + bias=cfg_dict.get("image_mwcnn_bias"), + batchnorm=cfg_dict.get("image_mwcnn_batchnorm"), + ) + elif image_model_architecture in ["UNET", "NORMUNET"]: + image_model = NormUnet( + cfg_dict.get("imspace_unet_num_filters"), + cfg_dict.get("imspace_unet_num_pool_layers"), + in_chans=2, + out_chans=2, + drop_prob=cfg_dict.get("imspace_unet_dropout_probability"), + padding_size=cfg_dict.get("imspace_unet_padding_size"), + normalize=cfg_dict.get("imspace_unet_normalize"), + ) + else: + raise NotImplementedError( + f"VSNet is currently implemented only with image_model_architecture == 'MWCNN' or 'UNet'." + f"Got {image_model_architecture}." + ) + + image_model = torch.nn.ModuleList([image_model] * num_cascades) + data_consistency_model = torch.nn.ModuleList([DataConsistencyLayer()] * num_cascades) + weighted_average_model = torch.nn.ModuleList([WeightedAverageTerm()] * num_cascades) + + self.model = VSNetBlock( + denoiser_block=image_model, + data_consistency_block=data_consistency_model, + weighted_average_block=weighted_average_model, + num_cascades=num_cascades, + fft_type=self.fft_type, + ) + + self._coil_dim = 1 + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_chans"), + cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + normalize=cfg_dict.get("sens_normalize"), + ) + + self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() + self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() + self.output_type = cfg_dict.get("output_type") + + self.accumulate_estimates = False + +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + init_pred: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the network. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] + If self.accumulate_loss is True, returns a list of all intermediate estimates. + If False, returns the final estimate. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + image = self.model(y, sensitivity_maps, mask) + image = torch.view_as_complex( + coil_combination(ifft2c(image, fft_type=self.fft_type), sensitivity_maps, method=self.output_type, dim=1) + ) + _, image = center_crop_to_smallest(target, image) + return image
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/xpdnet.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/xpdnet.html new file mode 100644 index 00000000..00e5ac0e --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/xpdnet.html @@ -0,0 +1,314 @@ + + + + + + mridc.collections.reconstruction.models.xpdnet — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.xpdnet
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.xpdnet

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from abc import ABC
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from torch.nn import L1Loss
+
+from mridc.collections.common.losses.ssim import SSIMLoss
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.models.conv.conv2d import Conv2d
+from mridc.collections.reconstruction.models.crossdomain.crossdomain import CrossDomainNetwork
+from mridc.collections.reconstruction.models.crossdomain.multicoil import MultiCoil
+from mridc.collections.reconstruction.models.didn.didn import DIDN
+from mridc.collections.reconstruction.models.mwcnn.mwcnn import MWCNN
+from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["XPDNet"]
+
+
+
[docs]class XPDNet(BaseMRIReconstructionModel, ABC): + """ + Implementation of the XPDNet, as presented in Ramzi, Zaccharie, et al. + + References + ---------- + + .. + + Ramzi, Zaccharie, et al. “XPDNet for MRI Reconstruction: An Application to the 2020 FastMRI Challenge. \ + ” ArXiv:2010.07290 [Physics, Stat], July 2021. arXiv.org, http://arxiv.org/abs/2010.07290. + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + num_primal = cfg_dict.get("num_primal") + num_dual = cfg_dict.get("num_dual") + num_iter = cfg_dict.get("num_iter") + + kspace_model_architecture = cfg_dict.get("kspace_model_architecture") + dual_conv_hidden_channels = cfg_dict.get("dual_conv_hidden_channels") + dual_conv_num_dubs = cfg_dict.get("dual_conv_num_dubs") + dual_conv_batchnorm = cfg_dict.get("dual_conv_batchnorm") + dual_didn_hidden_channels = cfg_dict.get("dual_didn_hidden_channels") + dual_didn_num_dubs = cfg_dict.get("dual_didn_num_dubs") + dual_didn_num_convs_recon = cfg_dict.get("dual_didn_num_convs_recon") + + if cfg_dict.get("use_primal_only"): + kspace_model_list = None + num_dual = 1 + elif kspace_model_architecture == "CONV": + kspace_model_list = torch.nn.ModuleList( + [ + MultiCoil( + Conv2d( + 2 * (num_dual + num_primal + 1), + 2 * num_dual, + dual_conv_hidden_channels, + dual_conv_num_dubs, + batchnorm=dual_conv_batchnorm, + ) + ) + for _ in range(num_iter) + ] + ) + elif kspace_model_architecture == "DIDN": + kspace_model_list = torch.nn.ModuleList( + [ + MultiCoil( + DIDN( + in_channels=2 * (num_dual + num_primal + 1), + out_channels=2 * num_dual, + hidden_channels=dual_didn_hidden_channels, + num_dubs=dual_didn_num_dubs, + num_convs_recon=dual_didn_num_convs_recon, + ) + ) + for _ in range(num_iter) + ] + ) + elif kspace_model_architecture in ["UNET", "NORMUNET"]: + kspace_model_list = torch.nn.ModuleList( + [ + MultiCoil( + NormUnet( + cfg_dict.get("kspace_unet_num_filters"), + cfg_dict.get("kspace_unet_num_pool_layers"), + in_chans=2 * (num_dual + num_primal + 1), + out_chans=2 * num_dual, + drop_prob=cfg_dict.get("kspace_unet_dropout_probability"), + padding_size=cfg_dict.get("kspace_unet_padding_size"), + normalize=cfg_dict.get("kspace_unet_normalize"), + ), + coil_to_batch=True, + ) + for _ in range(num_iter) + ] + ) + else: + raise NotImplementedError( + "XPDNet is currently implemented for kspace_model_architecture == 'CONV' or 'DIDN'." + f"Got kspace_model_architecture == {kspace_model_architecture}." + ) + + image_model_architecture = cfg_dict.get("image_model_architecture") + mwcnn_hidden_channels = cfg_dict.get("mwcnn_hidden_channels") + mwcnn_num_scales = cfg_dict.get("mwcnn_num_scales") + mwcnn_bias = cfg_dict.get("mwcnn_bias") + mwcnn_batchnorm = cfg_dict.get("mwcnn_batchnorm") + + if image_model_architecture == "MWCNN": + image_model_list = torch.nn.ModuleList( + [ + torch.nn.Sequential( + MWCNN( + input_channels=2 * (num_primal + num_dual), + first_conv_hidden_channels=mwcnn_hidden_channels, + num_scales=mwcnn_num_scales, + bias=mwcnn_bias, + batchnorm=mwcnn_batchnorm, + ), + torch.nn.Conv2d(2 * (num_primal + num_dual), 2 * num_primal, kernel_size=3, padding=1), + ) + for _ in range(num_iter) + ] + ) + elif image_model_architecture in ["UNET", "NORMUNET"]: + image_model_list = torch.nn.ModuleList( + [ + NormUnet( + cfg_dict.get("imspace_unet_num_filters"), + cfg_dict.get("imspace_unet_num_pool_layers"), + in_chans=2 * (num_primal + num_dual), + out_chans=2 * num_primal, + drop_prob=cfg_dict.get("imspace_unet_dropout_probability"), + padding_size=cfg_dict.get("imspace_unet_padding_size"), + normalize=cfg_dict.get("imspace_unet_normalize"), + ) + for _ in range(num_iter) + ] + ) + else: + raise NotImplementedError(f"Image model architecture {image_model_architecture} not found for XPDNet.") + + self.fft_type = cfg_dict.get("fft_type") + + self.xpdnet = CrossDomainNetwork( + fft_type=self.fft_type, + image_model_list=image_model_list, + kspace_model_list=kspace_model_list, + domain_sequence="KI" * num_iter, + image_buffer_size=num_primal, + kspace_buffer_size=num_dual, + normalize_image=cfg_dict.get("normalize_image"), + ) + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + cfg_dict.get("sens_chans"), + cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=cfg_dict.get("sens_mask_type"), + normalize=cfg_dict.get("sens_normalize"), + ) + + self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() + self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() + self.output_type = cfg_dict.get("output_type") + + self.accumulate_estimates = False + +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + init_pred: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the network. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] + If self.accumulate_loss is True, returns a list of all intermediate estimates. + If False, returns the final estimate. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + eta = self.xpdnet(y, sensitivity_maps, mask) + eta = (eta**2).sqrt().sum(-1) + _, eta = center_crop_to_smallest(target, eta) + return eta
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/models/zf.html b/docs/build/html/_modules/mridc/collections/reconstruction/models/zf.html new file mode 100644 index 00000000..aa59373d --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/models/zf.html @@ -0,0 +1,253 @@ + + + + + + mridc.collections.reconstruction.models.zf — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.models.zf
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.models.zf

+# coding=utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from abc import ABC
+from typing import Any, Dict, Tuple, Union
+
+import numpy as np
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+
+from mridc.collections.common.parts.fft import ifft2c
+from mridc.collections.common.parts.utils import check_stacked_complex, coil_combination
+from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
+from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
+from mridc.core.classes.common import typecheck
+
+__all__ = ["ZF"]
+
+
+
[docs]class ZF(BaseMRIReconstructionModel, ABC): + """ + Zero-Filled reconstruction using either root-sum-of-squares (RSS) or SENSE (SENSitivity Encoding), as presented \ + in Pruessmann KP, Weiger M, Scheidegger MB, Boesiger P. + + References + ---------- + + .. + + Pruessmann KP, Weiger M, Scheidegger MB, Boesiger P. SENSE: Sensitivity encoding for fast MRI. Magn Reson \ + Med 1999; 42:952-962. + + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # init superclass + super().__init__(cfg=cfg, trainer=trainer) + + zf_cfg_dict = OmegaConf.to_container(cfg, resolve=True) + + self.zf_method = zf_cfg_dict.get("zf_method") + self.fft_type = zf_cfg_dict.get("fft_type") + + # Initialize the sensitivity network if use_sens_net is True + self.use_sens_net = zf_cfg_dict.get("use_sens_net") + if self.use_sens_net: + self.sens_net = BaseSensitivityModel( + zf_cfg_dict.get("sens_chans"), + zf_cfg_dict.get("sens_pools"), + fft_type=self.fft_type, + mask_type=zf_cfg_dict.get("sens_mask_type"), + normalize=zf_cfg_dict.get("sens_normalize"), + ) + +
[docs] @staticmethod + def process_inputs(y, mask): + """ + Process the inputs to the method. + + Parameters + ---------- + y: Subsampled k-space data. + list of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + list of torch.Tensor, shape [1, 1, n_x, n_y, 1] + + Returns + ------- + y: Subsampled k-space data. + randomly selected y + mask: Sampling mask. + randomly selected mask + r: Random index. + """ + if isinstance(y, list): + r = np.random.randint(len(y)) + y = y[r] + mask = mask[r] + else: + r = 0 + return y, mask, r
+ +
[docs] @typecheck() + def forward( + self, + y: torch.Tensor, + sensitivity_maps: torch.Tensor, + mask: torch.Tensor, + target: torch.Tensor = None, + ) -> Union[list, Any]: + """ + Forward pass of the zero-filled method. + + Parameters + ---------- + y: Subsampled k-space data. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + sensitivity_maps: Coil sensitivity maps. + torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + mask: Sampling mask. + torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + target: Target data to compute the loss. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + + Returns + ------- + pred: torch.Tensor, shape [batch_size, n_x, n_y, 2] + Predicted data. + """ + sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps + pred = coil_combination( + ifft2c(y, fft_type=self.fft_type), sensitivity_maps, method=self.zf_method.upper(), dim=1 + ) + pred = check_stacked_complex(pred) + _, pred = center_crop_to_smallest(target, pred) + return pred
+ +
[docs] def test_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Tuple[str, int, torch.Tensor]: + """ + Test step. + + Parameters + ---------- + batch: Batch of data. + Dict of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] + batch_idx: Batch index. + int + + Returns + ------- + name: Name of the volume. + str + slice_num: Slice number. + int + pred: Predicted data. + torch.Tensor, shape [batch_size, n_x, n_y, 2] + """ + y, sensitivity_maps, mask, init_pred, target, fname, slice_num, _ = batch + y, mask, _ = self.process_inputs(y, mask) + prediction = self.forward(y, sensitivity_maps, mask, target) + + slice_num = int(slice_num) + name = str(fname[0]) # type: ignore + key = f"{name}_images_idx_{slice_num}" # type: ignore + output = torch.abs(prediction).detach().cpu() + target = torch.abs(target).detach().cpu() + output = output / output.max() # type: ignore + target = target / target.max() # type: ignore + error = torch.abs(target - output) + self.log_image(f"{key}/target", target) + self.log_image(f"{key}/reconstruction", output) + self.log_image(f"{key}/error", error) + + return name, slice_num, prediction.detach().cpu().numpy()
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/parts/transforms.html b/docs/build/html/_modules/mridc/collections/reconstruction/parts/transforms.html new file mode 100644 index 00000000..72f08b87 --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/parts/transforms.html @@ -0,0 +1,425 @@ + + + + + + mridc.collections.reconstruction.parts.transforms — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.parts.transforms
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.parts.transforms

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from mridc.collections.common.parts.fft import fft2c, ifft2c
+from mridc.collections.common.parts.utils import complex_conj, complex_mul, to_tensor
+from mridc.collections.reconstruction.data.subsample import MaskFunc
+from mridc.collections.reconstruction.parts.utils import apply_mask, center_crop, complex_center_crop
+
+__all__ = ["MRIDataTransforms"]
+
+
+
[docs]class MRIDataTransforms: + """MRI preprocessing data transforms.""" + + def __init__( + self, + mask_func: Optional[List[MaskFunc]] = None, + shift_mask: bool = False, + mask_center_scale: Optional[float] = 0.02, + half_scan_percentage: float = 0.0, + crop_size: Optional[Tuple[int, int]] = None, + kspace_crop: bool = False, + crop_before_masking: bool = True, + kspace_zero_filling_size: Optional[Tuple] = None, + normalize_inputs: bool = False, + fft_type: str = "orthogonal", + use_seed: bool = True, + ): + """ + Initialize the data transform. + + Parameters + ---------- + mask_func: The function that masks the kspace. + shift_mask: Whether to shift the mask. + mask_center_scale: The scale of the center of the mask. + half_scan_percentage: The percentage of the scan to be used. + crop_size: The size of the crop. + kspace_crop: Whether to crop the kspace. + crop_before_masking: Whether to crop before masking. + kspace_zero_filling_size: The size of padding in kspace -> zero filling. + normalize_inputs: Whether to normalize the inputs. + fft_type: The type of the FFT. + use_seed: Whether to use the seed. + """ + self.mask_func = mask_func + self.shift_mask = shift_mask + self.mask_center_scale = mask_center_scale + self.half_scan_percentage = half_scan_percentage + self.crop_size = crop_size + self.kspace_crop = kspace_crop + self.crop_before_masking = crop_before_masking + self.kspace_zero_filling_size = kspace_zero_filling_size + self.normalize_inputs = normalize_inputs + self.fft_type = fft_type + self.use_seed = use_seed + +
[docs] def __call__( + self, + kspace: np.ndarray, + sensitivity_map: np.ndarray, + mask: np.ndarray, + eta: np.ndarray, + target: np.ndarray, + attrs: Dict, + fname: str, + slice_idx: int, + ) -> Tuple[ + Union[Union[List[Union[torch.Tensor, Any]], torch.Tensor], Any], + Union[Optional[torch.Tensor], Any], + Union[List, Any], + Union[Optional[torch.Tensor], Any], + Union[torch.Tensor, Any], + str, + int, + Union[Union[List, torch.Tensor], Any], + ]: + """ + Apply the data transform. + + Parameters + ---------- + kspace: The kspace. + sensitivity_map: The sensitivity map. + mask: The mask. + eta: The initial estimation. + target: The target. + attrs: The attributes. + fname: The file name. + slice_idx: The slice number. + + Returns + ------- + The transformed data. + """ + kspace = to_tensor(kspace) + + # This condition is necessary in case of auto estimation of sense maps. + if sensitivity_map is not None and sensitivity_map.size != 0: + sensitivity_map = to_tensor(sensitivity_map) + + # Apply zero-filling on kspace + if self.kspace_zero_filling_size is not None and self.kspace_zero_filling_size not in ("", "None"): + padding_top = np.floor_divide(abs(int(self.kspace_zero_filling_size[0]) - kspace.shape[1]), 2) + padding_bottom = padding_top + padding_left = np.floor_divide(abs(int(self.kspace_zero_filling_size[1]) - kspace.shape[2]), 2) + padding_right = padding_left + + kspace = torch.view_as_complex(kspace) + kspace = torch.nn.functional.pad( + kspace, pad=(padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=0 + ) + kspace = torch.view_as_real(kspace) + + sensitivity_map = fft2c(sensitivity_map, self.fft_type) + sensitivity_map = torch.view_as_complex(sensitivity_map) + sensitivity_map = torch.nn.functional.pad( + sensitivity_map, + pad=(padding_left, padding_right, padding_top, padding_bottom), + mode="constant", + value=0, + ) + sensitivity_map = torch.view_as_real(sensitivity_map) + sensitivity_map = ifft2c(sensitivity_map, self.fft_type) + + if eta is not None and eta.size != 0: + eta = to_tensor(eta) + else: + eta = torch.tensor([]) + + # TODO: add RSS target option + if sensitivity_map is not None and sensitivity_map.size != 0: + target = torch.sum(complex_mul(ifft2c(kspace, fft_type=self.fft_type), complex_conj(sensitivity_map)), 0) + target = torch.view_as_complex(target) + elif target is not None and target.size != 0: + target = to_tensor(target) + elif "target" in attrs or "target_rss" in attrs: + target = torch.tensor(attrs["target"]) + else: + raise ValueError("No target found") + + target = torch.abs(target / torch.max(torch.abs(target))) + + seed = None if not self.use_seed else tuple(map(ord, fname)) + acq_start = attrs["padding_left"] if "padding_left" in attrs else 0 + acq_end = attrs["padding_right"] if "padding_left" in attrs else 0 + + # This should be outside of the condition because it needs to be returned in the end, even if cropping is off. + # crop_size = torch.tensor([attrs["recon_size"][0], attrs["recon_size"][1]]) + crop_size = target.shape + + if self.crop_size is not None and self.crop_size not in ("", "None"): + # Check for smallest size against the target shape. + h = int(self.crop_size[0]) if int(self.crop_size[0]) <= target.shape[0] else target.shape[0] + w = int(self.crop_size[1]) if int(self.crop_size[1]) <= target.shape[1] else target.shape[1] + + # Check for smallest size against the stored recon shape in metadata. + if crop_size[0] != 0: + h = h if h <= crop_size[0] else crop_size[0] + if crop_size[1] != 0: + w = w if w <= crop_size[1] else crop_size[1] + + self.crop_size = (int(h), int(w)) + + target = center_crop(target, self.crop_size) + if sensitivity_map is not None and sensitivity_map.size != 0: + sensitivity_map = ( + ifft2c( + complex_center_crop(fft2c(sensitivity_map, fft_type=self.fft_type), self.crop_size), + fft_type=self.fft_type, + ) + if self.kspace_crop + else complex_center_crop(sensitivity_map, self.crop_size) + ) + + if eta is not None and eta.ndim > 2: + eta = ( + ifft2c( + complex_center_crop(fft2c(eta, fft_type=self.fft_type), self.crop_size), fft_type=self.fft_type + ) + if self.kspace_crop + else complex_center_crop(eta, self.crop_size) + ) + + # Cropping before masking will maintain the shape of original kspace intact for masking. + if self.crop_size is not None and self.crop_size not in ("", "None") and self.crop_before_masking: + kspace = ( + complex_center_crop(kspace, self.crop_size) + if self.kspace_crop + else fft2c( + complex_center_crop(ifft2c(kspace, fft_type=self.fft_type), self.crop_size), fft_type=self.fft_type + ) + ) + + if self.mask_func is not None: + # Check for multiple masks/accelerations. + if isinstance(self.mask_func, list): + masked_kspaces = [] + masks = [] + accs = [] + for m in self.mask_func: + _masked_kspace, _mask, _acc = apply_mask( + kspace, + m, + seed, + (acq_start, acq_end), + shift=self.shift_mask, + half_scan_percentage=self.half_scan_percentage, + center_scale=self.mask_center_scale, + ) + masked_kspaces.append(_masked_kspace) + masks.append(_mask.byte()) + accs.append(_acc) + masked_kspace = masked_kspaces + mask = masks + acc = accs + else: + masked_kspace, mask, acc = apply_mask( + kspace, + self.mask_func[0], # type: ignore + seed, + (acq_start, acq_end), + shift=self.shift_mask, + half_scan_percentage=self.half_scan_percentage, + center_scale=self.mask_center_scale, + ) + mask = mask.byte() + else: + masked_kspace = kspace + acc = torch.tensor([np.around(mask.size / mask.sum())]) if mask is not None else torch.tensor([1]) + + if mask is not None: + mask = torch.from_numpy(mask) + if mask.shape[0] == masked_kspace.shape[2]: # type: ignore + mask = mask.permute(1, 0) + elif mask.shape[0] != masked_kspace.shape[1]: # type: ignore + mask = torch.ones( + [masked_kspace.shape[-3], masked_kspace.shape[-2]], dtype=torch.float32 # type: ignore + ) + else: + mask = torch.ones( + [masked_kspace.shape[-3], masked_kspace.shape[-2]], dtype=torch.float32 # type: ignore + ) + + if mask.ndim == 1: + mask = np.expand_dims(mask, axis=0) + + if mask.shape[-2] == 1: # 1D mask + mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(-1) + else: # 2D mask + # Crop loaded mask. + if self.crop_size is not None and self.crop_size not in ("", "None"): + mask = center_crop(mask, self.crop_size) + + mask = mask.unsqueeze(0).unsqueeze(-1) + + if self.shift_mask: + mask = torch.fft.fftshift(mask, dim=[-3, -2]) + + masked_kspace = masked_kspace * mask + mask = mask.byte() + + # Cropping after masking. + if self.crop_size is not None and self.crop_size not in ("", "None") and not self.crop_before_masking: + masked_kspace = ( + complex_center_crop(masked_kspace, self.crop_size) + if self.kspace_crop + else fft2c( + complex_center_crop(ifft2c(masked_kspace, fft_type=self.fft_type), self.crop_size), + fft_type=self.fft_type, + ) + ) + + mask = center_crop(mask.squeeze(-1), self.crop_size).unsqueeze(-1) + + # Normalize by the max value. + if self.normalize_inputs: + if isinstance(self.mask_func, list): + masked_kspaces = [] + for y in masked_kspace: + if self.fft_type in ("orthogonal", "orthogonal_norm_only"): + imspace = ifft2c(y, fft_type=self.fft_type) + imspace = imspace / torch.max(torch.abs(imspace)) + masked_kspaces.append(fft2c(imspace, fft_type=self.fft_type)) + elif self.fft_type == "fft_norm_only": + imspace = ifft2c(y, fft_type=self.fft_type) + masked_kspaces.append(fft2c(imspace, fft_type=self.fft_type)) + elif self.fft_type == "backward_norm": + imspace = ifft2c(y, fft_type=self.fft_type, fft_normalization="backward") + masked_kspaces.append(fft2c(imspace, fft_type=self.fft_type, fft_normalization="backward")) + else: + imspace = torch.fft.ifftn(torch.view_as_complex(y), dim=[-2, -1], norm=None) + imspace = imspace / torch.max(torch.abs(imspace)) + masked_kspaces.append(torch.view_as_real(torch.fft.fftn(imspace, dim=[-2, -1], norm=None))) + masked_kspace = masked_kspaces + else: + if self.fft_type in ("orthogonal", "orthogonal_norm_only"): + imspace = ifft2c(masked_kspace, fft_type=self.fft_type) + imspace = imspace / torch.max(torch.abs(imspace)) + masked_kspace = fft2c(imspace, fft_type=self.fft_type) + elif self.fft_type == "fft_norm_only": + masked_kspace = fft2c(ifft2c(masked_kspace, fft_type=self.fft_type), fft_type=self.fft_type) + elif self.fft_type == "backward_norm": + masked_kspace = fft2c( + ifft2c(masked_kspace, fft_type=self.fft_type, fft_normalization="backward"), + fft_type=self.fft_type, + fft_normalization="backward", + ) + else: + imspace = torch.fft.ifftn(torch.view_as_complex(masked_kspace), dim=[-2, -1], norm=None) + imspace = imspace / torch.max(torch.abs(imspace)) + masked_kspace = torch.view_as_real(torch.fft.fftn(imspace, dim=[-2, -1], norm=None)) + + if sensitivity_map.size != 0: + sensitivity_map = sensitivity_map / torch.max(torch.abs(sensitivity_map)) + + if eta.size != 0 and eta.ndim > 2: + eta = eta / torch.max(torch.abs(eta)) + + target = target / torch.max(torch.abs(target)) + + return masked_kspace, sensitivity_map, mask, eta, target, fname, slice_idx, acc
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/collections/reconstruction/parts/utils.html b/docs/build/html/_modules/mridc/collections/reconstruction/parts/utils.html new file mode 100644 index 00000000..c0b9f4ce --- /dev/null +++ b/docs/build/html/_modules/mridc/collections/reconstruction/parts/utils.html @@ -0,0 +1,302 @@ + + + + + + mridc.collections.reconstruction.parts.utils — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.collections.reconstruction.parts.utils
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.collections.reconstruction.parts.utils

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI
+
+from typing import Any, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+
+__all__ = [
+    "apply_mask",
+    "mask_center",
+    "batched_mask_center",
+    "center_crop",
+    "complex_center_crop",
+    "center_crop_to_smallest",
+]
+
+from mridc.collections.reconstruction.data.subsample import MaskFunc
+
+
+
[docs]def apply_mask( + data: torch.Tensor, + mask_func: MaskFunc, + seed: Optional[Union[int, Tuple[int, ...]]] = None, + padding: Optional[Sequence[int]] = None, + shift: bool = False, + half_scan_percentage: Optional[float] = 0.0, + center_scale: Optional[float] = 0.02, +) -> Tuple[Any, Any, Any]: + """ + Subsample given k-space by multiplying with a mask. + + Parameters + ---------- + data: The input k-space data. This should have at least 3 dimensions, where dimensions -3 and -2 are the + spatial dimensions, and the final dimension has size 2 (for complex values). + mask_func: A function that takes a shape (tuple of ints) and a random number seed and returns a mask. + seed: Seed for the random number generator. + padding: Padding value to apply for mask. + shift: Toggle to shift mask when subsampling. Applicable on 2D data. + half_scan_percentage: Percentage of kspace to be dropped. + center_scale: Scale of the center of the mask. Applicable on Gaussian masks. + + Returns + ------- + Tuple of subsampled k-space, mask, and mask indices. + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask, acc = mask_func(shape, seed, half_scan_percentage=half_scan_percentage, scale=center_scale) + + if padding is not None and padding[0] != 0: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + if shift: + mask = torch.fft.fftshift(mask, dim=(1, 2)) + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask, acc
+ + +
[docs]def mask_center( + x: torch.Tensor, mask_from: Optional[int], mask_to: Optional[int], mask_type: str = "2D" +) -> torch.Tensor: + """ + Apply a center crop to the input real image or batch of real images. + + Parameters + ---------- + x: The input real image or batch of real images. + mask_from: Part of center to start filling. + mask_to: Part of center to end filling. + mask_type: Type of mask to apply. Can be either "1D" or "2D". + + Returns + ------- + A mask with the center filled. + """ + mask = torch.zeros_like(x) + + if isinstance(mask_from, list): + mask_from = mask_from[0] + + if isinstance(mask_to, list): + mask_to = mask_to[0] + + if mask_type == "1D": + mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to] + elif mask_type == "2D": + mask[:, :, mask_from:mask_to] = x[:, :, mask_from:mask_to] + + return mask
+ + +
[docs]def batched_mask_center( + x: torch.Tensor, mask_from: torch.Tensor, mask_to: torch.Tensor, mask_type: str = "2D" +) -> torch.Tensor: + """ + Initializes a mask with the center filled in. Can operate with different masks for each batch element. + + Parameters + ---------- + x: The input real image or batch of real images. + mask_from: Part of center to start filling. + mask_to: Part of center to end filling. + mask_type: Type of mask to apply. Can be either "1D" or "2D". + + Returns + ------- + A mask with the center filled. + """ + if mask_from.shape != mask_to.shape: + raise ValueError("mask_from and mask_to must match shapes.") + if mask_from.ndim != 1: + raise ValueError("mask_from and mask_to must have 1 dimension.") + if mask_from.shape[0] not in (1, x.shape[0]) or x.shape[0] != mask_to.shape[0]: + raise ValueError("mask_from and mask_to must have batch_size length.") + + if mask_from.shape[0] == 1: + mask = mask_center(x, int(mask_from), int(mask_to), mask_type=mask_type) + else: + mask = torch.zeros_like(x) + for i, (start, end) in enumerate(zip(mask_from, mask_to)): + mask[i, :, :, start:end] = x[i, :, :, start:end] + + return mask
+ + +
[docs]def center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor: + """ + Apply a center crop to the input real image or batch of real images. + + Parameters + ---------- + data: The input tensor to be center cropped. It should have at least 2 dimensions and the cropping is applied + along the last two dimensions. + shape: The output shape. The shape should be smaller than the corresponding dimensions of data. + + Returns + ------- + The center cropped image. + """ + if not (0 < shape[0] <= data.shape[-2] and 0 < shape[1] <= data.shape[-1]): + raise ValueError("Invalid shapes.") + + w_from = torch.div((data.shape[-2] - shape[0]), 2, rounding_mode="trunc") + h_from = torch.div((data.shape[-1] - shape[1]), 2, rounding_mode="trunc") + w_to = w_from + shape[0] + h_to = h_from + shape[1] + + return data[..., w_from:w_to, h_from:h_to] # type: ignore
+ + +
[docs]def complex_center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor: + """ + Apply a center crop to the input image or batch of complex images. + + Parameters + ---------- + data: The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is + applied along dimensions -3 and -2 and the last dimensions should have a size of 2. + shape: The output shape. The shape should be smaller than the corresponding dimensions of data. + + Returns + ------- + The center cropped image. + """ + if not (0 < shape[0] <= data.shape[-3] and 0 < shape[1] <= data.shape[-2]): + raise ValueError("Invalid shapes.") + + w_from = torch.div((data.shape[-3] - shape[0]), 2, rounding_mode="trunc") + h_from = torch.div((data.shape[-2] - shape[1]), 2, rounding_mode="trunc") + w_to = w_from + shape[0] + h_to = h_from + shape[1] + + return data[..., w_from:w_to, h_from:h_to, :] # type: ignore
+ + +
[docs]def center_crop_to_smallest( + x: Union[torch.Tensor, np.ndarray], y: Union[torch.Tensor, np.ndarray] +) -> Tuple[Union[torch.Tensor, np.ndarray], Union[torch.Tensor, np.ndarray]]: + """ + Apply a center crop on the larger image to the size of the smaller. + + The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at dim=-1 and y is smaller than x at dim=-2, + then the returned dimension will be a mixture of the two. + + Parameters + ---------- + x: The first image. + y: The second image. + + Returns + ------- + Tuple of tensors x and y, each cropped to the minimum size. + """ + smallest_width = min(x.shape[-1], y.shape[-1]) + smallest_height = min(x.shape[-2], y.shape[-2]) + x = center_crop(x, (smallest_height, smallest_width)) + y = center_crop(y, (smallest_height, smallest_width)) + + return x, y
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/classes/common.html b/docs/build/html/_modules/mridc/core/classes/common.html new file mode 100644 index 00000000..ef3458c7 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/classes/common.html @@ -0,0 +1,888 @@ + + + + + + mridc.core.classes.common — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.classes.common
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.classes.common

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Interfaces common to all Neural Modules and Models.
+import hashlib
+import inspect
+import traceback
+from abc import ABC
+from contextlib import contextmanager
+from dataclasses import dataclass, field
+from enum import Enum
+from functools import total_ordering
+from pathlib import Path
+from typing import Dict, List, Optional, Union
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/classes/common.py
+import hydra
+import torch
+import wrapt
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+
+import mridc.utils
+from mridc.core.conf.trainer import TrainerConfig
+from mridc.core.connectors.save_restore_connector import SaveRestoreConnector
+from mridc.core.neural_types.comparison import NeuralTypeComparisonResult
+from mridc.core.neural_types.neural_type import NeuralType
+from mridc.utils import logging
+from mridc.utils.cloud import maybe_download_from_cloud
+
+_HAS_HYDRA = True
+
+__all__ = ["Typing", "FileIO", "Model", "PretrainedModelInfo", "Serialization", "is_typecheck_enabled", "typecheck"]
+
+_TYPECHECK_ENABLED = True
+
+
+
[docs]def is_typecheck_enabled(): + """Getter method for typechecking state.""" + return _TYPECHECK_ENABLED
+ + +@dataclass(frozen=True) +class TypecheckMetadata: + """ + Metadata class for input/output neural types. + # Primary attributes + original_types: Preserve the dictionary of type information provided. + ignore_collections: For backward compatibility, container support can be disabled explicitly + using this flag. When set to True, all nesting is ignored and nest-depth checks are skipped. + # Derived attributed + mandatory_types: Sub-dictionary of `original_types` which contains only those types which + are mandatory to include when calling the function. + base_types: Dictionary of flattened `str: NeuralType` definitions, disregarding the nest level + details into appropriate arguments. + container_depth: Dictionary mapping `str: int` - such that the valid depth of the nest of this + neural type is recorded. + has_container_types: Bool flag declaring if any of the neural types declares a container nest + in its signature. + is_singular_container_type: Bool flag declaring if this is a single Neural Type with a container + nest in its signature. Required for supporting python list expansion in return statement. + """ + + original_types: Dict[str, NeuralType] + ignore_collections: bool + + mandatory_types: Dict[str, NeuralType] = field(init=False) + base_types: Dict[str, NeuralType] = field(init=False) + + container_depth: Dict[str, int] = field(init=False) + has_container_types: bool = field(init=False) + is_singular_container_type: bool = field(init=False) + + def __post_init__(self): + has_container_types = any(isinstance(type_val, (list, tuple)) for type_val in self.original_types.values()) + + self.has_container_types = has_container_types + + # If only one NeuralType is declared, and it declares a container nest, set to True + self.is_singular_container_type = self.has_container_types and len(self.original_types) == 1 + + # If container nests are declared, flatten the nest into `base_types` + # Also compute the nest depth for each of the NeuralTypes + if self.has_container_types: + self.base_types = {} + self.container_depth = {} + + for type_key, type_val in self.original_types.items(): + depth = 0 + while isinstance(type_val, (list, tuple)): + if len(type_val) > 1: + raise TypeError( + f"Neural Type `{type_key}`: {type_val} definition contains more than one element when" + "declaring the nested container structure.\n" + "Please ensure that you have only 1 NeuralType inside of the entire nested structure " + "definition." + ) + + type_val = type_val[0] + depth += 1 + + self.base_types[type_key] = type_val + self.container_depth[type_key] = depth + else: + # Otherwise, simply preserve the original_types and set depth of nest to 0. + self.base_types = self.original_types + self.container_depth = {type_key: 0 for type_key in self.base_types.keys()} + + # Compute subset of original_types which are mandatory in the call argspec + self.mandatory_types = { + type_key: type_val for type_key, type_val in self.base_types.items() if not type_val.optional + } + + +
[docs]class Typing(ABC): + """An interface which endows module with neural types""" + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + """Define these to enable input neural type checks""" + return None + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Define these to enable output neural type checks""" + return None + + def _validate_input_types(self, input_types=None, ignore_collections=False, **kwargs): + """ + This function does a few things. + 1) It ensures that len(self.input_types <non-optional>) <= len(kwargs) <= len(self.input_types). + 2) For each (keyword name, keyword value) passed as input to the wrapped function: + - Check if the keyword name exists in the list of valid self.input_types names. + - Check if keyword value has the `neural_type` property. + - If it does, then perform a comparative check and assert that neural types + are compatible (SAME or GREATER). + - Check if keyword value is a container type (list or tuple). If yes, + then perform the elementwise test of neural type above on each element + of the nested structure, recursively. + + Parameters + ---------- + input_types: Either the `input_types` defined at class level, or the local function overridden type definition. + ignore_collections: For backward compatibility, container support can be disabled explicitly using this flag. + When set to True, all nesting is ignored and nest-depth checks are skipped. + kwargs: Dictionary of argument_name:argument_value pairs passed to the wrapped function upon call. + """ + # TODO: Properly implement this + if input_types is None: + return + # Precompute metadata + metadata = TypecheckMetadata(original_types=input_types, ignore_collections=ignore_collections) + + total_input_types = len(input_types) + mandatory_input_types = len(metadata.mandatory_types) + + # Allow number of input arguments to be <= total input neural types. + if len(kwargs) < mandatory_input_types or len(kwargs) > total_input_types: + raise TypeError( + f"Number of input arguments provided ({len(kwargs)}) is not as expected. Function has " + f"{total_input_types} total inputs with {mandatory_input_types} mandatory inputs." + ) + + for key, value in kwargs.items(): + # Check if keys exists in the defined input types + if key not in input_types: + raise TypeError( + f"Input argument {key} has no corresponding input_type match. " + f"Existing input_types = {input_types.keys()}" + ) + + # Perform neural type check + if hasattr(value, "neural_type") and metadata.base_types[key].compare(value.neural_type) not in ( + NeuralTypeComparisonResult.SAME, + NeuralTypeComparisonResult.GREATER, + ): + error_msg = [ + f"{input_types[key].compare(value.neural_type)} :", + f"Input type expected : {input_types[key]}", + f"Input type found : {value.neural_type}", + f"Argument: {key}", + ] + for i, dict_tuple in enumerate(metadata.base_types[key].elements_type.type_parameters.items()): + error_msg.insert(i + 2, f" input param_{i} : {dict_tuple[0]}: {dict_tuple[1]}") + error_msg.extend( + f" input param_{i} : {dict_tuple[0]}: {dict_tuple[1]}" + for i, dict_tuple in enumerate(value.neural_type.elements_type.type_parameters.items()) + ) + + raise TypeError("\n".join(error_msg)) + + # Perform input n dim check + if hasattr(value, "shape"): + value_shape = value.shape + type_shape = metadata.base_types[key].axes + name = key + + if type_shape is not None and len(value_shape) != len(tuple(type_shape)): + raise TypeError( + f"Input shape mismatch occurred for {name} in module {self.__class__.__name__} : \n" + f"Input shape expected = {metadata.base_types[name].axes} | \n" + f"Input shape found : {value_shape}" + ) + + elif isinstance(value, (list, tuple)): + for val in value: + # This initiates a DFS, tracking the depth count as it goes along the nested structure. + # Initial depth is 1 as we consider the current loop to be the 1st step inside the nest. + self.__check_neural_type(val, metadata, depth=1, name=key) + + def _attach_and_validate_output_types(self, out_objects, ignore_collections=False, output_types=None): + """ + This function does a few things. + 1) It ensures that len(out_object) == len(self.output_types). + 2) If the output is a tensor (or list/tuple of list/tuple ... of tensors), it + attaches a neural_type to it. For objects without the neural_type attribute, + such as python objects (dictionaries and lists, primitive data types, structs), + no neural_type is attached. + Note: tensor.neural_type is only checked during _validate_input_types which is + called prior to forward(). + + Parameters + ---------- + output_types: Either the `output_types` defined at class level, or the local function overridden type + definition. + ignore_collections: For backward compatibility, container support can be disabled explicitly using this flag. + When set to True, all nesting is ignored and nest-depth checks are skipped. + out_objects: The outputs of the wrapped function. + """ + # TODO: Properly implement this + if output_types is None: + return + # Precompute metadata + metadata = TypecheckMetadata(original_types=output_types, ignore_collections=ignore_collections) + out_types_list = list(metadata.base_types.items()) + mandatory_out_types_list = list(metadata.mandatory_types.items()) + + # First convert all outputs to list/tuple format to check correct number of outputs + if isinstance(out_objects, (list, tuple)): + out_container = out_objects # can be any rank nested structure + else: + out_container = [out_objects] + + # If this neural type has a *single output*, with *support for nested outputs*, + # then *do not* perform any check on the number of output items against the number + # of neural types (in this case, 1). + # This is done as python will *not* wrap a single returned list into a tuple of length 1, + # instead opting to keep the list intact. Therefore len(out_container) in such a case + # is the length of all the elements of that list - each of which has the same corresponding + # neural type (defined as the singular container type). + if metadata.is_singular_container_type: + pass + + # In all other cases, python will wrap multiple outputs into an outer tuple. + # Allow number of output arguments to be <= total output neural types and >= mandatory outputs. + + elif len(out_container) > len(out_types_list) or len(out_container) < len(mandatory_out_types_list): + raise TypeError( + "Number of output arguments provided ({}) is not as expected. It should be larger than {} and " + "less than {}.\n" + "This can be either because insufficient/extra number of output NeuralTypes were provided," + "or the provided NeuralTypes {} should enable container support " + "(add '[]' to the NeuralType definition)".format( + len(out_container), len(out_types_list), len(mandatory_out_types_list), output_types + ) + ) + + # Attach types recursively, if possible + if not isinstance(out_objects, tuple) and not isinstance(out_objects, list): + # Here, out_objects is a single object which can potentially be attached with a NeuralType + try: + out_objects.neural_type = out_types_list[0][1] + except AttributeError: + pass + + # Perform output n dim check + if hasattr(out_objects, "shape"): + value_shape = out_objects.shape + type_shape = out_types_list[0][1].axes + if type_shape is not None and len(value_shape) != len(type_shape): + name = out_types_list[0][0] + + raise TypeError( + f"Output shape mismatch occurred for {name} in module {self.__class__.__name__} : \n" + f"Output shape expected = {type_shape} | \n" + f"Output shape found : {value_shape}" + ) + + elif metadata.is_singular_container_type: + depth = 0 if len(out_objects) == 1 and type(out_objects) is tuple else 1 + for res in out_objects: + self.__attach_neural_type(res, metadata, depth=depth, name=out_types_list[0][0]) + else: + # If more then one item is returned in a return statement, python will wrap + # the output with an outer tuple. Therefore there must be a 1:1 correspondence + # of the output_neural type (with or without nested structure) to the actual output + # (whether it is a single object or a nested structure of objects). + # Therefore in such a case, we "start" the DFS at depth 0 - since the recursion is + # being applied on 1 neural type : 1 output struct (single or nested output). + # Since we are guaranteed that the outer tuple will be built by python, + # assuming initial depth of 0 is appropriate. + for ind, res in enumerate(out_objects): + self.__attach_neural_type(res, metadata, depth=0, name=out_types_list[ind][0]) + + def __check_neural_type(self, obj, metadata, depth, name=None): + """Checks if the object is of the correct type, and attaches the correct NeuralType.""" + if isinstance(obj, (tuple, list)): + for elem in obj: + self.__check_neural_type(elem, metadata, depth + 1, name=name) + return # after processing nest, return to avoid testing nest itself + + type_val = metadata.base_types[name] + + # If nest depth doesnt match neural type structure depth, raise an error + if not metadata.ignore_collections and depth != metadata.container_depth[name]: + raise TypeError( + "While checking input neural types,\n" + "Nested depth of value did not match container specification:\n" + f"Current nested depth of NeuralType '{name}' ({type_val}): {depth}\n" + f"Expected nested depth : {metadata.container_depth[name]}" + ) + + if hasattr(obj, "neural_type") and type_val.compare(obj.neural_type) not in ( + NeuralTypeComparisonResult.SAME, + NeuralTypeComparisonResult.GREATER, + ): + raise TypeError( + f"{type_val.compare(obj.neural_type)} : \n" + f"Input type expected = {type_val} | \n" + f"Input type found : {obj.neural_type}" + ) + + # Perform input n dim check + if hasattr(obj, "shape"): + value_shape = obj.shape + type_shape = type_val.axes + + if type_shape is not None and len(value_shape) != len(type_shape): + raise TypeError( + f"Input shape mismatch occurred for {name} in module {self.__class__.__name__} : \n" + f"Input shape expected = {type_shape} | \n" + f"Input shape found : {value_shape}" + ) + + def __attach_neural_type(self, obj, metadata, depth, name=None): + """Attach NeuralType to the object.""" + if isinstance(obj, (tuple, list)): + for elem in obj: + self.__attach_neural_type(elem, metadata, depth=depth + 1, name=name) + return # after processing nest, return to avoid argument insertion into nest itself + + type_val = metadata.base_types[name] + + # If nest depth doesnt match neural type structure depth, raise an error + if not metadata.ignore_collections and depth != metadata.container_depth[name]: + raise TypeError( + "While attaching output neural types,\n" + "Nested depth of value did not match container specification:\n" + f"Current nested depth of NeuralType '{name}' ({type_val}): {depth}\n" + f"Expected nested depth : {metadata.container_depth[name]}" + ) + + try: + obj.neural_type = type_val + except AttributeError: + pass + + # Perform output n dim check + if hasattr(obj, "shape"): + value_shape = obj.shape + type_shape = type_val.axes + + if type_shape is not None and len(value_shape) != len(type_shape): + raise TypeError( + f"Output shape mismatch occurred for {name} in module {self.__class__.__name__} : \n" + f"Output shape expected = {type_shape} | \n" + f"Output shape found : {value_shape}" + )
+ + +
[docs]class Serialization(ABC): + """Base class for serialization.""" + +
[docs] @classmethod + def from_config_dict(cls, config: "DictConfig", trainer: Optional[Trainer] = None): + """Instantiates object using DictConfig-based configuration""" + # Resolve the config dict + if _HAS_HYDRA: + if isinstance(config, DictConfig): + config = OmegaConf.to_container(config, resolve=True) + config = OmegaConf.create(config) + OmegaConf.set_struct(config, True) + + config = mridc.utils.model_utils.maybe_update_config_version(config) # type: ignore + + # Hydra 0.x API + if ("cls" in config or "target" in config) and "params" in config and _HAS_HYDRA: + # regular hydra-based instantiation + instance = hydra.utils.instantiate(config=config) + elif "_target_" in config and _HAS_HYDRA: + # regular hydra-based instantiation + instance = hydra.utils.instantiate(config=config) + else: + instance = None + prev_error = "" + + # Attempt class path resolution from config `target` class (if it exists) + if "target" in config: + target_cls = config["target"] # No guarantee that this is a omegaconf class + imported_cls = None + try: + # try to import the target class + imported_cls = mridc.utils.model_utils.import_class_by_path(target_cls) # type: ignore + + # use subclass instead + if issubclass(cls, imported_cls): + imported_cls = cls + if accepts_trainer := Serialization._inspect_signature_for_trainer(imported_cls): + if trainer is None: + # Create a dummy PL trainer object + cfg_trainer = TrainerConfig( + gpus=1, accelerator="ddp", num_nodes=1, logger=False, checkpoint_callback=False + ) + trainer = Trainer(cfg_trainer) + instance = imported_cls(cfg=config, trainer=trainer) # type: ignore + else: + instance = imported_cls(cfg=config) # type: ignore + + except Exception as e: + tb = traceback.format_exc() + prev_error = f"Model instantiation failed.\nTarget class: {target_cls}\nError: {e}\n{tb}" + + logging.debug(prev_error + "\n falling back to 'cls'.") + # target class resolution was unsuccessful, fall back to current `cls` + if instance is None: + try: + if accepts_trainer := Serialization._inspect_signature_for_trainer(cls): + instance = cls(cfg=config) # type: ignore + except Exception as e: + # report saved errors, if any, and raise the current error + if prev_error: + logging.error(f"{prev_error}") + raise e from e + + if not hasattr(instance, "_cfg"): + instance._cfg = config + return instance
+ +
[docs] def to_config_dict(self) -> "DictConfig": + """Returns object's configuration to config dictionary""" + if hasattr(self, "_cfg") and self._cfg is not None: # type: ignore + # Resolve the config dict + if _HAS_HYDRA and isinstance(self._cfg, DictConfig): # type: ignore + config = OmegaConf.to_container(self._cfg, resolve=True) # type: ignore + config = OmegaConf.create(config) + OmegaConf.set_struct(config, True) + + config = mridc.utils.model_utils.maybe_update_config_version(config) # type: ignore + + self._cfg = config + + return self._cfg + raise NotImplementedError( + "to_config_dict() can currently only return object._cfg but current object does not have it." + )
+ + @classmethod + def _inspect_signature_for_trainer(cls, check_cls): + """Inspects the signature of the class to see if it accepts a trainer argument.""" + if hasattr(check_cls, "__init__"): + signature = inspect.signature(check_cls.__init__) + return Trainer in signature.parameters + return False
+ + +
[docs]class FileIO(ABC): + """Base class for file IO.""" + +
[docs] def save_to(self, save_path: str): + """Saves module/model with weights""" + raise NotImplementedError()
+ +
[docs] @classmethod + def restore_from( + cls, + restore_path: str, + override_config_path: Optional[str] = None, + map_location: Optional[torch.device] = None, + strict: bool = True, + return_config: bool = False, + trainer: Optional[Trainer] = None, + save_restore_connector: SaveRestoreConnector = None, + ): + """Restores module/model with weights""" + raise NotImplementedError()
+ +
[docs] @classmethod + def from_config_file(cls, path2yaml_file: str): + """ + Instantiates an instance of mridc Model from YAML config file. Weights will be initialized randomly. + + Parameters + ---------- + path2yaml_file: path to yaml file with model configuration + + Returns + ------- + Model instance. + """ + if issubclass(cls, Serialization): + conf = OmegaConf.load(path2yaml_file) + return cls.from_config_dict(config=conf) + raise NotImplementedError()
+ +
[docs] def to_config_file(self, path2yaml_file: str): + """ + Saves current instance's configuration to YAML config file. Weights will not be saved. + + Parameters + ---------- + path2yaml_file: path2yaml_file: path to yaml file where model configuration will be saved. + """ + if hasattr(self, "_cfg"): + self._cfg = mridc.utils.model_utils.maybe_update_config_version(self._cfg) # type: ignore + with open(path2yaml_file, "w", encoding="utf-8") as fout: + OmegaConf.save(config=self._cfg, f=fout, resolve=True) + else: + raise NotImplementedError()
+ + +
[docs]@total_ordering +@dataclass +class PretrainedModelInfo: + """Class to store information about a pretrained model.""" + + pretrained_model_name: str + description: str + location: str + class_: Union["Model", None] = None + aliases: Union[List[str], None] = None + + def __repr__(self): + base = self.__class__.__name__ + extras = ( + "pretrained_model_name={pretrained_model_name},\n\t" + "description={description},\n\t" + "location={location}".format(**self.__dict__) + ) + + if self.class_ is not None: + extras = "{extras},\n\t" "class_={class_}".format(extras=extras, **self.__dict__) + + return f"{base}(\n\t{extras}\n)" + + def __hash__(self): + return hash(self.location) + + def __eq__(self, other): + # another object is equal to self, iff + # if it's hash is equal to hash(self) + return hash(self) == hash(other) or self.pretrained_model_name == other.pretrained_model_name + + def __lt__(self, other): + return self.pretrained_model_name < other.pretrained_model_name
+ + +
[docs]class Model(Typing, Serialization, FileIO, ABC): # type: ignore + """Abstract class offering interface which should be implemented by all mridc models.""" + +
[docs] @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + Should list all pre-trained models available. + Note: There is no check that requires model names and aliases to be unique. In the case of a collision, + whatever model (or alias) is listed first in the returned list will be instantiated. + + Returns + ------- + A list of PretrainedModelInfo entries. + """ + raise NotImplementedError()
+ +
[docs] @classmethod + def get_available_model_names(cls) -> List[str]: + """ + Returns the list of model names available. To get the complete model description use list_available_models(). + + Returns + ------- + A list of model names. + """ + return ( + [model.pretrained_model_name for model in cls.list_available_models()] # type: ignore + if cls.list_available_models() is not None + else [] + )
+ +
[docs] @classmethod + def from_pretrained( + cls, + model_name: str, + refresh_cache: bool = False, + override_config_path: Optional[str] = None, + map_location: Optional[torch.device] = None, + strict: bool = True, + return_config: bool = False, + trainer: Optional[Trainer] = None, + save_restore_connector: SaveRestoreConnector = None, + ): + """ + Instantiates an instance of mridc. Use restore_from() to instantiate from a local .mridc file. + + Parameters + ---------- + model_name: String key which will be used to find the module. + refresh_cache: If set to True, then when fetching from cloud, this will re-fetch the file from cloud even if it + is already found in a cache locally. + override_config_path: Path to a yaml config that will override the internal config file. + map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will + select a GPU if available, falling back to CPU otherwise. + strict: Passed to torch.load_state_dict. By default, True. + return_config: If set to true, will return just the underlying config of the restored model as an + OmegaConf/DictConfig object without instantiating the model. + trainer: Optional Trainer objects to use for restoring the model. + save_restore_connector: Optional SaveRestoreConnector object to use for restoring the model. + + Returns + ------- + A model instance of a particular model class or its underlying config (if return_config is set). + """ + if save_restore_connector is None: + save_restore_connector = SaveRestoreConnector() + + location_in_the_cloud = None + description = None + models = cls.list_available_models() + if models is not None: + for pretrained_model_info in cls.list_available_models(): # type: ignore + found = False + if pretrained_model_info.pretrained_model_name == model_name: + found = True + elif pretrained_model_info.aliases is not None: + for alias in pretrained_model_info.aliases: + if alias == model_name: + found = True + break + if found: + location_in_the_cloud = pretrained_model_info.location + description = pretrained_model_info.description + class_ = pretrained_model_info.class_ + break + + if location_in_the_cloud is None: + raise FileNotFoundError( + f"Model {model_name} was not found. " + f"Check cls.list_available_models() for the list of all available models." + ) + filename = location_in_the_cloud.split("/")[-1] + url = location_in_the_cloud.replace(filename, "") + cache_dir = Path.joinpath(mridc.utils.model_utils.resolve_cache_dir(), f"{filename[:-5]}") # type: ignore + # If either description and location in the cloud changes, this will force re-download + # of the model. + cache_subfolder = hashlib.sha512((location_in_the_cloud + description).encode("utf-8")).hexdigest() + # if file exists on cache_folder/subfolder, it will be re-used, unless refresh_cache is True + mridc_model_file_in_cache = maybe_download_from_cloud( + url=url, filename=filename, cache_dir=cache_dir, subfolder=cache_subfolder, refresh_cache=refresh_cache + ) + logging.info("Instantiating model from pre-trained checkpoint") + if class_ is None: + class_ = cls + + return class_.restore_from( + restore_path=mridc_model_file_in_cache, + override_config_path=override_config_path, + map_location=map_location, + strict=strict, + return_config=return_config, + trainer=trainer, + save_restore_connector=save_restore_connector, + )
+ + +
[docs]class typecheck: + """Decorator to check the type of the input arguments.""" + +
[docs] class TypeState(Enum): + """ + Placeholder to denote the default value of type information provided. + If the constructor of this decorator is used to override the class level type definition, this enum value + indicate that types will be overridden. + """ + + UNINITIALIZED = 0
+ + def __init__( + self, + input_types: Union[TypeState, Optional[Dict[str, NeuralType]]] = TypeState.UNINITIALIZED, + output_types: Union[TypeState, Optional[Dict[str, NeuralType]]] = TypeState.UNINITIALIZED, + ignore_collections: bool = False, + ): + """ + A decorator which performs input-output neural type checks, and attaches neural types to the output of the + function that it wraps. + Requires that the class inherit from `mridc.core.Typing` in order to perform type checking, and will raise an + error if that is not the case. + + # Usage (Class level type support) + @typecheck() + def fn(self, arg1, arg2, ...): + ... + # Usage (Function level type support) + @typecheck(input_types=..., output_types=...) + def fn(self, arg1, arg2, ...): + ... + + Points to be noted: + 1) The brackets () in `@typecheck()` are necessary. + You will encounter a TypeError: __init__() takes 1 positional argument but X + were given without those brackets. + 2) The function can take any number of positional arguments during definition. + When you call this function, all arguments must be passed using kwargs only. + """ + self.input_types = input_types + self.output_types = output_types + + self.input_override = input_types != self.TypeState.UNINITIALIZED + self.output_override = output_types != self.TypeState.UNINITIALIZED + self.ignore_collections = ignore_collections + + @wrapt.decorator(enabled=is_typecheck_enabled) + def __call__(self, wrapped, instance: Typing, args, kwargs): + if instance is None: + raise RuntimeError("Only classes which inherit mridc.core.Typing can use this decorator !") + + if not isinstance(instance, Typing): + raise RuntimeError("Only classes which inherit mridc.core.Typing can use this decorator !") + + if hasattr(instance, "input_ports") or hasattr(instance, "output_ports"): + raise RuntimeError( + "Typing requires override of `input_types()` and `output_types()`, " + "not `input_ports() and `output_ports()`" + ) + + # Preserve type information + if self.input_types is typecheck.TypeState.UNINITIALIZED: + self.input_types = instance.input_types + + if self.output_types is typecheck.TypeState.UNINITIALIZED: + self.output_types = instance.output_types + + # Resolve global type or local overridden type + input_types = self.input_types if self.input_override else instance.input_types + if self.output_override: + output_types = self.output_types + else: + output_types = instance.output_types + + # If types are not defined, skip type checks and just call the wrapped method + if input_types is None and output_types is None: + return wrapped(*args, **kwargs) + + # Check that all arguments are kwargs + if input_types is not None and len(args) > 0: + raise TypeError("All arguments must be passed by kwargs only for typed methods") + + # Perform rudimentary input checks here + instance._validate_input_types(input_types=input_types, ignore_collections=self.ignore_collections, **kwargs) + + # Call the method - this can be forward, or any other callable method + outputs = wrapped(*args, **kwargs) + + instance._attach_and_validate_output_types( + output_types=output_types, ignore_collections=self.ignore_collections, out_objects=outputs + ) + + return outputs + +
[docs] @staticmethod + def set_typecheck_enabled(enabled: bool = True): + """Set the global typecheck flag.""" + global _TYPECHECK_ENABLED + _TYPECHECK_ENABLED = enabled
+ +
[docs] @staticmethod + @contextmanager + def disable_checks(): + """Temporarily disable type checks.""" + typecheck.set_typecheck_enabled(enabled=False) + try: + yield + finally: + typecheck.set_typecheck_enabled(enabled=True)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/classes/dataset.html b/docs/build/html/_modules/mridc/core/classes/dataset.html new file mode 100644 index 00000000..f2d969d5 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/classes/dataset.html @@ -0,0 +1,201 @@ + + + + + + mridc.core.classes.dataset — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.classes.dataset
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.classes.dataset

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/classes/dataset.py
+from abc import ABC
+from dataclasses import dataclass
+from typing import Optional
+
+from torch.utils import data
+
+__all__ = ["Dataset", "DatasetConfig", "IterableDataset"]
+
+from mridc.core.classes.common import Serialization, Typing, typecheck
+
+
+
[docs]class Dataset(data.Dataset, Typing, Serialization, ABC): + """Dataset with output ports. Please Note: Subclasses of IterableDataset should *not* implement input_types.""" + + @staticmethod + def _collate_fn(batch): + """ + A default implementation of a collation function. + Users should override this method to define custom data loaders. + """ + return data.dataloader.default_collate(batch) + +
[docs] @typecheck() + def collate_fn(self, batch): + """ + This is the method that user pass as functor to DataLoader. + The method optionally performs neural type checking and add types to the outputs. + + Please note, subclasses of Dataset should not implement `input_types`. + + # Usage: + + .. code-block:: + + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) + + Returns + ------- + Collated batch, with or without types. + """ + if self.input_types is not None: + raise TypeError("Datasets should not implement `input_types` as they are not checked") + + # Simply forward the inner `_collate_fn` + return self._collate_fn(batch)
+ + +
[docs]class IterableDataset(data.IterableDataset, Typing, Serialization, ABC): + """ + Iterable Dataset with output ports. + Please Note: Subclasses of IterableDataset should *not* implement input_types. + """ + + @staticmethod + def _collate_fn(batch): + """ + A default implementation of a collation function. + Users should override this method to define custom data loaders. + """ + return data.dataloader.default_collate(batch) + +
[docs] @typecheck() + def collate_fn(self, batch): + """ + This is the method that user pass as functor to DataLoader. + The method optionally performs neural type checking and add types to the outputs. + + # Usage: + + .. code-block:: + + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) + + Returns + ------- + Collated batch, with or without types. + """ + if self.input_types is not None: + raise TypeError("Datasets should not implement `input_types` as they are not checked") + + # Simply forward the inner `_collate_fn` + return self._collate_fn(batch)
+ + +
[docs]@dataclass +class DatasetConfig: + """Dataset configuration.""" + + batch_size: int = 32 + drop_last: bool = False + shuffle: bool = False + num_workers: Optional[int] = 0 + pin_memory: bool = True
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/classes/export.html b/docs/build/html/_modules/mridc/core/classes/export.html new file mode 100644 index 00000000..2488d802 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/classes/export.html @@ -0,0 +1,305 @@ + + + + + + mridc.core.classes.export — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.classes.export
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.classes.export

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/classes/exportable.py
+
+from abc import ABC
+
+import torch
+from torch.onnx import TrainingMode
+
+from mridc.core.classes.common import typecheck
+from mridc.core.utils.neural_type_utils import get_dynamic_axes, get_io_names
+from mridc.utils import logging
+from mridc.utils.export_utils import (
+    ExportFormat,
+    get_export_format,
+    parse_input_example,
+    replace_for_export,
+    verify_runtime,
+    wrap_forward_method,
+)
+
+__all__ = ["ExportFormat", "Exportable"]
+
+
+
[docs]class Exportable(ABC): + """ + This Interface should be implemented by particular classes derived from mridc.core.NeuralModule or + mridc.core.ModelPT. It gives these entities ability to be exported for deployment to formats such as ONNX. + """ + + @property + def input_module(self): + return self + + @property + def output_module(self): + return self + +
[docs] def export( + self, + output: str, + input_example=None, + verbose=False, + export_params=True, + do_constant_folding=True, + onnx_opset_version=None, + try_script: bool = False, + training=TrainingMode.EVAL, + check_trace: bool = False, + use_dynamic_axes: bool = True, + dynamic_axes=None, + check_tolerance=0.01, + ): + """ + Export the module to a file. + + Parameters + ---------- + output: The output file path. + input_example: A dictionary of input names and values. + verbose: If True, print out the export process. + export_params: If True, export the parameters of the module. + do_constant_folding: If True, do constant folding. + onnx_opset_version: The ONNX opset version to use. + try_script: If True, try to export as TorchScript. + training: Training mode for the export. + check_trace: If True, check the trace of the exported model. + use_dynamic_axes: If True, use dynamic axes for the export. + dynamic_axes: A dictionary of input names and dynamic axes. + check_tolerance: The tolerance for the check_trace. + """ + my_args = locals().copy() + my_args.pop("self") + + exportables = [] + for m in self.modules(): # type: ignore + if isinstance(m, Exportable): + exportables.append(m) + + qual_name = self.__module__ + "." + self.__class__.__qualname__ + format = get_export_format(output) + output_descr = f"{qual_name} exported to {format}" + + # Pytorch's default for None is too low, can't pass None through + if onnx_opset_version is None: + onnx_opset_version = 13 + + try: + # Disable typechecks + typecheck.set_typecheck_enabled(enabled=False) + + # Allow user to completely override forward method to export + forward_method, old_forward_method = wrap_forward_method(self) + + # Set module mode + with torch.onnx.select_model_mode_for_export( + self, training + ), torch.inference_mode(), torch.jit.optimized_execution(True): + + if input_example is None: + input_example = self.input_module.input_example() + + # Remove i/o examples from args we propagate to enclosed Exportables + my_args.pop("output") + my_args.pop("input_example") + + # Run (possibly overridden) prepare methods before calling forward() + for ex in exportables: + ex._prepare_for_export(**my_args, noreplace=True) + self._prepare_for_export(output=output, input_example=input_example, **my_args) + + input_list, input_dict = parse_input_example(input_example) + input_names = self.input_names + output_names = self.output_names + output_example = tuple(self.forward(*input_list, **input_dict)) # type: ignore + + jitted_model = None + if try_script: + try: + jitted_model = torch.jit.script(self) + except Exception as e: + logging.error(f"jit.script() failed!\n{e}") + + if format == ExportFormat.TORCHSCRIPT: + if jitted_model is None: + jitted_model = torch.jit.trace_module( + self, + {"forward": tuple(input_list) + tuple(input_dict.values())}, + strict=True, + check_trace=check_trace, + check_tolerance=check_tolerance, + ) + if not self.training: # type: ignore + jitted_model = torch.jit.optimize_for_inference(jitted_model) + if verbose: + logging.info(f"JIT code:\n{jitted_model.code}") + jitted_model.save(output) + elif format == ExportFormat.ONNX: + if jitted_model is None: + jitted_model = self + + # dynamic axis is a mapping from input/output_name => list of "dynamic" indices + if dynamic_axes is None and use_dynamic_axes: + dynamic_axes = get_dynamic_axes(self.input_module.input_types, input_names) + dynamic_axes.update(get_dynamic_axes(self.output_module.output_types, output_names)) + + torch.onnx.export( + jitted_model, + input_example, + output, + input_names=input_names, + output_names=output_names, + verbose=verbose, + export_params=export_params, + do_constant_folding=do_constant_folding, + dynamic_axes=dynamic_axes, + opset_version=onnx_opset_version, + ) + + if check_trace: + verify_runtime(output, input_list, input_dict, input_names, output_names, output_example) + + else: + raise ValueError(f"Encountered unknown export format {format}.") + finally: + typecheck.set_typecheck_enabled(enabled=True) + if forward_method: + type(self).forward = old_forward_method # type: ignore + self._export_teardown() + return [output], [output_descr]
+ + @property + def disabled_deployment_input_names(self): + """Implement this method to return a set of input names disabled for export""" + return set() + + @property + def disabled_deployment_output_names(self): + """Implement this method to return a set of output names disabled for export""" + return set() + + @property + def supported_export_formats(self): + """Implement this method to return a set of export formats supported. Default is all types.""" + return {ExportFormat.ONNX, ExportFormat.TORCHSCRIPT} + + def _prepare_for_export(self, **kwargs): + """ + Override this method to prepare module for export. This is in-place operation. + Base version does common necessary module replacements (Apex etc) + """ + if "noreplace" not in kwargs: + replace_for_export(self) + + def _export_teardown(self): + """ + Override this method for any teardown code after export. + """ + + @property + def input_names(self): + """Implement this method to return a list of input names""" + return get_io_names(self.input_module.input_types, self.disabled_deployment_input_names) + + @property + def output_names(self): + """Override this method to return a set of output names disabled for export""" + return get_io_names(self.output_module.output_types, self.disabled_deployment_output_names)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/classes/loss.html b/docs/build/html/_modules/mridc/core/classes/loss.html new file mode 100644 index 00000000..ecefbadb --- /dev/null +++ b/docs/build/html/_modules/mridc/core/classes/loss.html @@ -0,0 +1,110 @@ + + + + + + mridc.core.classes.loss — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.classes.loss
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.classes.loss

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/classes/loss.py
+
+import torch
+
+__all__ = ["Loss"]
+
+from mridc.core.classes.common import Serialization, Typing
+
+
+
[docs]class Loss(torch.nn.modules.loss._Loss, Typing, Serialization): + """Inherit this class to implement custom loss."""
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/classes/modelPT.html b/docs/build/html/_modules/mridc/core/classes/modelPT.html new file mode 100644 index 00000000..00b198ac --- /dev/null +++ b/docs/build/html/_modules/mridc/core/classes/modelPT.html @@ -0,0 +1,1318 @@ + + + + + + mridc.core.classes.modelPT — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.classes.modelPT
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.classes.modelPT

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/classes/modelPT.py
+
+import copy
+import inspect
+import os
+import uuid
+from abc import abstractmethod
+from os import path
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import hydra
+import torch
+from omegaconf import DictConfig, OmegaConf, open_dict
+from pytorch_lightning import LightningModule, Trainer
+from pytorch_lightning.utilities import rank_zero_only
+
+from mridc.core.classes.common import Model
+
+__all__ = ["ModelPT"]
+
+import mridc.core.optim
+from mridc import package_info
+from mridc.core.connectors.save_restore_connector import SaveRestoreConnector
+from mridc.utils import logging
+from mridc.utils.app_state import AppState
+from mridc.utils.get_rank import is_global_rank_zero
+import mridc.utils
+
+
+
[docs]class ModelPT(LightningModule, Model): + """Interface for Pytorch-lightning based mridc models""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + """ + Base class from which all mridc models should inherit + + Internal global flags that determine core functionality of ModelPT. + _MODEL_IS_RESTORED: + This flag determines the context of the model - whether the model is currently being + restored or not. + - When set, it can be assumed that the model's will disable all automatic methods - + setup_training_data(), setup_validation/test_data() and their multi equivalents. + - If a model is being restored from a archive file (tarfile), it can be assumed that + under this context, the cwd is *inside* the tarfile itself. + _MODEL_RESTORE_PATH: + A string path to a a file from which the model is being restored. + This file can either be a PyTorch Lightning Checkpoint, or a archive (tarfile) that contains + artifact objects. + If it is an archive file, during restoration, the cwd will be temporarily moved to inside the + archive itself. + + Parameters + ---------- + cfg: configuration object. The cfg object should have (optionally) the following sub-configs: + - train_ds - to instantiate training dataset + - validation_ds - to instantiate validation dataset + - test_ds - to instantiate testing dataset + - optim - to instantiate optimizer with learning rate scheduler + trainer: Pytorch Lightning Trainer instance + """ + if trainer is not None and not isinstance(trainer, Trainer): + raise ValueError( + f"trainer constructor argument must be either None or pytorch_lightning.Trainer. " + f"But got {type(trainer)} instead." + ) + super().__init__() + + # set global vars in AppState + app_state = AppState() + + # Convert config to a DictConfig + cfg = mridc.utils.model_utils.convert_model_config_to_dict_config(cfg) + + # Convert config to support Hydra 1.0+ instantiation + cfg = mridc.utils.model_utils.maybe_update_config_version(cfg) + + if "model" in cfg: + raise ValueError( + "Creating model config node is forbidden due to collision problem when loading from checkpoint." + ) + + if "target" not in cfg: + # This is for Jarvis service. + OmegaConf.set_struct(cfg, False) + cfg.target = "{0}.{1}".format(self.__class__.__module__, self.__class__.__name__) + OmegaConf.set_struct(cfg, True) + + if "mridc_version" not in cfg: + with open_dict(cfg): + cfg.mridc_version = package_info.__version__ + + self._cfg = cfg + + self.save_hyperparameters("cfg") + self._train_dl = None + self._validation_dl = None + self._test_dl = None + self._optimizer_param_groups = None + self._optimizer = None + self._scheduler = None + self.trainer = trainer # reference required for self.*_rank + self._trainer = self.trainer # alias for backward compatibility + self._save_restore_connector = SaveRestoreConnector() + + self._set_model_guid() + + # Set device_id in AppState + if torch.cuda.is_available() and torch.cuda.current_device() is not None: + app_state.device_id = torch.cuda.current_device() + + if self._cfg is not None and not self._is_model_being_restored(): + if "train_ds" in self._cfg and self._cfg.train_ds is not None: + self.setup_training_data(self._cfg.train_ds) + + if "validation_ds" in self._cfg and self._cfg.validation_ds is not None: + self.setup_multiple_validation_data(self._cfg.validation_ds) # type: ignore + + if "test_ds" in self._cfg and self._cfg.test_ds is not None: + self.setup_multiple_test_data(test_data_config=None) # type: ignore + + else: + if "train_ds" in self._cfg and self._cfg.train_ds is not None: # type: ignore + logging.warning( + f"If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() " + f"method and provide a valid configuration file to setup the train data loader.\n" + f"Train config : \n{OmegaConf.to_yaml(self._cfg.train_ds)}" # type: ignore + ) + if "validation_ds" in self._cfg and self._cfg.validation_ds is not None: # type: ignore + logging.warning( + f"If you intend to do validation, please call the ModelPT.setup_validation_data() or " + f"ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to " + f"setup the validation data loader(s). \n" + f"Validation config : \n{OmegaConf.to_yaml(self._cfg.validation_ds)}" # type: ignore + ) + if "test_ds" in self._cfg and self._cfg.test_ds is not None: # type: ignore + logging.warning( + f"Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method " + f"and provide a valid configuration file to setup the test data loader(s).\n" + f"Test config : \n{OmegaConf.to_yaml(self._cfg.test_ds)}" # type: ignore + ) + + # ModelPT wrappers over subclass implementations + self._training_step = mridc.utils.model_utils.wrap_training_step(self.training_step) # type: ignore + +
[docs] def __init_subclass__(cls) -> None: + """This method is called when a subclass is created.""" + cls._save_restore_connector = SaveRestoreConnector()
+ +
[docs] def register_artifact(self, config_path: str, src: str, verify_src_exists: bool = True): + """ + Register model artifacts with this function. These artifacts (files) will be included inside .mridc file when + model.save_to("model.mridc") is called. + + How it works: + 1. It always returns existing absolute path which can be used during Model constructor call EXCEPTION: \ + src is None or "" in which case nothing will be done and src will be returned + 2. It will add (config_path, model_utils.ArtifactItem()) pair to self.artifacts + + If "src" is local existing path, then it will be returned in absolute path form. + elif "src" starts with "mridc_file:unique_artifact_name" .mridc will be untarred to a temporary folder \ + location and an actual existing path will be returned else an error will be raised. + + WARNING: use .register_artifact calls in your models' constructors. + The returned path is not guaranteed to exist after you have exited your model's constructor. + + Parameters + ---------- + config_path: Artifact key. Usually corresponds to the model config. + src: Path to artifact. + verify_src_exists: If set to False, then the artifact is optional and register_artifact will return None \ + even if src is not found. Defaults to True. + + Returns + ------- + If src is not None or empty it always returns absolute path which is guaranteed to exist during model \ + instance life. + """ + if src is None or not src: + return src + + if not hasattr(self, "artifacts"): + self.artifacts: Dict[str, mridc.utils.model_utils.ArtifactItem] = {} + + if self.artifacts is None: + self.artifacts = {} + + if config_path in self.artifacts: + logging.warning( + f"You tried to register an artifact under config key={config_path} but an artifact for " + f"it has already been registered." + ) + + return self._save_restore_connector.register_artifact(self, config_path, src, verify_src_exists)
+ +
[docs] def save_to(self, save_path: str): + """ + Saves model instance (weights and configuration) into .mridc file. You can use "restore_from" method to fully + restore instance from .mridc file. .mridc file is an archive (tar.gz) with the following: + - model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for \ + model's constructor + - model_wights.ckpt - model checkpoint + + Parameters + ---------- + Path to .mridc file where model instance should be saved. + """ + + def maybe_make_save_dir(_path: "Path"): + """Creates directory if it does not exist""" + if not _path.parent.exists(): + _path.parent.mkdir(parents=True) + + save_path = Path(save_path).expanduser().resolve() # type: ignore + app_state = AppState() + if app_state.model_parallel_size is not None: + if app_state.model_parallel_size > 1 and type(self._save_restore_connector) is SaveRestoreConnector: + raise ValueError( + "Default mridc SaveRestoreConnector will not work in model parallel mode. You should use a " + "connector which supports model parallel mode. You can also use a custom one." + ) + if app_state.data_parallel_rank == 0: + maybe_make_save_dir(Path(save_path)) + # connector checks for ranks properly, no need to check here + self._save_restore_connector.save_to(self, str(save_path)) # downstream tasks expect str, not Path + elif is_global_rank_zero(): + maybe_make_save_dir(Path(save_path)) + self._save_restore_connector.save_to(self, str(save_path)) # downstream tasks expect str, not Path
+ +
[docs] @classmethod + def restore_from( # type: ignore + cls, + restore_path: str, + override_config_path: Optional[Union[OmegaConf, str]] = None, + map_location: Optional[torch.device] = None, + strict: bool = True, + return_config: bool = False, + save_restore_connector: SaveRestoreConnector = None, + trainer: Optional[Trainer] = None, + ): + """ + Restores model instance (weights and configuration) from .mridc file. + + Parameters + ---------- + restore_path: path to .mridc file from which model should be instantiated override_config_path: path to a \ + yaml config that will override the internal config file or an OmegaConf/DictConfig object representing the \ + model config. + map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will \ + select a GPU if available, falling back to CPU otherwise. + strict: Passed to load_state_dict. By default, True. + return_config: If set to true, will return just the underlying config of the restored model as an \ + OmegaConf/DictConfig object without instantiating the model. + trainer: Optional, a pytorch lightning Trainer object that will be forwarded to the instantiated model's \ + constructor. + save_restore_connector: Can be overridden to add custom save and restore logic. + + Example + ------- + + .. code-block:: + + model = mridc.collections.asr.models.EncDecCTCModel.restore_from('asr.mridc') + assert isinstance(model, mridc.collections.asr.models.EncDecCTCModel) + + + Returns + ------- + An instance of type cls or its underlying config (if return_config is set). + """ + if save_restore_connector is None: + save_restore_connector = SaveRestoreConnector() + + restore_path = os.path.abspath(os.path.expanduser(restore_path)) + if not path.exists(restore_path): + raise FileNotFoundError(f"Can't find {restore_path}") + + app_state = AppState() + app_state.model_restore_path = restore_path + + cls.update_save_restore_connector(save_restore_connector) + instance = cls._save_restore_connector.restore_from( + cls, restore_path, override_config_path, map_location, strict, return_config, trainer + ) + if isinstance(instance, ModelPT): + instance._save_restore_connector = save_restore_connector + return instance
+ +
[docs] @classmethod + def load_from_checkpoint( + cls, + checkpoint_path: str, + *args, + map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + hparams_file: Optional[str] = None, + strict: bool = True, + **kwargs, + ): + """ + Loads ModelPT from checkpoint, with some maintenance of restoration. + For documentation, please refer to LightningModule.load_from_checkpoint() documentation. + """ + checkpoint = None + + try: + cls._set_model_restore_state(is_being_restored=True) + + checkpoint = super().load_from_checkpoint( + checkpoint_path=checkpoint_path, + *args, # type: ignore + map_location=map_location, + hparams_file=hparams_file, + strict=strict, + **kwargs, + ) + + finally: + cls._set_model_restore_state(is_being_restored=False) + return checkpoint
+ +
[docs] @abstractmethod + def setup_training_data(self, train_data_config: Union[DictConfig, Dict]): + """Setups data loader to be used in training."""
+ +
[docs] @abstractmethod + def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]): + """Setups data loader to be used in validation."""
+ +
[docs] def setup_test_data(self, test_data_config: Union[DictConfig, Dict]): + """(Optionally) Setups data loader to be used in test.""" + raise NotImplementedError()
+ +
[docs] def setup_multiple_validation_data(self, val_data_config: Union[DictConfig, Dict]): + """(Optionally) Setups data loader to be used in validation.""" + # Set some placeholder overridden by helper method + self._val_dl_idx = 0 + self.validation_names = None + + # preserve config + self._update_dataset_config(dataset_name="validation", config=val_data_config) + + try: + self._multi_dataset_mode = True + mridc.utils.model_utils.resolve_validation_dataloaders(model=self) + finally: + self._multi_dataset_mode = False + + if ( + self.validation_names is None + and self._validation_dl is not None + and type(self._validation_dl) in [list, tuple] + ): + self.validation_names = [f"val_{idx}_" for idx in range(len(self._validation_dl))]
+ +
[docs] def setup_multiple_test_data(self, test_data_config: Union[DictConfig, Dict]): + """(Optionally) Setups data loader to be used in test, with support for multiple data loaders.""" + # Set some placeholder overridden by helper method + self._test_dl_idx = 0 + self.test_names = None + self._test_dl = None # type: ignore + + # preserve config + self._update_dataset_config(dataset_name="test", config=test_data_config) + + try: + self._multi_dataset_mode = True + mridc.utils.model_utils.resolve_test_dataloaders(model=self) + finally: + self._multi_dataset_mode = False + + if self.test_names is None and self._test_dl is not None and type(self._test_dl) in [list, tuple]: + self.test_names = [f"test_{idx}_" for idx in range(len(self._test_dl))]
+ +
[docs] def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = None): + """ + Prepares an optimizer from a string name and its optional config parameters. + + Parameters + ---------- + optim_config: A dictionary containing the following keys: + - lr: mandatory key for learning rate. Will raise ValueError if not provided. + - optimizer: string name pointing to one of the available optimizers in the registry. If not provided, \ + defaults to "adam". + - opt_args: Optional list of strings, in the format "arg_name=arg_value". The list of "arg_value" will \ + be parsed and a dictionary of optimizer kwargs will be built and supplied to instantiate the optimizer. + + Returns + ------- + An instance of an optimizer. + """ + if self._optimizer_param_groups is None: + self.setup_optimizer_param_groups() + + # If config was not explicitly provided, use default + if optim_config is None: + # See if internal config has 'optim' namespace + if self._cfg is not None and hasattr(self._cfg, "optim"): + optim_config = self._cfg.optim + + # If config is still None, or internal config has no Optim, return without instantiation + if optim_config is None: + logging.info("No optimizer config provided, therefore no optimizer was created") + return + # Preserve the configuration + if not isinstance(optim_config, DictConfig): + optim_config = OmegaConf.create(optim_config) + + # See if internal config has `optim` namespace before preservation + if self._cfg is not None and hasattr(self._cfg, "optim"): + if self._cfg.optim is None: + self._cfg.optim = copy.deepcopy(optim_config) + else: + with open_dict(self._cfg.optim): + self._cfg.optim = copy.deepcopy(optim_config) + + # Setup optimizer and scheduler + if optim_config is not None and isinstance(optim_config, DictConfig): + optim_config = OmegaConf.to_container(optim_config, resolve=True) + + if self._trainer is None: + logging.warning("Trainer wasn't specified in model constructor. Make sure that you really wanted it.") + + if "sched" in optim_config and self._trainer is not None: + if not isinstance(self._trainer.accumulate_grad_batches, int): + raise ValueError("We do not currently support gradient accumulation that is not an integer.") + if self._trainer.max_steps is None or self.trainer.max_steps < 0: # type: ignore + # Store information needed to calculate max_steps + optim_config["sched"]["t_max_epochs"] = self._trainer.max_epochs + optim_config["sched"]["t_accumulate_grad_batches"] = self._trainer.accumulate_grad_batches + optim_config["sched"]["t_limit_train_batches"] = self._trainer.limit_train_batches + if self._trainer.accelerator is None: + optim_config["sched"]["t_num_workers"] = self._trainer.num_devices or 1 + elif self._trainer.accelerator == "ddp_cpu": + optim_config["sched"]["t_num_workers"] = self._trainer.num_devices * self._trainer.num_nodes + elif self._trainer.accelerator == "ddp": + optim_config["sched"]["t_num_workers"] = self._trainer.num_devices * self._trainer.num_nodes + else: + logging.warning( + f"The lightning trainer received accelerator: {self._trainer.accelerator}. We " + "recommend to use 'ddp' instead." + ) + optim_config["sched"]["t_num_workers"] = self._trainer.num_devices * self._trainer.num_nodes + else: + optim_config["sched"]["max_steps"] = self._trainer.max_steps + + # Force into DictConfig from nested structure + optim_config = OmegaConf.create(optim_config) + # Get back nested dict so we its mutable + optim_config = OmegaConf.to_container(optim_config, resolve=True) + + # Extract scheduler config if inside optimizer config + if "sched" in optim_config: + scheduler_config = optim_config.pop("sched") + else: + scheduler_config = None + + # Check if caller provided optimizer name, default to Adam otherwise + optimizer_cls = optim_config.get("_target_", None) + + if optimizer_cls is None: + # Try to get optimizer name for dynamic resolution, defaulting to Adam + optimizer_name = optim_config.get("name", "adam") + elif inspect.isclass(optimizer_cls): + optimizer_name = optimizer_cls.__name__.lower() + else: + # resolve the class name (lowercase) from the class path if not provided + optimizer_name = optimizer_cls.split(".")[-1].lower() + + # We are guaranteed to have lr since it is required by the argparser + # But maybe user forgot to pass it to this function + lr = optim_config.get("lr", None) + + # Check if caller has optimizer kwargs, default to empty dictionary + if "args" in optim_config: + optimizer_args = optim_config.pop("args") + optimizer_args = mridc.core.optim.optimizers.parse_optimizer_args(optimizer_name, optimizer_args) + else: + optimizer_args = copy.deepcopy(optim_config) + + # Remove extra parameters from optimizer_args nest + # Assume all other parameters are to be passed into optimizer constructor + optimizer_args.pop("name", None) + optimizer_args.pop("cls", None) + optimizer_args.pop("lr", None) + + # Adaptive schedulers don't need `lr` + if lr is not None: + optimizer_args["lr"] = lr + + # Actually instantiate the optimizer + if optimizer_cls is not None: + if inspect.isclass(optimizer_cls): + optimizer = optimizer_cls(self._optimizer_param_groups, **optimizer_args) + logging.info("Optimizer config = %s", str(optimizer)) + + self._optimizer = optimizer + + else: + # Attempt class path resolution + try: + optimizer_cls = OmegaConf.create({"_target_": optimizer_cls}) + if lr is not None: + optimizer_config = {"lr": lr} + else: + optimizer_config = {} + optimizer_config.update(optimizer_args) + + optimizer_instance = hydra.utils.instantiate( + optimizer_cls, self._optimizer_param_groups, **optimizer_config + ) # type: DictConfig + + logging.info("Optimizer config = %s", str(optimizer_instance)) + + self._optimizer = optimizer_instance + + except Exception as e: + logging.error( + "Could not instantiate class path - {} with kwargs {}".format( + optimizer_cls, str(optimizer_config) + ) + ) + raise e + + else: + optimizer = mridc.core.optim.optimizers.get_optimizer(optimizer_name) + optimizer = optimizer(self._optimizer_param_groups, **optimizer_args) + + logging.info("Optimizer config = %s", str(optimizer)) + + self._optimizer = optimizer + + # Try to instantiate scheduler for optimizer + self._scheduler = mridc.core.optim.lr_scheduler.prepare_lr_scheduler( # type: ignore + optimizer=self._optimizer, scheduler_config=scheduler_config, train_dataloader=self._train_dl + ) + + # Return the optimizer with/without scheduler + # This return allows multiple optimizers or schedulers to be created + return self._optimizer, self._scheduler
+ +
[docs] def setup_optimizer_param_groups(self): + """ + Used to create param groups for the optimizer. As an example, this can be used to specify per-layer learning + rates: + + .. code-block:: + + optim.SGD([ + {'params': model.base.parameters()}, + {'params': model.classifier.parameters(), 'lr': 1e-3} + ], lr=1e-2, momentum=0.9) + + See https://pytorch.org/docs/stable/optim.html for more information. By default, ModelPT will use + self.parameters(). Override this method to add custom param groups. + """ + param_groups = None + if hasattr(self, "parameters"): + param_groups = [{"params": self.parameters()}] + self._optimizer_param_groups = param_groups
+ +
[docs] def configure_optimizers(self): + """Configure optimizers and schedulers for training.""" + self.setup_optimization() + + if self._scheduler is None: + return self._optimizer + + return [self._optimizer], [self._scheduler]
+ +
[docs] def train_dataloader(self): + """Return the training dataloader.""" + return self._train_dl if self._train_dl is not None else None
+ +
[docs] def val_dataloader(self): + """Return the validation dataloader.""" + return self._validation_dl if self._validation_dl is not None else None
+ +
[docs] def test_dataloader(self): + """Return the test dataloader.""" + return self._test_dl if self._test_dl is not None else None
+ +
[docs] def validation_epoch_end( + self, outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]] + ) -> Optional[Dict[str, Dict[str, torch.Tensor]]]: + """ + Default DataLoader for Validation set which automatically supports multiple data loaders + via `multi_validation_epoch_end`. + If multi dataset support is not required, override this method entirely in base class. + In such a case, there is no need to implement `multi_validation_epoch_end` either. + + .. note:: + If more than one data loader exists, and they all provide `val_loss`, + only the `val_loss` of the first data loader will be used by default. + This default can be changed by passing the special key `val_dl_idx: int` + inside the `validation_ds` config. + + Parameters + ---------- + outputs: Single or nested list of tensor outputs from one or more data loaders. + + Returns + ------- + A dictionary containing the union of all items from individual data_loaders, along with merged logs from all + data loaders. + """ + # Case where we dont provide data loaders + if outputs is not None and len(outputs) == 0: + return {} + + # Case where we provide exactly 1 data loader + if type(outputs[0]) is dict: + output_dict = self.multi_validation_epoch_end(outputs, dataloader_idx=0) # type: ignore + + if output_dict is not None and "log" in output_dict: + self.log_dict(output_dict.pop("log"), on_epoch=True) # type: ignore + + return output_dict + + output_dict = {"log": {}} + + # The output is a list of list of dicts, outer list corresponds to dataloader idx + for dataloader_idx, val_outputs in enumerate(outputs): # type: ignore + # Get prefix and dispatch call to multi epoch end + dataloader_prefix = self.get_validation_dataloader_prefix(dataloader_idx) + dataloader_logs = self.multi_validation_epoch_end(val_outputs, dataloader_idx=dataloader_idx) + + # If result was not provided, generate empty dict + dataloader_logs: Dict[Any, Any] = dataloader_logs or {} # type: ignore + + # Perform `val_loss` resolution first (if provided outside logs) + if ("val_loss" in dataloader_logs and "val_loss" not in output_dict) and ( # type: ignore + dataloader_idx == self._val_dl_idx + ): + output_dict["val_loss"] = dataloader_logs["val_loss"] # type: ignore + + # For every item in the result dictionary + for k, v in dataloader_logs.items(): # type: ignore + # If the key is `log` + if k == "log": + # Parse every element of the log, and attach the prefix name of the data loader + log_dict = {} + + for k_log, v_log in v.items(): + # If we are logging the metric, but dont provide it at result level, + # store it twice - once in log and once in result level. + # Also mark log with prefix name to avoid log level clash with other data loaders + if k_log not in output_dict["log"] and dataloader_idx == self._val_dl_idx: # type: ignore + new_k_log = k_log + + # Also insert duplicate key with prefix for ease of comparison / avoid name clash + log_dict[dataloader_prefix + k_log] = v_log + + else: + # Simply prepend prefix to key and save + new_k_log = dataloader_prefix + k_log + + # Store log value + log_dict[new_k_log] = v_log + + # Update log storage of individual data loader + output_logs = output_dict["log"] # type: ignore + output_logs.update(log_dict) + + # Update global log storage + output_dict["log"] = output_logs # type: ignore + + else: + # If any values are stored outside 'log', simply prefix name and store + new_k = dataloader_prefix + k + output_dict[new_k] = v # type: ignore + + if "log" in output_dict: # type: ignore + self.log_dict(output_dict.pop("log"), on_epoch=True) # type: ignore + + # return everything else + return output_dict
+ +
[docs] def test_epoch_end( + self, outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]] + ) -> Optional[Dict[str, Dict[str, torch.Tensor]]]: + """ + Default DataLoader for Test set which automatically supports multiple data loaders + via `multi_test_epoch_end`. + If multi dataset support is not required, override this method entirely in base class. + In such a case, there is no need to implement `multi_test_epoch_end` either. + + .. note:: + If more than one data loader exists, and they all provide `test_loss`, + only the `test_loss` of the first data loader will be used by default. + This default can be changed by passing the special key `_test_dl_idx: int` + inside the `test_ds` config. + + Parameters + ---------- + outputs: Single or nested list of tensor outputs from one or more data loaders. + + Returns + ------- + A dictionary containing the union of all items from individual data_loaders, along with merged logs from all + data loaders. + """ + # Case where we dont provide data loaders + if outputs is not None and len(outputs) == 0: + return {} + + # Case where we provide exactly 1 data loader + if type(outputs[0]) is dict: + output_dict = self.multi_test_epoch_end(outputs, dataloader_idx=0) # type: ignore + + if output_dict is not None and "log" in output_dict: + self.log_dict(output_dict.pop("log"), on_epoch=True) # type: ignore + + return output_dict + + output_dict = {"log": {}} + + # The output is a list of dicts, outer list corresponds to dataloader idx + for dataloader_idx, test_outputs in enumerate(outputs): # type: ignore + # Get prefix and dispatch call to multi epoch end + dataloader_prefix = self.get_test_dataloader_prefix(dataloader_idx) + self.multi_test_epoch_end(test_outputs, dataloader_idx=dataloader_idx) + + # If result was not provided, generate empty dict + dataloader_logs = dataloader_logs or {} # type: ignore + + # Perform `test_loss` resolution first (if provided outside logs) + if ( + "test_loss" in dataloader_logs + and "test_loss" not in output_dict # type: ignore + and dataloader_idx == self._test_dl_idx + ): # type: ignore + output_dict["test_loss"] = dataloader_logs["test_loss"] # type: ignore + + # For every item in the result dictionary + for k, v in dataloader_logs.items(): + # If the key is `log` + if k == "log": + # Parse every element of the log, and attach the prefix name of the data loader + log_dict = {} + for k_log, v_log in v.items(): + # If we are logging the loss, but dont provide it at result level, + # store it twice - once in log and once in result level. + # Also mark log with prefix name to avoid log level clash with other data loaders + if k_log not in output_dict["log"] and dataloader_idx == self._test_dl_idx: # type: ignore + new_k_log = k_log + + # Also insert duplicate key with prefix for ease of comparison / avoid name clash + log_dict[dataloader_prefix + k_log] = v_log + + else: + # Simply prepend prefix to key and save + new_k_log = dataloader_prefix + k_log + + log_dict[new_k_log] = v_log + + # Update log storage of individual data loader + output_logs = output_dict.get("log", {}) # type: ignore + output_logs.update(log_dict) + + # Update global log storage + output_dict["log"] = output_logs # type: ignore + + else: + # If any values are stored outside 'log', simply prefix name and store + new_k = dataloader_prefix + k + output_dict[new_k] = v # type: ignore + + if "log" in output_dict: # type: ignore + self.log_dict(output_dict.pop("log"), on_epoch=True) # type: ignore + + # return everything else + return output_dict
+ +
[docs] @staticmethod + def multi_validation_epoch_end( + outputs: Union[object, List[Dict[str, torch.Tensor]], None], dataloader_idx: int = 0 + ) -> None: + """ + Adds support for multiple validation datasets. Should be overridden by subclass, to obtain appropriate logs for + each of the dataloaders. + + Parameters + ---------- + outputs: Same as that provided by LightningModule.validation_epoch_end() for a single dataloader. + dataloader_idx: int representing the index of the dataloader. + + Returns + ------- + A dictionary of values, optionally containing a sub-dict `log`, such that the values in the log will be + pre-pended by the dataloader prefix. + """ + logging.warning( + "Multi data loader support has been enabled, but `multi_validation_epoch_end(outputs, dataloader_idx) " + "has not been implemented.\n" + "If you require multi data loader support for validation sets, please override this method.\n" + "If you do not require multi data loader support, please instead override `validation_epoch_end(outputs)." + )
+ +
[docs] @staticmethod + def multi_test_epoch_end(outputs: Union[object, List[Dict[str, torch.Tensor]]], dataloader_idx: int = 0) -> None: + """ + Adds support for multiple test datasets. Should be overridden by subclass, to obtain appropriate logs for each + of the dataloaders. + + Parameters + ---------- + outputs: Same as that provided by LightningModule.validation_epoch_end() for a single dataloader. + dataloader_idx: int representing the index of the dataloader. + + Returns + ------- + A dictionary of values, optionally containing a sub-dict `log`, such that the values in the log will be + pre-pended by the dataloader prefix. + """ + logging.warning( + "Multi data loader support has been enabled, but `multi_test_epoch_end(outputs, dataloader_idx) has not " + "been implemented.\n" + "If you require multi data loader support for validation sets, please override this method.\n" + "If you do not require multi data loader support, please instead override test_epoch_end(outputs)." + )
+ +
[docs] def get_validation_dataloader_prefix(self, dataloader_idx: int = 0) -> str: + """Get the name of one or more data loaders, which will be prepended to all logs.""" + return self.validation_names[dataloader_idx] # type: ignore
+ +
[docs] def get_test_dataloader_prefix(self, dataloader_idx: int = 0) -> str: + """Get the name of one or more data loaders, which will be prepended to all logs.""" + return self.test_names[dataloader_idx] # type: ignore
+ +
[docs] def load_part_of_state_dict(self, state_dict, include, exclude, load_from_string): + """Load part of the state dict.""" + excluded_param_names = [] + # create dict + dict_to_load = {} + for k, v in state_dict.items(): + should_add = any(p in k for p in include) + # except for if any string from exclude is present + for e in exclude: + if e in k: + excluded_param_names.append(k) + should_add = False + break + if should_add: + dict_to_load[k] = v + + # Restore checkpoint part into current model + self.load_state_dict(dict_to_load, strict=False) # type: ignore + logging.info(f"Model checkpoint partially restored from {load_from_string}") + + if excluded_param_names: + logging.info( + f"The following parameters were excluded from loading from {load_from_string} : {excluded_param_names}" + ) + logging.info("Make sure that this is what you wanted!")
+ +
[docs] @rank_zero_only + def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: str = "cpu"): + """ + Initializes a given model with the parameters obtained via specific config arguments. The state dict of the \ + provided model will be updated with `strict=False` setting to prevent requirement of exact model parameters \ + matching. + + Initializations + + init_from_mridc_model: Str path to a .mridc model, which will be instantiated in order to extract the state \ + dict. + + init_from_pretrained_model: Str name of a pretrained model checkpoint (obtained via cloud). The model will \ + be downloaded (or a cached copy will be used), instantiated and then its state dict will be extracted. + + init_from_ptl_ckpt: Str name of a Pytorch Lightning checkpoint file. It will be loaded and the state dict \ + will extract. + + Parameters + ---------- + cfg: The config used to instantiate the model. It needs only contain one of the above keys. + map_location: str or torch.device() which represents where the intermediate state dict (from the pretrained \ + model or checkpoint) will be loaded. + """ + args = ["init_from_mridc_model", "init_from_pretrained_model", "init_from_ptl_ckpt"] + arg_matches = [(1 if arg in cfg and arg is not None else 0) for arg in args] + + if sum(arg_matches) == 0: + # model weights do not need to be restored + return + + if sum(arg_matches) > 1: + raise ValueError( + f"Cannot pass more than one model initialization arguments to config!\n" + f"Found : {[args[idx] for idx, arg_present in enumerate(arg_matches) if arg_present]}" + ) + + if "init_from_mridc_model" in cfg and cfg.init_from_mridc_model is not None: # type: ignore + with open_dict(cfg): # type: ignore + if isinstance(cfg.init_from_mridc_model, str): # type: ignore + model_path = cfg.init_from_mridc_model # type: ignore + # Restore model + restored_model = self.restore_from( + model_path, map_location=map_location, strict=cfg.get("init_strict", True) # type: ignore + ) + # Restore checkpoint into current model + self.load_state_dict(restored_model.state_dict(), strict=False) + logging.info(f"Model checkpoint restored from mridc file with path : `{model_path}`") + elif isinstance(cfg.init_from_mridc_model, (DictConfig, dict)): # type: ignore + model_load_dict = cfg.init_from_mridc_model # type: ignore + for model_load_cfg in model_load_dict.values(): + model_path = model_load_cfg.path + # Restore model + restored_model = self.restore_from( + model_path, map_location=map_location, strict=cfg.get("init_strict", True) # type: ignore + ) + + include = model_load_cfg.pop("include", [""]) + exclude = model_load_cfg.pop("exclude", []) + + self.load_part_of_state_dict( + restored_model.state_dict(), include, exclude, f"mridc file with path `{model_path}`" + ) + else: + raise TypeError("Invalid type: init_from_mridc_model is not a string or a dict!") + + if "init_from_pretrained_model" in cfg and cfg.init_from_pretrained_model is not None: # type: ignore + with open_dict(cfg): # type: ignore + # Restore model + if isinstance(cfg.init_from_pretrained_model, str): # type: ignore + model_name = cfg.pop("init_from_pretrained_model") # type: ignore + + # Check if model is being resumed or not - only works if `Trainer` is attached to model + if hasattr(self, "trainer") and self.trainer is not None: + trainer = self.trainer + if ( + hasattr(trainer, "resume_from_checkpoint") + and trainer.checkpoint_connector.resume_checkpoint_path is not None + ): + logging.info( + "Model training is being resumed via Pytorch Lightning.\n" + "Initialization from pretrained model (via cloud) will be skipped." + ) + return + + restored_model = self.from_pretrained( + model_name, map_location=map_location, strict=cfg.get("init_strict", True) # type: ignore + ) + + # Restore checkpoint into current model + self.load_state_dict(restored_model.state_dict(), strict=False) + logging.info(f"Model checkpoint restored from pretrained checkpoint with name : `{model_name}`") + elif isinstance(cfg.init_from_pretrained_model, dict): # type: ignore + pass + elif isinstance(cfg.init_from_pretrained_model, (DictConfig, dict)): # type: ignore + model_load_dict = cfg.init_from_pretrained_model # type: ignore + for model_load_cfg in model_load_dict.values(): + model_name = model_load_cfg.name + # Restore model + restored_model = self.from_pretrained( + model_name, map_location=map_location, strict=cfg.get("init_strict", True) # type: ignore + ) + + include = model_load_cfg.pop("include", [""]) + exclude = model_load_cfg.pop("exclude", []) + + self.load_part_of_state_dict( + restored_model.state_dict(), + include, + exclude, + f"pretrained checkpoint with name `{model_name}`", + ) + else: + raise TypeError("Invalid type: init_from_pretrained_model is not a string or a dict!") + + if "init_from_ptl_ckpt" in cfg and cfg.init_from_ptl_ckpt is not None: # type: ignore + with open_dict(cfg): # type: ignore + if isinstance(cfg.init_from_ptl_ckpt, str): # type: ignore + # Restore checkpoint + ckpt_path = cfg.pop("init_from_ptl_ckpt") # type: ignore + ckpt = torch.load(ckpt_path, map_location=map_location) + + # Restore checkpoint into current model + self.load_state_dict(ckpt["state_dict"], strict=False) + logging.info( + f"Model checkpoint restored from pytorch lightning checkpoint with path : `{ckpt_path}`" + ) + elif isinstance(cfg.init_from_ptl_ckpt, (DictConfig, dict)): # type: ignore + model_load_dict = cfg.init_from_ptl_ckpt # type: ignore + for model_load_cfg in model_load_dict.values(): + ckpt_path = model_load_cfg.path + # Restore model + ckpt = torch.load(ckpt_path, map_location=map_location) + + include = model_load_cfg.pop("include", [""]) + exclude = model_load_cfg.pop("exclude", []) + + self.load_part_of_state_dict( + ckpt["state_dict"], include, exclude, f"nemo file with path `{model_path}`" + ) + else: + raise TypeError("Invalid type: init_from_ptl_ckpt is not a string or a dict!")
+ +
[docs] def teardown(self, stage: str): + """Called at the end of fit and test.""" + if stage == "fit" and "PL_TRAINER_GPUS" in os.environ: + os.environ.pop("PL_TRAINER_GPUS") + + super().teardown(stage)
+ +
[docs] @classmethod + def extract_state_dict_from( + cls, + restore_path: str, + save_dir: str, + split_by_module: bool = False, + save_restore_connector: SaveRestoreConnector = None, + ): + """ + Extract the state dict(s) from a provided .mridc tarfile and save it to a directory. + + Parameters + ---------- + restore_path: path to .mridc file from which state dict(s) should be extracted + save_dir: directory in which the saved state dict(s) should be stored + split_by_module: bool flag, which determines whether the output checkpoint should be for the entire Model, or + the individual module's that comprise the Model + save_restore_connector: Can be overridden to add custom save and restore logic. + + Example + ------- + To convert the .mridc tarfile into a single Model level PyTorch checkpoint + + .. code-block:: + + state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc', \ + './asr_ckpts') + + To restore a model from a Model level checkpoint + + .. code-block:: + + model = mridc.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration + model.load_state_dict(torch.load("./asr_ckpts/model_weights.ckpt")) + + To convert the .mridc tarfile into multiple Module level PyTorch checkpoints + + .. code-block:: + + state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc', \ + './asr_ckpts', split_by_module=True) + + To restore a module from a Module level checkpoint + + .. code-block:: + + model = mridc.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration + # load the individual components + model.preprocessor.load_state_dict(torch.load("./asr_ckpts/preprocessor.ckpt")) + model.encoder.load_state_dict(torch.load("./asr_ckpts/encoder.ckpt")) + model.decoder.load_state_dict(torch.load("./asr_ckpts/decoder.ckpt")) + + Returns + ------- + The state dict that was loaded from the original .mridc checkpoint. + """ + if save_restore_connector is None: + save_restore_connector = SaveRestoreConnector() + + if not path.exists(restore_path): + raise FileExistsError(f"Can't find {restore_path}") + + cls.update_save_restore_connector(save_restore_connector) + return cls._save_restore_connector.extract_state_dict_from(restore_path, save_dir, split_by_module)
+ +
[docs] def prepare_test(self, trainer: "Trainer") -> bool: + """ + Helper method to check whether the model can safely be tested on a dataset after training (or loading a + checkpoint). + + .. code-block:: + + trainer = Trainer() + if model.prepare_test(trainer): + trainer.test(model) + + Returns + ------- + Bool which declares the model safe to test. Provides warnings if it has to return False to guide the user. + """ + if not hasattr(self._cfg, "test_ds"): + logging.info("No `test_ds` config found within the manifest.") + return False + + if trainer is not None and trainer.num_devices > 1: + # Replace ddp multi-gpu until PTL has a fix + DDP_WARN = """\n\nDuring testing, it is currently advisable to construct a new Trainer " + "with single GPU and no DDP to obtain accurate results. + "Following pattern should be used: " + "trainer = Trainer(devices=1, accelerator='gpu') + "if model.prepare_test(trainer):" + " trainer.test(model)\n\n""" + + logging.warning(DDP_WARN) + return False + + # Assign trainer to the model + self.set_trainer(trainer) + return True
+ +
[docs] def set_trainer(self, trainer: Trainer): + """Set an instance of Trainer object.""" + self.trainer = trainer + self._trainer = trainer + self.set_world_size(self._trainer)
+ +
[docs] def set_world_size(self, trainer: Trainer): + """Determines the world size from the PyTorch Lightning Trainer and then updates AppState.""" + # Update AppState with world information from trainer + if isinstance(trainer, Trainer): + app_state = AppState() + if self._trainer.num_devices and self._trainer.num_nodes: # type: ignore + app_state.world_size = self._trainer.num_devices * self._trainer.num_nodes # type: ignore + else: + logging.warning("World size can only be set by PyTorch Lightning Trainer.")
+ + def _update_dataset_config(self, dataset_name: str, config: Optional[Union[DictConfig, Dict]]): + """ + Update the config (if not None) of the dataset by given name. Preserves said config after updating. + + Parameters + ---------- + dataset_name: str name of the dataset whose config is being updated. Can be one of `train`, `validation` and + `test`. + config: Optional DictConfig or dict. If None is passed, this method simply returns. If dict is passed, it is + cast into a DictConfig. The internal config is updated with the passed config. + """ + if hasattr(self, "_multi_dataset_mode") and self._multi_dataset_mode is True: + return + + if config is not None: + if not isinstance(config, DictConfig): + config = OmegaConf.create(config) + + if dataset_name in {"train", "validation", "test"}: + OmegaConf.set_struct(self.cfg, False) + + key_name = f"{dataset_name}_ds" + self.cfg[key_name] = config + + OmegaConf.set_struct(self.cfg, True) + + # Update hyperparameters by calling property setter + self.cfg = self._cfg + else: + raise ValueError("`dataset_name` when updating config must be one of [train, validation, test]") + + @property + def num_weights(self): + """Utility property that returns the total number of parameters of the Model.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + @property + def cfg(self): + """ + Property that holds the finalized internal config of the model. + + .. note:: + Changes to this config are not reflected in the state of the model. + Please create a new model using an updated config to properly update the model. + """ + return self._cfg + + @cfg.setter + def cfg(self, cfg): + """ + Property that holds the finalized internal config of the model. + + .. note:: + Changes to this config are not reflected in the state of the model. + Please create a new model using an updated config to properly update the model. + """ + self._cfg = cfg + self._set_hparams(OmegaConf.create({"cfg": self._cfg})) + + # TODO: Remove this when we have a better way to handle this + if hasattr(self, "_hparams_initial") and "cfg" in self._hparams_initial: + self._hparams_initial["cfg"] = OmegaConf.to_object(self._cfg) + + @staticmethod + def _is_model_being_restored() -> bool: + """Checks if the model is being restored from a checkpoint.""" + app_state = AppState() + return app_state.is_model_being_restored + + @staticmethod + def _set_model_restore_state(is_being_restored: bool, folder: str = None): + """Sets the state of the model to be restored.""" + app_state = AppState() + app_state.is_model_being_restored = is_being_restored + app_state.mridc_file_folder = folder # type: ignore + + def _set_model_guid(self): + """Sets the model guid.""" + if not hasattr(self, "model_guid"): + appstate = AppState() + + # Generate a unique uuid for the instance + # also determine if the model is being restored or not, and preserve the path + self.model_guid = str(uuid.uuid4()) + if self._is_model_being_restored(): + restore_path = appstate.model_restore_path + else: + restore_path = None + + appstate.register_model_guid(self.model_guid, restoration_path=restore_path) + +
[docs] @classmethod + def update_save_restore_connector(cls, save_restore_connector): + """Update the save_restore_connector of the model.""" + if hasattr(cls, "_save_restore_connector"): + cls._save_restore_connector = save_restore_connector + else: + setattr(cls, "_save_restore_connector", save_restore_connector)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/classes/module.html b/docs/build/html/_modules/mridc/core/classes/module.html new file mode 100644 index 00000000..d4b6e83f --- /dev/null +++ b/docs/build/html/_modules/mridc/core/classes/module.html @@ -0,0 +1,157 @@ + + + + + + mridc.core.classes.module — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.classes.module
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.classes.module

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/classes/module.py
+from abc import ABC
+from contextlib import contextmanager
+
+from torch.nn import Module
+
+__all__ = ["NeuralModule"]
+
+from mridc.core.classes.common import FileIO, Serialization, Typing
+
+
+
[docs]class NeuralModule(Module, Typing, Serialization, FileIO, ABC): + """Abstract class offering interface shared between all PyTorch Neural Modules.""" + + @property + def num_weights(self): + """Utility property that returns the total number of parameters of NeuralModule.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + +
[docs] @staticmethod + def input_example(max_batch=None, max_dim=None): + """ + Override this method if random inputs won't work + + Parameters + ---------- + max_batch: Maximum batch size to generate + max_dim: Maximum dimension to generate + + Returns + ------- + A tuple sample of valid input data. + """ + return None
+ +
[docs] def freeze(self) -> None: + r"""Freeze all params for inference.""" + for param in self.parameters(): + param.requires_grad = False + + self.eval()
+ +
[docs] def unfreeze(self) -> None: + """Unfreeze all parameters for training.""" + for param in self.parameters(): + param.requires_grad = True + + self.train()
+ +
[docs] @contextmanager + def as_frozen(self): + """Context manager which temporarily freezes a module, yields control and finally unfreezes the module.""" + self.freeze() + + try: + yield + finally: + self.unfreeze()
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/conf/base_config.html b/docs/build/html/_modules/mridc/core/conf/base_config.html new file mode 100644 index 00000000..df4ce5a8 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/conf/base_config.html @@ -0,0 +1,110 @@ + + + + + + mridc.core.conf.base_config — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.conf.base_config
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.conf.base_config

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+from dataclasses import dataclass
+from typing import Optional
+
+__all__ = ["Config"]
+
+
+
[docs]@dataclass +class Config: + """Abstract mridc Configuration class.""" + + name: Optional[str] = None
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/conf/dataloader.html b/docs/build/html/_modules/mridc/core/conf/dataloader.html new file mode 100644 index 00000000..26706d80 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/conf/dataloader.html @@ -0,0 +1,130 @@ + + + + + + mridc.core.conf.dataloader — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.conf.dataloader
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.conf.dataloader

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/config/pytorch.py
+
+from dataclasses import dataclass
+from typing import Any, Optional
+
+from omegaconf import MISSING
+
+__all__ = ["DataLoaderConfig"]
+
+
+
[docs]@dataclass +class DataLoaderConfig: + """ + Configuration of PyTorch DataLoader. + + ..note: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader + """ + + batch_size: int = MISSING + shuffle: bool = False + sampler: Optional[Any] = None + batch_sampler: Optional[Any] = None + num_workers: int = 0 + collate_fn: Optional[Any] = None + pin_memory: bool = False + drop_last: bool = False + timeout: int = 0 + worker_init_fn: Optional[Any] = None + multiprocessing_context: Optional[Any] = None
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/conf/hydra_runner.html b/docs/build/html/_modules/mridc/core/conf/hydra_runner.html new file mode 100644 index 00000000..f5cac620 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/conf/hydra_runner.html @@ -0,0 +1,210 @@ + + + + + + mridc.core.conf.hydra_runner — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.conf.hydra_runner
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.conf.hydra_runner

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/config/hydra_runner.py
+
+import functools
+import os
+import sys
+from argparse import Namespace
+from typing import Any, Callable, Optional
+
+from hydra._internal.utils import _run_hydra, get_args_parser
+from hydra.core.config_store import ConfigStore
+from hydra.types import TaskFunction
+from omegaconf import DictConfig, OmegaConf
+
+# multiple interpolated values in the config
+OmegaConf.register_new_resolver("multiply", lambda x, y: x * y)
+
+
+
[docs]def hydra_runner( + config_path: Optional[str] = ".", config_name: Optional[str] = None, schema: Optional[Any] = None +) -> Callable[[TaskFunction], Any]: + """ + Decorator used for passing the Config paths to main function. + Optionally registers a schema used for validation/providing default values. + + Parameters + ---------- + config_path: Path to the config file. + config_name: Name of the config file. + schema: Schema used for validation/providing default values. + + Returns + ------- + A decorator that passes the config paths to the main function. + """ + + def decorator(task_function: TaskFunction) -> Callable[[], None]: + """Decorator that passes the config paths to the main function.""" + + @functools.wraps(task_function) + def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any: + """Wrapper that passes the config paths to the main function.""" + # Check it config was passed. + if cfg_passthrough is not None: + return task_function(cfg_passthrough) + args = get_args_parser() + + # Parse arguments in order to retrieve overrides + parsed_args: Namespace = args.parse_args() + + # Get overriding args in dot string format + overrides = parsed_args.overrides # type: list + + # Disable the creation of .hydra subdir + # https://hydra.cc/docs/tutorials/basic/running_your_app/working_directory + overrides.append("hydra.output_subdir=null") + # Hydra logging outputs only to stdout (no log file). + # https://hydra.cc/docs/configure_hydra/logging + overrides.append("hydra/job_logging=stdout") + + # Set run.dir ONLY for ExpManager "compatibility" - to be removed. + overrides.append("hydra.run.dir=.") + + # Check if user set the schema. + if schema is not None: + # Create config store. + cs = ConfigStore.instance() + + # Get the correct ConfigStore "path name" to "inject" the schema. + if parsed_args.config_name is not None: + path, name = os.path.split(parsed_args.config_name) + # Make sure the path is not set - as this will disable validation scheme. + if path != "": + sys.stderr.write( + "ERROR Cannot set config file path using `--config-name` when " + "using schema. Please set path using `--config-path` and file name using " + "`--config-name` separately.\n" + ) + sys.exit(1) + else: + name = config_name + + # Register the configuration as a node under the name in the group. + cs.store(name=name, node=schema) # group=group, + + # Wrap a callable object with name `parse_args` + # This is to mimic the ArgParser.parse_args() API. + class _argparse_wrapper: + """Wrapper for ArgParser.parse_args().""" + + def __init__(self, arg_parser): + self.arg_parser = arg_parser + self._actions = arg_parser._actions + + @staticmethod + def parse_args(args=None, namespace=None): + """Parse arguments.""" + return parsed_args + + # no return value from run_hydra() as it may sometime actually run the task_function + # multiple times (--multirun) + + _run_hydra( + args_parser=_argparse_wrapper(args), # type: ignore + task_function=task_function, + config_path=config_path, + config_name=config_name, + ) + + return wrapper + + return decorator
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/conf/modelPT.html b/docs/build/html/_modules/mridc/core/conf/modelPT.html new file mode 100644 index 00000000..43b92c7f --- /dev/null +++ b/docs/build/html/_modules/mridc/core/conf/modelPT.html @@ -0,0 +1,278 @@ + + + + + + mridc.core.conf.modelPT — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.conf.modelPT
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.conf.modelPT

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/config/modelPT.py
+
+from dataclasses import dataclass, field
+from typing import Any, Dict, Optional
+
+from omegaconf import MISSING
+
+from mridc.core.classes.dataset import DatasetConfig
+from mridc.core.conf.optimizers import OptimizerParams
+from mridc.core.conf.schedulers import SchedulerParams
+from mridc.core.conf.trainer import TrainerConfig
+from mridc.utils.exp_manager import ExpManagerConfig
+
+
+
[docs]@dataclass +class SchedConfig: + """Configuration for the scheduler.""" + + name: str = MISSING + min_lr: float = 0.0 + last_epoch: int = -1
+ + +
[docs]@dataclass +class OptimConfig: + """Configuration for the optimizer.""" + + name: str = MISSING + sched: Optional[SchedConfig] = None
+ + +
[docs]@dataclass +class ModelConfig: + """Configuration for the model.""" + + train_ds: Optional[DatasetConfig] = None + validation_ds: Optional[DatasetConfig] = None + test_ds: Optional[DatasetConfig] = None + optim: Optional[OptimConfig] = None
+ + +
[docs]@dataclass +class HydraConfig: + """Configuration for the hydra framework.""" + + run: Dict[str, Any] = field(default_factory=lambda: {"dir": "."}) + job_logging: Dict[str, Any] = field(default_factory=lambda: {"root": {"handlers": None}})
+ + +
[docs]@dataclass +class MRIDCConfig: + """Configuration for the mridc framework.""" + + name: str = MISSING + model: ModelConfig = MISSING + trainer: TrainerConfig = TrainerConfig( + strategy="ddp", + enable_checkpointing=False, + logger=False, + log_every_n_steps=1, + accelerator="gpu", + ) + exp_manager: Optional[Any] = ExpManagerConfig() + hydra: HydraConfig = HydraConfig()
+ + +
[docs]class ModelConfigBuilder: + """Builder for the ModelConfig class.""" + + def __init__(self, model_cfg: ModelConfig): + """ + Base class for any Model Config Builder. + A Model Config Builder is a utility class that accepts a ModelConfig dataclass, and via a set of utility + methods (that are implemented by the subclassed ModelConfigBuilder), builds a finalized ModelConfig that can be + supplied to a MRIDCModel dataclass as the `model` component. + + Subclasses *must* implement the private method `_finalize_cfg`. + Inside this method, they must update `self.model_cfg` with all interdependent config + options that need to be set (either updated by user explicitly or with their default value). + The updated model config must then be preserved in `self.model_cfg`. + + Example + ------- + # Create the config builder + config_builder = <subclass>ModelConfigBuilder() + # Update the components of the config that are modifiable + config_builder.set_X(X) + config_builder.set_Y(Y) + # Create a "finalized" config dataclass that will contain all the updates + # that were specified by the builder + model_config = config_builder.build() + # Use model config as is (or further update values), then create a new Model + model = mridc.<domain>.models.<ModelName>Model(cfg=model_config, trainer=Trainer()) + Supported build methods: + - set_train_ds: All model configs can accept a subclass of `DatasetConfig` as their + training conf. Subclasses can override this method to enable auto-complete + by replacing `Optional[DatasetConfig]` with `Optional[<subclass of DatasetConfig>]`. + - set_validation_ds: All model configs can accept a subclass of `DatasetConfig` as their + validation conf. Subclasses can override this method to enable auto-complete + by replacing `Optional[DatasetConfig]` with `Optional[<subclass of DatasetConfig>]`. + - set_test_ds: All model configs can accept a subclass of `DatasetConfig` as their + test conf. Subclasses can override this method to enable auto-complete + by replacing `Optional[DatasetConfig]` with `Optional[<subclass of DatasetConfig>]`. + - set_optim: A build method that supports changes to the Optimizer (and optionally, + the Scheduler) used for training the model. The function accepts two inputs - + `cfg`: A subclass of `OptimizerParams` - any OptimizerParams subclass can be used, + in order to select an appropriate Optimizer. Examples: AdamParams. + `sched_cfg`: A subclass of `SchedulerParams` - any SchedulerParams subclass can be used, + in order to select an appropriate Scheduler. Examples: CosineAnnealingParams. + Note that this argument is optional. + - build(): The method which should return a "finalized" ModelConfig dataclass. + Subclasses *should* always override this method, and update the signature + of this method with the return type of the Dataclass, so that it enables + autocomplete for the user. + Example: + def build(self) -> EncDecCTCConfig: + return super().build() + Any additional build methods must be added by subclasses of ModelConfigBuilder. + + Parameters + ---------- + model_cfg: The model config dataclass to be updated. + + Returns + ------- + The updated model config dataclass. + """ + self.model_cfg = model_cfg + self.train_ds_cfg = None + self.validation_ds_cfg = None + self.test_ds_cfg = None + self.optim_cfg = None + +
[docs] def set_train_ds(self, cfg: Optional[DatasetConfig] = None): + """Set the training dataset configuration.""" + self.model_cfg.train_ds = cfg
+ +
[docs] def set_validation_ds(self, cfg: Optional[DatasetConfig] = None): + """Set the validation dataset configuration.""" + self.model_cfg.validation_ds = cfg
+ +
[docs] def set_test_ds(self, cfg: Optional[DatasetConfig] = None): + """Set the test dataset configuration.""" + self.model_cfg.test_ds = cfg
+ +
[docs] def set_optim(self, cfg: OptimizerParams, sched_cfg: Optional[SchedulerParams] = None): + """Set the optimizer configuration.""" + + @dataclass + class WrappedOptimConfig(OptimConfig, cfg.__class__): # type: ignore + """A wrapper class for the OptimizerParams dataclass.""" + + # Setup optim + optim_name = cfg.__class__.__name__.replace("Params", "").lower() + wrapped_cfg = WrappedOptimConfig(name=optim_name, sched=None, **vars(cfg)) # type: ignore + + if sched_cfg is not None: + + @dataclass + class WrappedSchedConfig(SchedConfig, sched_cfg.__class__): # type: ignore + """A wrapper class for the SchedulerParams dataclass.""" + + # Setup scheduler + sched_name = sched_cfg.__class__.__name__.replace("Params", "") + wrapped_sched_cfg = WrappedSchedConfig(name=sched_name, **vars(sched_cfg)) + + wrapped_cfg.sched = wrapped_sched_cfg + + self.model_cfg.optim = wrapped_cfg
+ + def _finalize_cfg(self): + """Finalize the model configuration.""" + raise NotImplementedError() + +
[docs] def build(self) -> ModelConfig: + """Validate config""" + self._finalize_cfg() + + return self.model_cfg
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/conf/optimizers.html b/docs/build/html/_modules/mridc/core/conf/optimizers.html new file mode 100644 index 00000000..69a732e8 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/conf/optimizers.html @@ -0,0 +1,372 @@ + + + + + + mridc.core.conf.optimizers — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.conf.optimizers
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.conf.optimizers

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/config/optimizers.py
+
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Dict, Optional, Tuple, Union
+
+from omegaconf import MISSING, OmegaConf
+
+__all__ = [
+    "OptimizerParams",
+    "AdamParams",
+    "NovogradParams",
+    "SGDParams",
+    "AdadeltaParams",
+    "AdamaxParams",
+    "AdagradParams",
+    "AdamWParams",
+    "RMSpropParams",
+    "RpropParams",
+    "get_optimizer_config",
+    "register_optimizer_params",
+]
+
+
+
[docs]@dataclass +class OptimizerParams: + """Base Optimizer params with no values. User can chose it to explicitly override via command line arguments.""" + + lr: Optional[float] = MISSING
+ + +
[docs]@dataclass +class SGDParams(OptimizerParams): + """ + Default configuration for Adam optimizer. + + .. note:: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html?highlight=sgd#torch.optim.SGD + """ + + momentum: float = 0 + dampening: float = 0 + weight_decay: float = 0 + nesterov: bool = False
+ + +
[docs]@dataclass +class AdamParams(OptimizerParams): + """ + Default configuration for Adam optimizer. + + .. note:: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html?highlight=adam#torch.optim.Adam + """ + + # betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-08 + weight_decay: float = 0 + amsgrad: bool = False
+ + +
[docs]@dataclass +class AdamWParams(OptimizerParams): + """ + Default configuration for AdamW optimizer. + + .. note:: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html#torch.optim.AdamW + """ + + betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-08 + weight_decay: float = 0 + amsgrad: bool = False
+ + +
[docs]@dataclass +class AdadeltaParams(OptimizerParams): + """ + Default configuration for Adadelta optimizer. + + .. note:: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html#torch.optim.Adadelta + """ + + rho: float = 0.9 + eps: float = 1e-6 + weight_decay: float = 0
+ + +
[docs]@dataclass +class AdamaxParams(OptimizerParams): + """ + Default configuration for Adamax optimizer. + + .. note:: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html#torch.optim.Adamax + """ + + betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + weight_decay: float = 0
+ + +
[docs]@dataclass +class AdagradParams(OptimizerParams): + """ + Default configuration for Adagrad optimizer. + + .. note:: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html#torch.optim.Adagrad + """ + + lr_decay: float = 0 + weight_decay: float = 0 + initial_accumulator_value: float = 0 + eps: float = 1e-10
+ + +
[docs]@dataclass +class RMSpropParams(OptimizerParams): + """ + Default configuration for RMSprop optimizer. + + .. note:: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html#torch.optim.RMSprop + """ + + alpha: float = 0.99 + eps: float = 1e-8 + weight_decay: float = 0 + momentum: float = 0 + centered: bool = False
+ + +
[docs]@dataclass +class RpropParams(OptimizerParams): + """ + Default configuration for RpropParams optimizer. + + .. note:: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html#torch.optim.Rprop + """ + + etas: Tuple[float, float] = (0.5, 1.2) + step_sizes: Tuple[float, float] = (1e-6, 50)
+ + +
[docs]@dataclass +class NovogradParams(OptimizerParams): + """ + Configuration of the Novograd optimizer. It has been proposed in "Stochastic Gradient Methods with Layer-wise + Adaptive Moments for Training of Deep Networks" (https://arxiv.org/abs/1905.11286). The OptimizerParams is a Base + Optimizer params with no values. User can choose to explicitly override it via command line arguments. + """ + + betas: Tuple[float, float] = (0.95, 0.98) + eps: float = 1e-8 + weight_decay: float = 0 + grad_averaging: bool = False + amsgrad: bool = False + lr: float = 1e-3 + luc: bool = False + luc_trust: float = 1e-3 + luc_eps: float = 1e-8
+ + +@dataclass +class AdafactorParams(OptimizerParams): + """ + Configuration of the Adafactor optimizer. + It has been proposed in "Adafactor: Adaptive Learning Rates with Sublinear Memory Cost" + (https://arxiv.org/abs/1804.04235) + + Parameters + ---------- + lr: Learning rate. + (float, optional), (default: 1e-3) + beta1: Coefficients used for computing running averages of gradient and its square. + (float, optional), (default: None) + eps: Term added to the denominator to improve numerical stability. + (Tuple [float, float] optional) + weight_decay: Weight decay (L2 penalty). + (float, optional), (default: 0) + scale_parameter: Scale parameter. + (float, optional), (default: False) + relative_step: Whether to use relative step sizes. + (bool, optional), (default: False) + warmup_init: Whether to warm up the learning rate linearly. + (bool, optional) (default: False) + """ + + beta1: Optional[float] = None + eps: Tuple[float, float] = (1e-30, 1e-3) + clip_threshold: float = 1.0 + decay_rate: float = 0.8 + weight_decay: float = 0 + scale_parameter: bool = True + relative_step: bool = False + warmup_init: bool = False + + +
[docs]def register_optimizer_params(name: str, optimizer_params: OptimizerParams): + """ + Checks if the optimizer param name exists in the registry, and if it doesn't, adds it. + This allows custom optimizer params to be added and called by name during instantiation. + + Parameters + ---------- + name: Name of the optimizer. Will be used as key to retrieve the optimizer. + optimizer_params: Optimizer class + """ + if name in AVAILABLE_OPTIMIZER_PARAMS: + raise ValueError(f"Cannot override pre-existing optimizers. Conflicting optimizer name = {name}") + + AVAILABLE_OPTIMIZER_PARAMS[name] = optimizer_params # type: ignore
+ + +
[docs]def get_optimizer_config( + name: str, **kwargs: Optional[Dict[str, Any]] +) -> Union[Dict[str, Optional[Dict[str, Any]]], partial]: + """ + Convenience method to obtain a OptimizerParams class and partially instantiate it with optimizer kwargs. + + Parameters + ---------- + name: Name of the OptimizerParams in the registry. + kwargs: Optional kwargs of the optimizer used during instantiation. + + Returns + ------- + A partially instantiated OptimizerParams. + """ + if name is None: + return kwargs + + if name not in AVAILABLE_OPTIMIZER_PARAMS: + raise ValueError( + f"Cannot resolve optimizer parameters '{name}'. Available optimizer parameters are : " + f"{AVAILABLE_OPTIMIZER_PARAMS.keys()}" + ) + + scheduler_params = AVAILABLE_OPTIMIZER_PARAMS[name] + + if kwargs is not None and kwargs: + kwargs = OmegaConf.create(kwargs) + OmegaConf.merge(scheduler_params(), kwargs) + + scheduler_params = partial(scheduler_params, **kwargs) # type: ignore + return scheduler_params # type: ignore
+ + +AVAILABLE_OPTIMIZER_PARAMS = { + "optim_params": OptimizerParams, + "adam_params": AdamParams, + "novograd_params": NovogradParams, + "sgd_params": SGDParams, + "adadelta_params": AdadeltaParams, + "adamax_params": AdamaxParams, + "adagrad_params": AdagradParams, + "adamw_params": AdamWParams, + "rmsprop_params": RMSpropParams, + "rprop_params": RpropParams, + "adafactor_params": AdafactorParams, +} +
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/conf/schedulers.html b/docs/build/html/_modules/mridc/core/conf/schedulers.html new file mode 100644 index 00000000..89e15752 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/conf/schedulers.html @@ -0,0 +1,307 @@ + + + + + + mridc.core.conf.schedulers — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.conf.schedulers
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.conf.schedulers

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/config/schedulers.py
+
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Dict, Optional
+
+
+
[docs]@dataclass +class SchedulerParams: + """Base configuration for all schedulers.""" + + last_epoch: int = -1
+ + +
[docs]@dataclass +class SquareRootConstantSchedulerParams(SchedulerParams): + """ + Base configuration for all schedulers. + It is not derived from Config as it is not a mridc object (and in particular it doesn't need a name). + """ + + constant_steps: Optional[float] = None + constant_ratio: Optional[float] = None
+ + +
[docs]@dataclass +class WarmupSchedulerParams(SchedulerParams): + """Base configuration for all schedulers.""" + + max_steps: int = 0 + warmup_steps: Optional[float] = None + warmup_ratio: Optional[float] = None
+ + +
[docs]@dataclass +class WarmupHoldSchedulerParams(WarmupSchedulerParams): + """Base configuration for all schedulers.""" + + hold_steps: Optional[float] = None + hold_ratio: Optional[float] = None + min_lr: float = 0.0
+ + +
[docs]@dataclass +class WarmupAnnealingHoldSchedulerParams(WarmupSchedulerParams): + """Base configuration for all schedulers.""" + + constant_steps: Optional[float] = None + constant_ratio: Optional[float] = None + min_lr: float = 0.0
+ + +
[docs]@dataclass +class SquareAnnealingParams(WarmupSchedulerParams): + """Square Annealing parameter config""" + + min_lr: float = 1e-5
+ + +
[docs]@dataclass +class SquareRootAnnealingParams(WarmupSchedulerParams): + """Square Root Annealing parameter config""" + + min_lr: float = 0.0
+ + +
[docs]@dataclass +class CosineAnnealingParams(WarmupAnnealingHoldSchedulerParams): + """Cosine Annealing parameter config""" + + min_lr: float = 0.0
+ + +
[docs]@dataclass +class NoamAnnealingParams(WarmupSchedulerParams): + """Cosine Annealing parameter config""" + + min_lr: float = 0.0
+ + +
[docs]@dataclass +class WarmupAnnealingParams(WarmupSchedulerParams): + """Warmup Annealing parameter config""" + + warmup_ratio: Optional[float] = None
+ + +
[docs]@dataclass +class InverseSquareRootAnnealingParams(WarmupSchedulerParams): + """Inverse Square Root Annealing parameter config"""
+ + +
[docs]@dataclass +class PolynomialDecayAnnealingParams(WarmupSchedulerParams): + """Polynomial Decay Annealing parameter config""" + + power: float = 1.0 + cycle: bool = False
+ + +
[docs]@dataclass +class PolynomialHoldDecayAnnealingParams(WarmupSchedulerParams): + """Polynomial Hold Decay Annealing parameter config""" + + power: float = 1.0 + cycle: bool = False
+ + +
[docs]@dataclass +class StepLRParams(SchedulerParams): + """Config for StepLR.""" + + step_size: float = 0.1 + gamma: float = 0.1
+ + +
[docs]@dataclass +class ExponentialLRParams(SchedulerParams): + """Config for ExponentialLR.""" + + gamma: float = 0.9
+ + +
[docs]@dataclass +class ReduceLROnPlateauParams: + """Config for ReduceLROnPlateau.""" + + mode: str = "min" + factor: float = 0.1 + patience: int = 10 + verbose: bool = False + threshold: float = 1e-4 + threshold_mode: str = "rel" + cooldown: int = 0 + min_lr: float = 0 + eps: float = 1e-8
+ + +
[docs]@dataclass +class CyclicLRParams(SchedulerParams): + """Config for CyclicLR.""" + + base_lr: float = 0.001 + max_lr: float = 0.1 + step_size_up: int = 2000 + step_size_down: Optional[int] = None + mode: str = "triangular" + gamma: float = 1.0 + scale_mode: str = "cycle" + # scale_fn is not supported + cycle_momentum: bool = True + base_momentum: float = 0.8 + max_momentum: float = 0.9
+ + +
[docs]def register_scheduler_params(name: str, scheduler_params: SchedulerParams): + """ + Checks if the scheduler config name exists in the registry, and if it doesn't, adds it. + This allows custom schedulers to be added and called by name during instantiation. + + Parameters + ---------- + name: Name of the optimizer. Will be used as key to retrieve the optimizer. + scheduler_params: SchedulerParams class + """ + if name in AVAILABLE_SCHEDULER_PARAMS: + raise ValueError(f"Cannot override pre-existing optimizers. Conflicting optimizer name = {name}") + + AVAILABLE_SCHEDULER_PARAMS[name] = scheduler_params # type: ignore
+ + +
[docs]def get_scheduler_config(name: str, **kwargs: Optional[Dict[str, Any]]) -> partial: + """ + Convenience method to obtain a SchedulerParams class and partially instantiate it with optimizer kwargs. + + Parameters + ---------- + name: Name of the SchedulerParams in the registry. + kwargs: Optional kwargs of the optimizer used during instantiation. + + Returns + ------- + A partially instantiated SchedulerParams. + """ + if name not in AVAILABLE_SCHEDULER_PARAMS: + raise ValueError( + f"Cannot resolve scheduler parameters '{name}'. Available scheduler parameters are : " + f"{AVAILABLE_SCHEDULER_PARAMS.keys()}" + ) + + return partial(AVAILABLE_SCHEDULER_PARAMS[name], **kwargs)
+ + +AVAILABLE_SCHEDULER_PARAMS = { + "SchedulerParams": SchedulerParams, + "WarmupPolicyParams": WarmupSchedulerParams, + "WarmupHoldPolicyParams": WarmupHoldSchedulerParams, + "WarmupAnnealingHoldSchedulerParams": WarmupAnnealingHoldSchedulerParams, + "SquareAnnealingParams": SquareAnnealingParams, + "SquareRootAnnealingParams": SquareRootAnnealingParams, + "InverseSquareRootAnnealingParams": InverseSquareRootAnnealingParams, + "SquareRootConstantSchedulerParams": SquareRootConstantSchedulerParams, + "CosineAnnealingParams": CosineAnnealingParams, + "NoamAnnealingParams": NoamAnnealingParams, + "WarmupAnnealingParams": WarmupAnnealingParams, + "PolynomialDecayAnnealingParams": PolynomialDecayAnnealingParams, + "PolynomialHoldDecayAnnealingParams": PolynomialHoldDecayAnnealingParams, +} +
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/conf/trainer.html b/docs/build/html/_modules/mridc/core/conf/trainer.html new file mode 100644 index 00000000..13164172 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/conf/trainer.html @@ -0,0 +1,178 @@ + + + + + + mridc.core.conf.trainer — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.conf.trainer
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.conf.trainer

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/config/pytorch_lightning.py
+
+from dataclasses import dataclass
+from typing import Any, Optional
+
+from hydra.core.config_store import ConfigStore
+
+__all__ = ["TrainerConfig"]
+
+cs = ConfigStore.instance()
+
+
+
[docs]@dataclass +class TrainerConfig: + """TrainerConfig is a dataclass that holds all the hyperparameters for the training process.""" + + logger: Any = True + checkpoint_callback: Any = True + callbacks: Optional[Any] = None + default_root_dir: Optional[str] = None + gradient_clip_val: float = 0 + process_position: int = 0 + num_nodes: int = 1 + gpus: Optional[Any] = None + auto_select_gpus: bool = False + tpu_cores: Optional[Any] = None + log_gpu_memory: Optional[str] = None + progress_bar_refresh_rate: int = 1 + enable_progress_bar: bool = True + overfit_batches: Any = 0.0 + track_grad_norm: Any = -1 + check_val_every_n_epoch: int = 1 + fast_dev_run: bool = False + accumulate_grad_batches: Any = 1 + max_epochs: int = 1000 + min_epochs: int = 1 + max_steps: Optional[int] = None + min_steps: Optional[int] = None + limit_train_batches: Any = 1.0 + limit_val_batches: Any = 1.0 + limit_test_batches: Any = 1.0 + val_check_interval: Any = 1.0 + flush_logs_every_n_steps: int = 100 + log_every_n_steps: int = 50 + accelerator: Optional[str] = None + sync_batchnorm: bool = False + precision: Any = 32 + weights_summary: Optional[str] = "full" # ModelSummary.MODE_DEFAULT + weights_save_path: Optional[str] = None + num_sanity_val_steps: int = 2 + resume_from_checkpoint: Optional[str] = None + profiler: Optional[Any] = None + benchmark: bool = False + deterministic: bool = False + auto_lr_find: Any = False + replace_sampler_ddp: bool = True + detect_anomaly: bool = False + terminate_on_nan: bool = False + auto_scale_batch_size: Any = False + prepare_data_per_node: bool = True + amp_backend: str = "native" + amp_level: Optional[str] = None + plugins: Optional[Any] = None # Optional[Union[str, list]] + move_metrics_to_cpu: bool = False + multiple_trainloader_mode: str = "max_size_cycle" + limit_predict_batches: float = 1.0 + stochastic_weight_avg: bool = False + gradient_clip_algorithm: str = "norm" + max_time: Optional[Any] = None # can be one of Union[str, timedelta, Dict[str, int], None] + reload_dataloaders_every_n_epochs: int = 0 + ipus: Optional[int] = None + devices: Any = None + strategy: Any = None + enable_checkpointing: bool = True + enable_model_summary: bool = True
+ + +# Register the trainer config. +cs.store(group="trainer", name="trainer", node=TrainerConfig) +
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/connectors/save_restore_connector.html b/docs/build/html/_modules/mridc/core/connectors/save_restore_connector.html new file mode 100644 index 00000000..4ad32ba4 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/connectors/save_restore_connector.html @@ -0,0 +1,594 @@ + + + + + + mridc.core.connectors.save_restore_connector — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.connectors.save_restore_connector
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.connectors.save_restore_connector

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/connectors/save_restore_connector.py
+
+import os
+import shutil
+import tarfile
+import tempfile
+import uuid
+from typing import Optional, Union
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from omegaconf.omegaconf import open_dict
+from pytorch_lightning.trainer.trainer import Trainer
+
+import mridc.utils
+from mridc.utils import logging
+from mridc.utils.app_state import AppState
+from mridc.utils.get_rank import is_global_rank_zero
+
+
+
[docs]class SaveRestoreConnector: + """This class is used to save and restore the model state.""" + + def __init__(self) -> None: + self._model_config_yaml = "model_config.yaml" + self._model_weights_ckpt = "model_weights.ckpt" + +
[docs] def save_to(self, model, save_path: str): + """ + Saves model instance (weights and configuration) into .mridc file. + You can use "restore_from" method to fully restore instance from .mridc file. + .mridc file is an archive (tar.gz) with the following: + - model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for \ + model's constructor + - model_wights.chpt - model checkpoint + + Parameters + ---------- + model: ModelPT object to be saved. + save_path: Path to .mridc file where model instance should be saved + """ + if is_global_rank_zero(): + with tempfile.TemporaryDirectory() as tmpdir: + config_yaml = os.path.join(tmpdir, self.model_config_yaml) + model_weights = os.path.join(tmpdir, self.model_weights_ckpt) + model.to_config_file(path2yaml_file=config_yaml) + if hasattr(model, "artifacts") and model.artifacts is not None: + self._handle_artifacts(model, mridc_file_folder=tmpdir) + # We should not update self._cfg here - the model can still be in use + self._update_artifact_paths(model, path2yaml_file=config_yaml) + self._save_state_dict_to_disk(model.state_dict(), model_weights) + self._make_mridc_file_from_folder(filename=save_path, source_dir=tmpdir) + else: + return
+ +
[docs] def load_config_and_state_dict( + self, + calling_cls, + restore_path: str, + override_config_path: Optional[Union[OmegaConf, str]] = None, + map_location: Optional[torch.device] = None, + strict: bool = True, + return_config: bool = False, + trainer: Trainer = None, + ): + """ + Restores model instance (weights and configuration) into .mridc file + + Parameters + ---------- + calling_cls: Class of the model to be restored. + restore_path: path to .mridc file from which model should be instantiated + override_config_path: path to a yaml config that will override the internal config file or an + OmegaConf/DictConfig object representing the model config. + map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will + select a GPU if available, falling back to CPU otherwise. + strict: Passed to load_state_dict. By default, True. + return_config: If set to true, will return just the underlying config of the restored model as an OmegaConf + DictConfig object without instantiating the model. + trainer: Optional trainer object to be used for model parallelism. + + Example + ------- + ``` + model = mridc.collections.asr.models.EncDecCTCModel.restore_from('asr.mridc') + assert isinstance(model, mridc.collections.asr.models.EncDecCTCModel) + ``` + + Returns + ------- + An instance of type cls or its underlying config (if return_config is set). + """ + # Get path where the command is executed - the artifacts will be "retrieved" there + # (original .mridc behavior) + cwd = os.getcwd() + + if map_location is None: + if torch.cuda.is_available(): + map_location = torch.device("cuda") + else: + map_location = torch.device("cpu") + + app_state = AppState() + with tempfile.TemporaryDirectory() as tmpdir: + try: + self._unpack_mridc_file(path2file=restore_path, out_folder=tmpdir) + os.chdir(tmpdir) + if override_config_path is None: + config_yaml = os.path.join(tmpdir, self.model_config_yaml) + else: + # can be str path or OmegaConf / DictConfig object + config_yaml = override_config_path + if not isinstance(config_yaml, (OmegaConf, DictConfig)): + conf = OmegaConf.load(config_yaml) + else: + conf = config_yaml + if override_config_path is not None: + # Resolve the override config + conf = OmegaConf.to_container(conf, resolve=True) + conf = OmegaConf.create(conf) + # If override is top level config, extract just `model` from it + if "model" in conf: + conf = conf.model + + if return_config: + instance = conf + return instance + if app_state.model_parallel_rank is not None and app_state.model_parallel_size > 1: + model_weights = self._inject_model_parallel_rank_for_ckpt(tmpdir, self.model_weights_ckpt) + else: + model_weights = os.path.join(tmpdir, self.model_weights_ckpt) + OmegaConf.set_struct(conf, True) + os.chdir(cwd) + # get the class + calling_cls._set_model_restore_state(is_being_restored=True, folder=tmpdir) # type: ignore + instance = calling_cls.from_config_dict(config=conf, trainer=trainer) + instance = instance.to(map_location) + # add load_state_dict override + if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: + model_weights = self._inject_model_parallel_rank_for_ckpt(tmpdir, self.model_weights_ckpt) + instance.load_state_dict( + self._load_state_dict_from_disk(model_weights, map_location=map_location), strict=strict + ) + logging.info(f"Model {instance.__class__.__name__} was successfully restored from {restore_path}.") + instance._set_model_restore_state(is_being_restored=False) # type: ignore + finally: + os.chdir(cwd) + + return instance
+ +
[docs] @staticmethod + def load_instance_with_state_dict(instance, state_dict, strict): + """Loads the state dict into the instance.""" + instance.load_state_dict(state_dict, strict=strict) + instance._set_model_restore_state(is_being_restored=False) # type: ignore
+ +
[docs] def restore_from( + self, + calling_cls, + restore_path: str, + override_config_path: Optional[Union[OmegaConf, str]] = None, + map_location: Optional[torch.device] = None, + strict: bool = True, + return_config: bool = False, + trainer: Trainer = None, + ): + """ + Restores model instance (weights and configuration) into .mridc file + + Parameters + ---------- + calling_cls: The class of the model to be restored. + restore_path: path to .mridc file from which model should be instantiated + override_config_path: path to a yaml config that will override the internal config file or an + OmegaConf/DictConfig object representing the model config. + map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will + select a GPU if available, falling back to CPU otherwise. + strict: Passed to load_state_dict. By default, True. + return_config: If set to true, will return just the underlying config of the restored model as an + OmegaConf/DictConfig object without instantiating the model. + trainer: Optional trainer object to be used for restoring the model. + + Returns + ------- + An instance of type cls or its underlying config (if return_config is set). + """ + # Get path where the command is executed - the artifacts will be "retrieved" there (original .mridc behavior) + loaded_params = self.load_config_and_state_dict( + calling_cls, + restore_path, + override_config_path, + map_location, + strict, + return_config, + trainer, + ) + + if not isinstance(loaded_params, tuple): + return loaded_params + + _, instance, state_dict = loaded_params + self.load_instance_with_state_dict(instance, state_dict, strict) + logging.info(f"Model {instance.__class__.__name__} was successfully restored from {restore_path}.") + return instance
+ +
[docs] def extract_state_dict_from(self, restore_path: str, save_dir: str, split_by_module: bool = False): + """ + Extract the state dict(s) from a provided .mridc tarfile and save it to a directory. + + Parameters + ---------- + restore_path: path to .mridc file from which state dict(s) should be extracted + save_dir: directory in which the saved state dict(s) should be stored + split_by_module: bool flag, which determines whether the output checkpoint should be for the entire Model, or + the individual module's that comprise the Model. + + Example + ------- + To convert the .mridc tarfile into a single Model level PyTorch checkpoint + :: + state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc', + './asr_ckpts') + To restore a model from a Model level checkpoint + :: + model = mridc.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration + model.load_state_dict(torch.load("./asr_ckpts/model_weights.ckpt")) + To convert the .mridc tarfile into multiple Module level PyTorch checkpoints + :: + state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc', + './asr_ckpts', split_by_module=True). To restore a module from a Module level checkpoint + :: + model = mridc.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration + # load the individual components + model.preprocessor.load_state_dict(torch.load("./asr_ckpts/preprocessor.ckpt")) + model.encoder.load_state_dict(torch.load("./asr_ckpts/encoder.ckpt")) + model.decoder.load_state_dict(torch.load("./asr_ckpts/decoder.ckpt")) + + Returns + ------- + The state dict that was loaded from the original .mridc checkpoint. + """ + cwd = os.getcwd() + + save_dir = os.path.abspath(save_dir) + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + + with tempfile.TemporaryDirectory() as tmpdir: + try: + self._unpack_mridc_file(path2file=restore_path, out_folder=tmpdir) + os.chdir(tmpdir) + model_weights = os.path.join(tmpdir, self.model_weights_ckpt) + state_dict = self._load_state_dict_from_disk(model_weights) + + if not split_by_module: + filepath = os.path.join(save_dir, self.model_weights_ckpt) + self._save_state_dict_to_disk(state_dict, filepath) + + else: + key_set = {key.split(".")[0] for key in state_dict.keys()} + for primary_key in key_set: + inner_keys = [key for key in state_dict.keys() if key.split(".")[0] == primary_key] + state_dict_subset = { + ".".join(inner_key.split(".")[1:]): state_dict[inner_key] for inner_key in inner_keys + } + filepath = os.path.join(save_dir, f"{primary_key}.ckpt") + self._save_state_dict_to_disk(state_dict_subset, filepath) + + logging.info(f"Checkpoints from {restore_path} were successfully extracted into {save_dir}.") + finally: + os.chdir(cwd) + + return state_dict
+ +
[docs] @staticmethod + def register_artifact(model, config_path: str, src: str, verify_src_exists: bool = True): + """ + Register model artifacts with this function. These artifacts (files) will be included inside .mridc file + when model.save_to("mymodel.mridc") is called. + + How it works: + 1. It always returns existing absolute path which can be used during Model constructor call. EXCEPTION: src is + None or "" in which case nothing will be done and src will be returned + 2. It will add (config_path, model_utils.ArtifactItem()) pair to self.artifacts. If "src" is local existing + path, then it will be returned in absolute path form. elif "src" starts with "mridc_file:unique_artifact_name": + .mridc will be untarred to a temporary folder location and an actual existing path will be returned else an + error will be raised. + + WARNING: use .register_artifact calls in your models' constructors. + The returned path is not guaranteed to exist after you have exited your model's constructor. + + Parameters + ---------- + model: ModelPT object to register artifact for. + config_path: Artifact key. Usually corresponds to the model config. + src: Path to artifact. + verify_src_exists: If set to False, then the artifact is optional and register_artifact will return None + even if src is not found. Defaults to True. + + Returns + -------- + If src is not None or empty it always returns absolute path which is guaranteed to exist during model instance + life. + """ + app_state = AppState() + + artifact_item = mridc.utils.model_utils.ArtifactItem() # type: ignore + + # This is for backward compatibility, if the src objects exists simply inside the tarfile + # without its key having been overridden, this pathway will be used. + src_obj_name = os.path.basename(src) + if app_state.mridc_file_folder is not None: + src_obj_path = os.path.abspath(os.path.join(app_state.mridc_file_folder, src_obj_name)) + else: + src_obj_path = src_obj_name + + # src is a local existing path - register artifact and return exact same path for usage by the model + if os.path.exists(os.path.abspath(src)): + return_path = os.path.abspath(src) + artifact_item.path_type = mridc.utils.model_utils.ArtifactPathType.LOCAL_PATH # type: ignore + + elif src.startswith("mridc:"): + return_path = os.path.abspath(os.path.join(app_state.mridc_file_folder, src[5:])) + artifact_item.path_type = mridc.utils.model_utils.ArtifactPathType.TAR_PATH # type: ignore + + elif os.path.exists(src_obj_path): + return_path = src_obj_path + artifact_item.path_type = mridc.utils.model_utils.ArtifactPathType.TAR_PATH # type: ignore + elif verify_src_exists: + raise FileNotFoundError( + f"src path does not exist or it is not a path in mridc file. src value I got was: {src}. " + f"Absolute: {os.path.abspath(src)}" + ) + else: + # artifact is optional and we simply return None + return None + + if not os.path.exists(return_path): + raise AssertionError + + artifact_item.path = os.path.abspath(src) + model.artifacts[config_path] = artifact_item + # we were called by ModelPT + if hasattr(model, "cfg"): + with open_dict(model._cfg): + OmegaConf.update(model.cfg, config_path, return_path) + return return_path
+ + def _handle_artifacts(self, model, mridc_file_folder): + """ + This method is called by ModelPT.save_to() and ModelPT.load_from(). It will handle all artifacts and save them + to the mridc file. + + Parameters + ---------- + model: ModelPT object to register artifact for. + mridc_file_folder: Path to the mridc file. + """ + tarfile_artifacts = [] + app_state = AppState() + for conf_path, artiitem in model.artifacts.items(): + if artiitem.path_type == mridc.utils.model_utils.ArtifactPathType.LOCAL_PATH: + if not os.path.exists(artiitem.path): + raise FileNotFoundError(f"Artifact {conf_path} not found at location: {artiitem.path}") + + # Generate new uniq artifact name and copy it to mridc_file_folder + # Note uuid.uuid4().hex is guaranteed to be 32 character long + artifact_base_name = os.path.basename(artiitem.path) + artifact_uniq_name = f"{uuid.uuid4().hex}_{artifact_base_name}" + shutil.copy2(artiitem.path, os.path.join(mridc_file_folder, artifact_uniq_name)) + + # Update artifacts registry + artiitem.hashed_path = f"mridc:{artifact_uniq_name}" + model.artifacts[conf_path] = artiitem + + elif artiitem.path_type == mridc.utils.model_utils.ArtifactPathType.TAR_PATH: + # process all tarfile artifacts in one go, so preserve key-value pair + tarfile_artifacts.append((conf_path, artiitem)) + + else: + raise ValueError("Directly referencing artifacts from other mridc files isn't supported yet") + + # Process current tarfile artifacts by unpacking the previous tarfile and extract the artifacts + # that are currently required. + model_metadata = app_state.get_model_metadata_from_guid(model.model_guid) + if tarfile_artifacts and model_metadata.restoration_path is not None: + # Need to step into mridc archive to extract file + # Get path where the command is executed - the artifacts will be "retrieved" there + # (original .mridc behavior) + cwd = os.getcwd() + try: + # Step into the mridc archive to try and find the file + with tempfile.TemporaryDirectory() as archive_dir: + self._unpack_mridc_file(path2file=model_metadata.restoration_path, out_folder=archive_dir) + os.chdir(archive_dir) + for conf_path, artiitem in tarfile_artifacts: + # Get basename and copy it to mridc_file_folder + if "mridc:" in artiitem.path: + artifact_base_name = artiitem.path.split("mridc:")[1] + else: + artifact_base_name = os.path.basename(artiitem.path) + # no need to hash here as we are in tarfile_artifacts which are already hashed + artifact_uniq_name = artifact_base_name + shutil.copy2(artifact_base_name, os.path.join(mridc_file_folder, artifact_uniq_name)) + + # Update artifacts registry + new_artiitem = mridc.utils.model_utils.ArtifactItem() + new_artiitem.path = f"mridc:{artifact_uniq_name}" + new_artiitem.path_type = mridc.utils.model_utils.ArtifactPathType.TAR_PATH + model.artifacts[conf_path] = new_artiitem + finally: + # change back working directory + os.chdir(cwd) + + @staticmethod + def _update_artifact_paths(model, path2yaml_file): + """ + This method is called by ModelPT.save_to() and ModelPT.load_from() to update the artifact paths in the + model. + """ + if model.artifacts is not None and len(model.artifacts) > 0: + conf = OmegaConf.load(path2yaml_file) + for conf_path, item in model.artifacts.items(): + if item.hashed_path is None: + OmegaConf.update(conf, conf_path, item.path) + else: + OmegaConf.update(conf, conf_path, item.hashed_path) + with open(path2yaml_file, "w", encoding="utf-8") as fout: + OmegaConf.save(config=conf, f=fout, resolve=True) + + @staticmethod + def _inject_model_parallel_rank_for_ckpt(dirname, basename): + """ + This method is called by ModelPT.save_to() and ModelPT.load_from() to inject the parallel rank of the process + into the checkpoint file name. + """ + model_weights = os.path.join(dirname, basename) + model_weights = mridc.utils.model_utils.inject_model_parallel_rank(model_weights) + return model_weights + + @staticmethod + def _make_mridc_file_from_folder(filename, source_dir): + """This method is called by ModelPT.save_to() and ModelPT.load_from() to create a mridc file from a folder.""" + dirname = os.path.dirname(filename) + os.makedirs(dirname, exist_ok=True) + with tarfile.open(filename, "w") as tar: + tar.add(source_dir, arcname=".") + + @staticmethod + def _unpack_mridc_file(path2file: str, out_folder: str) -> str: + """This method is called by ModelPT.save_to() and ModelPT.load_from() to unpack a mridc file.""" + if not os.path.exists(path2file): + raise FileNotFoundError(f"{path2file} does not exist") + # we start with an assumption of uncompressed tar, which should be true for versions 1.7.0 and above + tar_header = "r:" + try: + tar_test = tarfile.open(path2file, tar_header) + tar_test.close() + except tarfile.ReadError: + # can be older checkpoint => try compressed tar + tar_header = "r:gz" + tar = tarfile.open(path2file, tar_header) + tar.extractall(path=out_folder) + tar.close() + return out_folder + + @staticmethod + def _save_state_dict_to_disk(state_dict, filepath): + """This method is called by ModelPT.save_to() and ModelPT.load_from() to save the state dict to disk.""" + torch.save(state_dict, filepath) + + @staticmethod + def _load_state_dict_from_disk(model_weights, map_location=None): + """This method is called by ModelPT.save_to() and ModelPT.load_from() to load the state dict from disk.""" + return torch.load(model_weights, map_location=map_location) + + @property + def model_config_yaml(self) -> str: + """This property is used to get the path to the model config yaml file.""" + return self._model_config_yaml + + @model_config_yaml.setter + def model_config_yaml(self, path: str): + """This property is used to set the path to the model config yaml file.""" + self._model_config_yaml = path + + @property + def model_weights_ckpt(self) -> str: + """This property is used to get the path to the model weights ckpt file.""" + return self._model_weights_ckpt + + @model_weights_ckpt.setter + def model_weights_ckpt(self, path: str): + """This property is used to set the path to the model weights ckpt file.""" + self._model_weights_ckpt = path
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/neural_types/axes.html b/docs/build/html/_modules/mridc/core/neural_types/axes.html new file mode 100644 index 00000000..94e411cf --- /dev/null +++ b/docs/build/html/_modules/mridc/core/neural_types/axes.html @@ -0,0 +1,201 @@ + + + + + + mridc.core.neural_types.axes — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.neural_types.axes
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.neural_types.axes

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/neural_types/axes.py
+
+from enum import Enum
+from typing import Optional
+
+__all__ = ["AxisKindAbstract", "AxisKind", "AxisType"]
+
+
+
[docs]class AxisKindAbstract(Enum): + """ + This is an abstract Enum to represents what does varying axis dimension mean. In practice, you will almost always + use AxisKind Enum. This Enum should be inherited by your OWN Enum if you aren't satisfied with AxisKind. Then your + own Enum can be used instead of AxisKind. + """
+ + +
[docs]class AxisKind(AxisKindAbstract): + """ + This Enum represents what does varying axis dimension mean. For example, does this dimension correspond to width, \ + batch, time, etc. The "Dimension" and "Channel" kinds are the same and used to represent a general axis. "Any" \ + axis will accept any axis kind fed to it. + """ + + # TODO (wdika): change names of the enums + Batch = 0 + Time = 1 + Dimension = 2 + Channel = 2 + Width = 3 + Height = 4 + Any = 5 + Sequence = 6 + FlowGroup = 7 + Singleton = 8 # Used to represent a axis that has size 1 + +
[docs] def __repr__(self): + """Returns short string representation of the AxisKind""" + return self.__str__()
+ +
[docs] def __str__(self): + """Returns short string representation of the AxisKind""" + return str(self.name).lower()
+ +
[docs] def t_with_string(self, text): + """It checks if text is 't_<any string>'""" + return text.startswith("t_") and text.endswith("_") and text[2:-1] == self.__str__()
+ +
[docs] @staticmethod + def from_str(label): + """Returns AxisKind instance based on short string representation""" + _label = label.lower().strip() + if _label in ("b", "n", "batch"): + return AxisKind.Batch + if _label == "t" or _label == "time" or (len(_label) > 2 and _label.startswith("t_")): + return AxisKind.Time + if _label in ("d", "c", "channel"): + return AxisKind.Dimension + if _label in ("w", "width"): + return AxisKind.Width + if _label in ("h", "height"): + return AxisKind.Height + if _label in ("s", "singleton"): + return AxisKind.Singleton + if _label in ("seq", "sequence"): + return AxisKind.Sequence + if _label == "flowgroup": + return AxisKind.FlowGroup + if _label == "any": + return AxisKind.Any + raise ValueError(f"Can't create AxisKind from {label}")
+ + +
[docs]class AxisType: + """This class represents axis semantics and (optionally) it's dimensionality + + Parameters + ---------- + kind: what kind of axis it is? For example Batch, Height, etc. + AxisKindAbstract + size: specify if the axis should have a fixed size. By default, it is set to None and you typically do not want to + set it for Batch and Time. + (int, optional) + is_list: whether this is a list or a tensor axis. + (bool, default=False) + """ + + def __init__(self, kind: AxisKindAbstract, size: Optional[int] = None, is_list=False): + if size is not None and is_list: + raise ValueError("The axis can't be list and have a fixed size") + self.kind = kind + self.size = size + self.is_list = is_list + +
[docs] def __repr__(self): + """Returns short string representation of the AxisType""" + if self.size is None: + representation = str(self.kind) + else: + representation = f"{str(self.kind)}:{self.size}" + if self.is_list: + representation += "_listdim" + return representation
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/neural_types/comparison.html b/docs/build/html/_modules/mridc/core/neural_types/comparison.html new file mode 100644 index 00000000..636d9d58 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/neural_types/comparison.html @@ -0,0 +1,118 @@ + + + + + + mridc.core.neural_types.comparison — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.neural_types.comparison
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.neural_types.comparison

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/neural_types/comparison.py
+
+from enum import Enum
+
+__all__ = ["NeuralTypeComparisonResult"]
+
+
+
[docs]class NeuralTypeComparisonResult(Enum): + """The result of comparing two neural type objects for compatibility. When comparing A.compare_to(B).""" + + SAME = 0 + LESS = 1 # A is B + GREATER = 2 # B is A + DIM_INCOMPATIBLE = 3 # Resize connector might fix incompatibility + TRANSPOSE_SAME = 4 # A transpose and/or converting between lists and tensors will make them same + CONTAINER_SIZE_MISMATCH = 5 # A and B contain different number of elements + INCOMPATIBLE = 6 # A and B are incompatible + SAME_TYPE_INCOMPATIBLE_PARAMS = 7 # A and B are of the same type but parametrized differently + UNCHECKED = 8 # type comparison wasn't done
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/neural_types/elements.html b/docs/build/html/_modules/mridc/core/neural_types/elements.html new file mode 100644 index 00000000..0e928c5d --- /dev/null +++ b/docs/build/html/_modules/mridc/core/neural_types/elements.html @@ -0,0 +1,354 @@ + + + + + + mridc.core.neural_types.elements — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.neural_types.elements
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.neural_types.elements

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/neural_types/elements.py
+
+from abc import ABC, ABCMeta
+from typing import Dict, Optional, Tuple
+
+__all__ = [
+    "ElementType",
+    "VoidType",
+    "ChannelType",
+    "MRISignal",
+    "RecurrentsType",
+    "LabelsType",
+    "LogprobsType",
+    "ProbsType",
+    "LossType",
+    "RegressionValuesType",
+    "CategoricalValuesType",
+    "PredictionsType",
+    "LengthsType",
+    "MaskType",
+    "Target",
+    "ReconstructionTarget",
+    "ImageFeatureValue",
+    "Index",
+    "ImageValue",
+    "NormalizedImageValue",
+    "StringLabel",
+    "StringType",
+    "Length",
+    "IntType",
+    "FloatType",
+    "NormalDistributionSamplesType",
+    "NormalDistributionMeanType",
+    "NormalDistributionLogVarianceType",
+    "LogDeterminantType",
+    "SequenceToSequenceAlignmentType",
+]
+
+from mridc.core.neural_types.comparison import NeuralTypeComparisonResult
+
+
+
[docs]class ElementType(ABC): + """Abstract class defining semantics of the tensor elements. We are relying on Python for inheritance checking""" + +
[docs] def __str__(self): + """Override this method to provide a human readable representation of the type""" + return self.__doc__
+ +
[docs] def __repr__(self): + """Override this method to provide a human readable representation of the type""" + return self.__class__.__name__
+ + @property + def type_parameters(self) -> Dict: + """ + Override this property to parametrize your type. For example, you can specify 'storage' type such as float, + int, bool with 'dtype' keyword. Another example, is if you want to represent a signal with a particular + property (say, sample frequency), then you can put sample_freq->value in there. When two types are compared + their type_parameters must match." + """ + return {} + + @property + def fields(self) -> Optional[Tuple]: + """ + This should be used to logically represent tuples/structures. For example, if you want to represent a \ + bounding box (x, y, width, height) you can put a tuple with names ('x', y', 'w', 'h') in here. Under the \ + hood this should be converted to the last tensor dimension of fixed size = len(fields). When two types are \ + compared their fields must match. + """ + return None + +
[docs] def compare(self, second) -> NeuralTypeComparisonResult: + """Override this method to provide a comparison between two types.""" + # First, check general compatibility + first_t = type(self) + second_t = type(second) + + if first_t == second_t: + result = NeuralTypeComparisonResult.SAME + elif issubclass(first_t, second_t): + result = NeuralTypeComparisonResult.LESS + elif issubclass(second_t, first_t): + result = NeuralTypeComparisonResult.GREATER + else: + result = NeuralTypeComparisonResult.INCOMPATIBLE + + if result != NeuralTypeComparisonResult.SAME: + return result + # now check that all parameters match + check_params = set(self.type_parameters.keys()) == set(second.type_parameters.keys()) + if not check_params: + return NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS + for k1, v1 in self.type_parameters.items(): + if v1 is None or second.type_parameters[k1] is None: + # Treat None as Void + continue + if v1 != second.type_parameters[k1]: + return NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS + # check that all fields match + if self.fields == second.fields: + return NeuralTypeComparisonResult.SAME + return NeuralTypeComparisonResult.INCOMPATIBLE
+ + +
[docs]class VoidType(ElementType): + """ + Void-like type which is compatible with everything. It is a good practice to use this type only as necessary. + For example, when you need template-like functionality. + """ + +
[docs] def compare(cls, second: ABCMeta) -> NeuralTypeComparisonResult: + """Void type is compatible with everything.""" + return NeuralTypeComparisonResult.SAME
+ + +# TODO: Consider moving these files elsewhere +
[docs]class ChannelType(ElementType): + """Element to represent convolutional input/output channel."""
+ + +
[docs]class RecurrentsType(ElementType): + """Element type to represent recurrent layers"""
+ + +
[docs]class LengthsType(ElementType): + """Element type representing lengths of something"""
+ + +
[docs]class ProbsType(ElementType): + """Element type to represent probabilities. For example, outputs of softmax layers."""
+ + +
[docs]class LogprobsType(ElementType): + """Element type to represent log-probabilities. For example, outputs of log softmax layers."""
+ + +
[docs]class LossType(ElementType): + """Element type to represent outputs of Loss modules"""
+ + +
[docs]class MRISignal(ElementType): + """ + Element type to represent encoded representation returned by the mri model + + Parameters + ---------- + freq: sampling frequency of a signal. Note that two signals will only be the same if their freq is the same. + """ + + def __init__(self, freq: int = None): + self._params = {"freq": freq} + + @property + def type_parameters(self): + """Returns the type parameters of the element type.""" + return self._params
+ + +
[docs]class LabelsType(ElementType): + """Element type to represent labels of something. For example, labels of a dataset."""
+ + +
[docs]class PredictionsType(LabelsType): + """Element type to represent some sort of predictions returned by model"""
+ + +
[docs]class RegressionValuesType(PredictionsType): + """Element type to represent labels for regression task"""
+ + +
[docs]class CategoricalValuesType(PredictionsType): + """Element type to represent labels for categorical classification task"""
+ + +
[docs]class MaskType(PredictionsType): + """Element type to represent a boolean mask"""
+ + +
[docs]class Index(ElementType): + """Type representing an element being an index of the sample."""
+ + +
[docs]class Target(ElementType): + """Type representing an element being a target value."""
+ + +
[docs]class ReconstructionTarget(Target): + """ + Type representing an element being target value in the reconstruction task, i.e. identifier of a desired + class. + """
+ + +
[docs]class ImageValue(ElementType): + """Type representing an element/value of a single image channel,"""
+ + +
[docs]class NormalizedImageValue(ImageValue): + """Type representing an element/value of a single image channel normalized to <0-1> range."""
+ + +
[docs]class ImageFeatureValue(ImageValue): + """Type representing an element (single value) of a (image) feature maps."""
+ + +
[docs]class StringType(ElementType): + """Element type representing a single string"""
+ + +
[docs]class StringLabel(StringType): + """Type representing a label being a string with class name (e.g. the "hamster" class in CIFAR100)."""
+ + +class BoolType(ElementType): + """Element type representing a single integer""" + + +
[docs]class IntType(ElementType): + """Element type representing a single integer"""
+ + +
[docs]class FloatType(ElementType): + """Element type representing a single float"""
+ + +
[docs]class Length(IntType): + """Type representing an element storing a "length" (e.g. length of a list)."""
+ + +class ProbabilityDistributionSamplesType(ElementType): + """Element to represent tensors that meant to be sampled from a valid probability distribution""" + + +
[docs]class NormalDistributionSamplesType(ProbabilityDistributionSamplesType): + """Element to represent tensors that meant to be sampled from a valid normal distribution"""
+ + +
[docs]class SequenceToSequenceAlignmentType(ElementType): + """ + Class to represent the alignment from seq-to-seq attention outputs. Generally a mapping from encoder time steps + to decoder time steps. + """
+ + +
[docs]class NormalDistributionMeanType(ElementType): + """Element to represent the mean of a normal distribution"""
+ + +
[docs]class NormalDistributionLogVarianceType(ElementType): + """Element to represent the log variance of a normal distribution"""
+ + +
[docs]class LogDeterminantType(ElementType): + """Element for representing log determinants usually used in flow models"""
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/neural_types/neural_type.html b/docs/build/html/_modules/mridc/core/neural_types/neural_type.html new file mode 100644 index 00000000..c5c1eb13 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/neural_types/neural_type.html @@ -0,0 +1,306 @@ + + + + + + mridc.core.neural_types.neural_type — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.neural_types.neural_type
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.neural_types.neural_type

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/neural_types/neural_type.py
+
+from typing import Optional, Tuple
+
+__all__ = ["NeuralType", "NeuralTypeError", "NeuralPortNameMismatchError", "NeuralPortNmTensorMismatchError"]
+
+from mridc.core.neural_types.axes import AxisKind, AxisType
+from mridc.core.neural_types.comparison import NeuralTypeComparisonResult
+from mridc.core.neural_types.elements import ElementType, VoidType
+
+
+
[docs]class NeuralType: + """ + This is the main class which would represent neural type concept. It is used to represent *the types* of inputs and + outputs. + + Parameters + ---------- + axes: a tuple of AxisTypes objects representing the semantics of what varying each axis means. You can use a short, + string-based form here. For example: ('B', 'C', 'H', 'W') would correspond to an NCHW format frequently used in + computer vision. ('B', 'T', 'D') is frequently used for signal processing and means + [batch, time, dimension/channel]. + elements_type: an instance of ElementType class representing the semantics of what is stored inside the tensor. + For example: logits (LogitsType), log probabilities (LogprobType), etc. + optional: By default, this is false. If set to True, it would mean that input to the port of this type can be + optional. + """ + + def __str__(self): + if self.axes is not None: + return f"axes: {self.axes}; elements_type: {self.elements_type.__class__.__name__}" + return f"axes: None; elements_type: {self.elements_type.__class__.__name__}" + + def __init__(self, axes: Optional[Tuple] = None, elements_type: ElementType = VoidType(), optional=False): + if not isinstance(elements_type, ElementType): + raise ValueError( + "elements_type of NeuralType must be an instance of a class derived from ElementType. " + "Did you pass a class instead?" + ) + self.elements_type = elements_type + if axes is not None: + NeuralType.__check_sanity(axes) + axes_list = [] + for axis in axes: + if isinstance(axis, str): + axes_list.append(AxisType(AxisKind.from_str(axis), None)) + elif isinstance(axis, AxisType): + axes_list.append(axis) + else: + raise ValueError("axis type must be either str or AxisType instance") + self.axes = tuple(axes_list) # type: ignore + else: + self.axes = None # type: ignore + self.optional = optional + +
[docs] def compare(self, second) -> NeuralTypeComparisonResult: + """ + Performs neural type comparison of self with second. When you chain two modules' inputs/outputs via __call__ + method, this comparison will be called to ensure neural type compatibility. + """ + # First, handle dimensionality + axes_a = self.axes + axes_b = second.axes + + # "Big void" type + if isinstance(self.elements_type, VoidType) and self.axes is None: + return NeuralTypeComparisonResult.SAME + + if self.axes is None: + if second.axes is None: + return self.elements_type.compare(second.elements_type) + return NeuralTypeComparisonResult.INCOMPATIBLE + + dimensions_pass = NeuralType.__compare_axes(axes_a, axes_b) # type: ignore + element_comparison_result = self.elements_type.compare(second.elements_type) + + # SAME DIMS + if dimensions_pass == 0: + return element_comparison_result + # TRANSPOSE_SAME DIMS + if dimensions_pass == 1 and element_comparison_result == NeuralTypeComparisonResult.SAME: + return NeuralTypeComparisonResult.TRANSPOSE_SAME + if ( + dimensions_pass == 1 + or dimensions_pass == 2 + and element_comparison_result != NeuralTypeComparisonResult.SAME + ): + return NeuralTypeComparisonResult.INCOMPATIBLE + if dimensions_pass == 2: + return NeuralTypeComparisonResult.DIM_INCOMPATIBLE + return NeuralTypeComparisonResult.INCOMPATIBLE
+ +
[docs] def compare_and_raise_error(self, parent_type_name, port_name, second_object): + """Method compares definition of one type with another and raises an error if not compatible.""" + type_compatibility = self.compare(second_object) + if type_compatibility not in (NeuralTypeComparisonResult.SAME, NeuralTypeComparisonResult.GREATER): + raise NeuralPortNmTensorMismatchError( + parent_type_name, port_name, str(self), str(second_object.ntype), type_compatibility + )
+ +
[docs] def __eq__(self, other): + """Checks if two NeuralTypes are equal.""" + return self.compare(other) if isinstance(other, NeuralType) else False
+ + @staticmethod + def __check_sanity(axes): + """Check that list come before any tensor dimension""" + are_strings = True + for axis in axes: + if not isinstance(axis, str): + are_strings = False + if isinstance(axis, str) and not are_strings: + raise ValueError("Either use full class names or all strings") + if are_strings: + return + checks_passed = True + saw_tensor_dim = False + for axis in axes: + if not axis.is_list: + saw_tensor_dim = True + elif saw_tensor_dim: # which is preceded by tensor dim + checks_passed = False + if not checks_passed: + raise ValueError( + "You have list dimension after Tensor dimension. All list dimensions must preceded Tensor dimensions" + ) + + @staticmethod + def __compare_axes(axes_a, axes_b) -> int: + """ + Compares axes_a and axes_b + Args: + axes_a: first axes tuple + axes_b: second axes tuple + Returns: + 0 - if they are exactly the same + 1 - if they are "TRANSPOSE_SAME" + 2 - if they are "DIM_INCOMPATIBLE" + 3 - if they are different + """ + if axes_a is None: + return 0 if axes_b is None else 3 + if axes_b is None: + return 3 + if len(axes_a) != len(axes_b): + return 3 + # After these ifs we know that len(axes_a) == len(axes_b) + + same = True + kinds_a = {} + kinds_b = {} + for axis_a, axis_b in zip(axes_a, axes_b): + kinds_a[axis_a.kind] = axis_a.size + kinds_b[axis_b.kind] = axis_b.size + if axis_a.kind == AxisKind.Any: + same = True + elif ( + axis_a.kind != axis_b.kind + or axis_a.is_list != axis_b.is_list + or (axis_a.size != axis_b.size and axis_a.size is not None) + ): + same = False + if same: + return 0 + # can be TRANSPOSE_SAME, DIM_INCOMPATIBLE + if kinds_a.keys() == kinds_b.keys(): + return next((2 for key, value in kinds_a.items() if kinds_b[key] != value), 1) + return 3 + +
[docs] def __repr__(self): + """Returns string representation of NeuralType.""" + axes = str(self.axes) if self.axes is not None else "None" + if self.elements_type is not None: + element_type = repr(self.elements_type) + else: + element_type = "None" + + data = f"axis={axes}, element_type={element_type}" + + if self.optional: + data = f"{data}, optional={self.optional}" + + return f"{self.__class__.__name__}({data})"
+ + +
[docs]class NeuralTypeError(Exception): + """Base class for neural type related exceptions."""
+ + +
[docs]class NeuralPortNameMismatchError(NeuralTypeError): + """Exception raised when neural module is called with incorrect port names.""" + + def __init__(self, input_port_name): + super().__init__() + self.message = "Wrong input port name: {0}".format(input_port_name)
+ + +
[docs]class NeuralPortNmTensorMismatchError(NeuralTypeError): + """Exception raised when a port is fed with a NmTensor of incompatible type.""" + + def __init__(self, class_name, port_name, first_type, second_type, type_compatibility): + super().__init__() + self.message = "\nIn {}. \nPort: {} and a NmTensor it was fed are \n".format( + class_name, port_name + ) + "of incompatible neural types:\n\n{} \n\n and \n\n{}".format(first_type, second_type) + + self.message += "\n\nType comparison result: {}".format(type_compatibility)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/optim/adafactor.html b/docs/build/html/_modules/mridc/core/optim/adafactor.html new file mode 100644 index 00000000..776c5b47 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/optim/adafactor.html @@ -0,0 +1,315 @@ + + + + + + mridc.core.optim.adafactor — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.optim.adafactor
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.optim.adafactor

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from:
+# https://github.com/wdika/NeMo/blob/9d095ff261319301e4711edf7530a6bb7cf6c8b6/nemo/core/optim/adafactor.py
+
+import math
+
+import torch
+from torch.optim.optimizer import Optimizer
+
+__all__ = ["Adafactor"]
+
+
+
[docs]class Adafactor(Optimizer): + """ + Implements Adafactor algorithm. + + This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` + (see https://arxiv.org/abs/1804.04235) + Note that this optimizer internally adjusts the learning rate depending on the *scale_parameter*, *relative_step* + and *warmup_init* options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` + and `relative_step=False`. + + Parameters + ---------- + params: Iterable of parameters to optimize or dicts defining parameter groups. + iterable + lr: External learning rate. + float (optional), (default: None) + eps: Regularization constants for square gradient and parameter scale respectively. + tuple (float, float), (default: (1e-30, 1e-3)) + clip_threshold: Threshold of root-mean-square of final gradient update. + float, (default: 1.0) + decay_rate: Coefficient used to compute running averages of square gradient. + float, (default: -0.8) + beta1: Coefficient used for computing running averages of gradient + float, (default: None) + weight_decay: Weight decay (L2 penalty). + float (optional), (default: 0) + scale_parameter: If True, learning rate is scaled by root-mean-square of parameter. + bool (default: True) + relative_step: If True, time-dependent learning rate is computed instead of external learning rate. + bool (default: True) + warmup_init: Time-dependent learning rate computation depends on whether warm-up initialization is being used. + bool (default: False) + + Returns + ------- + Adafactor Optimizer + """ + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + min_step=1e-2, + ): + if lr is not None and relative_step: + raise ValueError("Cannot combine manual lr and relative_step options") + if warmup_init and not relative_step: + raise ValueError("warmup_init requires relative_step=True") + self.min_step = min_step + + defaults = dict( + lr=lr, + eps=eps, + clip_threshold=clip_threshold, + decay_rate=decay_rate, + beta1=beta1, + weight_decay=weight_decay, + scale_parameter=scale_parameter, + relative_step=relative_step, + warmup_init=warmup_init, + min_step=min_step, + ) + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + """Whether optimizer supports memory efficient fp16""" + return True + + @property + def supports_flat_params(self): + """Whether the optimizer supports flat parameters.""" + return False + + def _get_lr(self, param_group, param_state): + """Returns the learning rate for the current layer.""" + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else self.min_step + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + +
[docs] def step(self, closure=None): + """ + Performs a single optimization step. + + Parameters + ---------- + closure: A closure that reevaluates the model and returns the loss. + callable (optional) + """ + loss = closure() if closure is not None else None + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = _get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = _rms(p_data_fp32) + group["lr"] = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) + + # Approximation of exponential moving average of square of gradient + update = _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((_rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(group["lr"]) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"]) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + + p_data_fp32.add_(-update) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss
+ + @staticmethod + def _get_options(param_group, param_shape): + """Returns the options for the current layer.""" + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + """Compute the root-mean-square of a tensor.""" + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + """ + Compute the square of the gradient, but approximate the sqrt using the exponential moving average of the + squared gradient. + """ + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/optim/lr_scheduler.html b/docs/build/html/_modules/mridc/core/optim/lr_scheduler.html new file mode 100644 index 00000000..ae0ada41 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/optim/lr_scheduler.html @@ -0,0 +1,1091 @@ + + + + + + mridc.core.optim.lr_scheduler — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.optim.lr_scheduler
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.optim.lr_scheduler

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/optim/lr_scheduler.py
+
+import copy
+import dataclasses
+import math
+import warnings
+from functools import partial
+from typing import Any, Dict, Optional, Union
+
+import hydra
+import torch.optim as optim
+import torch.optim.lr_scheduler as pt_scheduler
+import torch.utils.data.dataloader as dataloader
+from omegaconf import DictConfig, OmegaConf
+from torch.optim.lr_scheduler import _LRScheduler  # type: ignore
+
+from mridc.core.conf.schedulers import SchedulerParams, get_scheduler_config, register_scheduler_params
+from mridc.utils import logging
+from mridc.utils.model_utils import maybe_update_config_version
+
+
+
[docs]class WarmupPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. All arguments should be passed as kwargs for clarity. + + Parameters + ---------- + warmup_steps: Number of training steps in warmup stage. + warmup_ratio: Ratio of warmup steps to total steps. + max_steps: Total number of steps while training or `None` for infinite training. + + Returns + ------- + lr: Learning rate for current step. + """ + + def __init__(self, optimizer, *, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1): + """ + Parameters + ---------- + optimizer: optimizer + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for infinite training + min_lr: Minimum learning rate + last_epoch: Last epoch + """ + if warmup_steps is not None and warmup_ratio is not None: + raise AssertionError("Either use particular number of step or ratio") + if warmup_ratio is not None and max_steps is None: + raise AssertionError("If there is a ratio, there should be a total steps") + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + +
[docs] def get_lr(self): + """Get learning rate at current step.""" + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + ) + + step = self.last_epoch + + if 0 < self.warmup_steps >= step: + return self._get_warmup_lr(step) + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step)
+ + def _get_warmup_lr(self, step): + """Linear warmup""" + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs
+ + +
[docs]class SquareRootConstantPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. All arguments should be passed as kwargs for clarity. + + Parameters + ---------- + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for infinite training + """ + + def __init__( + self, optimizer, *, constant_steps=None, constant_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1 + ): + """ + Parameters + ---------- + optimizer: optimizer + constant_steps: Number of training steps in constant stage + constant_ratio: Ratio of constant steps to total steps + max_steps: Total number of steps while training or `None` for infinite training + min_lr: Minimum learning rate + last_epoch: Last epoch + """ + if constant_steps is not None and constant_ratio is not None: + raise AssertionError("Either use particular number of step or ratio") + + if constant_ratio is not None and max_steps is None: + raise AssertionError("If there is a ratio, there should be a total steps") + + # It is necessary to assign all attributes *before* __init__, as class is wrapped by an inner class. + self.max_steps = max_steps + if constant_steps is not None: + self.constant_steps = constant_steps + elif constant_ratio is not None: + self.constant_steps = int(constant_ratio * max_steps) + else: + self.constant_steps = 0 + + self.constant_lr = 1 / (constant_steps**0.5) + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + +
[docs] def get_lr(self): + """Get learning rate at current step.""" + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + ) + + step = self.last_epoch + + if step <= self.constant_steps: + return [self.constant_lr for _ in self.base_lrs] + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step)
+ + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs
+ + +
[docs]class WarmupHoldPolicy(WarmupPolicy): + """ + Variant of WarmupPolicy which maintains high learning rate for a defined number of steps. All arguments should be + passed as kwargs for clarity, + + Parameters + ---------- + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + hold_steps: Number of training steps to hold the learning rate after warm up + hold_ratio: Ratio of hold steps to total steps + max_steps: Total number of steps while training or `None` for infinite training + + Results + ------- + Learning rate is linearly increased from 0 to 1 over warmup steps, then linearly decreased from 1 to 0 over hold + steps. + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + hold_steps=None, + hold_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + """ + + Parameters + ---------- + optimizer: optimizer + warmup_steps: Number of training steps in warmup stage. + warmup_ratio: Ratio of warmup steps to total steps. + hold_steps: Number of training steps to hold the learning rate after warm up. + hold_ratio: Ratio of hold steps to total steps. + max_steps: Total number of steps while training or `None` for infinite training. + min_lr: Minimum learning rate. + last_epoch: Last epoch. + """ + if hold_steps is not None and hold_ratio is not None: + raise AssertionError("Either use particular number of step or ratio") + if hold_ratio is not None and max_steps is None: + raise AssertionError("If there is a ratio, there should be a total steps") + + self.min_lr = min_lr + self._last_warmup_lr = 0.0 + + # Necessary to duplicate as class attributes are hidden in inner class + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if hold_steps is not None: + self.hold_steps = hold_steps + self.warmup_steps + elif hold_ratio is not None: + self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps + else: + self.hold_steps = 0 + + super().__init__( + optimizer, + warmup_steps=warmup_steps, + warmup_ratio=warmup_ratio, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + ) + +
[docs] def get_lr(self): + """Get learning rate at current step.""" + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning + ) + + step = self.last_epoch + + # Warmup phase + if 0 < self.warmup_steps >= step: + return self._get_warmup_lr(step) + + # Hold phase + if self.hold_steps < step >= self.warmup_steps: + return self.base_lrs + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step)
+ + +
[docs]class WarmupAnnealHoldPolicy(_LRScheduler): + """ + Adds warmup kwargs and warmup logic to lr policy. All arguments should be passed as kwargs for clarity. + + Parameters + ---------- + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for infinite training + min_lr: Minimum lr to hold the learning rate after decay at. + constant_steps: Number of steps to keep lr constant at. + constant_ratio: Ratio of steps to keep lr constant. + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + constant_steps=None, + constant_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + """ + Parameters + ---------- + optimizer: Optimizer + warmup_steps: Number of training steps in warmup stage. + warmup_ratio: Ratio of warmup steps to total steps. + constant_steps: Number of steps to keep lr constant at. + constant_ratio: Ratio of steps to keep lr constant. + max_steps: Total number of steps while training or `None` for infinite training. + min_lr: Minimum lr to hold the learning rate after decay at. + last_epoch: The index of last epoch. + """ + if warmup_steps is not None and warmup_ratio is not None: + raise AssertionError("Either use particular number of step or ratio") + if constant_steps is not None and constant_ratio is not None: + raise AssertionError("Either use constant_steps or constant_ratio") + if warmup_ratio is not None and max_steps is None: + raise AssertionError("If there is a ratio, there should be a total steps") + + # It is necessary to assign all attributes *before* __init__, as class is wrapped by an inner class. + self.max_steps = max_steps + + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if constant_steps is not None: + self.constant_steps = constant_steps + elif constant_ratio is not None: + self.constant_steps = int(constant_ratio * max_steps) + else: + self.constant_steps = 0 + + self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps) + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + +
[docs] def get_lr(self): + """Get learning rate at current step.""" + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + ) + + step = self.last_epoch + + # Warmup steps + if 0 < self.warmup_steps >= step: + return self._get_warmup_lr(step) + + # Constant steps after warmup and decay + if self.constant_steps > 0 and (self.warmup_steps + self.decay_steps) < step <= self.max_steps: + return self._get_constant_lr(step) + + # Min lr after max steps of updates + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step)
+ + def _get_warmup_lr(self, step): + """Get learning rate at warmup stage.""" + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + def _get_constant_lr(self, step): + """Get learning rate at constant stage.""" + return [self.min_lr for _ in self.base_lrs] + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs
+ + +def _sqrt_annealing(initial_lr, step, max_steps, min_lr): + """Anneal learning rate by sqrt.""" + mult = ((max_steps - step) / max_steps) ** 0.5 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _square_annealing(initial_lr, step, max_steps, min_lr): + """Anneal learning rate by square.""" + mult = ((max_steps - step) / max_steps) ** 2 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _cosine_annealing(initial_lr, step, max_steps, min_lr): + """Anneal learning rate by cosine.""" + mult = 0.5 * (1 + math.cos(math.pi * step / max_steps)) + return (initial_lr - min_lr) * mult + min_lr + + +def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step, decay_steps, min_lr): + """Anneal learning rate by linear warmup and cosine annealing.""" + if max_lr <= min_lr: + raise AssertionError + # Use linear warmup for the initial part. + if warmup_steps > 0 and step <= warmup_steps: + return max_lr * float(step) / float(warmup_steps) + + # For any steps larger than `decay_steps`, use `min_lr`. + if step > warmup_steps + decay_steps: + return min_lr + + # If we are done with the warmup period, use the decay style. + num_steps_ = step - warmup_steps + decay_steps_ = decay_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + if decay_ratio < 0.0: + raise AssertionError + if decay_ratio > 1.0: + raise AssertionError + delta_lr = max_lr - min_lr + + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + + return min_lr + coeff * delta_lr + + +def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle): + """Polynomial decay of learning rate.""" + if cycle: + multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps) + decay_steps *= multiplier + else: + step = min(step, decay_steps) + p = step / decay_steps + lr = (initial_lr - min_lr) * math.pow(1.0 - p, power) + lr += min_lr + return lr + + +
[docs]class SquareAnnealing(WarmupPolicy): + """Anneal learning rate by square.""" + + def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + """Get learning rate at current step.""" + return [ + _square_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) + for initial_lr in self.base_lrs + ]
+ + +
[docs]class SquareRootAnnealing(WarmupPolicy): + """Anneal learning rate by square root.""" + + def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + """Get learning rate at current step.""" + return [ + _sqrt_annealing( + initial_lr=initial_lr, + step=step, + max_steps=self.max_steps, + min_lr=self.min_lr, + ) + for initial_lr in self.base_lrs + ]
+ + +
[docs]class CosineAnnealing(WarmupAnnealHoldPolicy): + """Anneal learning rate by cosine.""" + + def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + """Get learning rate at current step.""" + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate that was lower than the minimum learning rate." + ) + + return ( + [ + _cosine_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) + for initial_lr in self.base_lrs + ] + if self.constant_steps is None or self.constant_steps == 0 + else self._get_linear_warmup_with_cosine_annealing_lr(step) + ) + + def _get_warmup_lr(self, step): + """Get the warmup learning rate for the given step.""" + if self.constant_steps is None or self.constant_steps == 0: + return super()._get_warmup_lr(step) + + # Use linear warmup for the initial part. + return self._get_linear_warmup_with_cosine_annealing_lr(step) + + def _get_constant_lr(self, step): + """Only called when constant_steps is not None and not 0.""" + return self._get_linear_warmup_with_cosine_annealing_lr(step) + + def _get_linear_warmup_with_cosine_annealing_lr(self, step): + """Cosine Schedule, slightly different warmup schedule + constant LR at the end.""" + return [ + _linear_warmup_with_cosine_annealing( + max_lr=self.base_lrs[0], + warmup_steps=self.warmup_steps, + step=step, + decay_steps=self.decay_steps, + min_lr=self.min_lr, + ) + for _ in self.base_lrs + ]
+ + +
[docs]class NoamAnnealing(_LRScheduler): + """Noam learning rate annealing.""" + + def __init__( + self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1 + ): + self._normalize = d_model ** (-0.5) + if warmup_steps is not None and warmup_ratio is not None: + raise AssertionError("Either use particular number of step or ratio") + if warmup_ratio is not None and max_steps is None: + raise AssertionError("If there is a ratio, there should be a total steps") + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + +
[docs] def get_lr(self): + """Get learning rate at current step.""" + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + ) + + step = max(1, self.last_epoch) + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate that was lower than the minimum learning rate." + ) + + return [self._noam_annealing(initial_lr=initial_lr, step=step) for initial_lr in self.base_lrs]
+ + def _noam_annealing(self, initial_lr, step): + """Noam learning rate annealing.""" + mult = self._normalize * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5))) + out_lr = initial_lr * mult + if step > self.warmup_steps: + out_lr = max(out_lr, self.min_lr) + return out_lr
+ + +
[docs]class WarmupAnnealing(WarmupPolicy): + """Warmup learning rate annealing.""" + + def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + """Get learning rate at current step.""" + delta_lr = self.base_lrs[0] - self.min_lr + mult = (step - self.warmup_steps) / (self.max_steps - self.warmup_steps) + return [self.min_lr + (1 - mult) * delta_lr for _ in self.base_lrs]
+ + +
[docs]class InverseSquareRootAnnealing(WarmupPolicy): + """Inverse square root learning rate annealing.""" + + def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr) + + def _get_lr(self, step): + """Get learning rate at current step.""" + denom = ((step + 1) / (self.warmup_steps + 1)) ** 0.5 + return [initial_lr / denom for initial_lr in self.base_lrs]
+ + +
[docs]class T5InverseSquareRootAnnealing(SquareRootConstantPolicy): + """Inverse square root learning rate annealing.""" + + def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr) + + def _get_lr(self, step): + """Get learning rate at current step.""" + return [1 / (step**0.5) for _ in self.base_lrs]
+ + +
[docs]class PolynomialDecayAnnealing(WarmupPolicy): + """Polynomial decay learning rate annealing.""" + + def __init__(self, optimizer, *, max_steps, min_lr=0.0, power=1.0, cycle=False, last_epoch=-1, **kwargs): + self.power = power + self.cycle = cycle + + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + """Get learning rate at current step.""" + return [ + _poly_decay( + initial_lr, + step=step - self.warmup_steps, + decay_steps=self.max_steps - self.warmup_steps, + power=self.power, + min_lr=self.min_lr, + cycle=self.cycle, + ) + for initial_lr in self.base_lrs + ]
+ + +
[docs]class PolynomialHoldDecayAnnealing(WarmupHoldPolicy): + """Polynomial decay learning rate annealing.""" + + def __init__(self, optimizer, *, max_steps, min_lr=0.0, power=1.0, cycle=False, last_epoch=-1, **kwargs): + self.power = power + self.cycle = cycle + + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + """Get learning rate at current step.""" + return [ + _poly_decay( + initial_lr, + step=step - self.hold_steps, + decay_steps=self.max_steps - max(self.warmup_steps, self.hold_steps), + power=self.power, + min_lr=self.min_lr, + cycle=self.cycle, + ) + for initial_lr in self.base_lrs + ]
+ + +
[docs]def register_scheduler(name: str, scheduler: _LRScheduler, scheduler_params: SchedulerParams): + """ + Checks if the scheduler name exists in the registry, and if it doesn't, adds it. + This allows custom schedulers to be added and called by name during instantiation. + + Parameters + ---------- + name: Name of the optimizer. Will be used as key to retrieve the optimizer. + scheduler: Scheduler class (inherits from _LRScheduler) + scheduler_params: The parameters as a dataclass of the scheduler + """ + if name in AVAILABLE_SCHEDULERS: + raise ValueError(f"Cannot override pre-existing schedulers. Conflicting scheduler name = {name}") + + AVAILABLE_SCHEDULERS[name] = scheduler + + sched_name = f"{scheduler.__name__}_params" + register_scheduler_params(name=sched_name, scheduler_params=scheduler_params)
+ + +
[docs]def get_scheduler(name: str, **kwargs: Optional[Dict[str, Any]]) -> _LRScheduler: + """ + Convenience method to obtain an _LRScheduler class and partially instantiate it with optimizer kwargs. + + Parameters + ---------- + name: Name of the scheduler in the registry. + kwargs: Optional kwargs of the scheduler used during instantiation. + + Returns + ------- + A partially instantiated _LRScheduler + """ + if name not in AVAILABLE_SCHEDULERS: + raise ValueError( + f"Cannot resolve scheduler{name}'. Available optimizers are : " f"{AVAILABLE_SCHEDULERS.keys()}" + ) + + scheduler_cls = AVAILABLE_SCHEDULERS[name] + return partial(scheduler_cls, **kwargs)
+ + +
[docs]def prepare_lr_scheduler( + optimizer: optim.Optimizer, + scheduler_config: Union[Dict[str, Any], DictConfig, None], + train_dataloader: Optional[dataloader.DataLoader] = None, +) -> Optional[Dict[str, Any]]: + """ + Constructs an LR Scheduler (optionally) for a given optimizer, based on a config with the following schema. + + Parameters + ---------- + optimizer: The optimizer to use for the scheduler. + name: <name of optimizer> + + lr: <maximal learning rate> + + # <additional optimizer arguments> + + args: + + name: auto # special keyword, resolves to correct optimizer config for given optimizer name + + # cls: mridc.core.config.optimizers.NovogradParams # explicit instantiation by class path + + params: # optional override parameters for the optimizer config + + betas: [0.8, 0.5] + + weight_decay: 0.001 + + scheduler_config: The scheduler config. + + name: <name of scheduler> + + iters_per_batch: null # computed at runtime; mandatory to have + + max_steps: null # computed at runtime or explicitly set here; mandatory to have + + # pytorch lightning args <mandatory> + + monitor: val_loss + + reduce_on_plateau: false + + # <scheduler config override> + + args: + + name: auto # special keyword, resolves to correct optimizer config for given optimizer name + + # cls: mridc.core.config.schedulers.CosineAnnealingParams # explicit instantiation by class path + + params: # optional override parameters for the optimizer config + + warmup_steps: null + + warmup_ratio: null + + min_lr: 0.0 + + last_epoch: -1 + + train_dataloader: Optional requirement, must be passed if "iters_per_batch" is defined instead of "max_steps". \ + Used to compute effective "max_steps". + + Returns + ------- + A dictionary containing the LR Scheduler implementation if the config was successfully parsed along with other \ + parameters required by Pytorch Lightning, otherwise None. + """ + if scheduler_config is not None: + scheduler_config = maybe_update_config_version(scheduler_config) + + # Build nested dictionary for convenience out of structured objects + if isinstance(scheduler_config, DictConfig): + scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True) + + elif dataclasses.is_dataclass(scheduler_config): + # Recursively transform data classes to basic dictionaries + scheduler_config = OmegaConf.create(scheduler_config) + scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True) + + # Test to see if config follows above schema + + add_max_args_flag = True + interval = "step" + if scheduler_config is not None: + if "args" in scheduler_config: + scheduler_args = scheduler_config.pop("args") + else: + scheduler_args = copy.deepcopy(scheduler_config) + + # Remove extra parameters from scheduler_args nest + # Assume all other parameters are to be passed into scheduler constructor + + if "name" in scheduler_args and scheduler_args["name"] == "ReduceLROnPlateau": + add_max_args_flag = False + interval = "epoch" + + scheduler_args.pop("name", None) + scheduler_args.pop("t_max_epochs", None) + scheduler_args.pop("t_accumulate_grad_batches", None) + scheduler_args.pop("t_limit_train_batches", None) + scheduler_args.pop("t_num_workers", None) + scheduler_args.pop("monitor", None) + scheduler_args.pop("reduce_on_plateau", None) + + else: + # Return gracefully in case `sched` was not supplied; inform user + logging.info("Scheduler not initialized as no `sched` config supplied to setup_optimizer()") + return None + + # Try instantiation of scheduler params from config class path + if "_target_" in scheduler_args: + scheduler_args_cfg = OmegaConf.create(scheduler_args) + scheduler_conf = hydra.utils.instantiate(scheduler_args_cfg) + scheduler_args = vars(scheduler_conf) + + # Get name of the scheduler + scheduler_name = scheduler_conf.__class__.__name__ + + if "Params" in scheduler_name: + scheduler_name = scheduler_name.replace("Params", "") + + else: + # Class path instantiation failed; try resolving "name" component + + # Get name of the scheduler + if "name" in scheduler_config: + scheduler_name = scheduler_config["name"] + else: + logging.warning( + "Could not resolve classpath for Scheduler Config, and `name` " + "was not provided either. \n" + "Scheduler cannot be instantiated !" + ) + return None + + # If class path was not provided, perhaps `name` is provided for resolution + if "name" in scheduler_args: + # If `auto` is passed as name for resolution of optimizer name, + # then lookup optimizer name and resolve its parameter config + if scheduler_args["name"] == "auto": + scheduler_params_name = f"{scheduler_name}Params" + else: + scheduler_params_name = scheduler_args["name"] + + # Get override arguments provided in the config yaml file / Dict Config + scheduler_params_override = scheduler_args.get("params", {}) + + # If params is itself a dict config object provided explicitly in Dict Config + # Resolve to dictionary for convenience + if isinstance(scheduler_params_override, DictConfig): + scheduler_params_override = OmegaConf.to_container(scheduler_params_override, resolve=True) + + # Get and instantiate the Config dataclass for this scheduler + scheduler_params_cls = get_scheduler_config(scheduler_params_name, **scheduler_params_override) + scheduler_params = scheduler_params_cls # instantiate the parameters object + scheduler_args = vars(scheduler_params) # extract just the dictionary from the Config object + + # Extract value to monitor in losses, if provided. + if "monitor" in scheduler_config: + monitor = scheduler_config.get("monitor") + else: + # Default to train loss + monitor = "loss" + + # Store exact max_steps if it is provided + if "max_steps" in scheduler_config and scheduler_config["max_steps"] is not None: + max_steps = scheduler_config["max_steps"] + + elif "t_max_epochs" in scheduler_config: + # Compute effective max_steps if t_max_epochs is provided + if train_dataloader is None: + logging.warning( + "As `t_max_epochs` is provided/computed, it is required to pass the train dataloader in order\n" + "to compute effective maximum number of steps.\n" + "Scheduler will not be instantiated !" + ) + return None + + # Raise exception if neither `max_steps` nor `t_max_epochs` is provided + if scheduler_config.get("t_max_epochs", None) is None: + logging.warning( + "`t_max_epochs` cannot be None when `max_steps` is not not provided.\n" + "This can occur when `train dataloader` is not available to correctly " + "prepare the scheduler.\n" + "Scheduler will not be instantiated !" + ) + return None + + # Get iters_per_batch + max_epochs = scheduler_config.get("t_max_epochs") + accumulate_grad_batches = scheduler_config.get("t_accumulate_grad_batches") + limit_train_batches = scheduler_config.get("t_limit_train_batches") + num_workers = scheduler_config.get("t_num_workers") + + # Compute effective num max_steps + num_samples = len(train_dataloader.dataset) # type: ignore + + # we may need to override ModelPT setup_optimization + if train_dataloader.batch_size is not None: + batch_size = train_dataloader.batch_size + elif hasattr(train_dataloader, "batch_sampler") and train_dataloader.batch_sampler is not None: + if train_dataloader.batch_sampler.micro_batch_size is not None: + batch_size = train_dataloader.batch_sampler.micro_batch_size + else: + raise ValueError(f"Could not find batch_size from batch_sampler: {train_dataloader.batch_sampler}") + else: + raise ValueError(f"Could not find batch_size from train_dataloader: {train_dataloader}") + drop_last = train_dataloader.drop_last + + max_steps = compute_max_steps( + max_epochs=max_epochs, + accumulate_grad_batches=accumulate_grad_batches, + limit_train_batches=limit_train_batches, + num_workers=num_workers, + num_samples=num_samples, + batch_size=batch_size, + drop_last=drop_last, + ) + + else: + logging.warning( + "Neither `max_steps` nor `iters_per_batch` were provided to `optim.sched`, " + "cannot compute effective `max_steps` !\n" + "Scheduler will not be instantiated !" + ) + return None + + # Inject max_steps (effective or provided) into the scheduler config + if add_max_args_flag and scheduler_config.get("name", "") != "ExponentialLR": + scheduler_args["max_steps"] = max_steps + + # Get the scheduler class from the config + scheduler_cls = get_scheduler(scheduler_name, **scheduler_args) + + # Instantiate the LR schedule + schedule = scheduler_cls(optimizer, **scheduler_args) + + logging.info( + 'Scheduler "%s" \nwill be used during training (effective maximum steps = %d) - \nParameters : \n(%s)', + str(schedule), + max_steps, + OmegaConf.to_yaml(OmegaConf.create(scheduler_args)), + ) + + # Wrap the schedule in PTL arguments to perform stepwise computation + # Rather than epoch level computation + reduce_lr_on_plateau = isinstance(schedule, optim.lr_scheduler.ReduceLROnPlateau) + + return { + "scheduler": schedule, + "interval": interval, + "frequency": 1, + "monitor": monitor, + "reduce_on_plateau": reduce_lr_on_plateau, + }
+ + +
[docs]def compute_max_steps( + max_epochs, accumulate_grad_batches, limit_train_batches, num_workers, num_samples, batch_size, drop_last +): + """Compute effective max_steps from the provided parameters.""" + _round = math.floor if drop_last else math.ceil + + sampler_num_samples = math.ceil(num_samples / max(1, num_workers)) + + if drop_last and num_workers > 1: + logging.warning( + "Please note that drop_last is broken in pytorch 1.6.0. We will fix when pytorch 1.7.0 is released" + ) + + steps_per_epoch = _round(sampler_num_samples / batch_size) + if isinstance(limit_train_batches, int) or limit_train_batches == 0.0: + steps_per_epoch = min(steps_per_epoch, int(limit_train_batches)) + elif steps_per_epoch != float("inf"): + # limit_train_batches is a percentage of batches per epoch + steps_per_epoch = int(steps_per_epoch * limit_train_batches) + + return math.ceil(steps_per_epoch / accumulate_grad_batches) * max_epochs
+ + +AVAILABLE_SCHEDULERS = { + "WarmupPolicy": WarmupPolicy, + "WarmupHoldPolicy": WarmupHoldPolicy, + "SquareAnnealing": SquareAnnealing, + "CosineAnnealing": CosineAnnealing, + "NoamAnnealing": NoamAnnealing, + "WarmupAnnealing": WarmupAnnealing, + "InverseSquareRootAnnealing": InverseSquareRootAnnealing, + "T5InverseSquareRootAnnealing": T5InverseSquareRootAnnealing, + "SquareRootAnnealing": SquareRootAnnealing, + "PolynomialDecayAnnealing": PolynomialDecayAnnealing, + "PolynomialHoldDecayAnnealing": PolynomialHoldDecayAnnealing, + "StepLR": pt_scheduler.StepLR, + "ExponentialLR": pt_scheduler.ExponentialLR, + "ReduceLROnPlateau": pt_scheduler.ReduceLROnPlateau, + "CyclicLR": pt_scheduler.CyclicLR, +} +
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/optim/novograd.html b/docs/build/html/_modules/mridc/core/optim/novograd.html new file mode 100644 index 00000000..306cb205 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/optim/novograd.html @@ -0,0 +1,240 @@ + + + + + + mridc.core.optim.novograd — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.optim.novograd
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.optim.novograd

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/optim/novograd.py
+
+import torch
+from torch.optim.optimizer import Optimizer
+
+__all__ = ["Novograd"]
+
+
+def _check_valid_opt_params(lr, eps, betas):
+    """Check if the given learning rate and epsilon are valid."""
+    if lr < 0:
+        raise ValueError(f"Invalid learning rate: {lr}")
+    if eps < 0:
+        raise ValueError(f"Invalid epsilon value: {eps}")
+    if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
+        raise ValueError(f"Betas have to be between 0 and 1: {betas}")
+
+
+
[docs]class Novograd(Optimizer): + """ + Implements Novograd algorithm. + It has been proposed in "Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep + Networks" (https://arxiv.org/abs/1905.11286). + + Parameters + ---------- + params: Iterable of parameters to optimize or dicts defining parameter groups. + iterable + lr: Learning rate. + float, (default: 1e-3) + betas: Coefficients used for computing running averages of gradient and its square. + (Tuple[float, float], optional) (default: (0.9, 0.999)) + eps: Term added to the denominator to improve numerical stability. + (float, optional), (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad: whether to use the AMSGrad variant of this algorithm from the paper "On the Convergence of Adam and + Beyond". + (boolean, optional), (default: False) + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.95, 0.98), + eps=1e-8, + weight_decay=0, + grad_averaging=False, + amsgrad=False, + luc=False, + luc_trust=1e-3, + luc_eps=1e-8, + ): + _check_valid_opt_params(lr, eps, betas) + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, amsgrad=amsgrad + ) + self.luc = luc + self.luc_trust = luc_trust + self.luc_eps = luc_eps + super(Novograd, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Novograd, self).__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + +
[docs] def step(self, closure=None): + """ + Performs a single optimization step. + + Parameters + ---------- + closure: A closure that reevaluates the model and returns the loss. + + Returns + ------- + loss: Loss (if provided) + """ + loss = closure() if closure is not None else None + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError("Sparse gradients are not supported.") + amsgrad = group["amsgrad"] + state = self.state[p] + + # State initialization + if not state: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device) + if amsgrad: + # Maintains max of all exp moving avg of squared grad + state["max_exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + norm = grad.norm().pow(2) + + if exp_avg_sq == 0: + exp_avg_sq.copy_(norm) + else: + exp_avg_sq.mul_(beta2).add_(norm, alpha=1.0 - beta2) + + if amsgrad: + # Maintains max of all 2nd moment running avg till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group["eps"]) + else: + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + grad.div_(denom) + if group["weight_decay"] != 0: + grad.add_(p.data, alpha=group["weight_decay"]) + if group["grad_averaging"]: + grad.mul_(1 - beta1) + exp_avg.mul_(beta1).add_(grad) + + if self.luc: + # Clip update so that updates are less than eta*weights + data_norm = torch.norm(p.data) + grad_norm = torch.norm(exp_avg.data) + luc_factor = self.luc_trust * data_norm / (grad_norm + self.luc_eps) + luc_factor = min(luc_factor, group["lr"]) + p.data.add_(exp_avg, alpha=-luc_factor) + else: + p.data.add_(exp_avg, alpha=-group["lr"]) + + return loss
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/optim/optimizer_with_master_params.html b/docs/build/html/_modules/mridc/core/optim/optimizer_with_master_params.html new file mode 100644 index 00000000..2a96cc4c --- /dev/null +++ b/docs/build/html/_modules/mridc/core/optim/optimizer_with_master_params.html @@ -0,0 +1,487 @@ + + + + + + mridc.core.optim.optimizer_with_master_params — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.optim.optimizer_with_master_params
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.optim.optimizer_with_master_params

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/optim/optimizer_with_master_params.py
+
+from contextlib import contextmanager
+
+import torch
+
+from mridc.utils import logging
+
+try:
+    from apex.multi_tensor_apply import multi_tensor_applier
+    from apex.transformer.parallel_state import get_data_parallel_world_size, get_data_parallel_group
+    from apex.transformer.tensor_parallel import copy_tensor_model_parallel_attributes
+    import amp_C
+
+    HAVE_APEX = True
+
+except ImportError:
+
+    HAVE_APEX = False
+
+
+def _zero_grad_group_helper(group, set_to_none):
+    """Zero out the gradient for a group of parameters. Note: copied from torch.optim.optimizer."""
+    for param in group:
+        if param.grad is not None:
+            if set_to_none:
+                param.grad = None
+            else:
+                if param.grad.grad_fn is not None:
+                    param.grad.detach_()
+                else:
+                    param.grad.requires_grad_(False)
+                param.grad.zero_()
+
+
+def _multi_tensor_copy_this_to_that(this, that, overflow_buf):
+    """
+    Use multi-tensor-applier to copy values from one list to another. We don't have a blfoat16 implementation so for
+    now if the overflow_buf is not provided, we default back to simple loop copy to be compatible with bfloat16.
+    """
+    if overflow_buf:
+        # Scaling with factor `1.0` is equivalent to copy.
+        multi_tensor_applier(amp_C.multi_tensor_scale, overflow_buf, [this, that], 1.0)
+    else:
+        # FIXME: use multi-tensor applier for bf16
+        for this_, that_ in zip(this, that):
+            that_.copy_(this_)
+
+
+
[docs]class GradBucket: + """Persistent buffer for main gradients that remains allocated between training iterations.""" + + def __init__(self, numel): + if not HAVE_APEX: + raise ImportError("Apex was not found. Using model parallel models will error out.") + + self.numel = numel + self.data = torch.zeros(self.numel, dtype=torch.float, device=torch.cuda.current_device(), requires_grad=False) + +
[docs] def zero(self): + """Reset the buffer to zero.""" + self.data.zero_()
+ +
[docs] def allreduce_buffer(self): + """Synchronous buffer data allreduce""" + self.data.div_(get_data_parallel_world_size()) + torch.distributed.all_reduce(self.data, group=get_data_parallel_group()) # type: ignore
+ +
[docs] def get(self, shape, start_index): + """Return a tensor with the input `shape` as a view into the 1-D data starting at `start_index`.""" + end_index = start_index + shape.numel() + if end_index > self.numel: + raise AssertionError("requested tensor is out of the buffer range.") + buffer_tensor = self.data[start_index:end_index] + buffer_tensor = buffer_tensor.view(shape) + return buffer_tensor
+ + +
[docs]class MainParamsOptimizerWrapper(torch.optim.Optimizer): + """ + Float16 optimizer wrapper for half precision (fp16 and bf16) data types. + This optimizer wrapper holds main parameters and gradients in fp32 to support + stable convergence. + + Parameters + ---------- + optimizer: base optimizer such as Adam or SGD. + fp32_grad_accum: to enable the use of fp32 in gradient accumulation and allreduce. + contiguous_grad_bucket: to enable allocating the master gradients in the contiguous memory space to reduce memory + fragmentation. + async_grad_allreduce: enable asynchronous gradient allreduce that is executed along with the training step back prop. + """ + + def __init__( + self, + optimizer, + fp32_grad_accum=False, + contiguous_grad_bucket=False, + async_grad_allreduce=False, + ): + super().__init__(optimizer.param_groups) + if not HAVE_APEX: + raise ImportError("Apex was not found. Using model parallel models will error out.") + + self.optimizer = optimizer + if not self.optimizer: + raise AssertionError("no optimizer is provided.") + if contiguous_grad_bucket and not fp32_grad_accum: + raise AssertionError("contiguous gradient buffer assumes using fp32 grad.") + if async_grad_allreduce: + if not fp32_grad_accum: + raise AssertionError( + "async allreduce applies to master gradients only, " + "which is supposed to be accumulated after grad op." + ) + if not contiguous_grad_bucket: + raise AssertionError( + "currently async_grad_allreduce is supported only " "with contiguous_grad_bucket." + ) + + self._fp32_grad_accum = fp32_grad_accum + self._contiguous_grad_bucket = contiguous_grad_bucket + self._async_grad_allreduce = async_grad_allreduce + self._require_backward_grad_sync = False + + # Dummy tensor needed for apex multi-apply tensor. + self._dummy_overflow_buf = None + + # Create persistent buffers for main gradients in contiguous memory space + # - Chunked element-wise and allreduce ops without creating a temporary buffer for merged operation + # - Low memory fragmentation + self._main_grad_buffers = None + if self._contiguous_grad_bucket: + self._main_grad_buffers = {} + # get the size of buffers + num_elements = {} + for i, param_group in enumerate(self.optimizer.param_groups): + for param in param_group["params"]: + if param.requires_grad: + num_elements[i] = num_elements.get(i, 0) + param.data.nelement() + + # Allocate gradient memory buffers for each data type + self._main_grad_buffers[i] = GradBucket(num_elements[i]) + + # Three groups of parameters: + self.float16_groups = [] # original float16 parameters + self.fp32_from_float16_groups = [] # fp32 copy of float16 parameters + self.fp32_from_fp32_groups = [] # original fp32 parameters + + # gradient function hooks + if self._fp32_grad_accum: + self.grad_accs = [] + + # For all the groups in the original optimizer: + for i, param_group in enumerate(self.optimizer.param_groups): + float16_params_this_group = [] + fp32_params_this_group = [] + fp32_from_float16_params_this_group = [] + # For all the parameters in this group: + for j, param in enumerate(param_group["params"]): + if param.requires_grad: + # float16 params: + if param.type() in ["torch.cuda.HalfTensor", "torch.cuda.BFloat16Tensor"]: + float16_params_this_group.append(param) + + # Allocate the main parameter + main_param = param.detach().clone().float() + + # Copy tensor model parallel attributes. + copy_tensor_model_parallel_attributes(main_param, param) + if hasattr(param, "shared"): + main_param.shared = param.shared + + # Assign the grad buffer offset to main parameters + if self._contiguous_grad_bucket: + num_elements[i] -= param.data.nelement() + main_param.grad = self._main_grad_buffers[i].get(param.data.shape, num_elements[i]) + param.main_grad = main_param.grad + + # Replace the optimizer params with the new fp32 copy. + param_group["params"][j] = main_param + fp32_from_float16_params_this_group.append(main_param) + # Reset existing state dict key to the new main param. + if param in self.optimizer.state: + self.optimizer.state[main_param] = self.optimizer.state.pop(param) + elif param.type() == "torch.cuda.FloatTensor": + fp32_params_this_group.append(param) + param_group["params"][j] = param + + else: + raise TypeError( + "Wrapped parameters must be one of torch.cuda.FloatTensor, torch.cuda.HalfTensor, " + f"or torch.cuda.BFloat16Tensor. Received {param.type()}" + ) + + # Add gradient accumulation hook for fp32 grad accumulation + if self._fp32_grad_accum: + # Expand so we get access to grad_fn. + param_tmp = param.expand_as(param) + # Get the gradient accumulator function. + grad_acc = param_tmp.grad_fn.next_functions[0][0] + grad_acc.register_hook(self._make_param_hook(param, main_param)) + self.grad_accs.append(grad_acc) + + self.float16_groups.append(float16_params_this_group) + self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group) + self.fp32_from_fp32_groups.append(fp32_params_this_group) + + # Leverage state_dict() and load_state_dict() to + # recast preexisting per-param state tensors + self.optimizer.load_state_dict(self.optimizer.state_dict()) + + def _make_param_hook(self, param, main_param): + """Create the grad accumulation and all-reduce hook for back prop.""" + + def param_hook(*unused): + """Gradient accumulation and all-reduce.""" + if param.grad.data is not None: + if main_param.grad is None: + main_param.grad = param.grad.float() + else: + main_param.grad.add_(param.grad.data) + # Deallocate grad memory. + param.grad = None + + # Asynchronous gradients allreduce across data_parallel ranks + if self._require_backward_grad_sync: + main_param.grad.div_(get_data_parallel_world_size()) + torch.distributed.all_reduce( # type: ignore + main_param.grad, group=get_data_parallel_group(), async_op=True + ) + + return param_hook + +
[docs] def zero_grad(self, set_to_none=True): + """ + We only need to zero the model related parameters, i.e., float16_groups & fp32_from_fp32_groups. We + additionally zero fp32_from_float16_groups as a memory optimization to reduce fragmentation; in the case of + set_to_none==True, the space used by this field can be safely deallocated at this point. + """ + for group in self.float16_groups: + _zero_grad_group_helper(group, set_to_none) + if self._contiguous_grad_bucket: + for i in self._main_grad_buffers: + self._main_grad_buffers[i].zero() + else: + for group in self.fp32_from_float16_groups: + _zero_grad_group_helper(group, set_to_none) + for group in self.fp32_from_fp32_groups: + _zero_grad_group_helper(group, set_to_none)
+ +
[docs] def copy_model_grads_to_main_grads(self): + """Copy model grads to main grads.""" + # This only needs to be done for the float16 group. + for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups): + for model_param, main_param in zip(model_group, main_group): + if model_param.grad is not None: + main_param.grad = model_param.grad.float() + + # Safe to deallocate model's grad after copying. + # (If using contiguous buffers, main_grad's memory should + # persist and therefore should not be deallocated.) + model_param.grad = None
+ + def _get_model_and_main_params_data_float16(self): + """Get model and main params data in float16.""" + model_data = [] + main_data = [] + half_dtype = None + for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups): + for model_param, main_param in zip(model_group, main_group): + if half_dtype is None: + half_dtype = model_param.data.dtype + model_data.append(model_param.data) + main_data.append(main_param.data) + return model_data, main_data, half_dtype + + def _set_overflow_buffer(self, half_dtype): + """Set overflow buffer.""" + if half_dtype == torch.float16: + if self._dummy_overflow_buf is None: + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) # type: ignore + else: + self._dummy_overflow_buf.fill_(0) + + def _copy_main_params_to_model_params(self): + """Copy main params to model params.""" + # Only needed for the float16 params. + model_data, main_data, half_dtype = self._get_model_and_main_params_data_float16() + self._set_overflow_buffer(half_dtype) + _multi_tensor_copy_this_to_that(this=main_data, that=model_data, overflow_buf=self._dummy_overflow_buf) + + def _copy_model_params_to_main_params(self): + """Copy model params to main params.""" + # Only needed for the float16 params. + model_data, main_data, half_dtype = self._get_model_and_main_params_data_float16() + self._set_overflow_buffer(half_dtype) + _multi_tensor_copy_this_to_that(this=model_data, that=main_data, overflow_buf=self._dummy_overflow_buf) + +
[docs] def reload_model_params(self): + """Reload model params.""" + self._copy_model_params_to_main_params()
+ +
[docs] @torch.no_grad() + def step(self, **kwargs): + """Step the optimizer.""" + self.optimizer.step(closure=None, **kwargs) + # Update params from main params. + with torch.no_grad(): + self._copy_main_params_to_model_params() + # Successful update. + return True
+ +
[docs] def state_dict(self): + """Return the state of the optimizer.""" + return {"optimizer": self.optimizer.state_dict(), "fp32_from_fp16_params": self.fp32_from_float16_groups}
+ +
[docs] def load_state_dict(self, state_dict): + """Load the state of the optimizer.""" + # Optimizer. + optimizer_key = "optimizer" + if optimizer_key not in state_dict: + optimizer_key = "optimizer_state_dict" + logging.info("***WARNING*** loading optimizer from " "an old checkpoint ...") + self.optimizer.load_state_dict(state_dict[optimizer_key]) + + # Copy data for the main params. + fp32_from_float16_params_key = "fp32_from_fp16_params" + if fp32_from_float16_params_key not in state_dict: + fp32_from_float16_params_key = "fp32_from_fp16" + for current_group, saved_group in zip(self.fp32_from_float16_groups, state_dict[fp32_from_float16_params_key]): + for current_param, saved_param in zip(current_group, saved_group): + current_param.data.copy_(saved_param.data)
+ +
[docs] def allreduce_main_grads(self): + """All reduce main grads.""" + for i in self._main_grad_buffers: + self._main_grad_buffers[i].allreduce_buffer()
+ +
[docs] @contextmanager + def grad_sync(self): + """A context manager to disable gradient synchronizations across data-parallel ranks.""" + old_require_backward_grad_sync = self._require_backward_grad_sync + self._require_backward_grad_sync = True + try: + yield + finally: + self._require_backward_grad_sync = old_require_backward_grad_sync
+ + @property + def async_master_grads_allreudce(self): + """Return whether to use async allreduce for master grads.""" + return self._async_grad_allreduce + + @property + def fp32_grad_accumulation(self): + """Return whether to accumulate gradients in fp32.""" + return self._fp32_grad_accum + +
[docs] def get_parameters(self): + """Return the parameters of the optimizer.""" + params = [] + for param_group in self.optimizer.param_groups: + params.extend(iter(param_group["params"])) + return params
+ + def _get_state(self): + """Promote state, so it can be retrieved or set via "optimizer_instance.state.""" + return self.optimizer.state + + def _set_state(self, value): + """Promote state, so it can be retrieved or set via "optimizer_instance.state.""" + self.optimizer.state = value + + state = property(_get_state, _set_state) + + def _get_param_groups(self): + """ + Promote param_groups, so it can be retrieved or set via "optimizer_instance.param_groups. + (for example, to adjust the learning rate) + """ + return self.optimizer.param_groups + + def _set_param_groups(self, value): + """Set param_groups.""" + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/optim/optimizers.html b/docs/build/html/_modules/mridc/core/optim/optimizers.html new file mode 100644 index 00000000..c503df48 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/optim/optimizers.html @@ -0,0 +1,253 @@ + + + + + + mridc.core.optim.optimizers — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.optim.optimizers
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.optim.optimizers

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/optim/optimizers.py
+
+import copy
+from functools import partial
+from typing import Any, Dict, Optional, Union
+
+import hydra
+import torch
+import torch.optim as optim
+from omegaconf import DictConfig, OmegaConf
+from torch.optim import adadelta, adagrad, adamax, rmsprop, rprop
+from torch.optim.optimizer import Optimizer
+
+from mridc.core.conf.optimizers import OptimizerParams, get_optimizer_config, register_optimizer_params
+from mridc.core.optim.adafactor import Adafactor
+from mridc.core.optim.novograd import Novograd
+from mridc.utils.model_utils import maybe_update_config_version
+
+AVAILABLE_OPTIMIZERS = {
+    "sgd": optim.SGD,
+    "adam": optim.Adam,
+    "adamw": optim.AdamW,
+    "adadelta": adadelta.Adadelta,
+    "adamax": adamax.Adamax,
+    "adagrad": adagrad.Adagrad,
+    "rmsprop": rmsprop.RMSprop,
+    "rprop": rprop.Rprop,
+    "novograd": Novograd,
+    "adafactor": Adafactor,
+}
+
+__all__ = ["AVAILABLE_OPTIMIZERS", "get_optimizer", "register_optimizer", "parse_optimizer_args"]
+
+
+
[docs]def parse_optimizer_args( + optimizer_name: str, optimizer_kwargs: Union[DictConfig, Dict[str, Any]] +) -> Union[Dict[str, Any], DictConfig]: + """ + Parses a list of strings, of the format "key=value" or "key2=val1,val2,..." + into a dictionary of type {key=value, key2=[val1, val2], ...} + This dictionary is then used to instantiate the chosen Optimizer. + + Parameters + ---------- + optimizer_name: string name of the optimizer, used for auto resolution of params. + optimizer_kwargs: Either a list of strings in a specified format, or a dictionary. If a dictionary is provided, it + is assumed the dictionary is the final parsed value, and simply returned. If a list of strings is provided, each + item in the list is parsed into a new dictionary. + + Returns + ------- + A dictionary of the parsed arguments. + """ + kwargs: Dict[Any, Any] = {} + + if optimizer_kwargs is None: + return kwargs + + optimizer_kwargs = copy.deepcopy(optimizer_kwargs) + optimizer_kwargs = maybe_update_config_version(optimizer_kwargs) + + if isinstance(optimizer_kwargs, DictConfig): + optimizer_kwargs = OmegaConf.to_container(optimizer_kwargs, resolve=True) + + # If it is a dictionary, perform stepwise resolution + if hasattr(optimizer_kwargs, "keys"): + # Attempt class path resolution + if "_target_" in optimizer_kwargs: # captures (target, _target_) + optimizer_kwargs_config = OmegaConf.create(optimizer_kwargs) + optimizer_instance = hydra.utils.instantiate(optimizer_kwargs_config) # type: DictConfig + optimizer_instance = vars(optimizer_instance) # type: ignore + return optimizer_instance + + # If class path was not provided, perhaps `name` is provided for resolution + if "name" in optimizer_kwargs: + # If `auto` is passed as name for resolution of optimizer name, + # then lookup optimizer name and resolve its parameter config + if optimizer_kwargs["name"] == "auto": + optimizer_params_name = f"{optimizer_name}_params" + optimizer_kwargs.pop("name") + else: + optimizer_params_name = optimizer_kwargs.pop("name") + + # Override arguments provided in the config yaml file + if "params" in optimizer_kwargs: + # If optimizer kwarg overrides are wrapped in yaml `params` + optimizer_params_override = optimizer_kwargs.get("params") + else: + # If the kwargs themselves are a DictConfig + optimizer_params_override = optimizer_kwargs + + if isinstance(optimizer_params_override, DictConfig): + optimizer_params_override = OmegaConf.to_container(optimizer_params_override, resolve=True) + + optimizer_params_cls = get_optimizer_config(optimizer_params_name, **optimizer_params_override) + + # If we are provided just a Config object, simply return the dictionary of that object + if optimizer_params_name is None: + optimizer_params = vars(optimizer_params_cls) + return optimizer_params + # If we are provided a partial class instantiation of a Config, instantiate it and retrieve its vars + # as a dictionary. + optimizer_params = vars(optimizer_params_cls) # instantiate the parameters object + return optimizer_params + + # simply return the dictionary that was provided + return optimizer_kwargs + + return kwargs
+ + +
[docs]def register_optimizer(name: str, optimizer: Optimizer, optimizer_params: OptimizerParams): + """ + Checks if the optimizer name exists in the registry, and if it doesn't, adds it. + This allows custom optimizers to be added and called by name during instantiation. + + Parameters + ---------- + name: Name of the optimizer. Will be used as key to retrieve the optimizer. + optimizer: Optimizer class. + optimizer_params: The parameters as a dataclass of the optimizer. + """ + if name in AVAILABLE_OPTIMIZERS: + raise ValueError(f"Cannot override pre-existing optimizers. Conflicting optimizer name = {name}") + + AVAILABLE_OPTIMIZERS[name] = optimizer + + optim_name = f"{optimizer.__name__}_params" + register_optimizer_params(name=optim_name, optimizer_params=optimizer_params)
+ + +
[docs]def get_optimizer(name: str, **kwargs: Optional[Dict[str, Any]]) -> partial: + """ + Convenience method to obtain an Optimizer class and partially instantiate it with optimizer kwargs. + + Parameters + ---------- + name: Name of the Optimizer in the registry. + kwargs: Optional kwargs of the optimizer used during instantiation. + + Returns + ------- + A partially instantiated Optimizer. + """ + if name not in AVAILABLE_OPTIMIZERS: + raise ValueError( + f"Cannot resolve optimizer '{name}'. Available optimizers are : " f"{AVAILABLE_OPTIMIZERS.keys()}" + ) + if name == "fused_adam" and not torch.cuda.is_available(): + raise ValueError("CUDA must be available to use fused_adam.") + + optimizer = AVAILABLE_OPTIMIZERS[name] + optimizer = partial(optimizer, **kwargs) + return optimizer
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/utils/neural_type_utils.html b/docs/build/html/_modules/mridc/core/utils/neural_type_utils.html new file mode 100644 index 00000000..7420ca7e --- /dev/null +++ b/docs/build/html/_modules/mridc/core/utils/neural_type_utils.html @@ -0,0 +1,190 @@ + + + + + + mridc.core.utils.neural_type_utils — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.utils.neural_type_utils
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.utils.neural_type_utils

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/utils/neural_type_utils.py
+
+from collections import defaultdict
+
+from mridc.core.neural_types.axes import AxisKind
+from mridc.core.neural_types.neural_type import NeuralType
+
+
+
[docs]def get_io_names(types, disabled_names): + """ + This method will return a list of input and output names for a given NeuralType. + + Parameters + ---------- + types: The NeuralType of the module or model to be inspected. + disabled_names: A list of names that should be excluded from the result. + + Returns + ------- + A list of input and output names. + """ + names = list(types.keys()) + for name in disabled_names: + if name in names: + names.remove(name) + return names
+ + +
[docs]def extract_dynamic_axes(name: str, ntype: NeuralType): + """ + This method will extract BATCH and TIME dimension ids from each provided input/output name argument. + + For example, if module/model accepts argument named "input_signal" with type corresponding to [Batch, Time, Dim] + shape, then the returned result should contain "input_signal" -> [0, 1] because Batch and Time are dynamic axes + as they can change from call to call during inference. + + Parameters + ---------- + name: Name of input or output parameter + ntype: Corresponding Neural Type + + Returns + ------- + A dictionary with input/output name as key and a list of dynamic axes as value. + """ + + def unpack_nested_neural_type(neural_type): + """ + This method will unpack nested NeuralTypes. + + Parameters + ---------- + neural_type: The NeuralType to be unpacked. + + Returns + ------- + A list of all the nested NeuralTypes. + """ + if type(neural_type) in (list, tuple): + return unpack_nested_neural_type(neural_type[0]) + return neural_type + + dynamic_axes = defaultdict(list) + if type(ntype) in (list, tuple): + ntype = unpack_nested_neural_type(ntype) + + if ntype.axes: + for ind, axis in enumerate(ntype.axes): + if axis.kind in [AxisKind.Batch, AxisKind.Time, AxisKind.Width, AxisKind.Height]: + dynamic_axes[name].append(ind) + return dynamic_axes
+ + +
[docs]def get_dynamic_axes(types, names): + """ + This method will return a dictionary with input/output names as keys and a list of dynamic axes as values. + + Parameters + ---------- + types: The NeuralType of the module or model to be inspected. + names: A list of names that should be inspected. + + Returns + ------- + A dictionary with input/output names as keys and a list of dynamic axes as values. + """ + dynamic_axes = defaultdict(list) + for name in names: + if name in types: + dynamic_axes |= extract_dynamic_axes(name, types[name]) + return dynamic_axes
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/core/utils/numba_utils.html b/docs/build/html/_modules/mridc/core/utils/numba_utils.html new file mode 100644 index 00000000..419c1ff5 --- /dev/null +++ b/docs/build/html/_modules/mridc/core/utils/numba_utils.html @@ -0,0 +1,243 @@ + + + + + + mridc.core.utils.numba_utils — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.core.utils.numba_utils
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.core.utils.numba_utils

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/utils/numba_utils.py
+
+import contextlib
+import logging as pylogger
+import operator
+import os
+
+# Prevent Numba CUDA logs from showing at info level
+from mridc.utils.model_utils import check_lib_version
+
+cuda_logger = pylogger.getLogger("numba.cuda.cudadrv.driver")
+cuda_logger.setLevel(pylogger.ERROR)  # only show error
+
+__NUMBA_DEFAULT_MINIMUM_VERSION__ = "0.53.0"
+__NUMBA_MINIMUM_VERSION__ = os.environ.get("MRIDC_NUMBA_MINVER", __NUMBA_DEFAULT_MINIMUM_VERSION__)
+
+NUMBA_INSTALLATION_MESSAGE = (
+    "Could not import `numba`.\n"
+    "Please install numba in one of the following ways."
+    "1) If using conda, simply install it with conda using `conda install -c numba numba`\n"
+    "2) If using pip (not recommended), `pip install --upgrade numba`\n"
+    "followed by `export NUMBAPRO_LIBDEVICE='/usr/local/cuda/nvvm/libdevice/'` and \n"
+    "`export NUMBAPRO_NVVM='/usr/local/cuda/nvvm/lib64/libnvvm.so'`.\n"
+    "It is advised to always install numba using conda only, "
+    "as pip installations might interfere with other libraries such as llvmlite.\n"
+    "If pip install does not work, you can also try adding `--ignore-installed` to the pip command,\n"
+    "but this is not advised."
+)
+
+STRICT_NUMBA_COMPAT_CHECK = True
+
+# Get environment key if available
+if "STRICT_NUMBA_COMPAT_CHECK" in os.environ:
+    check_str = os.environ.get("STRICT_NUMBA_COMPAT_CHECK")
+    check_bool = str(check_str).lower() in {"yes", "true", "t", "1"}
+    STRICT_NUMBA_COMPAT_CHECK = check_bool
+
+
+
[docs]def is_numba_compat_strict() -> bool: + """ + Returns strictness level of numba cuda compatibility checks. + If value is true, numba cuda compatibility matrix must be satisfied. + If value is false, only cuda availability is checked, not compatibility. + Numba Cuda may still compile and run without issues in such a case, or it may fail. + """ + return STRICT_NUMBA_COMPAT_CHECK
+ + +
[docs]def set_numba_compat_strictness(strict: bool): + """ + Sets the strictness level of numba cuda compatibility checks. + If value is true, numba cuda compatibility matrix must be satisfied. + If value is false, only cuda availability is checked, not compatibility. + Numba Cuda may still compile and run without issues in such a case, or it may fail. + + Parameters + ---------- + strict: Whether to enforce strict compatibility checks or relax them. + """ + global STRICT_NUMBA_COMPAT_CHECK + STRICT_NUMBA_COMPAT_CHECK = strict
+ + +
[docs]@contextlib.contextmanager +def with_numba_compat_strictness(strict: bool): + """Context manager to temporarily set numba cuda compatibility strictness.""" + initial_strictness = is_numba_compat_strict() + set_numba_compat_strictness(strict=strict) + yield + set_numba_compat_strictness(strict=initial_strictness)
+ + +
[docs]def numba_cpu_is_supported(min_version: str) -> bool: + """ + Tests if an appropriate version of numba is installed. + + Parameters + ---------- + min_version: The minimum version of numba that is required. + + Returns + ------- + bool, whether numba CPU supported with this current installation or not. + """ + module_available, _ = check_lib_version("numba", checked_version=min_version, operator=operator.ge) + + # If numba is not installed + if module_available is None: + return False + return True
+ + +
[docs]def numba_cuda_is_supported(min_version: str) -> bool: + """ + Tests if an appropriate version of numba is installed, and if it is, + if cuda is supported properly within it. + + Parameters + ---------- + min_version: The minimum version of numba that is required. + + Returns + ------- + Whether cuda is supported with this current installation or not. + """ + module_available = numba_cpu_is_supported(min_version) + + # If numba is not installed + if module_available is None: + return False + + if module_available is not True: + return False + from numba import cuda + + if not hasattr(cuda, "is_supported_version"): + # assume cuda is supported, but it may fail due to CUDA incompatibility + return False + + try: + cuda_available = cuda.is_available() + cuda_compatible = cuda.is_supported_version() if cuda_available else False + if is_numba_compat_strict(): + return cuda_available and cuda_compatible + return cuda_available + + except OSError: + # dlopen(libcudart.dylib) might fail if CUDA was never installed in the first place. + return False
+ + +
[docs]def skip_numba_cuda_test_if_unsupported(min_version: str): + """ + Helper method to skip pytest test case if numba cuda is not supported. + + Parameters + ---------- + min_version: The minimum version of numba that is required. + """ + numba_cuda_support = numba_cuda_is_supported(min_version) + if not numba_cuda_support: + import pytest + + pytest.skip(f"Numba cuda test is being skipped. Minimum version required : {min_version}")
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/launch.html b/docs/build/html/_modules/mridc/launch.html new file mode 100644 index 00000000..8a5db541 --- /dev/null +++ b/docs/build/html/_modules/mridc/launch.html @@ -0,0 +1,195 @@ + + + + + + mridc.launch — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +

Source code for mridc.launch

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+import pytorch_lightning as pl
+import torch
+from omegaconf import DictConfig, OmegaConf
+
+from mridc.collections.reconstruction.models.ccnn import CascadeNet
+from mridc.collections.reconstruction.models.cirim import CIRIM
+from mridc.collections.reconstruction.models.crnn import CRNNet
+from mridc.collections.reconstruction.models.dunet import DUNet
+from mridc.collections.reconstruction.models.jointicnet import JointICNet
+from mridc.collections.reconstruction.models.kikinet import KIKINet
+from mridc.collections.reconstruction.models.lpd import LPDNet
+from mridc.collections.reconstruction.models.multidomainnet import MultiDomainNet
+from mridc.collections.reconstruction.models.pics import PICS
+from mridc.collections.reconstruction.models.rvn import RecurrentVarNet
+from mridc.collections.reconstruction.models.unet import UNet
+from mridc.collections.reconstruction.models.vn import VarNet
+from mridc.collections.reconstruction.models.vsnet import VSNet
+from mridc.collections.reconstruction.models.xpdnet import XPDNet
+from mridc.collections.reconstruction.models.zf import ZF
+from mridc.core.conf.hydra_runner import hydra_runner
+from mridc.utils import logging
+from mridc.utils.exp_manager import exp_manager
+
+
+
[docs]@hydra_runner(config_path=".", config_name="config") +def main(cfg: DictConfig) -> None: + """ + Main function for training and running a model + + Parameters + ---------- + cfg: Configuration (yaml) file. + DictConfig + """ + logging.info(f"Config: {OmegaConf.to_yaml(cfg)}") + + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + + model_name = (cfg.model["model_name"]).upper() + + if model_name == "CASCADENET": + model = CascadeNet(cfg.model, trainer=trainer) + elif model_name == "CIRIM": + model = CIRIM(cfg.model, trainer=trainer) + elif model_name == "CRNNET": + model = CRNNet(cfg.model, trainer=trainer) + elif model_name == "DUNET": + model = DUNet(cfg.model, trainer=trainer) + elif model_name in ("E2EVN", "VN"): + model = VarNet(cfg.model, trainer=trainer) + elif model_name == "JOINTICNET": + model = JointICNet(cfg.model, trainer=trainer) + elif model_name == "KIKINET": + model = KIKINet(cfg.model, trainer=trainer) + elif model_name == "LPDNET": + model = LPDNet(cfg.model, trainer=trainer) + elif model_name == "MULTIDOMAINNET": + model = MultiDomainNet(cfg.model, trainer=trainer) + elif model_name == "PICS": + model = PICS(cfg.model, trainer=trainer) + elif model_name == "RVN": + model = RecurrentVarNet(cfg.model, trainer=trainer) + elif model_name == "UNET": + model = UNet(cfg.model, trainer=trainer) + elif model_name == "VSNET": + model = VSNet(cfg.model, trainer=trainer) + elif model_name == "XPDNET": + model = XPDNet(cfg.model, trainer=trainer) + elif model_name == "ZF": + model = ZF(cfg.model, trainer=trainer) + else: + raise NotImplementedError( + f"{model_name} is not implemented in MRIDC. You can use one of the following methods: " + f"CASCADENET, CIRIM, CRNNET, DUNET, E2EVN, JOINTICNET, KIKINET, LPDNET, MULTIDOMAINNET, PICS, RVN, UNET, " + f"VSNET, XPDNET, or Zero-Filled. /n" + f"If you implemented a new model, please consider adding it through a PR on GitHub." + ) + + if cfg.get("pretrained", None): + checkpoint = cfg.get("checkpoint", None) + logging.info(f"Loading pretrained model from {checkpoint}") + model.load_state_dict(torch.load(checkpoint)["state_dict"]) + + if cfg.get("mode", None) == "train": + logging.info("Validating") + trainer.validate(model) + logging.info("Training") + trainer.fit(model) + else: + logging.info("Evaluating") + trainer.test(model)
+ + +if __name__ == "__main__": + main() +
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/app_state.html b/docs/build/html/_modules/mridc/utils/app_state.html new file mode 100644 index 00000000..0e619e62 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/app_state.html @@ -0,0 +1,447 @@ + + + + + + mridc.utils.app_state — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +

Source code for mridc.utils.app_state

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/tree/main/nemo/utils
+
+from dataclasses import dataclass
+from threading import Lock
+from typing import Dict, Optional
+
+from mridc.utils.metaclasses import Singleton
+
+
+
[docs]@dataclass() +class ModelMetadataRegistry: + """A registry for model metadata.""" + + guid: str + gidx: int + restoration_path: Optional[str] = None
+ + +
[docs]class AppState(metaclass=Singleton): + """A singleton class that holds the state of the application.""" + + def __init__(self): + """Initializes the AppState class.""" + # method call lock + self.model_parallel_rank = None + self.__lock = Lock() + + # TODO: should we store global config in hydra_runner? + self._app_cfg = None + + # World info + self._device_id = None + self._local_rank = None + self._global_rank = None + self._model_parallel_rank = None + self._tensor_model_parallel_rank = None + self._pipeline_model_parallel_rank = None + self._data_parallel_rank = None + + self._world_size = None + self._model_parallel_size = None + self._tensor_model_parallel_size = None + self._tensor_model_parallel_group = None + self._pipeline_model_parallel_size = None + self._pipeline_model_parallel_group = None + self._pipeline_model_parallel_split_rank = None + self._model_parallel_group = None + self._data_parallel_size = None + self._data_parallel_group = None + + self._random_seed = None + + # Logging info + self._log_dir = None + self._exp_dir = None + self._name = None + self._checkpoint_name = None + self._version = None + self._create_checkpoint_callback = None + self._checkpoint_callback_params = None + + # Save and Restore (.mridc) + self._tmpdir_name = None + self._is_model_being_restored = False + self._mridc_file_folder = None + self._model_restore_path = None + self._all_model_restore_paths = [] + self._model_guid_map = {} # type: Dict[str, ModelMetadataRegistry] + + @property + def device_id(self): + """Property returns the device_id.""" + return self._device_id + + @device_id.setter + def device_id(self, id: int): + """Property sets the device_id.""" + self._device_id = id + + @property + def world_size(self): + """Property returns the total number of GPUs.""" + return self._world_size + + @world_size.setter + def world_size(self, size: int): + """Property sets the total number of GPUs.""" + self._world_size = size + + @property + def model_parallel_size(self): + """Property returns the number of GPUs in each model parallel group.""" + return self._model_parallel_size + + @model_parallel_size.setter + def model_parallel_size(self, size: int): + """Property sets the number of GPUs in each model parallel group.""" + self._model_parallel_size = size + + @property + def tensor_model_parallel_size(self): + """Property returns the number of GPUs in each model parallel group.""" + return self._tensor_model_parallel_size + + @tensor_model_parallel_size.setter + def tensor_model_parallel_size(self, size): + """Property sets the number of GPUs in each model parallel group.""" + self._tensor_model_parallel_size = size + + @property + def pipeline_model_parallel_size(self): + """Property returns the number of GPUs in each model parallel group.""" + return self._pipeline_model_parallel_size + + @pipeline_model_parallel_size.setter + def pipeline_model_parallel_size(self, size): + """Property sets the number of GPUs in each model parallel group.""" + self._pipeline_model_parallel_size = size + + @property + def data_parallel_size(self): + """Property returns the number of GPUs in each data parallel group.""" + return self._data_parallel_size + + @data_parallel_size.setter + def data_parallel_size(self, size: int): + """Property sets the number of GPUs in each data parallel group.""" + self._data_parallel_size = size + + @property + def local_rank(self): + """Property returns the local rank.""" + return self._local_rank + + @local_rank.setter + def local_rank(self, rank: int): + """Property sets the local rank.""" + self._local_rank = rank + + @property + def global_rank(self): + """Property returns the global rank.""" + return self._global_rank + + @global_rank.setter + def global_rank(self, rank: int): + """Property sets the global rank.""" + self._global_rank = rank + + @property + def tensor_model_parallel_rank(self): + """Property returns the model parallel rank.""" + return self._tensor_model_parallel_rank + + @tensor_model_parallel_rank.setter + def tensor_model_parallel_rank(self, rank): + """Property sets the model parallel rank.""" + self._tensor_model_parallel_rank = rank + + @property + def tensor_model_parallel_group(self): + """Property returns the model parallel group.""" + return self._tensor_model_parallel_group + + @tensor_model_parallel_group.setter + def tensor_model_parallel_group(self, group): + """Property sets the model parallel group.""" + self._tensor_model_parallel_group = group + + @property + def pipeline_model_parallel_rank(self): + """Property returns the model parallel rank.""" + return self._pipeline_model_parallel_rank + + @pipeline_model_parallel_rank.setter + def pipeline_model_parallel_rank(self, rank): + """Property sets the model parallel rank.""" + self._pipeline_model_parallel_rank = rank + + @property + def pipeline_model_parallel_split_rank(self): + """Property returns the model parallel split rank.""" + return self._pipeline_model_parallel_split_rank + + @pipeline_model_parallel_split_rank.setter + def pipeline_model_parallel_split_rank(self, rank): + """Property sets the model parallel split rank.""" + self._pipeline_model_parallel_split_rank = rank + + @property + def pipeline_model_parallel_group(self): + """Property returns the model parallel group.""" + return self._pipeline_model_parallel_group + + @pipeline_model_parallel_group.setter + def pipeline_model_parallel_group(self, group): + """Property sets the model parallel group.""" + self._pipeline_model_parallel_group = group + + @property + def data_parallel_rank(self): + """Property returns the data parallel rank.""" + return self._data_parallel_rank + + @data_parallel_rank.setter + def data_parallel_rank(self, rank: int): + """Property sets the data parallel rank.""" + self._data_parallel_rank = rank + + @property + def data_parallel_group(self): + """Property returns the data parallel group.""" + return self._data_parallel_group + + @data_parallel_group.setter + def data_parallel_group(self, group): + """Property sets the data parallel group.""" + self._data_parallel_group = group + + @property + def random_seed(self): + """Property returns the random seed.""" + return self._random_seed + + @random_seed.setter + def random_seed(self, seed: int): + """Property sets the random seed.""" + self._random_seed = seed + + @property + def log_dir(self): + """Returns the log_dir set by exp_manager.""" + return self._log_dir + + @log_dir.setter + def log_dir(self, dir): + """Sets the log_dir property.""" + self._log_dir = dir + + @property + def exp_dir(self): + """Returns the exp_dir set by exp_manager.""" + return self._exp_dir + + @exp_dir.setter + def exp_dir(self, dir): + """Sets the log_dir property.""" + self._exp_dir = dir + + @property + def name(self): + """Returns the name set by exp_manager.""" + return self._name + + @name.setter + def name(self, name: str): + """Sets the name property.""" + self._name = name + + @property + def checkpoint_name(self): + """Returns the name set by exp_manager.""" + return self._checkpoint_name + + @checkpoint_name.setter + def checkpoint_name(self, name: str): + """Sets the name property.""" + self._checkpoint_name = name + + @property + def version(self): + """Returns the version set by exp_manager.""" + return self._version + + @version.setter + def version(self, version: str): + """Sets the version property.""" + self._version = version + + @property + def create_checkpoint_callback(self): + """Returns the create_checkpoint_callback set by exp_manager.""" + return self._create_checkpoint_callback + + @create_checkpoint_callback.setter + def create_checkpoint_callback(self, create_checkpoint_callback: bool): + """Sets the create_checkpoint_callback property.""" + self._create_checkpoint_callback = create_checkpoint_callback + + @property + def checkpoint_callback_params(self): + """Returns the version set by exp_manager.""" + return self._checkpoint_callback_params + + @checkpoint_callback_params.setter + def checkpoint_callback_params(self, params: dict): + """Sets the name property.""" + self._checkpoint_callback_params = params + + @property + def model_restore_path(self): + """Returns the model_restore_path set by exp_manager.""" + return self._all_model_restore_paths[-1] if len(self._all_model_restore_paths) > 0 else None + + @model_restore_path.setter + def model_restore_path(self, path): + """Sets the model_restore_path property.""" + with self.__lock: + self._model_restore_path = path + self._all_model_restore_paths.append(path) + +
[docs] def register_model_guid(self, guid: str, restoration_path: Optional[str] = None): + """Maps a guid to its restore path (None or last absolute path).""" + with self.__lock: + if guid in self._model_guid_map: + idx = self._model_guid_map[guid].gidx + else: + idx = len(self._model_guid_map) + self._model_guid_map[guid] = ModelMetadataRegistry(guid, idx, restoration_path=restoration_path)
+ +
[docs] def reset_model_guid_registry(self): + """Resets the model guid registry.""" + with self.__lock: + self._model_guid_map.clear()
+ +
[docs] def get_model_metadata_from_guid(self, guid) -> ModelMetadataRegistry: + """Returns the global model idx and restoration path.""" + return self._model_guid_map[guid]
+ + @property + def is_model_being_restored(self) -> bool: + """Returns whether a model is being restored.""" + return self._is_model_being_restored + + @is_model_being_restored.setter + def is_model_being_restored(self, is_restored: bool): + """Sets whether a model is being restored.""" + self._is_model_being_restored = is_restored + + @property + def mridc_file_folder(self) -> str: + """Returns the mridc_file_folder set by exp_manager.""" + return self._mridc_file_folder + + @mridc_file_folder.setter + def mridc_file_folder(self, path: str): + """Sets the mridc_file_folder property.""" + self._mridc_file_folder = path
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/arguments.html b/docs/build/html/_modules/mridc/utils/arguments.html new file mode 100644 index 00000000..dd80b12a --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/arguments.html @@ -0,0 +1,214 @@ + + + + + + mridc.utils.arguments — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +

Source code for mridc.utils.arguments

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/arguments.py
+
+from argparse import ArgumentParser
+from typing import Any, Dict, List, Optional, Union
+
+
+
[docs]def add_optimizer_args( + parent_parser: ArgumentParser, + optimizer: str = "adam", + default_lr: float = None, + default_opt_args: Optional[Union[Dict[str, Any], List[str]]] = None, +) -> ArgumentParser: + """ + Extends existing argparse with default optimizer args. + + # Example of adding optimizer args to command line: + python train_script.py ... --optimizer "novograd" --lr 0.01 --opt_args betas=0.95,0.5 weight_decay=0.001 + + Parameters + ---------- + parent_parser: Custom CLI parser that will be extended. + ArgumentParser + optimizer: Default optimizer required. + str, default "adam" + default_lr: Default learning rate. + float, default None + default_opt_args: Default optimizer arguments. + Optional[Union[Dict[str, Any], List[str]]], default None + + Returns + ------- + Parser extended by Optimizers arguments. + ArgumentParser + """ + if default_opt_args is None: + default_opt_args = [] + + parser = ArgumentParser(parents=[parent_parser], add_help=True, conflict_handler="resolve") + + parser.add_argument("--optimizer", type=str, default=optimizer, help="Name of the optimizer. Defaults to Adam.") + parser.add_argument("--lr", type=float, default=default_lr, help="Learning rate of the optimizer.") + parser.add_argument( + "--opt_args", + default=default_opt_args, + nargs="+", + type=str, + help="Overriding arguments for the optimizer. \n Must follow the pattern : \n name=value separated by spaces." + "Example: --opt_args weight_decay=0.001 eps=1e-8 betas=0.9,0.999", + ) + + return parser
+ + +
[docs]def add_scheduler_args(parent_parser: ArgumentParser) -> ArgumentParser: + """ + Extends existing argparse with default scheduler args. + + Parameters + ---------- + parent_parser: Custom CLI parser that will be extended. + ArgumentParser + + Returns + ------- + Parser extended by Schedulers arguments. + ArgumentParser + """ + parser = ArgumentParser(parents=[parent_parser], add_help=False, conflict_handler="resolve") + parser.add_argument("--warmup_steps", type=int, required=False, default=None, help="Number of warmup steps") + parser.add_argument( + "--warmup_ratio", + type=float, + required=False, + default=None, + help="Number of warmup steps as a percentage of total training steps", + ) + parser.add_argument("--hold_steps", type=int, required=False, default=None, help="Number of hold LR steps") + parser.add_argument( + "--hold_ratio", + type=float, + required=False, + default=None, + help="Number of hold LR steps as a percentage of total training steps", + ) + parser.add_argument("--min_lr", type=float, required=False, default=0.0, help="Minimum learning rate") + parser.add_argument( + "--last_epoch", type=int, required=False, default=-1, help="Last epoch id. -1 indicates training from scratch" + ) + return parser
+ + +
[docs]def add_recon_args(parent_parser: ArgumentParser) -> ArgumentParser: + """ + Extends existing argparse with default reconstruction args. + + Parameters + ---------- + parent_parser: Custom CLI parser that will be extended. + ArgumentParser + + Returns + ------- + Parser extended by Reconstruction arguments. + ArgumentParser + """ + parser = ArgumentParser(parents=[parent_parser], add_help=False, conflict_handler="resolve") + parser.add_argument( + "--data_dir", type=str, required=False, help="data directory to training or/and evaluation dataset" + ) + parser.add_argument("--config_file", type=str, required=False, default=None, help="Recon model configuration file") + parser.add_argument( + "--pretrained_model_name", default="recon-base-uncased", type=str, required=False, help="pretrained model name" + ) + parser.add_argument("--do_lower_case", action="store_true", required=False, help="lower case data") + return parser
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/cloud.html b/docs/build/html/_modules/mridc/utils/cloud.html new file mode 100644 index 00000000..04b7ec12 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/cloud.html @@ -0,0 +1,173 @@ + + + + + + mridc.utils.cloud — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +

Source code for mridc.utils.cloud

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/cloud.py
+
+import os
+from pathlib import Path
+from time import sleep
+
+import wget
+
+from mridc.utils import logging
+
+
+
[docs]def maybe_download_from_cloud(url, filename, subfolder=None, cache_dir=None, refresh_cache=False) -> str: + """ + Download a file from a URL if it does not exist in the cache. + + Parameters + ---------- + url: URL to download the file from. + str + filename: What to download. The request will be issued to url/filename + str + subfolder: Subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can be empty. + str + cache_dir: A cache directory where to download. If not present, this function will attempt to create it. + str, If None (default), then it will be $HOME/.cache/torch/mridc + refresh_cache: If True and cached file is present, it will delete it and re-fetch + bool + + Returns + ------- + If successful - absolute local path to the downloaded file else empty string. + """ + if cache_dir is None: + cache_location = Path.joinpath(Path.home(), ".cache/torch/mridc") + else: + cache_location = cache_dir + if subfolder is not None: + destination = Path.joinpath(cache_location, subfolder) + else: + destination = cache_location + + if not os.path.exists(destination): + os.makedirs(destination, exist_ok=True) + + destination_file = Path.joinpath(destination, filename) + + if os.path.exists(destination_file): + logging.info(f"Found existing object {destination_file}.") + if refresh_cache: + logging.info("Asked to refresh the cache.") + logging.info(f"Deleting file: {destination_file}") + os.remove(destination_file) + else: + logging.info(f"Re-using file from: {destination_file}") + return str(destination_file) + # download file + wget_uri = url + filename + logging.info(f"Downloading from: {wget_uri} to {str(destination_file)}") + # NGC links do not work everytime so we try and wait + i = 0 + max_attempts = 3 + while i < max_attempts: + i += 1 + try: + wget.download(wget_uri, str(destination_file)) + if os.path.exists(destination_file): + return str(destination_file) + return "" + except Exception as e: + logging.info(f"Download from cloud failed. Attempt {i} of {max_attempts}") + logging.info(f"Error: {e}") + sleep(0.05) + continue + raise ValueError("Not able to download url right now, please try again.")
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/config_utils.html b/docs/build/html/_modules/mridc/utils/config_utils.html new file mode 100644 index 00000000..af978b29 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/config_utils.html @@ -0,0 +1,353 @@ + + + + + + mridc.utils.config_utils — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.config_utils
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.config_utils

+# encoding: utf-8
+import sys
+
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/config_utils.py
+
+import copy
+import inspect
+from dataclasses import is_dataclass
+from typing import Dict, List, Optional, Set
+
+from omegaconf import DictConfig, OmegaConf, open_dict
+
+from mridc.core.conf.modelPT import MRIDCConfig
+from mridc.utils import logging
+
+_HAS_HYDRA = True
+
+
+
[docs]def update_model_config(model_cls: MRIDCConfig, update_cfg: "DictConfig", drop_missing_subconfigs: bool = True): + """ + Helper class that updates the default values of a ModelPT config class with the values in a DictConfig that \ + mirrors the structure of the config class. Assumes the `update_cfg` is a DictConfig (either generated manually, \ + via hydra or instantiated via yaml/model.cfg). This update_cfg is then used to override the default values \ + preset inside the ModelPT config class. If `drop_missing_subconfigs` is set, the certain sub-configs of the \ + ModelPT config class will be removed, if they are not found in the mirrored `update_cfg`. The following \ + sub-configs are subject to potential removal: + - `train_ds` + - `validation_ds` + - `test_ds` + - `optim` + nested sched + + Parameters + ---------- + model_cls: A subclass of MRIDC, that details in entirety all the parameters that constitute the MRIDC Model. + update_cfg: A DictConfig that mirrors the structure of the MRIDCConfig data class. Used to update the default \ + values of the config class. + drop_missing_subconfigs: Bool which determines whether to drop certain sub-configs from the MRIDCConfig class, \ + if the corresponding sub-config is missing from `update_cfg`. + + Returns + ------- + A DictConfig with updated values that can be used to instantiate the MRIDC Model along with supporting \ + infrastructure. + """ + if not _HAS_HYDRA: + logging.error("This function requires Hydra/Omegaconf and it was not installed.") + sys.exit(1) + if not (is_dataclass(model_cls) or isinstance(model_cls, DictConfig)): + raise ValueError("`model_cfg` must be a dataclass or a structured OmegaConf object") + + if not isinstance(update_cfg, DictConfig): + update_cfg = OmegaConf.create(update_cfg) + + if is_dataclass(model_cls): + model_cls = OmegaConf.structured(model_cls) + + # Update optional configs + model_cls = _update_subconfig( + model_cls, update_cfg, subconfig_key="train_ds", drop_missing_subconfigs=drop_missing_subconfigs + ) + model_cls = _update_subconfig( + model_cls, update_cfg, subconfig_key="validation_ds", drop_missing_subconfigs=drop_missing_subconfigs + ) + model_cls = _update_subconfig( + model_cls, update_cfg, subconfig_key="test_ds", drop_missing_subconfigs=drop_missing_subconfigs + ) + model_cls = _update_subconfig( + model_cls, update_cfg, subconfig_key="optim", drop_missing_subconfigs=drop_missing_subconfigs + ) + + # Add optim and sched additional keys to model cls + model_cls = _add_subconfig_keys(model_cls, update_cfg, subconfig_key="optim") + + # Perform full merge of model config class and update config + # Remove ModelPT artifact `target` + if "target" in update_cfg.model and "target" not in model_cls.model: # type: ignore + with open_dict(update_cfg.model): + update_cfg.model.pop("target") + + # Remove ModelPT artifact `mridc_version` + if "mridc_version" in update_cfg.model and "mridc_version" not in model_cls.model: # type: ignore + with open_dict(update_cfg.model): + update_cfg.model.pop("mridc_version") + + return OmegaConf.merge(model_cls, update_cfg)
+ + +def _update_subconfig( + model_cfg: "DictConfig", update_cfg: "DictConfig", subconfig_key: str, drop_missing_subconfigs: bool +): + """ + Updates the MRIDCConfig DictConfig such that: + 1) If the sub-config key exists in the `update_cfg`, but does not exist in ModelPT config: + - Add the sub-config from update_cfg to ModelPT config + 2) If the sub-config key does not exist in `update_cfg`, but exists in ModelPT config: + - Remove the sub-config from the ModelPT config; iff the `drop_missing_subconfigs` flag is set. + + Parameters + ---------- + model_cfg: A DictConfig instantiated from the MRIDCConfig subclass. + update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values. + subconfig_key: A str key used to check and update the sub-config. + drop_missing_subconfigs: A bool flag, whether to allow deletion of the MRIDCConfig sub-config, if its mirror + sub-config does not exist in the `update_cfg`. + + Returns + ------- + The updated DictConfig for the MRIDCConfig + """ + if not _HAS_HYDRA: + logging.error("This function requires Hydra/Omegaconf and it was not installed.") + sys.exit(1) + with open_dict(model_cfg.model): + # If update config has the key, but model cfg doesnt have the key + # Add the update cfg subconfig to the model cfg + if subconfig_key in update_cfg.model and subconfig_key not in model_cfg.model: + model_cfg.model[subconfig_key] = update_cfg.model[subconfig_key] + + # If update config does not the key, but model cfg has the key + # Remove the model cfg subconfig in order to match layout of update cfg + if subconfig_key not in update_cfg.model and subconfig_key in model_cfg.model and drop_missing_subconfigs: + model_cfg.model.pop(subconfig_key) + + return model_cfg + + +def _add_subconfig_keys(model_cfg: "DictConfig", update_cfg: "DictConfig", subconfig_key: str): + """ + For certain sub-configs, the default values specified by the MRIDCConfig class is insufficient. + In order to support every potential value in the merge between the `update_cfg`, it would require explicit + definition of all possible cases. + An example of such a case is Optimizers, and their equivalent Schedulers. All optimizers share a few basic details + - such as name and lr, but almost all require additional parameters - such as weight decay. + It is impractical to create a config for every single optimizer + every single scheduler combination. + In such a case, we perform a dual merge. The Optim and Sched Dataclass contain the bare minimum essential + components. The extra values are provided via update_cfg. + In order to enable the merge, we first need to update the update sub-config to incorporate the keys, with dummy + temporary values (merge update config with model config). This is done on a copy of the update sub-config, as the + actual override values might be overridden by the MRIDCConfig defaults. + Then we perform a merge of this temporary sub-config with the actual override config in a later step (merge + model_cfg with original update_cfg, done outside this function). + + Parameters + ---------- + model_cfg: A DictConfig instantiated from the MRIDCConfig subclass. + update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values. + subconfig_key: A str key used to check and update the sub-config. + + Returns + ------- + A ModelPT DictConfig with additional keys added to the sub-config. + """ + if not _HAS_HYDRA: + logging.error("This function requires Hydra/Omegaconf and it was not installed.") + sys.exit(1) + with open_dict(model_cfg.model): + # Create copy of original model sub config + if subconfig_key in update_cfg.model: + if subconfig_key not in model_cfg.model: + # create the key as a placeholder + model_cfg.model[subconfig_key] = None + + subconfig = copy.deepcopy(model_cfg.model[subconfig_key]) + update_subconfig = copy.deepcopy(update_cfg.model[subconfig_key]) + + # Add the keys and update temporary values, will be updated during full merge + subconfig = OmegaConf.merge(update_subconfig, subconfig) + # Update sub config + model_cfg.model[subconfig_key] = subconfig + + return model_cfg + + +
[docs]def assert_dataclass_signature_match( + cls: "class_type", # type: ignore + datacls: "dataclass", # type: ignore + ignore_args: Optional[List[str]] = None, + remap_args: Optional[Dict[str, str]] = None, +): + """ + Analyses the signature of a provided class and its respective data class, + asserting that the dataclass signature matches the class __init__ signature. + Note: + This is not a value based check. This function only checks if all argument + names exist on both class and dataclass and logs mismatches. + + Parameters + ---------- + cls: Any class type - but not an instance of a class. Pass type(x) where x is an instance + if class type is not easily available. + datacls: A corresponding dataclass for the above class. + ignore_args: (Optional) A list of string argument names which are forcibly ignored, + even if mismatched in the signature. Useful when a dataclass is a superset of the + arguments of a class. + remap_args: (Optional) A dictionary, mapping an argument name that exists (in either the + class or its dataclass), to another name. Useful when argument names are mismatched between + a class and its dataclass due to indirect instantiation via a helper method. + + Returns + ------- + A tuple containing information about the analysis: + 1) A bool value which is True if the signatures matched exactly / after ignoring values. + False otherwise. + 2) A set of arguments names that exist in the class, but *do not* exist in the dataclass. + If exact signature match occurs, this will be None instead. + 3) A set of argument names that exist in the data class, but *do not* exist in the class itself. + If exact signature match occurs, this will be None instead. + """ + class_sig = inspect.signature(cls.__init__) + + class_params = dict(**class_sig.parameters) + class_params.pop("self") + + dataclass_sig = inspect.signature(datacls) + + dataclass_params = dict(**dataclass_sig.parameters) + dataclass_params.pop("_target_", None) + + class_params = set(class_params.keys()) # type: ignore + dataclass_params = set(dataclass_params.keys()) # type: ignore + + if remap_args is not None: + for original_arg, new_arg in remap_args.items(): + if original_arg in class_params: + class_params.remove(original_arg) # type: ignore + class_params.add(new_arg) # type: ignore + logging.info(f"Remapped {original_arg} -> {new_arg} in {cls.__name__}") + + if original_arg in dataclass_params: + dataclass_params.remove(original_arg) # type: ignore + dataclass_params.add(new_arg) # type: ignore + logging.info(f"Remapped {original_arg} -> {new_arg} in {datacls.__name__}") + + if ignore_args is not None: + ignore_args = set(ignore_args) # type: ignore + + class_params = class_params - ignore_args # type: ignore + dataclass_params = dataclass_params - ignore_args # type: ignore + logging.info(f"Removing ignored arguments - {ignore_args}") + + intersection: Set[type] = set.intersection(class_params, dataclass_params) # type: ignore + subset_cls = class_params - intersection # type: ignore + subset_datacls = dataclass_params - intersection # type: ignore + + if (len(class_params) != len(dataclass_params)) or len(subset_cls) > 0 or len(subset_datacls) > 0: + logging.error(f"Class {cls.__name__} arguments do not match " f"Dataclass {datacls.__name__}!") + + if len(subset_cls) > 0: + logging.error(f"Class {cls.__name__} has additional arguments :\n" f"{subset_cls}") + + if len(subset_datacls): + logging.error(f"Dataclass {datacls.__name__} has additional arguments :\n{subset_datacls}") + + return False, subset_cls, subset_datacls + return True, None, None
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/decorators/deprecated.html b/docs/build/html/_modules/mridc/utils/decorators/deprecated.html new file mode 100644 index 00000000..7cc00372 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/decorators/deprecated.html @@ -0,0 +1,175 @@ + + + + + + mridc.utils.decorators.deprecated — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.decorators.deprecated
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.decorators.deprecated

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/decorators/deprecated.py
+
+__all__ = ["deprecated"]
+
+import functools
+import inspect
+from typing import Dict
+
+import wrapt
+
+# Remember which deprecation warnings have been printed already.
+from mridc.utils import logging
+
+_PRINTED_WARNING: Dict = {}
+
+
+
[docs]def deprecated(wrapped=None, version=None, explanation=None): + """ + This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted + when the function is used. + + Parameters + ---------- + wrapped: The function to be decorated. + function + version: The version of the package where the function was deprecated. + str + explanation: The explanation of the deprecation. + str + + Returns + ------- + The decorated function. + """ + if wrapped is None: + return functools.partial(deprecated, version=version, explanation=explanation) + + @wrapt.decorator + def wrapper(_wrapped, args, kwargs): + """ + Prints the adequate warning (only once per function) when required and calls the function func, passing the + original arguments, i.e. version and explanation. + + Parameters + ---------- + _wrapped: The function to be decorated. + args: The arguments passed to the function to be decorated. + kwargs: The keyword arguments passed to the function to be decorated. + + Returns + ------- + The decorated function. + """ + # Check if we already warned about that function. + if _wrapped.__name__ not in _PRINTED_WARNING: + # Add to list so we won't print it again. + _PRINTED_WARNING[_wrapped.__name__] = True + + # Prepare the warning message. + entity_name = "Class" if inspect.isclass(wrapped) else "Function" + msg = f"{entity_name} '{_wrapped.__name__}' is deprecated." + + # Optionally, add version and explanation. + if version is not None: + msg = f"{msg} It is going to be removed in the {version} version." + + if explanation is not None: + msg = f"{msg} {explanation}" + + # Display the deprecated warning. + logging.warning(msg) + + # Call the function. + return _wrapped(*args, **kwargs) + + return wrapper(wrapped)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/decorators/experimental.html b/docs/build/html/_modules/mridc/utils/decorators/experimental.html new file mode 100644 index 00000000..65799bc5 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/decorators/experimental.html @@ -0,0 +1,140 @@ + + + + + + mridc.utils.decorators.experimental — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.decorators.experimental
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.decorators.experimental

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/decorators/experimental.py
+
+__all__ = ["experimental"]
+
+from mridc.utils import logging
+
+
+
[docs]def experimental(cls): + """ + Decorator to mark a class as experimental. + + Parameters + ---------- + cls: The class to be decorated. + class + + Returns + ------- + The decorated class. + """ + + def wrapped(x): + """ + Wrapper function. + + Parameters + ---------- + x: The class to be decorated. + class + + Returns + ------- + The decorated class with the experimental flag set. + """ + logging.warning( + f"Module {x} is experimental, not ready for production and is not fully supported. Use at your own risk." + ) + + return x + + return wrapped(x=cls)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/decorators/port_docs.html b/docs/build/html/_modules/mridc/utils/decorators/port_docs.html new file mode 100644 index 00000000..4ed6bad0 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/decorators/port_docs.html @@ -0,0 +1,207 @@ + + + + + + mridc.utils.decorators.port_docs — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.decorators.port_docs
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.decorators.port_docs

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/decorators/port_docs.py
+
+# The "add_port_docs" decorator is needed to nicely generate neural types in Sphynx for input and output ports
+
+__all__ = ["add_port_docs"]
+
+import functools
+import sys
+
+import wrapt
+
+
+def _normalize_docstring(docstring):
+    """
+    Normalize docstring indentation. Replaces tabs with spaces, removes leading and trailing blanks lines, and removes
+     any indentation.
+
+    Copied from PEP-257: https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
+
+    Parameters
+    ----------
+    docstring: The docstring to normalize.
+        str
+
+    Returns
+    -------
+    The normalized docstring.
+    """
+    if not docstring:
+        return ""
+    # Convert tabs to spaces (following the normal Python rules)
+    # and split into a list of lines:
+    lines = docstring.expandtabs().splitlines()
+    # Determine minimum indentation (first line doesn't count):
+    # (we use sys.maxsize because sys.maxint doesn't exist in Python 3)
+    indent = sys.maxsize
+    for line in lines[1:]:
+        if stripped := line.lstrip():
+            indent = min(indent, len(line) - len(stripped))
+    # Remove indentation (first line is special):
+    trimmed = [lines[0].strip()]
+    if indent < sys.maxsize:
+        trimmed.extend(line[indent:].rstrip() for line in lines[1:])
+    # Strip off trailing and leading blank lines:
+    while trimmed and not trimmed[-1]:
+        trimmed.pop()
+    while trimmed and not trimmed[0]:
+        trimmed.pop(0)
+    # Return a single string:
+    return "\n".join(trimmed)
+
+
+
[docs]def add_port_docs(wrapped=None, instance=None, value=""): + """ + Adds port documentation to the wrapped function. + + Parameters + ---------- + wrapped: The function to decorate. + function + instance: The instance of the function. + object + value: The value of the port. + object + + Returns + ------- + The decorated function. + """ + if wrapped is None: + return functools.partial(add_port_docs, value=value) + + @wrapt.decorator + def wrapper(wrapped, instance=None, args=None, kwargs=None): + """ + Adds port documentation to the wrapped function. + + Parameters + ---------- + wrapped: The function to decorate. + instance: The instance of the function. + args: The arguments of the function. + kwargs: The keyword arguments of the function. + + Returns + ------- + The decorated function. + """ + return wrapped(*args, **kwargs) + + decorated = wrapper(wrapped) + try: + port_2_ntype = decorated(instance) + except AttributeError: + port_2_ntype = None + + port_description = "" + if port_2_ntype is not None: + for port, ntype in port_2_ntype.items(): + port_description += "* *" + port + "* : " + str(ntype) + port_description += "\n\n" + + __doc__ = _normalize_docstring(wrapped.__doc__) + "\n\n" + str(port_description) + __doc__ = _normalize_docstring(__doc__) + + wrapt.FunctionWrapper.__setattr__(decorated, "__doc__", __doc__) + + return decorated
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/distributed.html b/docs/build/html/_modules/mridc/utils/distributed.html new file mode 100644 index 00000000..2be3bb70 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/distributed.html @@ -0,0 +1,148 @@ + + + + + + mridc.utils.distributed — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.distributed
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.distributed

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/distributed.py
+
+import os
+
+import torch
+
+from mridc.utils import logging
+
+
+
[docs]def initialize_distributed(args, backend="nccl"): + """ + Initialize distributed training. + + Parameters + ---------- + args: The arguments object. + backend: The backend to use. + default: "nccl" + + Returns + ------- + local_rank: The local rank of the process. + rank: The rank of the process. + world_size: The number of processes. + """ + # Get local rank in case it is provided. + local_rank = args.local_rank + + # Get rank and world size. + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + + logging.info( + f"Initializing torch.distributed with local_rank: {local_rank}, rank: {rank}, world_size: {world_size}" + ) + + # Set the device id. + device = rank % torch.cuda.device_count() + if local_rank is not None: + device = local_rank + torch.cuda.set_device(device) + + # Call the init process. + init_method = "tcp://" + master_ip = os.getenv("MASTER_ADDR", "localhost") + master_port = os.getenv("MASTER_PORT", "6000") + init_method += f"{master_ip}:{master_port}" + torch.distributed.init_process_group(backend=backend, world_size=world_size, rank=rank, init_method=init_method) + return local_rank, rank, world_size
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/env_var_parsing.html b/docs/build/html/_modules/mridc/utils/env_var_parsing.html new file mode 100644 index 00000000..12afdca9 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/env_var_parsing.html @@ -0,0 +1,279 @@ + + + + + + mridc.utils.env_var_parsing — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.env_var_parsing
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.env_var_parsing

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/env_var_parsing.py
+
+import decimal
+import json
+import os
+
+from dateutil import parser  # type: ignore
+
+__all__ = [
+    "get_env",
+    "get_envbool",
+    "get_envint",
+    "get_envfloat",
+    "get_envdecimal",
+    "get_envdate",
+    "get_envdatetime",
+    "get_envlist",
+    "get_envdict",
+    "CoercionError",
+    "RequiredSettingMissingError",
+]
+
+
+
[docs]class CoercionError(Exception): + """Custom error raised when a value cannot be coerced.""" + + def __init__(self, key, value, func): + msg = f"Unable to coerce '{key}={value}' using {func.__name__}." + super(CoercionError, self).__init__(msg)
+ + +
[docs]class RequiredSettingMissingError(Exception): + """Custom error raised when a required env var is missing.""" + + def __init__(self, key): + msg = f"Required env var '{key}' is missing." + super(RequiredSettingMissingError, self).__init__(msg)
+ + +def _get_env(key, default=None, coerce=lambda x: x, required=False): + """ + Return env var coerced into a type other than string. This function extends the standard os.getenv function to \ + enable the coercion of values into data types other than string (all env vars are strings by default). + + Parameters + ---------- + key: The name of the env var to retrieve. + default: The default value to return if the env var is not set. NB the default value is **not** coerced, and is \ + assumed to be of the correct type. + coerce: A function that takes a string and returns a value of the desired type. + required: If True, raises a RequiredSettingMissingError if the env var is not set. + + Returns + ------- + The value of the env var coerced into the desired type. + """ + try: + value = os.environ[key] + except KeyError as e: + if required is True: + raise RequiredSettingMissingError(key) from e + return default + + try: + return coerce(value) + except Exception as exc: + raise CoercionError(key, value, coerce) from exc + + +# standard type coercion functions +def _bool(value): + """Return env var cast as boolean.""" + if isinstance(value, bool): + return value + + return value is not None and value.lower() not in ( + "false", + "0", + "no", + "n", + "f", + "none", + ) + + +def _int(value): + """Return env var cast as integer.""" + return int(value) + + +def _float(value): + """Return env var cast as float.""" + return float(value) + + +def _decimal(value): + """Return env var cast as Decimal.""" + return decimal.Decimal(value) + + +def _dict(value): + """Return env var as a dict.""" + return json.loads(value) + + +def _datetime(value): + """Return env var as a datetime.""" + return parser.parse(value) + + +def _date(value): + """Return env var as a date.""" + return parser.parse(value).date() + + +
[docs]def get_env(key, *default, **kwargs): + """ + Return env var. This is the parent function of all other get_foo functions, and is responsible for unpacking \ + args/kwargs into the values that _get_env expects (it is the root function that actually interacts with environ). + + Parameters + ---------- + key: string, the env var name to look up. + default: (optional) the value to use if the env var does not exist. If this value is not supplied, then the \ + env var is considered to be required, and a RequiredSettingMissingError error will be raised if it does not exist. + kwargs: + coerce: a func that may be supplied to coerce the value into something else. This is used by the default \ + get_foo functions to cast strings to builtin types, but could be a function that returns a custom class. + + Returns + ------- + The env var, coerced if required, and a default if supplied. + """ + if len(default) not in (0, 1): + raise AssertionError("Too many args supplied.") + func = kwargs.get("coerce", lambda x: x) + required = len(default) == 0 + default = None if required else default[0] + return _get_env(key, default=default, coerce=func, required=required)
+ + +
[docs]def get_envbool(key, *default): + """Return env var cast as boolean.""" + return get_env(key, *default, coerce=_bool)
+ + +
[docs]def get_envint(key, *default): + """Return env var cast as integer.""" + return get_env(key, *default, coerce=_int)
+ + +
[docs]def get_envfloat(key, *default): + """Return env var cast as float.""" + return get_env(key, *default, coerce=_float)
+ + +
[docs]def get_envdecimal(key, *default): + """Return env var cast as Decimal.""" + return get_env(key, *default, coerce=_decimal)
+ + +
[docs]def get_envdate(key, *default): + """Return env var as a date.""" + return get_env(key, *default, coerce=_date)
+ + +
[docs]def get_envdatetime(key, *default): + """Return env var as a datetime.""" + return get_env(key, *default, coerce=_datetime)
+ + +
[docs]def get_envlist(key, *default, **kwargs): + """Return env var as a list.""" + separator = kwargs.get("separator", " ") + return get_env(key, *default, coerce=lambda x: x.split(separator))
+ + +
[docs]def get_envdict(key, *default): + """Return env var as a dict.""" + return get_env(key, *default, coerce=_dict)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/exceptions.html b/docs/build/html/_modules/mridc/utils/exceptions.html new file mode 100644 index 00000000..c108b334 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/exceptions.html @@ -0,0 +1,129 @@ + + + + + + mridc.utils.exceptions — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +

Source code for mridc.utils.exceptions

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/exceptions.py
+
+
+
[docs]class MRIDCBaseException(Exception): + """MRIDC Base Exception. All exceptions created in MRIDC should inherit from this class"""
+ + +
[docs]class LightningNotInstalledException(MRIDCBaseException): + """Exception for when lightning is not installed""" + + def __init__(self, obj): + message = ( + f" You are trying to use {obj} without installing all of pytorch_lightning, hydra, and " + f"omegaconf. Please install those packages before trying to access {obj}." + ) + super().__init__(message)
+ + +
[docs]class CheckInstall: + """Class to check if a package is installed.""" + + def __init__(self, *args, **kwargs): + raise LightningNotInstalledException(self) + + def __call__(self, *args, **kwargs): + raise LightningNotInstalledException(self) + + def __getattr__(self, *args, **kwargs): + raise LightningNotInstalledException(self)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/exp_manager.html b/docs/build/html/_modules/mridc/utils/exp_manager.html new file mode 100644 index 00000000..864c9b58 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/exp_manager.html @@ -0,0 +1,1082 @@ + + + + + + mridc.utils.exp_manager — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.exp_manager
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.exp_manager

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/exp_manager.py
+import os
+import re
+import subprocess
+import sys
+import time
+from copy import deepcopy
+from dataclasses import dataclass
+
+from pathlib import Path
+from shutil import copy, move
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from hydra.core.hydra_config import HydraConfig
+from hydra.utils import get_original_cwd
+from omegaconf import DictConfig, OmegaConf, open_dict
+from pytorch_lightning import Trainer
+from pytorch_lightning.callbacks import Callback, ModelCheckpoint
+from pytorch_lightning.callbacks.timer import Timer
+from pytorch_lightning.loggers import LoggerCollection as _LoggerCollection, TensorBoardLogger, WandbLogger
+from pytorch_lightning.strategies.ddp import DDPStrategy
+
+import mridc.utils
+from mridc.constants import MRIDC_ENV_VARNAME_TESTING, MRIDC_ENV_VARNAME_VERSION
+from mridc.utils import logging, timers
+from mridc.utils.app_state import AppState
+from mridc.utils.env_var_parsing import get_envbool
+from mridc.utils.exceptions import MRIDCBaseException
+from mridc.utils.get_rank import is_global_rank_zero
+from mridc.utils.lightning_logger_patch import add_filehandlers_to_pl_logger
+
+
+
[docs]class NotFoundError(MRIDCBaseException): + """Raised when a file or folder is not found"""
+ + +
[docs]class LoggerMisconfigurationError(MRIDCBaseException): + """Raised when a mismatch between trainer.logger and exp_manager occurs""" + + def __init__(self, message): + message = ( + message + "You can disable lightning's trainer from creating a logger by passing logger=False to its " + "constructor. " + ) + super().__init__(message)
+ + +
[docs]class CheckpointMisconfigurationError(MRIDCBaseException): + """Raised when a mismatch between trainer.callbacks and exp_manager occurs"""
+ + +
[docs]@dataclass +class CallbackParams: + """Parameters for a callback""" + + filepath: Optional[str] = None # Deprecated + dirpath: Optional[str] = None # If None, exp_manager will attempt to handle the filepath + filename: Optional[str] = None # If None, exp_manager will attempt to handle the filepath + monitor: Optional[str] = "val_loss" + verbose: Optional[bool] = True + save_last: Optional[bool] = True + save_top_k: Optional[int] = 3 + save_weights_only: Optional[bool] = False + mode: Optional[str] = "min" + every_n_epochs: Optional[int] = 1 + prefix: Optional[str] = None # If None, exp_manager will attempt to handle the filepath + postfix: str = ".mridc" + save_best_model: bool = False + always_save_mridc: bool = False + save_mridc_on_train_end: Optional[bool] = True # Automatically save .mridc file during on_train_end hook + model_parallel_size: Optional[int] = None # tensor parallel size * pipeline parallel size
+ + +
[docs]@dataclass +class StepTimingParams: + """Parameters for the step timing callback.""" + + reduction: Optional[str] = "mean" + # if True torch.cuda.synchronize() is called on start/stop + sync_cuda: Optional[bool] = False + # if positive, defines the size of a sliding window for computing mean + buffer_size: Optional[int] = 1
+ + +
[docs]@dataclass +class ExpManagerConfig: + """Configuration for the experiment manager.""" + + # Log dir creation parameters + explicit_log_dir: Optional[str] = None + exp_dir: Optional[str] = None + name: Optional[str] = None + version: Optional[str] = None + use_datetime_version: Optional[bool] = True + resume_if_exists: Optional[bool] = False + resume_past_end: Optional[bool] = False + resume_ignore_no_checkpoint: Optional[bool] = False + # Logging parameters + create_tensorboard_logger: Optional[bool] = True + summary_writer_kwargs: Optional[Dict[Any, Any]] = None + create_wandb_logger: Optional[bool] = False + wandb_logger_kwargs: Optional[Dict[Any, Any]] = None + # Checkpointing parameters + create_checkpoint_callback: Optional[bool] = True + checkpoint_callback_params: Optional[CallbackParams] = CallbackParams() + # Additional exp_manager arguments + files_to_copy: Optional[List[str]] = None + # logs timing of train/val/test steps + log_step_timing: Optional[bool] = True + step_timing_kwargs: Optional[StepTimingParams] = StepTimingParams() + # Configures creation of log files for different ranks + log_local_rank_0_only: Optional[bool] = False + log_global_rank_0_only: Optional[bool] = False + model_parallel_size: Optional[int] = None
+ + +
[docs]class TimingCallback(Callback): + """Logs execution time of train/val/test steps""" + + def __init__(self, timer_kwargs=None): + """Initialize TimingCallback""" + if timer_kwargs is None: + timer_kwargs = {} + self.timer = timers.NamedTimer(**timer_kwargs) + + def _on_batch_start(self, name): + """Called at the beginning of each batch""" + # reset only if we do not return mean of a sliding window + if self.timer.buffer_size <= 0: + self.timer.reset(name) + + self.timer.start(name) + + def _on_batch_end(self, name, pl_module): + """Called at the end of each batch""" + self.timer.stop(name) + pl_module.log(name, self.timer[name], on_step=True, on_epoch=False) + +
[docs] def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, **kwargs): + """Called at the beginning of each training batch""" + self._on_batch_start("train_step_timing")
+ +
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, **kwargs): + """Logs the time taken by the training batch""" + self._on_batch_end("train_step_timing", pl_module)
+ +
[docs] def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + """Logs the time taken by the validation batch""" + self._on_batch_start("validation_step_timing")
+ +
[docs] def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + """Logs the time taken by the validation step""" + self._on_batch_end("validation_step_timing", pl_module)
+ +
[docs] def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + """Logs execution time of test steps""" + self._on_batch_start("test_step_timing")
+ +
[docs] def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + """Logs execution time of test steps""" + self._on_batch_end("test_step_timing", pl_module)
+ +
[docs] def on_before_backward(self, trainer, pl_module, loss): + """Logs the time taken for backward pass""" + self._on_batch_start("train_backward_timing")
+ +
[docs] def on_after_backward(self, trainer, pl_module): + """Note: this is called after the optimizer step""" + self._on_batch_end("train_backward_timing", pl_module)
+ + +
[docs]def exp_manager(trainer: Trainer, cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]: + """ + exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning \ + paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will \ + get exp_dir, name, and version from the logger. Otherwise, it will use the exp_dir and name arguments to create \ + the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir. + + The version can be a datetime string or an integer. Datetime version can be disabled if you use_datetime_version \ + is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch \ + lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file \ + for each process to log their output into. + + exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from \ + the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need \ + multiple consecutive jobs), you need to avoid creating the version folders. Therefore, from v1.0.0, when \ + resume_if_exists is set to True, creating the version folders is ignored. + + Parameters + ---------- + trainer: The lightning trainer object. + cfg: Can have the following keys: + - explicit_log_dir: Can be used to override exp_dir/name/version folder creation. Defaults to None, which \ + will use exp_dir, name, and version to construct the logging directory. + - exp_dir: The base directory to create the logging directory. Defaults to None, which logs to \ + ./mridc_experiments. + - name: The name of the experiment. Defaults to None which turns into "default" via name = name or "default". + - version: The version of the experiment. Defaults to None which uses either a datetime string or lightning's \ + TensorboardLogger system of using version_{int}. + - use_datetime_version: Whether to use a datetime string for version. Defaults to True. + - resume_if_exists: Whether this experiment is resuming from a previous run. If True, it sets \ + trainer._checkpoint_connector.resume_from_checkpoint_fit_path so that the trainer should auto-resume. \ + exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. From v1.0.0, when \ + resume_if_exists is True, we would not create version folders to make it easier to find the log folder for \ + next runs. + - resume_past_end: exp_manager errors out if resume_if_exists is True and a checkpoint matching \*end.ckpt \ + indicating a previous training run fully completed. This behaviour can be disabled, in which case the \ + \*end.ckpt will be loaded by setting resume_past_end to True. Defaults to False. + - resume_ignore_no_checkpoint: exp_manager errors out if resume_if_exists is True and no checkpoint could be \ + found. This behaviour can be disabled, in which case exp_manager will print a message and continue without \ + restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False. + - create_tensorboard_logger: Whether to create a tensorboard logger and attach it to the pytorch lightning \ + trainer. Defaults to True. + - summary_writer_kwargs: A dictionary of kwargs that can be passed to lightning's TensorboardLogger class. \ + Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None. + - create_wandb_logger: Whether to create a Weights and Biases logger and attach it to the pytorch lightning \ + trainer. Defaults to False. + - wandb_logger_kwargs: A dictionary of kwargs that can be passed to lightning's WandBLogger class. Note that \ + name and project are required parameters if create_wandb_logger is True. Defaults to None. + - create_checkpoint_callback: Whether to create a ModelCheckpoint callback and attach it to the pytorch \ + lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most recent \ + checkpoint under \*last.ckpt, and the final checkpoint after training completes under \*end.ckpt. \ + Defaults to True. + - files_to_copy: A list of files to copy to the experiment logging directory. Defaults to None which copies \ + no files. + - log_local_rank_0_only: Whether to only create log files for local rank 0. Defaults to False. Set this to \ + True if you are using DDP with many GPUs and do not want many log files in your exp dir. + - log_global_rank_0_only: Whether to only create log files for global rank 0. Defaults to False. Set this to \ + True if you are using DDP with many GPUs and do not want many log files in your exp dir. + + Returns + ------- + The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and version. + """ + # Add rank information to logger + # Note: trainer.global_rank and trainer.is_global_zero are not set until trainer.fit, so have to hack around it + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + global_rank = trainer.node_rank * trainer.num_devices + local_rank + logging.rank = global_rank + + if cfg is None: + logging.error("exp_manager did not receive a cfg argument. It will be disabled.") + return None + + if trainer.fast_dev_run: + logging.info("Trainer was called with fast_dev_run. exp_manager will return without any functionality.") + return None + + # Ensure passed cfg is compliant with ExpManagerConfig + schema = OmegaConf.structured(ExpManagerConfig) + if isinstance(cfg, dict): + cfg = OmegaConf.create(cfg) + elif not isinstance(cfg, DictConfig): + raise ValueError(f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig") + cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True)) + cfg = OmegaConf.merge(schema, cfg) + + error_checks(trainer, cfg) # Ensures that trainer options are compliant with MRIDC and exp_manager arguments + + log_dir, exp_dir, name, version = get_log_dir( + trainer=trainer, + exp_dir=cfg.exp_dir, + name=cfg.name, + version=cfg.version, + explicit_log_dir=cfg.explicit_log_dir, + use_datetime_version=cfg.use_datetime_version, + resume_if_exists=cfg.resume_if_exists, + ) + + if cfg.resume_if_exists: + check_resume(trainer, str(log_dir), cfg.resume_past_end, cfg.resume_ignore_no_checkpoint) + + checkpoint_name = name + # If name returned from get_log_dir is "", use cfg.name for checkpointing + if checkpoint_name is None or checkpoint_name == "": + checkpoint_name = cfg.name or "default" + cfg.name = name # Used for configure_loggers so that the log_dir is properly set even if name is "" + cfg.version = version + + # update app_state with log_dir, exp_dir, etc + app_state = AppState() + app_state.log_dir = log_dir + app_state.exp_dir = exp_dir + app_state.name = name + app_state.version = version + app_state.checkpoint_name = checkpoint_name + app_state.create_checkpoint_callback = cfg.create_checkpoint_callback + app_state.checkpoint_callback_params = cfg.checkpoint_callback_params + + # Create the logging directory if it does not exist + os.makedirs(log_dir, exist_ok=True) # Cannot limit creation to global zero as all ranks write to own log file + logging.info(f"Experiments will be logged at {log_dir}") + trainer._default_root_dir = log_dir + + if cfg.log_local_rank_0_only is True and cfg.log_global_rank_0_only is True: + raise ValueError( + "Cannot set both log_local_rank_0_only and log_global_rank_0_only to True." + "Please set either one or neither." + ) + + # This is set if the env var MRIDC_TESTING is set to True. + mridc_testing = get_envbool(MRIDC_ENV_VARNAME_TESTING, False) + + log_file = log_dir / f"mridc_log_globalrank-{global_rank}_localrank-{local_rank}.txt" + # Handle logging to file + # Logs local rank 0 only + if local_rank == 0 and cfg.log_local_rank_0_only and not mridc_testing: + logging.add_file_handler(log_file) + elif global_rank == 0 and cfg.log_global_rank_0_only and mridc_testing: + logging.add_file_handler(log_file) + else: + logging.add_file_handler(log_file) + + # For some reason, LearningRateLogger requires trainer to have a logger. Safer to create logger on all ranks + # not just global rank 0. + if cfg.create_tensorboard_logger or cfg.create_wandb_logger: + configure_loggers( + trainer, + [Path(exp_dir)], + cfg.name, + cfg.version, + cfg.create_tensorboard_logger, + cfg.summary_writer_kwargs, + cfg.create_wandb_logger, + cfg.wandb_logger_kwargs, + ) + + # add loggers timing callbacks + if cfg.log_step_timing: + timing_callback = TimingCallback(timer_kwargs=cfg.step_timing_kwargs or {}) + trainer.callbacks.insert(0, timing_callback) + + if cfg.create_checkpoint_callback: + configure_checkpointing( + trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params + ) + + if is_global_rank_zero(): + # Move files_to_copy to folder and add git information if present + if cfg.files_to_copy: + for _file in cfg.files_to_copy: + copy(Path(_file), log_dir) + + # Create files for cmd args and git info + with open(log_dir / "cmd-args.log", "w", encoding="utf-8") as _file: + _file.write(" ".join(sys.argv)) + + # Try to get git hash + git_repo, git_hash = get_git_hash() + if git_repo: + with open(log_dir / "git-info.log", "w", encoding="utf-8") as _file: + _file.write(f"commit hash: {git_hash}") + _file.write(get_git_diff()) + + # Add err_file logging to global_rank zero + logging.add_err_file_handler(log_dir / "mridc_error_log.txt") + + # Add lightning file logging to global_rank zero + add_filehandlers_to_pl_logger(log_dir / "lightning_logs.txt", log_dir / "mridc_error_log.txt") + + return log_dir
+ + +
[docs]def error_checks(trainer: Trainer, cfg: Optional[Union[DictConfig, Dict]] = None): + """ + Checks that the passed trainer is compliant with MRIDC and exp_manager's passed configuration. Checks that: + - Throws error when hydra has changed the working directory. This causes issues with lightning's DDP + - Throws error when trainer has loggers defined but create_tensorboard_logger or create_WandB_logger is True + - Prints error messages when 1) run on multi-node and not Slurm, and 2) run on multi-gpu without DDP + """ + if HydraConfig.initialized() and get_original_cwd() != os.getcwd(): + raise ValueError( + "Hydra changed the working directory. This interferes with ExpManger's functionality. Please pass " + "hydra.run.dir=. to your python script." + ) + + if trainer.logger is not None and (cfg.create_tensorboard_logger or cfg.create_wandb_logger): # type: ignore + raise LoggerMisconfigurationError( + "The pytorch lightning trainer that was passed to exp_manager contained a logger, and either " + "create_tensorboard_logger or create_wandb_logger was set to True. These can only be used if trainer does " + "not already have a logger." + ) + + if trainer.num_nodes > 1 and not check_slurm(trainer): # type: ignore + logging.error( + "You are running multi-node training without SLURM handling the processes." + " Please note that this is not tested in MRIDC and could result in errors." + ) + + if trainer.num_devices > 1 and not isinstance(trainer.strategy, DDPStrategy): # type: ignore + logging.error( + "You are running multi-gpu without ddp.Please note that this is not tested in MRIDC and could result in " + "errors." + )
+ + +
[docs]def check_resume( + trainer: Trainer, + log_dir: str, + resume_past_end: bool = False, + resume_ignore_no_checkpoint: bool = False, +): + """ + Checks that resume=True was used correctly with the arguments pass to exp_manager. Sets + trainer._checkpoint_connector.resume_from_checkpoint_fit_path as necessary. + + Parameters + ---------- + trainer: The trainer that is being used. + log_dir: The directory where the logs are being saved. + resume_past_end: Whether to resume from the end of the experiment. + resume_ignore_no_checkpoint: Whether to ignore if there is no checkpoint to resume from. + + Returns + ------- + NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found. + ValueError: If resume is True, and there were more than 1 checkpoint could found. + """ + if not log_dir: + raise ValueError(f"Resuming requires the log_dir {log_dir} to be passed to exp_manager") + + checkpoint_dir = Path(Path(log_dir) / "checkpoints") + + checkpoint = None + end_checkpoints = list(checkpoint_dir.rglob("*end.ckpt")) + last_checkpoints = list(checkpoint_dir.rglob("*last.ckpt")) + if not checkpoint_dir.exists(): + if not resume_ignore_no_checkpoint: + raise NotFoundError(f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume.") + logging.warning(f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Training from scratch.") + return + if end_checkpoints: + if not resume_past_end: + raise ValueError( + f"Found {end_checkpoints[0]} indicating that the last training run has already completed." + ) + if len(end_checkpoints) > 1: + if "mp_rank" in str(end_checkpoints[0]): + checkpoint = end_checkpoints[0] + else: + raise ValueError(f"Multiple checkpoints {end_checkpoints} that matches *end.ckpt.") + logging.info(f"Resuming from {end_checkpoints[0]}") + elif not last_checkpoints: + if not resume_ignore_no_checkpoint: + raise NotFoundError(f"There were no checkpoints found in {checkpoint_dir}. Cannot resume.") + logging.warning(f"There were no checkpoints found in {checkpoint_dir}. Training from scratch.") + return + elif len(last_checkpoints) > 1: + if "mp_rank" not in str(last_checkpoints[0]) and "tp_rank" not in str(last_checkpoints[0]): + raise ValueError(f"Multiple checkpoints {last_checkpoints} that matches *last.ckpt.") + checkpoint = last_checkpoints[0] + checkpoint = mridc.utils.model_utils.uninject_model_parallel_rank(checkpoint) # type: ignore + else: + logging.info(f"Resuming from {last_checkpoints[0]}") + checkpoint = last_checkpoints[0] + + trainer._checkpoint_connector.resume_from_checkpoint_fit_path = str(checkpoint) + + if is_global_rank_zero(): + if files_to_move := [child for child in Path(log_dir).iterdir() if child.is_file()]: + # Move old files to a new folder + other_run_dirs = Path(log_dir).glob("run_*") + run_count = sum(bool(fold.is_dir()) for fold in other_run_dirs) + new_run_dir = Path(Path(log_dir) / f"run_{run_count}") + new_run_dir.mkdir() + for _file in files_to_move: + move(str(_file), str(new_run_dir))
+ + +
[docs]def check_explicit_log_dir( + trainer: Trainer, explicit_log_dir: List[Union[Path, str]], exp_dir: str, name: str, version: str +) -> Tuple[Path, str, str, str]: + """ + Checks that the passed arguments are compatible with explicit_log_dir. + + Parameters + ---------- + trainer: The trainer to check. + explicit_log_dir: The explicit log dir to check. + exp_dir: The experiment directory to check. + name: The experiment name to check. + version: The experiment version to check. + + Returns + ------- + The log_dir, exp_dir, name, and version that should be used. + + Raises + ------ + LoggerMisconfigurationError + """ + if trainer.logger is not None: + raise LoggerMisconfigurationError( + "The pytorch lightning trainer that was passed to exp_manager contained a logger and explicit_log_dir: " + f"{explicit_log_dir} was pass to exp_manager. Please remove the logger from the lightning trainer." + ) + # Checking only (explicit_log_dir) vs (exp_dir and version). + # The `name` will be used as the actual name of checkpoint/archive. + if exp_dir or version: + logging.error( + f"exp_manager received explicit_log_dir: {explicit_log_dir} and at least one of exp_dir: {exp_dir}, " + f"or version: {version}. Please note that exp_dir, name, and version will be ignored." + ) + if is_global_rank_zero() and Path(str(explicit_log_dir)).exists(): + logging.warning(f"Exp_manager is logging to {explicit_log_dir}, but it already exists.") + return Path(str(explicit_log_dir)), str(explicit_log_dir), "", ""
+ + +
[docs]def get_log_dir( + trainer: Trainer, + exp_dir: str = None, + name: str = None, + version: str = None, + explicit_log_dir: str = None, + use_datetime_version: bool = True, + resume_if_exists: bool = False, +) -> Tuple[Path, str, str, str]: + """ + Obtains the log_dir used for exp_manager. + + Parameters + ---------- + trainer: The trainer to check. + exp_dir: The experiment directory to check. + name: The experiment name to check. + version: The experiment version to check. + explicit_log_dir: The explicit log dir to check. + use_datetime_version: Whether to use datetime versioning. + resume_if_exists: Whether to resume if the log_dir already exists. + + Raises + ------- + LoggerMisconfigurationError: If trainer is incompatible with arguments + NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found. + ValueError: If resume is True, and there were more than 1 checkpoint could found. + """ + if explicit_log_dir: # If explicit log_dir was passed, short circuit + return check_explicit_log_dir(trainer, [Path(explicit_log_dir)], exp_dir, name, version) # type: ignore + + # Default exp_dir to ./mridc_experiments if None was passed + _exp_dir = exp_dir + if exp_dir is None: + _exp_dir = str(Path.cwd() / "mridc_experiments") + + # If the user has already defined a logger for the trainer, use the logger defaults for logging directory + if trainer.logger is not None: + if trainer.logger.save_dir: + if exp_dir: + raise LoggerMisconfigurationError( + "The pytorch lightning trainer that was passed to exp_manager contained a logger, the logger's " + f"save_dir was not None, and exp_dir ({exp_dir}) was not None. If trainer.logger.save_dir " + "exists, exp_manager will use trainer.logger.save_dir as the logging directory and exp_dir " + "must be None." + ) + _exp_dir = trainer.logger.save_dir + if name: + raise LoggerMisconfigurationError( + "The pytorch lightning trainer that was passed to exp_manager contained a logger, and name: " + f"{name} was also passed to exp_manager. If the trainer contains a " + "logger, exp_manager will use trainer.logger.name, and name passed to exp_manager must be None." + ) + name = trainer.logger.name + version = f"version_{trainer.logger.version}" + # Use user-defined exp_dir, project_name, exp_name, and versioning options + else: + name = name or "default" + version = version or os.environ.get(MRIDC_ENV_VARNAME_VERSION) + + if not version: + if resume_if_exists: + logging.warning( + "No version folders would be created under the log folder as 'resume_if_exists' is enabled." + ) + version = None + elif is_global_rank_zero(): + if use_datetime_version: + version = time.strftime("%Y-%m-%d_%H-%M-%S") + else: + tensorboard_logger = TensorBoardLogger(save_dir=_exp_dir, name=name, version=version) + version = f"version_{tensorboard_logger.version}" + os.environ[MRIDC_ENV_VARNAME_VERSION] = "" if version is None else version + + log_dir = Path(str(_exp_dir)) / Path(str(name)) / Path("" if version is None else str(version)) + return log_dir, str(_exp_dir), str(name), str(version)
+ + +
[docs]def get_git_hash(): + """ + Helper function that tries to get the commit hash if running inside a git folder. + + Returns + ------- + Bool: Whether the git subprocess ran without error. + String: git subprocess output or error message + """ + try: + return True, subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.STDOUT).decode() + except subprocess.CalledProcessError as err: + return False, "{}\n".format(err.output.decode("utf-8"))
+ + +
[docs]def get_git_diff(): + """ + Helper function that tries to get the git diff if running inside a git folder. + + Returns + ------- + Bool: Whether the git subprocess ran without error. + String: git subprocess output or error message + """ + try: + return subprocess.check_output(["git", "diff"], stderr=subprocess.STDOUT).decode() + except subprocess.CalledProcessError as err: + return "{}\n".format(err.output.decode("utf-8"))
+ + +
[docs]class LoggerList(_LoggerCollection): + """A thin wrapper on Lightning's LoggerCollection such that name and version are better aligned with exp_manager""" + + def __init__(self, _logger_iterable, mridc_name=None, mridc_version=""): + super().__init__(_logger_iterable) + self._mridc_name = mridc_name + self._mridc_version = mridc_version + + @property + def name(self) -> str: + """The name of the experiment.""" + return self._mridc_name + + @property + def version(self) -> str: + """The version of the experiment. If the logger was created with a version, this will be the version.""" + return self._mridc_version
+ + +
[docs]def configure_loggers( + trainer: Trainer, + exp_dir: List[Union[Path, str]], + name: str, + version: str, + create_tensorboard_logger: bool, + summary_writer_kwargs: dict, + create_wandb_logger: bool, + wandb_kwargs: dict, +): + """ + Creates TensorboardLogger and/or WandBLogger and attach them to trainer. Raises ValueError if summary_writer_kwargs + or wandb_kwargs are miss configured. + + Parameters + ---------- + trainer: The trainer to attach the loggers to. + exp_dir: The experiment directory. + name: The name of the experiment. + version: The version of the experiment. + create_tensorboard_logger: Whether to create a TensorboardLogger. + summary_writer_kwargs: The kwargs to pass to the TensorboardLogger. + create_wandb_logger: Whether to create a Weights & Biases logger. + wandb_kwargs: The kwargs to pass to the Weights & Biases logger. + + Returns + ------- + LoggerList: A list of loggers. + """ + # Potentially create tensorboard logger and/or WandBLogger + logger_list = [] + if create_tensorboard_logger: + if summary_writer_kwargs is None: + summary_writer_kwargs = {} + elif "log_dir" in summary_writer_kwargs: + raise ValueError( + "You cannot pass `log_dir` as part of `summary_writer_kwargs`. `log_dir` is handled by lightning's " + "TensorBoardLogger logger." + ) + tensorboard_logger = TensorBoardLogger( + save_dir=exp_dir[0], name=name, version=version, **summary_writer_kwargs + ) + logger_list.append(tensorboard_logger) + logging.info("TensorboardLogger has been set up") + + if create_wandb_logger: + if wandb_kwargs is None: + wandb_kwargs = {} + if "name" not in wandb_kwargs and "project" not in wandb_kwargs: + raise ValueError("name and project are required for wandb_logger") + wandb_logger = WandbLogger(save_dir=exp_dir[0], version=version, **wandb_kwargs) + + logger_list.append(wandb_logger) + logging.info("WandBLogger has been set up") + + logger_list = ( + LoggerList(logger_list, mridc_name=name, mridc_version=version) if len(logger_list) > 1 else logger_list[0] + ) + trainer._logger_connector.configure_logger(logger_list)
+ + +
[docs]class MRIDCModelCheckpoint(ModelCheckpoint): + """Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end""" + + def __init__( + self, + always_save_mridc=False, + save_mridc_on_train_end=True, + save_best_model=False, + postfix=".mridc", + n_resume=False, + model_parallel_size=None, + **kwargs, + ): + """ + + Parameters + ---------- + always_save_mridc: Whether to save the model even if it is not the best model. Default: False. + save_mridc_on_train_end: Whether to save the model at the end of training. Default: True. + save_best_model: Whether to save the model if it is the best model. Default: False. + postfix: The postfix to add to the model name. Default: ".mridc". + n_resume: Whether to resume training from a checkpoint. Default: False. + model_parallel_size: The size of the model parallel group. Default: None. + kwargs: The kwargs to pass to ModelCheckpoint. + """ + # Parse and store "extended" parameters: save_best model and postfix. + self.always_save_mridc = always_save_mridc + self.save_mridc_on_train_end = save_mridc_on_train_end + self.save_best_model = save_best_model + self.previous_model_path = None + self.last_model_path: Union[Any, str] = None + if self.save_best_model and not self.save_mridc_on_train_end: + logging.warning( + ( + "Found save_best_model is True and save_mridc_on_train_end is False. " + "Set save_mridc_on_train_end to True to automatically save the best model." + ) + ) + self.postfix = postfix + self.previous_best_path = "" + self.model_parallel_size = model_parallel_size + + # `prefix` is deprecated + self.prefix = kwargs.pop("prefix") if "prefix" in kwargs else "" + # Call the parent class constructor with the remaining kwargs. + super().__init__(**kwargs) + + if self.save_top_k != -1 and n_resume: + logging.debug("Checking previous runs") + self.mridc_topk_check_previous_run() + +
[docs] def mridc_topk_check_previous_run(self): + """Check if there are previous runs with the same topk value.""" + self.best_k_models = {} + self.kth_best_model_path = "" + self.best_model_score = None + self.best_model_path = "" + + checkpoints = list(Path(self.dirpath).rglob("*.ckpt")) + for checkpoint in checkpoints: + if "mp_rank" in str(checkpoint) or "tp_rank" in str(checkpoint): + checkpoint = mridc.utils.model_utils.uninject_model_parallel_rank(checkpoint) + checkpoint = str(checkpoint) + if checkpoint.endswith("-last.ckpt"): + continue + index = checkpoint.find(self.monitor) + len(self.monitor) + 1 # Find monitor in str + 1 for '=' + if index != -1: + if match := re.search("[A-z]", checkpoint[index:]): + value = checkpoint[index : index + match.start() - 1] # -1 due to separator hypen + self.best_k_models[checkpoint] = float(value) + if not self.best_k_models: + return # No saved checkpoints yet + + _reverse = self.mode != "min" + + best_k_models = sorted(self.best_k_models, key=self.best_k_models.get, reverse=_reverse) + + # This section should be ok as rank zero will delete all excess checkpoints, since all other ranks are + # instantiated after rank zero. models_to_delete should be 0 for all other ranks. + if self.model_parallel_size is not None: + models_to_delete = len(best_k_models) - self.model_parallel_size * self.save_top_k + else: + models_to_delete = len(best_k_models) - self.save_top_k + logging.debug(f"Number of models to delete: {models_to_delete}") + for _ in range(models_to_delete): + model = best_k_models.pop(-1) + self.best_k_models.pop(model) + self._del_model_without_trainer(model) + logging.debug(f"Removed checkpoint: {model}") + + self.kth_best_model_path = best_k_models[-1] + self.best_model_path = best_k_models[0] + self.best_model_score = self.best_k_models[self.best_model_path]
+ +
[docs] def on_save_checkpoint(self, trainer, pl_module, checkpoint): + """ + Override the default on_save_checkpoint to save the best model if needed. + + Parameters + ---------- + trainer: The trainer object. + pl_module: The PyTorch-Lightning module. + checkpoint: The checkpoint object. + """ + output = super().on_save_checkpoint(trainer, pl_module, checkpoint) + if not self.always_save_mridc: + return output + # Load the best model and then re-save it + app_state = AppState() + + if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: + raise ValueError("always_save_mridc is not implemented for model parallel models.") + + # since we are creating tarfile artifacts we need to update .mridc path + app_state.model_restore_path = os.path.abspath( + os.path.expanduser(os.path.join(self.dirpath, self.prefix + self.postfix)) + ) + + if self.save_best_model: + if not os.path.exists(self.best_model_path): + return output + + if self.best_model_path == self.previous_best_path: + return output + + self.previous_model_path = self.best_model_path + old_state_dict = deepcopy(pl_module.state_dict()) + checkpoint = torch.load(self.best_model_path, map_location="cpu") + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + # get a new instance of the model + pl_module.load_state_dict(checkpoint, strict=True) + pl_module.save_to(save_path=app_state.model_restore_path) + pl_module.load_state_dict(old_state_dict, strict=True) + else: + pl_module.save_to(save_path=app_state.model_restore_path) + return output
+ +
[docs] def on_train_end(self, trainer, pl_module): + """ + This is called at the end of training. + + Parameters + ---------- + trainer: The trainer object. + pl_module: The PyTorch-Lightning module. + """ + if trainer.fast_dev_run: + return None + + # Call parent on_train_end() to save the -last checkpoint + super().on_train_end(trainer, pl_module) + + # Load the best model and then re-save it + if self.save_best_model: + # wait for all processes to finish + trainer.training_type_plugin.barrier("SaveBestCheckpointConnector.resume_end") + if self.best_model_path == "": + logging.warning( + f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints " + "were found. Saving latest model instead." + ) + else: + trainer._checkpoint_connector.restore(self.best_model_path) + + if self.save_mridc_on_train_end: + pl_module.save_to(save_path=os.path.join(self.dirpath, self.prefix + self.postfix))
+ + def _del_model_without_trainer(self, filepath: str) -> None: + """ + Delete a model without a trainer. + + Parameters + ---------- + filepath: The path to the model to delete. + """ + app_state = AppState() + if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: + # filepath needs to be updated to include mp_rank + filepath = mridc.utils.model_utils.inject_model_parallel_rank(filepath) # type: ignore + + # each model parallel rank needs to remove its model + if is_global_rank_zero() or (app_state.model_parallel_size is not None and app_state.data_parallel_rank == 0): + try: + self._fs.rm(filepath) + logging.info(f"Removed checkpoint: {filepath}") + except FileNotFoundError: + logging.info(f"Tried to remove checkpoint: {filepath} but failed.")
+ + +
[docs]def configure_checkpointing(trainer: Trainer, log_dir: Path, name: str, resume: bool, params: "DictConfig"): + """Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint + callback or if trainer.weights_save_path was passed to Trainer. + """ + for callback in trainer.callbacks: + if isinstance(callback, ModelCheckpoint): + raise CheckpointMisconfigurationError( + "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint " + "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback " + "to False, or remove ModelCheckpoint from the lightning trainer" + ) + if Path(trainer.weights_save_path) != Path.cwd(): + raise CheckpointMisconfigurationError( + "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager" + ) + + # Create the callback and attach it to trainer + if "filepath" in params: + if params.filepath is not None: + logging.warning("filepath is deprecated. Please switch to dirpath and filename instead") + if params.dirpath is None: + params.dirpath = Path(params.filepath).parent + if params.filename is None: + params.filename = Path(params.filepath).name + with open_dict(params): + del params["filepath"] + if params.dirpath is None: + params.dirpath = Path(log_dir / "checkpoints") + if params.filename is None: + params.filename = f"{name}--{{{params.monitor}:.4f}}-{{epoch}}" + if params.prefix is None: + params.prefix = name + MRIDCModelCheckpoint.CHECKPOINT_NAME_LAST = f"{params.filename}-last" + + logging.debug(params.dirpath) + logging.debug(params.filename) + logging.debug(params.prefix) + + if "val" in params.monitor: + if ( + trainer.max_epochs is not None + and trainer.max_epochs != -1 + and trainer.max_epochs < trainer.check_val_every_n_epoch + ): + logging.error( + "The checkpoint callback was told to monitor a validation value but trainer.max_epochs(" + f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch(" + f"{trainer.check_val_every_n_epoch}). It is very likely this run will fail with " + f"ModelCheckpoint(monitor='{params.monitor}') not found in the returned metrics. Please ensure that " + f"validation is run within trainer.max_epochs." + ) + elif trainer.max_steps is not None: + logging.warning( + "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to " + f"{trainer.max_steps}. Please ensure that max_steps will run for at least " + f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out." + ) + + checkpoint_callback = MRIDCModelCheckpoint(n_resume=resume, **params) + checkpoint_callback.last_model_path = trainer._checkpoint_connector.resume_from_checkpoint_fit_path or "" + if "mp_rank" in checkpoint_callback.last_model_path or "tp_rank" in checkpoint_callback.last_model_path: + checkpoint_callback.last_model_path = mridc.utils.model_utils.uninject_model_parallel_rank( # type: ignore + checkpoint_callback.last_model_path + ) + trainer.callbacks.append(checkpoint_callback)
+ + +
[docs]def check_slurm(trainer): + """ + Checks if the trainer is running on a slurm cluster. If so, it will check if the trainer is running on the master + node. If it is not, it will exit. + + Parameters + ---------- + trainer: The trainer to check. + + Returns + ------- + True if the trainer is running on the master node, False otherwise. + """ + try: + return trainer.accelerator_connector.is_slurm_managing_tasks + except AttributeError: + return False
+ + +
[docs]class StatelessTimer(Timer): + """Extension of PTL timers to be per run.""" + +
[docs] def state_dict(self) -> Dict[str, Any]: # type: ignore + """Saves the state of the timer.""" + return {}
+ +
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Loads the state of the timer."""
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/export_utils.html b/docs/build/html/_modules/mridc/utils/export_utils.html new file mode 100644 index 00000000..0f14bd9f --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/export_utils.html @@ -0,0 +1,363 @@ + + + + + + mridc.utils.export_utils — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.export_utils
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.export_utils

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/export_utils.py
+import os
+from enum import Enum
+from typing import Callable, Dict, Optional, Type
+
+import onnx
+import torch
+import torch.nn as nn
+
+from mridc.utils import logging
+
+try:
+    import onnxruntime
+
+    ort_available = True
+except ImportError:
+    ort_available = False
+
+
+
[docs]class ExportFormat(Enum): + """Which format to use when exporting a Neural Module for deployment""" + + ONNX = (1,) + TORCHSCRIPT = (2,)
+ + +_EXT_DICT = {".pt": ExportFormat.TORCHSCRIPT, ".ts": ExportFormat.TORCHSCRIPT, ".onnx": ExportFormat.ONNX} + + +
[docs]class CastToFloat(nn.Module): + """Cast input to float""" + + def __init__(self, mod): + super().__init__() + self.mod = mod + +
[docs] def forward(self, x): + """Forward pass""" + return self.mod.forward(x.to(torch.float).to(x.dtype)) if torch.is_autocast_enabled() else self.mod.forward(x)
+ + +
[docs]def get_export_format(filename: str): + """Get export format from filename""" + _, ext = os.path.splitext(filename) + try: + return _EXT_DICT[ext] + except KeyError as e: + raise ValueError(f"Export file {filename} extension does not correspond to any export format!") from e
+ + +
[docs]def augment_filename(output: str, prepend: str): + """Augment output filename with prepend""" + path, filename = os.path.split(output) + filename = f"{prepend}-{filename}" + return os.path.join(path, filename)
+ + +
[docs]def forward_method(self): + """Forward method for export""" + if hasattr(self, "forward_for_export"): + return self.forward_for_export + return self.forward
+ + +
[docs]def wrap_forward_method(self): + """Wraps the forward method of the module with a function that returns the output of the forward method""" + tp = type(self) + old_forward_method = None + if hasattr(tp, "forward_for_export"): + forward_method = tp.forward_for_export + old_forward_method = tp.forward + tp.forward = forward_method + else: + forward_method = None + return forward_method, old_forward_method
+ + +
[docs]def parse_input_example(input_example): + """Parse input example to onnxrt input format""" + input_list = list(input_example) + input_dict = {} + # process possible kwargs + if isinstance(input_list[-1], dict): + input_dict = input_list[-1] + input_list = input_list[:-1] + return input_list, input_dict
+ + +
[docs]def to_onnxrt_input(input_names, input_dict, input_list): + """Transforms input to onnxrt input format""" + return { + k: input_dict[k].cpu().numpy() if k in input_dict else input_list.pop().cpu().numpy() + for k in reversed(input_names) + }
+ + +
[docs]def verify_runtime( + output, + input_list, + input_dict, + input_names, + output_names, + output_example, + check_tolerance=0.01, +): + """ + Verify runtime output with onnxrt. + + Parameters + ---------- + output: The output of the module. + input_list: The input list of the module. + input_dict: The input dict of the module. + input_names: The input names of the module. + output_names: The output names of the module. + output_example: The output example of the module. + check_tolerance: The tolerance for the check. + + Returns + ------- + The runtime output. + """ + # Verify the model can be read, and is valid + onnx_model = onnx.load(output) + input_names = [node.name for node in onnx_model.graph.input] + # skipcq: PYL-W0622 + global ort_available + if not ort_available: + logging.warning(f"ONNX generated at {output}, not verified - please install onnxruntime_gpu package.\n") + onnx.checker.check_model(onnx_model, full_check=True) + return + + onnx_session_opt = onnxruntime.SessionOptions() + onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + + sess = onnxruntime.InferenceSession( + onnx_model.SerializeToString(), sess_options=onnx_session_opt, providers=["CUDAExecutionProvider"] + ) + ort_out = sess.run(output_names, to_onnxrt_input(input_names, input_dict, input_list)) + all_good = True + + for i, out in enumerate(ort_out[0]): + expected = output_example[i] + if torch.is_tensor(expected): + tout = torch.from_numpy(out) + if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): + all_good = False + logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") + status = "SUCCESS" if all_good else "FAIL" + logging.info(f"ONNX generated at {output} verified with onnxruntime : {status}") + return all_good
+ + +
[docs]def simple_replace(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: + """ + Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same attributes. + No weights are copied. + + Parameters + ---------- + BaseT: The base type of the module. + DestT: The destination type of the module. + + Returns + ------- + A function to replace BaseT with DestT. + """ + + def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: + """Swap function to replace BaseT module with DestT""" + if not isinstance(mod, BaseT): + return None + args = [getattr(mod, name, None) for name in mod.__constants__] + return DestT(*args) + + return expansion_fn
+ + +
[docs]def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: + """ + Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same attributes. + No weights are copied. + + Parameters + ---------- + BaseT: The base type of the module. + DestT: The destination type of the module. + + Returns + ------- + A function to replace BaseT with DestT. + """ + + def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: + """Expansion function to replace BaseT module with DestT""" + return DestT(mod) + + return expansion_fn
+ + +
[docs]def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]): + """ + This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows + for swapping nested modules through arbitrary levels if children + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + """ + for path, new_mod in mapping.items(): + expanded_path = path.split(".") + parent_mod = model + for sub_path in expanded_path[:-1]: + parent_mod = parent_mod._modules[sub_path] # noqa + parent_mod._modules[expanded_path[-1]] = new_mod # noqa + + return model
+ + +
[docs]def replace_modules( + model: nn.Module, expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None +) -> nn.Module: + """ + Top-level function to replace modules in model, specified by class name with a desired replacement. + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + + Parameters + ---------- + model: Top-level model to replace modules in. + expansions: A dictionary of module class names to functions to replace them with. + + Returns + ------- + The model with replaced modules. + """ + mapping: Dict[str, nn.Module] = {} + for name, m in model.named_modules(): + m_type = type(m).__name__ + if m_type in expansions: # type: ignore + if swapped := expansions[m_type](m): # type: ignore + mapping[name] = swapped + logging.warning(f"Swapped {len(mapping)} modules") + swap_modules(model, mapping) + return model
+ + +default_replacements = { + "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), + "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), + "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), +} + + +
[docs]def replace_for_export(model: nn.Module) -> nn.Module: + """ + Top-level function to replace default set of modules in model + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + + Parameters + ---------- + model: Top-level model to replace modules in. + + Returns + ------- + The model with replaced modules. + """ + replace_modules(model, default_replacements)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/formaters/base.html b/docs/build/html/_modules/mridc/utils/formaters/base.html new file mode 100644 index 00000000..d024aca7 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/formaters/base.html @@ -0,0 +1,238 @@ + + + + + + mridc.utils.formaters.base — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.formaters.base
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.formaters.base

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/formatters/base.py
+
+import logging
+
+from mridc.utils.formaters.colors import Fore as ForegroundColors
+from mridc.utils.formaters.utils import check_color_support, to_unicode
+
+__all__ = ["BaseMRIDCFormatter", "DebugMRIDCFormatter"]
+
+
+class BaseFormatter(logging.Formatter):
+    """
+    Base class for all formatters used in Tornado. Key features of this formatter are:
+        * Color support when logging to a terminal that supports it.
+        * Timestamps on every log line.
+        * Robust against str/bytes encoding problems.
+    """
+
+    DEFAULT_FORMAT = "%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s"
+
+    DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
+
+    DEFAULT_COLORS = {
+        logging.DEBUG: ForegroundColors.CYAN,
+        logging.INFO: ForegroundColors.GREEN,
+        logging.WARNING: ForegroundColors.YELLOW,
+        logging.ERROR: ForegroundColors.MAGENTA,
+        logging.CRITICAL: ForegroundColors.RED,
+    }
+
+    def __init__(self, color=True, fmt=None, datefmt=None, colors=None):
+        """
+
+        Parameters
+        ----------
+        color: Enable color support.
+            bool, default: True
+        fmt: Log message format. It will be applied to the attributes dict of log records. The text between
+        ``%(color)s`` and ``%(end_color)s`` will be colored depending on the level if color support is on.
+            str, default: None
+        datefmt: Datetime format. Used for formatting ``(asctime)`` placeholder in ``prefix_fmt``.
+            str, default: None
+        colors: Dictionary mapping logging level to terminal color code.
+            dict, default: None
+        """
+        if fmt is None:
+            fmt = self.DEFAULT_FORMAT
+
+        if datefmt is None:
+            datefmt = self.DEFAULT_DATE_FORMAT
+
+        if colors is None:
+            colors = self.DEFAULT_COLORS
+
+        logging.Formatter.__init__(self, datefmt=datefmt)
+
+        self._fmt = fmt
+        self._colors = {}
+        self._normal = ""
+
+        if color and check_color_support():
+            self._colors = colors
+            self._normal = ForegroundColors.RESET
+
+    def format(self, record):
+        """
+        Formats a record.
+
+        Parameters
+        ----------
+        record: Log record to be formatted.
+            LogRecord
+
+        Returns
+        -------
+        The formatted record as a string.
+            str
+        """
+        try:
+            message = record.getMessage()
+            if not isinstance(message, str):
+                raise AssertionError
+            # Encoding notes:  The logging module prefers to work with character
+            # strings, but only enforces that log messages are instances of
+            # basestring.  In python 2, non-ascii bytestrings will make
+            # their way through the logging framework until they blow up with
+            # an unhelpful decoding error (with this formatter it happens
+            # when we attach the prefix, but there are other opportunities for
+            # exceptions further along in the framework).
+            #
+            # If a byte string makes it this far, convert it to unicode to
+            # ensure it will make it out to the logs.  Use repr() as a fallback
+            # to ensure that all byte strings can be converted successfully,
+            # but don't do it by default so we don't add extra quotes to ascii
+            # bytestrings.  This is a bit of a hacky place to do this, but
+            # it's worth it since the encoding errors that would otherwise
+            # result are so useless (and tornado is fond of using utf8-encoded
+            # byte strings wherever possible).
+            record.message = to_unicode(message)
+
+        except Exception as e:
+            record.message = "Bad message (%r): %r" % (e, record.__dict__)
+
+        record.asctime = self.formatTime(record, self.datefmt)
+
+        if record.levelno in self._colors:
+            record.color = self._colors[record.levelno]
+            record.end_color = self._normal
+        else:
+            record.color = record.end_color = ""
+
+        formatted = self._fmt % record.__dict__
+
+        if record.exc_info and not record.exc_text:
+            record.exc_text = self.formatException(record.exc_info)
+
+        if record.exc_text:
+            # exc_text contains multiple lines.  We need to _safe_unicode
+            # each line separately so that non-utf8 bytes don't cause
+            # all the newlines to turn into '\n'.
+            lines = [formatted.rstrip()]
+            lines.extend(to_unicode(ln) for ln in record.exc_text.split("\n"))
+
+            formatted = "\n".join(lines)
+        return formatted.replace("\n", "\n    ")
+
+
+
[docs]class BaseMRIDCFormatter(BaseFormatter): + """Base formatter for MRIDC logs.""" + + DEFAULT_FORMAT = "%(color)s[MRIDC %(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s"
+ + +
[docs]class DebugMRIDCFormatter(BaseFormatter): + """Debug formatter for MRIDC logs.""" + + DEFAULT_FORMAT = ( + "%(color)s[MRIDC %(levelname)1.1s %(asctime)s %(module)s:%(lineno)d rank:%(rank)s]%(end_color)s %(message)s" + )
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/formaters/colors.html b/docs/build/html/_modules/mridc/utils/formaters/colors.html new file mode 100644 index 00000000..d083aece --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/formaters/colors.html @@ -0,0 +1,339 @@ + + + + + + mridc.utils.formaters.colors — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.formaters.colors
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.formaters.colors

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/formatters/colors.py
+
+CSI = "\033["
+OSC = "\033]"
+BEL = "\007"
+
+
+
[docs]def code_to_chars(code): + """ + Convert ANSI color code to string of characters. + + Parameters + ---------- + code: ANSI color code. + int + + Returns + ------- + String of characters. + str + """ + return CSI + str(code) + "m"
+ + +
[docs]def set_title(title): + """ + Set terminal title. + + Parameters + ---------- + title: Title. + str + + Returns + ------- + String of characters. + str + """ + return f"{OSC}2;{title}{BEL}"
+ + +
[docs]def clear_screen(mode=2): + """ + Clear terminal screen. + + Parameters + ---------- + mode: Mode. + int + + Returns + ------- + String of characters. + str + """ + return CSI + str(mode) + "J"
+ + +
[docs]def clear_line(mode=2): + """ + Clear terminal line. + + Parameters + ---------- + mode: Mode. + int + + Returns + ------- + String of characters. + str + """ + return CSI + str(mode) + "K"
+ + +
[docs]class AnsiCodes: + """ANSI color codes.""" + + def __init__(self): + # the subclasses declare class attributes which are numbers. + # Upon instantiation we define instance attributes, which are the same + # as the class attributes but wrapped with the ANSI escape sequence + for name in dir(self): + if not name.startswith("_"): + value = getattr(self, name) + setattr(self, name, code_to_chars(value))
+ + +
[docs]class AnsiCursor: + """ANSI cursor codes.""" + +
[docs] @staticmethod + def UP(n=1): + """ + Move the cursor up n lines. + + Parameters + ---------- + n: Number of lines. + int + + Returns + ------- + String of characters. + str + """ + return CSI + str(n) + "A"
+ +
[docs] @staticmethod + def DOWN(n=1): + """ + Move the cursor down n lines. + + Parameters + ---------- + n: Number of lines. + int + + Returns + ------- + String of characters. + str + """ + return CSI + str(n) + "B"
+ +
[docs] @staticmethod + def FORWARD(n=1): + """ + Move the cursor forward n lines. + + Parameters + ---------- + n: Number of lines. + int + + Returns + ------- + String of characters. + str + """ + return CSI + str(n) + "C"
+ +
[docs] @staticmethod + def BACK(n=1): + """ + Move the cursor back n lines. + + Parameters + ---------- + n: Number of lines. + int + + Returns + ------- + String of characters. + str + """ + return CSI + str(n) + "D"
+ +
[docs] @staticmethod + def POS(x=1, y=1): + """ + Move the cursor to the specified position. + + Parameters + ---------- + x: X position. + int + y: Y position. + int + + Returns + ------- + String of characters. + str + """ + return CSI + str(y) + ";" + str(x) + "H"
+ + +
[docs]class AnsiFore(AnsiCodes): + """ANSI color codes for foreground text.""" + + BLACK = 30 + RED = 31 + GREEN = 32 + YELLOW = 33 + BLUE = 34 + MAGENTA = 35 + CYAN = 36 + WHITE = 37 + RESET = 39 + + # These are fairly well supported, but not part of the standard. + LIGHTBLACK_EX = 90 + LIGHTRED_EX = 91 + LIGHTGREEN_EX = 92 + LIGHTYELLOW_EX = 93 + LIGHTBLUE_EX = 94 + LIGHTMAGENTA_EX = 95 + LIGHTCYAN_EX = 96 + LIGHTWHITE_EX = 97
+ + +
[docs]class AnsiBack(AnsiCodes): + """ANSI color codes for background text.""" + + BLACK = 40 + RED = 41 + GREEN = 42 + YELLOW = 43 + BLUE = 44 + MAGENTA = 45 + CYAN = 46 + WHITE = 47 + RESET = 49 + + # These are fairly well supported, but not part of the standard. + LIGHTBLACK_EX = 100 + LIGHTRED_EX = 101 + LIGHTGREEN_EX = 102 + LIGHTYELLOW_EX = 103 + LIGHTBLUE_EX = 104 + LIGHTMAGENTA_EX = 105 + LIGHTCYAN_EX = 106 + LIGHTWHITE_EX = 107
+ + +
[docs]class AnsiStyle(AnsiCodes): + """ANSI color codes for text styles.""" + + BRIGHT = 1 + DIM = 2 + NORMAL = 22 + RESET_ALL = 0
+ + +Fore = AnsiFore() +Back = AnsiBack() +Style = AnsiStyle() +Cursor = AnsiCursor() +
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/formaters/utils.html b/docs/build/html/_modules/mridc/utils/formaters/utils.html new file mode 100644 index 00000000..e6a3adea --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/formaters/utils.html @@ -0,0 +1,147 @@ + + + + + + mridc.utils.formaters.utils — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.formaters.utils
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.formaters.utils

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/formatters/utils.py
+
+import sys
+
+__all__ = ["check_color_support", "to_unicode"]
+
+from mridc.constants import MRIDC_ENV_VARNAME_ENABLE_COLORING
+from mridc.utils.env_var_parsing import get_envbool
+
+
+
[docs]def check_color_support(): + """ + + Returns + ------- + True if the terminal supports color, False otherwise. + bool + """ + # Colors can be forced with an env variable + return bool(not sys.platform.lower().startswith("win") and get_envbool(MRIDC_ENV_VARNAME_ENABLE_COLORING, False))
+ + +
[docs]def to_unicode(value): + """ + Converts a string to unicode. If the string is already unicode, it is returned as is. If it is a byte string, it is + decoded using utf-8. + + Parameters + ---------- + value: The string to convert. + str + + Returns + ------- + The converted string. + str + """ + try: + if isinstance(value, (str, type(None))): + return value + + if not isinstance(value, bytes): + raise TypeError("Expected bytes, unicode, or None; got %r" % type(value)) + + return value.decode("utf-8") + + except UnicodeDecodeError: + return repr(value)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/get_rank.html b/docs/build/html/_modules/mridc/utils/get_rank.html new file mode 100644 index 00000000..45e8e00a --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/get_rank.html @@ -0,0 +1,120 @@ + + + + + + mridc.utils.get_rank — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +

Source code for mridc.utils.get_rank

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/get_rank.py
+from mridc.utils.env_var_parsing import get_envint
+
+
+
[docs]def is_global_rank_zero(): + """Helper function to determine if the current process is global_rank 0 (the main process).""" + # Try to get the pytorch RANK env var RANK is set by torch.distributed.launch + rank = get_envint("RANK", None) + if rank is not None: + return rank == 0 + + # Try to get the SLURM global rank env var SLURM_PROCID is set by SLURM + slurm_rank = get_envint("SLURM_PROCID", None) + if slurm_rank is not None: + return slurm_rank == 0 + + # if neither pytorch and SLURM env vars are set check NODE_RANK/GROUP_RANK and LOCAL_RANK env vars assume + # global_rank is zero if undefined + node_rank = get_envint("NODE_RANK", get_envint("GROUP_RANK", 0)) + local_rank = get_envint("LOCAL_RANK", 0) + return node_rank == 0 and local_rank == 0
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/lightning_logger_patch.html b/docs/build/html/_modules/mridc/utils/lightning_logger_patch.html new file mode 100644 index 00000000..0433ef70 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/lightning_logger_patch.html @@ -0,0 +1,146 @@ + + + + + + mridc.utils.lightning_logger_patch — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.lightning_logger_patch
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.lightning_logger_patch

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/lightning_logger_patch.py
+
+import logging as _logging
+from logging.handlers import MemoryHandler
+from typing import Any, Dict
+
+import pytorch_lightning as pl
+
+HANDLERS: Dict[Any, Any] = {}
+PATCHED = False
+
+
+
[docs]def add_memory_handlers_to_pl_logger(): + """ + Adds two MemoryHandlers to pytorch_lightning's logger. These two handlers are essentially message buffers. This + function is called in mridc.utils.__init__.py. These handlers are used in add_filehandlers_to_pl_logger to flush + buffered messages to files. + """ + if not HANDLERS: + HANDLERS["memory_err"] = MemoryHandler(-1) + HANDLERS["memory_err"].addFilter(lambda record: record.levelno > _logging.INFO) + HANDLERS["memory_all"] = MemoryHandler(-1) + pl._logger.addHandler(HANDLERS["memory_err"]) + pl._logger.addHandler(HANDLERS["memory_all"])
+ + +
[docs]def add_filehandlers_to_pl_logger(all_log_file, err_log_file): + """ + Adds two filehandlers to pytorch_lightning's logger. Called in mridc.utils.exp_manager(). The first filehandler + logs all messages to all_log_file while the second filehandler logs all WARNING and higher messages to + err_log_file. If "memory_err" and "memory_all" exist in HANDLERS, then those buffers are flushed to err_log_file + and all_log_file respectively, and then closed. + """ + HANDLERS["file"] = _logging.FileHandler(all_log_file) + pl._logger.addHandler(HANDLERS["file"]) + HANDLERS["file_err"] = _logging.FileHandler(err_log_file) + HANDLERS["file_err"].addFilter(lambda record: record.levelno > _logging.INFO) + pl._logger.addHandler(HANDLERS["file_err"]) + + if HANDLERS.get("memory_all"): + HANDLERS["memory_all"].setTarget(HANDLERS["file"]) + HANDLERS["memory_all"].close() + del HANDLERS["memory_all"] + if HANDLERS.get("memory_err"): + HANDLERS["memory_err"].setTarget(HANDLERS["file_err"]) + HANDLERS["memory_err"].close() + del HANDLERS["memory_err"]
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/metaclasses.html b/docs/build/html/_modules/mridc/utils/metaclasses.html new file mode 100644 index 00000000..364d54f9 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/metaclasses.html @@ -0,0 +1,124 @@ + + + + + + mridc.utils.metaclasses — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.metaclasses
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.metaclasses

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/metaclasses.py
+
+import threading
+from typing import Any, Dict
+
+
+
[docs]class Singleton(type): + """Implementation of a generic, tread-safe singleton meta-class. Can be used as meta-class, i.e. will create.""" + + # List of instances - one per class. + __instances: Dict[Any, Any] = {} + # Lock used for accessing the instance. + __lock = threading.Lock() + +
[docs] def __call__(cls, *args, **kwargs): + """Returns singleton instance. A thread safe implementation.""" + if cls not in cls.__instances: + # Enter critical section. + with cls.__lock: + # Check once again. + if cls not in cls.__instances: + # Create a new object instance - one per class. + cls.__instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + # Return the instance. + return cls.__instances[cls]
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/model_utils.html b/docs/build/html/_modules/mridc/utils/model_utils.html new file mode 100644 index 00000000..736d3b67 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/model_utils.html @@ -0,0 +1,731 @@ + + + + + + mridc.utils.model_utils — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.model_utils
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.model_utils

+# encoding: utf-8
+import sys
+
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/model_utils.py
+import copy
+import os
+from dataclasses import dataclass, is_dataclass
+from distutils.version import Version
+from enum import Enum
+from pathlib import Path
+from typing import Any, List, Optional, Set, Tuple, Union
+
+import wrapt
+from omegaconf import DictConfig, ListConfig, OmegaConf
+from omegaconf.errors import OmegaConfBaseException
+from pytorch_lightning import LightningModule
+
+import mridc
+from mridc.constants import MRIDC_ENV_CACHE_DIR
+from mridc.core.classes.common import PretrainedModelInfo
+from mridc.core.classes.modelPT import ModelPT
+from mridc.core.conf.modelPT import MRIDCConfig
+from mridc.utils.app_state import AppState
+from mridc.utils import logging
+
+_HAS_HYDRA = True
+
+_VAL_TEST_FASTPATH_KEY = "ds_item"
+
+__all__ = [
+    "ArtifactPathType",
+    "ArtifactItem",
+    "resolve_dataset_name_from_cfg",
+    "parse_dataset_as_name",
+    "unique_names_check",
+    "resolve_validation_dataloaders",
+    "wrap_training_step",
+    "convert_model_config_to_dict_config",
+    "_convert_config",
+    "maybe_update_config_version",
+    "import_class_by_path",
+    "resolve_subclass_pretrained_model_info",
+    "check_lib_version",
+    "resolve_cache_dir",
+    "inject_model_parallel_rank",
+    "uninject_model_parallel_rank",
+]
+
+
+
[docs]class ArtifactPathType(Enum): + """ + ArtifactPathType refers to the type of the path that the artifact is located at. + LOCAL_PATH: A user local filepath that exists on the file system. + TAR_PATH: A (generally flattened) filepath that exists inside of an archive (that may have its own full path). + """ + + LOCAL_PATH = 0 + TAR_PATH = 1
+ + +
[docs]@dataclass(init=False) +class ArtifactItem: + """ArtifactItem is a dataclass that holds the information of an artifact.""" + + path: str + path_type: ArtifactPathType + hashed_path: Optional[str] = None
+ + +
[docs]def resolve_dataset_name_from_cfg(cfg: "DictConfig") -> Union[Union[str, int, Enum, float, bool, None], Any]: + """ + Parses items of the provided sub-config to find the first potential key that resolves to an existing file or + directory. + + # Fast-path Resolution + In order to handle cases where we need to resolve items that are not paths, a fastpath key can be provided as + defined in the global `_VAL_TEST_FASTPATH_KEY`. + + This key can be used in two ways : + ## _VAL_TEST_FASTPATH_KEY points to another key in the config + If this _VAL_TEST_FASTPATH_KEY points to another key in this config itself, then we assume we want to loop through + the values of that key. This allows for any key in the config to become a fastpath key. + + Example + ------- + validation_ds: + + .. code-block:: + + splits: "val" + ... + <_VAL_TEST_FASTPATH_KEY>: "splits" <-- this points to the key name "splits" + + Then we can write the following when overriding in hydra: + ```python + python train_file.py ... model.validation_ds.splits=[val1, val2, dev1, dev2] ... + ``` + ## _VAL_TEST_FASTPATH_KEY itself acts as the resolved key + If this _VAL_TEST_FASTPATH_KEY does not point to another key in the config, then it is assumed that the items of + this key itself are used for resolution. + + Example + ------- + validation_ds: + + .. code-block:: + + <_VAL_TEST_FASTPATH_KEY>: "val" <-- this points to the key name "splits" + + Then we can write the following when overriding in hydra: + ```python + python train_file.py ... model.validation_ds.<_VAL_TEST_FASTPATH_KEY>=[val1, val2, dev1, dev2] ... + ``` + # IMPORTANT NOTE: + It <can> potentially mismatch if there exist more than 2 valid paths, and the first path does *not* resolve the + path of the data file (but does resolve to some other valid path). To avoid this side effect, place the data path + as the first item on the config file. + + Parameters + ---------- + cfg: Sub-config of the config file. + + Returns + ------- + A str representing the `key` of the config which hosts the filepath(s), or None in case path could not be resolved. + """ + if _VAL_TEST_FASTPATH_KEY in cfg and cfg[_VAL_TEST_FASTPATH_KEY] is not None: + fastpath_key = cfg[_VAL_TEST_FASTPATH_KEY] + + if isinstance(fastpath_key, str) and fastpath_key in cfg: + return cfg[fastpath_key] + return _VAL_TEST_FASTPATH_KEY + + for key, value in cfg.items(): + if type(value) in [list, tuple, ListConfig]: + # Count the number of valid paths in the list + values_are_paths = 0 + for val_i in value: + val_i = str(val_i) + + if os.path.exists(val_i) or os.path.isdir(val_i): + values_are_paths += 1 + else: + # reset counter and break inner loop + break + + if values_are_paths == len(value): + return key + + elif os.path.exists(str(value)) or os.path.isdir(str(value)): + return key + + return None
+ + +
[docs]def parse_dataset_as_name(name: str) -> str: + """ + Constructs a valid prefix-name from a provided file path. + + Parameters + ---------- + name: Path to some valid data/manifest file or a python object that will be used as a name for the data loader (via + str() cast). + + Returns + ------- + A valid prefix-name for the data loader. + """ + name = Path(name).stem if os.path.exists(name) or os.path.isdir(name) else name + # cleanup name + name = name.replace("-", "_") + + if "manifest" in name: + name = name.replace("manifest", "") + + if "dataset" in name: + name = name.replace("dataset", "") + + # Test if the manifest/dataset name was simply `manifest.yaml` or `dataset.yaml`: Invalid names. + if name == "": + raise ValueError( + "Provided dataset / manifest filename was `manifest.json` or `dataset.json`.\n" + "Such a name is invalid, since multiple datasets/manifests can share the same name,\n" + "thereby overriding their results during logging. Please pick a more descriptive filename \n" + "for the provided dataset / manifest file." + ) + + if name[-1] != "_": + name = f"{name}_" + + return name
+ + +
[docs]def unique_names_check(name_list: Optional[List[str]]): + """ + Performs a uniqueness check on the name list resolved, so that it can warn users about non-unique keys. + + Parameters + ---------- + name_list: List of strings resolved for data loaders. + """ + if name_list is None: + return + + # Name uniqueness checks + names = set() + for name in name_list: + if name in names: + logging.warning( + "Name resolution has found more than one data loader having the same name !\n" + "In such cases, logs will nor be properly generated. " + "Please rename the item to have unique names.\n" + f"Resolved name : {name}" + ) + else: + names.add(name) # we need just hash key check, value is just a placeholder
+ + +
[docs]def resolve_validation_dataloaders(model: ModelPT): + """ + Helper method that operates on the ModelPT class to automatically support multiple dataloaders for the validation + set. It does so by first resolving the path to one/more data files via `resolve_dataset_name_from_cfg()`. + If this resolution fails, it assumes the data loader is prepared to manually support / not support multiple data + loaders and simply calls the appropriate setup method. + If resolution succeeds: + - Checks if provided path is to a single file or a list of files. + If a single file is provided, simply tags that file as such and loads it via the setup method. + If multiple files are provided: + - Inject a new manifest path at index "i" into the resolved key. + - Calls the appropriate setup method to set the data loader. + - Collects the initialized data loader in a list and preserves it. + - Once all data loaders are processed, assigns the list of loaded loaders to the ModelPT. + - Finally, assigns a list of unique names resolved from the file paths to the ModelPT. + + Parameters + ---------- + model: ModelPT subclass, which requires >=1 Validation Dataloaders to be setup. + """ + if not _HAS_HYDRA: + logging.error("This function requires Hydra/OmegaConf and it was not installed.") + sys.exit(1) + cfg = copy.deepcopy(model._cfg) + dataloaders: List[Any] = [] + + # process val_loss_idx + if "val_dl_idx" in cfg.validation_ds: + cfg = OmegaConf.to_container(cfg) + val_dl_idx = cfg["validation_ds"].pop("val_dl_idx") + cfg = OmegaConf.create(cfg) + else: + val_dl_idx = 0 + + # Set val_loss_idx + model._val_dl_idx = val_dl_idx + + ds_key = resolve_dataset_name_from_cfg(cfg.validation_ds) + + if ds_key is None: + logging.debug( + f"Could not resolve file path from provided config - {cfg.validation_ds}. " + "Disabling support for multi-dataloaders." + ) + + model.setup_validation_data(cfg.validation_ds) + return + + ds_values = cfg.validation_ds[ds_key] + + if isinstance(ds_values, (list, tuple, ListConfig)): + + for ds_value in ds_values: + cfg.validation_ds[ds_key] = ds_value + model.setup_validation_data(cfg.validation_ds) + dataloaders.append(model.validation_dl) + + model.validation_dl = dataloaders # type: ignore + model.validation_names = [parse_dataset_as_name(ds) for ds in ds_values] # type: ignore + + unique_names_check(name_list=model.validation_names) + return + model.setup_validation_data(cfg.validation_ds) + model.validation_names = [parse_dataset_as_name(ds_values)] + + unique_names_check(name_list=model.validation_names)
+ + +def resolve_test_dataloaders(model: "ModelPT"): + """ + Helper method that operates on the ModelPT class to automatically support + multiple dataloaders for the test set. + It does so by first resolving the path to one/more data files via `resolve_dataset_name_from_cfg()`. + If this resolution fails, it assumes the data loader is prepared to manually support / not support + multiple data loaders and simply calls the appropriate setup method. + If resolution succeeds: + Checks if provided path is to a single file or a list of files. + If a single file is provided, simply tags that file as such and loads it via the setup method. + If multiple files are provided: + Inject a new manifest path at index "i" into the resolved key. + Calls the appropriate setup method to set the data loader. + Collects the initialized data loader in a list and preserves it. + Once all data loaders are processed, assigns the list of loaded loaders to the ModelPT. + Finally, assigns a list of unique names resolved from the file paths to the ModelPT. + + Parameters + ---------- + model: ModelPT subclass, which requires >=1 Test Dataloaders to be setup. + """ + if not _HAS_HYDRA: + logging.error("This function requires Hydra/OmegaConf and it was not installed.") + sys.exit(1) + cfg = copy.deepcopy(model._cfg) + dataloaders: List[Any] = [] + + # process test_loss_idx + if "test_dl_idx" in cfg.test_ds: + cfg = OmegaConf.to_container(cfg) + test_dl_idx = cfg["test_ds"].pop("test_dl_idx") + cfg = OmegaConf.create(cfg) + else: + test_dl_idx = 0 + + # Set val_loss_idx + model._test_dl_idx = test_dl_idx + + ds_key = resolve_dataset_name_from_cfg(cfg.test_ds) + + if ds_key is None: + logging.debug( + f"Could not resolve file path from provided config - {cfg.test_ds}. " + "Disabling support for multi-dataloaders." + ) + + model.setup_test_data(cfg.test_ds) + return + + ds_values = cfg.test_ds[ds_key] + + if isinstance(ds_values, (list, tuple, ListConfig)): + + for ds_value in ds_values: + cfg.test_ds[ds_key] = ds_value + model.setup_test_data(cfg.test_ds) + dataloaders.append(model.test_dl) + + model.test_dl = dataloaders # type: ignore + model.test_names = [parse_dataset_as_name(ds) for ds in ds_values] # type: ignore + + unique_names_check(name_list=model.test_names) + return + model.setup_test_data(cfg.test_ds) + model.test_names = [parse_dataset_as_name(ds_values)] + + unique_names_check(name_list=model.test_names) + + +
[docs]@wrapt.decorator +def wrap_training_step(wrapped, instance: LightningModule, args, kwargs): + """ + Wraps the training step of the LightningModule. + + Parameters + ---------- + wrapped: The wrapped function. + instance: The LightningModule instance. + args: The arguments passed to the wrapped function. + kwargs: The keyword arguments passed to the wrapped function. + + Returns + ------- + The return value of the wrapped function. + """ + output_dict = wrapped(*args, **kwargs) + + if isinstance(output_dict, dict) and output_dict is not None and "log" in output_dict: + log_dict = output_dict.pop("log") + instance.log_dict(log_dict, on_step=True) + + return output_dict
+ + +
[docs]def convert_model_config_to_dict_config(cfg: Union[DictConfig, MRIDCConfig]) -> DictConfig: + """ + Converts its input into a standard DictConfig. + + Possible input values are: + - DictConfig + - A dataclass which is a subclass of MRIDCConfig + + Parameters + ---------- + cfg: A dict-like object. + + Returns + ------- + The equivalent DictConfig. + """ + if not _HAS_HYDRA: + logging.error("This function requires Hydra/OmegaConf and it was not installed.") + sys.exit(1) + if not isinstance(cfg, (OmegaConf, DictConfig)) and is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) + + if not isinstance(cfg, DictConfig): + raise ValueError(f"cfg constructor argument must be of type DictConfig/dict but got {type(cfg)} instead.") + + config = OmegaConf.to_container(cfg, resolve=True) + config = OmegaConf.create(config) + return config
+ + +def _convert_config(cfg: "OmegaConf"): + """Recursive function converting the configuration from old hydra format to the new one.""" + if not _HAS_HYDRA: + logging.error("This function requires Hydra/OmegaConf and it was not installed.") + sys.exit(1) + + # Get rid of cls -> _target_. + if "cls" in cfg and "_target_" not in cfg: + cfg._target_ = cfg.pop("cls") # type: ignore + + # Get rid of params. + if "params" in cfg: + params = cfg.pop("params") # type: ignore + for param_key, param_val in params.items(): + cfg[param_key] = param_val + + # Recursion. + try: + for _, sub_cfg in cfg.items(): # type: ignore + if isinstance(sub_cfg, DictConfig): + _convert_config(sub_cfg) # type: ignore + except OmegaConfBaseException as e: + logging.warning(f"Skipped conversion for config/subconfig:\n{cfg}\n Reason: {e}.") + + +
[docs]def maybe_update_config_version(cfg: "DictConfig"): + """ + Recursively convert Hydra 0.x configs to Hydra 1.x configs. + Changes include: + - `cls` -> `_target_`. + - `params` -> drop params and shift all arguments to parent. + - `target` -> `_target_` cannot be performed due to ModelPT injecting `target` inside class. + + Parameters + ---------- + cfg: Any Hydra compatible DictConfig + + Returns + ------- + An updated DictConfig that conforms to Hydra 1.x format. + """ + if not _HAS_HYDRA: + logging.error("This function requires Hydra/OmegaConf and it was not installed.") + sys.exit(1) + if cfg is not None and not isinstance(cfg, DictConfig): + try: + temp_cfg = OmegaConf.create(cfg) + cfg = temp_cfg + except OmegaConfBaseException: + # Cannot be cast to DictConfig, skip updating. + return cfg + + # Make a copy of model config. + cfg = copy.deepcopy(cfg) + OmegaConf.set_struct(cfg, False) + + # Convert config. + _convert_config(cfg) # type: ignore + + # Update model config. + OmegaConf.set_struct(cfg, True) + + return cfg
+ + +
[docs]def import_class_by_path(path: str): + """Recursive import of class by path string.""" + paths = path.split(".") + path = ".".join(paths[:-1]) + class_name = paths[-1] + mod = __import__(path, fromlist=[class_name]) + mod = getattr(mod, class_name) + return mod
+ + +
[docs]def resolve_subclass_pretrained_model_info(base_class) -> Union[List[PretrainedModelInfo], Set[Any]]: + """ + Recursively traverses the inheritance graph of subclasses to extract all pretrained model info. + First constructs a set of unique pretrained model info by performing DFS over the inheritance graph. + All model info belonging to the same class is added together. + + Parameters + ---------- + base_class: The root class, whose subclass graph will be traversed. + + Returns + ------- + A list of unique pretrained model infos belonging to all the inherited subclasses of this baseclass. + """ + list_of_models = set() + + def recursive_subclass_walk(cls): + """ + Recursively traverses the inheritance graph of subclasses to extract all pretrained model info. + + Parameters + ---------- + cls: The class to be traversed. + + Returns + ------- + A list of unique pretrained model infos belonging to all the inherited subclasses of this baseclass. + """ + for subclass in cls.__subclasses__(): + # step into its immediate subclass + recursive_subclass_walk(subclass) + + subclass_models = subclass.list_available_models() + + if subclass_models is not None and len(subclass_models) > 0: + # Inject subclass info into pretrained model info, if not already overridden by subclass. + for model_info in subclass_models: + # If subclass manually injects class_, dont override. + if model_info.class_ is None: + model_info.class_ = subclass + + for model_info in subclass_models: + list_of_models.add(model_info) + + recursive_subclass_walk(base_class) + list_of_models = list(sorted(list_of_models)) # type: ignore + return list_of_models
+ + +
[docs]def check_lib_version(lib_name: str, checked_version: str, operator) -> Tuple[Optional[bool], str]: + """ + Checks if a library is installed, and if it is, checks the operator(lib.__version__, checked_version) as a result. + This bool result along with a string analysis of result is returned. + If the library is not installed at all, then returns None instead, along with a string explaining + that the library is not installed + + Parameters + ---------- + lib_name: lower case str name of the library that must be imported. + checked_version: semver string that is compared against lib.__version__. + operator: binary callable function func(a, b) -> bool; that compares lib.__version__ against version in some + manner. Must return a boolean. + + Returns + ------- + A tuple of results: + - Bool or None. Bool if the library could be imported, and the result of + operator(lib.__version__, checked_version) or False if __version__ is not implemented in lib. + None is passed if the library is not installed at all. + - A string analysis of the check. + """ + try: + if "." in lib_name: + mod = import_class_by_path(lib_name) + else: + mod = __import__(lib_name) + + if hasattr(mod, "__version__"): + lib_ver = Version(mod.__version__) # type: ignore + match_ver = Version(checked_version) # type: ignore + + if operator(lib_ver, match_ver): + msg = f"Lib {lib_name} version is satisfied !" + return True, msg + msg = ( + f"Lib {lib_name} version ({lib_ver}) is not {operator.__name__} than required version " + f"{checked_version}.\n" + f"Please upgrade the lib using either pip or conda to the latest version." + ) + return False, msg + msg = ( + f"Lib {lib_name} does not implement __version__ in its init file. " + f"Could not check version compatibility." + ) + return False, msg + except ImportError: + pass + + msg = f"Lib {lib_name} has not been installed. Please use pip or conda to install this package." + return None, msg
+ + +
[docs]def resolve_cache_dir() -> Path: + """ + Utility method to resolve a cache directory for MRIDC that can be overridden by an environment variable. + Example: + MRIDC_CACHE_DIR="~/mridc_cache_dir/" python mridc_example_script.py + + Returns + ------- + A Path object, resolved to the absolute path of the cache directory. If no override is provided, uses an inbuilt + default which adapts to mridc versions strings. + """ + override_dir = os.environ.get(MRIDC_ENV_CACHE_DIR, "") + return ( + Path.joinpath(Path.home(), f".cache/torch/MRIDC/MRIDC_{mridc.__version__}") + if override_dir == "" + else Path(override_dir).resolve() + )
+ + +
[docs]def uninject_model_parallel_rank(filepath): + """Uninjects tensor/pipeline model parallel ranks from the filepath.""" + filepath = str(filepath) + if "mp_rank" in filepath or "tp_rank" in filepath: + dirname = os.path.dirname(os.path.dirname(filepath)) + basename = os.path.basename(filepath) + filepath = os.path.join(dirname, basename) + return filepath
+ + +
[docs]def inject_model_parallel_rank(filepath): + """Injects tensor/pipeline model parallel ranks into the filepath. Does nothing if not using model parallelism.""" + filepath = uninject_model_parallel_rank(filepath) + app_state = AppState() + if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: + # filepath needs to be updated to include mp_rank + dirname = os.path.dirname(filepath) + basename = os.path.basename(filepath) + if app_state.pipeline_model_parallel_size is None or app_state.pipeline_model_parallel_size == 1: + filepath = f"{dirname}/mp_rank_{app_state.tensor_model_parallel_rank:02d}/{basename}" + else: + filepath = ( + f"{dirname}/tp_rank_{app_state.tensor_model_parallel_rank:02d}_pp_rank_" + f"{app_state.pipeline_model_parallel_rank:03d}/{basename} " + ) + return filepath + return filepath
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/mridc_logging.html b/docs/build/html/_modules/mridc/utils/mridc_logging.html new file mode 100644 index 00000000..48c47007 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/mridc_logging.html @@ -0,0 +1,505 @@ + + + + + + mridc.utils.mridc_logging — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Module code »
  • +
  • mridc.utils.mridc_logging
  • +
  • +
  • +
+
+
+
+
+ +

Source code for mridc.utils.mridc_logging

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/nemo_logging.py
+
+import enum
+import logging as _logging
+import sys
+import threading
+import warnings
+from contextlib import contextmanager
+from logging.handlers import MemoryHandler
+
+__all__ = ["Logger", "LogMode"]
+
+from mridc.constants import MRIDC_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR, MRIDC_ENV_VARNAME_TESTING
+from mridc.utils.env_var_parsing import get_envbool
+from mridc.utils.formaters.base import BaseMRIDCFormatter, DebugMRIDCFormatter
+from mridc.utils.get_rank import is_global_rank_zero
+from mridc.utils.metaclasses import Singleton
+
+
+
[docs]class LogMode(enum.IntEnum): + """Enum for the different logging modes.""" + + EACH = 0 # Log the message each time + ONCE = 1 # Log the message only once. The same message will not be logged again.
+ + +
[docs]class Logger(metaclass=Singleton): + """Singleton class for logging.""" + + # Level 0 + NOTSET = _logging.NOTSET + + # Level 10 + DEBUG = _logging.DEBUG + + # Level 20 + INFO = _logging.INFO + + # Level 30 + WARNING = _logging.WARNING + + # Level 40 + ERROR = _logging.ERROR + + # Level 50 + CRITICAL = _logging.CRITICAL + + _level_names = {0: "NOTSET", 10: "DEBUG", 20: "INFO", 30: "WARNING", 40: "ERROR", 50: "CRITICAL"} + + def __init__(self, capture_warnings=True): + + self._logger = None + # Multi-GPU runs run in separate processes, thread locks shouldn't be needed + self._logger_lock = threading.Lock() + self._handlers = {} + self.old_warnings_showwarning = None + self._define_logger(capture_warnings) + self.once_logged = set() + self.rank = 0 if is_global_rank_zero() else "UNK" + + def _define_logger(self, capture_warnings=True): + """Creates the logger if not already created. Called in init""" + # Use double-checked locking to avoid taking lock unnecessarily. + if self._logger is not None: + return self._logger + + with self._logger_lock: + try: + self._logger = _logging.getLogger("mridc_logger") + # By default, silence all loggers except the logger for rank 0 + self.remove_stream_handlers() + # If MRIDC_TESTING is set, add a streamhandler to all ranks + if get_envbool(MRIDC_ENV_VARNAME_TESTING, False): + old_factory = _logging.getLogRecordFactory() + + def record_factory(*args, **kwargs): + record = old_factory(*args, **kwargs) + record.rank = self.rank + return record + + _logging.setLogRecordFactory(record_factory) + self.add_stream_handlers(formatter=DebugMRIDCFormatter) + elif is_global_rank_zero(): + self.add_stream_handlers() + + # Add memoryhandlers, essentially buffers. They are used to save messages that we will flush to file + # once the appropriate file handlers are added. + if is_global_rank_zero(): + # Add a memoryhandler for error messages. Only logged on rank 0 + self._handlers["memory_err"] = MemoryHandler(-1) + self._handlers["memory_err"].addFilter(lambda record: record.levelno > _logging.INFO) + formatter = BaseMRIDCFormatter + self._handlers["memory_err"].setFormatter(formatter()) + self._logger.addHandler(self._handlers["memory_err"]) + # Add a memoryhandler for all messages on all ranks + self._handlers["memory_all"] = MemoryHandler(-1) + formatter = BaseMRIDCFormatter + self._handlers["memory_all"].setFormatter(formatter()) + self._logger.addHandler(self._handlers["memory_all"]) + + finally: + level = Logger.INFO + if get_envbool(MRIDC_ENV_VARNAME_TESTING, False): + level = Logger.DEBUG + self.set_verbosity(verbosity_level=level) + self.captureWarnings(capture_warnings) + + self._logger.propagate = False + +
[docs] def remove_stream_handlers(self): + """Removes StreamHandler that log to stdout and stderr from the logger.""" + if self._logger is None: + raise RuntimeError("Impossible to set handlers if the Logger is not predefined") + + # ======== Remove Handler if already existing ======== + + try: + self._logger.removeHandler(self._handlers["stream_stdout"]) + del self._handlers["stream_stdout"] + except KeyError: + pass + + try: + self._logger.removeHandler(self._handlers["stream_stderr"]) + del self._handlers["stream_stderr"] + except KeyError: + pass
+ +
[docs] def add_stream_handlers(self, formatter=BaseMRIDCFormatter): + """ + Add StreamHandler that log to stdout and stderr to the logger. INFO and lower logs are streamed to stdout + while WARNING and higher are streamed to stderr. If the MRIDC_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR environment + variable is set, all logs are sent to stderr instead. + """ + if self._logger is None: + raise RuntimeError("Impossible to set handlers if the Logger is not predefined") + + # Add the output handler. + if get_envbool(MRIDC_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR, False): + self._handlers["stream_stdout"] = _logging.StreamHandler(sys.stderr) + + else: + self._handlers["stream_stdout"] = _logging.StreamHandler(sys.stdout) + self._handlers["stream_stdout"].addFilter(lambda record: record.levelno <= _logging.INFO) + + self._handlers["stream_stderr"] = _logging.StreamHandler(sys.stderr) + self._handlers["stream_stderr"].addFilter(lambda record: record.levelno > _logging.INFO) + + self._handlers["stream_stdout"].setFormatter(formatter()) + self._logger.addHandler(self._handlers["stream_stdout"]) + + try: + self._handlers["stream_stderr"].setFormatter(formatter()) + self._logger.addHandler(self._handlers["stream_stderr"]) + except KeyError: + pass
+ +
[docs] def reset_stream_handler(self, formatter=BaseMRIDCFormatter): + """Removes then adds stream handlers.""" + self.remove_stream_handlers() + self.add_stream_handlers(formatter=formatter)
+ +
[docs] def add_file_handler(self, log_file): + """ + Add a FileHandler to logger that logs all messages to a file. If the logger had a MemoryHandler at + self._handlers["memory_all"], those buffered messages are flushed to the new file, and the MemoryHandler is + closed. + """ + if self._logger is None: + raise RuntimeError("Impossible to set handlers if the Logger is not predefined") + + self._handlers["file"] = _logging.FileHandler(log_file) + formatter = BaseMRIDCFormatter + self._handlers["file"].setFormatter(formatter()) + self._logger.addHandler(self._handlers["file"]) + + if self._handlers.get("memory_all"): + self._handlers["memory_all"].setTarget(self._handlers["file"]) + self._handlers["memory_all"].close() # flush and remove + del self._handlers["memory_all"]
+ +
[docs] def add_err_file_handler(self, log_file): + """ + Add a FileHandler to logger that logs all WARNING and higher messages to a file. If the logger had a + MemoryHandler at self._handlers["memory_err"], those buffered messages are flushed to the new file, and the + MemoryHandler is closed. + """ + if self._logger is None: + raise RuntimeError("Impossible to set handlers if the Logger is not predefined") + + self._handlers["file_err"] = _logging.FileHandler(log_file) + self._handlers["file_err"].addFilter(lambda record: record.levelno > _logging.INFO) + + formatter = BaseMRIDCFormatter + self._handlers["file_err"].setFormatter(formatter()) + self._logger.addHandler(self._handlers["file_err"]) + + if self._handlers.get("memory_err"): + self._handlers["memory_err"].setTarget(self._handlers["file_err"]) + self._handlers["memory_err"].close() # flush and remove + del self._handlers["memory_err"]
+ +
[docs] def getEffectiveLevel(self): + """Return how much logging output will be produced.""" + if self._logger is not None: + return self._logger.getEffectiveLevel()
+ +
[docs] def get_verbosity(self): + """See getEffectiveLevel""" + return self.getEffectiveLevel()
+ +
[docs] def setLevel(self, verbosity_level): + """Sets the threshold for what messages will be logged.""" + if self._logger is not None: + self._logger.setLevel(verbosity_level) + + for handler in self._logger.handlers: + handler.setLevel(verbosity_level)
+ +
[docs] def set_verbosity(self, verbosity_level): + """See setLevel""" + self.setLevel(verbosity_level)
+ +
[docs] @contextmanager + def patch_stderr_handler(self, stream): + """Sends messages that should log to stderr to stream instead. Useful for unittests""" + if self._logger is None: + raise RuntimeError("Impossible to patch logging handlers if handler does not exist") + try: + old_stream = self._handlers["stream_stderr"].stream + if old_stream is None: + raise ValueError + + # Port backwards set_stream() from python 3.7 + self._handlers["stream_stderr"].acquire() + try: + self._handlers["stream_stderr"].flush() + self._handlers["stream_stderr"].stream = stream + finally: + self._handlers["stream_stderr"].release() + + yield stream + except (KeyError, ValueError) as e: + raise RuntimeError("Impossible to patch logging handlers if handler does not exist") from e + + finally: + # Port backwards set_stream() from python 3.7 + self._handlers["stream_stderr"].acquire() + try: + self._handlers["stream_stderr"].flush() + self._handlers["stream_stderr"].stream = old_stream + finally: + self._handlers["stream_stderr"].release()
+ +
[docs] @contextmanager + def patch_stdout_handler(self, stream): + """Sends messages that should log to stdout to stream instead. Useful for unittests""" + if self._logger is None: + raise RuntimeError("Impossible to patch logging handlers if handler does not exist") + try: + old_stream = self._handlers["stream_stdout"].stream + if old_stream is None: + raise ValueError + + # Port backwards set_stream() from python 3.7 + self._handlers["stream_stdout"].acquire() + try: + self._handlers["stream_stdout"].flush() + self._handlers["stream_stdout"].stream = stream + finally: + self._handlers["stream_stdout"].release() + + yield stream + except (KeyError, ValueError) as e: + raise RuntimeError("Impossible to patch logging handlers if handler does not exist") from e + + finally: + # Port backwards set_stream() from python 3.7 + self._handlers["stream_stdout"].acquire() + try: + self._handlers["stream_stdout"].flush() + self._handlers["stream_stdout"].stream = old_stream + finally: + self._handlers["stream_stdout"].release()
+ +
[docs] @contextmanager + def temp_verbosity(self, verbosity_level): + """Sets a temporary threshold for what messages will be logged.""" + if self._logger is not None: + + old_verbosity = self.get_verbosity() + + try: + self.set_verbosity(verbosity_level) + yield + + finally: + self.set_verbosity(old_verbosity) + + else: + try: + yield + + finally: + pass
+ +
[docs] def captureWarnings(self, capture): + """ + If capture is true, redirect all warnings to the logging package. + If capture is False, ensure that warnings are not redirected to logging but to their original destinations. + """ + if self._logger is not None: + + if capture and self.old_warnings_showwarning is None: + # Backup Method + self.old_warnings_showwarning = warnings.showwarning + warnings.showwarning = self._showwarning + + elif not capture and self.old_warnings_showwarning is not None: + # Restore Method + warnings.showwarning = self.old_warnings_showwarning + self.old_warnings_showwarning = None
+ + def _showwarning(self, message, category, filename, lineno, file=None, line=None): + """ + Implementation of show warnings which redirects to logging. + It will call warnings.formatwarning and will log the resulting string with level logging.WARNING. + """ + s = warnings.formatwarning(message, category, filename, lineno, line) + self.warning("%s", s) + + def _logged_once(self, msg, mode): + """ + Returns True if the given message has been logged at least once in the given mode. + + Parameters + ---------- + msg: The message to check. + mode: The mode to check. + + Returns + ------- + True if the message has been logged at least once in the given mode. + """ + if mode == LogMode.ONCE: + PREFIX_LEN = 12 + if msg[PREFIX_LEN:] in self.once_logged: + return True + self.once_logged.add(msg[PREFIX_LEN:]) + return False + +
[docs] def debug(self, msg, *args, mode=LogMode.EACH, **kwargs): + """ + Log 'msg % args' with severity 'DEBUG'. + To pass exception information, use the keyword argument exc_info with a true value, e.g. + logger.debug("Houston, we have %s", "thorny problem", exc_info=1) + """ + if self._logger is not None and self._logger.isEnabledFor(Logger.DEBUG) and not self._logged_once(msg, mode): + self._logger._log(Logger.DEBUG, msg, args, **kwargs)
+ +
[docs] def info(self, msg, *args, mode=LogMode.EACH, **kwargs): + """ + Log 'msg % args' with severity 'INFO'. + To pass exception information, use the keyword argument exc_info with a true value, e.g. + logger.info("Houston, we have %s", "interesting problem", exc_info=1) + """ + if self._logger is not None and self._logger.isEnabledFor(Logger.INFO) and not self._logged_once(msg, mode): + self._logger._log(Logger.INFO, msg, args, **kwargs)
+ +
[docs] def warning(self, msg, *args, mode=LogMode.EACH, **kwargs): + """ + Log 'msg % args' with severity 'WARNING'. + To pass exception information, use the keyword argument exc_info with a true value, e.g. + logger.warning("Houston, we have %s", "bit of a problem", exc_info=1) + """ + if self._logger is not None and self._logger.isEnabledFor(Logger.WARNING) and not self._logged_once(msg, mode): + self._logger._log(Logger.WARNING, msg, args, **kwargs)
+ +
[docs] def error(self, msg, *args, mode=LogMode.EACH, **kwargs): + """ + Log 'msg % args' with severity 'ERROR'. + To pass exception information, use the keyword argument exc_info with a true value, e.g. + logger.error("Houston, we have %s", "major problem", exc_info=1) + """ + if self._logger is not None and self._logger.isEnabledFor(Logger.ERROR) and not self._logged_once(msg, mode): + self._logger._log(Logger.ERROR, msg, args, **kwargs)
+ +
[docs] def critical(self, msg, *args, mode=LogMode.EACH, **kwargs) -> None: + """ + Log 'msg % args' with severity 'CRITICAL'. + To pass exception information, use the keyword argument exc_info with a true value, e.g. + logger.critical("Houston, we have %s", "major disaster", exc_info=1) + + Parameters + ---------- + msg: the message to log + *args: the arguments to the message + mode: the mode to log the message in + **kwargs: the keyword arguments to the message + """ + if ( + self._logger is not None + and self._logger.isEnabledFor(Logger.CRITICAL) + and not self._logged_once(msg, mode) + ): + self._logger._log(Logger.CRITICAL, msg, args, **kwargs)
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_modules/mridc/utils/timers.html b/docs/build/html/_modules/mridc/utils/timers.html new file mode 100644 index 00000000..e3629ac6 --- /dev/null +++ b/docs/build/html/_modules/mridc/utils/timers.html @@ -0,0 +1,241 @@ + + + + + + mridc.utils.timers — mridc v.0.0.1 documentation + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +

Source code for mridc.utils.timers

+# encoding: utf-8
+__author__ = "Dimitrios Karkalousos"
+
+# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/timers.py
+
+import time
+
+import numpy as np
+import torch
+
+__all__ = ["NamedTimer"]
+
+
+
[docs]class NamedTimer: + """ + A timer class that supports multiple named timers. + A named timer can be used multiple times, in which case the average dt will be returned. + A named timer cannot be started if it is already currently running. + Use case: measuring execution of multiple code blocks. + """ + + _REDUCTION_TYPE = ["mean", "sum", "min", "max", "none"] + + def __init__(self, reduction="mean", sync_cuda=False, buffer_size=-1): + """ + + Parameters + ---------- + reduction: Reduction over multiple timings of the same timer (none - returns the list instead of a scalar). + sync_cuda: If True torch.cuda.synchronize() is called for start/stop + buffer_size: If positive, limits the number of stored measures per name + """ + if reduction not in self._REDUCTION_TYPE: + raise ValueError(f"Unknown reduction={reduction} please use one of {self._REDUCTION_TYPE}") + + self._reduction = reduction + self._sync_cuda = sync_cuda + self._buffer_size = buffer_size + + self.reset() + + def __getitem__(self, k): + return self.get(k) + + @property + def buffer_size(self): + """Returns the buffer size of the timer.""" + return self._buffer_size + + @property + def _reduction_fn(self): + """Returns the reduction function for the timer.""" + if self._reduction == "none": + + def fn(x): + return x + + else: + fn = getattr(np, self._reduction) + + return fn + +
[docs] def reset(self, name=None): + """ + Resents all / specific timer + + Parameters + ---------- + name: Timer name to reset (if None all timers are reset) + """ + if name is None: + self.timers = {} + else: + self.timers[name] = {}
+ +
[docs] def start(self, name=""): + """ + Starts measuring a named timer. + + Parameters + ---------- + name: timer name to start + """ + timer_data = self.timers.get(name, {}) + + if "start" in timer_data: + raise RuntimeError(f"Cannot start timer = '{name}' since it is already active") + + # synchronize pytorch cuda execution if supported + if self._sync_cuda and torch.cuda.is_initialized(): + torch.cuda.synchronize() + + timer_data["start"] = time.time() + + self.timers[name] = timer_data
+ +
[docs] def stop(self, name=""): + """ + Stops measuring a named timer. + + Parameters + ---------- + name: timer name to stop + """ + timer_data = self.timers.get(name) + if (timer_data is None) or ("start" not in timer_data): + raise RuntimeError(f"Cannot end timer = '{name}' since it is not active") + + # synchronize pytorch cuda execution if supported + if self._sync_cuda and torch.cuda.is_initialized(): + torch.cuda.synchronize() + + # compute dt and make timer inactive + dt = time.time() - timer_data.pop("start") + + # store dt + timer_data["dt"] = timer_data.get("dt", []) + [dt] + + # enforce buffer_size if positive + if self._buffer_size > 0: + timer_data["dt"] = timer_data["dt"][-self._buffer_size :] + + self.timers[name] = timer_data
+ +
[docs] def active_timers(self): + """Return list of all active named timers""" + return [k for k, v in self.timers.items() if "start" in v]
+ +
[docs] def get(self, name=""): + """ + Returns the value of a named timer + + Parameters + ---------- + name: timer name to return + """ + dt_list = self.timers[name].get("dt", []) + + return self._reduction_fn(dt_list)
+ +
[docs] def export(self): + """Exports a dictionary with average/all dt per named timer""" + fn = self._reduction_fn + + return {k: fn(v["dt"]) for k, v in self.timers.items() if "dt" in v}
+
+ +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/_sources/index.rst.txt b/docs/build/html/_sources/index.rst.txt new file mode 100644 index 00000000..e25c6b0e --- /dev/null +++ b/docs/build/html/_sources/index.rst.txt @@ -0,0 +1,15 @@ +.. mridc documentation master file, created by + sphinx-quickstart on Wed May 25 16:45:16 2022. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to mridc's documentation! +================================= +.. include:: ../../README.md + :parser: myst_parser.sphinx_ + +.. toctree:: + :maxdepth: 4 + :caption: API Documentation: + + modules.rst diff --git a/docs/build/html/_sources/modules.rst.txt b/docs/build/html/_sources/modules.rst.txt new file mode 100644 index 00000000..af286e6a --- /dev/null +++ b/docs/build/html/_sources/modules.rst.txt @@ -0,0 +1,7 @@ +mridc +===== + +.. toctree:: + :maxdepth: 4 + + mridc diff --git a/docs/build/html/_sources/mridc.collections.common.callbacks.rst.txt b/docs/build/html/_sources/mridc.collections.common.callbacks.rst.txt new file mode 100644 index 00000000..c3aaa45d --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.common.callbacks.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.common.callbacks package +========================================== + +Submodules +---------- + +mridc.collections.common.callbacks.callbacks module +--------------------------------------------------- + +.. automodule:: mridc.collections.common.callbacks.callbacks + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.common.callbacks + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.common.data.rst.txt b/docs/build/html/_sources/mridc.collections.common.data.rst.txt new file mode 100644 index 00000000..dbd9e857 --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.common.data.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.common.data package +===================================== + +Submodules +---------- + +mridc.collections.common.data.dataset module +-------------------------------------------- + +.. automodule:: mridc.collections.common.data.dataset + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.common.data + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.common.losses.rst.txt b/docs/build/html/_sources/mridc.collections.common.losses.rst.txt new file mode 100644 index 00000000..3b4e880a --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.common.losses.rst.txt @@ -0,0 +1,29 @@ +mridc.collections.common.losses package +======================================= + +Submodules +---------- + +mridc.collections.common.losses.aggregator module +------------------------------------------------- + +.. automodule:: mridc.collections.common.losses.aggregator + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.common.losses.ssim module +------------------------------------------- + +.. automodule:: mridc.collections.common.losses.ssim + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.common.losses + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.common.metrics.rst.txt b/docs/build/html/_sources/mridc.collections.common.metrics.rst.txt new file mode 100644 index 00000000..06dd9f87 --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.common.metrics.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.common.metrics package +======================================== + +Submodules +---------- + +mridc.collections.common.metrics.global\_average\_loss\_metric module +--------------------------------------------------------------------- + +.. automodule:: mridc.collections.common.metrics.global_average_loss_metric + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.common.metrics + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.common.parts.rst.txt b/docs/build/html/_sources/mridc.collections.common.parts.rst.txt new file mode 100644 index 00000000..b06319e8 --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.common.parts.rst.txt @@ -0,0 +1,53 @@ +mridc.collections.common.parts package +====================================== + +Submodules +---------- + +mridc.collections.common.parts.fft module +----------------------------------------- + +.. automodule:: mridc.collections.common.parts.fft + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.common.parts.patch\_utils module +-------------------------------------------------- + +.. automodule:: mridc.collections.common.parts.patch_utils + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.common.parts.ptl\_overrides module +---------------------------------------------------- + +.. automodule:: mridc.collections.common.parts.ptl_overrides + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.common.parts.rnn\_utils module +------------------------------------------------ + +.. automodule:: mridc.collections.common.parts.rnn_utils + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.common.parts.utils module +------------------------------------------- + +.. automodule:: mridc.collections.common.parts.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.common.parts + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.common.rst.txt b/docs/build/html/_sources/mridc.collections.common.rst.txt new file mode 100644 index 00000000..c0a988df --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.common.rst.txt @@ -0,0 +1,22 @@ +mridc.collections.common package +================================ + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mridc.collections.common.callbacks + mridc.collections.common.data + mridc.collections.common.losses + mridc.collections.common.metrics + mridc.collections.common.parts + +Module contents +--------------- + +.. automodule:: mridc.collections.common + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.data.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.data.rst.txt new file mode 100644 index 00000000..e2e984f6 --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.data.rst.txt @@ -0,0 +1,29 @@ +mridc.collections.reconstruction.data package +============================================= + +Submodules +---------- + +mridc.collections.reconstruction.data.mri\_data module +------------------------------------------------------ + +.. automodule:: mridc.collections.reconstruction.data.mri_data + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.data.subsample module +------------------------------------------------------ + +.. automodule:: mridc.collections.reconstruction.data.subsample + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.data + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.metrics.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.metrics.rst.txt new file mode 100644 index 00000000..06da8eab --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.metrics.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.metrics package +================================================ + +Submodules +---------- + +mridc.collections.reconstruction.metrics.evaluate module +-------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.metrics.evaluate + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.metrics + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.cascadenet.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.cascadenet.rst.txt new file mode 100644 index 00000000..1406af85 --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.cascadenet.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.cascadenet package +========================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.cascadenet.ccnn\_block module +--------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.cascadenet.ccnn_block + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.cascadenet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.conv.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.conv.rst.txt new file mode 100644 index 00000000..4ac025fe --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.conv.rst.txt @@ -0,0 +1,29 @@ +mridc.collections.reconstruction.models.conv package +==================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.conv.conv2d module +---------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.conv.conv2d + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.conv.gruconv2d module +------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.conv.gruconv2d + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.conv + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.convrecnet.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.convrecnet.rst.txt new file mode 100644 index 00000000..1038005e --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.convrecnet.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.convrecnet package +========================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.convrecnet.crnn\_block module +--------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.convrecnet.crnn_block + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.convrecnet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.crossdomain.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.crossdomain.rst.txt new file mode 100644 index 00000000..b4341827 --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.crossdomain.rst.txt @@ -0,0 +1,29 @@ +mridc.collections.reconstruction.models.crossdomain package +=========================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.crossdomain.crossdomain module +---------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.crossdomain.crossdomain + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.crossdomain.multicoil module +-------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.crossdomain.multicoil + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.crossdomain + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.didn.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.didn.rst.txt new file mode 100644 index 00000000..365e6b2e --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.didn.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.didn package +==================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.didn.didn module +-------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.didn.didn + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.didn + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.multidomain.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.multidomain.rst.txt new file mode 100644 index 00000000..dab55d8e --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.multidomain.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.multidomain package +=========================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.multidomain.multidomain module +---------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.multidomain.multidomain + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.multidomain + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.mwcnn.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.mwcnn.rst.txt new file mode 100644 index 00000000..2fa986b3 --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.mwcnn.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.mwcnn package +===================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.mwcnn.mwcnn module +---------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.mwcnn.mwcnn + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.mwcnn + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.primaldual.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.primaldual.rst.txt new file mode 100644 index 00000000..72ce9d35 --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.primaldual.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.primaldual package +========================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.primaldual.pd module +------------------------------------------------------------ + +.. automodule:: mridc.collections.reconstruction.models.primaldual.pd + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.primaldual + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.recurrentvarnet.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.recurrentvarnet.rst.txt new file mode 100644 index 00000000..d82f2f7c --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.recurrentvarnet.rst.txt @@ -0,0 +1,29 @@ +mridc.collections.reconstruction.models.recurrentvarnet package +=============================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.recurrentvarnet.conv2gru module +----------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.recurrentvarnet.conv2gru + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet module +----------------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.recurrentvarnet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.rim.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.rim.rst.txt new file mode 100644 index 00000000..58a6db10 --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.rim.rst.txt @@ -0,0 +1,45 @@ +mridc.collections.reconstruction.models.rim package +=================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.rim.conv\_layers module +--------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.rim.conv_layers + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.rim.rim\_block module +------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.rim.rim_block + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.rim.rnn\_cells module +------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.rim.rnn_cells + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.rim.utils module +-------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.rim.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.rim + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.rst.txt new file mode 100644 index 00000000..c2710de4 --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.rst.txt @@ -0,0 +1,162 @@ +mridc.collections.reconstruction.models package +=============================================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mridc.collections.reconstruction.models.cascadenet + mridc.collections.reconstruction.models.conv + mridc.collections.reconstruction.models.convrecnet + mridc.collections.reconstruction.models.crossdomain + mridc.collections.reconstruction.models.didn + mridc.collections.reconstruction.models.multidomain + mridc.collections.reconstruction.models.mwcnn + mridc.collections.reconstruction.models.primaldual + mridc.collections.reconstruction.models.recurrentvarnet + mridc.collections.reconstruction.models.rim + mridc.collections.reconstruction.models.sigmanet + mridc.collections.reconstruction.models.unet_base + mridc.collections.reconstruction.models.variablesplittingnet + mridc.collections.reconstruction.models.varnet + +Submodules +---------- + +mridc.collections.reconstruction.models.base module +--------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.base + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.ccnn module +--------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.ccnn + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.cirim module +---------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.cirim + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.crnn module +--------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.crnn + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.dunet module +---------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.dunet + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.jointicnet module +--------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.jointicnet + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.kikinet module +------------------------------------------------------ + +.. automodule:: mridc.collections.reconstruction.models.kikinet + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.lpd module +-------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.lpd + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.multidomainnet module +------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.multidomainnet + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.pics module +--------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.pics + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.rvn module +-------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.rvn + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.unet module +--------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.unet + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.vn module +------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.vn + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.vsnet module +---------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.vsnet + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.xpdnet module +----------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.xpdnet + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.zf module +------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.zf + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.sigmanet.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.sigmanet.rst.txt new file mode 100644 index 00000000..9c3de5e3 --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.sigmanet.rst.txt @@ -0,0 +1,29 @@ +mridc.collections.reconstruction.models.sigmanet package +======================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.sigmanet.dc\_layers module +------------------------------------------------------------------ + +.. automodule:: mridc.collections.reconstruction.models.sigmanet.dc_layers + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.sigmanet.sensitivity\_net module +------------------------------------------------------------------------ + +.. automodule:: mridc.collections.reconstruction.models.sigmanet.sensitivity_net + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.sigmanet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.unet_base.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.unet_base.rst.txt new file mode 100644 index 00000000..deacd5a8 --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.unet_base.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.unet\_base package +========================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.unet\_base.unet\_block module +--------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.unet_base.unet_block + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.unet_base + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.variablesplittingnet.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.variablesplittingnet.rst.txt new file mode 100644 index 00000000..8f982eee --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.variablesplittingnet.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.variablesplittingnet package +==================================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.variablesplittingnet.vsnet\_block module +-------------------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.variablesplittingnet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.models.varnet.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.models.varnet.rst.txt new file mode 100644 index 00000000..2210165f --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.models.varnet.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.varnet package +====================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.varnet.vn\_block module +--------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.varnet.vn_block + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.varnet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.parts.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.parts.rst.txt new file mode 100644 index 00000000..1aa734e4 --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.parts.rst.txt @@ -0,0 +1,29 @@ +mridc.collections.reconstruction.parts package +============================================== + +Submodules +---------- + +mridc.collections.reconstruction.parts.transforms module +-------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.parts.transforms + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.parts.utils module +--------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.parts.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.parts + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.reconstruction.rst.txt b/docs/build/html/_sources/mridc.collections.reconstruction.rst.txt new file mode 100644 index 00000000..f24c3c4b --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.reconstruction.rst.txt @@ -0,0 +1,21 @@ +mridc.collections.reconstruction package +======================================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mridc.collections.reconstruction.data + mridc.collections.reconstruction.metrics + mridc.collections.reconstruction.models + mridc.collections.reconstruction.parts + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.collections.rst.txt b/docs/build/html/_sources/mridc.collections.rst.txt new file mode 100644 index 00000000..362d583a --- /dev/null +++ b/docs/build/html/_sources/mridc.collections.rst.txt @@ -0,0 +1,19 @@ +mridc.collections package +========================= + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mridc.collections.common + mridc.collections.reconstruction + +Module contents +--------------- + +.. automodule:: mridc.collections + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.core.classes.rst.txt b/docs/build/html/_sources/mridc.core.classes.rst.txt new file mode 100644 index 00000000..9ffff854 --- /dev/null +++ b/docs/build/html/_sources/mridc.core.classes.rst.txt @@ -0,0 +1,61 @@ +mridc.core.classes package +========================== + +Submodules +---------- + +mridc.core.classes.common module +-------------------------------- + +.. automodule:: mridc.core.classes.common + :members: + :undoc-members: + :show-inheritance: + +mridc.core.classes.dataset module +--------------------------------- + +.. automodule:: mridc.core.classes.dataset + :members: + :undoc-members: + :show-inheritance: + +mridc.core.classes.export module +-------------------------------- + +.. automodule:: mridc.core.classes.export + :members: + :undoc-members: + :show-inheritance: + +mridc.core.classes.loss module +------------------------------ + +.. automodule:: mridc.core.classes.loss + :members: + :undoc-members: + :show-inheritance: + +mridc.core.classes.modelPT module +--------------------------------- + +.. automodule:: mridc.core.classes.modelPT + :members: + :undoc-members: + :show-inheritance: + +mridc.core.classes.module module +-------------------------------- + +.. automodule:: mridc.core.classes.module + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.core.classes + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.core.conf.rst.txt b/docs/build/html/_sources/mridc.core.conf.rst.txt new file mode 100644 index 00000000..efac091d --- /dev/null +++ b/docs/build/html/_sources/mridc.core.conf.rst.txt @@ -0,0 +1,69 @@ +mridc.core.conf package +======================= + +Submodules +---------- + +mridc.core.conf.base\_config module +----------------------------------- + +.. automodule:: mridc.core.conf.base_config + :members: + :undoc-members: + :show-inheritance: + +mridc.core.conf.dataloader module +--------------------------------- + +.. automodule:: mridc.core.conf.dataloader + :members: + :undoc-members: + :show-inheritance: + +mridc.core.conf.hydra\_runner module +------------------------------------ + +.. automodule:: mridc.core.conf.hydra_runner + :members: + :undoc-members: + :show-inheritance: + +mridc.core.conf.modelPT module +------------------------------ + +.. automodule:: mridc.core.conf.modelPT + :members: + :undoc-members: + :show-inheritance: + +mridc.core.conf.optimizers module +--------------------------------- + +.. automodule:: mridc.core.conf.optimizers + :members: + :undoc-members: + :show-inheritance: + +mridc.core.conf.schedulers module +--------------------------------- + +.. automodule:: mridc.core.conf.schedulers + :members: + :undoc-members: + :show-inheritance: + +mridc.core.conf.trainer module +------------------------------ + +.. automodule:: mridc.core.conf.trainer + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.core.conf + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.core.connectors.rst.txt b/docs/build/html/_sources/mridc.core.connectors.rst.txt new file mode 100644 index 00000000..e7bfd68b --- /dev/null +++ b/docs/build/html/_sources/mridc.core.connectors.rst.txt @@ -0,0 +1,21 @@ +mridc.core.connectors package +============================= + +Submodules +---------- + +mridc.core.connectors.save\_restore\_connector module +----------------------------------------------------- + +.. automodule:: mridc.core.connectors.save_restore_connector + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.core.connectors + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.core.neural_types.rst.txt b/docs/build/html/_sources/mridc.core.neural_types.rst.txt new file mode 100644 index 00000000..43deaebd --- /dev/null +++ b/docs/build/html/_sources/mridc.core.neural_types.rst.txt @@ -0,0 +1,45 @@ +mridc.core.neural\_types package +================================ + +Submodules +---------- + +mridc.core.neural\_types.axes module +------------------------------------ + +.. automodule:: mridc.core.neural_types.axes + :members: + :undoc-members: + :show-inheritance: + +mridc.core.neural\_types.comparison module +------------------------------------------ + +.. automodule:: mridc.core.neural_types.comparison + :members: + :undoc-members: + :show-inheritance: + +mridc.core.neural\_types.elements module +---------------------------------------- + +.. automodule:: mridc.core.neural_types.elements + :members: + :undoc-members: + :show-inheritance: + +mridc.core.neural\_types.neural\_type module +-------------------------------------------- + +.. automodule:: mridc.core.neural_types.neural_type + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.core.neural_types + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.core.optim.rst.txt b/docs/build/html/_sources/mridc.core.optim.rst.txt new file mode 100644 index 00000000..bb352f50 --- /dev/null +++ b/docs/build/html/_sources/mridc.core.optim.rst.txt @@ -0,0 +1,53 @@ +mridc.core.optim package +======================== + +Submodules +---------- + +mridc.core.optim.adafactor module +--------------------------------- + +.. automodule:: mridc.core.optim.adafactor + :members: + :undoc-members: + :show-inheritance: + +mridc.core.optim.lr\_scheduler module +------------------------------------- + +.. automodule:: mridc.core.optim.lr_scheduler + :members: + :undoc-members: + :show-inheritance: + +mridc.core.optim.novograd module +-------------------------------- + +.. automodule:: mridc.core.optim.novograd + :members: + :undoc-members: + :show-inheritance: + +mridc.core.optim.optimizer\_with\_master\_params module +------------------------------------------------------- + +.. automodule:: mridc.core.optim.optimizer_with_master_params + :members: + :undoc-members: + :show-inheritance: + +mridc.core.optim.optimizers module +---------------------------------- + +.. automodule:: mridc.core.optim.optimizers + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.core.optim + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.core.rst.txt b/docs/build/html/_sources/mridc.core.rst.txt new file mode 100644 index 00000000..bfbcf8c0 --- /dev/null +++ b/docs/build/html/_sources/mridc.core.rst.txt @@ -0,0 +1,23 @@ +mridc.core package +================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mridc.core.classes + mridc.core.conf + mridc.core.connectors + mridc.core.neural_types + mridc.core.optim + mridc.core.utils + +Module contents +--------------- + +.. automodule:: mridc.core + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.core.utils.rst.txt b/docs/build/html/_sources/mridc.core.utils.rst.txt new file mode 100644 index 00000000..a87d1362 --- /dev/null +++ b/docs/build/html/_sources/mridc.core.utils.rst.txt @@ -0,0 +1,29 @@ +mridc.core.utils package +======================== + +Submodules +---------- + +mridc.core.utils.neural\_type\_utils module +------------------------------------------- + +.. automodule:: mridc.core.utils.neural_type_utils + :members: + :undoc-members: + :show-inheritance: + +mridc.core.utils.numba\_utils module +------------------------------------ + +.. automodule:: mridc.core.utils.numba_utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.core.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.rst.txt b/docs/build/html/_sources/mridc.rst.txt new file mode 100644 index 00000000..09088364 --- /dev/null +++ b/docs/build/html/_sources/mridc.rst.txt @@ -0,0 +1,47 @@ +mridc package +============= + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mridc.collections + mridc.core + mridc.utils + +Submodules +---------- + +mridc.constants module +---------------------- + +.. automodule:: mridc.constants + :members: + :undoc-members: + :show-inheritance: + +mridc.launch module +------------------- + +.. automodule:: mridc.launch + :members: + :undoc-members: + :show-inheritance: + +mridc.package\_info module +-------------------------- + +.. automodule:: mridc.package_info + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.utils.decorators.rst.txt b/docs/build/html/_sources/mridc.utils.decorators.rst.txt new file mode 100644 index 00000000..5b55330b --- /dev/null +++ b/docs/build/html/_sources/mridc.utils.decorators.rst.txt @@ -0,0 +1,37 @@ +mridc.utils.decorators package +============================== + +Submodules +---------- + +mridc.utils.decorators.deprecated module +---------------------------------------- + +.. automodule:: mridc.utils.decorators.deprecated + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.decorators.experimental module +------------------------------------------ + +.. automodule:: mridc.utils.decorators.experimental + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.decorators.port\_docs module +---------------------------------------- + +.. automodule:: mridc.utils.decorators.port_docs + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.utils.decorators + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.utils.formaters.rst.txt b/docs/build/html/_sources/mridc.utils.formaters.rst.txt new file mode 100644 index 00000000..ef8fed13 --- /dev/null +++ b/docs/build/html/_sources/mridc.utils.formaters.rst.txt @@ -0,0 +1,37 @@ +mridc.utils.formaters package +============================= + +Submodules +---------- + +mridc.utils.formaters.base module +--------------------------------- + +.. automodule:: mridc.utils.formaters.base + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.formaters.colors module +----------------------------------- + +.. automodule:: mridc.utils.formaters.colors + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.formaters.utils module +---------------------------------- + +.. automodule:: mridc.utils.formaters.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.utils.formaters + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/mridc.utils.rst.txt b/docs/build/html/_sources/mridc.utils.rst.txt new file mode 100644 index 00000000..dea4ead2 --- /dev/null +++ b/docs/build/html/_sources/mridc.utils.rst.txt @@ -0,0 +1,142 @@ +mridc.utils package +=================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mridc.utils.decorators + mridc.utils.formaters + +Submodules +---------- + +mridc.utils.app\_state module +----------------------------- + +.. automodule:: mridc.utils.app_state + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.arguments module +---------------------------- + +.. automodule:: mridc.utils.arguments + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.cloud module +------------------------ + +.. automodule:: mridc.utils.cloud + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.config\_utils module +-------------------------------- + +.. automodule:: mridc.utils.config_utils + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.distributed module +------------------------------ + +.. automodule:: mridc.utils.distributed + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.env\_var\_parsing module +------------------------------------ + +.. automodule:: mridc.utils.env_var_parsing + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.exceptions module +----------------------------- + +.. automodule:: mridc.utils.exceptions + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.exp\_manager module +------------------------------- + +.. automodule:: mridc.utils.exp_manager + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.export\_utils module +-------------------------------- + +.. automodule:: mridc.utils.export_utils + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.get\_rank module +---------------------------- + +.. automodule:: mridc.utils.get_rank + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.lightning\_logger\_patch module +------------------------------------------- + +.. automodule:: mridc.utils.lightning_logger_patch + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.metaclasses module +------------------------------ + +.. automodule:: mridc.utils.metaclasses + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.model\_utils module +------------------------------- + +.. automodule:: mridc.utils.model_utils + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.mridc\_logging module +--------------------------------- + +.. automodule:: mridc.utils.mridc_logging + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.timers module +------------------------- + +.. automodule:: mridc.utils.timers + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_static/basic.css b/docs/build/html/_static/basic.css new file mode 100644 index 00000000..1c769a56 --- /dev/null +++ b/docs/build/html/_static/basic.css @@ -0,0 +1,906 @@ +/* + * basic.css + * ~~~~~~~~~ + * + * Sphinx stylesheet -- basic theme. + * + * :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. + * :license: BSD, see LICENSE for details. + * + */ + +/* -- main layout ----------------------------------------------------------- */ + +div.clearer { + clear: both; +} + +div.section::after { + display: block; + content: ''; + clear: left; +} + +/* -- relbar ---------------------------------------------------------------- */ + +div.related { + width: 100%; + font-size: 90%; +} + +div.related h3 { + display: none; +} + +div.related ul { + margin: 0; + padding: 0 0 0 10px; + list-style: none; +} + +div.related li { + display: inline; +} + +div.related li.right { + float: right; + margin-right: 5px; +} + +/* -- sidebar --------------------------------------------------------------- */ + +div.sphinxsidebarwrapper { + padding: 10px 5px 0 10px; +} + +div.sphinxsidebar { + float: left; + width: 230px; + margin-left: -100%; + font-size: 90%; + word-wrap: break-word; + overflow-wrap : break-word; +} + +div.sphinxsidebar ul { + list-style: none; +} + +div.sphinxsidebar ul ul, +div.sphinxsidebar ul.want-points { + margin-left: 20px; + list-style: square; +} + +div.sphinxsidebar ul ul { + margin-top: 0; + margin-bottom: 0; +} + +div.sphinxsidebar form { + margin-top: 10px; +} + +div.sphinxsidebar input { + border: 1px solid #98dbcc; + font-family: sans-serif; + font-size: 1em; +} + +div.sphinxsidebar #searchbox form.search { + overflow: hidden; +} + +div.sphinxsidebar #searchbox input[type="text"] { + float: left; + width: 80%; + padding: 0.25em; + box-sizing: border-box; +} + +div.sphinxsidebar #searchbox input[type="submit"] { + float: left; + width: 20%; + border-left: none; + padding: 0.25em; + box-sizing: border-box; +} + + +img { + border: 0; + max-width: 100%; +} + +/* -- search page ----------------------------------------------------------- */ + +ul.search { + margin: 10px 0 0 20px; + padding: 0; +} + +ul.search li { + padding: 5px 0 5px 20px; + background-image: url(file.png); + background-repeat: no-repeat; + background-position: 0 7px; +} + +ul.search li a { + font-weight: bold; +} + +ul.search li p.context { + color: #888; + margin: 2px 0 0 30px; + text-align: left; +} + +ul.keywordmatches li.goodmatch a { + font-weight: bold; +} + +/* -- index page ------------------------------------------------------------ */ + +table.contentstable { + width: 90%; + margin-left: auto; + margin-right: auto; +} + +table.contentstable p.biglink { + line-height: 150%; +} + +a.biglink { + font-size: 1.3em; +} + +span.linkdescr { + font-style: italic; + padding-top: 5px; + font-size: 90%; +} + +/* -- general index --------------------------------------------------------- */ + +table.indextable { + width: 100%; +} + +table.indextable td { + text-align: left; + vertical-align: top; +} + +table.indextable ul { + margin-top: 0; + margin-bottom: 0; + list-style-type: none; +} + +table.indextable > tbody > tr > td > ul { + padding-left: 0em; +} + +table.indextable tr.pcap { + height: 10px; +} + +table.indextable tr.cap { + margin-top: 10px; + background-color: #f2f2f2; +} + +img.toggler { + margin-right: 3px; + margin-top: 3px; + cursor: pointer; +} + +div.modindex-jumpbox { + border-top: 1px solid #ddd; + border-bottom: 1px solid #ddd; + margin: 1em 0 1em 0; + padding: 0.4em; +} + +div.genindex-jumpbox { + border-top: 1px solid #ddd; + border-bottom: 1px solid #ddd; + margin: 1em 0 1em 0; + padding: 0.4em; +} + +/* -- domain module index --------------------------------------------------- */ + +table.modindextable td { + padding: 2px; + border-collapse: collapse; +} + +/* -- general body styles --------------------------------------------------- */ + +div.body { + min-width: 450px; + max-width: 800px; +} + +div.body p, div.body dd, div.body li, div.body blockquote { + -moz-hyphens: auto; + -ms-hyphens: auto; + -webkit-hyphens: auto; + hyphens: auto; +} + +a.headerlink { + visibility: hidden; +} + +a.brackets:before, +span.brackets > a:before{ + content: "["; +} + +a.brackets:after, +span.brackets > a:after { + content: "]"; +} + +h1:hover > a.headerlink, +h2:hover > a.headerlink, +h3:hover > a.headerlink, +h4:hover > a.headerlink, +h5:hover > a.headerlink, +h6:hover > a.headerlink, +dt:hover > a.headerlink, +caption:hover > a.headerlink, +p.caption:hover > a.headerlink, +div.code-block-caption:hover > a.headerlink { + visibility: visible; +} + +div.body p.caption { + text-align: inherit; +} + +div.body td { + text-align: left; +} + +.first { + margin-top: 0 !important; +} + +p.rubric { + margin-top: 30px; + font-weight: bold; +} + +img.align-left, figure.align-left, .figure.align-left, object.align-left { + clear: left; + float: left; + margin-right: 1em; +} + +img.align-right, figure.align-right, .figure.align-right, object.align-right { + clear: right; + float: right; + margin-left: 1em; +} + +img.align-center, figure.align-center, .figure.align-center, object.align-center { + display: block; + margin-left: auto; + margin-right: auto; +} + +img.align-default, figure.align-default, .figure.align-default { + display: block; + margin-left: auto; + margin-right: auto; +} + +.align-left { + text-align: left; +} + +.align-center { + text-align: center; +} + +.align-default { + text-align: center; +} + +.align-right { + text-align: right; +} + +/* -- sidebars -------------------------------------------------------------- */ + +div.sidebar, +aside.sidebar { + margin: 0 0 0.5em 1em; + border: 1px solid #ddb; + padding: 7px; + background-color: #ffe; + width: 40%; + float: right; + clear: right; + overflow-x: auto; +} + +p.sidebar-title { + font-weight: bold; +} + +div.admonition, div.topic, blockquote { + clear: left; +} + +/* -- topics ---------------------------------------------------------------- */ + +div.topic { + border: 1px solid #ccc; + padding: 7px; + margin: 10px 0 10px 0; +} + +p.topic-title { + font-size: 1.1em; + font-weight: bold; + margin-top: 10px; +} + +/* -- admonitions ----------------------------------------------------------- */ + +div.admonition { + margin-top: 10px; + margin-bottom: 10px; + padding: 7px; +} + +div.admonition dt { + font-weight: bold; +} + +p.admonition-title { + margin: 0px 10px 5px 0px; + font-weight: bold; +} + +div.body p.centered { + text-align: center; + margin-top: 25px; +} + +/* -- content of sidebars/topics/admonitions -------------------------------- */ + +div.sidebar > :last-child, +aside.sidebar > :last-child, +div.topic > :last-child, +div.admonition > :last-child { + margin-bottom: 0; +} + +div.sidebar::after, +aside.sidebar::after, +div.topic::after, +div.admonition::after, +blockquote::after { + display: block; + content: ''; + clear: both; +} + +/* -- tables ---------------------------------------------------------------- */ + +table.docutils { + margin-top: 10px; + margin-bottom: 10px; + border: 0; + border-collapse: collapse; +} + +table.align-center { + margin-left: auto; + margin-right: auto; +} + +table.align-default { + margin-left: auto; + margin-right: auto; +} + +table caption span.caption-number { + font-style: italic; +} + +table caption span.caption-text { +} + +table.docutils td, table.docutils th { + padding: 1px 8px 1px 5px; + border-top: 0; + border-left: 0; + border-right: 0; + border-bottom: 1px solid #aaa; +} + +table.footnote td, table.footnote th { + border: 0 !important; +} + +th { + text-align: left; + padding-right: 5px; +} + +table.citation { + border-left: solid 1px gray; + margin-left: 1px; +} + +table.citation td { + border-bottom: none; +} + +th > :first-child, +td > :first-child { + margin-top: 0px; +} + +th > :last-child, +td > :last-child { + margin-bottom: 0px; +} + +/* -- figures --------------------------------------------------------------- */ + +div.figure, figure { + margin: 0.5em; + padding: 0.5em; +} + +div.figure p.caption, figcaption { + padding: 0.3em; +} + +div.figure p.caption span.caption-number, +figcaption span.caption-number { + font-style: italic; +} + +div.figure p.caption span.caption-text, +figcaption span.caption-text { +} + +/* -- field list styles ----------------------------------------------------- */ + +table.field-list td, table.field-list th { + border: 0 !important; +} + +.field-list ul { + margin: 0; + padding-left: 1em; +} + +.field-list p { + margin: 0; +} + +.field-name { + -moz-hyphens: manual; + -ms-hyphens: manual; + -webkit-hyphens: manual; + hyphens: manual; +} + +/* -- hlist styles ---------------------------------------------------------- */ + +table.hlist { + margin: 1em 0; +} + +table.hlist td { + vertical-align: top; +} + +/* -- object description styles --------------------------------------------- */ + +.sig { + font-family: 'Consolas', 'Menlo', 'DejaVu Sans Mono', 'Bitstream Vera Sans Mono', monospace; +} + +.sig-name, code.descname { + background-color: transparent; + font-weight: bold; +} + +.sig-name { + font-size: 1.1em; +} + +code.descname { + font-size: 1.2em; +} + +.sig-prename, code.descclassname { + background-color: transparent; +} + +.optional { + font-size: 1.3em; +} + +.sig-paren { + font-size: larger; +} + +.sig-param.n { + font-style: italic; +} + +/* C++ specific styling */ + +.sig-inline.c-texpr, +.sig-inline.cpp-texpr { + font-family: unset; +} + +.sig.c .k, .sig.c .kt, +.sig.cpp .k, .sig.cpp .kt { + color: #0033B3; +} + +.sig.c .m, +.sig.cpp .m { + color: #1750EB; +} + +.sig.c .s, .sig.c .sc, +.sig.cpp .s, .sig.cpp .sc { + color: #067D17; +} + + +/* -- other body styles ----------------------------------------------------- */ + +ol.arabic { + list-style: decimal; +} + +ol.loweralpha { + list-style: lower-alpha; +} + +ol.upperalpha { + list-style: upper-alpha; +} + +ol.lowerroman { + list-style: lower-roman; +} + +ol.upperroman { + list-style: upper-roman; +} + +:not(li) > ol > li:first-child > :first-child, +:not(li) > ul > li:first-child > :first-child { + margin-top: 0px; +} + +:not(li) > ol > li:last-child > :last-child, +:not(li) > ul > li:last-child > :last-child { + margin-bottom: 0px; +} + +ol.simple ol p, +ol.simple ul p, +ul.simple ol p, +ul.simple ul p { + margin-top: 0; +} + +ol.simple > li:not(:first-child) > p, +ul.simple > li:not(:first-child) > p { + margin-top: 0; +} + +ol.simple p, +ul.simple p { + margin-bottom: 0; +} + +dl.footnote > dt, +dl.citation > dt { + float: left; + margin-right: 0.5em; +} + +dl.footnote > dd, +dl.citation > dd { + margin-bottom: 0em; +} + +dl.footnote > dd:after, +dl.citation > dd:after { + content: ""; + clear: both; +} + +dl.field-list { + display: grid; + grid-template-columns: fit-content(30%) auto; +} + +dl.field-list > dt { + font-weight: bold; + word-break: break-word; + padding-left: 0.5em; + padding-right: 5px; +} + +dl.field-list > dt:after { + content: ":"; +} + +dl.field-list > dd { + padding-left: 0.5em; + margin-top: 0em; + margin-left: 0em; + margin-bottom: 0em; +} + +dl { + margin-bottom: 15px; +} + +dd > :first-child { + margin-top: 0px; +} + +dd ul, dd table { + margin-bottom: 10px; +} + +dd { + margin-top: 3px; + margin-bottom: 10px; + margin-left: 30px; +} + +dl > dd:last-child, +dl > dd:last-child > :last-child { + margin-bottom: 0; +} + +dt:target, span.highlighted { + background-color: #fbe54e; +} + +rect.highlighted { + fill: #fbe54e; +} + +dl.glossary dt { + font-weight: bold; + font-size: 1.1em; +} + +.versionmodified { + font-style: italic; +} + +.system-message { + background-color: #fda; + padding: 5px; + border: 3px solid red; +} + +.footnote:target { + background-color: #ffa; +} + +.line-block { + display: block; + margin-top: 1em; + margin-bottom: 1em; +} + +.line-block .line-block { + margin-top: 0; + margin-bottom: 0; + margin-left: 1.5em; +} + +.guilabel, .menuselection { + font-family: sans-serif; +} + +.accelerator { + text-decoration: underline; +} + +.classifier { + font-style: oblique; +} + +.classifier:before { + font-style: normal; + margin: 0 0.5em; + content: ":"; + display: inline-block; +} + +abbr, acronym { + border-bottom: dotted 1px; + cursor: help; +} + +/* -- code displays --------------------------------------------------------- */ + +pre { + overflow: auto; + overflow-y: hidden; /* fixes display issues on Chrome browsers */ +} + +pre, div[class*="highlight-"] { + clear: both; +} + +span.pre { + -moz-hyphens: none; + -ms-hyphens: none; + -webkit-hyphens: none; + hyphens: none; + white-space: nowrap; +} + +div[class*="highlight-"] { + margin: 1em 0; +} + +td.linenos pre { + border: 0; + background-color: transparent; + color: #aaa; +} + +table.highlighttable { + display: block; +} + +table.highlighttable tbody { + display: block; +} + +table.highlighttable tr { + display: flex; +} + +table.highlighttable td { + margin: 0; + padding: 0; +} + +table.highlighttable td.linenos { + padding-right: 0.5em; +} + +table.highlighttable td.code { + flex: 1; + overflow: hidden; +} + +.highlight .hll { + display: block; +} + +div.highlight pre, +table.highlighttable pre { + margin: 0; +} + +div.code-block-caption + div { + margin-top: 0; +} + +div.code-block-caption { + margin-top: 1em; + padding: 2px 5px; + font-size: small; +} + +div.code-block-caption code { + background-color: transparent; +} + +table.highlighttable td.linenos, +span.linenos, +div.highlight span.gp { /* gp: Generic.Prompt */ + user-select: none; + -webkit-user-select: text; /* Safari fallback only */ + -webkit-user-select: none; /* Chrome/Safari */ + -moz-user-select: none; /* Firefox */ + -ms-user-select: none; /* IE10+ */ +} + +div.code-block-caption span.caption-number { + padding: 0.1em 0.3em; + font-style: italic; +} + +div.code-block-caption span.caption-text { +} + +div.literal-block-wrapper { + margin: 1em 0; +} + +code.xref, a code { + background-color: transparent; + font-weight: bold; +} + +h1 code, h2 code, h3 code, h4 code, h5 code, h6 code { + background-color: transparent; +} + +.viewcode-link { + float: right; +} + +.viewcode-back { + float: right; + font-family: sans-serif; +} + +div.viewcode-block:target { + margin: -1px -10px; + padding: 0 10px; +} + +/* -- math display ---------------------------------------------------------- */ + +img.math { + vertical-align: middle; +} + +div.body div.math p { + text-align: center; +} + +span.eqno { + float: right; +} + +span.eqno a.headerlink { + position: absolute; + z-index: 1; +} + +div.math:hover a.headerlink { + visibility: visible; +} + +/* -- printout stylesheet --------------------------------------------------- */ + +@media print { + div.document, + div.documentwrapper, + div.bodywrapper { + margin: 0 !important; + width: 100%; + } + + div.sphinxsidebar, + div.related, + div.footer, + #top-link { + display: none; + } +} diff --git a/docs/build/html/_static/css/badge_only.css b/docs/build/html/_static/css/badge_only.css new file mode 100644 index 00000000..b1a0fbfe --- /dev/null +++ b/docs/build/html/_static/css/badge_only.css @@ -0,0 +1 @@ +.fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} diff --git a/docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff b/docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff new file mode 100644 index 00000000..6cb60000 Binary files /dev/null and b/docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff differ diff --git a/docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff2 b/docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff2 new file mode 100644 index 00000000..7059e231 Binary files /dev/null and b/docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff2 differ diff --git a/docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff b/docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff new file mode 100644 index 00000000..f815f63f Binary files /dev/null and b/docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff differ diff --git a/docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff2 b/docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff2 new file mode 100644 index 00000000..f2c76e5b Binary files /dev/null and b/docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff2 differ diff --git a/docs/build/html/_static/css/fonts/fontawesome-webfont.eot b/docs/build/html/_static/css/fonts/fontawesome-webfont.eot new file mode 100644 index 00000000..e9f60ca9 Binary files /dev/null and b/docs/build/html/_static/css/fonts/fontawesome-webfont.eot differ diff --git a/docs/build/html/_static/css/fonts/fontawesome-webfont.svg b/docs/build/html/_static/css/fonts/fontawesome-webfont.svg new file mode 100644 index 00000000..52c07733 --- /dev/null +++ b/docs/build/html/_static/css/fonts/fontawesome-webfont.svg @@ -0,0 +1,2671 @@ + + + + +Created by FontForge 20120731 at Mon Oct 24 17:37:40 2016 + By ,,, +Copyright Dave Gandy 2016. All rights reserved. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/build/html/_static/css/fonts/fontawesome-webfont.ttf b/docs/build/html/_static/css/fonts/fontawesome-webfont.ttf new file mode 100644 index 00000000..35acda2f Binary files /dev/null and b/docs/build/html/_static/css/fonts/fontawesome-webfont.ttf differ diff --git a/docs/build/html/_static/css/fonts/fontawesome-webfont.woff b/docs/build/html/_static/css/fonts/fontawesome-webfont.woff new file mode 100644 index 00000000..400014a4 Binary files /dev/null and b/docs/build/html/_static/css/fonts/fontawesome-webfont.woff differ diff --git a/docs/build/html/_static/css/fonts/fontawesome-webfont.woff2 b/docs/build/html/_static/css/fonts/fontawesome-webfont.woff2 new file mode 100644 index 00000000..4d13fc60 Binary files /dev/null and b/docs/build/html/_static/css/fonts/fontawesome-webfont.woff2 differ diff --git a/docs/build/html/_static/css/fonts/lato-bold-italic.woff b/docs/build/html/_static/css/fonts/lato-bold-italic.woff new file mode 100644 index 00000000..88ad05b9 Binary files /dev/null and b/docs/build/html/_static/css/fonts/lato-bold-italic.woff differ diff --git a/docs/build/html/_static/css/fonts/lato-bold-italic.woff2 b/docs/build/html/_static/css/fonts/lato-bold-italic.woff2 new file mode 100644 index 00000000..c4e3d804 Binary files /dev/null and b/docs/build/html/_static/css/fonts/lato-bold-italic.woff2 differ diff --git a/docs/build/html/_static/css/fonts/lato-bold.woff b/docs/build/html/_static/css/fonts/lato-bold.woff new file mode 100644 index 00000000..c6dff51f Binary files /dev/null and b/docs/build/html/_static/css/fonts/lato-bold.woff differ diff --git a/docs/build/html/_static/css/fonts/lato-bold.woff2 b/docs/build/html/_static/css/fonts/lato-bold.woff2 new file mode 100644 index 00000000..bb195043 Binary files /dev/null and b/docs/build/html/_static/css/fonts/lato-bold.woff2 differ diff --git a/docs/build/html/_static/css/fonts/lato-normal-italic.woff b/docs/build/html/_static/css/fonts/lato-normal-italic.woff new file mode 100644 index 00000000..76114bc0 Binary files /dev/null and b/docs/build/html/_static/css/fonts/lato-normal-italic.woff differ diff --git a/docs/build/html/_static/css/fonts/lato-normal-italic.woff2 b/docs/build/html/_static/css/fonts/lato-normal-italic.woff2 new file mode 100644 index 00000000..3404f37e Binary files /dev/null and b/docs/build/html/_static/css/fonts/lato-normal-italic.woff2 differ diff --git a/docs/build/html/_static/css/fonts/lato-normal.woff b/docs/build/html/_static/css/fonts/lato-normal.woff new file mode 100644 index 00000000..ae1307ff Binary files /dev/null and b/docs/build/html/_static/css/fonts/lato-normal.woff differ diff --git a/docs/build/html/_static/css/fonts/lato-normal.woff2 b/docs/build/html/_static/css/fonts/lato-normal.woff2 new file mode 100644 index 00000000..3bf98433 Binary files /dev/null and b/docs/build/html/_static/css/fonts/lato-normal.woff2 differ diff --git a/docs/build/html/_static/css/theme.css b/docs/build/html/_static/css/theme.css new file mode 100644 index 00000000..42d82ec2 --- /dev/null +++ b/docs/build/html/_static/css/theme.css @@ -0,0 +1,4 @@ +html{box-sizing:border-box}*,:after,:before{box-sizing:inherit}article,aside,details,figcaption,figure,footer,header,hgroup,nav,section{display:block}audio,canvas,video{display:inline-block;*display:inline;*zoom:1}[hidden],audio:not([controls]){display:none}*{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}html{font-size:100%;-webkit-text-size-adjust:100%;-ms-text-size-adjust:100%}body{margin:0}a:active,a:hover{outline:0}abbr[title]{border-bottom:1px dotted}b,strong{font-weight:700}blockquote{margin:0}dfn{font-style:italic}ins{background:#ff9;text-decoration:none}ins,mark{color:#000}mark{background:#ff0;font-style:italic;font-weight:700}.rst-content code,.rst-content tt,code,kbd,pre,samp{font-family:monospace,serif;_font-family:courier new,monospace;font-size:1em}pre{white-space:pre}q{quotes:none}q:after,q:before{content:"";content:none}small{font-size:85%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sup{top:-.5em}sub{bottom:-.25em}dl,ol,ul{margin:0;padding:0;list-style:none;list-style-image:none}li{list-style:none}dd{margin:0}img{border:0;-ms-interpolation-mode:bicubic;vertical-align:middle;max-width:100%}svg:not(:root){overflow:hidden}figure,form{margin:0}label{cursor:pointer}button,input,select,textarea{font-size:100%;margin:0;vertical-align:baseline;*vertical-align:middle}button,input{line-height:normal}button,input[type=button],input[type=reset],input[type=submit]{cursor:pointer;-webkit-appearance:button;*overflow:visible}button[disabled],input[disabled]{cursor:default}input[type=search]{-webkit-appearance:textfield;-moz-box-sizing:content-box;-webkit-box-sizing:content-box;box-sizing:content-box}textarea{resize:vertical}table{border-collapse:collapse;border-spacing:0}td{vertical-align:top}.chromeframe{margin:.2em 0;background:#ccc;color:#000;padding:.2em 0}.ir{display:block;border:0;text-indent:-999em;overflow:hidden;background-color:transparent;background-repeat:no-repeat;text-align:left;direction:ltr;*line-height:0}.ir br{display:none}.hidden{display:none!important;visibility:hidden}.visuallyhidden{border:0;clip:rect(0 0 0 0);height:1px;margin:-1px;overflow:hidden;padding:0;position:absolute;width:1px}.visuallyhidden.focusable:active,.visuallyhidden.focusable:focus{clip:auto;height:auto;margin:0;overflow:visible;position:static;width:auto}.invisible{visibility:hidden}.relative{position:relative}big,small{font-size:100%}@media print{body,html,section{background:none!important}*{box-shadow:none!important;text-shadow:none!important;filter:none!important;-ms-filter:none!important}a,a:visited{text-decoration:underline}.ir a:after,a[href^="#"]:after,a[href^="javascript:"]:after{content:""}blockquote,pre{page-break-inside:avoid}thead{display:table-header-group}img,tr{page-break-inside:avoid}img{max-width:100%!important}@page{margin:.5cm}.rst-content .toctree-wrapper>p.caption,h2,h3,p{orphans:3;widows:3}.rst-content .toctree-wrapper>p.caption,h2,h3{page-break-after:avoid}}.btn,.fa:before,.icon:before,.rst-content .admonition,.rst-content .admonition-title:before,.rst-content .admonition-todo,.rst-content .attention,.rst-content .caution,.rst-content .code-block-caption .headerlink:before,.rst-content .danger,.rst-content .eqno .headerlink:before,.rst-content .error,.rst-content .hint,.rst-content .important,.rst-content .note,.rst-content .seealso,.rst-content .tip,.rst-content .warning,.rst-content code.download span:first-child:before,.rst-content dl dt .headerlink:before,.rst-content h1 .headerlink:before,.rst-content h2 .headerlink:before,.rst-content h3 .headerlink:before,.rst-content h4 .headerlink:before,.rst-content h5 .headerlink:before,.rst-content h6 .headerlink:before,.rst-content p.caption .headerlink:before,.rst-content p .headerlink:before,.rst-content table>caption .headerlink:before,.rst-content tt.download span:first-child:before,.wy-alert,.wy-dropdown .caret:before,.wy-inline-validate.wy-inline-validate-danger .wy-input-context:before,.wy-inline-validate.wy-inline-validate-info .wy-input-context:before,.wy-inline-validate.wy-inline-validate-success .wy-input-context:before,.wy-inline-validate.wy-inline-validate-warning .wy-input-context:before,.wy-menu-vertical li.current>a,.wy-menu-vertical li.current>a button.toctree-expand:before,.wy-menu-vertical li.on a,.wy-menu-vertical li.on a button.toctree-expand:before,.wy-menu-vertical li button.toctree-expand:before,.wy-nav-top a,.wy-side-nav-search .wy-dropdown>a,.wy-side-nav-search>a,input[type=color],input[type=date],input[type=datetime-local],input[type=datetime],input[type=email],input[type=month],input[type=number],input[type=password],input[type=search],input[type=tel],input[type=text],input[type=time],input[type=url],input[type=week],select,textarea{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}/*! + * Font Awesome 4.7.0 by @davegandy - http://fontawesome.io - @fontawesome + * License - http://fontawesome.io/license (Font: SIL OFL 1.1, CSS: MIT License) + */@font-face{font-family:FontAwesome;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713);src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix&v=4.7.0) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#fontawesomeregular) format("svg");font-weight:400;font-style:normal}.fa,.icon,.rst-content .admonition-title,.rst-content .code-block-caption .headerlink,.rst-content .eqno .headerlink,.rst-content code.download span:first-child,.rst-content dl dt .headerlink,.rst-content h1 .headerlink,.rst-content h2 .headerlink,.rst-content h3 .headerlink,.rst-content h4 .headerlink,.rst-content h5 .headerlink,.rst-content h6 .headerlink,.rst-content p.caption .headerlink,.rst-content p .headerlink,.rst-content table>caption .headerlink,.rst-content tt.download span:first-child,.wy-menu-vertical li.current>a button.toctree-expand,.wy-menu-vertical li.on a button.toctree-expand,.wy-menu-vertical li button.toctree-expand{display:inline-block;font:normal normal normal 14px/1 FontAwesome;font-size:inherit;text-rendering:auto;-webkit-font-smoothing:antialiased;-moz-osx-font-smoothing:grayscale}.fa-lg{font-size:1.33333em;line-height:.75em;vertical-align:-15%}.fa-2x{font-size:2em}.fa-3x{font-size:3em}.fa-4x{font-size:4em}.fa-5x{font-size:5em}.fa-fw{width:1.28571em;text-align:center}.fa-ul{padding-left:0;margin-left:2.14286em;list-style-type:none}.fa-ul>li{position:relative}.fa-li{position:absolute;left:-2.14286em;width:2.14286em;top:.14286em;text-align:center}.fa-li.fa-lg{left:-1.85714em}.fa-border{padding:.2em .25em .15em;border:.08em solid #eee;border-radius:.1em}.fa-pull-left{float:left}.fa-pull-right{float:right}.fa-pull-left.icon,.fa.fa-pull-left,.rst-content .code-block-caption .fa-pull-left.headerlink,.rst-content .eqno .fa-pull-left.headerlink,.rst-content .fa-pull-left.admonition-title,.rst-content code.download span.fa-pull-left:first-child,.rst-content dl dt .fa-pull-left.headerlink,.rst-content h1 .fa-pull-left.headerlink,.rst-content h2 .fa-pull-left.headerlink,.rst-content h3 .fa-pull-left.headerlink,.rst-content h4 .fa-pull-left.headerlink,.rst-content h5 .fa-pull-left.headerlink,.rst-content h6 .fa-pull-left.headerlink,.rst-content p .fa-pull-left.headerlink,.rst-content table>caption .fa-pull-left.headerlink,.rst-content tt.download span.fa-pull-left:first-child,.wy-menu-vertical li.current>a button.fa-pull-left.toctree-expand,.wy-menu-vertical li.on a button.fa-pull-left.toctree-expand,.wy-menu-vertical li button.fa-pull-left.toctree-expand{margin-right:.3em}.fa-pull-right.icon,.fa.fa-pull-right,.rst-content .code-block-caption .fa-pull-right.headerlink,.rst-content .eqno .fa-pull-right.headerlink,.rst-content .fa-pull-right.admonition-title,.rst-content code.download span.fa-pull-right:first-child,.rst-content dl dt .fa-pull-right.headerlink,.rst-content h1 .fa-pull-right.headerlink,.rst-content h2 .fa-pull-right.headerlink,.rst-content h3 .fa-pull-right.headerlink,.rst-content h4 .fa-pull-right.headerlink,.rst-content h5 .fa-pull-right.headerlink,.rst-content h6 .fa-pull-right.headerlink,.rst-content p .fa-pull-right.headerlink,.rst-content table>caption .fa-pull-right.headerlink,.rst-content tt.download span.fa-pull-right:first-child,.wy-menu-vertical li.current>a button.fa-pull-right.toctree-expand,.wy-menu-vertical li.on a button.fa-pull-right.toctree-expand,.wy-menu-vertical li button.fa-pull-right.toctree-expand{margin-left:.3em}.pull-right{float:right}.pull-left{float:left}.fa.pull-left,.pull-left.icon,.rst-content .code-block-caption .pull-left.headerlink,.rst-content .eqno .pull-left.headerlink,.rst-content .pull-left.admonition-title,.rst-content code.download span.pull-left:first-child,.rst-content dl dt .pull-left.headerlink,.rst-content h1 .pull-left.headerlink,.rst-content h2 .pull-left.headerlink,.rst-content h3 .pull-left.headerlink,.rst-content h4 .pull-left.headerlink,.rst-content h5 .pull-left.headerlink,.rst-content h6 .pull-left.headerlink,.rst-content p .pull-left.headerlink,.rst-content table>caption .pull-left.headerlink,.rst-content tt.download span.pull-left:first-child,.wy-menu-vertical li.current>a button.pull-left.toctree-expand,.wy-menu-vertical li.on a button.pull-left.toctree-expand,.wy-menu-vertical li button.pull-left.toctree-expand{margin-right:.3em}.fa.pull-right,.pull-right.icon,.rst-content .code-block-caption .pull-right.headerlink,.rst-content .eqno .pull-right.headerlink,.rst-content .pull-right.admonition-title,.rst-content code.download span.pull-right:first-child,.rst-content dl dt .pull-right.headerlink,.rst-content h1 .pull-right.headerlink,.rst-content h2 .pull-right.headerlink,.rst-content h3 .pull-right.headerlink,.rst-content h4 .pull-right.headerlink,.rst-content h5 .pull-right.headerlink,.rst-content h6 .pull-right.headerlink,.rst-content p .pull-right.headerlink,.rst-content table>caption .pull-right.headerlink,.rst-content tt.download span.pull-right:first-child,.wy-menu-vertical li.current>a button.pull-right.toctree-expand,.wy-menu-vertical li.on a button.pull-right.toctree-expand,.wy-menu-vertical li button.pull-right.toctree-expand{margin-left:.3em}.fa-spin{-webkit-animation:fa-spin 2s linear infinite;animation:fa-spin 2s linear infinite}.fa-pulse{-webkit-animation:fa-spin 1s steps(8) infinite;animation:fa-spin 1s steps(8) infinite}@-webkit-keyframes fa-spin{0%{-webkit-transform:rotate(0deg);transform:rotate(0deg)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes fa-spin{0%{-webkit-transform:rotate(0deg);transform:rotate(0deg)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.fa-rotate-90{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=1)";-webkit-transform:rotate(90deg);-ms-transform:rotate(90deg);transform:rotate(90deg)}.fa-rotate-180{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=2)";-webkit-transform:rotate(180deg);-ms-transform:rotate(180deg);transform:rotate(180deg)}.fa-rotate-270{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=3)";-webkit-transform:rotate(270deg);-ms-transform:rotate(270deg);transform:rotate(270deg)}.fa-flip-horizontal{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=0, mirror=1)";-webkit-transform:scaleX(-1);-ms-transform:scaleX(-1);transform:scaleX(-1)}.fa-flip-vertical{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=2, mirror=1)";-webkit-transform:scaleY(-1);-ms-transform:scaleY(-1);transform:scaleY(-1)}:root .fa-flip-horizontal,:root .fa-flip-vertical,:root .fa-rotate-90,:root .fa-rotate-180,:root .fa-rotate-270{filter:none}.fa-stack{position:relative;display:inline-block;width:2em;height:2em;line-height:2em;vertical-align:middle}.fa-stack-1x,.fa-stack-2x{position:absolute;left:0;width:100%;text-align:center}.fa-stack-1x{line-height:inherit}.fa-stack-2x{font-size:2em}.fa-inverse{color:#fff}.fa-glass:before{content:""}.fa-music:before{content:""}.fa-search:before,.icon-search:before{content:""}.fa-envelope-o:before{content:""}.fa-heart:before{content:""}.fa-star:before{content:""}.fa-star-o:before{content:""}.fa-user:before{content:""}.fa-film:before{content:""}.fa-th-large:before{content:""}.fa-th:before{content:""}.fa-th-list:before{content:""}.fa-check:before{content:""}.fa-close:before,.fa-remove:before,.fa-times:before{content:""}.fa-search-plus:before{content:""}.fa-search-minus:before{content:""}.fa-power-off:before{content:""}.fa-signal:before{content:""}.fa-cog:before,.fa-gear:before{content:""}.fa-trash-o:before{content:""}.fa-home:before,.icon-home:before{content:""}.fa-file-o:before{content:""}.fa-clock-o:before{content:""}.fa-road:before{content:""}.fa-download:before,.rst-content code.download span:first-child:before,.rst-content tt.download span:first-child:before{content:""}.fa-arrow-circle-o-down:before{content:""}.fa-arrow-circle-o-up:before{content:""}.fa-inbox:before{content:""}.fa-play-circle-o:before{content:""}.fa-repeat:before,.fa-rotate-right:before{content:""}.fa-refresh:before{content:""}.fa-list-alt:before{content:""}.fa-lock:before{content:""}.fa-flag:before{content:""}.fa-headphones:before{content:""}.fa-volume-off:before{content:""}.fa-volume-down:before{content:""}.fa-volume-up:before{content:""}.fa-qrcode:before{content:""}.fa-barcode:before{content:""}.fa-tag:before{content:""}.fa-tags:before{content:""}.fa-book:before,.icon-book:before{content:""}.fa-bookmark:before{content:""}.fa-print:before{content:""}.fa-camera:before{content:""}.fa-font:before{content:""}.fa-bold:before{content:""}.fa-italic:before{content:""}.fa-text-height:before{content:""}.fa-text-width:before{content:""}.fa-align-left:before{content:""}.fa-align-center:before{content:""}.fa-align-right:before{content:""}.fa-align-justify:before{content:""}.fa-list:before{content:""}.fa-dedent:before,.fa-outdent:before{content:""}.fa-indent:before{content:""}.fa-video-camera:before{content:""}.fa-image:before,.fa-photo:before,.fa-picture-o:before{content:""}.fa-pencil:before{content:""}.fa-map-marker:before{content:""}.fa-adjust:before{content:""}.fa-tint:before{content:""}.fa-edit:before,.fa-pencil-square-o:before{content:""}.fa-share-square-o:before{content:""}.fa-check-square-o:before{content:""}.fa-arrows:before{content:""}.fa-step-backward:before{content:""}.fa-fast-backward:before{content:""}.fa-backward:before{content:""}.fa-play:before{content:""}.fa-pause:before{content:""}.fa-stop:before{content:""}.fa-forward:before{content:""}.fa-fast-forward:before{content:""}.fa-step-forward:before{content:""}.fa-eject:before{content:""}.fa-chevron-left:before{content:""}.fa-chevron-right:before{content:""}.fa-plus-circle:before{content:""}.fa-minus-circle:before{content:""}.fa-times-circle:before,.wy-inline-validate.wy-inline-validate-danger .wy-input-context:before{content:""}.fa-check-circle:before,.wy-inline-validate.wy-inline-validate-success .wy-input-context:before{content:""}.fa-question-circle:before{content:""}.fa-info-circle:before{content:""}.fa-crosshairs:before{content:""}.fa-times-circle-o:before{content:""}.fa-check-circle-o:before{content:""}.fa-ban:before{content:""}.fa-arrow-left:before{content:""}.fa-arrow-right:before{content:""}.fa-arrow-up:before{content:""}.fa-arrow-down:before{content:""}.fa-mail-forward:before,.fa-share:before{content:""}.fa-expand:before{content:""}.fa-compress:before{content:""}.fa-plus:before{content:""}.fa-minus:before{content:""}.fa-asterisk:before{content:""}.fa-exclamation-circle:before,.rst-content .admonition-title:before,.wy-inline-validate.wy-inline-validate-info .wy-input-context:before,.wy-inline-validate.wy-inline-validate-warning .wy-input-context:before{content:""}.fa-gift:before{content:""}.fa-leaf:before{content:""}.fa-fire:before,.icon-fire:before{content:""}.fa-eye:before{content:""}.fa-eye-slash:before{content:""}.fa-exclamation-triangle:before,.fa-warning:before{content:""}.fa-plane:before{content:""}.fa-calendar:before{content:""}.fa-random:before{content:""}.fa-comment:before{content:""}.fa-magnet:before{content:""}.fa-chevron-up:before{content:""}.fa-chevron-down:before{content:""}.fa-retweet:before{content:""}.fa-shopping-cart:before{content:""}.fa-folder:before{content:""}.fa-folder-open:before{content:""}.fa-arrows-v:before{content:""}.fa-arrows-h:before{content:""}.fa-bar-chart-o:before,.fa-bar-chart:before{content:""}.fa-twitter-square:before{content:""}.fa-facebook-square:before{content:""}.fa-camera-retro:before{content:""}.fa-key:before{content:""}.fa-cogs:before,.fa-gears:before{content:""}.fa-comments:before{content:""}.fa-thumbs-o-up:before{content:""}.fa-thumbs-o-down:before{content:""}.fa-star-half:before{content:""}.fa-heart-o:before{content:""}.fa-sign-out:before{content:""}.fa-linkedin-square:before{content:""}.fa-thumb-tack:before{content:""}.fa-external-link:before{content:""}.fa-sign-in:before{content:""}.fa-trophy:before{content:""}.fa-github-square:before{content:""}.fa-upload:before{content:""}.fa-lemon-o:before{content:""}.fa-phone:before{content:""}.fa-square-o:before{content:""}.fa-bookmark-o:before{content:""}.fa-phone-square:before{content:""}.fa-twitter:before{content:""}.fa-facebook-f:before,.fa-facebook:before{content:""}.fa-github:before,.icon-github:before{content:""}.fa-unlock:before{content:""}.fa-credit-card:before{content:""}.fa-feed:before,.fa-rss:before{content:""}.fa-hdd-o:before{content:""}.fa-bullhorn:before{content:""}.fa-bell:before{content:""}.fa-certificate:before{content:""}.fa-hand-o-right:before{content:""}.fa-hand-o-left:before{content:""}.fa-hand-o-up:before{content:""}.fa-hand-o-down:before{content:""}.fa-arrow-circle-left:before,.icon-circle-arrow-left:before{content:""}.fa-arrow-circle-right:before,.icon-circle-arrow-right:before{content:""}.fa-arrow-circle-up:before{content:""}.fa-arrow-circle-down:before{content:""}.fa-globe:before{content:""}.fa-wrench:before{content:""}.fa-tasks:before{content:""}.fa-filter:before{content:""}.fa-briefcase:before{content:""}.fa-arrows-alt:before{content:""}.fa-group:before,.fa-users:before{content:""}.fa-chain:before,.fa-link:before,.icon-link:before{content:""}.fa-cloud:before{content:""}.fa-flask:before{content:""}.fa-cut:before,.fa-scissors:before{content:""}.fa-copy:before,.fa-files-o:before{content:""}.fa-paperclip:before{content:""}.fa-floppy-o:before,.fa-save:before{content:""}.fa-square:before{content:""}.fa-bars:before,.fa-navicon:before,.fa-reorder:before{content:""}.fa-list-ul:before{content:""}.fa-list-ol:before{content:""}.fa-strikethrough:before{content:""}.fa-underline:before{content:""}.fa-table:before{content:""}.fa-magic:before{content:""}.fa-truck:before{content:""}.fa-pinterest:before{content:""}.fa-pinterest-square:before{content:""}.fa-google-plus-square:before{content:""}.fa-google-plus:before{content:""}.fa-money:before{content:""}.fa-caret-down:before,.icon-caret-down:before,.wy-dropdown .caret:before{content:""}.fa-caret-up:before{content:""}.fa-caret-left:before{content:""}.fa-caret-right:before{content:""}.fa-columns:before{content:""}.fa-sort:before,.fa-unsorted:before{content:""}.fa-sort-desc:before,.fa-sort-down:before{content:""}.fa-sort-asc:before,.fa-sort-up:before{content:""}.fa-envelope:before{content:""}.fa-linkedin:before{content:""}.fa-rotate-left:before,.fa-undo:before{content:""}.fa-gavel:before,.fa-legal:before{content:""}.fa-dashboard:before,.fa-tachometer:before{content:""}.fa-comment-o:before{content:""}.fa-comments-o:before{content:""}.fa-bolt:before,.fa-flash:before{content:""}.fa-sitemap:before{content:""}.fa-umbrella:before{content:""}.fa-clipboard:before,.fa-paste:before{content:""}.fa-lightbulb-o:before{content:""}.fa-exchange:before{content:""}.fa-cloud-download:before{content:""}.fa-cloud-upload:before{content:""}.fa-user-md:before{content:""}.fa-stethoscope:before{content:""}.fa-suitcase:before{content:""}.fa-bell-o:before{content:""}.fa-coffee:before{content:""}.fa-cutlery:before{content:""}.fa-file-text-o:before{content:""}.fa-building-o:before{content:""}.fa-hospital-o:before{content:""}.fa-ambulance:before{content:""}.fa-medkit:before{content:""}.fa-fighter-jet:before{content:""}.fa-beer:before{content:""}.fa-h-square:before{content:""}.fa-plus-square:before{content:""}.fa-angle-double-left:before{content:""}.fa-angle-double-right:before{content:""}.fa-angle-double-up:before{content:""}.fa-angle-double-down:before{content:""}.fa-angle-left:before{content:""}.fa-angle-right:before{content:""}.fa-angle-up:before{content:""}.fa-angle-down:before{content:""}.fa-desktop:before{content:""}.fa-laptop:before{content:""}.fa-tablet:before{content:""}.fa-mobile-phone:before,.fa-mobile:before{content:""}.fa-circle-o:before{content:""}.fa-quote-left:before{content:""}.fa-quote-right:before{content:""}.fa-spinner:before{content:""}.fa-circle:before{content:""}.fa-mail-reply:before,.fa-reply:before{content:""}.fa-github-alt:before{content:""}.fa-folder-o:before{content:""}.fa-folder-open-o:before{content:""}.fa-smile-o:before{content:""}.fa-frown-o:before{content:""}.fa-meh-o:before{content:""}.fa-gamepad:before{content:""}.fa-keyboard-o:before{content:""}.fa-flag-o:before{content:""}.fa-flag-checkered:before{content:""}.fa-terminal:before{content:""}.fa-code:before{content:""}.fa-mail-reply-all:before,.fa-reply-all:before{content:""}.fa-star-half-empty:before,.fa-star-half-full:before,.fa-star-half-o:before{content:""}.fa-location-arrow:before{content:""}.fa-crop:before{content:""}.fa-code-fork:before{content:""}.fa-chain-broken:before,.fa-unlink:before{content:""}.fa-question:before{content:""}.fa-info:before{content:""}.fa-exclamation:before{content:""}.fa-superscript:before{content:""}.fa-subscript:before{content:""}.fa-eraser:before{content:""}.fa-puzzle-piece:before{content:""}.fa-microphone:before{content:""}.fa-microphone-slash:before{content:""}.fa-shield:before{content:""}.fa-calendar-o:before{content:""}.fa-fire-extinguisher:before{content:""}.fa-rocket:before{content:""}.fa-maxcdn:before{content:""}.fa-chevron-circle-left:before{content:""}.fa-chevron-circle-right:before{content:""}.fa-chevron-circle-up:before{content:""}.fa-chevron-circle-down:before{content:""}.fa-html5:before{content:""}.fa-css3:before{content:""}.fa-anchor:before{content:""}.fa-unlock-alt:before{content:""}.fa-bullseye:before{content:""}.fa-ellipsis-h:before{content:""}.fa-ellipsis-v:before{content:""}.fa-rss-square:before{content:""}.fa-play-circle:before{content:""}.fa-ticket:before{content:""}.fa-minus-square:before{content:""}.fa-minus-square-o:before,.wy-menu-vertical li.current>a button.toctree-expand:before,.wy-menu-vertical li.on a button.toctree-expand:before{content:""}.fa-level-up:before{content:""}.fa-level-down:before{content:""}.fa-check-square:before{content:""}.fa-pencil-square:before{content:""}.fa-external-link-square:before{content:""}.fa-share-square:before{content:""}.fa-compass:before{content:""}.fa-caret-square-o-down:before,.fa-toggle-down:before{content:""}.fa-caret-square-o-up:before,.fa-toggle-up:before{content:""}.fa-caret-square-o-right:before,.fa-toggle-right:before{content:""}.fa-eur:before,.fa-euro:before{content:""}.fa-gbp:before{content:""}.fa-dollar:before,.fa-usd:before{content:""}.fa-inr:before,.fa-rupee:before{content:""}.fa-cny:before,.fa-jpy:before,.fa-rmb:before,.fa-yen:before{content:""}.fa-rouble:before,.fa-rub:before,.fa-ruble:before{content:""}.fa-krw:before,.fa-won:before{content:""}.fa-bitcoin:before,.fa-btc:before{content:""}.fa-file:before{content:""}.fa-file-text:before{content:""}.fa-sort-alpha-asc:before{content:""}.fa-sort-alpha-desc:before{content:""}.fa-sort-amount-asc:before{content:""}.fa-sort-amount-desc:before{content:""}.fa-sort-numeric-asc:before{content:""}.fa-sort-numeric-desc:before{content:""}.fa-thumbs-up:before{content:""}.fa-thumbs-down:before{content:""}.fa-youtube-square:before{content:""}.fa-youtube:before{content:""}.fa-xing:before{content:""}.fa-xing-square:before{content:""}.fa-youtube-play:before{content:""}.fa-dropbox:before{content:""}.fa-stack-overflow:before{content:""}.fa-instagram:before{content:""}.fa-flickr:before{content:""}.fa-adn:before{content:""}.fa-bitbucket:before,.icon-bitbucket:before{content:""}.fa-bitbucket-square:before{content:""}.fa-tumblr:before{content:""}.fa-tumblr-square:before{content:""}.fa-long-arrow-down:before{content:""}.fa-long-arrow-up:before{content:""}.fa-long-arrow-left:before{content:""}.fa-long-arrow-right:before{content:""}.fa-apple:before{content:""}.fa-windows:before{content:""}.fa-android:before{content:""}.fa-linux:before{content:""}.fa-dribbble:before{content:""}.fa-skype:before{content:""}.fa-foursquare:before{content:""}.fa-trello:before{content:""}.fa-female:before{content:""}.fa-male:before{content:""}.fa-gittip:before,.fa-gratipay:before{content:""}.fa-sun-o:before{content:""}.fa-moon-o:before{content:""}.fa-archive:before{content:""}.fa-bug:before{content:""}.fa-vk:before{content:""}.fa-weibo:before{content:""}.fa-renren:before{content:""}.fa-pagelines:before{content:""}.fa-stack-exchange:before{content:""}.fa-arrow-circle-o-right:before{content:""}.fa-arrow-circle-o-left:before{content:""}.fa-caret-square-o-left:before,.fa-toggle-left:before{content:""}.fa-dot-circle-o:before{content:""}.fa-wheelchair:before{content:""}.fa-vimeo-square:before{content:""}.fa-try:before,.fa-turkish-lira:before{content:""}.fa-plus-square-o:before,.wy-menu-vertical li button.toctree-expand:before{content:""}.fa-space-shuttle:before{content:""}.fa-slack:before{content:""}.fa-envelope-square:before{content:""}.fa-wordpress:before{content:""}.fa-openid:before{content:""}.fa-bank:before,.fa-institution:before,.fa-university:before{content:""}.fa-graduation-cap:before,.fa-mortar-board:before{content:""}.fa-yahoo:before{content:""}.fa-google:before{content:""}.fa-reddit:before{content:""}.fa-reddit-square:before{content:""}.fa-stumbleupon-circle:before{content:""}.fa-stumbleupon:before{content:""}.fa-delicious:before{content:""}.fa-digg:before{content:""}.fa-pied-piper-pp:before{content:""}.fa-pied-piper-alt:before{content:""}.fa-drupal:before{content:""}.fa-joomla:before{content:""}.fa-language:before{content:""}.fa-fax:before{content:""}.fa-building:before{content:""}.fa-child:before{content:""}.fa-paw:before{content:""}.fa-spoon:before{content:""}.fa-cube:before{content:""}.fa-cubes:before{content:""}.fa-behance:before{content:""}.fa-behance-square:before{content:""}.fa-steam:before{content:""}.fa-steam-square:before{content:""}.fa-recycle:before{content:""}.fa-automobile:before,.fa-car:before{content:""}.fa-cab:before,.fa-taxi:before{content:""}.fa-tree:before{content:""}.fa-spotify:before{content:""}.fa-deviantart:before{content:""}.fa-soundcloud:before{content:""}.fa-database:before{content:""}.fa-file-pdf-o:before{content:""}.fa-file-word-o:before{content:""}.fa-file-excel-o:before{content:""}.fa-file-powerpoint-o:before{content:""}.fa-file-image-o:before,.fa-file-photo-o:before,.fa-file-picture-o:before{content:""}.fa-file-archive-o:before,.fa-file-zip-o:before{content:""}.fa-file-audio-o:before,.fa-file-sound-o:before{content:""}.fa-file-movie-o:before,.fa-file-video-o:before{content:""}.fa-file-code-o:before{content:""}.fa-vine:before{content:""}.fa-codepen:before{content:""}.fa-jsfiddle:before{content:""}.fa-life-bouy:before,.fa-life-buoy:before,.fa-life-ring:before,.fa-life-saver:before,.fa-support:before{content:""}.fa-circle-o-notch:before{content:""}.fa-ra:before,.fa-rebel:before,.fa-resistance:before{content:""}.fa-empire:before,.fa-ge:before{content:""}.fa-git-square:before{content:""}.fa-git:before{content:""}.fa-hacker-news:before,.fa-y-combinator-square:before,.fa-yc-square:before{content:""}.fa-tencent-weibo:before{content:""}.fa-qq:before{content:""}.fa-wechat:before,.fa-weixin:before{content:""}.fa-paper-plane:before,.fa-send:before{content:""}.fa-paper-plane-o:before,.fa-send-o:before{content:""}.fa-history:before{content:""}.fa-circle-thin:before{content:""}.fa-header:before{content:""}.fa-paragraph:before{content:""}.fa-sliders:before{content:""}.fa-share-alt:before{content:""}.fa-share-alt-square:before{content:""}.fa-bomb:before{content:""}.fa-futbol-o:before,.fa-soccer-ball-o:before{content:""}.fa-tty:before{content:""}.fa-binoculars:before{content:""}.fa-plug:before{content:""}.fa-slideshare:before{content:""}.fa-twitch:before{content:""}.fa-yelp:before{content:""}.fa-newspaper-o:before{content:""}.fa-wifi:before{content:""}.fa-calculator:before{content:""}.fa-paypal:before{content:""}.fa-google-wallet:before{content:""}.fa-cc-visa:before{content:""}.fa-cc-mastercard:before{content:""}.fa-cc-discover:before{content:""}.fa-cc-amex:before{content:""}.fa-cc-paypal:before{content:""}.fa-cc-stripe:before{content:""}.fa-bell-slash:before{content:""}.fa-bell-slash-o:before{content:""}.fa-trash:before{content:""}.fa-copyright:before{content:""}.fa-at:before{content:""}.fa-eyedropper:before{content:""}.fa-paint-brush:before{content:""}.fa-birthday-cake:before{content:""}.fa-area-chart:before{content:""}.fa-pie-chart:before{content:""}.fa-line-chart:before{content:""}.fa-lastfm:before{content:""}.fa-lastfm-square:before{content:""}.fa-toggle-off:before{content:""}.fa-toggle-on:before{content:""}.fa-bicycle:before{content:""}.fa-bus:before{content:""}.fa-ioxhost:before{content:""}.fa-angellist:before{content:""}.fa-cc:before{content:""}.fa-ils:before,.fa-shekel:before,.fa-sheqel:before{content:""}.fa-meanpath:before{content:""}.fa-buysellads:before{content:""}.fa-connectdevelop:before{content:""}.fa-dashcube:before{content:""}.fa-forumbee:before{content:""}.fa-leanpub:before{content:""}.fa-sellsy:before{content:""}.fa-shirtsinbulk:before{content:""}.fa-simplybuilt:before{content:""}.fa-skyatlas:before{content:""}.fa-cart-plus:before{content:""}.fa-cart-arrow-down:before{content:""}.fa-diamond:before{content:""}.fa-ship:before{content:""}.fa-user-secret:before{content:""}.fa-motorcycle:before{content:""}.fa-street-view:before{content:""}.fa-heartbeat:before{content:""}.fa-venus:before{content:""}.fa-mars:before{content:""}.fa-mercury:before{content:""}.fa-intersex:before,.fa-transgender:before{content:""}.fa-transgender-alt:before{content:""}.fa-venus-double:before{content:""}.fa-mars-double:before{content:""}.fa-venus-mars:before{content:""}.fa-mars-stroke:before{content:""}.fa-mars-stroke-v:before{content:""}.fa-mars-stroke-h:before{content:""}.fa-neuter:before{content:""}.fa-genderless:before{content:""}.fa-facebook-official:before{content:""}.fa-pinterest-p:before{content:""}.fa-whatsapp:before{content:""}.fa-server:before{content:""}.fa-user-plus:before{content:""}.fa-user-times:before{content:""}.fa-bed:before,.fa-hotel:before{content:""}.fa-viacoin:before{content:""}.fa-train:before{content:""}.fa-subway:before{content:""}.fa-medium:before{content:""}.fa-y-combinator:before,.fa-yc:before{content:""}.fa-optin-monster:before{content:""}.fa-opencart:before{content:""}.fa-expeditedssl:before{content:""}.fa-battery-4:before,.fa-battery-full:before,.fa-battery:before{content:""}.fa-battery-3:before,.fa-battery-three-quarters:before{content:""}.fa-battery-2:before,.fa-battery-half:before{content:""}.fa-battery-1:before,.fa-battery-quarter:before{content:""}.fa-battery-0:before,.fa-battery-empty:before{content:""}.fa-mouse-pointer:before{content:""}.fa-i-cursor:before{content:""}.fa-object-group:before{content:""}.fa-object-ungroup:before{content:""}.fa-sticky-note:before{content:""}.fa-sticky-note-o:before{content:""}.fa-cc-jcb:before{content:""}.fa-cc-diners-club:before{content:""}.fa-clone:before{content:""}.fa-balance-scale:before{content:""}.fa-hourglass-o:before{content:""}.fa-hourglass-1:before,.fa-hourglass-start:before{content:""}.fa-hourglass-2:before,.fa-hourglass-half:before{content:""}.fa-hourglass-3:before,.fa-hourglass-end:before{content:""}.fa-hourglass:before{content:""}.fa-hand-grab-o:before,.fa-hand-rock-o:before{content:""}.fa-hand-paper-o:before,.fa-hand-stop-o:before{content:""}.fa-hand-scissors-o:before{content:""}.fa-hand-lizard-o:before{content:""}.fa-hand-spock-o:before{content:""}.fa-hand-pointer-o:before{content:""}.fa-hand-peace-o:before{content:""}.fa-trademark:before{content:""}.fa-registered:before{content:""}.fa-creative-commons:before{content:""}.fa-gg:before{content:""}.fa-gg-circle:before{content:""}.fa-tripadvisor:before{content:""}.fa-odnoklassniki:before{content:""}.fa-odnoklassniki-square:before{content:""}.fa-get-pocket:before{content:""}.fa-wikipedia-w:before{content:""}.fa-safari:before{content:""}.fa-chrome:before{content:""}.fa-firefox:before{content:""}.fa-opera:before{content:""}.fa-internet-explorer:before{content:""}.fa-television:before,.fa-tv:before{content:""}.fa-contao:before{content:""}.fa-500px:before{content:""}.fa-amazon:before{content:""}.fa-calendar-plus-o:before{content:""}.fa-calendar-minus-o:before{content:""}.fa-calendar-times-o:before{content:""}.fa-calendar-check-o:before{content:""}.fa-industry:before{content:""}.fa-map-pin:before{content:""}.fa-map-signs:before{content:""}.fa-map-o:before{content:""}.fa-map:before{content:""}.fa-commenting:before{content:""}.fa-commenting-o:before{content:""}.fa-houzz:before{content:""}.fa-vimeo:before{content:""}.fa-black-tie:before{content:""}.fa-fonticons:before{content:""}.fa-reddit-alien:before{content:""}.fa-edge:before{content:""}.fa-credit-card-alt:before{content:""}.fa-codiepie:before{content:""}.fa-modx:before{content:""}.fa-fort-awesome:before{content:""}.fa-usb:before{content:""}.fa-product-hunt:before{content:""}.fa-mixcloud:before{content:""}.fa-scribd:before{content:""}.fa-pause-circle:before{content:""}.fa-pause-circle-o:before{content:""}.fa-stop-circle:before{content:""}.fa-stop-circle-o:before{content:""}.fa-shopping-bag:before{content:""}.fa-shopping-basket:before{content:""}.fa-hashtag:before{content:""}.fa-bluetooth:before{content:""}.fa-bluetooth-b:before{content:""}.fa-percent:before{content:""}.fa-gitlab:before,.icon-gitlab:before{content:""}.fa-wpbeginner:before{content:""}.fa-wpforms:before{content:""}.fa-envira:before{content:""}.fa-universal-access:before{content:""}.fa-wheelchair-alt:before{content:""}.fa-question-circle-o:before{content:""}.fa-blind:before{content:""}.fa-audio-description:before{content:""}.fa-volume-control-phone:before{content:""}.fa-braille:before{content:""}.fa-assistive-listening-systems:before{content:""}.fa-american-sign-language-interpreting:before,.fa-asl-interpreting:before{content:""}.fa-deaf:before,.fa-deafness:before,.fa-hard-of-hearing:before{content:""}.fa-glide:before{content:""}.fa-glide-g:before{content:""}.fa-sign-language:before,.fa-signing:before{content:""}.fa-low-vision:before{content:""}.fa-viadeo:before{content:""}.fa-viadeo-square:before{content:""}.fa-snapchat:before{content:""}.fa-snapchat-ghost:before{content:""}.fa-snapchat-square:before{content:""}.fa-pied-piper:before{content:""}.fa-first-order:before{content:""}.fa-yoast:before{content:""}.fa-themeisle:before{content:""}.fa-google-plus-circle:before,.fa-google-plus-official:before{content:""}.fa-fa:before,.fa-font-awesome:before{content:""}.fa-handshake-o:before{content:""}.fa-envelope-open:before{content:""}.fa-envelope-open-o:before{content:""}.fa-linode:before{content:""}.fa-address-book:before{content:""}.fa-address-book-o:before{content:""}.fa-address-card:before,.fa-vcard:before{content:""}.fa-address-card-o:before,.fa-vcard-o:before{content:""}.fa-user-circle:before{content:""}.fa-user-circle-o:before{content:""}.fa-user-o:before{content:""}.fa-id-badge:before{content:""}.fa-drivers-license:before,.fa-id-card:before{content:""}.fa-drivers-license-o:before,.fa-id-card-o:before{content:""}.fa-quora:before{content:""}.fa-free-code-camp:before{content:""}.fa-telegram:before{content:""}.fa-thermometer-4:before,.fa-thermometer-full:before,.fa-thermometer:before{content:""}.fa-thermometer-3:before,.fa-thermometer-three-quarters:before{content:""}.fa-thermometer-2:before,.fa-thermometer-half:before{content:""}.fa-thermometer-1:before,.fa-thermometer-quarter:before{content:""}.fa-thermometer-0:before,.fa-thermometer-empty:before{content:""}.fa-shower:before{content:""}.fa-bath:before,.fa-bathtub:before,.fa-s15:before{content:""}.fa-podcast:before{content:""}.fa-window-maximize:before{content:""}.fa-window-minimize:before{content:""}.fa-window-restore:before{content:""}.fa-times-rectangle:before,.fa-window-close:before{content:""}.fa-times-rectangle-o:before,.fa-window-close-o:before{content:""}.fa-bandcamp:before{content:""}.fa-grav:before{content:""}.fa-etsy:before{content:""}.fa-imdb:before{content:""}.fa-ravelry:before{content:""}.fa-eercast:before{content:""}.fa-microchip:before{content:""}.fa-snowflake-o:before{content:""}.fa-superpowers:before{content:""}.fa-wpexplorer:before{content:""}.fa-meetup:before{content:""}.sr-only{position:absolute;width:1px;height:1px;padding:0;margin:-1px;overflow:hidden;clip:rect(0,0,0,0);border:0}.sr-only-focusable:active,.sr-only-focusable:focus{position:static;width:auto;height:auto;margin:0;overflow:visible;clip:auto}.fa,.icon,.rst-content .admonition-title,.rst-content .code-block-caption .headerlink,.rst-content .eqno .headerlink,.rst-content code.download span:first-child,.rst-content dl dt .headerlink,.rst-content h1 .headerlink,.rst-content h2 .headerlink,.rst-content h3 .headerlink,.rst-content h4 .headerlink,.rst-content h5 .headerlink,.rst-content h6 .headerlink,.rst-content p.caption .headerlink,.rst-content p .headerlink,.rst-content table>caption .headerlink,.rst-content tt.download span:first-child,.wy-dropdown .caret,.wy-inline-validate.wy-inline-validate-danger .wy-input-context,.wy-inline-validate.wy-inline-validate-info .wy-input-context,.wy-inline-validate.wy-inline-validate-success .wy-input-context,.wy-inline-validate.wy-inline-validate-warning .wy-input-context,.wy-menu-vertical li.current>a button.toctree-expand,.wy-menu-vertical li.on a button.toctree-expand,.wy-menu-vertical li button.toctree-expand{font-family:inherit}.fa:before,.icon:before,.rst-content .admonition-title:before,.rst-content .code-block-caption .headerlink:before,.rst-content .eqno .headerlink:before,.rst-content code.download span:first-child:before,.rst-content dl dt .headerlink:before,.rst-content h1 .headerlink:before,.rst-content h2 .headerlink:before,.rst-content h3 .headerlink:before,.rst-content h4 .headerlink:before,.rst-content h5 .headerlink:before,.rst-content h6 .headerlink:before,.rst-content p.caption .headerlink:before,.rst-content p .headerlink:before,.rst-content table>caption .headerlink:before,.rst-content tt.download span:first-child:before,.wy-dropdown .caret:before,.wy-inline-validate.wy-inline-validate-danger .wy-input-context:before,.wy-inline-validate.wy-inline-validate-info .wy-input-context:before,.wy-inline-validate.wy-inline-validate-success .wy-input-context:before,.wy-inline-validate.wy-inline-validate-warning .wy-input-context:before,.wy-menu-vertical li.current>a button.toctree-expand:before,.wy-menu-vertical li.on a button.toctree-expand:before,.wy-menu-vertical li button.toctree-expand:before{font-family:FontAwesome;display:inline-block;font-style:normal;font-weight:400;line-height:1;text-decoration:inherit}.rst-content .code-block-caption a .headerlink,.rst-content .eqno a .headerlink,.rst-content a .admonition-title,.rst-content code.download a span:first-child,.rst-content dl dt a .headerlink,.rst-content h1 a .headerlink,.rst-content h2 a .headerlink,.rst-content h3 a .headerlink,.rst-content h4 a .headerlink,.rst-content h5 a .headerlink,.rst-content h6 a .headerlink,.rst-content p.caption a .headerlink,.rst-content p a .headerlink,.rst-content table>caption a .headerlink,.rst-content tt.download a span:first-child,.wy-menu-vertical li.current>a button.toctree-expand,.wy-menu-vertical li.on a button.toctree-expand,.wy-menu-vertical li a button.toctree-expand,a .fa,a .icon,a .rst-content .admonition-title,a .rst-content .code-block-caption .headerlink,a .rst-content .eqno .headerlink,a .rst-content code.download span:first-child,a .rst-content dl dt .headerlink,a .rst-content h1 .headerlink,a .rst-content h2 .headerlink,a .rst-content h3 .headerlink,a .rst-content h4 .headerlink,a .rst-content h5 .headerlink,a .rst-content h6 .headerlink,a .rst-content p.caption .headerlink,a .rst-content p .headerlink,a .rst-content table>caption .headerlink,a .rst-content tt.download span:first-child,a .wy-menu-vertical li button.toctree-expand{display:inline-block;text-decoration:inherit}.btn .fa,.btn .icon,.btn .rst-content .admonition-title,.btn .rst-content .code-block-caption .headerlink,.btn .rst-content .eqno .headerlink,.btn .rst-content code.download span:first-child,.btn .rst-content dl dt .headerlink,.btn .rst-content h1 .headerlink,.btn .rst-content h2 .headerlink,.btn .rst-content h3 .headerlink,.btn .rst-content h4 .headerlink,.btn .rst-content h5 .headerlink,.btn .rst-content h6 .headerlink,.btn .rst-content p .headerlink,.btn .rst-content table>caption .headerlink,.btn .rst-content tt.download span:first-child,.btn .wy-menu-vertical li.current>a button.toctree-expand,.btn .wy-menu-vertical li.on a button.toctree-expand,.btn .wy-menu-vertical li button.toctree-expand,.nav .fa,.nav .icon,.nav .rst-content .admonition-title,.nav .rst-content .code-block-caption .headerlink,.nav .rst-content .eqno .headerlink,.nav .rst-content code.download span:first-child,.nav .rst-content dl dt .headerlink,.nav .rst-content h1 .headerlink,.nav .rst-content h2 .headerlink,.nav .rst-content h3 .headerlink,.nav .rst-content h4 .headerlink,.nav .rst-content h5 .headerlink,.nav .rst-content h6 .headerlink,.nav .rst-content p .headerlink,.nav .rst-content table>caption .headerlink,.nav .rst-content tt.download span:first-child,.nav .wy-menu-vertical li.current>a button.toctree-expand,.nav .wy-menu-vertical li.on a button.toctree-expand,.nav .wy-menu-vertical li button.toctree-expand,.rst-content .btn .admonition-title,.rst-content .code-block-caption .btn .headerlink,.rst-content .code-block-caption .nav .headerlink,.rst-content .eqno .btn .headerlink,.rst-content .eqno .nav .headerlink,.rst-content .nav .admonition-title,.rst-content code.download .btn span:first-child,.rst-content code.download .nav span:first-child,.rst-content dl dt .btn .headerlink,.rst-content dl dt .nav .headerlink,.rst-content h1 .btn .headerlink,.rst-content h1 .nav .headerlink,.rst-content h2 .btn .headerlink,.rst-content h2 .nav .headerlink,.rst-content h3 .btn .headerlink,.rst-content h3 .nav .headerlink,.rst-content h4 .btn .headerlink,.rst-content h4 .nav .headerlink,.rst-content h5 .btn .headerlink,.rst-content h5 .nav .headerlink,.rst-content h6 .btn .headerlink,.rst-content h6 .nav .headerlink,.rst-content p .btn .headerlink,.rst-content p .nav .headerlink,.rst-content table>caption .btn .headerlink,.rst-content table>caption .nav .headerlink,.rst-content tt.download .btn span:first-child,.rst-content tt.download .nav span:first-child,.wy-menu-vertical li .btn button.toctree-expand,.wy-menu-vertical li.current>a .btn button.toctree-expand,.wy-menu-vertical li.current>a .nav button.toctree-expand,.wy-menu-vertical li .nav button.toctree-expand,.wy-menu-vertical li.on a .btn button.toctree-expand,.wy-menu-vertical li.on a .nav button.toctree-expand{display:inline}.btn .fa-large.icon,.btn .fa.fa-large,.btn .rst-content .code-block-caption .fa-large.headerlink,.btn .rst-content .eqno .fa-large.headerlink,.btn .rst-content .fa-large.admonition-title,.btn .rst-content code.download span.fa-large:first-child,.btn .rst-content dl dt .fa-large.headerlink,.btn .rst-content h1 .fa-large.headerlink,.btn .rst-content h2 .fa-large.headerlink,.btn .rst-content h3 .fa-large.headerlink,.btn .rst-content h4 .fa-large.headerlink,.btn .rst-content h5 .fa-large.headerlink,.btn .rst-content h6 .fa-large.headerlink,.btn .rst-content p .fa-large.headerlink,.btn .rst-content table>caption .fa-large.headerlink,.btn .rst-content tt.download span.fa-large:first-child,.btn .wy-menu-vertical li button.fa-large.toctree-expand,.nav .fa-large.icon,.nav .fa.fa-large,.nav .rst-content .code-block-caption .fa-large.headerlink,.nav .rst-content .eqno .fa-large.headerlink,.nav .rst-content .fa-large.admonition-title,.nav .rst-content code.download span.fa-large:first-child,.nav .rst-content dl dt .fa-large.headerlink,.nav .rst-content h1 .fa-large.headerlink,.nav .rst-content h2 .fa-large.headerlink,.nav .rst-content h3 .fa-large.headerlink,.nav .rst-content h4 .fa-large.headerlink,.nav .rst-content h5 .fa-large.headerlink,.nav .rst-content h6 .fa-large.headerlink,.nav .rst-content p .fa-large.headerlink,.nav .rst-content table>caption .fa-large.headerlink,.nav .rst-content tt.download span.fa-large:first-child,.nav .wy-menu-vertical li button.fa-large.toctree-expand,.rst-content .btn .fa-large.admonition-title,.rst-content .code-block-caption .btn .fa-large.headerlink,.rst-content .code-block-caption .nav .fa-large.headerlink,.rst-content .eqno .btn .fa-large.headerlink,.rst-content .eqno .nav .fa-large.headerlink,.rst-content .nav .fa-large.admonition-title,.rst-content code.download .btn span.fa-large:first-child,.rst-content code.download .nav span.fa-large:first-child,.rst-content dl dt .btn .fa-large.headerlink,.rst-content dl dt .nav .fa-large.headerlink,.rst-content h1 .btn .fa-large.headerlink,.rst-content h1 .nav .fa-large.headerlink,.rst-content h2 .btn .fa-large.headerlink,.rst-content h2 .nav .fa-large.headerlink,.rst-content h3 .btn .fa-large.headerlink,.rst-content h3 .nav .fa-large.headerlink,.rst-content h4 .btn .fa-large.headerlink,.rst-content h4 .nav .fa-large.headerlink,.rst-content h5 .btn .fa-large.headerlink,.rst-content h5 .nav .fa-large.headerlink,.rst-content h6 .btn .fa-large.headerlink,.rst-content h6 .nav .fa-large.headerlink,.rst-content p .btn .fa-large.headerlink,.rst-content p .nav .fa-large.headerlink,.rst-content table>caption .btn .fa-large.headerlink,.rst-content table>caption .nav .fa-large.headerlink,.rst-content tt.download .btn span.fa-large:first-child,.rst-content tt.download .nav span.fa-large:first-child,.wy-menu-vertical li .btn button.fa-large.toctree-expand,.wy-menu-vertical li .nav button.fa-large.toctree-expand{line-height:.9em}.btn .fa-spin.icon,.btn .fa.fa-spin,.btn .rst-content .code-block-caption .fa-spin.headerlink,.btn .rst-content .eqno .fa-spin.headerlink,.btn .rst-content .fa-spin.admonition-title,.btn .rst-content code.download span.fa-spin:first-child,.btn .rst-content dl dt .fa-spin.headerlink,.btn .rst-content h1 .fa-spin.headerlink,.btn .rst-content h2 .fa-spin.headerlink,.btn .rst-content h3 .fa-spin.headerlink,.btn .rst-content h4 .fa-spin.headerlink,.btn .rst-content h5 .fa-spin.headerlink,.btn .rst-content h6 .fa-spin.headerlink,.btn .rst-content p .fa-spin.headerlink,.btn .rst-content table>caption .fa-spin.headerlink,.btn .rst-content tt.download span.fa-spin:first-child,.btn .wy-menu-vertical li button.fa-spin.toctree-expand,.nav .fa-spin.icon,.nav .fa.fa-spin,.nav .rst-content .code-block-caption .fa-spin.headerlink,.nav .rst-content .eqno .fa-spin.headerlink,.nav .rst-content .fa-spin.admonition-title,.nav .rst-content code.download span.fa-spin:first-child,.nav .rst-content dl dt .fa-spin.headerlink,.nav .rst-content h1 .fa-spin.headerlink,.nav .rst-content h2 .fa-spin.headerlink,.nav .rst-content h3 .fa-spin.headerlink,.nav .rst-content h4 .fa-spin.headerlink,.nav .rst-content h5 .fa-spin.headerlink,.nav .rst-content h6 .fa-spin.headerlink,.nav .rst-content p .fa-spin.headerlink,.nav .rst-content table>caption .fa-spin.headerlink,.nav .rst-content tt.download span.fa-spin:first-child,.nav .wy-menu-vertical li button.fa-spin.toctree-expand,.rst-content .btn .fa-spin.admonition-title,.rst-content .code-block-caption .btn .fa-spin.headerlink,.rst-content .code-block-caption .nav .fa-spin.headerlink,.rst-content .eqno .btn .fa-spin.headerlink,.rst-content .eqno .nav .fa-spin.headerlink,.rst-content .nav .fa-spin.admonition-title,.rst-content code.download .btn span.fa-spin:first-child,.rst-content code.download .nav span.fa-spin:first-child,.rst-content dl dt .btn .fa-spin.headerlink,.rst-content dl dt .nav .fa-spin.headerlink,.rst-content h1 .btn .fa-spin.headerlink,.rst-content h1 .nav .fa-spin.headerlink,.rst-content h2 .btn .fa-spin.headerlink,.rst-content h2 .nav .fa-spin.headerlink,.rst-content h3 .btn .fa-spin.headerlink,.rst-content h3 .nav .fa-spin.headerlink,.rst-content h4 .btn .fa-spin.headerlink,.rst-content h4 .nav .fa-spin.headerlink,.rst-content h5 .btn .fa-spin.headerlink,.rst-content h5 .nav .fa-spin.headerlink,.rst-content h6 .btn .fa-spin.headerlink,.rst-content h6 .nav .fa-spin.headerlink,.rst-content p .btn .fa-spin.headerlink,.rst-content p .nav .fa-spin.headerlink,.rst-content table>caption .btn .fa-spin.headerlink,.rst-content table>caption .nav .fa-spin.headerlink,.rst-content tt.download .btn span.fa-spin:first-child,.rst-content tt.download .nav span.fa-spin:first-child,.wy-menu-vertical li .btn button.fa-spin.toctree-expand,.wy-menu-vertical li .nav button.fa-spin.toctree-expand{display:inline-block}.btn.fa:before,.btn.icon:before,.rst-content .btn.admonition-title:before,.rst-content .code-block-caption .btn.headerlink:before,.rst-content .eqno .btn.headerlink:before,.rst-content code.download span.btn:first-child:before,.rst-content dl dt .btn.headerlink:before,.rst-content h1 .btn.headerlink:before,.rst-content h2 .btn.headerlink:before,.rst-content h3 .btn.headerlink:before,.rst-content h4 .btn.headerlink:before,.rst-content h5 .btn.headerlink:before,.rst-content h6 .btn.headerlink:before,.rst-content p .btn.headerlink:before,.rst-content table>caption .btn.headerlink:before,.rst-content tt.download span.btn:first-child:before,.wy-menu-vertical li button.btn.toctree-expand:before{opacity:.5;-webkit-transition:opacity .05s ease-in;-moz-transition:opacity .05s ease-in;transition:opacity .05s ease-in}.btn.fa:hover:before,.btn.icon:hover:before,.rst-content .btn.admonition-title:hover:before,.rst-content .code-block-caption .btn.headerlink:hover:before,.rst-content .eqno .btn.headerlink:hover:before,.rst-content code.download span.btn:first-child:hover:before,.rst-content dl dt .btn.headerlink:hover:before,.rst-content h1 .btn.headerlink:hover:before,.rst-content h2 .btn.headerlink:hover:before,.rst-content h3 .btn.headerlink:hover:before,.rst-content h4 .btn.headerlink:hover:before,.rst-content h5 .btn.headerlink:hover:before,.rst-content h6 .btn.headerlink:hover:before,.rst-content p .btn.headerlink:hover:before,.rst-content table>caption .btn.headerlink:hover:before,.rst-content tt.download span.btn:first-child:hover:before,.wy-menu-vertical li button.btn.toctree-expand:hover:before{opacity:1}.btn-mini .fa:before,.btn-mini .icon:before,.btn-mini .rst-content .admonition-title:before,.btn-mini .rst-content .code-block-caption .headerlink:before,.btn-mini .rst-content .eqno .headerlink:before,.btn-mini .rst-content code.download span:first-child:before,.btn-mini .rst-content dl dt .headerlink:before,.btn-mini .rst-content h1 .headerlink:before,.btn-mini .rst-content h2 .headerlink:before,.btn-mini .rst-content h3 .headerlink:before,.btn-mini .rst-content h4 .headerlink:before,.btn-mini .rst-content h5 .headerlink:before,.btn-mini .rst-content h6 .headerlink:before,.btn-mini .rst-content p .headerlink:before,.btn-mini .rst-content table>caption .headerlink:before,.btn-mini .rst-content tt.download span:first-child:before,.btn-mini .wy-menu-vertical li button.toctree-expand:before,.rst-content .btn-mini .admonition-title:before,.rst-content .code-block-caption .btn-mini .headerlink:before,.rst-content .eqno .btn-mini .headerlink:before,.rst-content code.download .btn-mini span:first-child:before,.rst-content dl dt .btn-mini .headerlink:before,.rst-content h1 .btn-mini .headerlink:before,.rst-content h2 .btn-mini .headerlink:before,.rst-content h3 .btn-mini .headerlink:before,.rst-content h4 .btn-mini .headerlink:before,.rst-content h5 .btn-mini .headerlink:before,.rst-content h6 .btn-mini .headerlink:before,.rst-content p .btn-mini .headerlink:before,.rst-content table>caption .btn-mini .headerlink:before,.rst-content tt.download .btn-mini span:first-child:before,.wy-menu-vertical li .btn-mini button.toctree-expand:before{font-size:14px;vertical-align:-15%}.rst-content .admonition,.rst-content .admonition-todo,.rst-content .attention,.rst-content .caution,.rst-content .danger,.rst-content .error,.rst-content .hint,.rst-content .important,.rst-content .note,.rst-content .seealso,.rst-content .tip,.rst-content .warning,.wy-alert{padding:12px;line-height:24px;margin-bottom:24px;background:#e7f2fa}.rst-content .admonition-title,.wy-alert-title{font-weight:700;display:block;color:#fff;background:#6ab0de;padding:6px 12px;margin:-12px -12px 12px}.rst-content .danger,.rst-content .error,.rst-content .wy-alert-danger.admonition,.rst-content .wy-alert-danger.admonition-todo,.rst-content .wy-alert-danger.attention,.rst-content .wy-alert-danger.caution,.rst-content .wy-alert-danger.hint,.rst-content .wy-alert-danger.important,.rst-content .wy-alert-danger.note,.rst-content .wy-alert-danger.seealso,.rst-content .wy-alert-danger.tip,.rst-content .wy-alert-danger.warning,.wy-alert.wy-alert-danger{background:#fdf3f2}.rst-content .danger .admonition-title,.rst-content .danger .wy-alert-title,.rst-content .error .admonition-title,.rst-content .error .wy-alert-title,.rst-content .wy-alert-danger.admonition-todo .admonition-title,.rst-content .wy-alert-danger.admonition-todo .wy-alert-title,.rst-content .wy-alert-danger.admonition .admonition-title,.rst-content .wy-alert-danger.admonition .wy-alert-title,.rst-content .wy-alert-danger.attention .admonition-title,.rst-content .wy-alert-danger.attention .wy-alert-title,.rst-content .wy-alert-danger.caution .admonition-title,.rst-content .wy-alert-danger.caution .wy-alert-title,.rst-content .wy-alert-danger.hint .admonition-title,.rst-content .wy-alert-danger.hint .wy-alert-title,.rst-content .wy-alert-danger.important .admonition-title,.rst-content .wy-alert-danger.important .wy-alert-title,.rst-content .wy-alert-danger.note .admonition-title,.rst-content .wy-alert-danger.note .wy-alert-title,.rst-content .wy-alert-danger.seealso .admonition-title,.rst-content .wy-alert-danger.seealso .wy-alert-title,.rst-content .wy-alert-danger.tip .admonition-title,.rst-content .wy-alert-danger.tip .wy-alert-title,.rst-content .wy-alert-danger.warning .admonition-title,.rst-content .wy-alert-danger.warning .wy-alert-title,.rst-content .wy-alert.wy-alert-danger .admonition-title,.wy-alert.wy-alert-danger .rst-content .admonition-title,.wy-alert.wy-alert-danger .wy-alert-title{background:#f29f97}.rst-content .admonition-todo,.rst-content .attention,.rst-content .caution,.rst-content .warning,.rst-content .wy-alert-warning.admonition,.rst-content .wy-alert-warning.danger,.rst-content .wy-alert-warning.error,.rst-content .wy-alert-warning.hint,.rst-content .wy-alert-warning.important,.rst-content .wy-alert-warning.note,.rst-content .wy-alert-warning.seealso,.rst-content .wy-alert-warning.tip,.wy-alert.wy-alert-warning{background:#ffedcc}.rst-content .admonition-todo .admonition-title,.rst-content .admonition-todo .wy-alert-title,.rst-content .attention .admonition-title,.rst-content .attention .wy-alert-title,.rst-content .caution .admonition-title,.rst-content .caution .wy-alert-title,.rst-content .warning .admonition-title,.rst-content .warning .wy-alert-title,.rst-content .wy-alert-warning.admonition .admonition-title,.rst-content .wy-alert-warning.admonition .wy-alert-title,.rst-content .wy-alert-warning.danger .admonition-title,.rst-content .wy-alert-warning.danger .wy-alert-title,.rst-content .wy-alert-warning.error .admonition-title,.rst-content .wy-alert-warning.error .wy-alert-title,.rst-content .wy-alert-warning.hint .admonition-title,.rst-content .wy-alert-warning.hint .wy-alert-title,.rst-content .wy-alert-warning.important .admonition-title,.rst-content .wy-alert-warning.important .wy-alert-title,.rst-content .wy-alert-warning.note .admonition-title,.rst-content .wy-alert-warning.note .wy-alert-title,.rst-content .wy-alert-warning.seealso .admonition-title,.rst-content .wy-alert-warning.seealso .wy-alert-title,.rst-content .wy-alert-warning.tip .admonition-title,.rst-content .wy-alert-warning.tip .wy-alert-title,.rst-content .wy-alert.wy-alert-warning .admonition-title,.wy-alert.wy-alert-warning .rst-content .admonition-title,.wy-alert.wy-alert-warning .wy-alert-title{background:#f0b37e}.rst-content .note,.rst-content .seealso,.rst-content .wy-alert-info.admonition,.rst-content .wy-alert-info.admonition-todo,.rst-content .wy-alert-info.attention,.rst-content .wy-alert-info.caution,.rst-content .wy-alert-info.danger,.rst-content .wy-alert-info.error,.rst-content .wy-alert-info.hint,.rst-content .wy-alert-info.important,.rst-content .wy-alert-info.tip,.rst-content .wy-alert-info.warning,.wy-alert.wy-alert-info{background:#e7f2fa}.rst-content .note .admonition-title,.rst-content .note .wy-alert-title,.rst-content .seealso .admonition-title,.rst-content .seealso .wy-alert-title,.rst-content .wy-alert-info.admonition-todo .admonition-title,.rst-content .wy-alert-info.admonition-todo .wy-alert-title,.rst-content .wy-alert-info.admonition .admonition-title,.rst-content .wy-alert-info.admonition .wy-alert-title,.rst-content .wy-alert-info.attention .admonition-title,.rst-content .wy-alert-info.attention .wy-alert-title,.rst-content .wy-alert-info.caution .admonition-title,.rst-content .wy-alert-info.caution .wy-alert-title,.rst-content .wy-alert-info.danger .admonition-title,.rst-content .wy-alert-info.danger .wy-alert-title,.rst-content .wy-alert-info.error .admonition-title,.rst-content .wy-alert-info.error .wy-alert-title,.rst-content .wy-alert-info.hint .admonition-title,.rst-content .wy-alert-info.hint .wy-alert-title,.rst-content .wy-alert-info.important .admonition-title,.rst-content .wy-alert-info.important .wy-alert-title,.rst-content .wy-alert-info.tip .admonition-title,.rst-content .wy-alert-info.tip .wy-alert-title,.rst-content .wy-alert-info.warning .admonition-title,.rst-content .wy-alert-info.warning .wy-alert-title,.rst-content .wy-alert.wy-alert-info .admonition-title,.wy-alert.wy-alert-info .rst-content .admonition-title,.wy-alert.wy-alert-info .wy-alert-title{background:#6ab0de}.rst-content .hint,.rst-content .important,.rst-content .tip,.rst-content .wy-alert-success.admonition,.rst-content .wy-alert-success.admonition-todo,.rst-content .wy-alert-success.attention,.rst-content .wy-alert-success.caution,.rst-content .wy-alert-success.danger,.rst-content .wy-alert-success.error,.rst-content .wy-alert-success.note,.rst-content .wy-alert-success.seealso,.rst-content .wy-alert-success.warning,.wy-alert.wy-alert-success{background:#dbfaf4}.rst-content .hint .admonition-title,.rst-content .hint .wy-alert-title,.rst-content .important .admonition-title,.rst-content .important .wy-alert-title,.rst-content .tip .admonition-title,.rst-content .tip .wy-alert-title,.rst-content .wy-alert-success.admonition-todo .admonition-title,.rst-content .wy-alert-success.admonition-todo .wy-alert-title,.rst-content .wy-alert-success.admonition .admonition-title,.rst-content .wy-alert-success.admonition .wy-alert-title,.rst-content .wy-alert-success.attention .admonition-title,.rst-content .wy-alert-success.attention .wy-alert-title,.rst-content .wy-alert-success.caution .admonition-title,.rst-content .wy-alert-success.caution .wy-alert-title,.rst-content .wy-alert-success.danger .admonition-title,.rst-content .wy-alert-success.danger .wy-alert-title,.rst-content .wy-alert-success.error .admonition-title,.rst-content .wy-alert-success.error .wy-alert-title,.rst-content .wy-alert-success.note .admonition-title,.rst-content .wy-alert-success.note .wy-alert-title,.rst-content .wy-alert-success.seealso .admonition-title,.rst-content .wy-alert-success.seealso .wy-alert-title,.rst-content .wy-alert-success.warning .admonition-title,.rst-content .wy-alert-success.warning .wy-alert-title,.rst-content .wy-alert.wy-alert-success .admonition-title,.wy-alert.wy-alert-success .rst-content .admonition-title,.wy-alert.wy-alert-success .wy-alert-title{background:#1abc9c}.rst-content .wy-alert-neutral.admonition,.rst-content .wy-alert-neutral.admonition-todo,.rst-content .wy-alert-neutral.attention,.rst-content .wy-alert-neutral.caution,.rst-content .wy-alert-neutral.danger,.rst-content .wy-alert-neutral.error,.rst-content .wy-alert-neutral.hint,.rst-content .wy-alert-neutral.important,.rst-content .wy-alert-neutral.note,.rst-content .wy-alert-neutral.seealso,.rst-content .wy-alert-neutral.tip,.rst-content .wy-alert-neutral.warning,.wy-alert.wy-alert-neutral{background:#f3f6f6}.rst-content .wy-alert-neutral.admonition-todo .admonition-title,.rst-content .wy-alert-neutral.admonition-todo .wy-alert-title,.rst-content .wy-alert-neutral.admonition .admonition-title,.rst-content .wy-alert-neutral.admonition .wy-alert-title,.rst-content .wy-alert-neutral.attention .admonition-title,.rst-content .wy-alert-neutral.attention .wy-alert-title,.rst-content .wy-alert-neutral.caution .admonition-title,.rst-content .wy-alert-neutral.caution .wy-alert-title,.rst-content .wy-alert-neutral.danger .admonition-title,.rst-content .wy-alert-neutral.danger .wy-alert-title,.rst-content .wy-alert-neutral.error .admonition-title,.rst-content .wy-alert-neutral.error .wy-alert-title,.rst-content .wy-alert-neutral.hint .admonition-title,.rst-content .wy-alert-neutral.hint .wy-alert-title,.rst-content .wy-alert-neutral.important .admonition-title,.rst-content .wy-alert-neutral.important .wy-alert-title,.rst-content .wy-alert-neutral.note .admonition-title,.rst-content .wy-alert-neutral.note .wy-alert-title,.rst-content .wy-alert-neutral.seealso .admonition-title,.rst-content .wy-alert-neutral.seealso .wy-alert-title,.rst-content .wy-alert-neutral.tip .admonition-title,.rst-content .wy-alert-neutral.tip .wy-alert-title,.rst-content .wy-alert-neutral.warning .admonition-title,.rst-content .wy-alert-neutral.warning .wy-alert-title,.rst-content .wy-alert.wy-alert-neutral .admonition-title,.wy-alert.wy-alert-neutral .rst-content .admonition-title,.wy-alert.wy-alert-neutral .wy-alert-title{color:#404040;background:#e1e4e5}.rst-content .wy-alert-neutral.admonition-todo a,.rst-content .wy-alert-neutral.admonition a,.rst-content .wy-alert-neutral.attention a,.rst-content .wy-alert-neutral.caution a,.rst-content .wy-alert-neutral.danger a,.rst-content .wy-alert-neutral.error a,.rst-content .wy-alert-neutral.hint a,.rst-content .wy-alert-neutral.important a,.rst-content .wy-alert-neutral.note a,.rst-content .wy-alert-neutral.seealso a,.rst-content .wy-alert-neutral.tip a,.rst-content .wy-alert-neutral.warning a,.wy-alert.wy-alert-neutral a{color:#2980b9}.rst-content .admonition-todo p:last-child,.rst-content .admonition p:last-child,.rst-content .attention p:last-child,.rst-content .caution p:last-child,.rst-content .danger p:last-child,.rst-content .error p:last-child,.rst-content .hint p:last-child,.rst-content .important p:last-child,.rst-content .note p:last-child,.rst-content .seealso p:last-child,.rst-content .tip p:last-child,.rst-content .warning p:last-child,.wy-alert p:last-child{margin-bottom:0}.wy-tray-container{position:fixed;bottom:0;left:0;z-index:600}.wy-tray-container li{display:block;width:300px;background:transparent;color:#fff;text-align:center;box-shadow:0 5px 5px 0 rgba(0,0,0,.1);padding:0 24px;min-width:20%;opacity:0;height:0;line-height:56px;overflow:hidden;-webkit-transition:all .3s ease-in;-moz-transition:all .3s ease-in;transition:all .3s ease-in}.wy-tray-container li.wy-tray-item-success{background:#27ae60}.wy-tray-container li.wy-tray-item-info{background:#2980b9}.wy-tray-container li.wy-tray-item-warning{background:#e67e22}.wy-tray-container li.wy-tray-item-danger{background:#e74c3c}.wy-tray-container li.on{opacity:1;height:56px}@media screen and (max-width:768px){.wy-tray-container{bottom:auto;top:0;width:100%}.wy-tray-container li{width:100%}}button{font-size:100%;margin:0;vertical-align:baseline;*vertical-align:middle;cursor:pointer;line-height:normal;-webkit-appearance:button;*overflow:visible}button::-moz-focus-inner,input::-moz-focus-inner{border:0;padding:0}button[disabled]{cursor:default}.btn{display:inline-block;border-radius:2px;line-height:normal;white-space:nowrap;text-align:center;cursor:pointer;font-size:100%;padding:6px 12px 8px;color:#fff;border:1px solid rgba(0,0,0,.1);background-color:#27ae60;text-decoration:none;font-weight:400;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;box-shadow:inset 0 1px 2px -1px hsla(0,0%,100%,.5),inset 0 -2px 0 0 rgba(0,0,0,.1);outline-none:false;vertical-align:middle;*display:inline;zoom:1;-webkit-user-drag:none;-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none;-webkit-transition:all .1s linear;-moz-transition:all .1s linear;transition:all .1s linear}.btn-hover{background:#2e8ece;color:#fff}.btn:hover{background:#2cc36b;color:#fff}.btn:focus{background:#2cc36b;outline:0}.btn:active{box-shadow:inset 0 -1px 0 0 rgba(0,0,0,.05),inset 0 2px 0 0 rgba(0,0,0,.1);padding:8px 12px 6px}.btn:visited{color:#fff}.btn-disabled,.btn-disabled:active,.btn-disabled:focus,.btn-disabled:hover,.btn:disabled{background-image:none;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);filter:alpha(opacity=40);opacity:.4;cursor:not-allowed;box-shadow:none}.btn::-moz-focus-inner{padding:0;border:0}.btn-small{font-size:80%}.btn-info{background-color:#2980b9!important}.btn-info:hover{background-color:#2e8ece!important}.btn-neutral{background-color:#f3f6f6!important;color:#404040!important}.btn-neutral:hover{background-color:#e5ebeb!important;color:#404040}.btn-neutral:visited{color:#404040!important}.btn-success{background-color:#27ae60!important}.btn-success:hover{background-color:#295!important}.btn-danger{background-color:#e74c3c!important}.btn-danger:hover{background-color:#ea6153!important}.btn-warning{background-color:#e67e22!important}.btn-warning:hover{background-color:#e98b39!important}.btn-invert{background-color:#222}.btn-invert:hover{background-color:#2f2f2f!important}.btn-link{background-color:transparent!important;color:#2980b9;box-shadow:none;border-color:transparent!important}.btn-link:active,.btn-link:hover{background-color:transparent!important;color:#409ad5!important;box-shadow:none}.btn-link:visited{color:#9b59b6}.wy-btn-group .btn,.wy-control .btn{vertical-align:middle}.wy-btn-group{margin-bottom:24px;*zoom:1}.wy-btn-group:after,.wy-btn-group:before{display:table;content:""}.wy-btn-group:after{clear:both}.wy-dropdown{position:relative;display:inline-block}.wy-dropdown-active .wy-dropdown-menu{display:block}.wy-dropdown-menu{position:absolute;left:0;display:none;float:left;top:100%;min-width:100%;background:#fcfcfc;z-index:100;border:1px solid #cfd7dd;box-shadow:0 2px 2px 0 rgba(0,0,0,.1);padding:12px}.wy-dropdown-menu>dd>a{display:block;clear:both;color:#404040;white-space:nowrap;font-size:90%;padding:0 12px;cursor:pointer}.wy-dropdown-menu>dd>a:hover{background:#2980b9;color:#fff}.wy-dropdown-menu>dd.divider{border-top:1px solid #cfd7dd;margin:6px 0}.wy-dropdown-menu>dd.search{padding-bottom:12px}.wy-dropdown-menu>dd.search input[type=search]{width:100%}.wy-dropdown-menu>dd.call-to-action{background:#e3e3e3;text-transform:uppercase;font-weight:500;font-size:80%}.wy-dropdown-menu>dd.call-to-action:hover{background:#e3e3e3}.wy-dropdown-menu>dd.call-to-action .btn{color:#fff}.wy-dropdown.wy-dropdown-up .wy-dropdown-menu{bottom:100%;top:auto;left:auto;right:0}.wy-dropdown.wy-dropdown-bubble .wy-dropdown-menu{background:#fcfcfc;margin-top:2px}.wy-dropdown.wy-dropdown-bubble .wy-dropdown-menu a{padding:6px 12px}.wy-dropdown.wy-dropdown-bubble .wy-dropdown-menu a:hover{background:#2980b9;color:#fff}.wy-dropdown.wy-dropdown-left .wy-dropdown-menu{right:0;left:auto;text-align:right}.wy-dropdown-arrow:before{content:" ";border-bottom:5px solid #f5f5f5;border-left:5px solid transparent;border-right:5px solid transparent;position:absolute;display:block;top:-4px;left:50%;margin-left:-3px}.wy-dropdown-arrow.wy-dropdown-arrow-left:before{left:11px}.wy-form-stacked select{display:block}.wy-form-aligned .wy-help-inline,.wy-form-aligned input,.wy-form-aligned label,.wy-form-aligned select,.wy-form-aligned textarea{display:inline-block;*display:inline;*zoom:1;vertical-align:middle}.wy-form-aligned .wy-control-group>label{display:inline-block;vertical-align:middle;width:10em;margin:6px 12px 0 0;float:left}.wy-form-aligned .wy-control{float:left}.wy-form-aligned .wy-control label{display:block}.wy-form-aligned .wy-control select{margin-top:6px}fieldset{margin:0}fieldset,legend{border:0;padding:0}legend{width:100%;white-space:normal;margin-bottom:24px;font-size:150%;*margin-left:-7px}label,legend{display:block}label{margin:0 0 .3125em;color:#333;font-size:90%}input,select,textarea{font-size:100%;margin:0;vertical-align:baseline;*vertical-align:middle}.wy-control-group{margin-bottom:24px;max-width:1200px;margin-left:auto;margin-right:auto;*zoom:1}.wy-control-group:after,.wy-control-group:before{display:table;content:""}.wy-control-group:after{clear:both}.wy-control-group.wy-control-group-required>label:after{content:" *";color:#e74c3c}.wy-control-group .wy-form-full,.wy-control-group .wy-form-halves,.wy-control-group .wy-form-thirds{padding-bottom:12px}.wy-control-group .wy-form-full input[type=color],.wy-control-group .wy-form-full input[type=date],.wy-control-group .wy-form-full input[type=datetime-local],.wy-control-group .wy-form-full input[type=datetime],.wy-control-group .wy-form-full input[type=email],.wy-control-group .wy-form-full input[type=month],.wy-control-group .wy-form-full input[type=number],.wy-control-group .wy-form-full input[type=password],.wy-control-group .wy-form-full input[type=search],.wy-control-group .wy-form-full input[type=tel],.wy-control-group .wy-form-full input[type=text],.wy-control-group .wy-form-full input[type=time],.wy-control-group .wy-form-full input[type=url],.wy-control-group .wy-form-full input[type=week],.wy-control-group .wy-form-full select,.wy-control-group .wy-form-halves input[type=color],.wy-control-group .wy-form-halves input[type=date],.wy-control-group .wy-form-halves input[type=datetime-local],.wy-control-group .wy-form-halves input[type=datetime],.wy-control-group .wy-form-halves input[type=email],.wy-control-group .wy-form-halves input[type=month],.wy-control-group .wy-form-halves input[type=number],.wy-control-group .wy-form-halves input[type=password],.wy-control-group .wy-form-halves input[type=search],.wy-control-group .wy-form-halves input[type=tel],.wy-control-group .wy-form-halves input[type=text],.wy-control-group .wy-form-halves input[type=time],.wy-control-group .wy-form-halves input[type=url],.wy-control-group .wy-form-halves input[type=week],.wy-control-group .wy-form-halves select,.wy-control-group .wy-form-thirds input[type=color],.wy-control-group .wy-form-thirds input[type=date],.wy-control-group .wy-form-thirds input[type=datetime-local],.wy-control-group .wy-form-thirds input[type=datetime],.wy-control-group .wy-form-thirds input[type=email],.wy-control-group .wy-form-thirds input[type=month],.wy-control-group .wy-form-thirds input[type=number],.wy-control-group .wy-form-thirds input[type=password],.wy-control-group .wy-form-thirds input[type=search],.wy-control-group .wy-form-thirds input[type=tel],.wy-control-group .wy-form-thirds input[type=text],.wy-control-group .wy-form-thirds input[type=time],.wy-control-group .wy-form-thirds input[type=url],.wy-control-group .wy-form-thirds input[type=week],.wy-control-group .wy-form-thirds select{width:100%}.wy-control-group .wy-form-full{float:left;display:block;width:100%;margin-right:0}.wy-control-group .wy-form-full:last-child{margin-right:0}.wy-control-group .wy-form-halves{float:left;display:block;margin-right:2.35765%;width:48.82117%}.wy-control-group .wy-form-halves:last-child,.wy-control-group .wy-form-halves:nth-of-type(2n){margin-right:0}.wy-control-group .wy-form-halves:nth-of-type(odd){clear:left}.wy-control-group .wy-form-thirds{float:left;display:block;margin-right:2.35765%;width:31.76157%}.wy-control-group .wy-form-thirds:last-child,.wy-control-group .wy-form-thirds:nth-of-type(3n){margin-right:0}.wy-control-group .wy-form-thirds:nth-of-type(3n+1){clear:left}.wy-control-group.wy-control-group-no-input .wy-control,.wy-control-no-input{margin:6px 0 0;font-size:90%}.wy-control-no-input{display:inline-block}.wy-control-group.fluid-input input[type=color],.wy-control-group.fluid-input input[type=date],.wy-control-group.fluid-input input[type=datetime-local],.wy-control-group.fluid-input input[type=datetime],.wy-control-group.fluid-input input[type=email],.wy-control-group.fluid-input input[type=month],.wy-control-group.fluid-input input[type=number],.wy-control-group.fluid-input input[type=password],.wy-control-group.fluid-input input[type=search],.wy-control-group.fluid-input input[type=tel],.wy-control-group.fluid-input input[type=text],.wy-control-group.fluid-input input[type=time],.wy-control-group.fluid-input input[type=url],.wy-control-group.fluid-input input[type=week]{width:100%}.wy-form-message-inline{padding-left:.3em;color:#666;font-size:90%}.wy-form-message{display:block;color:#999;font-size:70%;margin-top:.3125em;font-style:italic}.wy-form-message p{font-size:inherit;font-style:italic;margin-bottom:6px}.wy-form-message p:last-child{margin-bottom:0}input{line-height:normal}input[type=button],input[type=reset],input[type=submit]{-webkit-appearance:button;cursor:pointer;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;*overflow:visible}input[type=color],input[type=date],input[type=datetime-local],input[type=datetime],input[type=email],input[type=month],input[type=number],input[type=password],input[type=search],input[type=tel],input[type=text],input[type=time],input[type=url],input[type=week]{-webkit-appearance:none;padding:6px;display:inline-block;border:1px solid #ccc;font-size:80%;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;box-shadow:inset 0 1px 3px #ddd;border-radius:0;-webkit-transition:border .3s linear;-moz-transition:border .3s linear;transition:border .3s linear}input[type=datetime-local]{padding:.34375em .625em}input[disabled]{cursor:default}input[type=checkbox],input[type=radio]{padding:0;margin-right:.3125em;*height:13px;*width:13px}input[type=checkbox],input[type=radio],input[type=search]{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}input[type=search]::-webkit-search-cancel-button,input[type=search]::-webkit-search-decoration{-webkit-appearance:none}input[type=color]:focus,input[type=date]:focus,input[type=datetime-local]:focus,input[type=datetime]:focus,input[type=email]:focus,input[type=month]:focus,input[type=number]:focus,input[type=password]:focus,input[type=search]:focus,input[type=tel]:focus,input[type=text]:focus,input[type=time]:focus,input[type=url]:focus,input[type=week]:focus{outline:0;outline:thin dotted\9;border-color:#333}input.no-focus:focus{border-color:#ccc!important}input[type=checkbox]:focus,input[type=file]:focus,input[type=radio]:focus{outline:thin dotted #333;outline:1px auto #129fea}input[type=color][disabled],input[type=date][disabled],input[type=datetime-local][disabled],input[type=datetime][disabled],input[type=email][disabled],input[type=month][disabled],input[type=number][disabled],input[type=password][disabled],input[type=search][disabled],input[type=tel][disabled],input[type=text][disabled],input[type=time][disabled],input[type=url][disabled],input[type=week][disabled]{cursor:not-allowed;background-color:#fafafa}input:focus:invalid,select:focus:invalid,textarea:focus:invalid{color:#e74c3c;border:1px solid #e74c3c}input:focus:invalid:focus,select:focus:invalid:focus,textarea:focus:invalid:focus{border-color:#e74c3c}input[type=checkbox]:focus:invalid:focus,input[type=file]:focus:invalid:focus,input[type=radio]:focus:invalid:focus{outline-color:#e74c3c}input.wy-input-large{padding:12px;font-size:100%}textarea{overflow:auto;vertical-align:top;width:100%;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif}select,textarea{padding:.5em .625em;display:inline-block;border:1px solid #ccc;font-size:80%;box-shadow:inset 0 1px 3px #ddd;-webkit-transition:border .3s linear;-moz-transition:border .3s linear;transition:border .3s linear}select{border:1px solid #ccc;background-color:#fff}select[multiple]{height:auto}select:focus,textarea:focus{outline:0}input[readonly],select[disabled],select[readonly],textarea[disabled],textarea[readonly]{cursor:not-allowed;background-color:#fafafa}input[type=checkbox][disabled],input[type=radio][disabled]{cursor:not-allowed}.wy-checkbox,.wy-radio{margin:6px 0;color:#404040;display:block}.wy-checkbox input,.wy-radio input{vertical-align:baseline}.wy-form-message-inline{display:inline-block;*display:inline;*zoom:1;vertical-align:middle}.wy-input-prefix,.wy-input-suffix{white-space:nowrap;padding:6px}.wy-input-prefix .wy-input-context,.wy-input-suffix .wy-input-context{line-height:27px;padding:0 8px;display:inline-block;font-size:80%;background-color:#f3f6f6;border:1px solid #ccc;color:#999}.wy-input-suffix .wy-input-context{border-left:0}.wy-input-prefix .wy-input-context{border-right:0}.wy-switch{position:relative;display:block;height:24px;margin-top:12px;cursor:pointer}.wy-switch:before{left:0;top:0;width:36px;height:12px;background:#ccc}.wy-switch:after,.wy-switch:before{position:absolute;content:"";display:block;border-radius:4px;-webkit-transition:all .2s ease-in-out;-moz-transition:all .2s ease-in-out;transition:all .2s ease-in-out}.wy-switch:after{width:18px;height:18px;background:#999;left:-3px;top:-3px}.wy-switch span{position:absolute;left:48px;display:block;font-size:12px;color:#ccc;line-height:1}.wy-switch.active:before{background:#1e8449}.wy-switch.active:after{left:24px;background:#27ae60}.wy-switch.disabled{cursor:not-allowed;opacity:.8}.wy-control-group.wy-control-group-error .wy-form-message,.wy-control-group.wy-control-group-error>label{color:#e74c3c}.wy-control-group.wy-control-group-error input[type=color],.wy-control-group.wy-control-group-error input[type=date],.wy-control-group.wy-control-group-error input[type=datetime-local],.wy-control-group.wy-control-group-error input[type=datetime],.wy-control-group.wy-control-group-error input[type=email],.wy-control-group.wy-control-group-error input[type=month],.wy-control-group.wy-control-group-error input[type=number],.wy-control-group.wy-control-group-error input[type=password],.wy-control-group.wy-control-group-error input[type=search],.wy-control-group.wy-control-group-error input[type=tel],.wy-control-group.wy-control-group-error input[type=text],.wy-control-group.wy-control-group-error input[type=time],.wy-control-group.wy-control-group-error input[type=url],.wy-control-group.wy-control-group-error input[type=week],.wy-control-group.wy-control-group-error textarea{border:1px solid #e74c3c}.wy-inline-validate{white-space:nowrap}.wy-inline-validate .wy-input-context{padding:.5em .625em;display:inline-block;font-size:80%}.wy-inline-validate.wy-inline-validate-success .wy-input-context{color:#27ae60}.wy-inline-validate.wy-inline-validate-danger .wy-input-context{color:#e74c3c}.wy-inline-validate.wy-inline-validate-warning .wy-input-context{color:#e67e22}.wy-inline-validate.wy-inline-validate-info .wy-input-context{color:#2980b9}.rotate-90{-webkit-transform:rotate(90deg);-moz-transform:rotate(90deg);-ms-transform:rotate(90deg);-o-transform:rotate(90deg);transform:rotate(90deg)}.rotate-180{-webkit-transform:rotate(180deg);-moz-transform:rotate(180deg);-ms-transform:rotate(180deg);-o-transform:rotate(180deg);transform:rotate(180deg)}.rotate-270{-webkit-transform:rotate(270deg);-moz-transform:rotate(270deg);-ms-transform:rotate(270deg);-o-transform:rotate(270deg);transform:rotate(270deg)}.mirror{-webkit-transform:scaleX(-1);-moz-transform:scaleX(-1);-ms-transform:scaleX(-1);-o-transform:scaleX(-1);transform:scaleX(-1)}.mirror.rotate-90{-webkit-transform:scaleX(-1) rotate(90deg);-moz-transform:scaleX(-1) rotate(90deg);-ms-transform:scaleX(-1) rotate(90deg);-o-transform:scaleX(-1) rotate(90deg);transform:scaleX(-1) rotate(90deg)}.mirror.rotate-180{-webkit-transform:scaleX(-1) rotate(180deg);-moz-transform:scaleX(-1) rotate(180deg);-ms-transform:scaleX(-1) rotate(180deg);-o-transform:scaleX(-1) rotate(180deg);transform:scaleX(-1) rotate(180deg)}.mirror.rotate-270{-webkit-transform:scaleX(-1) rotate(270deg);-moz-transform:scaleX(-1) rotate(270deg);-ms-transform:scaleX(-1) rotate(270deg);-o-transform:scaleX(-1) rotate(270deg);transform:scaleX(-1) rotate(270deg)}@media only screen and (max-width:480px){.wy-form button[type=submit]{margin:.7em 0 0}.wy-form input[type=color],.wy-form input[type=date],.wy-form input[type=datetime-local],.wy-form input[type=datetime],.wy-form input[type=email],.wy-form input[type=month],.wy-form input[type=number],.wy-form input[type=password],.wy-form input[type=search],.wy-form input[type=tel],.wy-form input[type=text],.wy-form input[type=time],.wy-form input[type=url],.wy-form input[type=week],.wy-form label{margin-bottom:.3em;display:block}.wy-form input[type=color],.wy-form input[type=date],.wy-form input[type=datetime-local],.wy-form input[type=datetime],.wy-form input[type=email],.wy-form input[type=month],.wy-form input[type=number],.wy-form input[type=password],.wy-form input[type=search],.wy-form input[type=tel],.wy-form input[type=time],.wy-form input[type=url],.wy-form input[type=week]{margin-bottom:0}.wy-form-aligned .wy-control-group label{margin-bottom:.3em;text-align:left;display:block;width:100%}.wy-form-aligned .wy-control{margin:1.5em 0 0}.wy-form-message,.wy-form-message-inline,.wy-form .wy-help-inline{display:block;font-size:80%;padding:6px 0}}@media screen and (max-width:768px){.tablet-hide{display:none}}@media screen and (max-width:480px){.mobile-hide{display:none}}.float-left{float:left}.float-right{float:right}.full-width{width:100%}.rst-content table.docutils,.rst-content table.field-list,.wy-table{border-collapse:collapse;border-spacing:0;empty-cells:show;margin-bottom:24px}.rst-content table.docutils caption,.rst-content table.field-list caption,.wy-table caption{color:#000;font:italic 85%/1 arial,sans-serif;padding:1em 0;text-align:center}.rst-content table.docutils td,.rst-content table.docutils th,.rst-content table.field-list td,.rst-content table.field-list th,.wy-table td,.wy-table th{font-size:90%;margin:0;overflow:visible;padding:8px 16px}.rst-content table.docutils td:first-child,.rst-content table.docutils th:first-child,.rst-content table.field-list td:first-child,.rst-content table.field-list th:first-child,.wy-table td:first-child,.wy-table th:first-child{border-left-width:0}.rst-content table.docutils thead,.rst-content table.field-list thead,.wy-table thead{color:#000;text-align:left;vertical-align:bottom;white-space:nowrap}.rst-content table.docutils thead th,.rst-content table.field-list thead th,.wy-table thead th{font-weight:700;border-bottom:2px solid #e1e4e5}.rst-content table.docutils td,.rst-content table.field-list td,.wy-table td{background-color:transparent;vertical-align:middle}.rst-content table.docutils td p,.rst-content table.field-list td p,.wy-table td p{line-height:18px}.rst-content table.docutils td p:last-child,.rst-content table.field-list td p:last-child,.wy-table td p:last-child{margin-bottom:0}.rst-content table.docutils .wy-table-cell-min,.rst-content table.field-list .wy-table-cell-min,.wy-table .wy-table-cell-min{width:1%;padding-right:0}.rst-content table.docutils .wy-table-cell-min input[type=checkbox],.rst-content table.field-list .wy-table-cell-min input[type=checkbox],.wy-table .wy-table-cell-min input[type=checkbox]{margin:0}.wy-table-secondary{color:grey;font-size:90%}.wy-table-tertiary{color:grey;font-size:80%}.rst-content table.docutils:not(.field-list) tr:nth-child(2n-1) td,.wy-table-backed,.wy-table-odd td,.wy-table-striped tr:nth-child(2n-1) td{background-color:#f3f6f6}.rst-content table.docutils,.wy-table-bordered-all{border:1px solid #e1e4e5}.rst-content table.docutils td,.wy-table-bordered-all td{border-bottom:1px solid #e1e4e5;border-left:1px solid #e1e4e5}.rst-content table.docutils tbody>tr:last-child td,.wy-table-bordered-all tbody>tr:last-child td{border-bottom-width:0}.wy-table-bordered{border:1px solid #e1e4e5}.wy-table-bordered-rows td{border-bottom:1px solid #e1e4e5}.wy-table-bordered-rows tbody>tr:last-child td{border-bottom-width:0}.wy-table-horizontal td,.wy-table-horizontal th{border-width:0 0 1px;border-bottom:1px solid #e1e4e5}.wy-table-horizontal tbody>tr:last-child td{border-bottom-width:0}.wy-table-responsive{margin-bottom:24px;max-width:100%;overflow:auto}.wy-table-responsive table{margin-bottom:0!important}.wy-table-responsive table td,.wy-table-responsive table th{white-space:nowrap}a{color:#2980b9;text-decoration:none;cursor:pointer}a:hover{color:#3091d1}a:visited{color:#9b59b6}html{height:100%}body,html{overflow-x:hidden}body{font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;font-weight:400;color:#404040;min-height:100%;background:#edf0f2}.wy-text-left{text-align:left}.wy-text-center{text-align:center}.wy-text-right{text-align:right}.wy-text-large{font-size:120%}.wy-text-normal{font-size:100%}.wy-text-small,small{font-size:80%}.wy-text-strike{text-decoration:line-through}.wy-text-warning{color:#e67e22!important}a.wy-text-warning:hover{color:#eb9950!important}.wy-text-info{color:#2980b9!important}a.wy-text-info:hover{color:#409ad5!important}.wy-text-success{color:#27ae60!important}a.wy-text-success:hover{color:#36d278!important}.wy-text-danger{color:#e74c3c!important}a.wy-text-danger:hover{color:#ed7669!important}.wy-text-neutral{color:#404040!important}a.wy-text-neutral:hover{color:#595959!important}.rst-content .toctree-wrapper>p.caption,h1,h2,h3,h4,h5,h6,legend{margin-top:0;font-weight:700;font-family:Roboto Slab,ff-tisa-web-pro,Georgia,Arial,sans-serif}p{line-height:24px;font-size:16px;margin:0 0 24px}h1{font-size:175%}.rst-content .toctree-wrapper>p.caption,h2{font-size:150%}h3{font-size:125%}h4{font-size:115%}h5{font-size:110%}h6{font-size:100%}hr{display:block;height:1px;border:0;border-top:1px solid #e1e4e5;margin:24px 0;padding:0}.rst-content code,.rst-content tt,code{white-space:nowrap;max-width:100%;background:#fff;border:1px solid #e1e4e5;font-size:75%;padding:0 5px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;color:#e74c3c;overflow-x:auto}.rst-content tt.code-large,code.code-large{font-size:90%}.rst-content .section ul,.rst-content .toctree-wrapper ul,.rst-content section ul,.wy-plain-list-disc,article ul{list-style:disc;line-height:24px;margin-bottom:24px}.rst-content .section ul li,.rst-content .toctree-wrapper ul li,.rst-content section ul li,.wy-plain-list-disc li,article ul li{list-style:disc;margin-left:24px}.rst-content .section ul li p:last-child,.rst-content .section ul li ul,.rst-content .toctree-wrapper ul li p:last-child,.rst-content .toctree-wrapper ul li ul,.rst-content section ul li p:last-child,.rst-content section ul li ul,.wy-plain-list-disc li p:last-child,.wy-plain-list-disc li ul,article ul li p:last-child,article ul li ul{margin-bottom:0}.rst-content .section ul li li,.rst-content .toctree-wrapper ul li li,.rst-content section ul li li,.wy-plain-list-disc li li,article ul li li{list-style:circle}.rst-content .section ul li li li,.rst-content .toctree-wrapper ul li li li,.rst-content section ul li li li,.wy-plain-list-disc li li li,article ul li li li{list-style:square}.rst-content .section ul li ol li,.rst-content .toctree-wrapper ul li ol li,.rst-content section ul li ol li,.wy-plain-list-disc li ol li,article ul li ol li{list-style:decimal}.rst-content .section ol,.rst-content .section ol.arabic,.rst-content .toctree-wrapper ol,.rst-content .toctree-wrapper ol.arabic,.rst-content section ol,.rst-content section ol.arabic,.wy-plain-list-decimal,article ol{list-style:decimal;line-height:24px;margin-bottom:24px}.rst-content .section ol.arabic li,.rst-content .section ol li,.rst-content .toctree-wrapper ol.arabic li,.rst-content .toctree-wrapper ol li,.rst-content section ol.arabic li,.rst-content section ol li,.wy-plain-list-decimal li,article ol li{list-style:decimal;margin-left:24px}.rst-content .section ol.arabic li ul,.rst-content .section ol li p:last-child,.rst-content .section ol li ul,.rst-content .toctree-wrapper ol.arabic li ul,.rst-content .toctree-wrapper ol li p:last-child,.rst-content .toctree-wrapper ol li ul,.rst-content section ol.arabic li ul,.rst-content section ol li p:last-child,.rst-content section ol li ul,.wy-plain-list-decimal li p:last-child,.wy-plain-list-decimal li ul,article ol li p:last-child,article ol li ul{margin-bottom:0}.rst-content .section ol.arabic li ul li,.rst-content .section ol li ul li,.rst-content .toctree-wrapper ol.arabic li ul li,.rst-content .toctree-wrapper ol li ul li,.rst-content section ol.arabic li ul li,.rst-content section ol li ul li,.wy-plain-list-decimal li ul li,article ol li ul li{list-style:disc}.wy-breadcrumbs{*zoom:1}.wy-breadcrumbs:after,.wy-breadcrumbs:before{display:table;content:""}.wy-breadcrumbs:after{clear:both}.wy-breadcrumbs li{display:inline-block}.wy-breadcrumbs li.wy-breadcrumbs-aside{float:right}.wy-breadcrumbs li a{display:inline-block;padding:5px}.wy-breadcrumbs li a:first-child{padding-left:0}.rst-content .wy-breadcrumbs li tt,.wy-breadcrumbs li .rst-content tt,.wy-breadcrumbs li code{padding:5px;border:none;background:none}.rst-content .wy-breadcrumbs li tt.literal,.wy-breadcrumbs li .rst-content tt.literal,.wy-breadcrumbs li code.literal{color:#404040}.wy-breadcrumbs-extra{margin-bottom:0;color:#b3b3b3;font-size:80%;display:inline-block}@media screen and (max-width:480px){.wy-breadcrumbs-extra,.wy-breadcrumbs li.wy-breadcrumbs-aside{display:none}}@media print{.wy-breadcrumbs li.wy-breadcrumbs-aside{display:none}}html{font-size:16px}.wy-affix{position:fixed;top:1.618em}.wy-menu a:hover{text-decoration:none}.wy-menu-horiz{*zoom:1}.wy-menu-horiz:after,.wy-menu-horiz:before{display:table;content:""}.wy-menu-horiz:after{clear:both}.wy-menu-horiz li,.wy-menu-horiz ul{display:inline-block}.wy-menu-horiz li:hover{background:hsla(0,0%,100%,.1)}.wy-menu-horiz li.divide-left{border-left:1px solid #404040}.wy-menu-horiz li.divide-right{border-right:1px solid #404040}.wy-menu-horiz a{height:32px;display:inline-block;line-height:32px;padding:0 16px}.wy-menu-vertical{width:300px}.wy-menu-vertical header,.wy-menu-vertical p.caption{color:#55a5d9;height:32px;line-height:32px;padding:0 1.618em;margin:12px 0 0;display:block;font-weight:700;text-transform:uppercase;font-size:85%;white-space:nowrap}.wy-menu-vertical ul{margin-bottom:0}.wy-menu-vertical li.divide-top{border-top:1px solid #404040}.wy-menu-vertical li.divide-bottom{border-bottom:1px solid #404040}.wy-menu-vertical li.current{background:#e3e3e3}.wy-menu-vertical li.current a{color:grey;border-right:1px solid #c9c9c9;padding:.4045em 2.427em}.wy-menu-vertical li.current a:hover{background:#d6d6d6}.rst-content .wy-menu-vertical li tt,.wy-menu-vertical li .rst-content tt,.wy-menu-vertical li code{border:none;background:inherit;color:inherit;padding-left:0;padding-right:0}.wy-menu-vertical li button.toctree-expand{display:block;float:left;margin-left:-1.2em;line-height:18px;color:#4d4d4d;border:none;background:none;padding:0}.wy-menu-vertical li.current>a,.wy-menu-vertical li.on a{color:#404040;font-weight:700;position:relative;background:#fcfcfc;border:none;padding:.4045em 1.618em}.wy-menu-vertical li.current>a:hover,.wy-menu-vertical li.on a:hover{background:#fcfcfc}.wy-menu-vertical li.current>a:hover button.toctree-expand,.wy-menu-vertical li.on a:hover button.toctree-expand{color:grey}.wy-menu-vertical li.current>a button.toctree-expand,.wy-menu-vertical li.on a button.toctree-expand{display:block;line-height:18px;color:#333}.wy-menu-vertical li.toctree-l1.current>a{border-bottom:1px solid #c9c9c9;border-top:1px solid #c9c9c9}.wy-menu-vertical .toctree-l1.current .toctree-l2>ul,.wy-menu-vertical .toctree-l2.current .toctree-l3>ul,.wy-menu-vertical .toctree-l3.current .toctree-l4>ul,.wy-menu-vertical .toctree-l4.current .toctree-l5>ul,.wy-menu-vertical .toctree-l5.current .toctree-l6>ul,.wy-menu-vertical .toctree-l6.current .toctree-l7>ul,.wy-menu-vertical .toctree-l7.current .toctree-l8>ul,.wy-menu-vertical .toctree-l8.current .toctree-l9>ul,.wy-menu-vertical .toctree-l9.current .toctree-l10>ul,.wy-menu-vertical .toctree-l10.current .toctree-l11>ul{display:none}.wy-menu-vertical .toctree-l1.current .current.toctree-l2>ul,.wy-menu-vertical .toctree-l2.current .current.toctree-l3>ul,.wy-menu-vertical .toctree-l3.current .current.toctree-l4>ul,.wy-menu-vertical .toctree-l4.current .current.toctree-l5>ul,.wy-menu-vertical .toctree-l5.current .current.toctree-l6>ul,.wy-menu-vertical .toctree-l6.current .current.toctree-l7>ul,.wy-menu-vertical .toctree-l7.current .current.toctree-l8>ul,.wy-menu-vertical .toctree-l8.current .current.toctree-l9>ul,.wy-menu-vertical .toctree-l9.current .current.toctree-l10>ul,.wy-menu-vertical .toctree-l10.current .current.toctree-l11>ul{display:block}.wy-menu-vertical li.toctree-l3,.wy-menu-vertical li.toctree-l4{font-size:.9em}.wy-menu-vertical li.toctree-l2 a,.wy-menu-vertical li.toctree-l3 a,.wy-menu-vertical li.toctree-l4 a,.wy-menu-vertical li.toctree-l5 a,.wy-menu-vertical li.toctree-l6 a,.wy-menu-vertical li.toctree-l7 a,.wy-menu-vertical li.toctree-l8 a,.wy-menu-vertical li.toctree-l9 a,.wy-menu-vertical li.toctree-l10 a{color:#404040}.wy-menu-vertical li.toctree-l2 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l3 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l4 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l5 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l6 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l7 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l8 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l9 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l10 a:hover button.toctree-expand{color:grey}.wy-menu-vertical li.toctree-l2.current li.toctree-l3>a,.wy-menu-vertical li.toctree-l3.current li.toctree-l4>a,.wy-menu-vertical li.toctree-l4.current li.toctree-l5>a,.wy-menu-vertical li.toctree-l5.current li.toctree-l6>a,.wy-menu-vertical li.toctree-l6.current li.toctree-l7>a,.wy-menu-vertical li.toctree-l7.current li.toctree-l8>a,.wy-menu-vertical li.toctree-l8.current li.toctree-l9>a,.wy-menu-vertical li.toctree-l9.current li.toctree-l10>a,.wy-menu-vertical li.toctree-l10.current li.toctree-l11>a{display:block}.wy-menu-vertical li.toctree-l2.current>a{padding:.4045em 2.427em}.wy-menu-vertical li.toctree-l2.current li.toctree-l3>a{padding:.4045em 1.618em .4045em 4.045em}.wy-menu-vertical li.toctree-l3.current>a{padding:.4045em 4.045em}.wy-menu-vertical li.toctree-l3.current li.toctree-l4>a{padding:.4045em 1.618em .4045em 5.663em}.wy-menu-vertical li.toctree-l4.current>a{padding:.4045em 5.663em}.wy-menu-vertical li.toctree-l4.current li.toctree-l5>a{padding:.4045em 1.618em .4045em 7.281em}.wy-menu-vertical li.toctree-l5.current>a{padding:.4045em 7.281em}.wy-menu-vertical li.toctree-l5.current li.toctree-l6>a{padding:.4045em 1.618em .4045em 8.899em}.wy-menu-vertical li.toctree-l6.current>a{padding:.4045em 8.899em}.wy-menu-vertical li.toctree-l6.current li.toctree-l7>a{padding:.4045em 1.618em .4045em 10.517em}.wy-menu-vertical li.toctree-l7.current>a{padding:.4045em 10.517em}.wy-menu-vertical li.toctree-l7.current li.toctree-l8>a{padding:.4045em 1.618em .4045em 12.135em}.wy-menu-vertical li.toctree-l8.current>a{padding:.4045em 12.135em}.wy-menu-vertical li.toctree-l8.current li.toctree-l9>a{padding:.4045em 1.618em .4045em 13.753em}.wy-menu-vertical li.toctree-l9.current>a{padding:.4045em 13.753em}.wy-menu-vertical li.toctree-l9.current li.toctree-l10>a{padding:.4045em 1.618em .4045em 15.371em}.wy-menu-vertical li.toctree-l10.current>a{padding:.4045em 15.371em}.wy-menu-vertical li.toctree-l10.current li.toctree-l11>a{padding:.4045em 1.618em .4045em 16.989em}.wy-menu-vertical li.toctree-l2.current>a,.wy-menu-vertical li.toctree-l2.current li.toctree-l3>a{background:#c9c9c9}.wy-menu-vertical li.toctree-l2 button.toctree-expand{color:#a3a3a3}.wy-menu-vertical li.toctree-l3.current>a,.wy-menu-vertical li.toctree-l3.current li.toctree-l4>a{background:#bdbdbd}.wy-menu-vertical li.toctree-l3 button.toctree-expand{color:#969696}.wy-menu-vertical li.current ul{display:block}.wy-menu-vertical li ul{margin-bottom:0;display:none}.wy-menu-vertical li ul li a{margin-bottom:0;color:#d9d9d9;font-weight:400}.wy-menu-vertical a{line-height:18px;padding:.4045em 1.618em;display:block;position:relative;font-size:90%;color:#d9d9d9}.wy-menu-vertical a:hover{background-color:#4e4a4a;cursor:pointer}.wy-menu-vertical a:hover button.toctree-expand{color:#d9d9d9}.wy-menu-vertical a:active{background-color:#2980b9;cursor:pointer;color:#fff}.wy-menu-vertical a:active button.toctree-expand{color:#fff}.wy-side-nav-search{display:block;width:300px;padding:.809em;margin-bottom:.809em;z-index:200;background-color:#2980b9;text-align:center;color:#fcfcfc}.wy-side-nav-search input[type=text]{width:100%;border-radius:50px;padding:6px 12px;border-color:#2472a4}.wy-side-nav-search img{display:block;margin:auto auto .809em;height:45px;width:45px;background-color:#2980b9;padding:5px;border-radius:100%}.wy-side-nav-search .wy-dropdown>a,.wy-side-nav-search>a{color:#fcfcfc;font-size:100%;font-weight:700;display:inline-block;padding:4px 6px;margin-bottom:.809em;max-width:100%}.wy-side-nav-search .wy-dropdown>a:hover,.wy-side-nav-search>a:hover{background:hsla(0,0%,100%,.1)}.wy-side-nav-search .wy-dropdown>a img.logo,.wy-side-nav-search>a img.logo{display:block;margin:0 auto;height:auto;width:auto;border-radius:0;max-width:100%;background:transparent}.wy-side-nav-search .wy-dropdown>a.icon img.logo,.wy-side-nav-search>a.icon img.logo{margin-top:.85em}.wy-side-nav-search>div.version{margin-top:-.4045em;margin-bottom:.809em;font-weight:400;color:hsla(0,0%,100%,.3)}.wy-nav .wy-menu-vertical header{color:#2980b9}.wy-nav .wy-menu-vertical a{color:#b3b3b3}.wy-nav .wy-menu-vertical a:hover{background-color:#2980b9;color:#fff}[data-menu-wrap]{-webkit-transition:all .2s ease-in;-moz-transition:all .2s ease-in;transition:all .2s ease-in;position:absolute;opacity:1;width:100%;opacity:0}[data-menu-wrap].move-center{left:0;right:auto;opacity:1}[data-menu-wrap].move-left{right:auto;left:-100%;opacity:0}[data-menu-wrap].move-right{right:-100%;left:auto;opacity:0}.wy-body-for-nav{background:#fcfcfc}.wy-grid-for-nav{position:absolute;width:100%;height:100%}.wy-nav-side{position:fixed;top:0;bottom:0;left:0;padding-bottom:2em;width:300px;overflow-x:hidden;overflow-y:hidden;min-height:100%;color:#9b9b9b;background:#343131;z-index:200}.wy-side-scroll{width:320px;position:relative;overflow-x:hidden;overflow-y:scroll;height:100%}.wy-nav-top{display:none;background:#2980b9;color:#fff;padding:.4045em .809em;position:relative;line-height:50px;text-align:center;font-size:100%;*zoom:1}.wy-nav-top:after,.wy-nav-top:before{display:table;content:""}.wy-nav-top:after{clear:both}.wy-nav-top a{color:#fff;font-weight:700}.wy-nav-top img{margin-right:12px;height:45px;width:45px;background-color:#2980b9;padding:5px;border-radius:100%}.wy-nav-top i{font-size:30px;float:left;cursor:pointer;padding-top:inherit}.wy-nav-content-wrap{margin-left:300px;background:#fcfcfc;min-height:100%}.wy-nav-content{padding:1.618em 3.236em;height:100%;max-width:800px;margin:auto}.wy-body-mask{position:fixed;width:100%;height:100%;background:rgba(0,0,0,.2);display:none;z-index:499}.wy-body-mask.on{display:block}footer{color:grey}footer p{margin-bottom:12px}.rst-content footer span.commit tt,footer span.commit .rst-content tt,footer span.commit code{padding:0;font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;font-size:1em;background:none;border:none;color:grey}.rst-footer-buttons{*zoom:1}.rst-footer-buttons:after,.rst-footer-buttons:before{width:100%;display:table;content:""}.rst-footer-buttons:after{clear:both}.rst-breadcrumbs-buttons{margin-top:12px;*zoom:1}.rst-breadcrumbs-buttons:after,.rst-breadcrumbs-buttons:before{display:table;content:""}.rst-breadcrumbs-buttons:after{clear:both}#search-results .search li{margin-bottom:24px;border-bottom:1px solid #e1e4e5;padding-bottom:24px}#search-results .search li:first-child{border-top:1px solid #e1e4e5;padding-top:24px}#search-results .search li a{font-size:120%;margin-bottom:12px;display:inline-block}#search-results .context{color:grey;font-size:90%}.genindextable li>ul{margin-left:24px}@media screen and (max-width:768px){.wy-body-for-nav{background:#fcfcfc}.wy-nav-top{display:block}.wy-nav-side{left:-300px}.wy-nav-side.shift{width:85%;left:0}.wy-menu.wy-menu-vertical,.wy-side-nav-search,.wy-side-scroll{width:auto}.wy-nav-content-wrap{margin-left:0}.wy-nav-content-wrap .wy-nav-content{padding:1.618em}.wy-nav-content-wrap.shift{position:fixed;min-width:100%;left:85%;top:0;height:100%;overflow:hidden}}@media screen and (min-width:1100px){.wy-nav-content-wrap{background:rgba(0,0,0,.05)}.wy-nav-content{margin:0;background:#fcfcfc}}@media print{.rst-versions,.wy-nav-side,footer{display:none}.wy-nav-content-wrap{margin-left:0}}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60;*zoom:1}.rst-versions .rst-current-version:after,.rst-versions .rst-current-version:before{display:table;content:""}.rst-versions .rst-current-version:after{clear:both}.rst-content .code-block-caption .rst-versions .rst-current-version .headerlink,.rst-content .eqno .rst-versions .rst-current-version .headerlink,.rst-content .rst-versions .rst-current-version .admonition-title,.rst-content code.download .rst-versions .rst-current-version span:first-child,.rst-content dl dt .rst-versions .rst-current-version .headerlink,.rst-content h1 .rst-versions .rst-current-version .headerlink,.rst-content h2 .rst-versions .rst-current-version .headerlink,.rst-content h3 .rst-versions .rst-current-version .headerlink,.rst-content h4 .rst-versions .rst-current-version .headerlink,.rst-content h5 .rst-versions .rst-current-version .headerlink,.rst-content h6 .rst-versions .rst-current-version .headerlink,.rst-content p .rst-versions .rst-current-version .headerlink,.rst-content table>caption .rst-versions .rst-current-version .headerlink,.rst-content tt.download .rst-versions .rst-current-version span:first-child,.rst-versions .rst-current-version .fa,.rst-versions .rst-current-version .icon,.rst-versions .rst-current-version .rst-content .admonition-title,.rst-versions .rst-current-version .rst-content .code-block-caption .headerlink,.rst-versions .rst-current-version .rst-content .eqno .headerlink,.rst-versions .rst-current-version .rst-content code.download span:first-child,.rst-versions .rst-current-version .rst-content dl dt .headerlink,.rst-versions .rst-current-version .rst-content h1 .headerlink,.rst-versions .rst-current-version .rst-content h2 .headerlink,.rst-versions .rst-current-version .rst-content h3 .headerlink,.rst-versions .rst-current-version .rst-content h4 .headerlink,.rst-versions .rst-current-version .rst-content h5 .headerlink,.rst-versions .rst-current-version .rst-content h6 .headerlink,.rst-versions .rst-current-version .rst-content p .headerlink,.rst-versions .rst-current-version .rst-content table>caption .headerlink,.rst-versions .rst-current-version .rst-content tt.download span:first-child,.rst-versions .rst-current-version .wy-menu-vertical li button.toctree-expand,.wy-menu-vertical li .rst-versions .rst-current-version button.toctree-expand{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}}.rst-content .toctree-wrapper>p.caption,.rst-content h1,.rst-content h2,.rst-content h3,.rst-content h4,.rst-content h5,.rst-content h6{margin-bottom:24px}.rst-content img{max-width:100%;height:auto}.rst-content div.figure,.rst-content figure{margin-bottom:24px}.rst-content div.figure .caption-text,.rst-content figure .caption-text{font-style:italic}.rst-content div.figure p:last-child.caption,.rst-content figure p:last-child.caption{margin-bottom:0}.rst-content div.figure.align-center,.rst-content figure.align-center{text-align:center}.rst-content .section>a>img,.rst-content .section>img,.rst-content section>a>img,.rst-content section>img{margin-bottom:24px}.rst-content abbr[title]{text-decoration:none}.rst-content.style-external-links a.reference.external:after{font-family:FontAwesome;content:"\f08e";color:#b3b3b3;vertical-align:super;font-size:60%;margin:0 .2em}.rst-content blockquote{margin-left:24px;line-height:24px;margin-bottom:24px}.rst-content pre.literal-block{white-space:pre;margin:0;padding:12px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;display:block;overflow:auto}.rst-content div[class^=highlight],.rst-content pre.literal-block{border:1px solid #e1e4e5;overflow-x:auto;margin:1px 0 24px}.rst-content div[class^=highlight] div[class^=highlight],.rst-content pre.literal-block div[class^=highlight]{padding:0;border:none;margin:0}.rst-content div[class^=highlight] td.code{width:100%}.rst-content .linenodiv pre{border-right:1px solid #e6e9ea;margin:0;padding:12px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;user-select:none;pointer-events:none}.rst-content div[class^=highlight] pre{white-space:pre;margin:0;padding:12px;display:block;overflow:auto}.rst-content div[class^=highlight] pre .hll{display:block;margin:0 -12px;padding:0 12px}.rst-content .linenodiv pre,.rst-content div[class^=highlight] pre,.rst-content pre.literal-block{font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;font-size:12px;line-height:1.4}.rst-content div.highlight .gp,.rst-content div.highlight span.linenos{user-select:none;pointer-events:none}.rst-content div.highlight span.linenos{display:inline-block;padding-left:0;padding-right:12px;margin-right:12px;border-right:1px solid #e6e9ea}.rst-content .code-block-caption{font-style:italic;font-size:85%;line-height:1;padding:1em 0;text-align:center}@media print{.rst-content .codeblock,.rst-content div[class^=highlight],.rst-content div[class^=highlight] pre{white-space:pre-wrap}}.rst-content .admonition,.rst-content .admonition-todo,.rst-content .attention,.rst-content .caution,.rst-content .danger,.rst-content .error,.rst-content .hint,.rst-content .important,.rst-content .note,.rst-content .seealso,.rst-content .tip,.rst-content .warning{clear:both}.rst-content .admonition-todo .last,.rst-content .admonition-todo>:last-child,.rst-content .admonition .last,.rst-content .admonition>:last-child,.rst-content .attention .last,.rst-content .attention>:last-child,.rst-content .caution .last,.rst-content .caution>:last-child,.rst-content .danger .last,.rst-content .danger>:last-child,.rst-content .error .last,.rst-content .error>:last-child,.rst-content .hint .last,.rst-content .hint>:last-child,.rst-content .important .last,.rst-content .important>:last-child,.rst-content .note .last,.rst-content .note>:last-child,.rst-content .seealso .last,.rst-content .seealso>:last-child,.rst-content .tip .last,.rst-content .tip>:last-child,.rst-content .warning .last,.rst-content .warning>:last-child{margin-bottom:0}.rst-content .admonition-title:before{margin-right:4px}.rst-content .admonition table{border-color:rgba(0,0,0,.1)}.rst-content .admonition table td,.rst-content .admonition table th{background:transparent!important;border-color:rgba(0,0,0,.1)!important}.rst-content .section ol.loweralpha,.rst-content .section ol.loweralpha>li,.rst-content .toctree-wrapper ol.loweralpha,.rst-content .toctree-wrapper ol.loweralpha>li,.rst-content section ol.loweralpha,.rst-content section ol.loweralpha>li{list-style:lower-alpha}.rst-content .section ol.upperalpha,.rst-content .section ol.upperalpha>li,.rst-content .toctree-wrapper ol.upperalpha,.rst-content .toctree-wrapper ol.upperalpha>li,.rst-content section ol.upperalpha,.rst-content section ol.upperalpha>li{list-style:upper-alpha}.rst-content .section ol li>*,.rst-content .section ul li>*,.rst-content .toctree-wrapper ol li>*,.rst-content .toctree-wrapper ul li>*,.rst-content section ol li>*,.rst-content section ul li>*{margin-top:12px;margin-bottom:12px}.rst-content .section ol li>:first-child,.rst-content .section ul li>:first-child,.rst-content .toctree-wrapper ol li>:first-child,.rst-content .toctree-wrapper ul li>:first-child,.rst-content section ol li>:first-child,.rst-content section ul li>:first-child{margin-top:0}.rst-content .section ol li>p,.rst-content .section ol li>p:last-child,.rst-content .section ul li>p,.rst-content .section ul li>p:last-child,.rst-content .toctree-wrapper ol li>p,.rst-content .toctree-wrapper ol li>p:last-child,.rst-content .toctree-wrapper ul li>p,.rst-content .toctree-wrapper ul li>p:last-child,.rst-content section ol li>p,.rst-content section ol li>p:last-child,.rst-content section ul li>p,.rst-content section ul li>p:last-child{margin-bottom:12px}.rst-content .section ol li>p:only-child,.rst-content .section ol li>p:only-child:last-child,.rst-content .section ul li>p:only-child,.rst-content .section ul li>p:only-child:last-child,.rst-content .toctree-wrapper ol li>p:only-child,.rst-content .toctree-wrapper ol li>p:only-child:last-child,.rst-content .toctree-wrapper ul li>p:only-child,.rst-content .toctree-wrapper ul li>p:only-child:last-child,.rst-content section ol li>p:only-child,.rst-content section ol li>p:only-child:last-child,.rst-content section ul li>p:only-child,.rst-content section ul li>p:only-child:last-child{margin-bottom:0}.rst-content .section ol li>ol,.rst-content .section ol li>ul,.rst-content .section ul li>ol,.rst-content .section ul li>ul,.rst-content .toctree-wrapper ol li>ol,.rst-content .toctree-wrapper ol li>ul,.rst-content .toctree-wrapper ul li>ol,.rst-content .toctree-wrapper ul li>ul,.rst-content section ol li>ol,.rst-content section ol li>ul,.rst-content section ul li>ol,.rst-content section ul li>ul{margin-bottom:12px}.rst-content .section ol.simple li>*,.rst-content .section ol.simple li ol,.rst-content .section ol.simple li ul,.rst-content .section ul.simple li>*,.rst-content .section ul.simple li ol,.rst-content .section ul.simple li ul,.rst-content .toctree-wrapper ol.simple li>*,.rst-content .toctree-wrapper ol.simple li ol,.rst-content .toctree-wrapper ol.simple li ul,.rst-content .toctree-wrapper ul.simple li>*,.rst-content .toctree-wrapper ul.simple li ol,.rst-content .toctree-wrapper ul.simple li ul,.rst-content section ol.simple li>*,.rst-content section ol.simple li ol,.rst-content section ol.simple li ul,.rst-content section ul.simple li>*,.rst-content section ul.simple li ol,.rst-content section ul.simple li ul{margin-top:0;margin-bottom:0}.rst-content .line-block{margin-left:0;margin-bottom:24px;line-height:24px}.rst-content .line-block .line-block{margin-left:24px;margin-bottom:0}.rst-content .topic-title{font-weight:700;margin-bottom:12px}.rst-content .toc-backref{color:#404040}.rst-content .align-right{float:right;margin:0 0 24px 24px}.rst-content .align-left{float:left;margin:0 24px 24px 0}.rst-content .align-center{margin:auto}.rst-content .align-center:not(table){display:block}.rst-content .code-block-caption .headerlink,.rst-content .eqno .headerlink,.rst-content .toctree-wrapper>p.caption .headerlink,.rst-content dl dt .headerlink,.rst-content h1 .headerlink,.rst-content h2 .headerlink,.rst-content h3 .headerlink,.rst-content h4 .headerlink,.rst-content h5 .headerlink,.rst-content h6 .headerlink,.rst-content p.caption .headerlink,.rst-content p .headerlink,.rst-content table>caption .headerlink{opacity:0;font-size:14px;font-family:FontAwesome;margin-left:.5em}.rst-content .code-block-caption .headerlink:focus,.rst-content .code-block-caption:hover .headerlink,.rst-content .eqno .headerlink:focus,.rst-content .eqno:hover .headerlink,.rst-content .toctree-wrapper>p.caption .headerlink:focus,.rst-content .toctree-wrapper>p.caption:hover .headerlink,.rst-content dl dt .headerlink:focus,.rst-content dl dt:hover .headerlink,.rst-content h1 .headerlink:focus,.rst-content h1:hover .headerlink,.rst-content h2 .headerlink:focus,.rst-content h2:hover .headerlink,.rst-content h3 .headerlink:focus,.rst-content h3:hover .headerlink,.rst-content h4 .headerlink:focus,.rst-content h4:hover .headerlink,.rst-content h5 .headerlink:focus,.rst-content h5:hover .headerlink,.rst-content h6 .headerlink:focus,.rst-content h6:hover .headerlink,.rst-content p.caption .headerlink:focus,.rst-content p.caption:hover .headerlink,.rst-content p .headerlink:focus,.rst-content p:hover .headerlink,.rst-content table>caption .headerlink:focus,.rst-content table>caption:hover .headerlink{opacity:1}.rst-content .btn:focus{outline:2px solid}.rst-content table>caption .headerlink:after{font-size:12px}.rst-content .centered{text-align:center}.rst-content .sidebar{float:right;width:40%;display:block;margin:0 0 24px 24px;padding:24px;background:#f3f6f6;border:1px solid #e1e4e5}.rst-content .sidebar dl,.rst-content .sidebar p,.rst-content .sidebar ul{font-size:90%}.rst-content .sidebar .last,.rst-content .sidebar>:last-child{margin-bottom:0}.rst-content .sidebar .sidebar-title{display:block;font-family:Roboto Slab,ff-tisa-web-pro,Georgia,Arial,sans-serif;font-weight:700;background:#e1e4e5;padding:6px 12px;margin:-24px -24px 24px;font-size:100%}.rst-content .highlighted{background:#f1c40f;box-shadow:0 0 0 2px #f1c40f;display:inline;font-weight:700}.rst-content .citation-reference,.rst-content .footnote-reference{vertical-align:baseline;position:relative;top:-.4em;line-height:0;font-size:90%}.rst-content .hlist{width:100%}.rst-content dl dt span.classifier:before{content:" : "}.rst-content dl dt span.classifier-delimiter{display:none!important}html.writer-html4 .rst-content table.docutils.citation,html.writer-html4 .rst-content table.docutils.footnote{background:none;border:none}html.writer-html4 .rst-content table.docutils.citation td,html.writer-html4 .rst-content table.docutils.citation tr,html.writer-html4 .rst-content table.docutils.footnote td,html.writer-html4 .rst-content table.docutils.footnote tr{border:none;background-color:transparent!important;white-space:normal}html.writer-html4 .rst-content table.docutils.citation td.label,html.writer-html4 .rst-content table.docutils.footnote td.label{padding-left:0;padding-right:0;vertical-align:top}html.writer-html5 .rst-content dl.field-list,html.writer-html5 .rst-content dl.footnote{display:grid;grid-template-columns:max-content auto}html.writer-html5 .rst-content dl.field-list>dt,html.writer-html5 .rst-content dl.footnote>dt{padding-left:1rem}html.writer-html5 .rst-content dl.field-list>dt:after,html.writer-html5 .rst-content dl.footnote>dt:after{content:":"}html.writer-html5 .rst-content dl.field-list>dd,html.writer-html5 .rst-content dl.field-list>dt,html.writer-html5 .rst-content dl.footnote>dd,html.writer-html5 .rst-content dl.footnote>dt{margin-bottom:0}html.writer-html5 .rst-content dl.footnote{font-size:.9rem}html.writer-html5 .rst-content dl.footnote>dt{margin:0 .5rem .5rem 0;line-height:1.2rem;word-break:break-all;font-weight:400}html.writer-html5 .rst-content dl.footnote>dt>span.brackets{margin-right:.5rem}html.writer-html5 .rst-content dl.footnote>dt>span.brackets:before{content:"["}html.writer-html5 .rst-content dl.footnote>dt>span.brackets:after{content:"]"}html.writer-html5 .rst-content dl.footnote>dt>span.fn-backref{font-style:italic}html.writer-html5 .rst-content dl.footnote>dd{margin:0 0 .5rem;line-height:1.2rem}html.writer-html5 .rst-content dl.footnote>dd p,html.writer-html5 .rst-content dl.option-list kbd{font-size:.9rem}.rst-content table.docutils.footnote,html.writer-html4 .rst-content table.docutils.citation,html.writer-html5 .rst-content dl.footnote{color:grey}.rst-content table.docutils.footnote code,.rst-content table.docutils.footnote tt,html.writer-html4 .rst-content table.docutils.citation code,html.writer-html4 .rst-content table.docutils.citation tt,html.writer-html5 .rst-content dl.footnote code,html.writer-html5 .rst-content dl.footnote tt{color:#555}.rst-content .wy-table-responsive.citation,.rst-content .wy-table-responsive.footnote{margin-bottom:0}.rst-content .wy-table-responsive.citation+:not(.citation),.rst-content .wy-table-responsive.footnote+:not(.footnote){margin-top:24px}.rst-content .wy-table-responsive.citation:last-child,.rst-content .wy-table-responsive.footnote:last-child{margin-bottom:24px}.rst-content table.docutils th{border-color:#e1e4e5}html.writer-html5 .rst-content table.docutils th{border:1px solid #e1e4e5}html.writer-html5 .rst-content table.docutils td>p,html.writer-html5 .rst-content table.docutils th>p{line-height:1rem;margin-bottom:0;font-size:.9rem}.rst-content table.docutils td .last,.rst-content table.docutils td .last>:last-child{margin-bottom:0}.rst-content table.field-list,.rst-content table.field-list td{border:none}.rst-content table.field-list td p{font-size:inherit;line-height:inherit}.rst-content table.field-list td>strong{display:inline-block}.rst-content table.field-list .field-name{padding-right:10px;text-align:left;white-space:nowrap}.rst-content table.field-list .field-body{text-align:left}.rst-content code,.rst-content tt{color:#000;font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;padding:2px 5px}.rst-content code big,.rst-content code em,.rst-content tt big,.rst-content tt em{font-size:100%!important;line-height:normal}.rst-content code.literal,.rst-content tt.literal{color:#e74c3c;white-space:normal}.rst-content code.xref,.rst-content tt.xref,a .rst-content code,a .rst-content tt{font-weight:700;color:#404040}.rst-content kbd,.rst-content pre,.rst-content samp{font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace}.rst-content a code,.rst-content a tt{color:#2980b9}.rst-content dl{margin-bottom:24px}.rst-content dl dt{font-weight:700;margin-bottom:12px}.rst-content dl ol,.rst-content dl p,.rst-content dl table,.rst-content dl ul{margin-bottom:12px}.rst-content dl dd{margin:0 0 12px 24px;line-height:24px}html.writer-html4 .rst-content dl:not(.docutils),html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple){margin-bottom:24px}html.writer-html4 .rst-content dl:not(.docutils)>dt,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt{display:table;margin:6px 0;font-size:90%;line-height:normal;background:#e7f2fa;color:#2980b9;border-top:3px solid #6ab0de;padding:6px;position:relative}html.writer-html4 .rst-content dl:not(.docutils)>dt:before,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt:before{color:#6ab0de}html.writer-html4 .rst-content dl:not(.docutils)>dt .headerlink,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt .headerlink{color:#404040;font-size:100%!important}html.writer-html4 .rst-content dl:not(.docutils) dl:not(.field-list)>dt,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) dl:not(.field-list)>dt{margin-bottom:6px;border:none;border-left:3px solid #ccc;background:#f0f0f0;color:#555}html.writer-html4 .rst-content dl:not(.docutils) dl:not(.field-list)>dt .headerlink,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) dl:not(.field-list)>dt .headerlink{color:#404040;font-size:100%!important}html.writer-html4 .rst-content dl:not(.docutils)>dt:first-child,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt:first-child{margin-top:0}html.writer-html4 .rst-content dl:not(.docutils) code.descclassname,html.writer-html4 .rst-content dl:not(.docutils) code.descname,html.writer-html4 .rst-content dl:not(.docutils) tt.descclassname,html.writer-html4 .rst-content dl:not(.docutils) tt.descname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) code.descclassname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) code.descname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) tt.descclassname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) tt.descname{background-color:transparent;border:none;padding:0;font-size:100%!important}html.writer-html4 .rst-content dl:not(.docutils) code.descname,html.writer-html4 .rst-content dl:not(.docutils) tt.descname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) code.descname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) tt.descname{font-weight:700}html.writer-html4 .rst-content dl:not(.docutils) .optional,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .optional{display:inline-block;padding:0 4px;color:#000;font-weight:700}html.writer-html4 .rst-content dl:not(.docutils) .property,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .property{display:inline-block;padding-right:8px;max-width:100%}html.writer-html4 .rst-content dl:not(.docutils) .k,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .k{font-style:italic}html.writer-html4 .rst-content dl:not(.docutils) .descclassname,html.writer-html4 .rst-content dl:not(.docutils) .descname,html.writer-html4 .rst-content dl:not(.docutils) .sig-name,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .descclassname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .descname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .sig-name{font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;color:#000}.rst-content .viewcode-back,.rst-content .viewcode-link{display:inline-block;color:#27ae60;font-size:80%;padding-left:24px}.rst-content .viewcode-back{display:block;float:right}.rst-content p.rubric{margin-bottom:12px;font-weight:700}.rst-content code.download,.rst-content tt.download{background:inherit;padding:inherit;font-weight:400;font-family:inherit;font-size:inherit;color:inherit;border:inherit;white-space:inherit}.rst-content code.download span:first-child,.rst-content tt.download span:first-child{-webkit-font-smoothing:subpixel-antialiased}.rst-content code.download span:first-child:before,.rst-content tt.download span:first-child:before{margin-right:4px}.rst-content .guilabel{border:1px solid #7fbbe3;background:#e7f2fa;font-size:80%;font-weight:700;border-radius:4px;padding:2.4px 6px;margin:auto 2px}.rst-content .versionmodified{font-style:italic}@media screen and (max-width:480px){.rst-content .sidebar{width:100%}}span[id*=MathJax-Span]{color:#404040}.math{text-align:center}@font-face{font-family:Lato;src:url(fonts/lato-normal.woff2?bd03a2cc277bbbc338d464e679fe9942) format("woff2"),url(fonts/lato-normal.woff?27bd77b9162d388cb8d4c4217c7c5e2a) format("woff");font-weight:400;font-style:normal;font-display:block}@font-face{font-family:Lato;src:url(fonts/lato-bold.woff2?cccb897485813c7c256901dbca54ecf2) format("woff2"),url(fonts/lato-bold.woff?d878b6c29b10beca227e9eef4246111b) format("woff");font-weight:700;font-style:normal;font-display:block}@font-face{font-family:Lato;src:url(fonts/lato-bold-italic.woff2?0b6bb6725576b072c5d0b02ecdd1900d) format("woff2"),url(fonts/lato-bold-italic.woff?9c7e4e9eb485b4a121c760e61bc3707c) format("woff");font-weight:700;font-style:italic;font-display:block}@font-face{font-family:Lato;src:url(fonts/lato-normal-italic.woff2?4eb103b4d12be57cb1d040ed5e162e9d) format("woff2"),url(fonts/lato-normal-italic.woff?f28f2d6482446544ef1ea1ccc6dd5892) format("woff");font-weight:400;font-style:italic;font-display:block}@font-face{font-family:Roboto Slab;font-style:normal;font-weight:400;src:url(fonts/Roboto-Slab-Regular.woff2?7abf5b8d04d26a2cafea937019bca958) format("woff2"),url(fonts/Roboto-Slab-Regular.woff?c1be9284088d487c5e3ff0a10a92e58c) format("woff");font-display:block}@font-face{font-family:Roboto Slab;font-style:normal;font-weight:700;src:url(fonts/Roboto-Slab-Bold.woff2?9984f4a9bda09be08e83f2506954adbe) format("woff2"),url(fonts/Roboto-Slab-Bold.woff?bed5564a116b05148e3b3bea6fb1162a) format("woff");font-display:block} diff --git a/docs/build/html/_static/doctools.js b/docs/build/html/_static/doctools.js new file mode 100644 index 00000000..e1bfd708 --- /dev/null +++ b/docs/build/html/_static/doctools.js @@ -0,0 +1,358 @@ +/* + * doctools.js + * ~~~~~~~~~~~ + * + * Sphinx JavaScript utilities for all documentation. + * + * :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. + * :license: BSD, see LICENSE for details. + * + */ + +/** + * select a different prefix for underscore + */ +$u = _.noConflict(); + +/** + * make the code below compatible with browsers without + * an installed firebug like debugger +if (!window.console || !console.firebug) { + var names = ["log", "debug", "info", "warn", "error", "assert", "dir", + "dirxml", "group", "groupEnd", "time", "timeEnd", "count", "trace", + "profile", "profileEnd"]; + window.console = {}; + for (var i = 0; i < names.length; ++i) + window.console[names[i]] = function() {}; +} + */ + +/** + * small helper function to urldecode strings + * + * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/decodeURIComponent#Decoding_query_parameters_from_a_URL + */ +jQuery.urldecode = function(x) { + if (!x) { + return x + } + return decodeURIComponent(x.replace(/\+/g, ' ')); +}; + +/** + * small helper function to urlencode strings + */ +jQuery.urlencode = encodeURIComponent; + +/** + * This function returns the parsed url parameters of the + * current request. Multiple values per key are supported, + * it will always return arrays of strings for the value parts. + */ +jQuery.getQueryParameters = function(s) { + if (typeof s === 'undefined') + s = document.location.search; + var parts = s.substr(s.indexOf('?') + 1).split('&'); + var result = {}; + for (var i = 0; i < parts.length; i++) { + var tmp = parts[i].split('=', 2); + var key = jQuery.urldecode(tmp[0]); + var value = jQuery.urldecode(tmp[1]); + if (key in result) + result[key].push(value); + else + result[key] = [value]; + } + return result; +}; + +/** + * highlight a given string on a jquery object by wrapping it in + * span elements with the given class name. + */ +jQuery.fn.highlightText = function(text, className) { + function highlight(node, addItems) { + if (node.nodeType === 3) { + var val = node.nodeValue; + var pos = val.toLowerCase().indexOf(text); + if (pos >= 0 && + !jQuery(node.parentNode).hasClass(className) && + !jQuery(node.parentNode).hasClass("nohighlight")) { + var span; + var isInSVG = jQuery(node).closest("body, svg, foreignObject").is("svg"); + if (isInSVG) { + span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); + } else { + span = document.createElement("span"); + span.className = className; + } + span.appendChild(document.createTextNode(val.substr(pos, text.length))); + node.parentNode.insertBefore(span, node.parentNode.insertBefore( + document.createTextNode(val.substr(pos + text.length)), + node.nextSibling)); + node.nodeValue = val.substr(0, pos); + if (isInSVG) { + var rect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); + var bbox = node.parentElement.getBBox(); + rect.x.baseVal.value = bbox.x; + rect.y.baseVal.value = bbox.y; + rect.width.baseVal.value = bbox.width; + rect.height.baseVal.value = bbox.height; + rect.setAttribute('class', className); + addItems.push({ + "parent": node.parentNode, + "target": rect}); + } + } + } + else if (!jQuery(node).is("button, select, textarea")) { + jQuery.each(node.childNodes, function() { + highlight(this, addItems); + }); + } + } + var addItems = []; + var result = this.each(function() { + highlight(this, addItems); + }); + for (var i = 0; i < addItems.length; ++i) { + jQuery(addItems[i].parent).before(addItems[i].target); + } + return result; +}; + +/* + * backward compatibility for jQuery.browser + * This will be supported until firefox bug is fixed. + */ +if (!jQuery.browser) { + jQuery.uaMatch = function(ua) { + ua = ua.toLowerCase(); + + var match = /(chrome)[ \/]([\w.]+)/.exec(ua) || + /(webkit)[ \/]([\w.]+)/.exec(ua) || + /(opera)(?:.*version|)[ \/]([\w.]+)/.exec(ua) || + /(msie) ([\w.]+)/.exec(ua) || + ua.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(ua) || + []; + + return { + browser: match[ 1 ] || "", + version: match[ 2 ] || "0" + }; + }; + jQuery.browser = {}; + jQuery.browser[jQuery.uaMatch(navigator.userAgent).browser] = true; +} + +/** + * Small JavaScript module for the documentation. + */ +var Documentation = { + + init : function() { + this.fixFirefoxAnchorBug(); + this.highlightSearchWords(); + this.initIndexTable(); + this.initOnKeyListeners(); + }, + + /** + * i18n support + */ + TRANSLATIONS : {}, + PLURAL_EXPR : function(n) { return n === 1 ? 0 : 1; }, + LOCALE : 'unknown', + + // gettext and ngettext don't access this so that the functions + // can safely bound to a different name (_ = Documentation.gettext) + gettext : function(string) { + var translated = Documentation.TRANSLATIONS[string]; + if (typeof translated === 'undefined') + return string; + return (typeof translated === 'string') ? translated : translated[0]; + }, + + ngettext : function(singular, plural, n) { + var translated = Documentation.TRANSLATIONS[singular]; + if (typeof translated === 'undefined') + return (n == 1) ? singular : plural; + return translated[Documentation.PLURALEXPR(n)]; + }, + + addTranslations : function(catalog) { + for (var key in catalog.messages) + this.TRANSLATIONS[key] = catalog.messages[key]; + this.PLURAL_EXPR = new Function('n', 'return +(' + catalog.plural_expr + ')'); + this.LOCALE = catalog.locale; + }, + + /** + * add context elements like header anchor links + */ + addContextElements : function() { + $('div[id] > :header:first').each(function() { + $('\u00B6'). + attr('href', '#' + this.id). + attr('title', _('Permalink to this headline')). + appendTo(this); + }); + $('dt[id]').each(function() { + $('\u00B6'). + attr('href', '#' + this.id). + attr('title', _('Permalink to this definition')). + appendTo(this); + }); + }, + + /** + * workaround a firefox stupidity + * see: https://bugzilla.mozilla.org/show_bug.cgi?id=645075 + */ + fixFirefoxAnchorBug : function() { + if (document.location.hash && $.browser.mozilla) + window.setTimeout(function() { + document.location.href += ''; + }, 10); + }, + + /** + * highlight the search words provided in the url in the text + */ + highlightSearchWords : function() { + var params = $.getQueryParameters(); + var terms = (params.highlight) ? params.highlight[0].split(/\s+/) : []; + if (terms.length) { + var body = $('div.body'); + if (!body.length) { + body = $('body'); + } + window.setTimeout(function() { + $.each(terms, function() { + body.highlightText(this.toLowerCase(), 'highlighted'); + }); + }, 10); + $('') + .appendTo($('#searchbox')); + } + }, + + /** + * init the domain index toggle buttons + */ + initIndexTable : function() { + var togglers = $('img.toggler').click(function() { + var src = $(this).attr('src'); + var idnum = $(this).attr('id').substr(7); + $('tr.cg-' + idnum).toggle(); + if (src.substr(-9) === 'minus.png') + $(this).attr('src', src.substr(0, src.length-9) + 'plus.png'); + else + $(this).attr('src', src.substr(0, src.length-8) + 'minus.png'); + }).css('display', ''); + if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) { + togglers.click(); + } + }, + + /** + * helper function to hide the search marks again + */ + hideSearchWords : function() { + $('#searchbox .highlight-link').fadeOut(300); + $('span.highlighted').removeClass('highlighted'); + var url = new URL(window.location); + url.searchParams.delete('highlight'); + window.history.replaceState({}, '', url); + }, + + /** + * helper function to focus on search bar + */ + focusSearchBar : function() { + $('input[name=q]').first().focus(); + }, + + /** + * make the url absolute + */ + makeURL : function(relativeURL) { + return DOCUMENTATION_OPTIONS.URL_ROOT + '/' + relativeURL; + }, + + /** + * get the current relative url + */ + getCurrentURL : function() { + var path = document.location.pathname; + var parts = path.split(/\//); + $.each(DOCUMENTATION_OPTIONS.URL_ROOT.split(/\//), function() { + if (this === '..') + parts.pop(); + }); + var url = parts.join('/'); + return path.substring(url.lastIndexOf('/') + 1, path.length - 1); + }, + + initOnKeyListeners: function() { + // only install a listener if it is really needed + if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS && + !DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) + return; + + $(document).keydown(function(event) { + var activeElementType = document.activeElement.tagName; + // don't navigate when in search box, textarea, dropdown or button + if (activeElementType !== 'TEXTAREA' && activeElementType !== 'INPUT' && activeElementType !== 'SELECT' + && activeElementType !== 'BUTTON') { + if (event.altKey || event.ctrlKey || event.metaKey) + return; + + if (!event.shiftKey) { + switch (event.key) { + case 'ArrowLeft': + if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) + break; + var prevHref = $('link[rel="prev"]').prop('href'); + if (prevHref) { + window.location.href = prevHref; + return false; + } + break; + case 'ArrowRight': + if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) + break; + var nextHref = $('link[rel="next"]').prop('href'); + if (nextHref) { + window.location.href = nextHref; + return false; + } + break; + case 'Escape': + if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) + break; + Documentation.hideSearchWords(); + return false; + } + } + + // some keyboard layouts may need Shift to get / + switch (event.key) { + case '/': + if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) + break; + Documentation.focusSearchBar(); + return false; + } + } + }); + } +}; + +// quick alias for translations +_ = Documentation.gettext; + +$(document).ready(function() { + Documentation.init(); +}); diff --git a/docs/build/html/_static/documentation_options.js b/docs/build/html/_static/documentation_options.js new file mode 100644 index 00000000..7b09f8ca --- /dev/null +++ b/docs/build/html/_static/documentation_options.js @@ -0,0 +1,14 @@ +var DOCUMENTATION_OPTIONS = { + URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), + VERSION: 'v.0.0.1', + LANGUAGE: 'None', + COLLAPSE_INDEX: false, + BUILDER: 'html', + FILE_SUFFIX: '.html', + LINK_SUFFIX: '.html', + HAS_SOURCE: true, + SOURCELINK_SUFFIX: '.txt', + NAVIGATION_WITH_KEYS: false, + SHOW_SEARCH_SUMMARY: true, + ENABLE_SEARCH_SHORTCUTS: true, +}; diff --git a/docs/build/html/_static/file.png b/docs/build/html/_static/file.png new file mode 100644 index 00000000..a858a410 Binary files /dev/null and b/docs/build/html/_static/file.png differ diff --git a/docs/build/html/_static/jquery-3.5.1.js b/docs/build/html/_static/jquery-3.5.1.js new file mode 100644 index 00000000..50937333 --- /dev/null +++ b/docs/build/html/_static/jquery-3.5.1.js @@ -0,0 +1,10872 @@ +/*! + * jQuery JavaScript Library v3.5.1 + * https://jquery.com/ + * + * Includes Sizzle.js + * https://sizzlejs.com/ + * + * Copyright JS Foundation and other contributors + * Released under the MIT license + * https://jquery.org/license + * + * Date: 2020-05-04T22:49Z + */ +( function( global, factory ) { + + "use strict"; + + if ( typeof module === "object" && typeof module.exports === "object" ) { + + // For CommonJS and CommonJS-like environments where a proper `window` + // is present, execute the factory and get jQuery. + // For environments that do not have a `window` with a `document` + // (such as Node.js), expose a factory as module.exports. + // This accentuates the need for the creation of a real `window`. + // e.g. var jQuery = require("jquery")(window); + // See ticket #14549 for more info. + module.exports = global.document ? + factory( global, true ) : + function( w ) { + if ( !w.document ) { + throw new Error( "jQuery requires a window with a document" ); + } + return factory( w ); + }; + } else { + factory( global ); + } + +// Pass this if window is not defined yet +} )( typeof window !== "undefined" ? window : this, function( window, noGlobal ) { + +// Edge <= 12 - 13+, Firefox <=18 - 45+, IE 10 - 11, Safari 5.1 - 9+, iOS 6 - 9.1 +// throw exceptions when non-strict code (e.g., ASP.NET 4.5) accesses strict mode +// arguments.callee.caller (trac-13335). But as of jQuery 3.0 (2016), strict mode should be common +// enough that all such attempts are guarded in a try block. +"use strict"; + +var arr = []; + +var getProto = Object.getPrototypeOf; + +var slice = arr.slice; + +var flat = arr.flat ? function( array ) { + return arr.flat.call( array ); +} : function( array ) { + return arr.concat.apply( [], array ); +}; + + +var push = arr.push; + +var indexOf = arr.indexOf; + +var class2type = {}; + +var toString = class2type.toString; + +var hasOwn = class2type.hasOwnProperty; + +var fnToString = hasOwn.toString; + +var ObjectFunctionString = fnToString.call( Object ); + +var support = {}; + +var isFunction = function isFunction( obj ) { + + // Support: Chrome <=57, Firefox <=52 + // In some browsers, typeof returns "function" for HTML elements + // (i.e., `typeof document.createElement( "object" ) === "function"`). + // We don't want to classify *any* DOM node as a function. + return typeof obj === "function" && typeof obj.nodeType !== "number"; + }; + + +var isWindow = function isWindow( obj ) { + return obj != null && obj === obj.window; + }; + + +var document = window.document; + + + + var preservedScriptAttributes = { + type: true, + src: true, + nonce: true, + noModule: true + }; + + function DOMEval( code, node, doc ) { + doc = doc || document; + + var i, val, + script = doc.createElement( "script" ); + + script.text = code; + if ( node ) { + for ( i in preservedScriptAttributes ) { + + // Support: Firefox 64+, Edge 18+ + // Some browsers don't support the "nonce" property on scripts. + // On the other hand, just using `getAttribute` is not enough as + // the `nonce` attribute is reset to an empty string whenever it + // becomes browsing-context connected. + // See https://github.com/whatwg/html/issues/2369 + // See https://html.spec.whatwg.org/#nonce-attributes + // The `node.getAttribute` check was added for the sake of + // `jQuery.globalEval` so that it can fake a nonce-containing node + // via an object. + val = node[ i ] || node.getAttribute && node.getAttribute( i ); + if ( val ) { + script.setAttribute( i, val ); + } + } + } + doc.head.appendChild( script ).parentNode.removeChild( script ); + } + + +function toType( obj ) { + if ( obj == null ) { + return obj + ""; + } + + // Support: Android <=2.3 only (functionish RegExp) + return typeof obj === "object" || typeof obj === "function" ? + class2type[ toString.call( obj ) ] || "object" : + typeof obj; +} +/* global Symbol */ +// Defining this global in .eslintrc.json would create a danger of using the global +// unguarded in another place, it seems safer to define global only for this module + + + +var + version = "3.5.1", + + // Define a local copy of jQuery + jQuery = function( selector, context ) { + + // The jQuery object is actually just the init constructor 'enhanced' + // Need init if jQuery is called (just allow error to be thrown if not included) + return new jQuery.fn.init( selector, context ); + }; + +jQuery.fn = jQuery.prototype = { + + // The current version of jQuery being used + jquery: version, + + constructor: jQuery, + + // The default length of a jQuery object is 0 + length: 0, + + toArray: function() { + return slice.call( this ); + }, + + // Get the Nth element in the matched element set OR + // Get the whole matched element set as a clean array + get: function( num ) { + + // Return all the elements in a clean array + if ( num == null ) { + return slice.call( this ); + } + + // Return just the one element from the set + return num < 0 ? this[ num + this.length ] : this[ num ]; + }, + + // Take an array of elements and push it onto the stack + // (returning the new matched element set) + pushStack: function( elems ) { + + // Build a new jQuery matched element set + var ret = jQuery.merge( this.constructor(), elems ); + + // Add the old object onto the stack (as a reference) + ret.prevObject = this; + + // Return the newly-formed element set + return ret; + }, + + // Execute a callback for every element in the matched set. + each: function( callback ) { + return jQuery.each( this, callback ); + }, + + map: function( callback ) { + return this.pushStack( jQuery.map( this, function( elem, i ) { + return callback.call( elem, i, elem ); + } ) ); + }, + + slice: function() { + return this.pushStack( slice.apply( this, arguments ) ); + }, + + first: function() { + return this.eq( 0 ); + }, + + last: function() { + return this.eq( -1 ); + }, + + even: function() { + return this.pushStack( jQuery.grep( this, function( _elem, i ) { + return ( i + 1 ) % 2; + } ) ); + }, + + odd: function() { + return this.pushStack( jQuery.grep( this, function( _elem, i ) { + return i % 2; + } ) ); + }, + + eq: function( i ) { + var len = this.length, + j = +i + ( i < 0 ? len : 0 ); + return this.pushStack( j >= 0 && j < len ? [ this[ j ] ] : [] ); + }, + + end: function() { + return this.prevObject || this.constructor(); + }, + + // For internal use only. + // Behaves like an Array's method, not like a jQuery method. + push: push, + sort: arr.sort, + splice: arr.splice +}; + +jQuery.extend = jQuery.fn.extend = function() { + var options, name, src, copy, copyIsArray, clone, + target = arguments[ 0 ] || {}, + i = 1, + length = arguments.length, + deep = false; + + // Handle a deep copy situation + if ( typeof target === "boolean" ) { + deep = target; + + // Skip the boolean and the target + target = arguments[ i ] || {}; + i++; + } + + // Handle case when target is a string or something (possible in deep copy) + if ( typeof target !== "object" && !isFunction( target ) ) { + target = {}; + } + + // Extend jQuery itself if only one argument is passed + if ( i === length ) { + target = this; + i--; + } + + for ( ; i < length; i++ ) { + + // Only deal with non-null/undefined values + if ( ( options = arguments[ i ] ) != null ) { + + // Extend the base object + for ( name in options ) { + copy = options[ name ]; + + // Prevent Object.prototype pollution + // Prevent never-ending loop + if ( name === "__proto__" || target === copy ) { + continue; + } + + // Recurse if we're merging plain objects or arrays + if ( deep && copy && ( jQuery.isPlainObject( copy ) || + ( copyIsArray = Array.isArray( copy ) ) ) ) { + src = target[ name ]; + + // Ensure proper type for the source value + if ( copyIsArray && !Array.isArray( src ) ) { + clone = []; + } else if ( !copyIsArray && !jQuery.isPlainObject( src ) ) { + clone = {}; + } else { + clone = src; + } + copyIsArray = false; + + // Never move original objects, clone them + target[ name ] = jQuery.extend( deep, clone, copy ); + + // Don't bring in undefined values + } else if ( copy !== undefined ) { + target[ name ] = copy; + } + } + } + } + + // Return the modified object + return target; +}; + +jQuery.extend( { + + // Unique for each copy of jQuery on the page + expando: "jQuery" + ( version + Math.random() ).replace( /\D/g, "" ), + + // Assume jQuery is ready without the ready module + isReady: true, + + error: function( msg ) { + throw new Error( msg ); + }, + + noop: function() {}, + + isPlainObject: function( obj ) { + var proto, Ctor; + + // Detect obvious negatives + // Use toString instead of jQuery.type to catch host objects + if ( !obj || toString.call( obj ) !== "[object Object]" ) { + return false; + } + + proto = getProto( obj ); + + // Objects with no prototype (e.g., `Object.create( null )`) are plain + if ( !proto ) { + return true; + } + + // Objects with prototype are plain iff they were constructed by a global Object function + Ctor = hasOwn.call( proto, "constructor" ) && proto.constructor; + return typeof Ctor === "function" && fnToString.call( Ctor ) === ObjectFunctionString; + }, + + isEmptyObject: function( obj ) { + var name; + + for ( name in obj ) { + return false; + } + return true; + }, + + // Evaluates a script in a provided context; falls back to the global one + // if not specified. + globalEval: function( code, options, doc ) { + DOMEval( code, { nonce: options && options.nonce }, doc ); + }, + + each: function( obj, callback ) { + var length, i = 0; + + if ( isArrayLike( obj ) ) { + length = obj.length; + for ( ; i < length; i++ ) { + if ( callback.call( obj[ i ], i, obj[ i ] ) === false ) { + break; + } + } + } else { + for ( i in obj ) { + if ( callback.call( obj[ i ], i, obj[ i ] ) === false ) { + break; + } + } + } + + return obj; + }, + + // results is for internal usage only + makeArray: function( arr, results ) { + var ret = results || []; + + if ( arr != null ) { + if ( isArrayLike( Object( arr ) ) ) { + jQuery.merge( ret, + typeof arr === "string" ? + [ arr ] : arr + ); + } else { + push.call( ret, arr ); + } + } + + return ret; + }, + + inArray: function( elem, arr, i ) { + return arr == null ? -1 : indexOf.call( arr, elem, i ); + }, + + // Support: Android <=4.0 only, PhantomJS 1 only + // push.apply(_, arraylike) throws on ancient WebKit + merge: function( first, second ) { + var len = +second.length, + j = 0, + i = first.length; + + for ( ; j < len; j++ ) { + first[ i++ ] = second[ j ]; + } + + first.length = i; + + return first; + }, + + grep: function( elems, callback, invert ) { + var callbackInverse, + matches = [], + i = 0, + length = elems.length, + callbackExpect = !invert; + + // Go through the array, only saving the items + // that pass the validator function + for ( ; i < length; i++ ) { + callbackInverse = !callback( elems[ i ], i ); + if ( callbackInverse !== callbackExpect ) { + matches.push( elems[ i ] ); + } + } + + return matches; + }, + + // arg is for internal usage only + map: function( elems, callback, arg ) { + var length, value, + i = 0, + ret = []; + + // Go through the array, translating each of the items to their new values + if ( isArrayLike( elems ) ) { + length = elems.length; + for ( ; i < length; i++ ) { + value = callback( elems[ i ], i, arg ); + + if ( value != null ) { + ret.push( value ); + } + } + + // Go through every key on the object, + } else { + for ( i in elems ) { + value = callback( elems[ i ], i, arg ); + + if ( value != null ) { + ret.push( value ); + } + } + } + + // Flatten any nested arrays + return flat( ret ); + }, + + // A global GUID counter for objects + guid: 1, + + // jQuery.support is not used in Core but other projects attach their + // properties to it so it needs to exist. + support: support +} ); + +if ( typeof Symbol === "function" ) { + jQuery.fn[ Symbol.iterator ] = arr[ Symbol.iterator ]; +} + +// Populate the class2type map +jQuery.each( "Boolean Number String Function Array Date RegExp Object Error Symbol".split( " " ), +function( _i, name ) { + class2type[ "[object " + name + "]" ] = name.toLowerCase(); +} ); + +function isArrayLike( obj ) { + + // Support: real iOS 8.2 only (not reproducible in simulator) + // `in` check used to prevent JIT error (gh-2145) + // hasOwn isn't used here due to false negatives + // regarding Nodelist length in IE + var length = !!obj && "length" in obj && obj.length, + type = toType( obj ); + + if ( isFunction( obj ) || isWindow( obj ) ) { + return false; + } + + return type === "array" || length === 0 || + typeof length === "number" && length > 0 && ( length - 1 ) in obj; +} +var Sizzle = +/*! + * Sizzle CSS Selector Engine v2.3.5 + * https://sizzlejs.com/ + * + * Copyright JS Foundation and other contributors + * Released under the MIT license + * https://js.foundation/ + * + * Date: 2020-03-14 + */ +( function( window ) { +var i, + support, + Expr, + getText, + isXML, + tokenize, + compile, + select, + outermostContext, + sortInput, + hasDuplicate, + + // Local document vars + setDocument, + document, + docElem, + documentIsHTML, + rbuggyQSA, + rbuggyMatches, + matches, + contains, + + // Instance-specific data + expando = "sizzle" + 1 * new Date(), + preferredDoc = window.document, + dirruns = 0, + done = 0, + classCache = createCache(), + tokenCache = createCache(), + compilerCache = createCache(), + nonnativeSelectorCache = createCache(), + sortOrder = function( a, b ) { + if ( a === b ) { + hasDuplicate = true; + } + return 0; + }, + + // Instance methods + hasOwn = ( {} ).hasOwnProperty, + arr = [], + pop = arr.pop, + pushNative = arr.push, + push = arr.push, + slice = arr.slice, + + // Use a stripped-down indexOf as it's faster than native + // https://jsperf.com/thor-indexof-vs-for/5 + indexOf = function( list, elem ) { + var i = 0, + len = list.length; + for ( ; i < len; i++ ) { + if ( list[ i ] === elem ) { + return i; + } + } + return -1; + }, + + booleans = "checked|selected|async|autofocus|autoplay|controls|defer|disabled|hidden|" + + "ismap|loop|multiple|open|readonly|required|scoped", + + // Regular expressions + + // http://www.w3.org/TR/css3-selectors/#whitespace + whitespace = "[\\x20\\t\\r\\n\\f]", + + // https://www.w3.org/TR/css-syntax-3/#ident-token-diagram + identifier = "(?:\\\\[\\da-fA-F]{1,6}" + whitespace + + "?|\\\\[^\\r\\n\\f]|[\\w-]|[^\0-\\x7f])+", + + // Attribute selectors: http://www.w3.org/TR/selectors/#attribute-selectors + attributes = "\\[" + whitespace + "*(" + identifier + ")(?:" + whitespace + + + // Operator (capture 2) + "*([*^$|!~]?=)" + whitespace + + + // "Attribute values must be CSS identifiers [capture 5] + // or strings [capture 3 or capture 4]" + "*(?:'((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\"|(" + identifier + "))|)" + + whitespace + "*\\]", + + pseudos = ":(" + identifier + ")(?:\\((" + + + // To reduce the number of selectors needing tokenize in the preFilter, prefer arguments: + // 1. quoted (capture 3; capture 4 or capture 5) + "('((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\")|" + + + // 2. simple (capture 6) + "((?:\\\\.|[^\\\\()[\\]]|" + attributes + ")*)|" + + + // 3. anything else (capture 2) + ".*" + + ")\\)|)", + + // Leading and non-escaped trailing whitespace, capturing some non-whitespace characters preceding the latter + rwhitespace = new RegExp( whitespace + "+", "g" ), + rtrim = new RegExp( "^" + whitespace + "+|((?:^|[^\\\\])(?:\\\\.)*)" + + whitespace + "+$", "g" ), + + rcomma = new RegExp( "^" + whitespace + "*," + whitespace + "*" ), + rcombinators = new RegExp( "^" + whitespace + "*([>+~]|" + whitespace + ")" + whitespace + + "*" ), + rdescend = new RegExp( whitespace + "|>" ), + + rpseudo = new RegExp( pseudos ), + ridentifier = new RegExp( "^" + identifier + "$" ), + + matchExpr = { + "ID": new RegExp( "^#(" + identifier + ")" ), + "CLASS": new RegExp( "^\\.(" + identifier + ")" ), + "TAG": new RegExp( "^(" + identifier + "|[*])" ), + "ATTR": new RegExp( "^" + attributes ), + "PSEUDO": new RegExp( "^" + pseudos ), + "CHILD": new RegExp( "^:(only|first|last|nth|nth-last)-(child|of-type)(?:\\(" + + whitespace + "*(even|odd|(([+-]|)(\\d*)n|)" + whitespace + "*(?:([+-]|)" + + whitespace + "*(\\d+)|))" + whitespace + "*\\)|)", "i" ), + "bool": new RegExp( "^(?:" + booleans + ")$", "i" ), + + // For use in libraries implementing .is() + // We use this for POS matching in `select` + "needsContext": new RegExp( "^" + whitespace + + "*[>+~]|:(even|odd|eq|gt|lt|nth|first|last)(?:\\(" + whitespace + + "*((?:-\\d)?\\d*)" + whitespace + "*\\)|)(?=[^-]|$)", "i" ) + }, + + rhtml = /HTML$/i, + rinputs = /^(?:input|select|textarea|button)$/i, + rheader = /^h\d$/i, + + rnative = /^[^{]+\{\s*\[native \w/, + + // Easily-parseable/retrievable ID or TAG or CLASS selectors + rquickExpr = /^(?:#([\w-]+)|(\w+)|\.([\w-]+))$/, + + rsibling = /[+~]/, + + // CSS escapes + // http://www.w3.org/TR/CSS21/syndata.html#escaped-characters + runescape = new RegExp( "\\\\[\\da-fA-F]{1,6}" + whitespace + "?|\\\\([^\\r\\n\\f])", "g" ), + funescape = function( escape, nonHex ) { + var high = "0x" + escape.slice( 1 ) - 0x10000; + + return nonHex ? + + // Strip the backslash prefix from a non-hex escape sequence + nonHex : + + // Replace a hexadecimal escape sequence with the encoded Unicode code point + // Support: IE <=11+ + // For values outside the Basic Multilingual Plane (BMP), manually construct a + // surrogate pair + high < 0 ? + String.fromCharCode( high + 0x10000 ) : + String.fromCharCode( high >> 10 | 0xD800, high & 0x3FF | 0xDC00 ); + }, + + // CSS string/identifier serialization + // https://drafts.csswg.org/cssom/#common-serializing-idioms + rcssescape = /([\0-\x1f\x7f]|^-?\d)|^-$|[^\0-\x1f\x7f-\uFFFF\w-]/g, + fcssescape = function( ch, asCodePoint ) { + if ( asCodePoint ) { + + // U+0000 NULL becomes U+FFFD REPLACEMENT CHARACTER + if ( ch === "\0" ) { + return "\uFFFD"; + } + + // Control characters and (dependent upon position) numbers get escaped as code points + return ch.slice( 0, -1 ) + "\\" + + ch.charCodeAt( ch.length - 1 ).toString( 16 ) + " "; + } + + // Other potentially-special ASCII characters get backslash-escaped + return "\\" + ch; + }, + + // Used for iframes + // See setDocument() + // Removing the function wrapper causes a "Permission Denied" + // error in IE + unloadHandler = function() { + setDocument(); + }, + + inDisabledFieldset = addCombinator( + function( elem ) { + return elem.disabled === true && elem.nodeName.toLowerCase() === "fieldset"; + }, + { dir: "parentNode", next: "legend" } + ); + +// Optimize for push.apply( _, NodeList ) +try { + push.apply( + ( arr = slice.call( preferredDoc.childNodes ) ), + preferredDoc.childNodes + ); + + // Support: Android<4.0 + // Detect silently failing push.apply + // eslint-disable-next-line no-unused-expressions + arr[ preferredDoc.childNodes.length ].nodeType; +} catch ( e ) { + push = { apply: arr.length ? + + // Leverage slice if possible + function( target, els ) { + pushNative.apply( target, slice.call( els ) ); + } : + + // Support: IE<9 + // Otherwise append directly + function( target, els ) { + var j = target.length, + i = 0; + + // Can't trust NodeList.length + while ( ( target[ j++ ] = els[ i++ ] ) ) {} + target.length = j - 1; + } + }; +} + +function Sizzle( selector, context, results, seed ) { + var m, i, elem, nid, match, groups, newSelector, + newContext = context && context.ownerDocument, + + // nodeType defaults to 9, since context defaults to document + nodeType = context ? context.nodeType : 9; + + results = results || []; + + // Return early from calls with invalid selector or context + if ( typeof selector !== "string" || !selector || + nodeType !== 1 && nodeType !== 9 && nodeType !== 11 ) { + + return results; + } + + // Try to shortcut find operations (as opposed to filters) in HTML documents + if ( !seed ) { + setDocument( context ); + context = context || document; + + if ( documentIsHTML ) { + + // If the selector is sufficiently simple, try using a "get*By*" DOM method + // (excepting DocumentFragment context, where the methods don't exist) + if ( nodeType !== 11 && ( match = rquickExpr.exec( selector ) ) ) { + + // ID selector + if ( ( m = match[ 1 ] ) ) { + + // Document context + if ( nodeType === 9 ) { + if ( ( elem = context.getElementById( m ) ) ) { + + // Support: IE, Opera, Webkit + // TODO: identify versions + // getElementById can match elements by name instead of ID + if ( elem.id === m ) { + results.push( elem ); + return results; + } + } else { + return results; + } + + // Element context + } else { + + // Support: IE, Opera, Webkit + // TODO: identify versions + // getElementById can match elements by name instead of ID + if ( newContext && ( elem = newContext.getElementById( m ) ) && + contains( context, elem ) && + elem.id === m ) { + + results.push( elem ); + return results; + } + } + + // Type selector + } else if ( match[ 2 ] ) { + push.apply( results, context.getElementsByTagName( selector ) ); + return results; + + // Class selector + } else if ( ( m = match[ 3 ] ) && support.getElementsByClassName && + context.getElementsByClassName ) { + + push.apply( results, context.getElementsByClassName( m ) ); + return results; + } + } + + // Take advantage of querySelectorAll + if ( support.qsa && + !nonnativeSelectorCache[ selector + " " ] && + ( !rbuggyQSA || !rbuggyQSA.test( selector ) ) && + + // Support: IE 8 only + // Exclude object elements + ( nodeType !== 1 || context.nodeName.toLowerCase() !== "object" ) ) { + + newSelector = selector; + newContext = context; + + // qSA considers elements outside a scoping root when evaluating child or + // descendant combinators, which is not what we want. + // In such cases, we work around the behavior by prefixing every selector in the + // list with an ID selector referencing the scope context. + // The technique has to be used as well when a leading combinator is used + // as such selectors are not recognized by querySelectorAll. + // Thanks to Andrew Dupont for this technique. + if ( nodeType === 1 && + ( rdescend.test( selector ) || rcombinators.test( selector ) ) ) { + + // Expand context for sibling selectors + newContext = rsibling.test( selector ) && testContext( context.parentNode ) || + context; + + // We can use :scope instead of the ID hack if the browser + // supports it & if we're not changing the context. + if ( newContext !== context || !support.scope ) { + + // Capture the context ID, setting it first if necessary + if ( ( nid = context.getAttribute( "id" ) ) ) { + nid = nid.replace( rcssescape, fcssescape ); + } else { + context.setAttribute( "id", ( nid = expando ) ); + } + } + + // Prefix every selector in the list + groups = tokenize( selector ); + i = groups.length; + while ( i-- ) { + groups[ i ] = ( nid ? "#" + nid : ":scope" ) + " " + + toSelector( groups[ i ] ); + } + newSelector = groups.join( "," ); + } + + try { + push.apply( results, + newContext.querySelectorAll( newSelector ) + ); + return results; + } catch ( qsaError ) { + nonnativeSelectorCache( selector, true ); + } finally { + if ( nid === expando ) { + context.removeAttribute( "id" ); + } + } + } + } + } + + // All others + return select( selector.replace( rtrim, "$1" ), context, results, seed ); +} + +/** + * Create key-value caches of limited size + * @returns {function(string, object)} Returns the Object data after storing it on itself with + * property name the (space-suffixed) string and (if the cache is larger than Expr.cacheLength) + * deleting the oldest entry + */ +function createCache() { + var keys = []; + + function cache( key, value ) { + + // Use (key + " ") to avoid collision with native prototype properties (see Issue #157) + if ( keys.push( key + " " ) > Expr.cacheLength ) { + + // Only keep the most recent entries + delete cache[ keys.shift() ]; + } + return ( cache[ key + " " ] = value ); + } + return cache; +} + +/** + * Mark a function for special use by Sizzle + * @param {Function} fn The function to mark + */ +function markFunction( fn ) { + fn[ expando ] = true; + return fn; +} + +/** + * Support testing using an element + * @param {Function} fn Passed the created element and returns a boolean result + */ +function assert( fn ) { + var el = document.createElement( "fieldset" ); + + try { + return !!fn( el ); + } catch ( e ) { + return false; + } finally { + + // Remove from its parent by default + if ( el.parentNode ) { + el.parentNode.removeChild( el ); + } + + // release memory in IE + el = null; + } +} + +/** + * Adds the same handler for all of the specified attrs + * @param {String} attrs Pipe-separated list of attributes + * @param {Function} handler The method that will be applied + */ +function addHandle( attrs, handler ) { + var arr = attrs.split( "|" ), + i = arr.length; + + while ( i-- ) { + Expr.attrHandle[ arr[ i ] ] = handler; + } +} + +/** + * Checks document order of two siblings + * @param {Element} a + * @param {Element} b + * @returns {Number} Returns less than 0 if a precedes b, greater than 0 if a follows b + */ +function siblingCheck( a, b ) { + var cur = b && a, + diff = cur && a.nodeType === 1 && b.nodeType === 1 && + a.sourceIndex - b.sourceIndex; + + // Use IE sourceIndex if available on both nodes + if ( diff ) { + return diff; + } + + // Check if b follows a + if ( cur ) { + while ( ( cur = cur.nextSibling ) ) { + if ( cur === b ) { + return -1; + } + } + } + + return a ? 1 : -1; +} + +/** + * Returns a function to use in pseudos for input types + * @param {String} type + */ +function createInputPseudo( type ) { + return function( elem ) { + var name = elem.nodeName.toLowerCase(); + return name === "input" && elem.type === type; + }; +} + +/** + * Returns a function to use in pseudos for buttons + * @param {String} type + */ +function createButtonPseudo( type ) { + return function( elem ) { + var name = elem.nodeName.toLowerCase(); + return ( name === "input" || name === "button" ) && elem.type === type; + }; +} + +/** + * Returns a function to use in pseudos for :enabled/:disabled + * @param {Boolean} disabled true for :disabled; false for :enabled + */ +function createDisabledPseudo( disabled ) { + + // Known :disabled false positives: fieldset[disabled] > legend:nth-of-type(n+2) :can-disable + return function( elem ) { + + // Only certain elements can match :enabled or :disabled + // https://html.spec.whatwg.org/multipage/scripting.html#selector-enabled + // https://html.spec.whatwg.org/multipage/scripting.html#selector-disabled + if ( "form" in elem ) { + + // Check for inherited disabledness on relevant non-disabled elements: + // * listed form-associated elements in a disabled fieldset + // https://html.spec.whatwg.org/multipage/forms.html#category-listed + // https://html.spec.whatwg.org/multipage/forms.html#concept-fe-disabled + // * option elements in a disabled optgroup + // https://html.spec.whatwg.org/multipage/forms.html#concept-option-disabled + // All such elements have a "form" property. + if ( elem.parentNode && elem.disabled === false ) { + + // Option elements defer to a parent optgroup if present + if ( "label" in elem ) { + if ( "label" in elem.parentNode ) { + return elem.parentNode.disabled === disabled; + } else { + return elem.disabled === disabled; + } + } + + // Support: IE 6 - 11 + // Use the isDisabled shortcut property to check for disabled fieldset ancestors + return elem.isDisabled === disabled || + + // Where there is no isDisabled, check manually + /* jshint -W018 */ + elem.isDisabled !== !disabled && + inDisabledFieldset( elem ) === disabled; + } + + return elem.disabled === disabled; + + // Try to winnow out elements that can't be disabled before trusting the disabled property. + // Some victims get caught in our net (label, legend, menu, track), but it shouldn't + // even exist on them, let alone have a boolean value. + } else if ( "label" in elem ) { + return elem.disabled === disabled; + } + + // Remaining elements are neither :enabled nor :disabled + return false; + }; +} + +/** + * Returns a function to use in pseudos for positionals + * @param {Function} fn + */ +function createPositionalPseudo( fn ) { + return markFunction( function( argument ) { + argument = +argument; + return markFunction( function( seed, matches ) { + var j, + matchIndexes = fn( [], seed.length, argument ), + i = matchIndexes.length; + + // Match elements found at the specified indexes + while ( i-- ) { + if ( seed[ ( j = matchIndexes[ i ] ) ] ) { + seed[ j ] = !( matches[ j ] = seed[ j ] ); + } + } + } ); + } ); +} + +/** + * Checks a node for validity as a Sizzle context + * @param {Element|Object=} context + * @returns {Element|Object|Boolean} The input node if acceptable, otherwise a falsy value + */ +function testContext( context ) { + return context && typeof context.getElementsByTagName !== "undefined" && context; +} + +// Expose support vars for convenience +support = Sizzle.support = {}; + +/** + * Detects XML nodes + * @param {Element|Object} elem An element or a document + * @returns {Boolean} True iff elem is a non-HTML XML node + */ +isXML = Sizzle.isXML = function( elem ) { + var namespace = elem.namespaceURI, + docElem = ( elem.ownerDocument || elem ).documentElement; + + // Support: IE <=8 + // Assume HTML when documentElement doesn't yet exist, such as inside loading iframes + // https://bugs.jquery.com/ticket/4833 + return !rhtml.test( namespace || docElem && docElem.nodeName || "HTML" ); +}; + +/** + * Sets document-related variables once based on the current document + * @param {Element|Object} [doc] An element or document object to use to set the document + * @returns {Object} Returns the current document + */ +setDocument = Sizzle.setDocument = function( node ) { + var hasCompare, subWindow, + doc = node ? node.ownerDocument || node : preferredDoc; + + // Return early if doc is invalid or already selected + // Support: IE 11+, Edge 17 - 18+ + // IE/Edge sometimes throw a "Permission denied" error when strict-comparing + // two documents; shallow comparisons work. + // eslint-disable-next-line eqeqeq + if ( doc == document || doc.nodeType !== 9 || !doc.documentElement ) { + return document; + } + + // Update global variables + document = doc; + docElem = document.documentElement; + documentIsHTML = !isXML( document ); + + // Support: IE 9 - 11+, Edge 12 - 18+ + // Accessing iframe documents after unload throws "permission denied" errors (jQuery #13936) + // Support: IE 11+, Edge 17 - 18+ + // IE/Edge sometimes throw a "Permission denied" error when strict-comparing + // two documents; shallow comparisons work. + // eslint-disable-next-line eqeqeq + if ( preferredDoc != document && + ( subWindow = document.defaultView ) && subWindow.top !== subWindow ) { + + // Support: IE 11, Edge + if ( subWindow.addEventListener ) { + subWindow.addEventListener( "unload", unloadHandler, false ); + + // Support: IE 9 - 10 only + } else if ( subWindow.attachEvent ) { + subWindow.attachEvent( "onunload", unloadHandler ); + } + } + + // Support: IE 8 - 11+, Edge 12 - 18+, Chrome <=16 - 25 only, Firefox <=3.6 - 31 only, + // Safari 4 - 5 only, Opera <=11.6 - 12.x only + // IE/Edge & older browsers don't support the :scope pseudo-class. + // Support: Safari 6.0 only + // Safari 6.0 supports :scope but it's an alias of :root there. + support.scope = assert( function( el ) { + docElem.appendChild( el ).appendChild( document.createElement( "div" ) ); + return typeof el.querySelectorAll !== "undefined" && + !el.querySelectorAll( ":scope fieldset div" ).length; + } ); + + /* Attributes + ---------------------------------------------------------------------- */ + + // Support: IE<8 + // Verify that getAttribute really returns attributes and not properties + // (excepting IE8 booleans) + support.attributes = assert( function( el ) { + el.className = "i"; + return !el.getAttribute( "className" ); + } ); + + /* getElement(s)By* + ---------------------------------------------------------------------- */ + + // Check if getElementsByTagName("*") returns only elements + support.getElementsByTagName = assert( function( el ) { + el.appendChild( document.createComment( "" ) ); + return !el.getElementsByTagName( "*" ).length; + } ); + + // Support: IE<9 + support.getElementsByClassName = rnative.test( document.getElementsByClassName ); + + // Support: IE<10 + // Check if getElementById returns elements by name + // The broken getElementById methods don't pick up programmatically-set names, + // so use a roundabout getElementsByName test + support.getById = assert( function( el ) { + docElem.appendChild( el ).id = expando; + return !document.getElementsByName || !document.getElementsByName( expando ).length; + } ); + + // ID filter and find + if ( support.getById ) { + Expr.filter[ "ID" ] = function( id ) { + var attrId = id.replace( runescape, funescape ); + return function( elem ) { + return elem.getAttribute( "id" ) === attrId; + }; + }; + Expr.find[ "ID" ] = function( id, context ) { + if ( typeof context.getElementById !== "undefined" && documentIsHTML ) { + var elem = context.getElementById( id ); + return elem ? [ elem ] : []; + } + }; + } else { + Expr.filter[ "ID" ] = function( id ) { + var attrId = id.replace( runescape, funescape ); + return function( elem ) { + var node = typeof elem.getAttributeNode !== "undefined" && + elem.getAttributeNode( "id" ); + return node && node.value === attrId; + }; + }; + + // Support: IE 6 - 7 only + // getElementById is not reliable as a find shortcut + Expr.find[ "ID" ] = function( id, context ) { + if ( typeof context.getElementById !== "undefined" && documentIsHTML ) { + var node, i, elems, + elem = context.getElementById( id ); + + if ( elem ) { + + // Verify the id attribute + node = elem.getAttributeNode( "id" ); + if ( node && node.value === id ) { + return [ elem ]; + } + + // Fall back on getElementsByName + elems = context.getElementsByName( id ); + i = 0; + while ( ( elem = elems[ i++ ] ) ) { + node = elem.getAttributeNode( "id" ); + if ( node && node.value === id ) { + return [ elem ]; + } + } + } + + return []; + } + }; + } + + // Tag + Expr.find[ "TAG" ] = support.getElementsByTagName ? + function( tag, context ) { + if ( typeof context.getElementsByTagName !== "undefined" ) { + return context.getElementsByTagName( tag ); + + // DocumentFragment nodes don't have gEBTN + } else if ( support.qsa ) { + return context.querySelectorAll( tag ); + } + } : + + function( tag, context ) { + var elem, + tmp = [], + i = 0, + + // By happy coincidence, a (broken) gEBTN appears on DocumentFragment nodes too + results = context.getElementsByTagName( tag ); + + // Filter out possible comments + if ( tag === "*" ) { + while ( ( elem = results[ i++ ] ) ) { + if ( elem.nodeType === 1 ) { + tmp.push( elem ); + } + } + + return tmp; + } + return results; + }; + + // Class + Expr.find[ "CLASS" ] = support.getElementsByClassName && function( className, context ) { + if ( typeof context.getElementsByClassName !== "undefined" && documentIsHTML ) { + return context.getElementsByClassName( className ); + } + }; + + /* QSA/matchesSelector + ---------------------------------------------------------------------- */ + + // QSA and matchesSelector support + + // matchesSelector(:active) reports false when true (IE9/Opera 11.5) + rbuggyMatches = []; + + // qSa(:focus) reports false when true (Chrome 21) + // We allow this because of a bug in IE8/9 that throws an error + // whenever `document.activeElement` is accessed on an iframe + // So, we allow :focus to pass through QSA all the time to avoid the IE error + // See https://bugs.jquery.com/ticket/13378 + rbuggyQSA = []; + + if ( ( support.qsa = rnative.test( document.querySelectorAll ) ) ) { + + // Build QSA regex + // Regex strategy adopted from Diego Perini + assert( function( el ) { + + var input; + + // Select is set to empty string on purpose + // This is to test IE's treatment of not explicitly + // setting a boolean content attribute, + // since its presence should be enough + // https://bugs.jquery.com/ticket/12359 + docElem.appendChild( el ).innerHTML = "" + + ""; + + // Support: IE8, Opera 11-12.16 + // Nothing should be selected when empty strings follow ^= or $= or *= + // The test attribute must be unknown in Opera but "safe" for WinRT + // https://msdn.microsoft.com/en-us/library/ie/hh465388.aspx#attribute_section + if ( el.querySelectorAll( "[msallowcapture^='']" ).length ) { + rbuggyQSA.push( "[*^$]=" + whitespace + "*(?:''|\"\")" ); + } + + // Support: IE8 + // Boolean attributes and "value" are not treated correctly + if ( !el.querySelectorAll( "[selected]" ).length ) { + rbuggyQSA.push( "\\[" + whitespace + "*(?:value|" + booleans + ")" ); + } + + // Support: Chrome<29, Android<4.4, Safari<7.0+, iOS<7.0+, PhantomJS<1.9.8+ + if ( !el.querySelectorAll( "[id~=" + expando + "-]" ).length ) { + rbuggyQSA.push( "~=" ); + } + + // Support: IE 11+, Edge 15 - 18+ + // IE 11/Edge don't find elements on a `[name='']` query in some cases. + // Adding a temporary attribute to the document before the selection works + // around the issue. + // Interestingly, IE 10 & older don't seem to have the issue. + input = document.createElement( "input" ); + input.setAttribute( "name", "" ); + el.appendChild( input ); + if ( !el.querySelectorAll( "[name='']" ).length ) { + rbuggyQSA.push( "\\[" + whitespace + "*name" + whitespace + "*=" + + whitespace + "*(?:''|\"\")" ); + } + + // Webkit/Opera - :checked should return selected option elements + // http://www.w3.org/TR/2011/REC-css3-selectors-20110929/#checked + // IE8 throws error here and will not see later tests + if ( !el.querySelectorAll( ":checked" ).length ) { + rbuggyQSA.push( ":checked" ); + } + + // Support: Safari 8+, iOS 8+ + // https://bugs.webkit.org/show_bug.cgi?id=136851 + // In-page `selector#id sibling-combinator selector` fails + if ( !el.querySelectorAll( "a#" + expando + "+*" ).length ) { + rbuggyQSA.push( ".#.+[+~]" ); + } + + // Support: Firefox <=3.6 - 5 only + // Old Firefox doesn't throw on a badly-escaped identifier. + el.querySelectorAll( "\\\f" ); + rbuggyQSA.push( "[\\r\\n\\f]" ); + } ); + + assert( function( el ) { + el.innerHTML = "" + + ""; + + // Support: Windows 8 Native Apps + // The type and name attributes are restricted during .innerHTML assignment + var input = document.createElement( "input" ); + input.setAttribute( "type", "hidden" ); + el.appendChild( input ).setAttribute( "name", "D" ); + + // Support: IE8 + // Enforce case-sensitivity of name attribute + if ( el.querySelectorAll( "[name=d]" ).length ) { + rbuggyQSA.push( "name" + whitespace + "*[*^$|!~]?=" ); + } + + // FF 3.5 - :enabled/:disabled and hidden elements (hidden elements are still enabled) + // IE8 throws error here and will not see later tests + if ( el.querySelectorAll( ":enabled" ).length !== 2 ) { + rbuggyQSA.push( ":enabled", ":disabled" ); + } + + // Support: IE9-11+ + // IE's :disabled selector does not pick up the children of disabled fieldsets + docElem.appendChild( el ).disabled = true; + if ( el.querySelectorAll( ":disabled" ).length !== 2 ) { + rbuggyQSA.push( ":enabled", ":disabled" ); + } + + // Support: Opera 10 - 11 only + // Opera 10-11 does not throw on post-comma invalid pseudos + el.querySelectorAll( "*,:x" ); + rbuggyQSA.push( ",.*:" ); + } ); + } + + if ( ( support.matchesSelector = rnative.test( ( matches = docElem.matches || + docElem.webkitMatchesSelector || + docElem.mozMatchesSelector || + docElem.oMatchesSelector || + docElem.msMatchesSelector ) ) ) ) { + + assert( function( el ) { + + // Check to see if it's possible to do matchesSelector + // on a disconnected node (IE 9) + support.disconnectedMatch = matches.call( el, "*" ); + + // This should fail with an exception + // Gecko does not error, returns false instead + matches.call( el, "[s!='']:x" ); + rbuggyMatches.push( "!=", pseudos ); + } ); + } + + rbuggyQSA = rbuggyQSA.length && new RegExp( rbuggyQSA.join( "|" ) ); + rbuggyMatches = rbuggyMatches.length && new RegExp( rbuggyMatches.join( "|" ) ); + + /* Contains + ---------------------------------------------------------------------- */ + hasCompare = rnative.test( docElem.compareDocumentPosition ); + + // Element contains another + // Purposefully self-exclusive + // As in, an element does not contain itself + contains = hasCompare || rnative.test( docElem.contains ) ? + function( a, b ) { + var adown = a.nodeType === 9 ? a.documentElement : a, + bup = b && b.parentNode; + return a === bup || !!( bup && bup.nodeType === 1 && ( + adown.contains ? + adown.contains( bup ) : + a.compareDocumentPosition && a.compareDocumentPosition( bup ) & 16 + ) ); + } : + function( a, b ) { + if ( b ) { + while ( ( b = b.parentNode ) ) { + if ( b === a ) { + return true; + } + } + } + return false; + }; + + /* Sorting + ---------------------------------------------------------------------- */ + + // Document order sorting + sortOrder = hasCompare ? + function( a, b ) { + + // Flag for duplicate removal + if ( a === b ) { + hasDuplicate = true; + return 0; + } + + // Sort on method existence if only one input has compareDocumentPosition + var compare = !a.compareDocumentPosition - !b.compareDocumentPosition; + if ( compare ) { + return compare; + } + + // Calculate position if both inputs belong to the same document + // Support: IE 11+, Edge 17 - 18+ + // IE/Edge sometimes throw a "Permission denied" error when strict-comparing + // two documents; shallow comparisons work. + // eslint-disable-next-line eqeqeq + compare = ( a.ownerDocument || a ) == ( b.ownerDocument || b ) ? + a.compareDocumentPosition( b ) : + + // Otherwise we know they are disconnected + 1; + + // Disconnected nodes + if ( compare & 1 || + ( !support.sortDetached && b.compareDocumentPosition( a ) === compare ) ) { + + // Choose the first element that is related to our preferred document + // Support: IE 11+, Edge 17 - 18+ + // IE/Edge sometimes throw a "Permission denied" error when strict-comparing + // two documents; shallow comparisons work. + // eslint-disable-next-line eqeqeq + if ( a == document || a.ownerDocument == preferredDoc && + contains( preferredDoc, a ) ) { + return -1; + } + + // Support: IE 11+, Edge 17 - 18+ + // IE/Edge sometimes throw a "Permission denied" error when strict-comparing + // two documents; shallow comparisons work. + // eslint-disable-next-line eqeqeq + if ( b == document || b.ownerDocument == preferredDoc && + contains( preferredDoc, b ) ) { + return 1; + } + + // Maintain original order + return sortInput ? + ( indexOf( sortInput, a ) - indexOf( sortInput, b ) ) : + 0; + } + + return compare & 4 ? -1 : 1; + } : + function( a, b ) { + + // Exit early if the nodes are identical + if ( a === b ) { + hasDuplicate = true; + return 0; + } + + var cur, + i = 0, + aup = a.parentNode, + bup = b.parentNode, + ap = [ a ], + bp = [ b ]; + + // Parentless nodes are either documents or disconnected + if ( !aup || !bup ) { + + // Support: IE 11+, Edge 17 - 18+ + // IE/Edge sometimes throw a "Permission denied" error when strict-comparing + // two documents; shallow comparisons work. + /* eslint-disable eqeqeq */ + return a == document ? -1 : + b == document ? 1 : + /* eslint-enable eqeqeq */ + aup ? -1 : + bup ? 1 : + sortInput ? + ( indexOf( sortInput, a ) - indexOf( sortInput, b ) ) : + 0; + + // If the nodes are siblings, we can do a quick check + } else if ( aup === bup ) { + return siblingCheck( a, b ); + } + + // Otherwise we need full lists of their ancestors for comparison + cur = a; + while ( ( cur = cur.parentNode ) ) { + ap.unshift( cur ); + } + cur = b; + while ( ( cur = cur.parentNode ) ) { + bp.unshift( cur ); + } + + // Walk down the tree looking for a discrepancy + while ( ap[ i ] === bp[ i ] ) { + i++; + } + + return i ? + + // Do a sibling check if the nodes have a common ancestor + siblingCheck( ap[ i ], bp[ i ] ) : + + // Otherwise nodes in our document sort first + // Support: IE 11+, Edge 17 - 18+ + // IE/Edge sometimes throw a "Permission denied" error when strict-comparing + // two documents; shallow comparisons work. + /* eslint-disable eqeqeq */ + ap[ i ] == preferredDoc ? -1 : + bp[ i ] == preferredDoc ? 1 : + /* eslint-enable eqeqeq */ + 0; + }; + + return document; +}; + +Sizzle.matches = function( expr, elements ) { + return Sizzle( expr, null, null, elements ); +}; + +Sizzle.matchesSelector = function( elem, expr ) { + setDocument( elem ); + + if ( support.matchesSelector && documentIsHTML && + !nonnativeSelectorCache[ expr + " " ] && + ( !rbuggyMatches || !rbuggyMatches.test( expr ) ) && + ( !rbuggyQSA || !rbuggyQSA.test( expr ) ) ) { + + try { + var ret = matches.call( elem, expr ); + + // IE 9's matchesSelector returns false on disconnected nodes + if ( ret || support.disconnectedMatch || + + // As well, disconnected nodes are said to be in a document + // fragment in IE 9 + elem.document && elem.document.nodeType !== 11 ) { + return ret; + } + } catch ( e ) { + nonnativeSelectorCache( expr, true ); + } + } + + return Sizzle( expr, document, null, [ elem ] ).length > 0; +}; + +Sizzle.contains = function( context, elem ) { + + // Set document vars if needed + // Support: IE 11+, Edge 17 - 18+ + // IE/Edge sometimes throw a "Permission denied" error when strict-comparing + // two documents; shallow comparisons work. + // eslint-disable-next-line eqeqeq + if ( ( context.ownerDocument || context ) != document ) { + setDocument( context ); + } + return contains( context, elem ); +}; + +Sizzle.attr = function( elem, name ) { + + // Set document vars if needed + // Support: IE 11+, Edge 17 - 18+ + // IE/Edge sometimes throw a "Permission denied" error when strict-comparing + // two documents; shallow comparisons work. + // eslint-disable-next-line eqeqeq + if ( ( elem.ownerDocument || elem ) != document ) { + setDocument( elem ); + } + + var fn = Expr.attrHandle[ name.toLowerCase() ], + + // Don't get fooled by Object.prototype properties (jQuery #13807) + val = fn && hasOwn.call( Expr.attrHandle, name.toLowerCase() ) ? + fn( elem, name, !documentIsHTML ) : + undefined; + + return val !== undefined ? + val : + support.attributes || !documentIsHTML ? + elem.getAttribute( name ) : + ( val = elem.getAttributeNode( name ) ) && val.specified ? + val.value : + null; +}; + +Sizzle.escape = function( sel ) { + return ( sel + "" ).replace( rcssescape, fcssescape ); +}; + +Sizzle.error = function( msg ) { + throw new Error( "Syntax error, unrecognized expression: " + msg ); +}; + +/** + * Document sorting and removing duplicates + * @param {ArrayLike} results + */ +Sizzle.uniqueSort = function( results ) { + var elem, + duplicates = [], + j = 0, + i = 0; + + // Unless we *know* we can detect duplicates, assume their presence + hasDuplicate = !support.detectDuplicates; + sortInput = !support.sortStable && results.slice( 0 ); + results.sort( sortOrder ); + + if ( hasDuplicate ) { + while ( ( elem = results[ i++ ] ) ) { + if ( elem === results[ i ] ) { + j = duplicates.push( i ); + } + } + while ( j-- ) { + results.splice( duplicates[ j ], 1 ); + } + } + + // Clear input after sorting to release objects + // See https://github.com/jquery/sizzle/pull/225 + sortInput = null; + + return results; +}; + +/** + * Utility function for retrieving the text value of an array of DOM nodes + * @param {Array|Element} elem + */ +getText = Sizzle.getText = function( elem ) { + var node, + ret = "", + i = 0, + nodeType = elem.nodeType; + + if ( !nodeType ) { + + // If no nodeType, this is expected to be an array + while ( ( node = elem[ i++ ] ) ) { + + // Do not traverse comment nodes + ret += getText( node ); + } + } else if ( nodeType === 1 || nodeType === 9 || nodeType === 11 ) { + + // Use textContent for elements + // innerText usage removed for consistency of new lines (jQuery #11153) + if ( typeof elem.textContent === "string" ) { + return elem.textContent; + } else { + + // Traverse its children + for ( elem = elem.firstChild; elem; elem = elem.nextSibling ) { + ret += getText( elem ); + } + } + } else if ( nodeType === 3 || nodeType === 4 ) { + return elem.nodeValue; + } + + // Do not include comment or processing instruction nodes + + return ret; +}; + +Expr = Sizzle.selectors = { + + // Can be adjusted by the user + cacheLength: 50, + + createPseudo: markFunction, + + match: matchExpr, + + attrHandle: {}, + + find: {}, + + relative: { + ">": { dir: "parentNode", first: true }, + " ": { dir: "parentNode" }, + "+": { dir: "previousSibling", first: true }, + "~": { dir: "previousSibling" } + }, + + preFilter: { + "ATTR": function( match ) { + match[ 1 ] = match[ 1 ].replace( runescape, funescape ); + + // Move the given value to match[3] whether quoted or unquoted + match[ 3 ] = ( match[ 3 ] || match[ 4 ] || + match[ 5 ] || "" ).replace( runescape, funescape ); + + if ( match[ 2 ] === "~=" ) { + match[ 3 ] = " " + match[ 3 ] + " "; + } + + return match.slice( 0, 4 ); + }, + + "CHILD": function( match ) { + + /* matches from matchExpr["CHILD"] + 1 type (only|nth|...) + 2 what (child|of-type) + 3 argument (even|odd|\d*|\d*n([+-]\d+)?|...) + 4 xn-component of xn+y argument ([+-]?\d*n|) + 5 sign of xn-component + 6 x of xn-component + 7 sign of y-component + 8 y of y-component + */ + match[ 1 ] = match[ 1 ].toLowerCase(); + + if ( match[ 1 ].slice( 0, 3 ) === "nth" ) { + + // nth-* requires argument + if ( !match[ 3 ] ) { + Sizzle.error( match[ 0 ] ); + } + + // numeric x and y parameters for Expr.filter.CHILD + // remember that false/true cast respectively to 0/1 + match[ 4 ] = +( match[ 4 ] ? + match[ 5 ] + ( match[ 6 ] || 1 ) : + 2 * ( match[ 3 ] === "even" || match[ 3 ] === "odd" ) ); + match[ 5 ] = +( ( match[ 7 ] + match[ 8 ] ) || match[ 3 ] === "odd" ); + + // other types prohibit arguments + } else if ( match[ 3 ] ) { + Sizzle.error( match[ 0 ] ); + } + + return match; + }, + + "PSEUDO": function( match ) { + var excess, + unquoted = !match[ 6 ] && match[ 2 ]; + + if ( matchExpr[ "CHILD" ].test( match[ 0 ] ) ) { + return null; + } + + // Accept quoted arguments as-is + if ( match[ 3 ] ) { + match[ 2 ] = match[ 4 ] || match[ 5 ] || ""; + + // Strip excess characters from unquoted arguments + } else if ( unquoted && rpseudo.test( unquoted ) && + + // Get excess from tokenize (recursively) + ( excess = tokenize( unquoted, true ) ) && + + // advance to the next closing parenthesis + ( excess = unquoted.indexOf( ")", unquoted.length - excess ) - unquoted.length ) ) { + + // excess is a negative index + match[ 0 ] = match[ 0 ].slice( 0, excess ); + match[ 2 ] = unquoted.slice( 0, excess ); + } + + // Return only captures needed by the pseudo filter method (type and argument) + return match.slice( 0, 3 ); + } + }, + + filter: { + + "TAG": function( nodeNameSelector ) { + var nodeName = nodeNameSelector.replace( runescape, funescape ).toLowerCase(); + return nodeNameSelector === "*" ? + function() { + return true; + } : + function( elem ) { + return elem.nodeName && elem.nodeName.toLowerCase() === nodeName; + }; + }, + + "CLASS": function( className ) { + var pattern = classCache[ className + " " ]; + + return pattern || + ( pattern = new RegExp( "(^|" + whitespace + + ")" + className + "(" + whitespace + "|$)" ) ) && classCache( + className, function( elem ) { + return pattern.test( + typeof elem.className === "string" && elem.className || + typeof elem.getAttribute !== "undefined" && + elem.getAttribute( "class" ) || + "" + ); + } ); + }, + + "ATTR": function( name, operator, check ) { + return function( elem ) { + var result = Sizzle.attr( elem, name ); + + if ( result == null ) { + return operator === "!="; + } + if ( !operator ) { + return true; + } + + result += ""; + + /* eslint-disable max-len */ + + return operator === "=" ? result === check : + operator === "!=" ? result !== check : + operator === "^=" ? check && result.indexOf( check ) === 0 : + operator === "*=" ? check && result.indexOf( check ) > -1 : + operator === "$=" ? check && result.slice( -check.length ) === check : + operator === "~=" ? ( " " + result.replace( rwhitespace, " " ) + " " ).indexOf( check ) > -1 : + operator === "|=" ? result === check || result.slice( 0, check.length + 1 ) === check + "-" : + false; + /* eslint-enable max-len */ + + }; + }, + + "CHILD": function( type, what, _argument, first, last ) { + var simple = type.slice( 0, 3 ) !== "nth", + forward = type.slice( -4 ) !== "last", + ofType = what === "of-type"; + + return first === 1 && last === 0 ? + + // Shortcut for :nth-*(n) + function( elem ) { + return !!elem.parentNode; + } : + + function( elem, _context, xml ) { + var cache, uniqueCache, outerCache, node, nodeIndex, start, + dir = simple !== forward ? "nextSibling" : "previousSibling", + parent = elem.parentNode, + name = ofType && elem.nodeName.toLowerCase(), + useCache = !xml && !ofType, + diff = false; + + if ( parent ) { + + // :(first|last|only)-(child|of-type) + if ( simple ) { + while ( dir ) { + node = elem; + while ( ( node = node[ dir ] ) ) { + if ( ofType ? + node.nodeName.toLowerCase() === name : + node.nodeType === 1 ) { + + return false; + } + } + + // Reverse direction for :only-* (if we haven't yet done so) + start = dir = type === "only" && !start && "nextSibling"; + } + return true; + } + + start = [ forward ? parent.firstChild : parent.lastChild ]; + + // non-xml :nth-child(...) stores cache data on `parent` + if ( forward && useCache ) { + + // Seek `elem` from a previously-cached index + + // ...in a gzip-friendly way + node = parent; + outerCache = node[ expando ] || ( node[ expando ] = {} ); + + // Support: IE <9 only + // Defend against cloned attroperties (jQuery gh-1709) + uniqueCache = outerCache[ node.uniqueID ] || + ( outerCache[ node.uniqueID ] = {} ); + + cache = uniqueCache[ type ] || []; + nodeIndex = cache[ 0 ] === dirruns && cache[ 1 ]; + diff = nodeIndex && cache[ 2 ]; + node = nodeIndex && parent.childNodes[ nodeIndex ]; + + while ( ( node = ++nodeIndex && node && node[ dir ] || + + // Fallback to seeking `elem` from the start + ( diff = nodeIndex = 0 ) || start.pop() ) ) { + + // When found, cache indexes on `parent` and break + if ( node.nodeType === 1 && ++diff && node === elem ) { + uniqueCache[ type ] = [ dirruns, nodeIndex, diff ]; + break; + } + } + + } else { + + // Use previously-cached element index if available + if ( useCache ) { + + // ...in a gzip-friendly way + node = elem; + outerCache = node[ expando ] || ( node[ expando ] = {} ); + + // Support: IE <9 only + // Defend against cloned attroperties (jQuery gh-1709) + uniqueCache = outerCache[ node.uniqueID ] || + ( outerCache[ node.uniqueID ] = {} ); + + cache = uniqueCache[ type ] || []; + nodeIndex = cache[ 0 ] === dirruns && cache[ 1 ]; + diff = nodeIndex; + } + + // xml :nth-child(...) + // or :nth-last-child(...) or :nth(-last)?-of-type(...) + if ( diff === false ) { + + // Use the same loop as above to seek `elem` from the start + while ( ( node = ++nodeIndex && node && node[ dir ] || + ( diff = nodeIndex = 0 ) || start.pop() ) ) { + + if ( ( ofType ? + node.nodeName.toLowerCase() === name : + node.nodeType === 1 ) && + ++diff ) { + + // Cache the index of each encountered element + if ( useCache ) { + outerCache = node[ expando ] || + ( node[ expando ] = {} ); + + // Support: IE <9 only + // Defend against cloned attroperties (jQuery gh-1709) + uniqueCache = outerCache[ node.uniqueID ] || + ( outerCache[ node.uniqueID ] = {} ); + + uniqueCache[ type ] = [ dirruns, diff ]; + } + + if ( node === elem ) { + break; + } + } + } + } + } + + // Incorporate the offset, then check against cycle size + diff -= last; + return diff === first || ( diff % first === 0 && diff / first >= 0 ); + } + }; + }, + + "PSEUDO": function( pseudo, argument ) { + + // pseudo-class names are case-insensitive + // http://www.w3.org/TR/selectors/#pseudo-classes + // Prioritize by case sensitivity in case custom pseudos are added with uppercase letters + // Remember that setFilters inherits from pseudos + var args, + fn = Expr.pseudos[ pseudo ] || Expr.setFilters[ pseudo.toLowerCase() ] || + Sizzle.error( "unsupported pseudo: " + pseudo ); + + // The user may use createPseudo to indicate that + // arguments are needed to create the filter function + // just as Sizzle does + if ( fn[ expando ] ) { + return fn( argument ); + } + + // But maintain support for old signatures + if ( fn.length > 1 ) { + args = [ pseudo, pseudo, "", argument ]; + return Expr.setFilters.hasOwnProperty( pseudo.toLowerCase() ) ? + markFunction( function( seed, matches ) { + var idx, + matched = fn( seed, argument ), + i = matched.length; + while ( i-- ) { + idx = indexOf( seed, matched[ i ] ); + seed[ idx ] = !( matches[ idx ] = matched[ i ] ); + } + } ) : + function( elem ) { + return fn( elem, 0, args ); + }; + } + + return fn; + } + }, + + pseudos: { + + // Potentially complex pseudos + "not": markFunction( function( selector ) { + + // Trim the selector passed to compile + // to avoid treating leading and trailing + // spaces as combinators + var input = [], + results = [], + matcher = compile( selector.replace( rtrim, "$1" ) ); + + return matcher[ expando ] ? + markFunction( function( seed, matches, _context, xml ) { + var elem, + unmatched = matcher( seed, null, xml, [] ), + i = seed.length; + + // Match elements unmatched by `matcher` + while ( i-- ) { + if ( ( elem = unmatched[ i ] ) ) { + seed[ i ] = !( matches[ i ] = elem ); + } + } + } ) : + function( elem, _context, xml ) { + input[ 0 ] = elem; + matcher( input, null, xml, results ); + + // Don't keep the element (issue #299) + input[ 0 ] = null; + return !results.pop(); + }; + } ), + + "has": markFunction( function( selector ) { + return function( elem ) { + return Sizzle( selector, elem ).length > 0; + }; + } ), + + "contains": markFunction( function( text ) { + text = text.replace( runescape, funescape ); + return function( elem ) { + return ( elem.textContent || getText( elem ) ).indexOf( text ) > -1; + }; + } ), + + // "Whether an element is represented by a :lang() selector + // is based solely on the element's language value + // being equal to the identifier C, + // or beginning with the identifier C immediately followed by "-". + // The matching of C against the element's language value is performed case-insensitively. + // The identifier C does not have to be a valid language name." + // http://www.w3.org/TR/selectors/#lang-pseudo + "lang": markFunction( function( lang ) { + + // lang value must be a valid identifier + if ( !ridentifier.test( lang || "" ) ) { + Sizzle.error( "unsupported lang: " + lang ); + } + lang = lang.replace( runescape, funescape ).toLowerCase(); + return function( elem ) { + var elemLang; + do { + if ( ( elemLang = documentIsHTML ? + elem.lang : + elem.getAttribute( "xml:lang" ) || elem.getAttribute( "lang" ) ) ) { + + elemLang = elemLang.toLowerCase(); + return elemLang === lang || elemLang.indexOf( lang + "-" ) === 0; + } + } while ( ( elem = elem.parentNode ) && elem.nodeType === 1 ); + return false; + }; + } ), + + // Miscellaneous + "target": function( elem ) { + var hash = window.location && window.location.hash; + return hash && hash.slice( 1 ) === elem.id; + }, + + "root": function( elem ) { + return elem === docElem; + }, + + "focus": function( elem ) { + return elem === document.activeElement && + ( !document.hasFocus || document.hasFocus() ) && + !!( elem.type || elem.href || ~elem.tabIndex ); + }, + + // Boolean properties + "enabled": createDisabledPseudo( false ), + "disabled": createDisabledPseudo( true ), + + "checked": function( elem ) { + + // In CSS3, :checked should return both checked and selected elements + // http://www.w3.org/TR/2011/REC-css3-selectors-20110929/#checked + var nodeName = elem.nodeName.toLowerCase(); + return ( nodeName === "input" && !!elem.checked ) || + ( nodeName === "option" && !!elem.selected ); + }, + + "selected": function( elem ) { + + // Accessing this property makes selected-by-default + // options in Safari work properly + if ( elem.parentNode ) { + // eslint-disable-next-line no-unused-expressions + elem.parentNode.selectedIndex; + } + + return elem.selected === true; + }, + + // Contents + "empty": function( elem ) { + + // http://www.w3.org/TR/selectors/#empty-pseudo + // :empty is negated by element (1) or content nodes (text: 3; cdata: 4; entity ref: 5), + // but not by others (comment: 8; processing instruction: 7; etc.) + // nodeType < 6 works because attributes (2) do not appear as children + for ( elem = elem.firstChild; elem; elem = elem.nextSibling ) { + if ( elem.nodeType < 6 ) { + return false; + } + } + return true; + }, + + "parent": function( elem ) { + return !Expr.pseudos[ "empty" ]( elem ); + }, + + // Element/input types + "header": function( elem ) { + return rheader.test( elem.nodeName ); + }, + + "input": function( elem ) { + return rinputs.test( elem.nodeName ); + }, + + "button": function( elem ) { + var name = elem.nodeName.toLowerCase(); + return name === "input" && elem.type === "button" || name === "button"; + }, + + "text": function( elem ) { + var attr; + return elem.nodeName.toLowerCase() === "input" && + elem.type === "text" && + + // Support: IE<8 + // New HTML5 attribute values (e.g., "search") appear with elem.type === "text" + ( ( attr = elem.getAttribute( "type" ) ) == null || + attr.toLowerCase() === "text" ); + }, + + // Position-in-collection + "first": createPositionalPseudo( function() { + return [ 0 ]; + } ), + + "last": createPositionalPseudo( function( _matchIndexes, length ) { + return [ length - 1 ]; + } ), + + "eq": createPositionalPseudo( function( _matchIndexes, length, argument ) { + return [ argument < 0 ? argument + length : argument ]; + } ), + + "even": createPositionalPseudo( function( matchIndexes, length ) { + var i = 0; + for ( ; i < length; i += 2 ) { + matchIndexes.push( i ); + } + return matchIndexes; + } ), + + "odd": createPositionalPseudo( function( matchIndexes, length ) { + var i = 1; + for ( ; i < length; i += 2 ) { + matchIndexes.push( i ); + } + return matchIndexes; + } ), + + "lt": createPositionalPseudo( function( matchIndexes, length, argument ) { + var i = argument < 0 ? + argument + length : + argument > length ? + length : + argument; + for ( ; --i >= 0; ) { + matchIndexes.push( i ); + } + return matchIndexes; + } ), + + "gt": createPositionalPseudo( function( matchIndexes, length, argument ) { + var i = argument < 0 ? argument + length : argument; + for ( ; ++i < length; ) { + matchIndexes.push( i ); + } + return matchIndexes; + } ) + } +}; + +Expr.pseudos[ "nth" ] = Expr.pseudos[ "eq" ]; + +// Add button/input type pseudos +for ( i in { radio: true, checkbox: true, file: true, password: true, image: true } ) { + Expr.pseudos[ i ] = createInputPseudo( i ); +} +for ( i in { submit: true, reset: true } ) { + Expr.pseudos[ i ] = createButtonPseudo( i ); +} + +// Easy API for creating new setFilters +function setFilters() {} +setFilters.prototype = Expr.filters = Expr.pseudos; +Expr.setFilters = new setFilters(); + +tokenize = Sizzle.tokenize = function( selector, parseOnly ) { + var matched, match, tokens, type, + soFar, groups, preFilters, + cached = tokenCache[ selector + " " ]; + + if ( cached ) { + return parseOnly ? 0 : cached.slice( 0 ); + } + + soFar = selector; + groups = []; + preFilters = Expr.preFilter; + + while ( soFar ) { + + // Comma and first run + if ( !matched || ( match = rcomma.exec( soFar ) ) ) { + if ( match ) { + + // Don't consume trailing commas as valid + soFar = soFar.slice( match[ 0 ].length ) || soFar; + } + groups.push( ( tokens = [] ) ); + } + + matched = false; + + // Combinators + if ( ( match = rcombinators.exec( soFar ) ) ) { + matched = match.shift(); + tokens.push( { + value: matched, + + // Cast descendant combinators to space + type: match[ 0 ].replace( rtrim, " " ) + } ); + soFar = soFar.slice( matched.length ); + } + + // Filters + for ( type in Expr.filter ) { + if ( ( match = matchExpr[ type ].exec( soFar ) ) && ( !preFilters[ type ] || + ( match = preFilters[ type ]( match ) ) ) ) { + matched = match.shift(); + tokens.push( { + value: matched, + type: type, + matches: match + } ); + soFar = soFar.slice( matched.length ); + } + } + + if ( !matched ) { + break; + } + } + + // Return the length of the invalid excess + // if we're just parsing + // Otherwise, throw an error or return tokens + return parseOnly ? + soFar.length : + soFar ? + Sizzle.error( selector ) : + + // Cache the tokens + tokenCache( selector, groups ).slice( 0 ); +}; + +function toSelector( tokens ) { + var i = 0, + len = tokens.length, + selector = ""; + for ( ; i < len; i++ ) { + selector += tokens[ i ].value; + } + return selector; +} + +function addCombinator( matcher, combinator, base ) { + var dir = combinator.dir, + skip = combinator.next, + key = skip || dir, + checkNonElements = base && key === "parentNode", + doneName = done++; + + return combinator.first ? + + // Check against closest ancestor/preceding element + function( elem, context, xml ) { + while ( ( elem = elem[ dir ] ) ) { + if ( elem.nodeType === 1 || checkNonElements ) { + return matcher( elem, context, xml ); + } + } + return false; + } : + + // Check against all ancestor/preceding elements + function( elem, context, xml ) { + var oldCache, uniqueCache, outerCache, + newCache = [ dirruns, doneName ]; + + // We can't set arbitrary data on XML nodes, so they don't benefit from combinator caching + if ( xml ) { + while ( ( elem = elem[ dir ] ) ) { + if ( elem.nodeType === 1 || checkNonElements ) { + if ( matcher( elem, context, xml ) ) { + return true; + } + } + } + } else { + while ( ( elem = elem[ dir ] ) ) { + if ( elem.nodeType === 1 || checkNonElements ) { + outerCache = elem[ expando ] || ( elem[ expando ] = {} ); + + // Support: IE <9 only + // Defend against cloned attroperties (jQuery gh-1709) + uniqueCache = outerCache[ elem.uniqueID ] || + ( outerCache[ elem.uniqueID ] = {} ); + + if ( skip && skip === elem.nodeName.toLowerCase() ) { + elem = elem[ dir ] || elem; + } else if ( ( oldCache = uniqueCache[ key ] ) && + oldCache[ 0 ] === dirruns && oldCache[ 1 ] === doneName ) { + + // Assign to newCache so results back-propagate to previous elements + return ( newCache[ 2 ] = oldCache[ 2 ] ); + } else { + + // Reuse newcache so results back-propagate to previous elements + uniqueCache[ key ] = newCache; + + // A match means we're done; a fail means we have to keep checking + if ( ( newCache[ 2 ] = matcher( elem, context, xml ) ) ) { + return true; + } + } + } + } + } + return false; + }; +} + +function elementMatcher( matchers ) { + return matchers.length > 1 ? + function( elem, context, xml ) { + var i = matchers.length; + while ( i-- ) { + if ( !matchers[ i ]( elem, context, xml ) ) { + return false; + } + } + return true; + } : + matchers[ 0 ]; +} + +function multipleContexts( selector, contexts, results ) { + var i = 0, + len = contexts.length; + for ( ; i < len; i++ ) { + Sizzle( selector, contexts[ i ], results ); + } + return results; +} + +function condense( unmatched, map, filter, context, xml ) { + var elem, + newUnmatched = [], + i = 0, + len = unmatched.length, + mapped = map != null; + + for ( ; i < len; i++ ) { + if ( ( elem = unmatched[ i ] ) ) { + if ( !filter || filter( elem, context, xml ) ) { + newUnmatched.push( elem ); + if ( mapped ) { + map.push( i ); + } + } + } + } + + return newUnmatched; +} + +function setMatcher( preFilter, selector, matcher, postFilter, postFinder, postSelector ) { + if ( postFilter && !postFilter[ expando ] ) { + postFilter = setMatcher( postFilter ); + } + if ( postFinder && !postFinder[ expando ] ) { + postFinder = setMatcher( postFinder, postSelector ); + } + return markFunction( function( seed, results, context, xml ) { + var temp, i, elem, + preMap = [], + postMap = [], + preexisting = results.length, + + // Get initial elements from seed or context + elems = seed || multipleContexts( + selector || "*", + context.nodeType ? [ context ] : context, + [] + ), + + // Prefilter to get matcher input, preserving a map for seed-results synchronization + matcherIn = preFilter && ( seed || !selector ) ? + condense( elems, preMap, preFilter, context, xml ) : + elems, + + matcherOut = matcher ? + + // If we have a postFinder, or filtered seed, or non-seed postFilter or preexisting results, + postFinder || ( seed ? preFilter : preexisting || postFilter ) ? + + // ...intermediate processing is necessary + [] : + + // ...otherwise use results directly + results : + matcherIn; + + // Find primary matches + if ( matcher ) { + matcher( matcherIn, matcherOut, context, xml ); + } + + // Apply postFilter + if ( postFilter ) { + temp = condense( matcherOut, postMap ); + postFilter( temp, [], context, xml ); + + // Un-match failing elements by moving them back to matcherIn + i = temp.length; + while ( i-- ) { + if ( ( elem = temp[ i ] ) ) { + matcherOut[ postMap[ i ] ] = !( matcherIn[ postMap[ i ] ] = elem ); + } + } + } + + if ( seed ) { + if ( postFinder || preFilter ) { + if ( postFinder ) { + + // Get the final matcherOut by condensing this intermediate into postFinder contexts + temp = []; + i = matcherOut.length; + while ( i-- ) { + if ( ( elem = matcherOut[ i ] ) ) { + + // Restore matcherIn since elem is not yet a final match + temp.push( ( matcherIn[ i ] = elem ) ); + } + } + postFinder( null, ( matcherOut = [] ), temp, xml ); + } + + // Move matched elements from seed to results to keep them synchronized + i = matcherOut.length; + while ( i-- ) { + if ( ( elem = matcherOut[ i ] ) && + ( temp = postFinder ? indexOf( seed, elem ) : preMap[ i ] ) > -1 ) { + + seed[ temp ] = !( results[ temp ] = elem ); + } + } + } + + // Add elements to results, through postFinder if defined + } else { + matcherOut = condense( + matcherOut === results ? + matcherOut.splice( preexisting, matcherOut.length ) : + matcherOut + ); + if ( postFinder ) { + postFinder( null, results, matcherOut, xml ); + } else { + push.apply( results, matcherOut ); + } + } + } ); +} + +function matcherFromTokens( tokens ) { + var checkContext, matcher, j, + len = tokens.length, + leadingRelative = Expr.relative[ tokens[ 0 ].type ], + implicitRelative = leadingRelative || Expr.relative[ " " ], + i = leadingRelative ? 1 : 0, + + // The foundational matcher ensures that elements are reachable from top-level context(s) + matchContext = addCombinator( function( elem ) { + return elem === checkContext; + }, implicitRelative, true ), + matchAnyContext = addCombinator( function( elem ) { + return indexOf( checkContext, elem ) > -1; + }, implicitRelative, true ), + matchers = [ function( elem, context, xml ) { + var ret = ( !leadingRelative && ( xml || context !== outermostContext ) ) || ( + ( checkContext = context ).nodeType ? + matchContext( elem, context, xml ) : + matchAnyContext( elem, context, xml ) ); + + // Avoid hanging onto element (issue #299) + checkContext = null; + return ret; + } ]; + + for ( ; i < len; i++ ) { + if ( ( matcher = Expr.relative[ tokens[ i ].type ] ) ) { + matchers = [ addCombinator( elementMatcher( matchers ), matcher ) ]; + } else { + matcher = Expr.filter[ tokens[ i ].type ].apply( null, tokens[ i ].matches ); + + // Return special upon seeing a positional matcher + if ( matcher[ expando ] ) { + + // Find the next relative operator (if any) for proper handling + j = ++i; + for ( ; j < len; j++ ) { + if ( Expr.relative[ tokens[ j ].type ] ) { + break; + } + } + return setMatcher( + i > 1 && elementMatcher( matchers ), + i > 1 && toSelector( + + // If the preceding token was a descendant combinator, insert an implicit any-element `*` + tokens + .slice( 0, i - 1 ) + .concat( { value: tokens[ i - 2 ].type === " " ? "*" : "" } ) + ).replace( rtrim, "$1" ), + matcher, + i < j && matcherFromTokens( tokens.slice( i, j ) ), + j < len && matcherFromTokens( ( tokens = tokens.slice( j ) ) ), + j < len && toSelector( tokens ) + ); + } + matchers.push( matcher ); + } + } + + return elementMatcher( matchers ); +} + +function matcherFromGroupMatchers( elementMatchers, setMatchers ) { + var bySet = setMatchers.length > 0, + byElement = elementMatchers.length > 0, + superMatcher = function( seed, context, xml, results, outermost ) { + var elem, j, matcher, + matchedCount = 0, + i = "0", + unmatched = seed && [], + setMatched = [], + contextBackup = outermostContext, + + // We must always have either seed elements or outermost context + elems = seed || byElement && Expr.find[ "TAG" ]( "*", outermost ), + + // Use integer dirruns iff this is the outermost matcher + dirrunsUnique = ( dirruns += contextBackup == null ? 1 : Math.random() || 0.1 ), + len = elems.length; + + if ( outermost ) { + + // Support: IE 11+, Edge 17 - 18+ + // IE/Edge sometimes throw a "Permission denied" error when strict-comparing + // two documents; shallow comparisons work. + // eslint-disable-next-line eqeqeq + outermostContext = context == document || context || outermost; + } + + // Add elements passing elementMatchers directly to results + // Support: IE<9, Safari + // Tolerate NodeList properties (IE: "length"; Safari: ) matching elements by id + for ( ; i !== len && ( elem = elems[ i ] ) != null; i++ ) { + if ( byElement && elem ) { + j = 0; + + // Support: IE 11+, Edge 17 - 18+ + // IE/Edge sometimes throw a "Permission denied" error when strict-comparing + // two documents; shallow comparisons work. + // eslint-disable-next-line eqeqeq + if ( !context && elem.ownerDocument != document ) { + setDocument( elem ); + xml = !documentIsHTML; + } + while ( ( matcher = elementMatchers[ j++ ] ) ) { + if ( matcher( elem, context || document, xml ) ) { + results.push( elem ); + break; + } + } + if ( outermost ) { + dirruns = dirrunsUnique; + } + } + + // Track unmatched elements for set filters + if ( bySet ) { + + // They will have gone through all possible matchers + if ( ( elem = !matcher && elem ) ) { + matchedCount--; + } + + // Lengthen the array for every element, matched or not + if ( seed ) { + unmatched.push( elem ); + } + } + } + + // `i` is now the count of elements visited above, and adding it to `matchedCount` + // makes the latter nonnegative. + matchedCount += i; + + // Apply set filters to unmatched elements + // NOTE: This can be skipped if there are no unmatched elements (i.e., `matchedCount` + // equals `i`), unless we didn't visit _any_ elements in the above loop because we have + // no element matchers and no seed. + // Incrementing an initially-string "0" `i` allows `i` to remain a string only in that + // case, which will result in a "00" `matchedCount` that differs from `i` but is also + // numerically zero. + if ( bySet && i !== matchedCount ) { + j = 0; + while ( ( matcher = setMatchers[ j++ ] ) ) { + matcher( unmatched, setMatched, context, xml ); + } + + if ( seed ) { + + // Reintegrate element matches to eliminate the need for sorting + if ( matchedCount > 0 ) { + while ( i-- ) { + if ( !( unmatched[ i ] || setMatched[ i ] ) ) { + setMatched[ i ] = pop.call( results ); + } + } + } + + // Discard index placeholder values to get only actual matches + setMatched = condense( setMatched ); + } + + // Add matches to results + push.apply( results, setMatched ); + + // Seedless set matches succeeding multiple successful matchers stipulate sorting + if ( outermost && !seed && setMatched.length > 0 && + ( matchedCount + setMatchers.length ) > 1 ) { + + Sizzle.uniqueSort( results ); + } + } + + // Override manipulation of globals by nested matchers + if ( outermost ) { + dirruns = dirrunsUnique; + outermostContext = contextBackup; + } + + return unmatched; + }; + + return bySet ? + markFunction( superMatcher ) : + superMatcher; +} + +compile = Sizzle.compile = function( selector, match /* Internal Use Only */ ) { + var i, + setMatchers = [], + elementMatchers = [], + cached = compilerCache[ selector + " " ]; + + if ( !cached ) { + + // Generate a function of recursive functions that can be used to check each element + if ( !match ) { + match = tokenize( selector ); + } + i = match.length; + while ( i-- ) { + cached = matcherFromTokens( match[ i ] ); + if ( cached[ expando ] ) { + setMatchers.push( cached ); + } else { + elementMatchers.push( cached ); + } + } + + // Cache the compiled function + cached = compilerCache( + selector, + matcherFromGroupMatchers( elementMatchers, setMatchers ) + ); + + // Save selector and tokenization + cached.selector = selector; + } + return cached; +}; + +/** + * A low-level selection function that works with Sizzle's compiled + * selector functions + * @param {String|Function} selector A selector or a pre-compiled + * selector function built with Sizzle.compile + * @param {Element} context + * @param {Array} [results] + * @param {Array} [seed] A set of elements to match against + */ +select = Sizzle.select = function( selector, context, results, seed ) { + var i, tokens, token, type, find, + compiled = typeof selector === "function" && selector, + match = !seed && tokenize( ( selector = compiled.selector || selector ) ); + + results = results || []; + + // Try to minimize operations if there is only one selector in the list and no seed + // (the latter of which guarantees us context) + if ( match.length === 1 ) { + + // Reduce context if the leading compound selector is an ID + tokens = match[ 0 ] = match[ 0 ].slice( 0 ); + if ( tokens.length > 2 && ( token = tokens[ 0 ] ).type === "ID" && + context.nodeType === 9 && documentIsHTML && Expr.relative[ tokens[ 1 ].type ] ) { + + context = ( Expr.find[ "ID" ]( token.matches[ 0 ] + .replace( runescape, funescape ), context ) || [] )[ 0 ]; + if ( !context ) { + return results; + + // Precompiled matchers will still verify ancestry, so step up a level + } else if ( compiled ) { + context = context.parentNode; + } + + selector = selector.slice( tokens.shift().value.length ); + } + + // Fetch a seed set for right-to-left matching + i = matchExpr[ "needsContext" ].test( selector ) ? 0 : tokens.length; + while ( i-- ) { + token = tokens[ i ]; + + // Abort if we hit a combinator + if ( Expr.relative[ ( type = token.type ) ] ) { + break; + } + if ( ( find = Expr.find[ type ] ) ) { + + // Search, expanding context for leading sibling combinators + if ( ( seed = find( + token.matches[ 0 ].replace( runescape, funescape ), + rsibling.test( tokens[ 0 ].type ) && testContext( context.parentNode ) || + context + ) ) ) { + + // If seed is empty or no tokens remain, we can return early + tokens.splice( i, 1 ); + selector = seed.length && toSelector( tokens ); + if ( !selector ) { + push.apply( results, seed ); + return results; + } + + break; + } + } + } + } + + // Compile and execute a filtering function if one is not provided + // Provide `match` to avoid retokenization if we modified the selector above + ( compiled || compile( selector, match ) )( + seed, + context, + !documentIsHTML, + results, + !context || rsibling.test( selector ) && testContext( context.parentNode ) || context + ); + return results; +}; + +// One-time assignments + +// Sort stability +support.sortStable = expando.split( "" ).sort( sortOrder ).join( "" ) === expando; + +// Support: Chrome 14-35+ +// Always assume duplicates if they aren't passed to the comparison function +support.detectDuplicates = !!hasDuplicate; + +// Initialize against the default document +setDocument(); + +// Support: Webkit<537.32 - Safari 6.0.3/Chrome 25 (fixed in Chrome 27) +// Detached nodes confoundingly follow *each other* +support.sortDetached = assert( function( el ) { + + // Should return 1, but returns 4 (following) + return el.compareDocumentPosition( document.createElement( "fieldset" ) ) & 1; +} ); + +// Support: IE<8 +// Prevent attribute/property "interpolation" +// https://msdn.microsoft.com/en-us/library/ms536429%28VS.85%29.aspx +if ( !assert( function( el ) { + el.innerHTML = ""; + return el.firstChild.getAttribute( "href" ) === "#"; +} ) ) { + addHandle( "type|href|height|width", function( elem, name, isXML ) { + if ( !isXML ) { + return elem.getAttribute( name, name.toLowerCase() === "type" ? 1 : 2 ); + } + } ); +} + +// Support: IE<9 +// Use defaultValue in place of getAttribute("value") +if ( !support.attributes || !assert( function( el ) { + el.innerHTML = ""; + el.firstChild.setAttribute( "value", "" ); + return el.firstChild.getAttribute( "value" ) === ""; +} ) ) { + addHandle( "value", function( elem, _name, isXML ) { + if ( !isXML && elem.nodeName.toLowerCase() === "input" ) { + return elem.defaultValue; + } + } ); +} + +// Support: IE<9 +// Use getAttributeNode to fetch booleans when getAttribute lies +if ( !assert( function( el ) { + return el.getAttribute( "disabled" ) == null; +} ) ) { + addHandle( booleans, function( elem, name, isXML ) { + var val; + if ( !isXML ) { + return elem[ name ] === true ? name.toLowerCase() : + ( val = elem.getAttributeNode( name ) ) && val.specified ? + val.value : + null; + } + } ); +} + +return Sizzle; + +} )( window ); + + + +jQuery.find = Sizzle; +jQuery.expr = Sizzle.selectors; + +// Deprecated +jQuery.expr[ ":" ] = jQuery.expr.pseudos; +jQuery.uniqueSort = jQuery.unique = Sizzle.uniqueSort; +jQuery.text = Sizzle.getText; +jQuery.isXMLDoc = Sizzle.isXML; +jQuery.contains = Sizzle.contains; +jQuery.escapeSelector = Sizzle.escape; + + + + +var dir = function( elem, dir, until ) { + var matched = [], + truncate = until !== undefined; + + while ( ( elem = elem[ dir ] ) && elem.nodeType !== 9 ) { + if ( elem.nodeType === 1 ) { + if ( truncate && jQuery( elem ).is( until ) ) { + break; + } + matched.push( elem ); + } + } + return matched; +}; + + +var siblings = function( n, elem ) { + var matched = []; + + for ( ; n; n = n.nextSibling ) { + if ( n.nodeType === 1 && n !== elem ) { + matched.push( n ); + } + } + + return matched; +}; + + +var rneedsContext = jQuery.expr.match.needsContext; + + + +function nodeName( elem, name ) { + + return elem.nodeName && elem.nodeName.toLowerCase() === name.toLowerCase(); + +}; +var rsingleTag = ( /^<([a-z][^\/\0>:\x20\t\r\n\f]*)[\x20\t\r\n\f]*\/?>(?:<\/\1>|)$/i ); + + + +// Implement the identical functionality for filter and not +function winnow( elements, qualifier, not ) { + if ( isFunction( qualifier ) ) { + return jQuery.grep( elements, function( elem, i ) { + return !!qualifier.call( elem, i, elem ) !== not; + } ); + } + + // Single element + if ( qualifier.nodeType ) { + return jQuery.grep( elements, function( elem ) { + return ( elem === qualifier ) !== not; + } ); + } + + // Arraylike of elements (jQuery, arguments, Array) + if ( typeof qualifier !== "string" ) { + return jQuery.grep( elements, function( elem ) { + return ( indexOf.call( qualifier, elem ) > -1 ) !== not; + } ); + } + + // Filtered directly for both simple and complex selectors + return jQuery.filter( qualifier, elements, not ); +} + +jQuery.filter = function( expr, elems, not ) { + var elem = elems[ 0 ]; + + if ( not ) { + expr = ":not(" + expr + ")"; + } + + if ( elems.length === 1 && elem.nodeType === 1 ) { + return jQuery.find.matchesSelector( elem, expr ) ? [ elem ] : []; + } + + return jQuery.find.matches( expr, jQuery.grep( elems, function( elem ) { + return elem.nodeType === 1; + } ) ); +}; + +jQuery.fn.extend( { + find: function( selector ) { + var i, ret, + len = this.length, + self = this; + + if ( typeof selector !== "string" ) { + return this.pushStack( jQuery( selector ).filter( function() { + for ( i = 0; i < len; i++ ) { + if ( jQuery.contains( self[ i ], this ) ) { + return true; + } + } + } ) ); + } + + ret = this.pushStack( [] ); + + for ( i = 0; i < len; i++ ) { + jQuery.find( selector, self[ i ], ret ); + } + + return len > 1 ? jQuery.uniqueSort( ret ) : ret; + }, + filter: function( selector ) { + return this.pushStack( winnow( this, selector || [], false ) ); + }, + not: function( selector ) { + return this.pushStack( winnow( this, selector || [], true ) ); + }, + is: function( selector ) { + return !!winnow( + this, + + // If this is a positional/relative selector, check membership in the returned set + // so $("p:first").is("p:last") won't return true for a doc with two "p". + typeof selector === "string" && rneedsContext.test( selector ) ? + jQuery( selector ) : + selector || [], + false + ).length; + } +} ); + + +// Initialize a jQuery object + + +// A central reference to the root jQuery(document) +var rootjQuery, + + // A simple way to check for HTML strings + // Prioritize #id over to avoid XSS via location.hash (#9521) + // Strict HTML recognition (#11290: must start with <) + // Shortcut simple #id case for speed + rquickExpr = /^(?:\s*(<[\w\W]+>)[^>]*|#([\w-]+))$/, + + init = jQuery.fn.init = function( selector, context, root ) { + var match, elem; + + // HANDLE: $(""), $(null), $(undefined), $(false) + if ( !selector ) { + return this; + } + + // Method init() accepts an alternate rootjQuery + // so migrate can support jQuery.sub (gh-2101) + root = root || rootjQuery; + + // Handle HTML strings + if ( typeof selector === "string" ) { + if ( selector[ 0 ] === "<" && + selector[ selector.length - 1 ] === ">" && + selector.length >= 3 ) { + + // Assume that strings that start and end with <> are HTML and skip the regex check + match = [ null, selector, null ]; + + } else { + match = rquickExpr.exec( selector ); + } + + // Match html or make sure no context is specified for #id + if ( match && ( match[ 1 ] || !context ) ) { + + // HANDLE: $(html) -> $(array) + if ( match[ 1 ] ) { + context = context instanceof jQuery ? context[ 0 ] : context; + + // Option to run scripts is true for back-compat + // Intentionally let the error be thrown if parseHTML is not present + jQuery.merge( this, jQuery.parseHTML( + match[ 1 ], + context && context.nodeType ? context.ownerDocument || context : document, + true + ) ); + + // HANDLE: $(html, props) + if ( rsingleTag.test( match[ 1 ] ) && jQuery.isPlainObject( context ) ) { + for ( match in context ) { + + // Properties of context are called as methods if possible + if ( isFunction( this[ match ] ) ) { + this[ match ]( context[ match ] ); + + // ...and otherwise set as attributes + } else { + this.attr( match, context[ match ] ); + } + } + } + + return this; + + // HANDLE: $(#id) + } else { + elem = document.getElementById( match[ 2 ] ); + + if ( elem ) { + + // Inject the element directly into the jQuery object + this[ 0 ] = elem; + this.length = 1; + } + return this; + } + + // HANDLE: $(expr, $(...)) + } else if ( !context || context.jquery ) { + return ( context || root ).find( selector ); + + // HANDLE: $(expr, context) + // (which is just equivalent to: $(context).find(expr) + } else { + return this.constructor( context ).find( selector ); + } + + // HANDLE: $(DOMElement) + } else if ( selector.nodeType ) { + this[ 0 ] = selector; + this.length = 1; + return this; + + // HANDLE: $(function) + // Shortcut for document ready + } else if ( isFunction( selector ) ) { + return root.ready !== undefined ? + root.ready( selector ) : + + // Execute immediately if ready is not present + selector( jQuery ); + } + + return jQuery.makeArray( selector, this ); + }; + +// Give the init function the jQuery prototype for later instantiation +init.prototype = jQuery.fn; + +// Initialize central reference +rootjQuery = jQuery( document ); + + +var rparentsprev = /^(?:parents|prev(?:Until|All))/, + + // Methods guaranteed to produce a unique set when starting from a unique set + guaranteedUnique = { + children: true, + contents: true, + next: true, + prev: true + }; + +jQuery.fn.extend( { + has: function( target ) { + var targets = jQuery( target, this ), + l = targets.length; + + return this.filter( function() { + var i = 0; + for ( ; i < l; i++ ) { + if ( jQuery.contains( this, targets[ i ] ) ) { + return true; + } + } + } ); + }, + + closest: function( selectors, context ) { + var cur, + i = 0, + l = this.length, + matched = [], + targets = typeof selectors !== "string" && jQuery( selectors ); + + // Positional selectors never match, since there's no _selection_ context + if ( !rneedsContext.test( selectors ) ) { + for ( ; i < l; i++ ) { + for ( cur = this[ i ]; cur && cur !== context; cur = cur.parentNode ) { + + // Always skip document fragments + if ( cur.nodeType < 11 && ( targets ? + targets.index( cur ) > -1 : + + // Don't pass non-elements to Sizzle + cur.nodeType === 1 && + jQuery.find.matchesSelector( cur, selectors ) ) ) { + + matched.push( cur ); + break; + } + } + } + } + + return this.pushStack( matched.length > 1 ? jQuery.uniqueSort( matched ) : matched ); + }, + + // Determine the position of an element within the set + index: function( elem ) { + + // No argument, return index in parent + if ( !elem ) { + return ( this[ 0 ] && this[ 0 ].parentNode ) ? this.first().prevAll().length : -1; + } + + // Index in selector + if ( typeof elem === "string" ) { + return indexOf.call( jQuery( elem ), this[ 0 ] ); + } + + // Locate the position of the desired element + return indexOf.call( this, + + // If it receives a jQuery object, the first element is used + elem.jquery ? elem[ 0 ] : elem + ); + }, + + add: function( selector, context ) { + return this.pushStack( + jQuery.uniqueSort( + jQuery.merge( this.get(), jQuery( selector, context ) ) + ) + ); + }, + + addBack: function( selector ) { + return this.add( selector == null ? + this.prevObject : this.prevObject.filter( selector ) + ); + } +} ); + +function sibling( cur, dir ) { + while ( ( cur = cur[ dir ] ) && cur.nodeType !== 1 ) {} + return cur; +} + +jQuery.each( { + parent: function( elem ) { + var parent = elem.parentNode; + return parent && parent.nodeType !== 11 ? parent : null; + }, + parents: function( elem ) { + return dir( elem, "parentNode" ); + }, + parentsUntil: function( elem, _i, until ) { + return dir( elem, "parentNode", until ); + }, + next: function( elem ) { + return sibling( elem, "nextSibling" ); + }, + prev: function( elem ) { + return sibling( elem, "previousSibling" ); + }, + nextAll: function( elem ) { + return dir( elem, "nextSibling" ); + }, + prevAll: function( elem ) { + return dir( elem, "previousSibling" ); + }, + nextUntil: function( elem, _i, until ) { + return dir( elem, "nextSibling", until ); + }, + prevUntil: function( elem, _i, until ) { + return dir( elem, "previousSibling", until ); + }, + siblings: function( elem ) { + return siblings( ( elem.parentNode || {} ).firstChild, elem ); + }, + children: function( elem ) { + return siblings( elem.firstChild ); + }, + contents: function( elem ) { + if ( elem.contentDocument != null && + + // Support: IE 11+ + // elements with no `data` attribute has an object + // `contentDocument` with a `null` prototype. + getProto( elem.contentDocument ) ) { + + return elem.contentDocument; + } + + // Support: IE 9 - 11 only, iOS 7 only, Android Browser <=4.3 only + // Treat the template element as a regular one in browsers that + // don't support it. + if ( nodeName( elem, "template" ) ) { + elem = elem.content || elem; + } + + return jQuery.merge( [], elem.childNodes ); + } +}, function( name, fn ) { + jQuery.fn[ name ] = function( until, selector ) { + var matched = jQuery.map( this, fn, until ); + + if ( name.slice( -5 ) !== "Until" ) { + selector = until; + } + + if ( selector && typeof selector === "string" ) { + matched = jQuery.filter( selector, matched ); + } + + if ( this.length > 1 ) { + + // Remove duplicates + if ( !guaranteedUnique[ name ] ) { + jQuery.uniqueSort( matched ); + } + + // Reverse order for parents* and prev-derivatives + if ( rparentsprev.test( name ) ) { + matched.reverse(); + } + } + + return this.pushStack( matched ); + }; +} ); +var rnothtmlwhite = ( /[^\x20\t\r\n\f]+/g ); + + + +// Convert String-formatted options into Object-formatted ones +function createOptions( options ) { + var object = {}; + jQuery.each( options.match( rnothtmlwhite ) || [], function( _, flag ) { + object[ flag ] = true; + } ); + return object; +} + +/* + * Create a callback list using the following parameters: + * + * options: an optional list of space-separated options that will change how + * the callback list behaves or a more traditional option object + * + * By default a callback list will act like an event callback list and can be + * "fired" multiple times. + * + * Possible options: + * + * once: will ensure the callback list can only be fired once (like a Deferred) + * + * memory: will keep track of previous values and will call any callback added + * after the list has been fired right away with the latest "memorized" + * values (like a Deferred) + * + * unique: will ensure a callback can only be added once (no duplicate in the list) + * + * stopOnFalse: interrupt callings when a callback returns false + * + */ +jQuery.Callbacks = function( options ) { + + // Convert options from String-formatted to Object-formatted if needed + // (we check in cache first) + options = typeof options === "string" ? + createOptions( options ) : + jQuery.extend( {}, options ); + + var // Flag to know if list is currently firing + firing, + + // Last fire value for non-forgettable lists + memory, + + // Flag to know if list was already fired + fired, + + // Flag to prevent firing + locked, + + // Actual callback list + list = [], + + // Queue of execution data for repeatable lists + queue = [], + + // Index of currently firing callback (modified by add/remove as needed) + firingIndex = -1, + + // Fire callbacks + fire = function() { + + // Enforce single-firing + locked = locked || options.once; + + // Execute callbacks for all pending executions, + // respecting firingIndex overrides and runtime changes + fired = firing = true; + for ( ; queue.length; firingIndex = -1 ) { + memory = queue.shift(); + while ( ++firingIndex < list.length ) { + + // Run callback and check for early termination + if ( list[ firingIndex ].apply( memory[ 0 ], memory[ 1 ] ) === false && + options.stopOnFalse ) { + + // Jump to end and forget the data so .add doesn't re-fire + firingIndex = list.length; + memory = false; + } + } + } + + // Forget the data if we're done with it + if ( !options.memory ) { + memory = false; + } + + firing = false; + + // Clean up if we're done firing for good + if ( locked ) { + + // Keep an empty list if we have data for future add calls + if ( memory ) { + list = []; + + // Otherwise, this object is spent + } else { + list = ""; + } + } + }, + + // Actual Callbacks object + self = { + + // Add a callback or a collection of callbacks to the list + add: function() { + if ( list ) { + + // If we have memory from a past run, we should fire after adding + if ( memory && !firing ) { + firingIndex = list.length - 1; + queue.push( memory ); + } + + ( function add( args ) { + jQuery.each( args, function( _, arg ) { + if ( isFunction( arg ) ) { + if ( !options.unique || !self.has( arg ) ) { + list.push( arg ); + } + } else if ( arg && arg.length && toType( arg ) !== "string" ) { + + // Inspect recursively + add( arg ); + } + } ); + } )( arguments ); + + if ( memory && !firing ) { + fire(); + } + } + return this; + }, + + // Remove a callback from the list + remove: function() { + jQuery.each( arguments, function( _, arg ) { + var index; + while ( ( index = jQuery.inArray( arg, list, index ) ) > -1 ) { + list.splice( index, 1 ); + + // Handle firing indexes + if ( index <= firingIndex ) { + firingIndex--; + } + } + } ); + return this; + }, + + // Check if a given callback is in the list. + // If no argument is given, return whether or not list has callbacks attached. + has: function( fn ) { + return fn ? + jQuery.inArray( fn, list ) > -1 : + list.length > 0; + }, + + // Remove all callbacks from the list + empty: function() { + if ( list ) { + list = []; + } + return this; + }, + + // Disable .fire and .add + // Abort any current/pending executions + // Clear all callbacks and values + disable: function() { + locked = queue = []; + list = memory = ""; + return this; + }, + disabled: function() { + return !list; + }, + + // Disable .fire + // Also disable .add unless we have memory (since it would have no effect) + // Abort any pending executions + lock: function() { + locked = queue = []; + if ( !memory && !firing ) { + list = memory = ""; + } + return this; + }, + locked: function() { + return !!locked; + }, + + // Call all callbacks with the given context and arguments + fireWith: function( context, args ) { + if ( !locked ) { + args = args || []; + args = [ context, args.slice ? args.slice() : args ]; + queue.push( args ); + if ( !firing ) { + fire(); + } + } + return this; + }, + + // Call all the callbacks with the given arguments + fire: function() { + self.fireWith( this, arguments ); + return this; + }, + + // To know if the callbacks have already been called at least once + fired: function() { + return !!fired; + } + }; + + return self; +}; + + +function Identity( v ) { + return v; +} +function Thrower( ex ) { + throw ex; +} + +function adoptValue( value, resolve, reject, noValue ) { + var method; + + try { + + // Check for promise aspect first to privilege synchronous behavior + if ( value && isFunction( ( method = value.promise ) ) ) { + method.call( value ).done( resolve ).fail( reject ); + + // Other thenables + } else if ( value && isFunction( ( method = value.then ) ) ) { + method.call( value, resolve, reject ); + + // Other non-thenables + } else { + + // Control `resolve` arguments by letting Array#slice cast boolean `noValue` to integer: + // * false: [ value ].slice( 0 ) => resolve( value ) + // * true: [ value ].slice( 1 ) => resolve() + resolve.apply( undefined, [ value ].slice( noValue ) ); + } + + // For Promises/A+, convert exceptions into rejections + // Since jQuery.when doesn't unwrap thenables, we can skip the extra checks appearing in + // Deferred#then to conditionally suppress rejection. + } catch ( value ) { + + // Support: Android 4.0 only + // Strict mode functions invoked without .call/.apply get global-object context + reject.apply( undefined, [ value ] ); + } +} + +jQuery.extend( { + + Deferred: function( func ) { + var tuples = [ + + // action, add listener, callbacks, + // ... .then handlers, argument index, [final state] + [ "notify", "progress", jQuery.Callbacks( "memory" ), + jQuery.Callbacks( "memory" ), 2 ], + [ "resolve", "done", jQuery.Callbacks( "once memory" ), + jQuery.Callbacks( "once memory" ), 0, "resolved" ], + [ "reject", "fail", jQuery.Callbacks( "once memory" ), + jQuery.Callbacks( "once memory" ), 1, "rejected" ] + ], + state = "pending", + promise = { + state: function() { + return state; + }, + always: function() { + deferred.done( arguments ).fail( arguments ); + return this; + }, + "catch": function( fn ) { + return promise.then( null, fn ); + }, + + // Keep pipe for back-compat + pipe: function( /* fnDone, fnFail, fnProgress */ ) { + var fns = arguments; + + return jQuery.Deferred( function( newDefer ) { + jQuery.each( tuples, function( _i, tuple ) { + + // Map tuples (progress, done, fail) to arguments (done, fail, progress) + var fn = isFunction( fns[ tuple[ 4 ] ] ) && fns[ tuple[ 4 ] ]; + + // deferred.progress(function() { bind to newDefer or newDefer.notify }) + // deferred.done(function() { bind to newDefer or newDefer.resolve }) + // deferred.fail(function() { bind to newDefer or newDefer.reject }) + deferred[ tuple[ 1 ] ]( function() { + var returned = fn && fn.apply( this, arguments ); + if ( returned && isFunction( returned.promise ) ) { + returned.promise() + .progress( newDefer.notify ) + .done( newDefer.resolve ) + .fail( newDefer.reject ); + } else { + newDefer[ tuple[ 0 ] + "With" ]( + this, + fn ? [ returned ] : arguments + ); + } + } ); + } ); + fns = null; + } ).promise(); + }, + then: function( onFulfilled, onRejected, onProgress ) { + var maxDepth = 0; + function resolve( depth, deferred, handler, special ) { + return function() { + var that = this, + args = arguments, + mightThrow = function() { + var returned, then; + + // Support: Promises/A+ section 2.3.3.3.3 + // https://promisesaplus.com/#point-59 + // Ignore double-resolution attempts + if ( depth < maxDepth ) { + return; + } + + returned = handler.apply( that, args ); + + // Support: Promises/A+ section 2.3.1 + // https://promisesaplus.com/#point-48 + if ( returned === deferred.promise() ) { + throw new TypeError( "Thenable self-resolution" ); + } + + // Support: Promises/A+ sections 2.3.3.1, 3.5 + // https://promisesaplus.com/#point-54 + // https://promisesaplus.com/#point-75 + // Retrieve `then` only once + then = returned && + + // Support: Promises/A+ section 2.3.4 + // https://promisesaplus.com/#point-64 + // Only check objects and functions for thenability + ( typeof returned === "object" || + typeof returned === "function" ) && + returned.then; + + // Handle a returned thenable + if ( isFunction( then ) ) { + + // Special processors (notify) just wait for resolution + if ( special ) { + then.call( + returned, + resolve( maxDepth, deferred, Identity, special ), + resolve( maxDepth, deferred, Thrower, special ) + ); + + // Normal processors (resolve) also hook into progress + } else { + + // ...and disregard older resolution values + maxDepth++; + + then.call( + returned, + resolve( maxDepth, deferred, Identity, special ), + resolve( maxDepth, deferred, Thrower, special ), + resolve( maxDepth, deferred, Identity, + deferred.notifyWith ) + ); + } + + // Handle all other returned values + } else { + + // Only substitute handlers pass on context + // and multiple values (non-spec behavior) + if ( handler !== Identity ) { + that = undefined; + args = [ returned ]; + } + + // Process the value(s) + // Default process is resolve + ( special || deferred.resolveWith )( that, args ); + } + }, + + // Only normal processors (resolve) catch and reject exceptions + process = special ? + mightThrow : + function() { + try { + mightThrow(); + } catch ( e ) { + + if ( jQuery.Deferred.exceptionHook ) { + jQuery.Deferred.exceptionHook( e, + process.stackTrace ); + } + + // Support: Promises/A+ section 2.3.3.3.4.1 + // https://promisesaplus.com/#point-61 + // Ignore post-resolution exceptions + if ( depth + 1 >= maxDepth ) { + + // Only substitute handlers pass on context + // and multiple values (non-spec behavior) + if ( handler !== Thrower ) { + that = undefined; + args = [ e ]; + } + + deferred.rejectWith( that, args ); + } + } + }; + + // Support: Promises/A+ section 2.3.3.3.1 + // https://promisesaplus.com/#point-57 + // Re-resolve promises immediately to dodge false rejection from + // subsequent errors + if ( depth ) { + process(); + } else { + + // Call an optional hook to record the stack, in case of exception + // since it's otherwise lost when execution goes async + if ( jQuery.Deferred.getStackHook ) { + process.stackTrace = jQuery.Deferred.getStackHook(); + } + window.setTimeout( process ); + } + }; + } + + return jQuery.Deferred( function( newDefer ) { + + // progress_handlers.add( ... ) + tuples[ 0 ][ 3 ].add( + resolve( + 0, + newDefer, + isFunction( onProgress ) ? + onProgress : + Identity, + newDefer.notifyWith + ) + ); + + // fulfilled_handlers.add( ... ) + tuples[ 1 ][ 3 ].add( + resolve( + 0, + newDefer, + isFunction( onFulfilled ) ? + onFulfilled : + Identity + ) + ); + + // rejected_handlers.add( ... ) + tuples[ 2 ][ 3 ].add( + resolve( + 0, + newDefer, + isFunction( onRejected ) ? + onRejected : + Thrower + ) + ); + } ).promise(); + }, + + // Get a promise for this deferred + // If obj is provided, the promise aspect is added to the object + promise: function( obj ) { + return obj != null ? jQuery.extend( obj, promise ) : promise; + } + }, + deferred = {}; + + // Add list-specific methods + jQuery.each( tuples, function( i, tuple ) { + var list = tuple[ 2 ], + stateString = tuple[ 5 ]; + + // promise.progress = list.add + // promise.done = list.add + // promise.fail = list.add + promise[ tuple[ 1 ] ] = list.add; + + // Handle state + if ( stateString ) { + list.add( + function() { + + // state = "resolved" (i.e., fulfilled) + // state = "rejected" + state = stateString; + }, + + // rejected_callbacks.disable + // fulfilled_callbacks.disable + tuples[ 3 - i ][ 2 ].disable, + + // rejected_handlers.disable + // fulfilled_handlers.disable + tuples[ 3 - i ][ 3 ].disable, + + // progress_callbacks.lock + tuples[ 0 ][ 2 ].lock, + + // progress_handlers.lock + tuples[ 0 ][ 3 ].lock + ); + } + + // progress_handlers.fire + // fulfilled_handlers.fire + // rejected_handlers.fire + list.add( tuple[ 3 ].fire ); + + // deferred.notify = function() { deferred.notifyWith(...) } + // deferred.resolve = function() { deferred.resolveWith(...) } + // deferred.reject = function() { deferred.rejectWith(...) } + deferred[ tuple[ 0 ] ] = function() { + deferred[ tuple[ 0 ] + "With" ]( this === deferred ? undefined : this, arguments ); + return this; + }; + + // deferred.notifyWith = list.fireWith + // deferred.resolveWith = list.fireWith + // deferred.rejectWith = list.fireWith + deferred[ tuple[ 0 ] + "With" ] = list.fireWith; + } ); + + // Make the deferred a promise + promise.promise( deferred ); + + // Call given func if any + if ( func ) { + func.call( deferred, deferred ); + } + + // All done! + return deferred; + }, + + // Deferred helper + when: function( singleValue ) { + var + + // count of uncompleted subordinates + remaining = arguments.length, + + // count of unprocessed arguments + i = remaining, + + // subordinate fulfillment data + resolveContexts = Array( i ), + resolveValues = slice.call( arguments ), + + // the master Deferred + master = jQuery.Deferred(), + + // subordinate callback factory + updateFunc = function( i ) { + return function( value ) { + resolveContexts[ i ] = this; + resolveValues[ i ] = arguments.length > 1 ? slice.call( arguments ) : value; + if ( !( --remaining ) ) { + master.resolveWith( resolveContexts, resolveValues ); + } + }; + }; + + // Single- and empty arguments are adopted like Promise.resolve + if ( remaining <= 1 ) { + adoptValue( singleValue, master.done( updateFunc( i ) ).resolve, master.reject, + !remaining ); + + // Use .then() to unwrap secondary thenables (cf. gh-3000) + if ( master.state() === "pending" || + isFunction( resolveValues[ i ] && resolveValues[ i ].then ) ) { + + return master.then(); + } + } + + // Multiple arguments are aggregated like Promise.all array elements + while ( i-- ) { + adoptValue( resolveValues[ i ], updateFunc( i ), master.reject ); + } + + return master.promise(); + } +} ); + + +// These usually indicate a programmer mistake during development, +// warn about them ASAP rather than swallowing them by default. +var rerrorNames = /^(Eval|Internal|Range|Reference|Syntax|Type|URI)Error$/; + +jQuery.Deferred.exceptionHook = function( error, stack ) { + + // Support: IE 8 - 9 only + // Console exists when dev tools are open, which can happen at any time + if ( window.console && window.console.warn && error && rerrorNames.test( error.name ) ) { + window.console.warn( "jQuery.Deferred exception: " + error.message, error.stack, stack ); + } +}; + + + + +jQuery.readyException = function( error ) { + window.setTimeout( function() { + throw error; + } ); +}; + + + + +// The deferred used on DOM ready +var readyList = jQuery.Deferred(); + +jQuery.fn.ready = function( fn ) { + + readyList + .then( fn ) + + // Wrap jQuery.readyException in a function so that the lookup + // happens at the time of error handling instead of callback + // registration. + .catch( function( error ) { + jQuery.readyException( error ); + } ); + + return this; +}; + +jQuery.extend( { + + // Is the DOM ready to be used? Set to true once it occurs. + isReady: false, + + // A counter to track how many items to wait for before + // the ready event fires. See #6781 + readyWait: 1, + + // Handle when the DOM is ready + ready: function( wait ) { + + // Abort if there are pending holds or we're already ready + if ( wait === true ? --jQuery.readyWait : jQuery.isReady ) { + return; + } + + // Remember that the DOM is ready + jQuery.isReady = true; + + // If a normal DOM Ready event fired, decrement, and wait if need be + if ( wait !== true && --jQuery.readyWait > 0 ) { + return; + } + + // If there are functions bound, to execute + readyList.resolveWith( document, [ jQuery ] ); + } +} ); + +jQuery.ready.then = readyList.then; + +// The ready event handler and self cleanup method +function completed() { + document.removeEventListener( "DOMContentLoaded", completed ); + window.removeEventListener( "load", completed ); + jQuery.ready(); +} + +// Catch cases where $(document).ready() is called +// after the browser event has already occurred. +// Support: IE <=9 - 10 only +// Older IE sometimes signals "interactive" too soon +if ( document.readyState === "complete" || + ( document.readyState !== "loading" && !document.documentElement.doScroll ) ) { + + // Handle it asynchronously to allow scripts the opportunity to delay ready + window.setTimeout( jQuery.ready ); + +} else { + + // Use the handy event callback + document.addEventListener( "DOMContentLoaded", completed ); + + // A fallback to window.onload, that will always work + window.addEventListener( "load", completed ); +} + + + + +// Multifunctional method to get and set values of a collection +// The value/s can optionally be executed if it's a function +var access = function( elems, fn, key, value, chainable, emptyGet, raw ) { + var i = 0, + len = elems.length, + bulk = key == null; + + // Sets many values + if ( toType( key ) === "object" ) { + chainable = true; + for ( i in key ) { + access( elems, fn, i, key[ i ], true, emptyGet, raw ); + } + + // Sets one value + } else if ( value !== undefined ) { + chainable = true; + + if ( !isFunction( value ) ) { + raw = true; + } + + if ( bulk ) { + + // Bulk operations run against the entire set + if ( raw ) { + fn.call( elems, value ); + fn = null; + + // ...except when executing function values + } else { + bulk = fn; + fn = function( elem, _key, value ) { + return bulk.call( jQuery( elem ), value ); + }; + } + } + + if ( fn ) { + for ( ; i < len; i++ ) { + fn( + elems[ i ], key, raw ? + value : + value.call( elems[ i ], i, fn( elems[ i ], key ) ) + ); + } + } + } + + if ( chainable ) { + return elems; + } + + // Gets + if ( bulk ) { + return fn.call( elems ); + } + + return len ? fn( elems[ 0 ], key ) : emptyGet; +}; + + +// Matches dashed string for camelizing +var rmsPrefix = /^-ms-/, + rdashAlpha = /-([a-z])/g; + +// Used by camelCase as callback to replace() +function fcamelCase( _all, letter ) { + return letter.toUpperCase(); +} + +// Convert dashed to camelCase; used by the css and data modules +// Support: IE <=9 - 11, Edge 12 - 15 +// Microsoft forgot to hump their vendor prefix (#9572) +function camelCase( string ) { + return string.replace( rmsPrefix, "ms-" ).replace( rdashAlpha, fcamelCase ); +} +var acceptData = function( owner ) { + + // Accepts only: + // - Node + // - Node.ELEMENT_NODE + // - Node.DOCUMENT_NODE + // - Object + // - Any + return owner.nodeType === 1 || owner.nodeType === 9 || !( +owner.nodeType ); +}; + + + + +function Data() { + this.expando = jQuery.expando + Data.uid++; +} + +Data.uid = 1; + +Data.prototype = { + + cache: function( owner ) { + + // Check if the owner object already has a cache + var value = owner[ this.expando ]; + + // If not, create one + if ( !value ) { + value = {}; + + // We can accept data for non-element nodes in modern browsers, + // but we should not, see #8335. + // Always return an empty object. + if ( acceptData( owner ) ) { + + // If it is a node unlikely to be stringify-ed or looped over + // use plain assignment + if ( owner.nodeType ) { + owner[ this.expando ] = value; + + // Otherwise secure it in a non-enumerable property + // configurable must be true to allow the property to be + // deleted when data is removed + } else { + Object.defineProperty( owner, this.expando, { + value: value, + configurable: true + } ); + } + } + } + + return value; + }, + set: function( owner, data, value ) { + var prop, + cache = this.cache( owner ); + + // Handle: [ owner, key, value ] args + // Always use camelCase key (gh-2257) + if ( typeof data === "string" ) { + cache[ camelCase( data ) ] = value; + + // Handle: [ owner, { properties } ] args + } else { + + // Copy the properties one-by-one to the cache object + for ( prop in data ) { + cache[ camelCase( prop ) ] = data[ prop ]; + } + } + return cache; + }, + get: function( owner, key ) { + return key === undefined ? + this.cache( owner ) : + + // Always use camelCase key (gh-2257) + owner[ this.expando ] && owner[ this.expando ][ camelCase( key ) ]; + }, + access: function( owner, key, value ) { + + // In cases where either: + // + // 1. No key was specified + // 2. A string key was specified, but no value provided + // + // Take the "read" path and allow the get method to determine + // which value to return, respectively either: + // + // 1. The entire cache object + // 2. The data stored at the key + // + if ( key === undefined || + ( ( key && typeof key === "string" ) && value === undefined ) ) { + + return this.get( owner, key ); + } + + // When the key is not a string, or both a key and value + // are specified, set or extend (existing objects) with either: + // + // 1. An object of properties + // 2. A key and value + // + this.set( owner, key, value ); + + // Since the "set" path can have two possible entry points + // return the expected data based on which path was taken[*] + return value !== undefined ? value : key; + }, + remove: function( owner, key ) { + var i, + cache = owner[ this.expando ]; + + if ( cache === undefined ) { + return; + } + + if ( key !== undefined ) { + + // Support array or space separated string of keys + if ( Array.isArray( key ) ) { + + // If key is an array of keys... + // We always set camelCase keys, so remove that. + key = key.map( camelCase ); + } else { + key = camelCase( key ); + + // If a key with the spaces exists, use it. + // Otherwise, create an array by matching non-whitespace + key = key in cache ? + [ key ] : + ( key.match( rnothtmlwhite ) || [] ); + } + + i = key.length; + + while ( i-- ) { + delete cache[ key[ i ] ]; + } + } + + // Remove the expando if there's no more data + if ( key === undefined || jQuery.isEmptyObject( cache ) ) { + + // Support: Chrome <=35 - 45 + // Webkit & Blink performance suffers when deleting properties + // from DOM nodes, so set to undefined instead + // https://bugs.chromium.org/p/chromium/issues/detail?id=378607 (bug restricted) + if ( owner.nodeType ) { + owner[ this.expando ] = undefined; + } else { + delete owner[ this.expando ]; + } + } + }, + hasData: function( owner ) { + var cache = owner[ this.expando ]; + return cache !== undefined && !jQuery.isEmptyObject( cache ); + } +}; +var dataPriv = new Data(); + +var dataUser = new Data(); + + + +// Implementation Summary +// +// 1. Enforce API surface and semantic compatibility with 1.9.x branch +// 2. Improve the module's maintainability by reducing the storage +// paths to a single mechanism. +// 3. Use the same single mechanism to support "private" and "user" data. +// 4. _Never_ expose "private" data to user code (TODO: Drop _data, _removeData) +// 5. Avoid exposing implementation details on user objects (eg. expando properties) +// 6. Provide a clear path for implementation upgrade to WeakMap in 2014 + +var rbrace = /^(?:\{[\w\W]*\}|\[[\w\W]*\])$/, + rmultiDash = /[A-Z]/g; + +function getData( data ) { + if ( data === "true" ) { + return true; + } + + if ( data === "false" ) { + return false; + } + + if ( data === "null" ) { + return null; + } + + // Only convert to a number if it doesn't change the string + if ( data === +data + "" ) { + return +data; + } + + if ( rbrace.test( data ) ) { + return JSON.parse( data ); + } + + return data; +} + +function dataAttr( elem, key, data ) { + var name; + + // If nothing was found internally, try to fetch any + // data from the HTML5 data-* attribute + if ( data === undefined && elem.nodeType === 1 ) { + name = "data-" + key.replace( rmultiDash, "-$&" ).toLowerCase(); + data = elem.getAttribute( name ); + + if ( typeof data === "string" ) { + try { + data = getData( data ); + } catch ( e ) {} + + // Make sure we set the data so it isn't changed later + dataUser.set( elem, key, data ); + } else { + data = undefined; + } + } + return data; +} + +jQuery.extend( { + hasData: function( elem ) { + return dataUser.hasData( elem ) || dataPriv.hasData( elem ); + }, + + data: function( elem, name, data ) { + return dataUser.access( elem, name, data ); + }, + + removeData: function( elem, name ) { + dataUser.remove( elem, name ); + }, + + // TODO: Now that all calls to _data and _removeData have been replaced + // with direct calls to dataPriv methods, these can be deprecated. + _data: function( elem, name, data ) { + return dataPriv.access( elem, name, data ); + }, + + _removeData: function( elem, name ) { + dataPriv.remove( elem, name ); + } +} ); + +jQuery.fn.extend( { + data: function( key, value ) { + var i, name, data, + elem = this[ 0 ], + attrs = elem && elem.attributes; + + // Gets all values + if ( key === undefined ) { + if ( this.length ) { + data = dataUser.get( elem ); + + if ( elem.nodeType === 1 && !dataPriv.get( elem, "hasDataAttrs" ) ) { + i = attrs.length; + while ( i-- ) { + + // Support: IE 11 only + // The attrs elements can be null (#14894) + if ( attrs[ i ] ) { + name = attrs[ i ].name; + if ( name.indexOf( "data-" ) === 0 ) { + name = camelCase( name.slice( 5 ) ); + dataAttr( elem, name, data[ name ] ); + } + } + } + dataPriv.set( elem, "hasDataAttrs", true ); + } + } + + return data; + } + + // Sets multiple values + if ( typeof key === "object" ) { + return this.each( function() { + dataUser.set( this, key ); + } ); + } + + return access( this, function( value ) { + var data; + + // The calling jQuery object (element matches) is not empty + // (and therefore has an element appears at this[ 0 ]) and the + // `value` parameter was not undefined. An empty jQuery object + // will result in `undefined` for elem = this[ 0 ] which will + // throw an exception if an attempt to read a data cache is made. + if ( elem && value === undefined ) { + + // Attempt to get data from the cache + // The key will always be camelCased in Data + data = dataUser.get( elem, key ); + if ( data !== undefined ) { + return data; + } + + // Attempt to "discover" the data in + // HTML5 custom data-* attrs + data = dataAttr( elem, key ); + if ( data !== undefined ) { + return data; + } + + // We tried really hard, but the data doesn't exist. + return; + } + + // Set the data... + this.each( function() { + + // We always store the camelCased key + dataUser.set( this, key, value ); + } ); + }, null, value, arguments.length > 1, null, true ); + }, + + removeData: function( key ) { + return this.each( function() { + dataUser.remove( this, key ); + } ); + } +} ); + + +jQuery.extend( { + queue: function( elem, type, data ) { + var queue; + + if ( elem ) { + type = ( type || "fx" ) + "queue"; + queue = dataPriv.get( elem, type ); + + // Speed up dequeue by getting out quickly if this is just a lookup + if ( data ) { + if ( !queue || Array.isArray( data ) ) { + queue = dataPriv.access( elem, type, jQuery.makeArray( data ) ); + } else { + queue.push( data ); + } + } + return queue || []; + } + }, + + dequeue: function( elem, type ) { + type = type || "fx"; + + var queue = jQuery.queue( elem, type ), + startLength = queue.length, + fn = queue.shift(), + hooks = jQuery._queueHooks( elem, type ), + next = function() { + jQuery.dequeue( elem, type ); + }; + + // If the fx queue is dequeued, always remove the progress sentinel + if ( fn === "inprogress" ) { + fn = queue.shift(); + startLength--; + } + + if ( fn ) { + + // Add a progress sentinel to prevent the fx queue from being + // automatically dequeued + if ( type === "fx" ) { + queue.unshift( "inprogress" ); + } + + // Clear up the last queue stop function + delete hooks.stop; + fn.call( elem, next, hooks ); + } + + if ( !startLength && hooks ) { + hooks.empty.fire(); + } + }, + + // Not public - generate a queueHooks object, or return the current one + _queueHooks: function( elem, type ) { + var key = type + "queueHooks"; + return dataPriv.get( elem, key ) || dataPriv.access( elem, key, { + empty: jQuery.Callbacks( "once memory" ).add( function() { + dataPriv.remove( elem, [ type + "queue", key ] ); + } ) + } ); + } +} ); + +jQuery.fn.extend( { + queue: function( type, data ) { + var setter = 2; + + if ( typeof type !== "string" ) { + data = type; + type = "fx"; + setter--; + } + + if ( arguments.length < setter ) { + return jQuery.queue( this[ 0 ], type ); + } + + return data === undefined ? + this : + this.each( function() { + var queue = jQuery.queue( this, type, data ); + + // Ensure a hooks for this queue + jQuery._queueHooks( this, type ); + + if ( type === "fx" && queue[ 0 ] !== "inprogress" ) { + jQuery.dequeue( this, type ); + } + } ); + }, + dequeue: function( type ) { + return this.each( function() { + jQuery.dequeue( this, type ); + } ); + }, + clearQueue: function( type ) { + return this.queue( type || "fx", [] ); + }, + + // Get a promise resolved when queues of a certain type + // are emptied (fx is the type by default) + promise: function( type, obj ) { + var tmp, + count = 1, + defer = jQuery.Deferred(), + elements = this, + i = this.length, + resolve = function() { + if ( !( --count ) ) { + defer.resolveWith( elements, [ elements ] ); + } + }; + + if ( typeof type !== "string" ) { + obj = type; + type = undefined; + } + type = type || "fx"; + + while ( i-- ) { + tmp = dataPriv.get( elements[ i ], type + "queueHooks" ); + if ( tmp && tmp.empty ) { + count++; + tmp.empty.add( resolve ); + } + } + resolve(); + return defer.promise( obj ); + } +} ); +var pnum = ( /[+-]?(?:\d*\.|)\d+(?:[eE][+-]?\d+|)/ ).source; + +var rcssNum = new RegExp( "^(?:([+-])=|)(" + pnum + ")([a-z%]*)$", "i" ); + + +var cssExpand = [ "Top", "Right", "Bottom", "Left" ]; + +var documentElement = document.documentElement; + + + + var isAttached = function( elem ) { + return jQuery.contains( elem.ownerDocument, elem ); + }, + composed = { composed: true }; + + // Support: IE 9 - 11+, Edge 12 - 18+, iOS 10.0 - 10.2 only + // Check attachment across shadow DOM boundaries when possible (gh-3504) + // Support: iOS 10.0-10.2 only + // Early iOS 10 versions support `attachShadow` but not `getRootNode`, + // leading to errors. We need to check for `getRootNode`. + if ( documentElement.getRootNode ) { + isAttached = function( elem ) { + return jQuery.contains( elem.ownerDocument, elem ) || + elem.getRootNode( composed ) === elem.ownerDocument; + }; + } +var isHiddenWithinTree = function( elem, el ) { + + // isHiddenWithinTree might be called from jQuery#filter function; + // in that case, element will be second argument + elem = el || elem; + + // Inline style trumps all + return elem.style.display === "none" || + elem.style.display === "" && + + // Otherwise, check computed style + // Support: Firefox <=43 - 45 + // Disconnected elements can have computed display: none, so first confirm that elem is + // in the document. + isAttached( elem ) && + + jQuery.css( elem, "display" ) === "none"; + }; + + + +function adjustCSS( elem, prop, valueParts, tween ) { + var adjusted, scale, + maxIterations = 20, + currentValue = tween ? + function() { + return tween.cur(); + } : + function() { + return jQuery.css( elem, prop, "" ); + }, + initial = currentValue(), + unit = valueParts && valueParts[ 3 ] || ( jQuery.cssNumber[ prop ] ? "" : "px" ), + + // Starting value computation is required for potential unit mismatches + initialInUnit = elem.nodeType && + ( jQuery.cssNumber[ prop ] || unit !== "px" && +initial ) && + rcssNum.exec( jQuery.css( elem, prop ) ); + + if ( initialInUnit && initialInUnit[ 3 ] !== unit ) { + + // Support: Firefox <=54 + // Halve the iteration target value to prevent interference from CSS upper bounds (gh-2144) + initial = initial / 2; + + // Trust units reported by jQuery.css + unit = unit || initialInUnit[ 3 ]; + + // Iteratively approximate from a nonzero starting point + initialInUnit = +initial || 1; + + while ( maxIterations-- ) { + + // Evaluate and update our best guess (doubling guesses that zero out). + // Finish if the scale equals or crosses 1 (making the old*new product non-positive). + jQuery.style( elem, prop, initialInUnit + unit ); + if ( ( 1 - scale ) * ( 1 - ( scale = currentValue() / initial || 0.5 ) ) <= 0 ) { + maxIterations = 0; + } + initialInUnit = initialInUnit / scale; + + } + + initialInUnit = initialInUnit * 2; + jQuery.style( elem, prop, initialInUnit + unit ); + + // Make sure we update the tween properties later on + valueParts = valueParts || []; + } + + if ( valueParts ) { + initialInUnit = +initialInUnit || +initial || 0; + + // Apply relative offset (+=/-=) if specified + adjusted = valueParts[ 1 ] ? + initialInUnit + ( valueParts[ 1 ] + 1 ) * valueParts[ 2 ] : + +valueParts[ 2 ]; + if ( tween ) { + tween.unit = unit; + tween.start = initialInUnit; + tween.end = adjusted; + } + } + return adjusted; +} + + +var defaultDisplayMap = {}; + +function getDefaultDisplay( elem ) { + var temp, + doc = elem.ownerDocument, + nodeName = elem.nodeName, + display = defaultDisplayMap[ nodeName ]; + + if ( display ) { + return display; + } + + temp = doc.body.appendChild( doc.createElement( nodeName ) ); + display = jQuery.css( temp, "display" ); + + temp.parentNode.removeChild( temp ); + + if ( display === "none" ) { + display = "block"; + } + defaultDisplayMap[ nodeName ] = display; + + return display; +} + +function showHide( elements, show ) { + var display, elem, + values = [], + index = 0, + length = elements.length; + + // Determine new display value for elements that need to change + for ( ; index < length; index++ ) { + elem = elements[ index ]; + if ( !elem.style ) { + continue; + } + + display = elem.style.display; + if ( show ) { + + // Since we force visibility upon cascade-hidden elements, an immediate (and slow) + // check is required in this first loop unless we have a nonempty display value (either + // inline or about-to-be-restored) + if ( display === "none" ) { + values[ index ] = dataPriv.get( elem, "display" ) || null; + if ( !values[ index ] ) { + elem.style.display = ""; + } + } + if ( elem.style.display === "" && isHiddenWithinTree( elem ) ) { + values[ index ] = getDefaultDisplay( elem ); + } + } else { + if ( display !== "none" ) { + values[ index ] = "none"; + + // Remember what we're overwriting + dataPriv.set( elem, "display", display ); + } + } + } + + // Set the display of the elements in a second loop to avoid constant reflow + for ( index = 0; index < length; index++ ) { + if ( values[ index ] != null ) { + elements[ index ].style.display = values[ index ]; + } + } + + return elements; +} + +jQuery.fn.extend( { + show: function() { + return showHide( this, true ); + }, + hide: function() { + return showHide( this ); + }, + toggle: function( state ) { + if ( typeof state === "boolean" ) { + return state ? this.show() : this.hide(); + } + + return this.each( function() { + if ( isHiddenWithinTree( this ) ) { + jQuery( this ).show(); + } else { + jQuery( this ).hide(); + } + } ); + } +} ); +var rcheckableType = ( /^(?:checkbox|radio)$/i ); + +var rtagName = ( /<([a-z][^\/\0>\x20\t\r\n\f]*)/i ); + +var rscriptType = ( /^$|^module$|\/(?:java|ecma)script/i ); + + + +( function() { + var fragment = document.createDocumentFragment(), + div = fragment.appendChild( document.createElement( "div" ) ), + input = document.createElement( "input" ); + + // Support: Android 4.0 - 4.3 only + // Check state lost if the name is set (#11217) + // Support: Windows Web Apps (WWA) + // `name` and `type` must use .setAttribute for WWA (#14901) + input.setAttribute( "type", "radio" ); + input.setAttribute( "checked", "checked" ); + input.setAttribute( "name", "t" ); + + div.appendChild( input ); + + // Support: Android <=4.1 only + // Older WebKit doesn't clone checked state correctly in fragments + support.checkClone = div.cloneNode( true ).cloneNode( true ).lastChild.checked; + + // Support: IE <=11 only + // Make sure textarea (and checkbox) defaultValue is properly cloned + div.innerHTML = ""; + support.noCloneChecked = !!div.cloneNode( true ).lastChild.defaultValue; + + // Support: IE <=9 only + // IE <=9 replaces "; + support.option = !!div.lastChild; +} )(); + + +// We have to close these tags to support XHTML (#13200) +var wrapMap = { + + // XHTML parsers do not magically insert elements in the + // same way that tag soup parsers do. So we cannot shorten + // this by omitting or other required elements. + thead: [ 1, "", "
" ], + col: [ 2, "", "
" ], + tr: [ 2, "", "
" ], + td: [ 3, "", "
" ], + + _default: [ 0, "", "" ] +}; + +wrapMap.tbody = wrapMap.tfoot = wrapMap.colgroup = wrapMap.caption = wrapMap.thead; +wrapMap.th = wrapMap.td; + +// Support: IE <=9 only +if ( !support.option ) { + wrapMap.optgroup = wrapMap.option = [ 1, "" ]; +} + + +function getAll( context, tag ) { + + // Support: IE <=9 - 11 only + // Use typeof to avoid zero-argument method invocation on host objects (#15151) + var ret; + + if ( typeof context.getElementsByTagName !== "undefined" ) { + ret = context.getElementsByTagName( tag || "*" ); + + } else if ( typeof context.querySelectorAll !== "undefined" ) { + ret = context.querySelectorAll( tag || "*" ); + + } else { + ret = []; + } + + if ( tag === undefined || tag && nodeName( context, tag ) ) { + return jQuery.merge( [ context ], ret ); + } + + return ret; +} + + +// Mark scripts as having already been evaluated +function setGlobalEval( elems, refElements ) { + var i = 0, + l = elems.length; + + for ( ; i < l; i++ ) { + dataPriv.set( + elems[ i ], + "globalEval", + !refElements || dataPriv.get( refElements[ i ], "globalEval" ) + ); + } +} + + +var rhtml = /<|&#?\w+;/; + +function buildFragment( elems, context, scripts, selection, ignored ) { + var elem, tmp, tag, wrap, attached, j, + fragment = context.createDocumentFragment(), + nodes = [], + i = 0, + l = elems.length; + + for ( ; i < l; i++ ) { + elem = elems[ i ]; + + if ( elem || elem === 0 ) { + + // Add nodes directly + if ( toType( elem ) === "object" ) { + + // Support: Android <=4.0 only, PhantomJS 1 only + // push.apply(_, arraylike) throws on ancient WebKit + jQuery.merge( nodes, elem.nodeType ? [ elem ] : elem ); + + // Convert non-html into a text node + } else if ( !rhtml.test( elem ) ) { + nodes.push( context.createTextNode( elem ) ); + + // Convert html into DOM nodes + } else { + tmp = tmp || fragment.appendChild( context.createElement( "div" ) ); + + // Deserialize a standard representation + tag = ( rtagName.exec( elem ) || [ "", "" ] )[ 1 ].toLowerCase(); + wrap = wrapMap[ tag ] || wrapMap._default; + tmp.innerHTML = wrap[ 1 ] + jQuery.htmlPrefilter( elem ) + wrap[ 2 ]; + + // Descend through wrappers to the right content + j = wrap[ 0 ]; + while ( j-- ) { + tmp = tmp.lastChild; + } + + // Support: Android <=4.0 only, PhantomJS 1 only + // push.apply(_, arraylike) throws on ancient WebKit + jQuery.merge( nodes, tmp.childNodes ); + + // Remember the top-level container + tmp = fragment.firstChild; + + // Ensure the created nodes are orphaned (#12392) + tmp.textContent = ""; + } + } + } + + // Remove wrapper from fragment + fragment.textContent = ""; + + i = 0; + while ( ( elem = nodes[ i++ ] ) ) { + + // Skip elements already in the context collection (trac-4087) + if ( selection && jQuery.inArray( elem, selection ) > -1 ) { + if ( ignored ) { + ignored.push( elem ); + } + continue; + } + + attached = isAttached( elem ); + + // Append to fragment + tmp = getAll( fragment.appendChild( elem ), "script" ); + + // Preserve script evaluation history + if ( attached ) { + setGlobalEval( tmp ); + } + + // Capture executables + if ( scripts ) { + j = 0; + while ( ( elem = tmp[ j++ ] ) ) { + if ( rscriptType.test( elem.type || "" ) ) { + scripts.push( elem ); + } + } + } + } + + return fragment; +} + + +var + rkeyEvent = /^key/, + rmouseEvent = /^(?:mouse|pointer|contextmenu|drag|drop)|click/, + rtypenamespace = /^([^.]*)(?:\.(.+)|)/; + +function returnTrue() { + return true; +} + +function returnFalse() { + return false; +} + +// Support: IE <=9 - 11+ +// focus() and blur() are asynchronous, except when they are no-op. +// So expect focus to be synchronous when the element is already active, +// and blur to be synchronous when the element is not already active. +// (focus and blur are always synchronous in other supported browsers, +// this just defines when we can count on it). +function expectSync( elem, type ) { + return ( elem === safeActiveElement() ) === ( type === "focus" ); +} + +// Support: IE <=9 only +// Accessing document.activeElement can throw unexpectedly +// https://bugs.jquery.com/ticket/13393 +function safeActiveElement() { + try { + return document.activeElement; + } catch ( err ) { } +} + +function on( elem, types, selector, data, fn, one ) { + var origFn, type; + + // Types can be a map of types/handlers + if ( typeof types === "object" ) { + + // ( types-Object, selector, data ) + if ( typeof selector !== "string" ) { + + // ( types-Object, data ) + data = data || selector; + selector = undefined; + } + for ( type in types ) { + on( elem, type, selector, data, types[ type ], one ); + } + return elem; + } + + if ( data == null && fn == null ) { + + // ( types, fn ) + fn = selector; + data = selector = undefined; + } else if ( fn == null ) { + if ( typeof selector === "string" ) { + + // ( types, selector, fn ) + fn = data; + data = undefined; + } else { + + // ( types, data, fn ) + fn = data; + data = selector; + selector = undefined; + } + } + if ( fn === false ) { + fn = returnFalse; + } else if ( !fn ) { + return elem; + } + + if ( one === 1 ) { + origFn = fn; + fn = function( event ) { + + // Can use an empty set, since event contains the info + jQuery().off( event ); + return origFn.apply( this, arguments ); + }; + + // Use same guid so caller can remove using origFn + fn.guid = origFn.guid || ( origFn.guid = jQuery.guid++ ); + } + return elem.each( function() { + jQuery.event.add( this, types, fn, data, selector ); + } ); +} + +/* + * Helper functions for managing events -- not part of the public interface. + * Props to Dean Edwards' addEvent library for many of the ideas. + */ +jQuery.event = { + + global: {}, + + add: function( elem, types, handler, data, selector ) { + + var handleObjIn, eventHandle, tmp, + events, t, handleObj, + special, handlers, type, namespaces, origType, + elemData = dataPriv.get( elem ); + + // Only attach events to objects that accept data + if ( !acceptData( elem ) ) { + return; + } + + // Caller can pass in an object of custom data in lieu of the handler + if ( handler.handler ) { + handleObjIn = handler; + handler = handleObjIn.handler; + selector = handleObjIn.selector; + } + + // Ensure that invalid selectors throw exceptions at attach time + // Evaluate against documentElement in case elem is a non-element node (e.g., document) + if ( selector ) { + jQuery.find.matchesSelector( documentElement, selector ); + } + + // Make sure that the handler has a unique ID, used to find/remove it later + if ( !handler.guid ) { + handler.guid = jQuery.guid++; + } + + // Init the element's event structure and main handler, if this is the first + if ( !( events = elemData.events ) ) { + events = elemData.events = Object.create( null ); + } + if ( !( eventHandle = elemData.handle ) ) { + eventHandle = elemData.handle = function( e ) { + + // Discard the second event of a jQuery.event.trigger() and + // when an event is called after a page has unloaded + return typeof jQuery !== "undefined" && jQuery.event.triggered !== e.type ? + jQuery.event.dispatch.apply( elem, arguments ) : undefined; + }; + } + + // Handle multiple events separated by a space + types = ( types || "" ).match( rnothtmlwhite ) || [ "" ]; + t = types.length; + while ( t-- ) { + tmp = rtypenamespace.exec( types[ t ] ) || []; + type = origType = tmp[ 1 ]; + namespaces = ( tmp[ 2 ] || "" ).split( "." ).sort(); + + // There *must* be a type, no attaching namespace-only handlers + if ( !type ) { + continue; + } + + // If event changes its type, use the special event handlers for the changed type + special = jQuery.event.special[ type ] || {}; + + // If selector defined, determine special event api type, otherwise given type + type = ( selector ? special.delegateType : special.bindType ) || type; + + // Update special based on newly reset type + special = jQuery.event.special[ type ] || {}; + + // handleObj is passed to all event handlers + handleObj = jQuery.extend( { + type: type, + origType: origType, + data: data, + handler: handler, + guid: handler.guid, + selector: selector, + needsContext: selector && jQuery.expr.match.needsContext.test( selector ), + namespace: namespaces.join( "." ) + }, handleObjIn ); + + // Init the event handler queue if we're the first + if ( !( handlers = events[ type ] ) ) { + handlers = events[ type ] = []; + handlers.delegateCount = 0; + + // Only use addEventListener if the special events handler returns false + if ( !special.setup || + special.setup.call( elem, data, namespaces, eventHandle ) === false ) { + + if ( elem.addEventListener ) { + elem.addEventListener( type, eventHandle ); + } + } + } + + if ( special.add ) { + special.add.call( elem, handleObj ); + + if ( !handleObj.handler.guid ) { + handleObj.handler.guid = handler.guid; + } + } + + // Add to the element's handler list, delegates in front + if ( selector ) { + handlers.splice( handlers.delegateCount++, 0, handleObj ); + } else { + handlers.push( handleObj ); + } + + // Keep track of which events have ever been used, for event optimization + jQuery.event.global[ type ] = true; + } + + }, + + // Detach an event or set of events from an element + remove: function( elem, types, handler, selector, mappedTypes ) { + + var j, origCount, tmp, + events, t, handleObj, + special, handlers, type, namespaces, origType, + elemData = dataPriv.hasData( elem ) && dataPriv.get( elem ); + + if ( !elemData || !( events = elemData.events ) ) { + return; + } + + // Once for each type.namespace in types; type may be omitted + types = ( types || "" ).match( rnothtmlwhite ) || [ "" ]; + t = types.length; + while ( t-- ) { + tmp = rtypenamespace.exec( types[ t ] ) || []; + type = origType = tmp[ 1 ]; + namespaces = ( tmp[ 2 ] || "" ).split( "." ).sort(); + + // Unbind all events (on this namespace, if provided) for the element + if ( !type ) { + for ( type in events ) { + jQuery.event.remove( elem, type + types[ t ], handler, selector, true ); + } + continue; + } + + special = jQuery.event.special[ type ] || {}; + type = ( selector ? special.delegateType : special.bindType ) || type; + handlers = events[ type ] || []; + tmp = tmp[ 2 ] && + new RegExp( "(^|\\.)" + namespaces.join( "\\.(?:.*\\.|)" ) + "(\\.|$)" ); + + // Remove matching events + origCount = j = handlers.length; + while ( j-- ) { + handleObj = handlers[ j ]; + + if ( ( mappedTypes || origType === handleObj.origType ) && + ( !handler || handler.guid === handleObj.guid ) && + ( !tmp || tmp.test( handleObj.namespace ) ) && + ( !selector || selector === handleObj.selector || + selector === "**" && handleObj.selector ) ) { + handlers.splice( j, 1 ); + + if ( handleObj.selector ) { + handlers.delegateCount--; + } + if ( special.remove ) { + special.remove.call( elem, handleObj ); + } + } + } + + // Remove generic event handler if we removed something and no more handlers exist + // (avoids potential for endless recursion during removal of special event handlers) + if ( origCount && !handlers.length ) { + if ( !special.teardown || + special.teardown.call( elem, namespaces, elemData.handle ) === false ) { + + jQuery.removeEvent( elem, type, elemData.handle ); + } + + delete events[ type ]; + } + } + + // Remove data and the expando if it's no longer used + if ( jQuery.isEmptyObject( events ) ) { + dataPriv.remove( elem, "handle events" ); + } + }, + + dispatch: function( nativeEvent ) { + + var i, j, ret, matched, handleObj, handlerQueue, + args = new Array( arguments.length ), + + // Make a writable jQuery.Event from the native event object + event = jQuery.event.fix( nativeEvent ), + + handlers = ( + dataPriv.get( this, "events" ) || Object.create( null ) + )[ event.type ] || [], + special = jQuery.event.special[ event.type ] || {}; + + // Use the fix-ed jQuery.Event rather than the (read-only) native event + args[ 0 ] = event; + + for ( i = 1; i < arguments.length; i++ ) { + args[ i ] = arguments[ i ]; + } + + event.delegateTarget = this; + + // Call the preDispatch hook for the mapped type, and let it bail if desired + if ( special.preDispatch && special.preDispatch.call( this, event ) === false ) { + return; + } + + // Determine handlers + handlerQueue = jQuery.event.handlers.call( this, event, handlers ); + + // Run delegates first; they may want to stop propagation beneath us + i = 0; + while ( ( matched = handlerQueue[ i++ ] ) && !event.isPropagationStopped() ) { + event.currentTarget = matched.elem; + + j = 0; + while ( ( handleObj = matched.handlers[ j++ ] ) && + !event.isImmediatePropagationStopped() ) { + + // If the event is namespaced, then each handler is only invoked if it is + // specially universal or its namespaces are a superset of the event's. + if ( !event.rnamespace || handleObj.namespace === false || + event.rnamespace.test( handleObj.namespace ) ) { + + event.handleObj = handleObj; + event.data = handleObj.data; + + ret = ( ( jQuery.event.special[ handleObj.origType ] || {} ).handle || + handleObj.handler ).apply( matched.elem, args ); + + if ( ret !== undefined ) { + if ( ( event.result = ret ) === false ) { + event.preventDefault(); + event.stopPropagation(); + } + } + } + } + } + + // Call the postDispatch hook for the mapped type + if ( special.postDispatch ) { + special.postDispatch.call( this, event ); + } + + return event.result; + }, + + handlers: function( event, handlers ) { + var i, handleObj, sel, matchedHandlers, matchedSelectors, + handlerQueue = [], + delegateCount = handlers.delegateCount, + cur = event.target; + + // Find delegate handlers + if ( delegateCount && + + // Support: IE <=9 + // Black-hole SVG instance trees (trac-13180) + cur.nodeType && + + // Support: Firefox <=42 + // Suppress spec-violating clicks indicating a non-primary pointer button (trac-3861) + // https://www.w3.org/TR/DOM-Level-3-Events/#event-type-click + // Support: IE 11 only + // ...but not arrow key "clicks" of radio inputs, which can have `button` -1 (gh-2343) + !( event.type === "click" && event.button >= 1 ) ) { + + for ( ; cur !== this; cur = cur.parentNode || this ) { + + // Don't check non-elements (#13208) + // Don't process clicks on disabled elements (#6911, #8165, #11382, #11764) + if ( cur.nodeType === 1 && !( event.type === "click" && cur.disabled === true ) ) { + matchedHandlers = []; + matchedSelectors = {}; + for ( i = 0; i < delegateCount; i++ ) { + handleObj = handlers[ i ]; + + // Don't conflict with Object.prototype properties (#13203) + sel = handleObj.selector + " "; + + if ( matchedSelectors[ sel ] === undefined ) { + matchedSelectors[ sel ] = handleObj.needsContext ? + jQuery( sel, this ).index( cur ) > -1 : + jQuery.find( sel, this, null, [ cur ] ).length; + } + if ( matchedSelectors[ sel ] ) { + matchedHandlers.push( handleObj ); + } + } + if ( matchedHandlers.length ) { + handlerQueue.push( { elem: cur, handlers: matchedHandlers } ); + } + } + } + } + + // Add the remaining (directly-bound) handlers + cur = this; + if ( delegateCount < handlers.length ) { + handlerQueue.push( { elem: cur, handlers: handlers.slice( delegateCount ) } ); + } + + return handlerQueue; + }, + + addProp: function( name, hook ) { + Object.defineProperty( jQuery.Event.prototype, name, { + enumerable: true, + configurable: true, + + get: isFunction( hook ) ? + function() { + if ( this.originalEvent ) { + return hook( this.originalEvent ); + } + } : + function() { + if ( this.originalEvent ) { + return this.originalEvent[ name ]; + } + }, + + set: function( value ) { + Object.defineProperty( this, name, { + enumerable: true, + configurable: true, + writable: true, + value: value + } ); + } + } ); + }, + + fix: function( originalEvent ) { + return originalEvent[ jQuery.expando ] ? + originalEvent : + new jQuery.Event( originalEvent ); + }, + + special: { + load: { + + // Prevent triggered image.load events from bubbling to window.load + noBubble: true + }, + click: { + + // Utilize native event to ensure correct state for checkable inputs + setup: function( data ) { + + // For mutual compressibility with _default, replace `this` access with a local var. + // `|| data` is dead code meant only to preserve the variable through minification. + var el = this || data; + + // Claim the first handler + if ( rcheckableType.test( el.type ) && + el.click && nodeName( el, "input" ) ) { + + // dataPriv.set( el, "click", ... ) + leverageNative( el, "click", returnTrue ); + } + + // Return false to allow normal processing in the caller + return false; + }, + trigger: function( data ) { + + // For mutual compressibility with _default, replace `this` access with a local var. + // `|| data` is dead code meant only to preserve the variable through minification. + var el = this || data; + + // Force setup before triggering a click + if ( rcheckableType.test( el.type ) && + el.click && nodeName( el, "input" ) ) { + + leverageNative( el, "click" ); + } + + // Return non-false to allow normal event-path propagation + return true; + }, + + // For cross-browser consistency, suppress native .click() on links + // Also prevent it if we're currently inside a leveraged native-event stack + _default: function( event ) { + var target = event.target; + return rcheckableType.test( target.type ) && + target.click && nodeName( target, "input" ) && + dataPriv.get( target, "click" ) || + nodeName( target, "a" ); + } + }, + + beforeunload: { + postDispatch: function( event ) { + + // Support: Firefox 20+ + // Firefox doesn't alert if the returnValue field is not set. + if ( event.result !== undefined && event.originalEvent ) { + event.originalEvent.returnValue = event.result; + } + } + } + } +}; + +// Ensure the presence of an event listener that handles manually-triggered +// synthetic events by interrupting progress until reinvoked in response to +// *native* events that it fires directly, ensuring that state changes have +// already occurred before other listeners are invoked. +function leverageNative( el, type, expectSync ) { + + // Missing expectSync indicates a trigger call, which must force setup through jQuery.event.add + if ( !expectSync ) { + if ( dataPriv.get( el, type ) === undefined ) { + jQuery.event.add( el, type, returnTrue ); + } + return; + } + + // Register the controller as a special universal handler for all event namespaces + dataPriv.set( el, type, false ); + jQuery.event.add( el, type, { + namespace: false, + handler: function( event ) { + var notAsync, result, + saved = dataPriv.get( this, type ); + + if ( ( event.isTrigger & 1 ) && this[ type ] ) { + + // Interrupt processing of the outer synthetic .trigger()ed event + // Saved data should be false in such cases, but might be a leftover capture object + // from an async native handler (gh-4350) + if ( !saved.length ) { + + // Store arguments for use when handling the inner native event + // There will always be at least one argument (an event object), so this array + // will not be confused with a leftover capture object. + saved = slice.call( arguments ); + dataPriv.set( this, type, saved ); + + // Trigger the native event and capture its result + // Support: IE <=9 - 11+ + // focus() and blur() are asynchronous + notAsync = expectSync( this, type ); + this[ type ](); + result = dataPriv.get( this, type ); + if ( saved !== result || notAsync ) { + dataPriv.set( this, type, false ); + } else { + result = {}; + } + if ( saved !== result ) { + + // Cancel the outer synthetic event + event.stopImmediatePropagation(); + event.preventDefault(); + return result.value; + } + + // If this is an inner synthetic event for an event with a bubbling surrogate + // (focus or blur), assume that the surrogate already propagated from triggering the + // native event and prevent that from happening again here. + // This technically gets the ordering wrong w.r.t. to `.trigger()` (in which the + // bubbling surrogate propagates *after* the non-bubbling base), but that seems + // less bad than duplication. + } else if ( ( jQuery.event.special[ type ] || {} ).delegateType ) { + event.stopPropagation(); + } + + // If this is a native event triggered above, everything is now in order + // Fire an inner synthetic event with the original arguments + } else if ( saved.length ) { + + // ...and capture the result + dataPriv.set( this, type, { + value: jQuery.event.trigger( + + // Support: IE <=9 - 11+ + // Extend with the prototype to reset the above stopImmediatePropagation() + jQuery.extend( saved[ 0 ], jQuery.Event.prototype ), + saved.slice( 1 ), + this + ) + } ); + + // Abort handling of the native event + event.stopImmediatePropagation(); + } + } + } ); +} + +jQuery.removeEvent = function( elem, type, handle ) { + + // This "if" is needed for plain objects + if ( elem.removeEventListener ) { + elem.removeEventListener( type, handle ); + } +}; + +jQuery.Event = function( src, props ) { + + // Allow instantiation without the 'new' keyword + if ( !( this instanceof jQuery.Event ) ) { + return new jQuery.Event( src, props ); + } + + // Event object + if ( src && src.type ) { + this.originalEvent = src; + this.type = src.type; + + // Events bubbling up the document may have been marked as prevented + // by a handler lower down the tree; reflect the correct value. + this.isDefaultPrevented = src.defaultPrevented || + src.defaultPrevented === undefined && + + // Support: Android <=2.3 only + src.returnValue === false ? + returnTrue : + returnFalse; + + // Create target properties + // Support: Safari <=6 - 7 only + // Target should not be a text node (#504, #13143) + this.target = ( src.target && src.target.nodeType === 3 ) ? + src.target.parentNode : + src.target; + + this.currentTarget = src.currentTarget; + this.relatedTarget = src.relatedTarget; + + // Event type + } else { + this.type = src; + } + + // Put explicitly provided properties onto the event object + if ( props ) { + jQuery.extend( this, props ); + } + + // Create a timestamp if incoming event doesn't have one + this.timeStamp = src && src.timeStamp || Date.now(); + + // Mark it as fixed + this[ jQuery.expando ] = true; +}; + +// jQuery.Event is based on DOM3 Events as specified by the ECMAScript Language Binding +// https://www.w3.org/TR/2003/WD-DOM-Level-3-Events-20030331/ecma-script-binding.html +jQuery.Event.prototype = { + constructor: jQuery.Event, + isDefaultPrevented: returnFalse, + isPropagationStopped: returnFalse, + isImmediatePropagationStopped: returnFalse, + isSimulated: false, + + preventDefault: function() { + var e = this.originalEvent; + + this.isDefaultPrevented = returnTrue; + + if ( e && !this.isSimulated ) { + e.preventDefault(); + } + }, + stopPropagation: function() { + var e = this.originalEvent; + + this.isPropagationStopped = returnTrue; + + if ( e && !this.isSimulated ) { + e.stopPropagation(); + } + }, + stopImmediatePropagation: function() { + var e = this.originalEvent; + + this.isImmediatePropagationStopped = returnTrue; + + if ( e && !this.isSimulated ) { + e.stopImmediatePropagation(); + } + + this.stopPropagation(); + } +}; + +// Includes all common event props including KeyEvent and MouseEvent specific props +jQuery.each( { + altKey: true, + bubbles: true, + cancelable: true, + changedTouches: true, + ctrlKey: true, + detail: true, + eventPhase: true, + metaKey: true, + pageX: true, + pageY: true, + shiftKey: true, + view: true, + "char": true, + code: true, + charCode: true, + key: true, + keyCode: true, + button: true, + buttons: true, + clientX: true, + clientY: true, + offsetX: true, + offsetY: true, + pointerId: true, + pointerType: true, + screenX: true, + screenY: true, + targetTouches: true, + toElement: true, + touches: true, + + which: function( event ) { + var button = event.button; + + // Add which for key events + if ( event.which == null && rkeyEvent.test( event.type ) ) { + return event.charCode != null ? event.charCode : event.keyCode; + } + + // Add which for click: 1 === left; 2 === middle; 3 === right + if ( !event.which && button !== undefined && rmouseEvent.test( event.type ) ) { + if ( button & 1 ) { + return 1; + } + + if ( button & 2 ) { + return 3; + } + + if ( button & 4 ) { + return 2; + } + + return 0; + } + + return event.which; + } +}, jQuery.event.addProp ); + +jQuery.each( { focus: "focusin", blur: "focusout" }, function( type, delegateType ) { + jQuery.event.special[ type ] = { + + // Utilize native event if possible so blur/focus sequence is correct + setup: function() { + + // Claim the first handler + // dataPriv.set( this, "focus", ... ) + // dataPriv.set( this, "blur", ... ) + leverageNative( this, type, expectSync ); + + // Return false to allow normal processing in the caller + return false; + }, + trigger: function() { + + // Force setup before trigger + leverageNative( this, type ); + + // Return non-false to allow normal event-path propagation + return true; + }, + + delegateType: delegateType + }; +} ); + +// Create mouseenter/leave events using mouseover/out and event-time checks +// so that event delegation works in jQuery. +// Do the same for pointerenter/pointerleave and pointerover/pointerout +// +// Support: Safari 7 only +// Safari sends mouseenter too often; see: +// https://bugs.chromium.org/p/chromium/issues/detail?id=470258 +// for the description of the bug (it existed in older Chrome versions as well). +jQuery.each( { + mouseenter: "mouseover", + mouseleave: "mouseout", + pointerenter: "pointerover", + pointerleave: "pointerout" +}, function( orig, fix ) { + jQuery.event.special[ orig ] = { + delegateType: fix, + bindType: fix, + + handle: function( event ) { + var ret, + target = this, + related = event.relatedTarget, + handleObj = event.handleObj; + + // For mouseenter/leave call the handler if related is outside the target. + // NB: No relatedTarget if the mouse left/entered the browser window + if ( !related || ( related !== target && !jQuery.contains( target, related ) ) ) { + event.type = handleObj.origType; + ret = handleObj.handler.apply( this, arguments ); + event.type = fix; + } + return ret; + } + }; +} ); + +jQuery.fn.extend( { + + on: function( types, selector, data, fn ) { + return on( this, types, selector, data, fn ); + }, + one: function( types, selector, data, fn ) { + return on( this, types, selector, data, fn, 1 ); + }, + off: function( types, selector, fn ) { + var handleObj, type; + if ( types && types.preventDefault && types.handleObj ) { + + // ( event ) dispatched jQuery.Event + handleObj = types.handleObj; + jQuery( types.delegateTarget ).off( + handleObj.namespace ? + handleObj.origType + "." + handleObj.namespace : + handleObj.origType, + handleObj.selector, + handleObj.handler + ); + return this; + } + if ( typeof types === "object" ) { + + // ( types-object [, selector] ) + for ( type in types ) { + this.off( type, selector, types[ type ] ); + } + return this; + } + if ( selector === false || typeof selector === "function" ) { + + // ( types [, fn] ) + fn = selector; + selector = undefined; + } + if ( fn === false ) { + fn = returnFalse; + } + return this.each( function() { + jQuery.event.remove( this, types, fn, selector ); + } ); + } +} ); + + +var + + // Support: IE <=10 - 11, Edge 12 - 13 only + // In IE/Edge using regex groups here causes severe slowdowns. + // See https://connect.microsoft.com/IE/feedback/details/1736512/ + rnoInnerhtml = /\s*$/g; + +// Prefer a tbody over its parent table for containing new rows +function manipulationTarget( elem, content ) { + if ( nodeName( elem, "table" ) && + nodeName( content.nodeType !== 11 ? content : content.firstChild, "tr" ) ) { + + return jQuery( elem ).children( "tbody" )[ 0 ] || elem; + } + + return elem; +} + +// Replace/restore the type attribute of script elements for safe DOM manipulation +function disableScript( elem ) { + elem.type = ( elem.getAttribute( "type" ) !== null ) + "/" + elem.type; + return elem; +} +function restoreScript( elem ) { + if ( ( elem.type || "" ).slice( 0, 5 ) === "true/" ) { + elem.type = elem.type.slice( 5 ); + } else { + elem.removeAttribute( "type" ); + } + + return elem; +} + +function cloneCopyEvent( src, dest ) { + var i, l, type, pdataOld, udataOld, udataCur, events; + + if ( dest.nodeType !== 1 ) { + return; + } + + // 1. Copy private data: events, handlers, etc. + if ( dataPriv.hasData( src ) ) { + pdataOld = dataPriv.get( src ); + events = pdataOld.events; + + if ( events ) { + dataPriv.remove( dest, "handle events" ); + + for ( type in events ) { + for ( i = 0, l = events[ type ].length; i < l; i++ ) { + jQuery.event.add( dest, type, events[ type ][ i ] ); + } + } + } + } + + // 2. Copy user data + if ( dataUser.hasData( src ) ) { + udataOld = dataUser.access( src ); + udataCur = jQuery.extend( {}, udataOld ); + + dataUser.set( dest, udataCur ); + } +} + +// Fix IE bugs, see support tests +function fixInput( src, dest ) { + var nodeName = dest.nodeName.toLowerCase(); + + // Fails to persist the checked state of a cloned checkbox or radio button. + if ( nodeName === "input" && rcheckableType.test( src.type ) ) { + dest.checked = src.checked; + + // Fails to return the selected option to the default selected state when cloning options + } else if ( nodeName === "input" || nodeName === "textarea" ) { + dest.defaultValue = src.defaultValue; + } +} + +function domManip( collection, args, callback, ignored ) { + + // Flatten any nested arrays + args = flat( args ); + + var fragment, first, scripts, hasScripts, node, doc, + i = 0, + l = collection.length, + iNoClone = l - 1, + value = args[ 0 ], + valueIsFunction = isFunction( value ); + + // We can't cloneNode fragments that contain checked, in WebKit + if ( valueIsFunction || + ( l > 1 && typeof value === "string" && + !support.checkClone && rchecked.test( value ) ) ) { + return collection.each( function( index ) { + var self = collection.eq( index ); + if ( valueIsFunction ) { + args[ 0 ] = value.call( this, index, self.html() ); + } + domManip( self, args, callback, ignored ); + } ); + } + + if ( l ) { + fragment = buildFragment( args, collection[ 0 ].ownerDocument, false, collection, ignored ); + first = fragment.firstChild; + + if ( fragment.childNodes.length === 1 ) { + fragment = first; + } + + // Require either new content or an interest in ignored elements to invoke the callback + if ( first || ignored ) { + scripts = jQuery.map( getAll( fragment, "script" ), disableScript ); + hasScripts = scripts.length; + + // Use the original fragment for the last item + // instead of the first because it can end up + // being emptied incorrectly in certain situations (#8070). + for ( ; i < l; i++ ) { + node = fragment; + + if ( i !== iNoClone ) { + node = jQuery.clone( node, true, true ); + + // Keep references to cloned scripts for later restoration + if ( hasScripts ) { + + // Support: Android <=4.0 only, PhantomJS 1 only + // push.apply(_, arraylike) throws on ancient WebKit + jQuery.merge( scripts, getAll( node, "script" ) ); + } + } + + callback.call( collection[ i ], node, i ); + } + + if ( hasScripts ) { + doc = scripts[ scripts.length - 1 ].ownerDocument; + + // Reenable scripts + jQuery.map( scripts, restoreScript ); + + // Evaluate executable scripts on first document insertion + for ( i = 0; i < hasScripts; i++ ) { + node = scripts[ i ]; + if ( rscriptType.test( node.type || "" ) && + !dataPriv.access( node, "globalEval" ) && + jQuery.contains( doc, node ) ) { + + if ( node.src && ( node.type || "" ).toLowerCase() !== "module" ) { + + // Optional AJAX dependency, but won't run scripts if not present + if ( jQuery._evalUrl && !node.noModule ) { + jQuery._evalUrl( node.src, { + nonce: node.nonce || node.getAttribute( "nonce" ) + }, doc ); + } + } else { + DOMEval( node.textContent.replace( rcleanScript, "" ), node, doc ); + } + } + } + } + } + } + + return collection; +} + +function remove( elem, selector, keepData ) { + var node, + nodes = selector ? jQuery.filter( selector, elem ) : elem, + i = 0; + + for ( ; ( node = nodes[ i ] ) != null; i++ ) { + if ( !keepData && node.nodeType === 1 ) { + jQuery.cleanData( getAll( node ) ); + } + + if ( node.parentNode ) { + if ( keepData && isAttached( node ) ) { + setGlobalEval( getAll( node, "script" ) ); + } + node.parentNode.removeChild( node ); + } + } + + return elem; +} + +jQuery.extend( { + htmlPrefilter: function( html ) { + return html; + }, + + clone: function( elem, dataAndEvents, deepDataAndEvents ) { + var i, l, srcElements, destElements, + clone = elem.cloneNode( true ), + inPage = isAttached( elem ); + + // Fix IE cloning issues + if ( !support.noCloneChecked && ( elem.nodeType === 1 || elem.nodeType === 11 ) && + !jQuery.isXMLDoc( elem ) ) { + + // We eschew Sizzle here for performance reasons: https://jsperf.com/getall-vs-sizzle/2 + destElements = getAll( clone ); + srcElements = getAll( elem ); + + for ( i = 0, l = srcElements.length; i < l; i++ ) { + fixInput( srcElements[ i ], destElements[ i ] ); + } + } + + // Copy the events from the original to the clone + if ( dataAndEvents ) { + if ( deepDataAndEvents ) { + srcElements = srcElements || getAll( elem ); + destElements = destElements || getAll( clone ); + + for ( i = 0, l = srcElements.length; i < l; i++ ) { + cloneCopyEvent( srcElements[ i ], destElements[ i ] ); + } + } else { + cloneCopyEvent( elem, clone ); + } + } + + // Preserve script evaluation history + destElements = getAll( clone, "script" ); + if ( destElements.length > 0 ) { + setGlobalEval( destElements, !inPage && getAll( elem, "script" ) ); + } + + // Return the cloned set + return clone; + }, + + cleanData: function( elems ) { + var data, elem, type, + special = jQuery.event.special, + i = 0; + + for ( ; ( elem = elems[ i ] ) !== undefined; i++ ) { + if ( acceptData( elem ) ) { + if ( ( data = elem[ dataPriv.expando ] ) ) { + if ( data.events ) { + for ( type in data.events ) { + if ( special[ type ] ) { + jQuery.event.remove( elem, type ); + + // This is a shortcut to avoid jQuery.event.remove's overhead + } else { + jQuery.removeEvent( elem, type, data.handle ); + } + } + } + + // Support: Chrome <=35 - 45+ + // Assign undefined instead of using delete, see Data#remove + elem[ dataPriv.expando ] = undefined; + } + if ( elem[ dataUser.expando ] ) { + + // Support: Chrome <=35 - 45+ + // Assign undefined instead of using delete, see Data#remove + elem[ dataUser.expando ] = undefined; + } + } + } + } +} ); + +jQuery.fn.extend( { + detach: function( selector ) { + return remove( this, selector, true ); + }, + + remove: function( selector ) { + return remove( this, selector ); + }, + + text: function( value ) { + return access( this, function( value ) { + return value === undefined ? + jQuery.text( this ) : + this.empty().each( function() { + if ( this.nodeType === 1 || this.nodeType === 11 || this.nodeType === 9 ) { + this.textContent = value; + } + } ); + }, null, value, arguments.length ); + }, + + append: function() { + return domManip( this, arguments, function( elem ) { + if ( this.nodeType === 1 || this.nodeType === 11 || this.nodeType === 9 ) { + var target = manipulationTarget( this, elem ); + target.appendChild( elem ); + } + } ); + }, + + prepend: function() { + return domManip( this, arguments, function( elem ) { + if ( this.nodeType === 1 || this.nodeType === 11 || this.nodeType === 9 ) { + var target = manipulationTarget( this, elem ); + target.insertBefore( elem, target.firstChild ); + } + } ); + }, + + before: function() { + return domManip( this, arguments, function( elem ) { + if ( this.parentNode ) { + this.parentNode.insertBefore( elem, this ); + } + } ); + }, + + after: function() { + return domManip( this, arguments, function( elem ) { + if ( this.parentNode ) { + this.parentNode.insertBefore( elem, this.nextSibling ); + } + } ); + }, + + empty: function() { + var elem, + i = 0; + + for ( ; ( elem = this[ i ] ) != null; i++ ) { + if ( elem.nodeType === 1 ) { + + // Prevent memory leaks + jQuery.cleanData( getAll( elem, false ) ); + + // Remove any remaining nodes + elem.textContent = ""; + } + } + + return this; + }, + + clone: function( dataAndEvents, deepDataAndEvents ) { + dataAndEvents = dataAndEvents == null ? false : dataAndEvents; + deepDataAndEvents = deepDataAndEvents == null ? dataAndEvents : deepDataAndEvents; + + return this.map( function() { + return jQuery.clone( this, dataAndEvents, deepDataAndEvents ); + } ); + }, + + html: function( value ) { + return access( this, function( value ) { + var elem = this[ 0 ] || {}, + i = 0, + l = this.length; + + if ( value === undefined && elem.nodeType === 1 ) { + return elem.innerHTML; + } + + // See if we can take a shortcut and just use innerHTML + if ( typeof value === "string" && !rnoInnerhtml.test( value ) && + !wrapMap[ ( rtagName.exec( value ) || [ "", "" ] )[ 1 ].toLowerCase() ] ) { + + value = jQuery.htmlPrefilter( value ); + + try { + for ( ; i < l; i++ ) { + elem = this[ i ] || {}; + + // Remove element nodes and prevent memory leaks + if ( elem.nodeType === 1 ) { + jQuery.cleanData( getAll( elem, false ) ); + elem.innerHTML = value; + } + } + + elem = 0; + + // If using innerHTML throws an exception, use the fallback method + } catch ( e ) {} + } + + if ( elem ) { + this.empty().append( value ); + } + }, null, value, arguments.length ); + }, + + replaceWith: function() { + var ignored = []; + + // Make the changes, replacing each non-ignored context element with the new content + return domManip( this, arguments, function( elem ) { + var parent = this.parentNode; + + if ( jQuery.inArray( this, ignored ) < 0 ) { + jQuery.cleanData( getAll( this ) ); + if ( parent ) { + parent.replaceChild( elem, this ); + } + } + + // Force callback invocation + }, ignored ); + } +} ); + +jQuery.each( { + appendTo: "append", + prependTo: "prepend", + insertBefore: "before", + insertAfter: "after", + replaceAll: "replaceWith" +}, function( name, original ) { + jQuery.fn[ name ] = function( selector ) { + var elems, + ret = [], + insert = jQuery( selector ), + last = insert.length - 1, + i = 0; + + for ( ; i <= last; i++ ) { + elems = i === last ? this : this.clone( true ); + jQuery( insert[ i ] )[ original ]( elems ); + + // Support: Android <=4.0 only, PhantomJS 1 only + // .get() because push.apply(_, arraylike) throws on ancient WebKit + push.apply( ret, elems.get() ); + } + + return this.pushStack( ret ); + }; +} ); +var rnumnonpx = new RegExp( "^(" + pnum + ")(?!px)[a-z%]+$", "i" ); + +var getStyles = function( elem ) { + + // Support: IE <=11 only, Firefox <=30 (#15098, #14150) + // IE throws on elements created in popups + // FF meanwhile throws on frame elements through "defaultView.getComputedStyle" + var view = elem.ownerDocument.defaultView; + + if ( !view || !view.opener ) { + view = window; + } + + return view.getComputedStyle( elem ); + }; + +var swap = function( elem, options, callback ) { + var ret, name, + old = {}; + + // Remember the old values, and insert the new ones + for ( name in options ) { + old[ name ] = elem.style[ name ]; + elem.style[ name ] = options[ name ]; + } + + ret = callback.call( elem ); + + // Revert the old values + for ( name in options ) { + elem.style[ name ] = old[ name ]; + } + + return ret; +}; + + +var rboxStyle = new RegExp( cssExpand.join( "|" ), "i" ); + + + +( function() { + + // Executing both pixelPosition & boxSizingReliable tests require only one layout + // so they're executed at the same time to save the second computation. + function computeStyleTests() { + + // This is a singleton, we need to execute it only once + if ( !div ) { + return; + } + + container.style.cssText = "position:absolute;left:-11111px;width:60px;" + + "margin-top:1px;padding:0;border:0"; + div.style.cssText = + "position:relative;display:block;box-sizing:border-box;overflow:scroll;" + + "margin:auto;border:1px;padding:1px;" + + "width:60%;top:1%"; + documentElement.appendChild( container ).appendChild( div ); + + var divStyle = window.getComputedStyle( div ); + pixelPositionVal = divStyle.top !== "1%"; + + // Support: Android 4.0 - 4.3 only, Firefox <=3 - 44 + reliableMarginLeftVal = roundPixelMeasures( divStyle.marginLeft ) === 12; + + // Support: Android 4.0 - 4.3 only, Safari <=9.1 - 10.1, iOS <=7.0 - 9.3 + // Some styles come back with percentage values, even though they shouldn't + div.style.right = "60%"; + pixelBoxStylesVal = roundPixelMeasures( divStyle.right ) === 36; + + // Support: IE 9 - 11 only + // Detect misreporting of content dimensions for box-sizing:border-box elements + boxSizingReliableVal = roundPixelMeasures( divStyle.width ) === 36; + + // Support: IE 9 only + // Detect overflow:scroll screwiness (gh-3699) + // Support: Chrome <=64 + // Don't get tricked when zoom affects offsetWidth (gh-4029) + div.style.position = "absolute"; + scrollboxSizeVal = roundPixelMeasures( div.offsetWidth / 3 ) === 12; + + documentElement.removeChild( container ); + + // Nullify the div so it wouldn't be stored in the memory and + // it will also be a sign that checks already performed + div = null; + } + + function roundPixelMeasures( measure ) { + return Math.round( parseFloat( measure ) ); + } + + var pixelPositionVal, boxSizingReliableVal, scrollboxSizeVal, pixelBoxStylesVal, + reliableTrDimensionsVal, reliableMarginLeftVal, + container = document.createElement( "div" ), + div = document.createElement( "div" ); + + // Finish early in limited (non-browser) environments + if ( !div.style ) { + return; + } + + // Support: IE <=9 - 11 only + // Style of cloned element affects source element cloned (#8908) + div.style.backgroundClip = "content-box"; + div.cloneNode( true ).style.backgroundClip = ""; + support.clearCloneStyle = div.style.backgroundClip === "content-box"; + + jQuery.extend( support, { + boxSizingReliable: function() { + computeStyleTests(); + return boxSizingReliableVal; + }, + pixelBoxStyles: function() { + computeStyleTests(); + return pixelBoxStylesVal; + }, + pixelPosition: function() { + computeStyleTests(); + return pixelPositionVal; + }, + reliableMarginLeft: function() { + computeStyleTests(); + return reliableMarginLeftVal; + }, + scrollboxSize: function() { + computeStyleTests(); + return scrollboxSizeVal; + }, + + // Support: IE 9 - 11+, Edge 15 - 18+ + // IE/Edge misreport `getComputedStyle` of table rows with width/height + // set in CSS while `offset*` properties report correct values. + // Behavior in IE 9 is more subtle than in newer versions & it passes + // some versions of this test; make sure not to make it pass there! + reliableTrDimensions: function() { + var table, tr, trChild, trStyle; + if ( reliableTrDimensionsVal == null ) { + table = document.createElement( "table" ); + tr = document.createElement( "tr" ); + trChild = document.createElement( "div" ); + + table.style.cssText = "position:absolute;left:-11111px"; + tr.style.height = "1px"; + trChild.style.height = "9px"; + + documentElement + .appendChild( table ) + .appendChild( tr ) + .appendChild( trChild ); + + trStyle = window.getComputedStyle( tr ); + reliableTrDimensionsVal = parseInt( trStyle.height ) > 3; + + documentElement.removeChild( table ); + } + return reliableTrDimensionsVal; + } + } ); +} )(); + + +function curCSS( elem, name, computed ) { + var width, minWidth, maxWidth, ret, + + // Support: Firefox 51+ + // Retrieving style before computed somehow + // fixes an issue with getting wrong values + // on detached elements + style = elem.style; + + computed = computed || getStyles( elem ); + + // getPropertyValue is needed for: + // .css('filter') (IE 9 only, #12537) + // .css('--customProperty) (#3144) + if ( computed ) { + ret = computed.getPropertyValue( name ) || computed[ name ]; + + if ( ret === "" && !isAttached( elem ) ) { + ret = jQuery.style( elem, name ); + } + + // A tribute to the "awesome hack by Dean Edwards" + // Android Browser returns percentage for some values, + // but width seems to be reliably pixels. + // This is against the CSSOM draft spec: + // https://drafts.csswg.org/cssom/#resolved-values + if ( !support.pixelBoxStyles() && rnumnonpx.test( ret ) && rboxStyle.test( name ) ) { + + // Remember the original values + width = style.width; + minWidth = style.minWidth; + maxWidth = style.maxWidth; + + // Put in the new values to get a computed value out + style.minWidth = style.maxWidth = style.width = ret; + ret = computed.width; + + // Revert the changed values + style.width = width; + style.minWidth = minWidth; + style.maxWidth = maxWidth; + } + } + + return ret !== undefined ? + + // Support: IE <=9 - 11 only + // IE returns zIndex value as an integer. + ret + "" : + ret; +} + + +function addGetHookIf( conditionFn, hookFn ) { + + // Define the hook, we'll check on the first run if it's really needed. + return { + get: function() { + if ( conditionFn() ) { + + // Hook not needed (or it's not possible to use it due + // to missing dependency), remove it. + delete this.get; + return; + } + + // Hook needed; redefine it so that the support test is not executed again. + return ( this.get = hookFn ).apply( this, arguments ); + } + }; +} + + +var cssPrefixes = [ "Webkit", "Moz", "ms" ], + emptyStyle = document.createElement( "div" ).style, + vendorProps = {}; + +// Return a vendor-prefixed property or undefined +function vendorPropName( name ) { + + // Check for vendor prefixed names + var capName = name[ 0 ].toUpperCase() + name.slice( 1 ), + i = cssPrefixes.length; + + while ( i-- ) { + name = cssPrefixes[ i ] + capName; + if ( name in emptyStyle ) { + return name; + } + } +} + +// Return a potentially-mapped jQuery.cssProps or vendor prefixed property +function finalPropName( name ) { + var final = jQuery.cssProps[ name ] || vendorProps[ name ]; + + if ( final ) { + return final; + } + if ( name in emptyStyle ) { + return name; + } + return vendorProps[ name ] = vendorPropName( name ) || name; +} + + +var + + // Swappable if display is none or starts with table + // except "table", "table-cell", or "table-caption" + // See here for display values: https://developer.mozilla.org/en-US/docs/CSS/display + rdisplayswap = /^(none|table(?!-c[ea]).+)/, + rcustomProp = /^--/, + cssShow = { position: "absolute", visibility: "hidden", display: "block" }, + cssNormalTransform = { + letterSpacing: "0", + fontWeight: "400" + }; + +function setPositiveNumber( _elem, value, subtract ) { + + // Any relative (+/-) values have already been + // normalized at this point + var matches = rcssNum.exec( value ); + return matches ? + + // Guard against undefined "subtract", e.g., when used as in cssHooks + Math.max( 0, matches[ 2 ] - ( subtract || 0 ) ) + ( matches[ 3 ] || "px" ) : + value; +} + +function boxModelAdjustment( elem, dimension, box, isBorderBox, styles, computedVal ) { + var i = dimension === "width" ? 1 : 0, + extra = 0, + delta = 0; + + // Adjustment may not be necessary + if ( box === ( isBorderBox ? "border" : "content" ) ) { + return 0; + } + + for ( ; i < 4; i += 2 ) { + + // Both box models exclude margin + if ( box === "margin" ) { + delta += jQuery.css( elem, box + cssExpand[ i ], true, styles ); + } + + // If we get here with a content-box, we're seeking "padding" or "border" or "margin" + if ( !isBorderBox ) { + + // Add padding + delta += jQuery.css( elem, "padding" + cssExpand[ i ], true, styles ); + + // For "border" or "margin", add border + if ( box !== "padding" ) { + delta += jQuery.css( elem, "border" + cssExpand[ i ] + "Width", true, styles ); + + // But still keep track of it otherwise + } else { + extra += jQuery.css( elem, "border" + cssExpand[ i ] + "Width", true, styles ); + } + + // If we get here with a border-box (content + padding + border), we're seeking "content" or + // "padding" or "margin" + } else { + + // For "content", subtract padding + if ( box === "content" ) { + delta -= jQuery.css( elem, "padding" + cssExpand[ i ], true, styles ); + } + + // For "content" or "padding", subtract border + if ( box !== "margin" ) { + delta -= jQuery.css( elem, "border" + cssExpand[ i ] + "Width", true, styles ); + } + } + } + + // Account for positive content-box scroll gutter when requested by providing computedVal + if ( !isBorderBox && computedVal >= 0 ) { + + // offsetWidth/offsetHeight is a rounded sum of content, padding, scroll gutter, and border + // Assuming integer scroll gutter, subtract the rest and round down + delta += Math.max( 0, Math.ceil( + elem[ "offset" + dimension[ 0 ].toUpperCase() + dimension.slice( 1 ) ] - + computedVal - + delta - + extra - + 0.5 + + // If offsetWidth/offsetHeight is unknown, then we can't determine content-box scroll gutter + // Use an explicit zero to avoid NaN (gh-3964) + ) ) || 0; + } + + return delta; +} + +function getWidthOrHeight( elem, dimension, extra ) { + + // Start with computed style + var styles = getStyles( elem ), + + // To avoid forcing a reflow, only fetch boxSizing if we need it (gh-4322). + // Fake content-box until we know it's needed to know the true value. + boxSizingNeeded = !support.boxSizingReliable() || extra, + isBorderBox = boxSizingNeeded && + jQuery.css( elem, "boxSizing", false, styles ) === "border-box", + valueIsBorderBox = isBorderBox, + + val = curCSS( elem, dimension, styles ), + offsetProp = "offset" + dimension[ 0 ].toUpperCase() + dimension.slice( 1 ); + + // Support: Firefox <=54 + // Return a confounding non-pixel value or feign ignorance, as appropriate. + if ( rnumnonpx.test( val ) ) { + if ( !extra ) { + return val; + } + val = "auto"; + } + + + // Support: IE 9 - 11 only + // Use offsetWidth/offsetHeight for when box sizing is unreliable. + // In those cases, the computed value can be trusted to be border-box. + if ( ( !support.boxSizingReliable() && isBorderBox || + + // Support: IE 10 - 11+, Edge 15 - 18+ + // IE/Edge misreport `getComputedStyle` of table rows with width/height + // set in CSS while `offset*` properties report correct values. + // Interestingly, in some cases IE 9 doesn't suffer from this issue. + !support.reliableTrDimensions() && nodeName( elem, "tr" ) || + + // Fall back to offsetWidth/offsetHeight when value is "auto" + // This happens for inline elements with no explicit setting (gh-3571) + val === "auto" || + + // Support: Android <=4.1 - 4.3 only + // Also use offsetWidth/offsetHeight for misreported inline dimensions (gh-3602) + !parseFloat( val ) && jQuery.css( elem, "display", false, styles ) === "inline" ) && + + // Make sure the element is visible & connected + elem.getClientRects().length ) { + + isBorderBox = jQuery.css( elem, "boxSizing", false, styles ) === "border-box"; + + // Where available, offsetWidth/offsetHeight approximate border box dimensions. + // Where not available (e.g., SVG), assume unreliable box-sizing and interpret the + // retrieved value as a content box dimension. + valueIsBorderBox = offsetProp in elem; + if ( valueIsBorderBox ) { + val = elem[ offsetProp ]; + } + } + + // Normalize "" and auto + val = parseFloat( val ) || 0; + + // Adjust for the element's box model + return ( val + + boxModelAdjustment( + elem, + dimension, + extra || ( isBorderBox ? "border" : "content" ), + valueIsBorderBox, + styles, + + // Provide the current computed size to request scroll gutter calculation (gh-3589) + val + ) + ) + "px"; +} + +jQuery.extend( { + + // Add in style property hooks for overriding the default + // behavior of getting and setting a style property + cssHooks: { + opacity: { + get: function( elem, computed ) { + if ( computed ) { + + // We should always get a number back from opacity + var ret = curCSS( elem, "opacity" ); + return ret === "" ? "1" : ret; + } + } + } + }, + + // Don't automatically add "px" to these possibly-unitless properties + cssNumber: { + "animationIterationCount": true, + "columnCount": true, + "fillOpacity": true, + "flexGrow": true, + "flexShrink": true, + "fontWeight": true, + "gridArea": true, + "gridColumn": true, + "gridColumnEnd": true, + "gridColumnStart": true, + "gridRow": true, + "gridRowEnd": true, + "gridRowStart": true, + "lineHeight": true, + "opacity": true, + "order": true, + "orphans": true, + "widows": true, + "zIndex": true, + "zoom": true + }, + + // Add in properties whose names you wish to fix before + // setting or getting the value + cssProps: {}, + + // Get and set the style property on a DOM Node + style: function( elem, name, value, extra ) { + + // Don't set styles on text and comment nodes + if ( !elem || elem.nodeType === 3 || elem.nodeType === 8 || !elem.style ) { + return; + } + + // Make sure that we're working with the right name + var ret, type, hooks, + origName = camelCase( name ), + isCustomProp = rcustomProp.test( name ), + style = elem.style; + + // Make sure that we're working with the right name. We don't + // want to query the value if it is a CSS custom property + // since they are user-defined. + if ( !isCustomProp ) { + name = finalPropName( origName ); + } + + // Gets hook for the prefixed version, then unprefixed version + hooks = jQuery.cssHooks[ name ] || jQuery.cssHooks[ origName ]; + + // Check if we're setting a value + if ( value !== undefined ) { + type = typeof value; + + // Convert "+=" or "-=" to relative numbers (#7345) + if ( type === "string" && ( ret = rcssNum.exec( value ) ) && ret[ 1 ] ) { + value = adjustCSS( elem, name, ret ); + + // Fixes bug #9237 + type = "number"; + } + + // Make sure that null and NaN values aren't set (#7116) + if ( value == null || value !== value ) { + return; + } + + // If a number was passed in, add the unit (except for certain CSS properties) + // The isCustomProp check can be removed in jQuery 4.0 when we only auto-append + // "px" to a few hardcoded values. + if ( type === "number" && !isCustomProp ) { + value += ret && ret[ 3 ] || ( jQuery.cssNumber[ origName ] ? "" : "px" ); + } + + // background-* props affect original clone's values + if ( !support.clearCloneStyle && value === "" && name.indexOf( "background" ) === 0 ) { + style[ name ] = "inherit"; + } + + // If a hook was provided, use that value, otherwise just set the specified value + if ( !hooks || !( "set" in hooks ) || + ( value = hooks.set( elem, value, extra ) ) !== undefined ) { + + if ( isCustomProp ) { + style.setProperty( name, value ); + } else { + style[ name ] = value; + } + } + + } else { + + // If a hook was provided get the non-computed value from there + if ( hooks && "get" in hooks && + ( ret = hooks.get( elem, false, extra ) ) !== undefined ) { + + return ret; + } + + // Otherwise just get the value from the style object + return style[ name ]; + } + }, + + css: function( elem, name, extra, styles ) { + var val, num, hooks, + origName = camelCase( name ), + isCustomProp = rcustomProp.test( name ); + + // Make sure that we're working with the right name. We don't + // want to modify the value if it is a CSS custom property + // since they are user-defined. + if ( !isCustomProp ) { + name = finalPropName( origName ); + } + + // Try prefixed name followed by the unprefixed name + hooks = jQuery.cssHooks[ name ] || jQuery.cssHooks[ origName ]; + + // If a hook was provided get the computed value from there + if ( hooks && "get" in hooks ) { + val = hooks.get( elem, true, extra ); + } + + // Otherwise, if a way to get the computed value exists, use that + if ( val === undefined ) { + val = curCSS( elem, name, styles ); + } + + // Convert "normal" to computed value + if ( val === "normal" && name in cssNormalTransform ) { + val = cssNormalTransform[ name ]; + } + + // Make numeric if forced or a qualifier was provided and val looks numeric + if ( extra === "" || extra ) { + num = parseFloat( val ); + return extra === true || isFinite( num ) ? num || 0 : val; + } + + return val; + } +} ); + +jQuery.each( [ "height", "width" ], function( _i, dimension ) { + jQuery.cssHooks[ dimension ] = { + get: function( elem, computed, extra ) { + if ( computed ) { + + // Certain elements can have dimension info if we invisibly show them + // but it must have a current display style that would benefit + return rdisplayswap.test( jQuery.css( elem, "display" ) ) && + + // Support: Safari 8+ + // Table columns in Safari have non-zero offsetWidth & zero + // getBoundingClientRect().width unless display is changed. + // Support: IE <=11 only + // Running getBoundingClientRect on a disconnected node + // in IE throws an error. + ( !elem.getClientRects().length || !elem.getBoundingClientRect().width ) ? + swap( elem, cssShow, function() { + return getWidthOrHeight( elem, dimension, extra ); + } ) : + getWidthOrHeight( elem, dimension, extra ); + } + }, + + set: function( elem, value, extra ) { + var matches, + styles = getStyles( elem ), + + // Only read styles.position if the test has a chance to fail + // to avoid forcing a reflow. + scrollboxSizeBuggy = !support.scrollboxSize() && + styles.position === "absolute", + + // To avoid forcing a reflow, only fetch boxSizing if we need it (gh-3991) + boxSizingNeeded = scrollboxSizeBuggy || extra, + isBorderBox = boxSizingNeeded && + jQuery.css( elem, "boxSizing", false, styles ) === "border-box", + subtract = extra ? + boxModelAdjustment( + elem, + dimension, + extra, + isBorderBox, + styles + ) : + 0; + + // Account for unreliable border-box dimensions by comparing offset* to computed and + // faking a content-box to get border and padding (gh-3699) + if ( isBorderBox && scrollboxSizeBuggy ) { + subtract -= Math.ceil( + elem[ "offset" + dimension[ 0 ].toUpperCase() + dimension.slice( 1 ) ] - + parseFloat( styles[ dimension ] ) - + boxModelAdjustment( elem, dimension, "border", false, styles ) - + 0.5 + ); + } + + // Convert to pixels if value adjustment is needed + if ( subtract && ( matches = rcssNum.exec( value ) ) && + ( matches[ 3 ] || "px" ) !== "px" ) { + + elem.style[ dimension ] = value; + value = jQuery.css( elem, dimension ); + } + + return setPositiveNumber( elem, value, subtract ); + } + }; +} ); + +jQuery.cssHooks.marginLeft = addGetHookIf( support.reliableMarginLeft, + function( elem, computed ) { + if ( computed ) { + return ( parseFloat( curCSS( elem, "marginLeft" ) ) || + elem.getBoundingClientRect().left - + swap( elem, { marginLeft: 0 }, function() { + return elem.getBoundingClientRect().left; + } ) + ) + "px"; + } + } +); + +// These hooks are used by animate to expand properties +jQuery.each( { + margin: "", + padding: "", + border: "Width" +}, function( prefix, suffix ) { + jQuery.cssHooks[ prefix + suffix ] = { + expand: function( value ) { + var i = 0, + expanded = {}, + + // Assumes a single number if not a string + parts = typeof value === "string" ? value.split( " " ) : [ value ]; + + for ( ; i < 4; i++ ) { + expanded[ prefix + cssExpand[ i ] + suffix ] = + parts[ i ] || parts[ i - 2 ] || parts[ 0 ]; + } + + return expanded; + } + }; + + if ( prefix !== "margin" ) { + jQuery.cssHooks[ prefix + suffix ].set = setPositiveNumber; + } +} ); + +jQuery.fn.extend( { + css: function( name, value ) { + return access( this, function( elem, name, value ) { + var styles, len, + map = {}, + i = 0; + + if ( Array.isArray( name ) ) { + styles = getStyles( elem ); + len = name.length; + + for ( ; i < len; i++ ) { + map[ name[ i ] ] = jQuery.css( elem, name[ i ], false, styles ); + } + + return map; + } + + return value !== undefined ? + jQuery.style( elem, name, value ) : + jQuery.css( elem, name ); + }, name, value, arguments.length > 1 ); + } +} ); + + +function Tween( elem, options, prop, end, easing ) { + return new Tween.prototype.init( elem, options, prop, end, easing ); +} +jQuery.Tween = Tween; + +Tween.prototype = { + constructor: Tween, + init: function( elem, options, prop, end, easing, unit ) { + this.elem = elem; + this.prop = prop; + this.easing = easing || jQuery.easing._default; + this.options = options; + this.start = this.now = this.cur(); + this.end = end; + this.unit = unit || ( jQuery.cssNumber[ prop ] ? "" : "px" ); + }, + cur: function() { + var hooks = Tween.propHooks[ this.prop ]; + + return hooks && hooks.get ? + hooks.get( this ) : + Tween.propHooks._default.get( this ); + }, + run: function( percent ) { + var eased, + hooks = Tween.propHooks[ this.prop ]; + + if ( this.options.duration ) { + this.pos = eased = jQuery.easing[ this.easing ]( + percent, this.options.duration * percent, 0, 1, this.options.duration + ); + } else { + this.pos = eased = percent; + } + this.now = ( this.end - this.start ) * eased + this.start; + + if ( this.options.step ) { + this.options.step.call( this.elem, this.now, this ); + } + + if ( hooks && hooks.set ) { + hooks.set( this ); + } else { + Tween.propHooks._default.set( this ); + } + return this; + } +}; + +Tween.prototype.init.prototype = Tween.prototype; + +Tween.propHooks = { + _default: { + get: function( tween ) { + var result; + + // Use a property on the element directly when it is not a DOM element, + // or when there is no matching style property that exists. + if ( tween.elem.nodeType !== 1 || + tween.elem[ tween.prop ] != null && tween.elem.style[ tween.prop ] == null ) { + return tween.elem[ tween.prop ]; + } + + // Passing an empty string as a 3rd parameter to .css will automatically + // attempt a parseFloat and fallback to a string if the parse fails. + // Simple values such as "10px" are parsed to Float; + // complex values such as "rotate(1rad)" are returned as-is. + result = jQuery.css( tween.elem, tween.prop, "" ); + + // Empty strings, null, undefined and "auto" are converted to 0. + return !result || result === "auto" ? 0 : result; + }, + set: function( tween ) { + + // Use step hook for back compat. + // Use cssHook if its there. + // Use .style if available and use plain properties where available. + if ( jQuery.fx.step[ tween.prop ] ) { + jQuery.fx.step[ tween.prop ]( tween ); + } else if ( tween.elem.nodeType === 1 && ( + jQuery.cssHooks[ tween.prop ] || + tween.elem.style[ finalPropName( tween.prop ) ] != null ) ) { + jQuery.style( tween.elem, tween.prop, tween.now + tween.unit ); + } else { + tween.elem[ tween.prop ] = tween.now; + } + } + } +}; + +// Support: IE <=9 only +// Panic based approach to setting things on disconnected nodes +Tween.propHooks.scrollTop = Tween.propHooks.scrollLeft = { + set: function( tween ) { + if ( tween.elem.nodeType && tween.elem.parentNode ) { + tween.elem[ tween.prop ] = tween.now; + } + } +}; + +jQuery.easing = { + linear: function( p ) { + return p; + }, + swing: function( p ) { + return 0.5 - Math.cos( p * Math.PI ) / 2; + }, + _default: "swing" +}; + +jQuery.fx = Tween.prototype.init; + +// Back compat <1.8 extension point +jQuery.fx.step = {}; + + + + +var + fxNow, inProgress, + rfxtypes = /^(?:toggle|show|hide)$/, + rrun = /queueHooks$/; + +function schedule() { + if ( inProgress ) { + if ( document.hidden === false && window.requestAnimationFrame ) { + window.requestAnimationFrame( schedule ); + } else { + window.setTimeout( schedule, jQuery.fx.interval ); + } + + jQuery.fx.tick(); + } +} + +// Animations created synchronously will run synchronously +function createFxNow() { + window.setTimeout( function() { + fxNow = undefined; + } ); + return ( fxNow = Date.now() ); +} + +// Generate parameters to create a standard animation +function genFx( type, includeWidth ) { + var which, + i = 0, + attrs = { height: type }; + + // If we include width, step value is 1 to do all cssExpand values, + // otherwise step value is 2 to skip over Left and Right + includeWidth = includeWidth ? 1 : 0; + for ( ; i < 4; i += 2 - includeWidth ) { + which = cssExpand[ i ]; + attrs[ "margin" + which ] = attrs[ "padding" + which ] = type; + } + + if ( includeWidth ) { + attrs.opacity = attrs.width = type; + } + + return attrs; +} + +function createTween( value, prop, animation ) { + var tween, + collection = ( Animation.tweeners[ prop ] || [] ).concat( Animation.tweeners[ "*" ] ), + index = 0, + length = collection.length; + for ( ; index < length; index++ ) { + if ( ( tween = collection[ index ].call( animation, prop, value ) ) ) { + + // We're done with this property + return tween; + } + } +} + +function defaultPrefilter( elem, props, opts ) { + var prop, value, toggle, hooks, oldfire, propTween, restoreDisplay, display, + isBox = "width" in props || "height" in props, + anim = this, + orig = {}, + style = elem.style, + hidden = elem.nodeType && isHiddenWithinTree( elem ), + dataShow = dataPriv.get( elem, "fxshow" ); + + // Queue-skipping animations hijack the fx hooks + if ( !opts.queue ) { + hooks = jQuery._queueHooks( elem, "fx" ); + if ( hooks.unqueued == null ) { + hooks.unqueued = 0; + oldfire = hooks.empty.fire; + hooks.empty.fire = function() { + if ( !hooks.unqueued ) { + oldfire(); + } + }; + } + hooks.unqueued++; + + anim.always( function() { + + // Ensure the complete handler is called before this completes + anim.always( function() { + hooks.unqueued--; + if ( !jQuery.queue( elem, "fx" ).length ) { + hooks.empty.fire(); + } + } ); + } ); + } + + // Detect show/hide animations + for ( prop in props ) { + value = props[ prop ]; + if ( rfxtypes.test( value ) ) { + delete props[ prop ]; + toggle = toggle || value === "toggle"; + if ( value === ( hidden ? "hide" : "show" ) ) { + + // Pretend to be hidden if this is a "show" and + // there is still data from a stopped show/hide + if ( value === "show" && dataShow && dataShow[ prop ] !== undefined ) { + hidden = true; + + // Ignore all other no-op show/hide data + } else { + continue; + } + } + orig[ prop ] = dataShow && dataShow[ prop ] || jQuery.style( elem, prop ); + } + } + + // Bail out if this is a no-op like .hide().hide() + propTween = !jQuery.isEmptyObject( props ); + if ( !propTween && jQuery.isEmptyObject( orig ) ) { + return; + } + + // Restrict "overflow" and "display" styles during box animations + if ( isBox && elem.nodeType === 1 ) { + + // Support: IE <=9 - 11, Edge 12 - 15 + // Record all 3 overflow attributes because IE does not infer the shorthand + // from identically-valued overflowX and overflowY and Edge just mirrors + // the overflowX value there. + opts.overflow = [ style.overflow, style.overflowX, style.overflowY ]; + + // Identify a display type, preferring old show/hide data over the CSS cascade + restoreDisplay = dataShow && dataShow.display; + if ( restoreDisplay == null ) { + restoreDisplay = dataPriv.get( elem, "display" ); + } + display = jQuery.css( elem, "display" ); + if ( display === "none" ) { + if ( restoreDisplay ) { + display = restoreDisplay; + } else { + + // Get nonempty value(s) by temporarily forcing visibility + showHide( [ elem ], true ); + restoreDisplay = elem.style.display || restoreDisplay; + display = jQuery.css( elem, "display" ); + showHide( [ elem ] ); + } + } + + // Animate inline elements as inline-block + if ( display === "inline" || display === "inline-block" && restoreDisplay != null ) { + if ( jQuery.css( elem, "float" ) === "none" ) { + + // Restore the original display value at the end of pure show/hide animations + if ( !propTween ) { + anim.done( function() { + style.display = restoreDisplay; + } ); + if ( restoreDisplay == null ) { + display = style.display; + restoreDisplay = display === "none" ? "" : display; + } + } + style.display = "inline-block"; + } + } + } + + if ( opts.overflow ) { + style.overflow = "hidden"; + anim.always( function() { + style.overflow = opts.overflow[ 0 ]; + style.overflowX = opts.overflow[ 1 ]; + style.overflowY = opts.overflow[ 2 ]; + } ); + } + + // Implement show/hide animations + propTween = false; + for ( prop in orig ) { + + // General show/hide setup for this element animation + if ( !propTween ) { + if ( dataShow ) { + if ( "hidden" in dataShow ) { + hidden = dataShow.hidden; + } + } else { + dataShow = dataPriv.access( elem, "fxshow", { display: restoreDisplay } ); + } + + // Store hidden/visible for toggle so `.stop().toggle()` "reverses" + if ( toggle ) { + dataShow.hidden = !hidden; + } + + // Show elements before animating them + if ( hidden ) { + showHide( [ elem ], true ); + } + + /* eslint-disable no-loop-func */ + + anim.done( function() { + + /* eslint-enable no-loop-func */ + + // The final step of a "hide" animation is actually hiding the element + if ( !hidden ) { + showHide( [ elem ] ); + } + dataPriv.remove( elem, "fxshow" ); + for ( prop in orig ) { + jQuery.style( elem, prop, orig[ prop ] ); + } + } ); + } + + // Per-property setup + propTween = createTween( hidden ? dataShow[ prop ] : 0, prop, anim ); + if ( !( prop in dataShow ) ) { + dataShow[ prop ] = propTween.start; + if ( hidden ) { + propTween.end = propTween.start; + propTween.start = 0; + } + } + } +} + +function propFilter( props, specialEasing ) { + var index, name, easing, value, hooks; + + // camelCase, specialEasing and expand cssHook pass + for ( index in props ) { + name = camelCase( index ); + easing = specialEasing[ name ]; + value = props[ index ]; + if ( Array.isArray( value ) ) { + easing = value[ 1 ]; + value = props[ index ] = value[ 0 ]; + } + + if ( index !== name ) { + props[ name ] = value; + delete props[ index ]; + } + + hooks = jQuery.cssHooks[ name ]; + if ( hooks && "expand" in hooks ) { + value = hooks.expand( value ); + delete props[ name ]; + + // Not quite $.extend, this won't overwrite existing keys. + // Reusing 'index' because we have the correct "name" + for ( index in value ) { + if ( !( index in props ) ) { + props[ index ] = value[ index ]; + specialEasing[ index ] = easing; + } + } + } else { + specialEasing[ name ] = easing; + } + } +} + +function Animation( elem, properties, options ) { + var result, + stopped, + index = 0, + length = Animation.prefilters.length, + deferred = jQuery.Deferred().always( function() { + + // Don't match elem in the :animated selector + delete tick.elem; + } ), + tick = function() { + if ( stopped ) { + return false; + } + var currentTime = fxNow || createFxNow(), + remaining = Math.max( 0, animation.startTime + animation.duration - currentTime ), + + // Support: Android 2.3 only + // Archaic crash bug won't allow us to use `1 - ( 0.5 || 0 )` (#12497) + temp = remaining / animation.duration || 0, + percent = 1 - temp, + index = 0, + length = animation.tweens.length; + + for ( ; index < length; index++ ) { + animation.tweens[ index ].run( percent ); + } + + deferred.notifyWith( elem, [ animation, percent, remaining ] ); + + // If there's more to do, yield + if ( percent < 1 && length ) { + return remaining; + } + + // If this was an empty animation, synthesize a final progress notification + if ( !length ) { + deferred.notifyWith( elem, [ animation, 1, 0 ] ); + } + + // Resolve the animation and report its conclusion + deferred.resolveWith( elem, [ animation ] ); + return false; + }, + animation = deferred.promise( { + elem: elem, + props: jQuery.extend( {}, properties ), + opts: jQuery.extend( true, { + specialEasing: {}, + easing: jQuery.easing._default + }, options ), + originalProperties: properties, + originalOptions: options, + startTime: fxNow || createFxNow(), + duration: options.duration, + tweens: [], + createTween: function( prop, end ) { + var tween = jQuery.Tween( elem, animation.opts, prop, end, + animation.opts.specialEasing[ prop ] || animation.opts.easing ); + animation.tweens.push( tween ); + return tween; + }, + stop: function( gotoEnd ) { + var index = 0, + + // If we are going to the end, we want to run all the tweens + // otherwise we skip this part + length = gotoEnd ? animation.tweens.length : 0; + if ( stopped ) { + return this; + } + stopped = true; + for ( ; index < length; index++ ) { + animation.tweens[ index ].run( 1 ); + } + + // Resolve when we played the last frame; otherwise, reject + if ( gotoEnd ) { + deferred.notifyWith( elem, [ animation, 1, 0 ] ); + deferred.resolveWith( elem, [ animation, gotoEnd ] ); + } else { + deferred.rejectWith( elem, [ animation, gotoEnd ] ); + } + return this; + } + } ), + props = animation.props; + + propFilter( props, animation.opts.specialEasing ); + + for ( ; index < length; index++ ) { + result = Animation.prefilters[ index ].call( animation, elem, props, animation.opts ); + if ( result ) { + if ( isFunction( result.stop ) ) { + jQuery._queueHooks( animation.elem, animation.opts.queue ).stop = + result.stop.bind( result ); + } + return result; + } + } + + jQuery.map( props, createTween, animation ); + + if ( isFunction( animation.opts.start ) ) { + animation.opts.start.call( elem, animation ); + } + + // Attach callbacks from options + animation + .progress( animation.opts.progress ) + .done( animation.opts.done, animation.opts.complete ) + .fail( animation.opts.fail ) + .always( animation.opts.always ); + + jQuery.fx.timer( + jQuery.extend( tick, { + elem: elem, + anim: animation, + queue: animation.opts.queue + } ) + ); + + return animation; +} + +jQuery.Animation = jQuery.extend( Animation, { + + tweeners: { + "*": [ function( prop, value ) { + var tween = this.createTween( prop, value ); + adjustCSS( tween.elem, prop, rcssNum.exec( value ), tween ); + return tween; + } ] + }, + + tweener: function( props, callback ) { + if ( isFunction( props ) ) { + callback = props; + props = [ "*" ]; + } else { + props = props.match( rnothtmlwhite ); + } + + var prop, + index = 0, + length = props.length; + + for ( ; index < length; index++ ) { + prop = props[ index ]; + Animation.tweeners[ prop ] = Animation.tweeners[ prop ] || []; + Animation.tweeners[ prop ].unshift( callback ); + } + }, + + prefilters: [ defaultPrefilter ], + + prefilter: function( callback, prepend ) { + if ( prepend ) { + Animation.prefilters.unshift( callback ); + } else { + Animation.prefilters.push( callback ); + } + } +} ); + +jQuery.speed = function( speed, easing, fn ) { + var opt = speed && typeof speed === "object" ? jQuery.extend( {}, speed ) : { + complete: fn || !fn && easing || + isFunction( speed ) && speed, + duration: speed, + easing: fn && easing || easing && !isFunction( easing ) && easing + }; + + // Go to the end state if fx are off + if ( jQuery.fx.off ) { + opt.duration = 0; + + } else { + if ( typeof opt.duration !== "number" ) { + if ( opt.duration in jQuery.fx.speeds ) { + opt.duration = jQuery.fx.speeds[ opt.duration ]; + + } else { + opt.duration = jQuery.fx.speeds._default; + } + } + } + + // Normalize opt.queue - true/undefined/null -> "fx" + if ( opt.queue == null || opt.queue === true ) { + opt.queue = "fx"; + } + + // Queueing + opt.old = opt.complete; + + opt.complete = function() { + if ( isFunction( opt.old ) ) { + opt.old.call( this ); + } + + if ( opt.queue ) { + jQuery.dequeue( this, opt.queue ); + } + }; + + return opt; +}; + +jQuery.fn.extend( { + fadeTo: function( speed, to, easing, callback ) { + + // Show any hidden elements after setting opacity to 0 + return this.filter( isHiddenWithinTree ).css( "opacity", 0 ).show() + + // Animate to the value specified + .end().animate( { opacity: to }, speed, easing, callback ); + }, + animate: function( prop, speed, easing, callback ) { + var empty = jQuery.isEmptyObject( prop ), + optall = jQuery.speed( speed, easing, callback ), + doAnimation = function() { + + // Operate on a copy of prop so per-property easing won't be lost + var anim = Animation( this, jQuery.extend( {}, prop ), optall ); + + // Empty animations, or finishing resolves immediately + if ( empty || dataPriv.get( this, "finish" ) ) { + anim.stop( true ); + } + }; + doAnimation.finish = doAnimation; + + return empty || optall.queue === false ? + this.each( doAnimation ) : + this.queue( optall.queue, doAnimation ); + }, + stop: function( type, clearQueue, gotoEnd ) { + var stopQueue = function( hooks ) { + var stop = hooks.stop; + delete hooks.stop; + stop( gotoEnd ); + }; + + if ( typeof type !== "string" ) { + gotoEnd = clearQueue; + clearQueue = type; + type = undefined; + } + if ( clearQueue ) { + this.queue( type || "fx", [] ); + } + + return this.each( function() { + var dequeue = true, + index = type != null && type + "queueHooks", + timers = jQuery.timers, + data = dataPriv.get( this ); + + if ( index ) { + if ( data[ index ] && data[ index ].stop ) { + stopQueue( data[ index ] ); + } + } else { + for ( index in data ) { + if ( data[ index ] && data[ index ].stop && rrun.test( index ) ) { + stopQueue( data[ index ] ); + } + } + } + + for ( index = timers.length; index--; ) { + if ( timers[ index ].elem === this && + ( type == null || timers[ index ].queue === type ) ) { + + timers[ index ].anim.stop( gotoEnd ); + dequeue = false; + timers.splice( index, 1 ); + } + } + + // Start the next in the queue if the last step wasn't forced. + // Timers currently will call their complete callbacks, which + // will dequeue but only if they were gotoEnd. + if ( dequeue || !gotoEnd ) { + jQuery.dequeue( this, type ); + } + } ); + }, + finish: function( type ) { + if ( type !== false ) { + type = type || "fx"; + } + return this.each( function() { + var index, + data = dataPriv.get( this ), + queue = data[ type + "queue" ], + hooks = data[ type + "queueHooks" ], + timers = jQuery.timers, + length = queue ? queue.length : 0; + + // Enable finishing flag on private data + data.finish = true; + + // Empty the queue first + jQuery.queue( this, type, [] ); + + if ( hooks && hooks.stop ) { + hooks.stop.call( this, true ); + } + + // Look for any active animations, and finish them + for ( index = timers.length; index--; ) { + if ( timers[ index ].elem === this && timers[ index ].queue === type ) { + timers[ index ].anim.stop( true ); + timers.splice( index, 1 ); + } + } + + // Look for any animations in the old queue and finish them + for ( index = 0; index < length; index++ ) { + if ( queue[ index ] && queue[ index ].finish ) { + queue[ index ].finish.call( this ); + } + } + + // Turn off finishing flag + delete data.finish; + } ); + } +} ); + +jQuery.each( [ "toggle", "show", "hide" ], function( _i, name ) { + var cssFn = jQuery.fn[ name ]; + jQuery.fn[ name ] = function( speed, easing, callback ) { + return speed == null || typeof speed === "boolean" ? + cssFn.apply( this, arguments ) : + this.animate( genFx( name, true ), speed, easing, callback ); + }; +} ); + +// Generate shortcuts for custom animations +jQuery.each( { + slideDown: genFx( "show" ), + slideUp: genFx( "hide" ), + slideToggle: genFx( "toggle" ), + fadeIn: { opacity: "show" }, + fadeOut: { opacity: "hide" }, + fadeToggle: { opacity: "toggle" } +}, function( name, props ) { + jQuery.fn[ name ] = function( speed, easing, callback ) { + return this.animate( props, speed, easing, callback ); + }; +} ); + +jQuery.timers = []; +jQuery.fx.tick = function() { + var timer, + i = 0, + timers = jQuery.timers; + + fxNow = Date.now(); + + for ( ; i < timers.length; i++ ) { + timer = timers[ i ]; + + // Run the timer and safely remove it when done (allowing for external removal) + if ( !timer() && timers[ i ] === timer ) { + timers.splice( i--, 1 ); + } + } + + if ( !timers.length ) { + jQuery.fx.stop(); + } + fxNow = undefined; +}; + +jQuery.fx.timer = function( timer ) { + jQuery.timers.push( timer ); + jQuery.fx.start(); +}; + +jQuery.fx.interval = 13; +jQuery.fx.start = function() { + if ( inProgress ) { + return; + } + + inProgress = true; + schedule(); +}; + +jQuery.fx.stop = function() { + inProgress = null; +}; + +jQuery.fx.speeds = { + slow: 600, + fast: 200, + + // Default speed + _default: 400 +}; + + +// Based off of the plugin by Clint Helfers, with permission. +// https://web.archive.org/web/20100324014747/http://blindsignals.com/index.php/2009/07/jquery-delay/ +jQuery.fn.delay = function( time, type ) { + time = jQuery.fx ? jQuery.fx.speeds[ time ] || time : time; + type = type || "fx"; + + return this.queue( type, function( next, hooks ) { + var timeout = window.setTimeout( next, time ); + hooks.stop = function() { + window.clearTimeout( timeout ); + }; + } ); +}; + + +( function() { + var input = document.createElement( "input" ), + select = document.createElement( "select" ), + opt = select.appendChild( document.createElement( "option" ) ); + + input.type = "checkbox"; + + // Support: Android <=4.3 only + // Default value for a checkbox should be "on" + support.checkOn = input.value !== ""; + + // Support: IE <=11 only + // Must access selectedIndex to make default options select + support.optSelected = opt.selected; + + // Support: IE <=11 only + // An input loses its value after becoming a radio + input = document.createElement( "input" ); + input.value = "t"; + input.type = "radio"; + support.radioValue = input.value === "t"; +} )(); + + +var boolHook, + attrHandle = jQuery.expr.attrHandle; + +jQuery.fn.extend( { + attr: function( name, value ) { + return access( this, jQuery.attr, name, value, arguments.length > 1 ); + }, + + removeAttr: function( name ) { + return this.each( function() { + jQuery.removeAttr( this, name ); + } ); + } +} ); + +jQuery.extend( { + attr: function( elem, name, value ) { + var ret, hooks, + nType = elem.nodeType; + + // Don't get/set attributes on text, comment and attribute nodes + if ( nType === 3 || nType === 8 || nType === 2 ) { + return; + } + + // Fallback to prop when attributes are not supported + if ( typeof elem.getAttribute === "undefined" ) { + return jQuery.prop( elem, name, value ); + } + + // Attribute hooks are determined by the lowercase version + // Grab necessary hook if one is defined + if ( nType !== 1 || !jQuery.isXMLDoc( elem ) ) { + hooks = jQuery.attrHooks[ name.toLowerCase() ] || + ( jQuery.expr.match.bool.test( name ) ? boolHook : undefined ); + } + + if ( value !== undefined ) { + if ( value === null ) { + jQuery.removeAttr( elem, name ); + return; + } + + if ( hooks && "set" in hooks && + ( ret = hooks.set( elem, value, name ) ) !== undefined ) { + return ret; + } + + elem.setAttribute( name, value + "" ); + return value; + } + + if ( hooks && "get" in hooks && ( ret = hooks.get( elem, name ) ) !== null ) { + return ret; + } + + ret = jQuery.find.attr( elem, name ); + + // Non-existent attributes return null, we normalize to undefined + return ret == null ? undefined : ret; + }, + + attrHooks: { + type: { + set: function( elem, value ) { + if ( !support.radioValue && value === "radio" && + nodeName( elem, "input" ) ) { + var val = elem.value; + elem.setAttribute( "type", value ); + if ( val ) { + elem.value = val; + } + return value; + } + } + } + }, + + removeAttr: function( elem, value ) { + var name, + i = 0, + + // Attribute names can contain non-HTML whitespace characters + // https://html.spec.whatwg.org/multipage/syntax.html#attributes-2 + attrNames = value && value.match( rnothtmlwhite ); + + if ( attrNames && elem.nodeType === 1 ) { + while ( ( name = attrNames[ i++ ] ) ) { + elem.removeAttribute( name ); + } + } + } +} ); + +// Hooks for boolean attributes +boolHook = { + set: function( elem, value, name ) { + if ( value === false ) { + + // Remove boolean attributes when set to false + jQuery.removeAttr( elem, name ); + } else { + elem.setAttribute( name, name ); + } + return name; + } +}; + +jQuery.each( jQuery.expr.match.bool.source.match( /\w+/g ), function( _i, name ) { + var getter = attrHandle[ name ] || jQuery.find.attr; + + attrHandle[ name ] = function( elem, name, isXML ) { + var ret, handle, + lowercaseName = name.toLowerCase(); + + if ( !isXML ) { + + // Avoid an infinite loop by temporarily removing this function from the getter + handle = attrHandle[ lowercaseName ]; + attrHandle[ lowercaseName ] = ret; + ret = getter( elem, name, isXML ) != null ? + lowercaseName : + null; + attrHandle[ lowercaseName ] = handle; + } + return ret; + }; +} ); + + + + +var rfocusable = /^(?:input|select|textarea|button)$/i, + rclickable = /^(?:a|area)$/i; + +jQuery.fn.extend( { + prop: function( name, value ) { + return access( this, jQuery.prop, name, value, arguments.length > 1 ); + }, + + removeProp: function( name ) { + return this.each( function() { + delete this[ jQuery.propFix[ name ] || name ]; + } ); + } +} ); + +jQuery.extend( { + prop: function( elem, name, value ) { + var ret, hooks, + nType = elem.nodeType; + + // Don't get/set properties on text, comment and attribute nodes + if ( nType === 3 || nType === 8 || nType === 2 ) { + return; + } + + if ( nType !== 1 || !jQuery.isXMLDoc( elem ) ) { + + // Fix name and attach hooks + name = jQuery.propFix[ name ] || name; + hooks = jQuery.propHooks[ name ]; + } + + if ( value !== undefined ) { + if ( hooks && "set" in hooks && + ( ret = hooks.set( elem, value, name ) ) !== undefined ) { + return ret; + } + + return ( elem[ name ] = value ); + } + + if ( hooks && "get" in hooks && ( ret = hooks.get( elem, name ) ) !== null ) { + return ret; + } + + return elem[ name ]; + }, + + propHooks: { + tabIndex: { + get: function( elem ) { + + // Support: IE <=9 - 11 only + // elem.tabIndex doesn't always return the + // correct value when it hasn't been explicitly set + // https://web.archive.org/web/20141116233347/http://fluidproject.org/blog/2008/01/09/getting-setting-and-removing-tabindex-values-with-javascript/ + // Use proper attribute retrieval(#12072) + var tabindex = jQuery.find.attr( elem, "tabindex" ); + + if ( tabindex ) { + return parseInt( tabindex, 10 ); + } + + if ( + rfocusable.test( elem.nodeName ) || + rclickable.test( elem.nodeName ) && + elem.href + ) { + return 0; + } + + return -1; + } + } + }, + + propFix: { + "for": "htmlFor", + "class": "className" + } +} ); + +// Support: IE <=11 only +// Accessing the selectedIndex property +// forces the browser to respect setting selected +// on the option +// The getter ensures a default option is selected +// when in an optgroup +// eslint rule "no-unused-expressions" is disabled for this code +// since it considers such accessions noop +if ( !support.optSelected ) { + jQuery.propHooks.selected = { + get: function( elem ) { + + /* eslint no-unused-expressions: "off" */ + + var parent = elem.parentNode; + if ( parent && parent.parentNode ) { + parent.parentNode.selectedIndex; + } + return null; + }, + set: function( elem ) { + + /* eslint no-unused-expressions: "off" */ + + var parent = elem.parentNode; + if ( parent ) { + parent.selectedIndex; + + if ( parent.parentNode ) { + parent.parentNode.selectedIndex; + } + } + } + }; +} + +jQuery.each( [ + "tabIndex", + "readOnly", + "maxLength", + "cellSpacing", + "cellPadding", + "rowSpan", + "colSpan", + "useMap", + "frameBorder", + "contentEditable" +], function() { + jQuery.propFix[ this.toLowerCase() ] = this; +} ); + + + + + // Strip and collapse whitespace according to HTML spec + // https://infra.spec.whatwg.org/#strip-and-collapse-ascii-whitespace + function stripAndCollapse( value ) { + var tokens = value.match( rnothtmlwhite ) || []; + return tokens.join( " " ); + } + + +function getClass( elem ) { + return elem.getAttribute && elem.getAttribute( "class" ) || ""; +} + +function classesToArray( value ) { + if ( Array.isArray( value ) ) { + return value; + } + if ( typeof value === "string" ) { + return value.match( rnothtmlwhite ) || []; + } + return []; +} + +jQuery.fn.extend( { + addClass: function( value ) { + var classes, elem, cur, curValue, clazz, j, finalValue, + i = 0; + + if ( isFunction( value ) ) { + return this.each( function( j ) { + jQuery( this ).addClass( value.call( this, j, getClass( this ) ) ); + } ); + } + + classes = classesToArray( value ); + + if ( classes.length ) { + while ( ( elem = this[ i++ ] ) ) { + curValue = getClass( elem ); + cur = elem.nodeType === 1 && ( " " + stripAndCollapse( curValue ) + " " ); + + if ( cur ) { + j = 0; + while ( ( clazz = classes[ j++ ] ) ) { + if ( cur.indexOf( " " + clazz + " " ) < 0 ) { + cur += clazz + " "; + } + } + + // Only assign if different to avoid unneeded rendering. + finalValue = stripAndCollapse( cur ); + if ( curValue !== finalValue ) { + elem.setAttribute( "class", finalValue ); + } + } + } + } + + return this; + }, + + removeClass: function( value ) { + var classes, elem, cur, curValue, clazz, j, finalValue, + i = 0; + + if ( isFunction( value ) ) { + return this.each( function( j ) { + jQuery( this ).removeClass( value.call( this, j, getClass( this ) ) ); + } ); + } + + if ( !arguments.length ) { + return this.attr( "class", "" ); + } + + classes = classesToArray( value ); + + if ( classes.length ) { + while ( ( elem = this[ i++ ] ) ) { + curValue = getClass( elem ); + + // This expression is here for better compressibility (see addClass) + cur = elem.nodeType === 1 && ( " " + stripAndCollapse( curValue ) + " " ); + + if ( cur ) { + j = 0; + while ( ( clazz = classes[ j++ ] ) ) { + + // Remove *all* instances + while ( cur.indexOf( " " + clazz + " " ) > -1 ) { + cur = cur.replace( " " + clazz + " ", " " ); + } + } + + // Only assign if different to avoid unneeded rendering. + finalValue = stripAndCollapse( cur ); + if ( curValue !== finalValue ) { + elem.setAttribute( "class", finalValue ); + } + } + } + } + + return this; + }, + + toggleClass: function( value, stateVal ) { + var type = typeof value, + isValidValue = type === "string" || Array.isArray( value ); + + if ( typeof stateVal === "boolean" && isValidValue ) { + return stateVal ? this.addClass( value ) : this.removeClass( value ); + } + + if ( isFunction( value ) ) { + return this.each( function( i ) { + jQuery( this ).toggleClass( + value.call( this, i, getClass( this ), stateVal ), + stateVal + ); + } ); + } + + return this.each( function() { + var className, i, self, classNames; + + if ( isValidValue ) { + + // Toggle individual class names + i = 0; + self = jQuery( this ); + classNames = classesToArray( value ); + + while ( ( className = classNames[ i++ ] ) ) { + + // Check each className given, space separated list + if ( self.hasClass( className ) ) { + self.removeClass( className ); + } else { + self.addClass( className ); + } + } + + // Toggle whole class name + } else if ( value === undefined || type === "boolean" ) { + className = getClass( this ); + if ( className ) { + + // Store className if set + dataPriv.set( this, "__className__", className ); + } + + // If the element has a class name or if we're passed `false`, + // then remove the whole classname (if there was one, the above saved it). + // Otherwise bring back whatever was previously saved (if anything), + // falling back to the empty string if nothing was stored. + if ( this.setAttribute ) { + this.setAttribute( "class", + className || value === false ? + "" : + dataPriv.get( this, "__className__" ) || "" + ); + } + } + } ); + }, + + hasClass: function( selector ) { + var className, elem, + i = 0; + + className = " " + selector + " "; + while ( ( elem = this[ i++ ] ) ) { + if ( elem.nodeType === 1 && + ( " " + stripAndCollapse( getClass( elem ) ) + " " ).indexOf( className ) > -1 ) { + return true; + } + } + + return false; + } +} ); + + + + +var rreturn = /\r/g; + +jQuery.fn.extend( { + val: function( value ) { + var hooks, ret, valueIsFunction, + elem = this[ 0 ]; + + if ( !arguments.length ) { + if ( elem ) { + hooks = jQuery.valHooks[ elem.type ] || + jQuery.valHooks[ elem.nodeName.toLowerCase() ]; + + if ( hooks && + "get" in hooks && + ( ret = hooks.get( elem, "value" ) ) !== undefined + ) { + return ret; + } + + ret = elem.value; + + // Handle most common string cases + if ( typeof ret === "string" ) { + return ret.replace( rreturn, "" ); + } + + // Handle cases where value is null/undef or number + return ret == null ? "" : ret; + } + + return; + } + + valueIsFunction = isFunction( value ); + + return this.each( function( i ) { + var val; + + if ( this.nodeType !== 1 ) { + return; + } + + if ( valueIsFunction ) { + val = value.call( this, i, jQuery( this ).val() ); + } else { + val = value; + } + + // Treat null/undefined as ""; convert numbers to string + if ( val == null ) { + val = ""; + + } else if ( typeof val === "number" ) { + val += ""; + + } else if ( Array.isArray( val ) ) { + val = jQuery.map( val, function( value ) { + return value == null ? "" : value + ""; + } ); + } + + hooks = jQuery.valHooks[ this.type ] || jQuery.valHooks[ this.nodeName.toLowerCase() ]; + + // If set returns undefined, fall back to normal setting + if ( !hooks || !( "set" in hooks ) || hooks.set( this, val, "value" ) === undefined ) { + this.value = val; + } + } ); + } +} ); + +jQuery.extend( { + valHooks: { + option: { + get: function( elem ) { + + var val = jQuery.find.attr( elem, "value" ); + return val != null ? + val : + + // Support: IE <=10 - 11 only + // option.text throws exceptions (#14686, #14858) + // Strip and collapse whitespace + // https://html.spec.whatwg.org/#strip-and-collapse-whitespace + stripAndCollapse( jQuery.text( elem ) ); + } + }, + select: { + get: function( elem ) { + var value, option, i, + options = elem.options, + index = elem.selectedIndex, + one = elem.type === "select-one", + values = one ? null : [], + max = one ? index + 1 : options.length; + + if ( index < 0 ) { + i = max; + + } else { + i = one ? index : 0; + } + + // Loop through all the selected options + for ( ; i < max; i++ ) { + option = options[ i ]; + + // Support: IE <=9 only + // IE8-9 doesn't update selected after form reset (#2551) + if ( ( option.selected || i === index ) && + + // Don't return options that are disabled or in a disabled optgroup + !option.disabled && + ( !option.parentNode.disabled || + !nodeName( option.parentNode, "optgroup" ) ) ) { + + // Get the specific value for the option + value = jQuery( option ).val(); + + // We don't need an array for one selects + if ( one ) { + return value; + } + + // Multi-Selects return an array + values.push( value ); + } + } + + return values; + }, + + set: function( elem, value ) { + var optionSet, option, + options = elem.options, + values = jQuery.makeArray( value ), + i = options.length; + + while ( i-- ) { + option = options[ i ]; + + /* eslint-disable no-cond-assign */ + + if ( option.selected = + jQuery.inArray( jQuery.valHooks.option.get( option ), values ) > -1 + ) { + optionSet = true; + } + + /* eslint-enable no-cond-assign */ + } + + // Force browsers to behave consistently when non-matching value is set + if ( !optionSet ) { + elem.selectedIndex = -1; + } + return values; + } + } + } +} ); + +// Radios and checkboxes getter/setter +jQuery.each( [ "radio", "checkbox" ], function() { + jQuery.valHooks[ this ] = { + set: function( elem, value ) { + if ( Array.isArray( value ) ) { + return ( elem.checked = jQuery.inArray( jQuery( elem ).val(), value ) > -1 ); + } + } + }; + if ( !support.checkOn ) { + jQuery.valHooks[ this ].get = function( elem ) { + return elem.getAttribute( "value" ) === null ? "on" : elem.value; + }; + } +} ); + + + + +// Return jQuery for attributes-only inclusion + + +support.focusin = "onfocusin" in window; + + +var rfocusMorph = /^(?:focusinfocus|focusoutblur)$/, + stopPropagationCallback = function( e ) { + e.stopPropagation(); + }; + +jQuery.extend( jQuery.event, { + + trigger: function( event, data, elem, onlyHandlers ) { + + var i, cur, tmp, bubbleType, ontype, handle, special, lastElement, + eventPath = [ elem || document ], + type = hasOwn.call( event, "type" ) ? event.type : event, + namespaces = hasOwn.call( event, "namespace" ) ? event.namespace.split( "." ) : []; + + cur = lastElement = tmp = elem = elem || document; + + // Don't do events on text and comment nodes + if ( elem.nodeType === 3 || elem.nodeType === 8 ) { + return; + } + + // focus/blur morphs to focusin/out; ensure we're not firing them right now + if ( rfocusMorph.test( type + jQuery.event.triggered ) ) { + return; + } + + if ( type.indexOf( "." ) > -1 ) { + + // Namespaced trigger; create a regexp to match event type in handle() + namespaces = type.split( "." ); + type = namespaces.shift(); + namespaces.sort(); + } + ontype = type.indexOf( ":" ) < 0 && "on" + type; + + // Caller can pass in a jQuery.Event object, Object, or just an event type string + event = event[ jQuery.expando ] ? + event : + new jQuery.Event( type, typeof event === "object" && event ); + + // Trigger bitmask: & 1 for native handlers; & 2 for jQuery (always true) + event.isTrigger = onlyHandlers ? 2 : 3; + event.namespace = namespaces.join( "." ); + event.rnamespace = event.namespace ? + new RegExp( "(^|\\.)" + namespaces.join( "\\.(?:.*\\.|)" ) + "(\\.|$)" ) : + null; + + // Clean up the event in case it is being reused + event.result = undefined; + if ( !event.target ) { + event.target = elem; + } + + // Clone any incoming data and prepend the event, creating the handler arg list + data = data == null ? + [ event ] : + jQuery.makeArray( data, [ event ] ); + + // Allow special events to draw outside the lines + special = jQuery.event.special[ type ] || {}; + if ( !onlyHandlers && special.trigger && special.trigger.apply( elem, data ) === false ) { + return; + } + + // Determine event propagation path in advance, per W3C events spec (#9951) + // Bubble up to document, then to window; watch for a global ownerDocument var (#9724) + if ( !onlyHandlers && !special.noBubble && !isWindow( elem ) ) { + + bubbleType = special.delegateType || type; + if ( !rfocusMorph.test( bubbleType + type ) ) { + cur = cur.parentNode; + } + for ( ; cur; cur = cur.parentNode ) { + eventPath.push( cur ); + tmp = cur; + } + + // Only add window if we got to document (e.g., not plain obj or detached DOM) + if ( tmp === ( elem.ownerDocument || document ) ) { + eventPath.push( tmp.defaultView || tmp.parentWindow || window ); + } + } + + // Fire handlers on the event path + i = 0; + while ( ( cur = eventPath[ i++ ] ) && !event.isPropagationStopped() ) { + lastElement = cur; + event.type = i > 1 ? + bubbleType : + special.bindType || type; + + // jQuery handler + handle = ( + dataPriv.get( cur, "events" ) || Object.create( null ) + )[ event.type ] && + dataPriv.get( cur, "handle" ); + if ( handle ) { + handle.apply( cur, data ); + } + + // Native handler + handle = ontype && cur[ ontype ]; + if ( handle && handle.apply && acceptData( cur ) ) { + event.result = handle.apply( cur, data ); + if ( event.result === false ) { + event.preventDefault(); + } + } + } + event.type = type; + + // If nobody prevented the default action, do it now + if ( !onlyHandlers && !event.isDefaultPrevented() ) { + + if ( ( !special._default || + special._default.apply( eventPath.pop(), data ) === false ) && + acceptData( elem ) ) { + + // Call a native DOM method on the target with the same name as the event. + // Don't do default actions on window, that's where global variables be (#6170) + if ( ontype && isFunction( elem[ type ] ) && !isWindow( elem ) ) { + + // Don't re-trigger an onFOO event when we call its FOO() method + tmp = elem[ ontype ]; + + if ( tmp ) { + elem[ ontype ] = null; + } + + // Prevent re-triggering of the same event, since we already bubbled it above + jQuery.event.triggered = type; + + if ( event.isPropagationStopped() ) { + lastElement.addEventListener( type, stopPropagationCallback ); + } + + elem[ type ](); + + if ( event.isPropagationStopped() ) { + lastElement.removeEventListener( type, stopPropagationCallback ); + } + + jQuery.event.triggered = undefined; + + if ( tmp ) { + elem[ ontype ] = tmp; + } + } + } + } + + return event.result; + }, + + // Piggyback on a donor event to simulate a different one + // Used only for `focus(in | out)` events + simulate: function( type, elem, event ) { + var e = jQuery.extend( + new jQuery.Event(), + event, + { + type: type, + isSimulated: true + } + ); + + jQuery.event.trigger( e, null, elem ); + } + +} ); + +jQuery.fn.extend( { + + trigger: function( type, data ) { + return this.each( function() { + jQuery.event.trigger( type, data, this ); + } ); + }, + triggerHandler: function( type, data ) { + var elem = this[ 0 ]; + if ( elem ) { + return jQuery.event.trigger( type, data, elem, true ); + } + } +} ); + + +// Support: Firefox <=44 +// Firefox doesn't have focus(in | out) events +// Related ticket - https://bugzilla.mozilla.org/show_bug.cgi?id=687787 +// +// Support: Chrome <=48 - 49, Safari <=9.0 - 9.1 +// focus(in | out) events fire after focus & blur events, +// which is spec violation - http://www.w3.org/TR/DOM-Level-3-Events/#events-focusevent-event-order +// Related ticket - https://bugs.chromium.org/p/chromium/issues/detail?id=449857 +if ( !support.focusin ) { + jQuery.each( { focus: "focusin", blur: "focusout" }, function( orig, fix ) { + + // Attach a single capturing handler on the document while someone wants focusin/focusout + var handler = function( event ) { + jQuery.event.simulate( fix, event.target, jQuery.event.fix( event ) ); + }; + + jQuery.event.special[ fix ] = { + setup: function() { + + // Handle: regular nodes (via `this.ownerDocument`), window + // (via `this.document`) & document (via `this`). + var doc = this.ownerDocument || this.document || this, + attaches = dataPriv.access( doc, fix ); + + if ( !attaches ) { + doc.addEventListener( orig, handler, true ); + } + dataPriv.access( doc, fix, ( attaches || 0 ) + 1 ); + }, + teardown: function() { + var doc = this.ownerDocument || this.document || this, + attaches = dataPriv.access( doc, fix ) - 1; + + if ( !attaches ) { + doc.removeEventListener( orig, handler, true ); + dataPriv.remove( doc, fix ); + + } else { + dataPriv.access( doc, fix, attaches ); + } + } + }; + } ); +} +var location = window.location; + +var nonce = { guid: Date.now() }; + +var rquery = ( /\?/ ); + + + +// Cross-browser xml parsing +jQuery.parseXML = function( data ) { + var xml; + if ( !data || typeof data !== "string" ) { + return null; + } + + // Support: IE 9 - 11 only + // IE throws on parseFromString with invalid input. + try { + xml = ( new window.DOMParser() ).parseFromString( data, "text/xml" ); + } catch ( e ) { + xml = undefined; + } + + if ( !xml || xml.getElementsByTagName( "parsererror" ).length ) { + jQuery.error( "Invalid XML: " + data ); + } + return xml; +}; + + +var + rbracket = /\[\]$/, + rCRLF = /\r?\n/g, + rsubmitterTypes = /^(?:submit|button|image|reset|file)$/i, + rsubmittable = /^(?:input|select|textarea|keygen)/i; + +function buildParams( prefix, obj, traditional, add ) { + var name; + + if ( Array.isArray( obj ) ) { + + // Serialize array item. + jQuery.each( obj, function( i, v ) { + if ( traditional || rbracket.test( prefix ) ) { + + // Treat each array item as a scalar. + add( prefix, v ); + + } else { + + // Item is non-scalar (array or object), encode its numeric index. + buildParams( + prefix + "[" + ( typeof v === "object" && v != null ? i : "" ) + "]", + v, + traditional, + add + ); + } + } ); + + } else if ( !traditional && toType( obj ) === "object" ) { + + // Serialize object item. + for ( name in obj ) { + buildParams( prefix + "[" + name + "]", obj[ name ], traditional, add ); + } + + } else { + + // Serialize scalar item. + add( prefix, obj ); + } +} + +// Serialize an array of form elements or a set of +// key/values into a query string +jQuery.param = function( a, traditional ) { + var prefix, + s = [], + add = function( key, valueOrFunction ) { + + // If value is a function, invoke it and use its return value + var value = isFunction( valueOrFunction ) ? + valueOrFunction() : + valueOrFunction; + + s[ s.length ] = encodeURIComponent( key ) + "=" + + encodeURIComponent( value == null ? "" : value ); + }; + + if ( a == null ) { + return ""; + } + + // If an array was passed in, assume that it is an array of form elements. + if ( Array.isArray( a ) || ( a.jquery && !jQuery.isPlainObject( a ) ) ) { + + // Serialize the form elements + jQuery.each( a, function() { + add( this.name, this.value ); + } ); + + } else { + + // If traditional, encode the "old" way (the way 1.3.2 or older + // did it), otherwise encode params recursively. + for ( prefix in a ) { + buildParams( prefix, a[ prefix ], traditional, add ); + } + } + + // Return the resulting serialization + return s.join( "&" ); +}; + +jQuery.fn.extend( { + serialize: function() { + return jQuery.param( this.serializeArray() ); + }, + serializeArray: function() { + return this.map( function() { + + // Can add propHook for "elements" to filter or add form elements + var elements = jQuery.prop( this, "elements" ); + return elements ? jQuery.makeArray( elements ) : this; + } ) + .filter( function() { + var type = this.type; + + // Use .is( ":disabled" ) so that fieldset[disabled] works + return this.name && !jQuery( this ).is( ":disabled" ) && + rsubmittable.test( this.nodeName ) && !rsubmitterTypes.test( type ) && + ( this.checked || !rcheckableType.test( type ) ); + } ) + .map( function( _i, elem ) { + var val = jQuery( this ).val(); + + if ( val == null ) { + return null; + } + + if ( Array.isArray( val ) ) { + return jQuery.map( val, function( val ) { + return { name: elem.name, value: val.replace( rCRLF, "\r\n" ) }; + } ); + } + + return { name: elem.name, value: val.replace( rCRLF, "\r\n" ) }; + } ).get(); + } +} ); + + +var + r20 = /%20/g, + rhash = /#.*$/, + rantiCache = /([?&])_=[^&]*/, + rheaders = /^(.*?):[ \t]*([^\r\n]*)$/mg, + + // #7653, #8125, #8152: local protocol detection + rlocalProtocol = /^(?:about|app|app-storage|.+-extension|file|res|widget):$/, + rnoContent = /^(?:GET|HEAD)$/, + rprotocol = /^\/\//, + + /* Prefilters + * 1) They are useful to introduce custom dataTypes (see ajax/jsonp.js for an example) + * 2) These are called: + * - BEFORE asking for a transport + * - AFTER param serialization (s.data is a string if s.processData is true) + * 3) key is the dataType + * 4) the catchall symbol "*" can be used + * 5) execution will start with transport dataType and THEN continue down to "*" if needed + */ + prefilters = {}, + + /* Transports bindings + * 1) key is the dataType + * 2) the catchall symbol "*" can be used + * 3) selection will start with transport dataType and THEN go to "*" if needed + */ + transports = {}, + + // Avoid comment-prolog char sequence (#10098); must appease lint and evade compression + allTypes = "*/".concat( "*" ), + + // Anchor tag for parsing the document origin + originAnchor = document.createElement( "a" ); + originAnchor.href = location.href; + +// Base "constructor" for jQuery.ajaxPrefilter and jQuery.ajaxTransport +function addToPrefiltersOrTransports( structure ) { + + // dataTypeExpression is optional and defaults to "*" + return function( dataTypeExpression, func ) { + + if ( typeof dataTypeExpression !== "string" ) { + func = dataTypeExpression; + dataTypeExpression = "*"; + } + + var dataType, + i = 0, + dataTypes = dataTypeExpression.toLowerCase().match( rnothtmlwhite ) || []; + + if ( isFunction( func ) ) { + + // For each dataType in the dataTypeExpression + while ( ( dataType = dataTypes[ i++ ] ) ) { + + // Prepend if requested + if ( dataType[ 0 ] === "+" ) { + dataType = dataType.slice( 1 ) || "*"; + ( structure[ dataType ] = structure[ dataType ] || [] ).unshift( func ); + + // Otherwise append + } else { + ( structure[ dataType ] = structure[ dataType ] || [] ).push( func ); + } + } + } + }; +} + +// Base inspection function for prefilters and transports +function inspectPrefiltersOrTransports( structure, options, originalOptions, jqXHR ) { + + var inspected = {}, + seekingTransport = ( structure === transports ); + + function inspect( dataType ) { + var selected; + inspected[ dataType ] = true; + jQuery.each( structure[ dataType ] || [], function( _, prefilterOrFactory ) { + var dataTypeOrTransport = prefilterOrFactory( options, originalOptions, jqXHR ); + if ( typeof dataTypeOrTransport === "string" && + !seekingTransport && !inspected[ dataTypeOrTransport ] ) { + + options.dataTypes.unshift( dataTypeOrTransport ); + inspect( dataTypeOrTransport ); + return false; + } else if ( seekingTransport ) { + return !( selected = dataTypeOrTransport ); + } + } ); + return selected; + } + + return inspect( options.dataTypes[ 0 ] ) || !inspected[ "*" ] && inspect( "*" ); +} + +// A special extend for ajax options +// that takes "flat" options (not to be deep extended) +// Fixes #9887 +function ajaxExtend( target, src ) { + var key, deep, + flatOptions = jQuery.ajaxSettings.flatOptions || {}; + + for ( key in src ) { + if ( src[ key ] !== undefined ) { + ( flatOptions[ key ] ? target : ( deep || ( deep = {} ) ) )[ key ] = src[ key ]; + } + } + if ( deep ) { + jQuery.extend( true, target, deep ); + } + + return target; +} + +/* Handles responses to an ajax request: + * - finds the right dataType (mediates between content-type and expected dataType) + * - returns the corresponding response + */ +function ajaxHandleResponses( s, jqXHR, responses ) { + + var ct, type, finalDataType, firstDataType, + contents = s.contents, + dataTypes = s.dataTypes; + + // Remove auto dataType and get content-type in the process + while ( dataTypes[ 0 ] === "*" ) { + dataTypes.shift(); + if ( ct === undefined ) { + ct = s.mimeType || jqXHR.getResponseHeader( "Content-Type" ); + } + } + + // Check if we're dealing with a known content-type + if ( ct ) { + for ( type in contents ) { + if ( contents[ type ] && contents[ type ].test( ct ) ) { + dataTypes.unshift( type ); + break; + } + } + } + + // Check to see if we have a response for the expected dataType + if ( dataTypes[ 0 ] in responses ) { + finalDataType = dataTypes[ 0 ]; + } else { + + // Try convertible dataTypes + for ( type in responses ) { + if ( !dataTypes[ 0 ] || s.converters[ type + " " + dataTypes[ 0 ] ] ) { + finalDataType = type; + break; + } + if ( !firstDataType ) { + firstDataType = type; + } + } + + // Or just use first one + finalDataType = finalDataType || firstDataType; + } + + // If we found a dataType + // We add the dataType to the list if needed + // and return the corresponding response + if ( finalDataType ) { + if ( finalDataType !== dataTypes[ 0 ] ) { + dataTypes.unshift( finalDataType ); + } + return responses[ finalDataType ]; + } +} + +/* Chain conversions given the request and the original response + * Also sets the responseXXX fields on the jqXHR instance + */ +function ajaxConvert( s, response, jqXHR, isSuccess ) { + var conv2, current, conv, tmp, prev, + converters = {}, + + // Work with a copy of dataTypes in case we need to modify it for conversion + dataTypes = s.dataTypes.slice(); + + // Create converters map with lowercased keys + if ( dataTypes[ 1 ] ) { + for ( conv in s.converters ) { + converters[ conv.toLowerCase() ] = s.converters[ conv ]; + } + } + + current = dataTypes.shift(); + + // Convert to each sequential dataType + while ( current ) { + + if ( s.responseFields[ current ] ) { + jqXHR[ s.responseFields[ current ] ] = response; + } + + // Apply the dataFilter if provided + if ( !prev && isSuccess && s.dataFilter ) { + response = s.dataFilter( response, s.dataType ); + } + + prev = current; + current = dataTypes.shift(); + + if ( current ) { + + // There's only work to do if current dataType is non-auto + if ( current === "*" ) { + + current = prev; + + // Convert response if prev dataType is non-auto and differs from current + } else if ( prev !== "*" && prev !== current ) { + + // Seek a direct converter + conv = converters[ prev + " " + current ] || converters[ "* " + current ]; + + // If none found, seek a pair + if ( !conv ) { + for ( conv2 in converters ) { + + // If conv2 outputs current + tmp = conv2.split( " " ); + if ( tmp[ 1 ] === current ) { + + // If prev can be converted to accepted input + conv = converters[ prev + " " + tmp[ 0 ] ] || + converters[ "* " + tmp[ 0 ] ]; + if ( conv ) { + + // Condense equivalence converters + if ( conv === true ) { + conv = converters[ conv2 ]; + + // Otherwise, insert the intermediate dataType + } else if ( converters[ conv2 ] !== true ) { + current = tmp[ 0 ]; + dataTypes.unshift( tmp[ 1 ] ); + } + break; + } + } + } + } + + // Apply converter (if not an equivalence) + if ( conv !== true ) { + + // Unless errors are allowed to bubble, catch and return them + if ( conv && s.throws ) { + response = conv( response ); + } else { + try { + response = conv( response ); + } catch ( e ) { + return { + state: "parsererror", + error: conv ? e : "No conversion from " + prev + " to " + current + }; + } + } + } + } + } + } + + return { state: "success", data: response }; +} + +jQuery.extend( { + + // Counter for holding the number of active queries + active: 0, + + // Last-Modified header cache for next request + lastModified: {}, + etag: {}, + + ajaxSettings: { + url: location.href, + type: "GET", + isLocal: rlocalProtocol.test( location.protocol ), + global: true, + processData: true, + async: true, + contentType: "application/x-www-form-urlencoded; charset=UTF-8", + + /* + timeout: 0, + data: null, + dataType: null, + username: null, + password: null, + cache: null, + throws: false, + traditional: false, + headers: {}, + */ + + accepts: { + "*": allTypes, + text: "text/plain", + html: "text/html", + xml: "application/xml, text/xml", + json: "application/json, text/javascript" + }, + + contents: { + xml: /\bxml\b/, + html: /\bhtml/, + json: /\bjson\b/ + }, + + responseFields: { + xml: "responseXML", + text: "responseText", + json: "responseJSON" + }, + + // Data converters + // Keys separate source (or catchall "*") and destination types with a single space + converters: { + + // Convert anything to text + "* text": String, + + // Text to html (true = no transformation) + "text html": true, + + // Evaluate text as a json expression + "text json": JSON.parse, + + // Parse text as xml + "text xml": jQuery.parseXML + }, + + // For options that shouldn't be deep extended: + // you can add your own custom options here if + // and when you create one that shouldn't be + // deep extended (see ajaxExtend) + flatOptions: { + url: true, + context: true + } + }, + + // Creates a full fledged settings object into target + // with both ajaxSettings and settings fields. + // If target is omitted, writes into ajaxSettings. + ajaxSetup: function( target, settings ) { + return settings ? + + // Building a settings object + ajaxExtend( ajaxExtend( target, jQuery.ajaxSettings ), settings ) : + + // Extending ajaxSettings + ajaxExtend( jQuery.ajaxSettings, target ); + }, + + ajaxPrefilter: addToPrefiltersOrTransports( prefilters ), + ajaxTransport: addToPrefiltersOrTransports( transports ), + + // Main method + ajax: function( url, options ) { + + // If url is an object, simulate pre-1.5 signature + if ( typeof url === "object" ) { + options = url; + url = undefined; + } + + // Force options to be an object + options = options || {}; + + var transport, + + // URL without anti-cache param + cacheURL, + + // Response headers + responseHeadersString, + responseHeaders, + + // timeout handle + timeoutTimer, + + // Url cleanup var + urlAnchor, + + // Request state (becomes false upon send and true upon completion) + completed, + + // To know if global events are to be dispatched + fireGlobals, + + // Loop variable + i, + + // uncached part of the url + uncached, + + // Create the final options object + s = jQuery.ajaxSetup( {}, options ), + + // Callbacks context + callbackContext = s.context || s, + + // Context for global events is callbackContext if it is a DOM node or jQuery collection + globalEventContext = s.context && + ( callbackContext.nodeType || callbackContext.jquery ) ? + jQuery( callbackContext ) : + jQuery.event, + + // Deferreds + deferred = jQuery.Deferred(), + completeDeferred = jQuery.Callbacks( "once memory" ), + + // Status-dependent callbacks + statusCode = s.statusCode || {}, + + // Headers (they are sent all at once) + requestHeaders = {}, + requestHeadersNames = {}, + + // Default abort message + strAbort = "canceled", + + // Fake xhr + jqXHR = { + readyState: 0, + + // Builds headers hashtable if needed + getResponseHeader: function( key ) { + var match; + if ( completed ) { + if ( !responseHeaders ) { + responseHeaders = {}; + while ( ( match = rheaders.exec( responseHeadersString ) ) ) { + responseHeaders[ match[ 1 ].toLowerCase() + " " ] = + ( responseHeaders[ match[ 1 ].toLowerCase() + " " ] || [] ) + .concat( match[ 2 ] ); + } + } + match = responseHeaders[ key.toLowerCase() + " " ]; + } + return match == null ? null : match.join( ", " ); + }, + + // Raw string + getAllResponseHeaders: function() { + return completed ? responseHeadersString : null; + }, + + // Caches the header + setRequestHeader: function( name, value ) { + if ( completed == null ) { + name = requestHeadersNames[ name.toLowerCase() ] = + requestHeadersNames[ name.toLowerCase() ] || name; + requestHeaders[ name ] = value; + } + return this; + }, + + // Overrides response content-type header + overrideMimeType: function( type ) { + if ( completed == null ) { + s.mimeType = type; + } + return this; + }, + + // Status-dependent callbacks + statusCode: function( map ) { + var code; + if ( map ) { + if ( completed ) { + + // Execute the appropriate callbacks + jqXHR.always( map[ jqXHR.status ] ); + } else { + + // Lazy-add the new callbacks in a way that preserves old ones + for ( code in map ) { + statusCode[ code ] = [ statusCode[ code ], map[ code ] ]; + } + } + } + return this; + }, + + // Cancel the request + abort: function( statusText ) { + var finalText = statusText || strAbort; + if ( transport ) { + transport.abort( finalText ); + } + done( 0, finalText ); + return this; + } + }; + + // Attach deferreds + deferred.promise( jqXHR ); + + // Add protocol if not provided (prefilters might expect it) + // Handle falsy url in the settings object (#10093: consistency with old signature) + // We also use the url parameter if available + s.url = ( ( url || s.url || location.href ) + "" ) + .replace( rprotocol, location.protocol + "//" ); + + // Alias method option to type as per ticket #12004 + s.type = options.method || options.type || s.method || s.type; + + // Extract dataTypes list + s.dataTypes = ( s.dataType || "*" ).toLowerCase().match( rnothtmlwhite ) || [ "" ]; + + // A cross-domain request is in order when the origin doesn't match the current origin. + if ( s.crossDomain == null ) { + urlAnchor = document.createElement( "a" ); + + // Support: IE <=8 - 11, Edge 12 - 15 + // IE throws exception on accessing the href property if url is malformed, + // e.g. http://example.com:80x/ + try { + urlAnchor.href = s.url; + + // Support: IE <=8 - 11 only + // Anchor's host property isn't correctly set when s.url is relative + urlAnchor.href = urlAnchor.href; + s.crossDomain = originAnchor.protocol + "//" + originAnchor.host !== + urlAnchor.protocol + "//" + urlAnchor.host; + } catch ( e ) { + + // If there is an error parsing the URL, assume it is crossDomain, + // it can be rejected by the transport if it is invalid + s.crossDomain = true; + } + } + + // Convert data if not already a string + if ( s.data && s.processData && typeof s.data !== "string" ) { + s.data = jQuery.param( s.data, s.traditional ); + } + + // Apply prefilters + inspectPrefiltersOrTransports( prefilters, s, options, jqXHR ); + + // If request was aborted inside a prefilter, stop there + if ( completed ) { + return jqXHR; + } + + // We can fire global events as of now if asked to + // Don't fire events if jQuery.event is undefined in an AMD-usage scenario (#15118) + fireGlobals = jQuery.event && s.global; + + // Watch for a new set of requests + if ( fireGlobals && jQuery.active++ === 0 ) { + jQuery.event.trigger( "ajaxStart" ); + } + + // Uppercase the type + s.type = s.type.toUpperCase(); + + // Determine if request has content + s.hasContent = !rnoContent.test( s.type ); + + // Save the URL in case we're toying with the If-Modified-Since + // and/or If-None-Match header later on + // Remove hash to simplify url manipulation + cacheURL = s.url.replace( rhash, "" ); + + // More options handling for requests with no content + if ( !s.hasContent ) { + + // Remember the hash so we can put it back + uncached = s.url.slice( cacheURL.length ); + + // If data is available and should be processed, append data to url + if ( s.data && ( s.processData || typeof s.data === "string" ) ) { + cacheURL += ( rquery.test( cacheURL ) ? "&" : "?" ) + s.data; + + // #9682: remove data so that it's not used in an eventual retry + delete s.data; + } + + // Add or update anti-cache param if needed + if ( s.cache === false ) { + cacheURL = cacheURL.replace( rantiCache, "$1" ); + uncached = ( rquery.test( cacheURL ) ? "&" : "?" ) + "_=" + ( nonce.guid++ ) + + uncached; + } + + // Put hash and anti-cache on the URL that will be requested (gh-1732) + s.url = cacheURL + uncached; + + // Change '%20' to '+' if this is encoded form body content (gh-2658) + } else if ( s.data && s.processData && + ( s.contentType || "" ).indexOf( "application/x-www-form-urlencoded" ) === 0 ) { + s.data = s.data.replace( r20, "+" ); + } + + // Set the If-Modified-Since and/or If-None-Match header, if in ifModified mode. + if ( s.ifModified ) { + if ( jQuery.lastModified[ cacheURL ] ) { + jqXHR.setRequestHeader( "If-Modified-Since", jQuery.lastModified[ cacheURL ] ); + } + if ( jQuery.etag[ cacheURL ] ) { + jqXHR.setRequestHeader( "If-None-Match", jQuery.etag[ cacheURL ] ); + } + } + + // Set the correct header, if data is being sent + if ( s.data && s.hasContent && s.contentType !== false || options.contentType ) { + jqXHR.setRequestHeader( "Content-Type", s.contentType ); + } + + // Set the Accepts header for the server, depending on the dataType + jqXHR.setRequestHeader( + "Accept", + s.dataTypes[ 0 ] && s.accepts[ s.dataTypes[ 0 ] ] ? + s.accepts[ s.dataTypes[ 0 ] ] + + ( s.dataTypes[ 0 ] !== "*" ? ", " + allTypes + "; q=0.01" : "" ) : + s.accepts[ "*" ] + ); + + // Check for headers option + for ( i in s.headers ) { + jqXHR.setRequestHeader( i, s.headers[ i ] ); + } + + // Allow custom headers/mimetypes and early abort + if ( s.beforeSend && + ( s.beforeSend.call( callbackContext, jqXHR, s ) === false || completed ) ) { + + // Abort if not done already and return + return jqXHR.abort(); + } + + // Aborting is no longer a cancellation + strAbort = "abort"; + + // Install callbacks on deferreds + completeDeferred.add( s.complete ); + jqXHR.done( s.success ); + jqXHR.fail( s.error ); + + // Get transport + transport = inspectPrefiltersOrTransports( transports, s, options, jqXHR ); + + // If no transport, we auto-abort + if ( !transport ) { + done( -1, "No Transport" ); + } else { + jqXHR.readyState = 1; + + // Send global event + if ( fireGlobals ) { + globalEventContext.trigger( "ajaxSend", [ jqXHR, s ] ); + } + + // If request was aborted inside ajaxSend, stop there + if ( completed ) { + return jqXHR; + } + + // Timeout + if ( s.async && s.timeout > 0 ) { + timeoutTimer = window.setTimeout( function() { + jqXHR.abort( "timeout" ); + }, s.timeout ); + } + + try { + completed = false; + transport.send( requestHeaders, done ); + } catch ( e ) { + + // Rethrow post-completion exceptions + if ( completed ) { + throw e; + } + + // Propagate others as results + done( -1, e ); + } + } + + // Callback for when everything is done + function done( status, nativeStatusText, responses, headers ) { + var isSuccess, success, error, response, modified, + statusText = nativeStatusText; + + // Ignore repeat invocations + if ( completed ) { + return; + } + + completed = true; + + // Clear timeout if it exists + if ( timeoutTimer ) { + window.clearTimeout( timeoutTimer ); + } + + // Dereference transport for early garbage collection + // (no matter how long the jqXHR object will be used) + transport = undefined; + + // Cache response headers + responseHeadersString = headers || ""; + + // Set readyState + jqXHR.readyState = status > 0 ? 4 : 0; + + // Determine if successful + isSuccess = status >= 200 && status < 300 || status === 304; + + // Get response data + if ( responses ) { + response = ajaxHandleResponses( s, jqXHR, responses ); + } + + // Use a noop converter for missing script + if ( !isSuccess && jQuery.inArray( "script", s.dataTypes ) > -1 ) { + s.converters[ "text script" ] = function() {}; + } + + // Convert no matter what (that way responseXXX fields are always set) + response = ajaxConvert( s, response, jqXHR, isSuccess ); + + // If successful, handle type chaining + if ( isSuccess ) { + + // Set the If-Modified-Since and/or If-None-Match header, if in ifModified mode. + if ( s.ifModified ) { + modified = jqXHR.getResponseHeader( "Last-Modified" ); + if ( modified ) { + jQuery.lastModified[ cacheURL ] = modified; + } + modified = jqXHR.getResponseHeader( "etag" ); + if ( modified ) { + jQuery.etag[ cacheURL ] = modified; + } + } + + // if no content + if ( status === 204 || s.type === "HEAD" ) { + statusText = "nocontent"; + + // if not modified + } else if ( status === 304 ) { + statusText = "notmodified"; + + // If we have data, let's convert it + } else { + statusText = response.state; + success = response.data; + error = response.error; + isSuccess = !error; + } + } else { + + // Extract error from statusText and normalize for non-aborts + error = statusText; + if ( status || !statusText ) { + statusText = "error"; + if ( status < 0 ) { + status = 0; + } + } + } + + // Set data for the fake xhr object + jqXHR.status = status; + jqXHR.statusText = ( nativeStatusText || statusText ) + ""; + + // Success/Error + if ( isSuccess ) { + deferred.resolveWith( callbackContext, [ success, statusText, jqXHR ] ); + } else { + deferred.rejectWith( callbackContext, [ jqXHR, statusText, error ] ); + } + + // Status-dependent callbacks + jqXHR.statusCode( statusCode ); + statusCode = undefined; + + if ( fireGlobals ) { + globalEventContext.trigger( isSuccess ? "ajaxSuccess" : "ajaxError", + [ jqXHR, s, isSuccess ? success : error ] ); + } + + // Complete + completeDeferred.fireWith( callbackContext, [ jqXHR, statusText ] ); + + if ( fireGlobals ) { + globalEventContext.trigger( "ajaxComplete", [ jqXHR, s ] ); + + // Handle the global AJAX counter + if ( !( --jQuery.active ) ) { + jQuery.event.trigger( "ajaxStop" ); + } + } + } + + return jqXHR; + }, + + getJSON: function( url, data, callback ) { + return jQuery.get( url, data, callback, "json" ); + }, + + getScript: function( url, callback ) { + return jQuery.get( url, undefined, callback, "script" ); + } +} ); + +jQuery.each( [ "get", "post" ], function( _i, method ) { + jQuery[ method ] = function( url, data, callback, type ) { + + // Shift arguments if data argument was omitted + if ( isFunction( data ) ) { + type = type || callback; + callback = data; + data = undefined; + } + + // The url can be an options object (which then must have .url) + return jQuery.ajax( jQuery.extend( { + url: url, + type: method, + dataType: type, + data: data, + success: callback + }, jQuery.isPlainObject( url ) && url ) ); + }; +} ); + +jQuery.ajaxPrefilter( function( s ) { + var i; + for ( i in s.headers ) { + if ( i.toLowerCase() === "content-type" ) { + s.contentType = s.headers[ i ] || ""; + } + } +} ); + + +jQuery._evalUrl = function( url, options, doc ) { + return jQuery.ajax( { + url: url, + + // Make this explicit, since user can override this through ajaxSetup (#11264) + type: "GET", + dataType: "script", + cache: true, + async: false, + global: false, + + // Only evaluate the response if it is successful (gh-4126) + // dataFilter is not invoked for failure responses, so using it instead + // of the default converter is kludgy but it works. + converters: { + "text script": function() {} + }, + dataFilter: function( response ) { + jQuery.globalEval( response, options, doc ); + } + } ); +}; + + +jQuery.fn.extend( { + wrapAll: function( html ) { + var wrap; + + if ( this[ 0 ] ) { + if ( isFunction( html ) ) { + html = html.call( this[ 0 ] ); + } + + // The elements to wrap the target around + wrap = jQuery( html, this[ 0 ].ownerDocument ).eq( 0 ).clone( true ); + + if ( this[ 0 ].parentNode ) { + wrap.insertBefore( this[ 0 ] ); + } + + wrap.map( function() { + var elem = this; + + while ( elem.firstElementChild ) { + elem = elem.firstElementChild; + } + + return elem; + } ).append( this ); + } + + return this; + }, + + wrapInner: function( html ) { + if ( isFunction( html ) ) { + return this.each( function( i ) { + jQuery( this ).wrapInner( html.call( this, i ) ); + } ); + } + + return this.each( function() { + var self = jQuery( this ), + contents = self.contents(); + + if ( contents.length ) { + contents.wrapAll( html ); + + } else { + self.append( html ); + } + } ); + }, + + wrap: function( html ) { + var htmlIsFunction = isFunction( html ); + + return this.each( function( i ) { + jQuery( this ).wrapAll( htmlIsFunction ? html.call( this, i ) : html ); + } ); + }, + + unwrap: function( selector ) { + this.parent( selector ).not( "body" ).each( function() { + jQuery( this ).replaceWith( this.childNodes ); + } ); + return this; + } +} ); + + +jQuery.expr.pseudos.hidden = function( elem ) { + return !jQuery.expr.pseudos.visible( elem ); +}; +jQuery.expr.pseudos.visible = function( elem ) { + return !!( elem.offsetWidth || elem.offsetHeight || elem.getClientRects().length ); +}; + + + + +jQuery.ajaxSettings.xhr = function() { + try { + return new window.XMLHttpRequest(); + } catch ( e ) {} +}; + +var xhrSuccessStatus = { + + // File protocol always yields status code 0, assume 200 + 0: 200, + + // Support: IE <=9 only + // #1450: sometimes IE returns 1223 when it should be 204 + 1223: 204 + }, + xhrSupported = jQuery.ajaxSettings.xhr(); + +support.cors = !!xhrSupported && ( "withCredentials" in xhrSupported ); +support.ajax = xhrSupported = !!xhrSupported; + +jQuery.ajaxTransport( function( options ) { + var callback, errorCallback; + + // Cross domain only allowed if supported through XMLHttpRequest + if ( support.cors || xhrSupported && !options.crossDomain ) { + return { + send: function( headers, complete ) { + var i, + xhr = options.xhr(); + + xhr.open( + options.type, + options.url, + options.async, + options.username, + options.password + ); + + // Apply custom fields if provided + if ( options.xhrFields ) { + for ( i in options.xhrFields ) { + xhr[ i ] = options.xhrFields[ i ]; + } + } + + // Override mime type if needed + if ( options.mimeType && xhr.overrideMimeType ) { + xhr.overrideMimeType( options.mimeType ); + } + + // X-Requested-With header + // For cross-domain requests, seeing as conditions for a preflight are + // akin to a jigsaw puzzle, we simply never set it to be sure. + // (it can always be set on a per-request basis or even using ajaxSetup) + // For same-domain requests, won't change header if already provided. + if ( !options.crossDomain && !headers[ "X-Requested-With" ] ) { + headers[ "X-Requested-With" ] = "XMLHttpRequest"; + } + + // Set headers + for ( i in headers ) { + xhr.setRequestHeader( i, headers[ i ] ); + } + + // Callback + callback = function( type ) { + return function() { + if ( callback ) { + callback = errorCallback = xhr.onload = + xhr.onerror = xhr.onabort = xhr.ontimeout = + xhr.onreadystatechange = null; + + if ( type === "abort" ) { + xhr.abort(); + } else if ( type === "error" ) { + + // Support: IE <=9 only + // On a manual native abort, IE9 throws + // errors on any property access that is not readyState + if ( typeof xhr.status !== "number" ) { + complete( 0, "error" ); + } else { + complete( + + // File: protocol always yields status 0; see #8605, #14207 + xhr.status, + xhr.statusText + ); + } + } else { + complete( + xhrSuccessStatus[ xhr.status ] || xhr.status, + xhr.statusText, + + // Support: IE <=9 only + // IE9 has no XHR2 but throws on binary (trac-11426) + // For XHR2 non-text, let the caller handle it (gh-2498) + ( xhr.responseType || "text" ) !== "text" || + typeof xhr.responseText !== "string" ? + { binary: xhr.response } : + { text: xhr.responseText }, + xhr.getAllResponseHeaders() + ); + } + } + }; + }; + + // Listen to events + xhr.onload = callback(); + errorCallback = xhr.onerror = xhr.ontimeout = callback( "error" ); + + // Support: IE 9 only + // Use onreadystatechange to replace onabort + // to handle uncaught aborts + if ( xhr.onabort !== undefined ) { + xhr.onabort = errorCallback; + } else { + xhr.onreadystatechange = function() { + + // Check readyState before timeout as it changes + if ( xhr.readyState === 4 ) { + + // Allow onerror to be called first, + // but that will not handle a native abort + // Also, save errorCallback to a variable + // as xhr.onerror cannot be accessed + window.setTimeout( function() { + if ( callback ) { + errorCallback(); + } + } ); + } + }; + } + + // Create the abort callback + callback = callback( "abort" ); + + try { + + // Do send the request (this may raise an exception) + xhr.send( options.hasContent && options.data || null ); + } catch ( e ) { + + // #14683: Only rethrow if this hasn't been notified as an error yet + if ( callback ) { + throw e; + } + } + }, + + abort: function() { + if ( callback ) { + callback(); + } + } + }; + } +} ); + + + + +// Prevent auto-execution of scripts when no explicit dataType was provided (See gh-2432) +jQuery.ajaxPrefilter( function( s ) { + if ( s.crossDomain ) { + s.contents.script = false; + } +} ); + +// Install script dataType +jQuery.ajaxSetup( { + accepts: { + script: "text/javascript, application/javascript, " + + "application/ecmascript, application/x-ecmascript" + }, + contents: { + script: /\b(?:java|ecma)script\b/ + }, + converters: { + "text script": function( text ) { + jQuery.globalEval( text ); + return text; + } + } +} ); + +// Handle cache's special case and crossDomain +jQuery.ajaxPrefilter( "script", function( s ) { + if ( s.cache === undefined ) { + s.cache = false; + } + if ( s.crossDomain ) { + s.type = "GET"; + } +} ); + +// Bind script tag hack transport +jQuery.ajaxTransport( "script", function( s ) { + + // This transport only deals with cross domain or forced-by-attrs requests + if ( s.crossDomain || s.scriptAttrs ) { + var script, callback; + return { + send: function( _, complete ) { + script = jQuery( " + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Index
  • +
  • +
  • +
+
+
+
+
+ + +

Index

+ +
+ _ + | A + | B + | C + | D + | E + | F + | G + | H + | I + | J + | K + | L + | M + | N + | O + | P + | R + | S + | T + | U + | V + | W + | X + | Y + | Z + +
+

_

+ + + +
+ +

A

+ + + +
+ +

B

+ + + +
+ +

C

+ + + +
+ +

D

+ + + +
+ +

E

+ + + +
+ +

F

+ + + +
+ +

G

+ + + +
+ +

H

+ + + +
+ +

I

+ + + +
+ +

J

+ + + +
+ +

K

+ + + +
+ +

L

+ + + +
+ +

M

+ + + +
+ +

N

+ + + +
+ +

O

+ + + +
+ +

P

+ + + +
+ +

R

+ + + +
+ +

S

+ + + +
+ +

T

+ + + +
+ +

U

+ + + +
+ +

V

+ + + +
+ +

W

+ + + +
+ +

X

+ + +
+ +

Y

+ + +
+ +

Z

+ + + +
+ + + +
+
+
+ +
+ +
+

© Copyright 2022, Dimitrios Karkalousos.

+
+ + Built with Sphinx using a + theme + provided by Read the Docs. + + +
+
+
+
+
+ + + + diff --git a/docs/build/html/index.html b/docs/build/html/index.html new file mode 100644 index 00000000..26c4f030 --- /dev/null +++ b/docs/build/html/index.html @@ -0,0 +1,230 @@ + + + + + + + Welcome to mridc’s documentation! — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

Welcome to mridc’s documentation!

+
+

Data Consistency for Magnetic Resonance Imaging

+

CodeQL +CircleCI +codecov +Code style: black

+
+
+

Introduction

+

MRIDC is a toolbox for applying AI methods on MR imaging. A collection of tools for data consistency and data quality +is provided for MRI data analysis. Primarily it focuses on the following tasks:

+
+

Reconstruction:

+

1.Cascades of Independently Recurrent Inference Machines (CIRIM), +2.Compressed Sensing (CS), +3.Convolutional Recurrent Neural Networks (CRNN), +4.Deep Cascade of Convolutional Neural Networks (CCNN), +5.Down-Up Net (DUNET), +6.End-to-End Variational Network (E2EVN), +7.Joint Deep Model-Based MR Image and Coil Sensitivity Reconstruction Network (Joint-ICNet), +8.Independently Recurrent Inference Machines (IRIM), +9.KIKI-Net, +10.Learned Primal-Dual Net (LPDNet), +11.MultiDomainNet, +12.Recurrent Inference Machines (RIM), +13.Recurrent Variational Network (RVN), +14.UNet, +15.Variable Splitting Network (VSNet), +16.XPDNet, +17.and Zero-Filled reconstruction (ZF).

+
+
+

Segmentation:

+

Coming soon…

+
+
+

Acknowledgements

+

MRIDC is based on the NeMo framework, using PyTorch Lightning for feasible +high-performance multi-GPU/multi-node mixed-precision training.

+

For the reconstruction methods:

+
    +
  • the implementations of 6 and 14 are thanks to and based on the fastMRI repo.

  • +
  • The implementations of 7, 9, 10, 11, 13, and 16 are thanks to and based on the DIRECT repo.

  • +
+
+
+
+

Installation

+

MRIDC is best to be installed in a Conda environment.

+
conda create -n mridc python=3.9
+conda activate mridc
+
+
+
+

Pip

+

Use pip installation if you want the latest stable version.

+
pip install mridc
+
+
+
+
+

From source

+

Use source installation if you want the latest development version, as well as for contributing to MRIDC.

+
git clone https://github.com/wdika/mridc
+cd mridc
+./reinstall.sh
+
+
+
+
+
+

Datasets

+

Recommended public datasets to use with this repo:

+ +
+
+

API Documentation

+

Documentation Status

+

Access the API Documentation here

+
+
+

License

+

License: Apache 2.0

+
+
+

Citation

+

Please cite MRIDC using the “Cite this repository” button or as

+
@misc{mridc,
+    author = {Karkalousos, Dimitrios and Caan, Matthan},
+    title = {MRIDC: Data Consistency for Magnetic Resonance Imaging},
+    year = {2021},
+    url = {https://github.com/wdika/mridc},
+}
+
+
+
+
+

Papers

+

The following papers use the MRIDC repo:

+

[1] Karkalousos, D. et al. (2021) ‘Assessment of Data Consistency through Cascades of Independently Recurrent +Inference Machines for fast and robust accelerated MRI reconstruction’

+
+
+ +
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/modules.html b/docs/build/html/modules.html new file mode 100644 index 00000000..8054cdc8 --- /dev/null +++ b/docs/build/html/modules.html @@ -0,0 +1,153 @@ + + + + + + + mridc — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/build/html/mridc.collections.common.callbacks.html b/docs/build/html/mridc.collections.common.callbacks.html new file mode 100644 index 00000000..52ececf1 --- /dev/null +++ b/docs/build/html/mridc.collections.common.callbacks.html @@ -0,0 +1,152 @@ + + + + + + + mridc.collections.common.callbacks package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.common.callbacks package

+
+

Submodules

+
+
+

mridc.collections.common.callbacks.callbacks module

+
+
+class mridc.collections.common.callbacks.callbacks.LogEpochTimeCallback[source]
+

Bases: pytorch_lightning.callbacks.base.Callback

+

Simple callback that logs how long each epoch takes, in seconds, to a pytorch lightning log

+
+
+on_train_epoch_end(trainer, pl_module)[source]
+

Called at the end of each epoch.

+
+ +
+
+on_train_epoch_start(trainer, pl_module)[source]
+

Called at the start of each epoch.

+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.common.data.html b/docs/build/html/mridc.collections.common.data.html new file mode 100644 index 00000000..04261f38 --- /dev/null +++ b/docs/build/html/mridc.collections.common.data.html @@ -0,0 +1,184 @@ + + + + + + + mridc.collections.common.data package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.common.data package

+
+

Submodules

+
+
+

mridc.collections.common.data.dataset module

+
+
+class mridc.collections.common.data.dataset.ConcatDataset(datasets: List[Any], shuffle: bool = True, sampling_technique: str = 'random', sampling_probabilities: Optional[List[float]] = None, global_rank: int = 0, world_size: int = 1)[source]
+

Bases: torch.utils.data.dataset.IterableDataset, abc.ABC

+

A dataset that accepts as argument multiple datasets and then samples from them based on the specified +sampling technique.

+
+
Parameters
+
    +
  • datasets (A list of datasets to sample from.) –

  • +
  • shuffle (Whether to shuffle individual datasets. Only works with non-iterable datasets. Defaults to True.) –

  • +
  • sampling_technique (Sampling technique to choose which dataset to draw a sample from. Defaults to 'random'.) –

  • +
  • 'round-robin'. (Currently supports 'random' and) –

  • +
  • sampling_probabilities (Probability values for sampling. Only used when sampling_technique = 'random'.) –

  • +
  • global_rank (Worker rank, used for partitioning map style datasets. Defaults to 0.) –

  • +
  • world_size (Total number of processes, used for partitioning map style datasets. Defaults to 1.) –

  • +
+
+
+
+
+__iter__()[source]
+

Returns an iterator over the dataset.

+
+ +
+
+__len__()[source]
+

Returns the number of elements in the dataset.

+
+ +
+
+get_iterable(dataset)[source]
+

Returns an iterable dataset.

+
+ +
+
+static random_generator(datasets, **kwargs)[source]
+

Generates random indices.

+
+ +
+
+static round_robin_generator(datasets, **kwargs)[source]
+

Generates indices in a round-robin fashion.

+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.common.html b/docs/build/html/mridc.collections.common.html new file mode 100644 index 00000000..accc75d6 --- /dev/null +++ b/docs/build/html/mridc.collections.common.html @@ -0,0 +1,168 @@ + + + + + + + mridc.collections.common package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ + +
+
+ + + + diff --git a/docs/build/html/mridc.collections.common.losses.html b/docs/build/html/mridc.collections.common.losses.html new file mode 100644 index 00000000..c0a760ed --- /dev/null +++ b/docs/build/html/mridc.collections.common.losses.html @@ -0,0 +1,203 @@ + + + + + + + mridc.collections.common.losses package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.common.losses package

+
+

Submodules

+
+
+

mridc.collections.common.losses.aggregator module

+
+
+class mridc.collections.common.losses.aggregator.AggregatorLoss(num_inputs: int = 2, weights: Optional[List[float]] = None)[source]
+

Bases: mridc.core.classes.loss.Loss

+

Sums several losses into one.

+
+
Parameters
+
    +
  • num_inputs (number of input losses) –

  • +
  • weights (a list of coefficient for merging losses) –

  • +
+
+
+
+
+forward(**kwargs)[source]
+

Computes the sum of the losses.

+
+ +
+
+property input_types
+

Returns definitions of module input ports.

+
+ +
+
+property output_types
+

Returns definitions of module output ports.

+
+ +
+
+reduction: str
+
+ +
+ +
+
+

mridc.collections.common.losses.ssim module

+
+
+class mridc.collections.common.losses.ssim.SSIMLoss(win_size: int = 7, k1: float = 0.01, k2: float = 0.03)[source]
+

Bases: torch.nn.modules.module.Module

+

SSIM loss module.

+
+
+forward(X: torch.Tensor, Y: torch.Tensor, data_range: torch.Tensor)[source]
+
+
Parameters
+
    +
  • X (First input tensor.) –

  • +
  • Y (Second input tensor.) –

  • +
  • data_range (Data range of the input tensors.) –

  • +
+
+
Return type
+

SSIM loss.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.common.metrics.html b/docs/build/html/mridc.collections.common.metrics.html new file mode 100644 index 00000000..e4f13f10 --- /dev/null +++ b/docs/build/html/mridc.collections.common.metrics.html @@ -0,0 +1,174 @@ + + + + + + + mridc.collections.common.metrics package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.common.metrics package

+
+

Submodules

+
+
+

mridc.collections.common.metrics.global_average_loss_metric module

+
+
+class mridc.collections.common.metrics.global_average_loss_metric.GlobalAverageLossMetric(compute_on_step=True, dist_sync_on_step=False, process_group=None, take_avg_loss=True)[source]
+

Bases: torchmetrics.metric.Metric

+

This class is for averaging loss across multiple processes if a distributed backend is used. True average is computed not running average. It does not accumulate gradients so the averaged loss cannot be used for optimization.

+
+

Note

+

If take_avg_loss is True, the update() method loss argument has to be a mean loss. If take_avg_loss is False then the update() method loss argument has to be a sum of losses. See PyTorch Lightning Metrics for the metric usage instruction.

+
+
+
Parameters
+
    +
  • compute_on_step (The method forward() only calls update() and returns None if this is set to False. Default: True) –

  • +
  • dist_sync_on_step (Synchronize metric state across processes at each method forward() call before returning the value at the step) –

  • +
  • process_group (Specify the process group on which synchronization is called. default: None (which selects the entire world)) –

  • +
  • take_avg_loss (If True values of update() method loss argument has to be a mean loss. If False values of update() method loss argument has to be a sum of losses. default: True) –

  • +
+
+
+
+
+compute()[source]
+

Returns mean loss.

+
+ +
+
+update(loss, num_measurements)[source]
+

Updates loss_sum and num_measurements.

+
+
Parameters
+
    +
  • loss (A float zero dimensional torch.Tensor which is either sum or average of losses for processed examples. See take_avg_loss parameter of __init__().) –

  • +
  • num_measurements (An integer zero dimensional torch.Tensor which contains a number of loss measurements. The sum or mean of the results of these measurements are in the loss parameter.) –

  • +
+
+
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.common.parts.html b/docs/build/html/mridc.collections.common.parts.html new file mode 100644 index 00000000..f79e9844 --- /dev/null +++ b/docs/build/html/mridc.collections.common.parts.html @@ -0,0 +1,412 @@ + + + + + + + mridc.collections.common.parts package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.common.parts package

+
+

Submodules

+
+
+

mridc.collections.common.parts.fft module

+
+
+mridc.collections.common.parts.fft.fft2c(data: torch.Tensor, fft_type: str = 'orthogonal', fft_normalization: str = 'ortho', fft_dim: Union[int, None, List[int]] = None) torch.Tensor[source]
+

Apply centered 2 dimensional Fast Fourier Transform.

+
+
Parameters
+
    +
  • data (Complex valued input data containing at least 3 dimensions: dimensions -2 & -1 are spatial dimensions. All) –

  • +
  • dimensions. (other dimensions are assumed to be batch) –

  • +
  • fft_type (Specify fft type. This is important if an orthogonal transformation is needed or not.) –

  • +
  • fft_normalization ("ortho" is the default normalization used by PyTorch. Can be changed to "ortho" or None.) –

  • +
  • fft_dim (dimensions to apply the FFT) –

  • +
+
+
Return type
+

The FFT of the input.

+
+
+
+ +
+
+mridc.collections.common.parts.fft.ifft2c(data: torch.Tensor, fft_type: str = 'orthogonal', fft_normalization: str = 'ortho', fft_dim: Union[int, None, List[int]] = None) torch.Tensor[source]
+

Apply centered 2 dimensional Inverse Fast Fourier Transform.

+
+
Parameters
+
    +
  • data (Complex valued input data containing at least 3 dimensions: dimensions -2 & -1 are spatial dimensions. All) –

  • +
  • dimensions. (other dimensions are assumed to be batch) –

  • +
  • fft_type (Specify fft type. This is important if an orthogonal transformation is needed or not.) –

  • +
  • fft_normalization ("ortho" is the default normalization used by PyTorch. Can be changed to "ortho" or None.) –

  • +
  • fft_dim (dimensions to apply the FFT) –

  • +
+
+
Return type
+

The IFFT of the input.

+
+
+
+ +
+
+

mridc.collections.common.parts.patch_utils module

+
+
+

mridc.collections.common.parts.ptl_overrides module

+
+
+class mridc.collections.common.parts.ptl_overrides.MRIDCNativeMixedPrecisionPlugin(init_scale: float = 4294967296, growth_interval: int = 1000)[source]
+

Bases: pytorch_lightning.plugins.precision.native_amp.NativeMixedPrecisionPlugin

+

Native Mixed Precision Plugin for MRIDC.

+
+ +
+
+

mridc.collections.common.parts.rnn_utils module

+
+
+mridc.collections.common.parts.rnn_utils.rnn_weights_init(module, std_init_range=0.02, xavier=True)[source]
+

# TODO: check if this is the correct way to initialize RNN weights +Initialize different weights in Transformer model.

+
+
Parameters
+
    +
  • module (torch.nn.Module to be initialized) –

  • +
  • std_init_range (standard deviation of normal initializer) –

  • +
  • xavier (if True, xavier initializer will be used in Linear layers as was proposed in AIAYN paper, otherwise normal) –

  • +
  • paper) (initializer will be used (like in BERT) –

  • +
+
+
+
+ +
+
+

mridc.collections.common.parts.utils module

+
+
+mridc.collections.common.parts.utils.check_stacked_complex(data: torch.Tensor) torch.Tensor[source]
+

Check if tensor is stacked complex (real & imag parts stacked along last dim) and convert it to a combined complex +tensor.

+
+
Parameters
+

data (A complex valued tensor, where the size of the final dimension might be 2.) –

+
+
Return type
+

A complex valued tensor.

+
+
+
+ +
+
+mridc.collections.common.parts.utils.coil_combination(data: torch.Tensor, sensitivity_maps: torch.Tensor, method: str = 'SENSE', dim: int = 0) torch.Tensor[source]
+

Coil combination.

+
+
Parameters
+
    +
  • data (The input tensor.) –

  • +
  • sensitivity_maps (The sensitivity maps.) –

  • +
  • method (The coil combination method.) –

  • +
  • dim (The dimensions along which to apply the coil combination transform.) –

  • +
+
+
Return type
+

Coil combined data.

+
+
+
+ +
+
+mridc.collections.common.parts.utils.complex_abs(data: torch.Tensor) torch.Tensor[source]
+

Compute the absolute value of a complex valued input tensor.

+
+
Parameters
+

data (A complex valued tensor, where the size of the final dimension should be 2.) –

+
+
Return type
+

Absolute value of data.

+
+
+
+ +
+
+mridc.collections.common.parts.utils.complex_abs_sq(data: torch.Tensor) torch.Tensor[source]
+

Compute the squared absolute value of a complex tensor.

+
+
Parameters
+

data (A complex valued tensor, where the size of the final dimension should be 2.) –

+
+
Return type
+

Squared absolute value of data.

+
+
+
+ +
+
+mridc.collections.common.parts.utils.complex_conj(x: torch.Tensor) torch.Tensor[source]
+

Complex conjugate.

+

This applies the complex conjugate assuming that the input array has the +last dimension as the complex dimension.

+
+
Parameters
+

x (A PyTorch tensor with the last dimension of size 2.) –

+
+
Return type
+

A PyTorch tensor with the last dimension of size 2.

+
+
+
+ +
+
+mridc.collections.common.parts.utils.complex_mul(x: torch.Tensor, y: torch.Tensor) torch.Tensor[source]
+

Complex multiplication.

+

This multiplies two complex tensors assuming that they are both stored as +real arrays with the last dimension being the complex dimension.

+
+
Parameters
+
    +
  • x (A PyTorch tensor with the last dimension of size 2.) –

  • +
  • y (A PyTorch tensor with the last dimension of size 2.) –

  • +
+
+
Return type
+

A PyTorch tensor with the last dimension of size 2.

+
+
+
+ +
+
+mridc.collections.common.parts.utils.rss(data: torch.Tensor, dim: int = 0) torch.Tensor[source]
+

Compute the Root Sum of Squares (RSS).

+

RSS is computed assuming that dim is the coil dimension.

+
+
Parameters
+
    +
  • data (The input tensor) –

  • +
  • dim (The dimensions along which to apply the RSS transform) –

  • +
+
+
Return type
+

The RSS value.

+
+
+
+ +
+
+mridc.collections.common.parts.utils.rss_complex(data: torch.Tensor, dim: int = 0) torch.Tensor[source]
+

Compute the Root Sum of Squares (RSS) for complex inputs.

+

RSS is computed assuming that dim is the coil dimension.

+
+
Parameters
+
    +
  • data (The input tensor) –

  • +
  • dim (The dimensions along which to apply the RSS transform) –

  • +
+
+
Return type
+

The RSS value.

+
+
+
+ +
+
+mridc.collections.common.parts.utils.save_reconstructions(reconstructions: Dict[str, numpy.ndarray], out_dir: pathlib.Path)[source]
+

Save reconstruction images.

+

This function writes to h5 files that are appropriate for submission to the +leaderboard.

+
+
Parameters
+
    +
  • reconstructions (A dictionary mapping input filenames to corresponding reconstructions.) –

  • +
  • out_dir (Path to the output directory where the reconstructions should be saved.) –

  • +
+
+
+
+ +
+
+mridc.collections.common.parts.utils.sense(data: torch.Tensor, sensitivity_maps: torch.Tensor, dim: int = 0) torch.Tensor[source]
+

The SENSitivity Encoding (SENSE) transform 1.

+

References

+
+
1
+

Pruessmann KP, Weiger M, Scheidegger MB, Boesiger P. SENSE: Sensitivity encoding for fast MRI. Magn Reson Med 1999; 42:952-962.

+
+
+
+
Parameters
+
    +
  • data (The input tensor) –

  • +
  • sensitivity_maps (The sensitivity maps) –

  • +
  • dim (The coil dimension) –

  • +
+
+
Return type
+

A coil-combined image.

+
+
+
+ +
+
+mridc.collections.common.parts.utils.tensor_to_complex_np(data: torch.Tensor) numpy.ndarray[source]
+

Converts a torch tensor to a numpy array.

+
+
Parameters
+

data (Input torch tensor to be converted to numpy.) –

+
+
Return type
+

Complex Numpy array version of data.

+
+
+
+ +
+
+mridc.collections.common.parts.utils.to_tensor(data: numpy.ndarray) torch.Tensor[source]
+

Converts a numpy array to a torch tensor.

+

For complex arrays, the real and imaginary parts are stacked along the last +dimension.

+
+
Parameters
+

data (Input numpy array to be converted to torch.) –

+
+
Return type
+

Torch tensor version of data.

+
+
+
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.html b/docs/build/html/mridc.collections.html new file mode 100644 index 00000000..91a298a2 --- /dev/null +++ b/docs/build/html/mridc.collections.html @@ -0,0 +1,223 @@ + + + + + + + mridc.collections package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections package

+
+

Subpackages

+
+ +
+
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.data.html b/docs/build/html/mridc.collections.reconstruction.data.html new file mode 100644 index 00000000..a8ccbf92 --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.data.html @@ -0,0 +1,431 @@ + + + + + + + mridc.collections.reconstruction.data package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.data package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.data.mri_data module

+
+
+class mridc.collections.reconstruction.data.mri_data.FastMRICombinedSliceDataset(roots: Sequence[pathlib.Path], challenges: Sequence[str], sense_roots: Optional[Sequence[pathlib.Path]] = None, transforms: Optional[Sequence[Optional[Callable]]] = None, sample_rates: Optional[Sequence[Optional[float]]] = None, volume_sample_rates: Optional[Sequence[Optional[float]]] = None, use_dataset_cache: bool = False, dataset_cache_file: Union[str, pathlib.Path, os.PathLike] = 'dataset_cache.yaml', num_cols: Optional[Tuple[int]] = None)[source]
+

Bases: torch.utils.data.dataset.Dataset

+

A dataset that combines multiple datasets.

+
+ +
+
+class mridc.collections.reconstruction.data.mri_data.FastMRISliceDataset(root: Union[str, pathlib.Path, os.PathLike], challenge: str = 'multicoil', transform: Optional[Callable] = None, sense_root: Optional[Union[str, pathlib.Path, os.PathLike]] = None, use_dataset_cache: bool = False, sample_rate: Optional[float] = None, volume_sample_rate: Optional[float] = None, dataset_cache_file: Union[str, pathlib.Path, os.PathLike] = 'dataset_cache.yaml', num_cols: Optional[Tuple[int]] = None, mask_root: Optional[Union[str, pathlib.Path, os.PathLike]] = None)[source]
+

Bases: torch.utils.data.dataset.Dataset

+

A dataset that loads slices from a single dataset.

+
+ +
+
+mridc.collections.reconstruction.data.mri_data.et_query(root: str, qlist: Sequence[str], namespace: str = 'https://www.ismrm.org/ISMRMRD') str[source]
+

Query an XML element for a list of attributes.

+
+
Parameters
+
    +
  • root (The root element of the XML tree.) –

  • +
  • qlist (A list of strings, each of which is an attribute name.) –

  • +
  • namespace (The namespace of the XML tree.) –

  • +
+
+
Return type
+

A string containing the value of the last attribute in the list.

+
+
+
+ +
+
+

mridc.collections.reconstruction.data.subsample module

+
+
+class mridc.collections.reconstruction.data.subsample.EquispacedMaskFunc(center_fractions: Sequence[float], accelerations: Sequence[int])[source]
+

Bases: mridc.collections.reconstruction.data.subsample.MaskFunc

+

EquispacedMaskFunc creates a sub-sampling mask of a given shape.

+
+
The mask selects a subset of columns from the input k-space data. If the k-space data has N columns, the mask picks out:
    +
  1. N_low_freqs = (N * center_fraction) columns in the center corresponding to low-frequencies.

  2. +
  3. The other columns are selected with equal spacing at a proportion that reaches the desired acceleration rate taking into consideration the number of low frequencies. This ensures that the expected number of columns selected is equal to (N / acceleration)

  4. +
+
+
+

It is possible to use multiple center_fractions and accelerations, in which case one possible (center_fraction, acceleration) is chosen uniformly at random each time the EquispacedMaskFunc object is called.

+

Note that this function may not give equispaced samples (documented in https://github.com/facebookresearch/fastMRI/issues/54), which will require modifications to standard GRAPPA approaches. Nonetheless, this aspect of the function has been preserved to match the public multicoil data.

+
+
+__call__(shape: Sequence[int], seed: Optional[Union[int, Tuple[int, ...]]] = None, half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02) Tuple[torch.Tensor, int][source]
+
+
Parameters
+
    +
  • shape (The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn along the second last dimension.) –

  • +
  • seed (Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. The random state is reset afterwards.) –

  • +
  • half_scan_percentage (Optional; Defines a fraction of the k-space data that is not sampled.) –

  • +
  • scale (Optional; Defines the scale of the center of the mask.) –

  • +
+
+
Return type
+

A tuple of the mask and the number of columns selected.

+
+
+
+ +
+ +
+
+class mridc.collections.reconstruction.data.subsample.Gaussian1DMaskFunc(center_fractions: Sequence[float], accelerations: Sequence[int])[source]
+

Bases: mridc.collections.reconstruction.data.subsample.MaskFunc

+

Creates a 1D sub-sampling mask of a given shape.

+

For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of which the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled according to a Gaussian distribution.

+

The center fractions here act as Full-Width at Half-Maximum (FWHM) values.

+
+
+__call__(shape: Union[Sequence[int], numpy.ndarray], seed: Optional[Union[int, Tuple[int, ...]]] = None, half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02) Tuple[torch.Tensor, int][source]
+
+
Parameters
+
    +
  • shape (The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn along the second last dimension.) –

  • +
  • seed (Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. The random state is reset afterwards.) –

  • +
  • half_scan_percentage (Optional; Defines a fraction of the k-space data that is not sampled.) –

  • +
  • scale (For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of which the half-axes will set to the set scale % of the fully sampled region) –

  • +
+
+
Return type
+

A tuple of the mask and the number of columns selected.

+
+
+
+ +
+
+gaussian_coordinates()[source]
+

Creates a Gaussian sampled k-space coordinates.

+
+ +
+
+gaussian_kernel()[source]
+

Creates a Gaussian sampled k-space kernel.

+
+ +
+
+gaussian_kspace()[source]
+

Creates a Gaussian sampled k-space center.

+
+ +
+ +
+
+class mridc.collections.reconstruction.data.subsample.Gaussian2DMaskFunc(center_fractions: Sequence[float], accelerations: Sequence[int])[source]
+

Bases: mridc.collections.reconstruction.data.subsample.MaskFunc

+

Creates a 2D sub-sampling mask of a given shape.

+

For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of which the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled according to a Gaussian distribution.

+

The center fractions here act as Full-Width at Half-Maximum (FWHM) values.

+
+
+__call__(shape: Union[Sequence[int], numpy.ndarray], seed: Optional[Union[int, Tuple[int, ...]]] = None, half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02) Tuple[torch.Tensor, int][source]
+
+
Parameters
+
    +
  • shape (The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn along the second last dimension.) –

  • +
  • seed (Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. The random state is reset afterwards.) –

  • +
  • half_scan_percentage (Optional; Defines a fraction of the k-space data that is not sampled.) –

  • +
  • scale (For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of which the half-axes will set to the set scale % of the fully sampled region) –

  • +
+
+
Return type
+

A tuple of the mask and the number of columns selected.

+
+
+
+ +
+
+gaussian_coordinates()[source]
+

Creates a Gaussian sampled k-space coordinates.

+
+ +
+
+gaussian_kernel()[source]
+

Creates a Gaussian kernel.

+
+ +
+
+gaussian_kspace()[source]
+

Creates a Gaussian sampled k-space center.

+
+ +
+ +
+
+class mridc.collections.reconstruction.data.subsample.MaskFunc(center_fractions: Sequence[float], accelerations: Sequence[int])[source]
+

Bases: object

+

A class that defines a mask function.

+
+
+__call__(shape: Sequence[int], seed: Optional[Union[int, Tuple[int, ...]]] = None, half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02) Tuple[torch.Tensor, int][source]
+
+
Parameters
+
    +
  • shape (Shape of the input tensor.) –

  • +
  • seed (Seed for the random number generator.) –

  • +
  • half_scan_percentage (Percentage of the low-frequency columns to be retained.) –

  • +
  • scale (Scale of the mask.) –

  • +
+
+
Return type
+

A tuple of the mask and the number of low-frequency columns retained.

+
+
+
+ +
+
+choose_acceleration()[source]
+

Choose acceleration.

+
+ +
+ +
+
+class mridc.collections.reconstruction.data.subsample.Poisson2DMaskFunc(center_fractions: Sequence[float], accelerations: Sequence[int])[source]
+

Bases: mridc.collections.reconstruction.data.subsample.MaskFunc

+

Creates a 2D sub-sampling mask of a given shape.

+

For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of which the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled according to a (variable density) Poisson distribution.

+

For a given acceleration factor to be accurate, the scale for the fully sampled center should remain at the default 0.02. A predefined list is used to convert the acceleration factor to the appropriate r parameter needed for the variable density calculation. This list has been made to accommodate acceleration factors of 4 up to 21, rounding off to the nearest one available. As such, acceleration factors outside this range cannot be used.

+
+
+__call__(shape: Union[Sequence[int], numpy.ndarray], seed: Optional[Union[int, Tuple[int, ...]]] = None, half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02) Tuple[torch.Tensor, int][source]
+
+
Parameters
+
    +
  • shape (The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn along the second last dimension.) –

  • +
  • seed (Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. The random state is reset afterwards.) –

  • +
  • half_scan_percentage (Optional; Defines a fraction of the k-space data that is not sampled.) –

  • +
  • scale (For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of which the half-axes will set to the set scale % of the fully sampled region) –

  • +
+
+
Return type
+

A tuple of the mask and the number of columns selected.

+
+
+
+ +
+
+centered_circle()[source]
+

Creates a boolean centered circle image using the scale as a radius.

+
+ +
+
+poisson_disc2d()[source]
+

Creates a 2D Poisson disc pattern.

+
+ +
+ +
+
+class mridc.collections.reconstruction.data.subsample.RandomMaskFunc(center_fractions: Sequence[float], accelerations: Sequence[int])[source]
+

Bases: mridc.collections.reconstruction.data.subsample.MaskFunc

+

RandomMaskFunc creates a sub-sampling mask of a given shape.

+
+
The mask selects a subset of columns from the input k-space data. If the k-space data has N columns, the mask picks out:
    +
  1. N_low_freqs = (N * center_fraction) columns in the center corresponding to low-frequencies.

  2. +
  3. The other columns are selected uniformly at random with a probability equal to: prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). This ensures that the expected number of columns selected is equal to (N / acceleration).

  4. +
+
+
+

It is possible to use multiple center_fractions and accelerations, in which case one possible (center_fraction, acceleration) is chosen uniformly at random each time the RandomMaskFunc object is called.

+

For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50% probability that 8-fold acceleration with 4% center fraction is selected.

+
+
+__call__(shape: Sequence[int], seed: Optional[Union[int, Tuple[int, ...]]] = None, half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02) Tuple[torch.Tensor, int][source]
+
+
Parameters
+
    +
  • shape (The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn along the second last dimension.) –

  • +
  • seed (Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. The random state is reset afterwards.) –

  • +
  • half_scan_percentage (Optional; Defines a fraction of the k-space data that is not sampled.) –

  • +
  • scale (Optional; Defines the scale of the center of the mask.) –

  • +
+
+
Return type
+

A tuple of the mask and the number of columns selected.

+
+
+
+ +
+ +
+
+mridc.collections.reconstruction.data.subsample.create_mask_for_mask_type(mask_type_str: str, center_fractions: Sequence[float], accelerations: Sequence[int]) mridc.collections.reconstruction.data.subsample.MaskFunc[source]
+

Creates a MaskFunc object for the given mask type.

+
+
Parameters
+
    +
  • mask_type_str (The string representation of the mask type.) –

  • +
  • center_fractions (The center fractions for the mask.) –

  • +
  • accelerations (The accelerations for the mask.) –

  • +
+
+
Return type
+

A MaskFunc object.

+
+
+
+ +
+
+mridc.collections.reconstruction.data.subsample.temp_seed(rng: <module 'numpy.random' from '/home/dimitris/anaconda3/lib/python3.9/site-packages/numpy/random/__init__.py'>, seed: typing.Optional[typing.Union[int, typing.Tuple[int, ...]]])[source]
+

Temporarily sets the seed of the given random number generator.

+
+
Parameters
+
    +
  • rng (The random number generator.) –

  • +
  • seed (The seed to set.) –

  • +
+
+
Return type
+

A context manager.

+
+
+
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.html b/docs/build/html/mridc.collections.reconstruction.html new file mode 100644 index 00000000..d41c78fe --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.html @@ -0,0 +1,268 @@ + + + + + + + mridc.collections.reconstruction package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction package

+
+

Subpackages

+
+ +
+
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.metrics.html b/docs/build/html/mridc.collections.reconstruction.metrics.html new file mode 100644 index 00000000..82591f55 --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.metrics.html @@ -0,0 +1,229 @@ + + + + + + + mridc.collections.reconstruction.metrics package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.metrics package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.metrics.evaluate module

+
+
+class mridc.collections.reconstruction.metrics.evaluate.Metrics(metric_funcs, output_path, method)[source]
+

Bases: object

+

Maintains running statistics for a given collection of metrics.

+
+
+__repr__()[source]
+

Representation of the metrics.

+
+ +
+
+means()[source]
+

Mean of the means of each metric.

+
+ +
+
+push(target, recons)[source]
+

Pushes a new batch of metrics to the running statistics.

+
+
Parameters
+
    +
  • target (target image) –

  • +
  • recons (reconstructed image) –

  • +
+
+
Returns
+

dict

+
+
Return type
+

A dict where the keys are metric names and the values are

+
+
+
+ +
+
+stddevs()[source]
+

Standard deviation of the means of each metric.

+
+ +
+ +
+
+mridc.collections.reconstruction.metrics.evaluate.evaluate(arguments, reconstruction_key, mask_background, output_path, method, acc, no_params, slice_start, slice_end)[source]
+

Evaluates the reconstructions.

+
+
Parameters
+
    +
  • arguments (The CLI arguments.) –

  • +
  • reconstruction_key (The key of the reconstruction to evaluate.) –

  • +
  • mask_background (The background mask.) –

  • +
  • output_path (The output path.) –

  • +
  • method (The reconstruction method.) –

  • +
  • acc (The acceleration factor.) –

  • +
  • no_params (The number of parameters.) –

  • +
  • slice_start (The start slice. (optional)) –

  • +
  • slice_end (The end slice. (optional)) –

  • +
+
+
Returns
+

dict

+
+
Return type
+

A dict where the keys are metric names and the values are the mean of the metric.

+
+
+
+ +
+
+mridc.collections.reconstruction.metrics.evaluate.mse(gt: numpy.ndarray, pred: numpy.ndarray) float[source]
+

Compute Mean Squared Error (MSE)

+
+ +
+
+mridc.collections.reconstruction.metrics.evaluate.nmse(gt: numpy.ndarray, pred: numpy.ndarray) float[source]
+

Compute Normalized Mean Squared Error (NMSE)

+
+ +
+
+mridc.collections.reconstruction.metrics.evaluate.psnr(gt: numpy.ndarray, pred: numpy.ndarray, maxval: Optional[numpy.ndarray] = None) float[source]
+

Compute Peak Signal to Noise Ratio metric (PSNR)

+
+ +
+
+mridc.collections.reconstruction.metrics.evaluate.ssim(gt: numpy.ndarray, pred: numpy.ndarray, maxval: Optional[numpy.ndarray] = None) float[source]
+

Compute Structural Similarity Index Metric (SSIM)

+
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.cascadenet.html b/docs/build/html/mridc.collections.reconstruction.models.cascadenet.html new file mode 100644 index 00000000..2f9689c1 --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.cascadenet.html @@ -0,0 +1,210 @@ + + + + + + + mridc.collections.reconstruction.models.cascadenet package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.cascadenet package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.cascadenet.ccnn_block module

+
+
+class mridc.collections.reconstruction.models.cascadenet.ccnn_block.CascadeNetBlock(model: torch.nn.modules.module.Module, fft_type: str = 'orthogonal', no_dc: bool = False)[source]
+

Bases: torch.nn.modules.module.Module

+

Model block for CascadeNet & Convolution Recurrent Neural Network.

+

This model applies a combination of soft data consistency with the input model as a regularizer. +A series of these blocks can be stacked to form the full variational network.

+
+
+forward(pred: torch.Tensor, ref_kspace: torch.Tensor, sens_maps: torch.Tensor, mask: torch.Tensor) torch.Tensor[source]
+

Forward pass of the model block.

+
+
Parameters
+
    +
  • pred (Predicted k-space data.) – torch.Tensor, shape [batch_size, n_coils, height, width, 2]

  • +
  • ref_kspace (Reference k-space data.) – torch.Tensor, shape [batch_size, n_coils, height, width, 2]

  • +
  • sens_maps (Sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, height, width, 2]

  • +
  • mask (Mask to apply to the data.) – torch.Tensor, shape [batch_size, 1, height, width, 1]

  • +
+
+
Returns
+

torch.Tensor, shape [batch_size, height, width, 2]

+
+
Return type
+

Reconstructed image.

+
+
+
+ +
+
+sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) torch.Tensor[source]
+

Expand the sensitivity maps to the same size as the input.

+
+
Parameters
+
    +
  • x (Input data.) – torch.Tensor, shape [batch_size, n_coils, height, width, 2]

  • +
  • sens_maps (Sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, height, width, 2]

  • +
+
+
Returns
+

torch.Tensor, shape [batch_size, n_coils, height, width, 2]

+
+
Return type
+

SENSE reconstruction expanded to the same size as the input.

+
+
+
+ +
+
+sens_reduce(x: torch.Tensor, sens_maps: torch.Tensor) torch.Tensor[source]
+

Reduce the sensitivity maps to the same size as the input.

+
+
Parameters
+
    +
  • x (Input data.) – torch.Tensor, shape [batch_size, n_coils, height, width, 2]

  • +
  • sens_maps (Sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, height, width, 2]

  • +
+
+
Returns
+

torch.Tensor, shape [batch_size, height, width, 2]

+
+
Return type
+

SENSE reconstruction.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.conv.html b/docs/build/html/mridc.collections.reconstruction.models.conv.html new file mode 100644 index 00000000..4b131e3b --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.conv.html @@ -0,0 +1,201 @@ + + + + + + + mridc.collections.reconstruction.models.conv package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.conv package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.conv.conv2d module

+
+
+class mridc.collections.reconstruction.models.conv.conv2d.Conv2d(in_channels, out_channels, hidden_channels, n_convs=3, activation=PReLU(num_parameters=1), batchnorm=False)[source]
+

Bases: torch.nn.modules.module.Module

+

Implementation of a simple cascade of 2D convolutions. +If batchnorm is set to True, batch normalization layer is applied after each convolution.

+
+
+forward(x)[source]
+

Performs the forward pass of Conv2d.

+
+
Parameters
+

x (Input tensor.) –

+
+
Return type
+

Convoluted output.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.conv.gruconv2d module

+
+
+class mridc.collections.reconstruction.models.conv.gruconv2d.GRUConv2d(in_channels, out_channels, hidden_channels, n_convs=3, activation='ReLU', batchnorm=False)[source]
+

Bases: torch.nn.modules.module.Module

+

Implementation of a GRU followed by a number of 2D convolutions inspired by 1.

+

References

+
+
1
+
    +
  1. Qin, J. Schlemper, J. Caballero, A. N. Price, J. V. Hajnal and D. Rueckert, “Convolutional Recurrent Neural Networks for Dynamic MR Image Reconstruction,” in IEEE Transactions on Medical Imaging, vol. 38, no. 1, pp. 280-290, Jan. 2019, doi: 10.1109/TMI.2018.2863670.

  2. +
+
+
+
+
+forward(x, hx: Optional[torch.Tensor] = None)[source]
+

Performs the forward pass of Conv2d.

+
+
Parameters
+
    +
  • x (Input tensor.) – torch.Tensor

  • +
  • hx (Initial hidden state.) – torch.Tensor

  • +
+
+
Return type
+

Convoluted output.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.convrecnet.html b/docs/build/html/mridc.collections.reconstruction.models.convrecnet.html new file mode 100644 index 00000000..6903553f --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.convrecnet.html @@ -0,0 +1,225 @@ + + + + + + + mridc.collections.reconstruction.models.convrecnet package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.convrecnet package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.convrecnet.crnn_block module

+
+
+class mridc.collections.reconstruction.models.convrecnet.crnn_block.DataConsistencyLayer[source]
+

Bases: torch.nn.modules.module.Module

+

Data consistency layer for the CRNN. +This layer is used to ensure that the output of the CRNN is the same as the input.

+
+
+forward(pred_kspace, ref_kspace, mask)[source]
+

Forward pass of the data consistency layer.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.convrecnet.crnn_block.RecurrentConvolutionalNetBlock(model: torch.nn.modules.module.Module, num_iterations: int = 10, fft_type: str = 'orthogonal', no_dc: bool = False)[source]
+

Bases: torch.nn.modules.module.Module

+

Model block for Recurrent Convolution Neural Network inspired by 1.

+

References

+
+
1
+
    +
  1. Qin, J. Schlemper, J. Caballero, A. N. Price, J. V. Hajnal and D. Rueckert, “Convolutional Recurrent Neural Networks for Dynamic MR Image Reconstruction,” in IEEE Transactions on Medical Imaging, vol. 38, no. 1, pp. 280-290, Jan. 2019, doi: 10.1109/TMI.2018.2863670.

  2. +
+
+
+
+
+forward(ref_kspace: torch.Tensor, sens_maps: torch.Tensor, mask: torch.Tensor) List[Union[torch.Tensor, Any]][source]
+

Forward pass of the model.

+
+
Parameters
+
    +
  • ref_kspace (Reference k-space data.) –

  • +
  • sens_maps (Sensitivity maps.) –

  • +
  • mask (Mask to apply to the data.) –

  • +
+
+
Return type
+

Reconstructed image.

+
+
+
+ +
+
+sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) torch.Tensor[source]
+

Expand the sensitivity maps to the same size as the input.

+
+
Parameters
+
    +
  • x (Input data.) –

  • +
  • sens_maps (Sensitivity maps.) –

  • +
+
+
Return type
+

SENSE reconstruction expanded to the same size as the input.

+
+
+
+ +
+
+sens_reduce(x: torch.Tensor, sens_maps: torch.Tensor) torch.Tensor[source]
+

Reduce the sensitivity maps to the same size as the input.

+
+
Parameters
+
    +
  • x (Input data.) –

  • +
  • sens_maps (Sensitivity maps.) –

  • +
+
+
Return type
+

SENSE reconstruction reduced to the same size as the input.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.crossdomain.html b/docs/build/html/mridc.collections.reconstruction.models.crossdomain.html new file mode 100644 index 00000000..9dbb72b9 --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.crossdomain.html @@ -0,0 +1,213 @@ + + + + + + + mridc.collections.reconstruction.models.crossdomain package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.crossdomain package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.crossdomain.crossdomain module

+
+
+class mridc.collections.reconstruction.models.crossdomain.crossdomain.CrossDomainNetwork(image_model_list: torch.nn.modules.module.Module, kspace_model_list: Optional[torch.nn.modules.module.Module] = None, domain_sequence: str = 'KIKI', image_buffer_size: int = 1, kspace_buffer_size: int = 1, normalize_image: bool = False, fft_type: str = 'orthogonal', **kwargs)[source]
+

Bases: torch.nn.modules.module.Module

+

This performs optimisation in both, k-space (“K”) and image (“I”) domains according to domain_sequence.

+
+
+forward(masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor, sampling_mask: torch.Tensor) torch.Tensor[source]
+

Computes the forward pass of CrossDomainNetwork.

+
+
Parameters
+
    +
  • masked_kspace (Subsampled k-space data.) – torch.tenor, shape [batch_size, n_coil, height, width, 2]

  • +
  • sensitivity_map (Sensitivity map.) – torch.tenor, shape [batch_size, n_coil, height, width, 2]

  • +
  • sampling_mask (Sampling mask.) – torch.tenor, shape [batch_size, 1, height, width, 1]

  • +
+
+
Returns
+

torch.tenor, shape [batch_size, height, width, 2]

+
+
Return type
+

Output image.

+
+
+
+ +
+
+image_correction(block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map)[source]
+

Performs image correction.

+
+ +
+
+kspace_correction(block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map, masked_kspace)[source]
+

Performs k-space correction.

+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.crossdomain.multicoil module

+
+
+class mridc.collections.reconstruction.models.crossdomain.multicoil.MultiCoil(model: torch.nn.modules.module.Module, coil_dim: int = 1, coil_to_batch: bool = False)[source]
+

Bases: torch.nn.modules.module.Module

+

This makes the forward pass of multi-coil data of shape (N, N_coils, H, W, C) to a model. +If coil_to_batch is set to True, coil dimension is moved to the batch dimension. Otherwise, it passes to the model +each coil-data individually.

+
+
+forward(x: torch.Tensor) torch.Tensor[source]
+

Performs the forward pass of MultiCoil.

+
+
Parameters
+

x (Multi-coil input.) – torch.Tensor, shape (N, N_coils, H, W, C)

+
+
Returns
+

torch.Tensor, shape (N, N_coils, H, W, C)

+
+
Return type
+

Multi-coil output.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.didn.html b/docs/build/html/mridc.collections.reconstruction.models.didn.html new file mode 100644 index 00000000..1b6e66c9 --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.didn.html @@ -0,0 +1,281 @@ + + + + + + + mridc.collections.reconstruction.models.didn package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.didn package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.didn.didn module

+
+
+class mridc.collections.reconstruction.models.didn.didn.DIDN(in_channels: int, out_channels: int, hidden_channels: int = 128, num_dubs: int = 6, num_convs_recon: int = 9, skip_connection: bool = False)[source]
+

Bases: torch.nn.modules.module.Module

+

Deep Iterative Down-up convolutional Neural network (DIDN) implementation as in Yu, Songhyun, et al.

+

References

+
+
+static crop_to_shape(x, shape)[source]
+

Crops x to specified shape.

+
+
Parameters
+
    +
  • x (Input tensor with shape (*, H, W).) –

  • +
  • shape (Crop shape corresponding to H, W.) –

  • +
+
+
Return type
+

Cropped tensor.

+
+
+
+ +
+
+forward(x, channel_dim=1)[source]
+

Takes as input a torch.Tensor x and computes DIDN(x).

+
+
Parameters
+
    +
  • x (Input tensor.) –

  • +
  • channel_dim (Channel dimension. Default: 1.) –

  • +
+
+
Return type
+

DIDN output tensor.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.didn.didn.DUB(in_channels, out_channels)[source]
+

Bases: torch.nn.modules.module.Module

+

Down-up block (DUB) for DIDN model as implemented in Yu, Songhyun, et al.

+

References

+
+
+static crop_to_shape(x, shape)[source]
+

Crops x to specified shape.

+
+
Parameters
+
    +
  • x (Input tensor with shape (*, H, W).) –

  • +
  • shape (Crop shape corresponding to H, W.) –

  • +
+
+
Return type
+

Cropped tensor.

+
+
+
+ +
+
+forward(x)[source]
+
+
Parameters
+

x (Input tensor.) –

+
+
Return type
+

DUB output.

+
+
+
+ +
+
+static pad(x)[source]
+

Pads input to height and width dimensions if odd.

+
+
Parameters
+

x (Input to pad.) –

+
+
Return type
+

Padded tensor.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.didn.didn.ReconBlock(in_channels, num_convs)[source]
+

Bases: torch.nn.modules.module.Module

+

Reconstruction Block of DIDN model as implemented in Yu, Songhyun, et al.

+

References

+
+
+forward(input_data)[source]
+

Computes num_convs convolutions followed by PReLU activation on input_data.

+
+
Parameters
+

input_data (Input tensor.) –

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.didn.didn.Subpixel(in_channels, out_channels, upscale_factor, kernel_size, padding=0)[source]
+

Bases: torch.nn.modules.module.Module

+

Subpixel convolution layer for up-scaling of low resolution features at super-resolution as implemented in Yu, Songhyun, et al.

+

References

+
+
+forward(x)[source]
+

Computes Subpixel convolution on input torch.Tensor x.

+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.html b/docs/build/html/mridc.collections.reconstruction.models.html new file mode 100644 index 00000000..9e9a47c1 --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.html @@ -0,0 +1,2097 @@ + + + + + + + mridc.collections.reconstruction.models package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models package

+
+

Subpackages

+
+ +
+
+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.base module

+
+
+class mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.core.classes.modelPT.ModelPT, abc.ABC

+

Base class of all MRIReconstruction models.

+
+
+log_image(name, image)[source]
+

Logs an image.

+
+
Parameters
+
    +
  • name (Name of the image.) – str

  • +
  • image (Image to log.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
+
+ +
+
+static process_inputs(y, mask, init_pred)[source]
+

Processes the inputs to the method.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – list of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – list of torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – list of torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

    +
  • y (Subsampled k-space data.) – randomly selected y

  • +
  • mask (Sampling mask.) – randomly selected mask

  • +
  • init_pred (Initial prediction.) – randomly selected init_pred

  • +
  • r (Random index.)

  • +
+

+
+
+
+ +
+
+process_loss(target, pred, _loss_fn)[source]
+

Processes the loss.

+
+
Parameters
+
    +
  • target (Target data.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • pred (Final prediction(s).) – list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or +torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • _loss_fn (Loss function.) – torch.nn.Module, default torch.nn.L1Loss()

  • +
+
+
Returns
+

loss – If self.accumulate_loss is True, returns an accumulative result of all intermediate losses.

+
+
Return type
+

torch.FloatTensor, shape [1]

+
+
+
+ +
+
+setup_test_data(test_data_config: Optional[omegaconf.dictconfig.DictConfig])[source]
+

Setups the test data.

+
+
Parameters
+

test_data_config (Test data configuration.) – dict

+
+
Returns
+

test_data – torch.utils.data.DataLoader

+
+
Return type
+

Test data.

+
+
+
+ +
+
+setup_training_data(train_data_config: Optional[omegaconf.dictconfig.DictConfig])[source]
+

Setups the training data.

+
+
Parameters
+

train_data_config (Training data configuration.) – dict

+
+
Returns
+

train_data – torch.utils.data.DataLoader

+
+
Return type
+

Training data.

+
+
+
+ +
+
+setup_validation_data(val_data_config: Optional[omegaconf.dictconfig.DictConfig])[source]
+

Setups the validation data.

+
+
Parameters
+

val_data_config (Validation data configuration.) – dict

+
+
Returns
+

val_data – torch.utils.data.DataLoader

+
+
Return type
+

Validation data.

+
+
+
+ +
+
+test_epoch_end(outputs)[source]
+

Called at the end of test epoch to aggregate outputs.

+
+
Parameters
+

outputs (List of outputs of the test batches.) – list of dicts

+
+
Return type
+

Saves the reconstructed images to .h5 files.

+
+
+
+ +
+
+test_step(batch: Dict[float, torch.Tensor], batch_idx: int) Tuple[str, int, torch.Tensor][source]
+

Performs a test step.

+
+
Parameters
+
    +
  • batch (Batch of data. Dict[str, torch.Tensor], with keys,) –

    +
    ‘y’: subsampled kspace,

    torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

    +
    +
    ’sensitivity_maps’: sensitivity_maps,

    torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

    +
    +
    ’mask’: sampling mask,

    torch.Tensor, shape [1, 1, n_x, n_y, 1]

    +
    +
    ’init_pred’: initial prediction. For example zero-filled or PICS.

    torch.Tensor, shape [batch_size, n_x, n_y, 2]

    +
    +
    ’target’: target data,

    torch.Tensor, shape [batch_size, n_x, n_y, 2]

    +
    +
    ’fname’: filename,

    str, shape [batch_size]

    +
    +
    ’slice_idx’: slice_idx,

    torch.Tensor, shape [batch_size]

    +
    +
    ’acc’: acceleration factor,

    torch.Tensor, shape [batch_size]

    +
    +
    ’max_value’: maximum value of the magnitude image space,

    torch.Tensor, shape [batch_size]

    +
    +
    ’crop_size’: crop size,

    torch.Tensor, shape [n_x, n_y]

    +
    +
    +

  • +
  • batch_idx (Batch index.) – int

  • +
+
+
Returns
+

    +
  • name (Name of the volume.) – str

  • +
  • slice_num (Slice number.) – int

  • +
  • pred (Predicted data.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+

+
+
+
+ +
+
+training: bool
+
+ +
+
+training_step(batch: Dict[float, torch.Tensor], batch_idx: int) Dict[str, torch.Tensor][source]
+

Performs a training step.

+
+
Parameters
+
    +
  • batch (Batch of data.) –

    Dict[str, torch.Tensor], with keys,

    +
    +
    ’y’: subsampled kspace,

    torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

    +
    +
    ’sensitivity_maps’: sensitivity_maps,

    torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

    +
    +
    ’mask’: sampling mask,

    torch.Tensor, shape [1, 1, n_x, n_y, 1]

    +
    +
    ’init_pred’: initial prediction. For example zero-filled or PICS.

    torch.Tensor, shape [batch_size, n_x, n_y, 2]

    +
    +
    ’target’: target data,

    torch.Tensor, shape [batch_size, n_x, n_y, 2]

    +
    +
    ’fname’: filename,

    str, shape [batch_size]

    +
    +
    ’slice_idx’: slice_idx,

    torch.Tensor, shape [batch_size]

    +
    +
    ’acc’: acceleration factor,

    torch.Tensor, shape [batch_size]

    +
    +
    ’max_value’: maximum value of the magnitude image space,

    torch.Tensor, shape [batch_size]

    +
    +
    ’crop_size’: crop size,

    torch.Tensor, shape [n_x, n_y]

    +
    +
    +

  • +
  • batch_idx (Batch index.) – int

  • +
+
+
Returns
+

    +
  • Dict[str, torch.Tensor], with keys,

  • +
  • ’loss’ (loss,) – torch.Tensor, shape [1]

  • +
  • ’log’ (log,) – dict, shape [1]

  • +
+

+
+
+
+ +
+
+validation_epoch_end(outputs)[source]
+

Called at the end of validation epoch to aggregate outputs.

+
+
Parameters
+

outputs (List of outputs of the validation batches.) – list of dicts

+
+
Returns
+

metrics – dict

+
+
Return type
+

Dictionary of metrics.

+
+
+
+ +
+
+validation_step(batch: Dict[float, torch.Tensor], batch_idx: int) Dict[source]
+

Performs a validation step.

+
+
Parameters
+
    +
  • batch (Batch of data. Dict[str, torch.Tensor], with keys,) –

    +
    ‘y’: subsampled kspace,

    torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

    +
    +
    ’sensitivity_maps’: sensitivity_maps,

    torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

    +
    +
    ’mask’: sampling mask,

    torch.Tensor, shape [1, 1, n_x, n_y, 1]

    +
    +
    ’init_pred’: initial prediction. For example zero-filled or PICS.

    torch.Tensor, shape [batch_size, n_x, n_y, 2]

    +
    +
    ’target’: target data,

    torch.Tensor, shape [batch_size, n_x, n_y, 2]

    +
    +
    ’fname’: filename,

    str, shape [batch_size]

    +
    +
    ’slice_idx’: slice_idx,

    torch.Tensor, shape [batch_size]

    +
    +
    ’acc’: acceleration factor,

    torch.Tensor, shape [batch_size]

    +
    +
    ’max_value’: maximum value of the magnitude image space,

    torch.Tensor, shape [batch_size]

    +
    +
    ’crop_size’: crop size,

    torch.Tensor, shape [n_x, n_y]

    +
    +
    +

  • +
  • batch_idx (Batch index.) – int

  • +
+
+
Returns
+

    +
  • Dict[str, torch.Tensor], with keys,

  • +
  • ’loss’ (loss,) – torch.Tensor, shape [1]

  • +
  • ’log’ (log,) – dict, shape [1]

  • +
+

+
+
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.base.BaseSensitivityModel(chans: int = 8, num_pools: int = 4, in_chans: int = 2, out_chans: int = 2, drop_prob: float = 0.0, padding_size: int = 15, mask_type: str = '2D', fft_type: str = 'orthogonal', normalize: bool = True, mask_center: bool = True)[source]
+

Bases: torch.nn.modules.module.Module, abc.ABC

+

Model for learning sensitivity estimation from k-space data. +This model applies an IFFT to multichannel k-space data and then a U-Net to the coil images to estimate coil +sensitivities.

+
+
+static batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) torch.Tensor[source]
+

Converts the number of channels in a tensor to the channel dimension.

+
+
Parameters
+
    +
  • x (Tensor to convert.) – torch.Tensor

  • +
  • batch_size (Original batch size.) – int

  • +
+
+
Returns
+

torch.Tensor

+
+
Return type
+

Converted tensor.

+
+
+
+ +
+
+static chans_to_batch_dim(x: torch.Tensor) Tuple[torch.Tensor, int][source]
+

Converts the number of channels in a tensor to the batch dimension.

+
+
Parameters
+

x (Tensor to convert.) – torch.Tensor

+
+
Returns
+

Tuple[torch.Tensor, int]

+
+
Return type
+

Tuple of the converted tensor and the original last dimension.

+
+
+
+ +
+
+static divide_root_sum_of_squares(x: torch.Tensor) torch.Tensor[source]
+

Divide the input by the root of the sum of squares of the magnitude of each complex number.

+
+
Parameters
+

x (Tensor to divide.) – torch.Tensor

+
+
Returns
+

torch.Tensor

+
+
Return type
+

RSS output tensor.

+
+
+
+ +
+
+forward(masked_kspace: torch.Tensor, mask: torch.Tensor, num_low_frequencies: Optional[int] = None) torch.Tensor[source]
+

Forward pass of the model.

+
+
Parameters
+
    +
  • masked_kspace (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [batch_size, 1, n_x, n_y, 1]

  • +
  • num_low_frequencies (Number of low frequencies to keep.) – int

  • +
+
+
Returns
+

torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

+
+
Return type
+

Normalized UNet output tensor.

+
+
+
+ +
+
+static get_pad_and_num_low_freqs(mask: torch.Tensor, num_low_frequencies: Optional[int] = None) Tuple[torch.Tensor, torch.Tensor][source]
+

Get the padding to apply to the input to make it square and the number of low frequencies to keep.

+
+
Parameters
+
    +
  • mask (Mask to use.) – torch.Tensor

  • +
  • num_low_frequencies (Number of low frequencies to keep.) – int

  • +
+
+
Returns
+

Tuple[torch.Tensor, torch.Tensor]

+
+
Return type
+

Tuple of the padding and the number of low frequencies to keep.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.ccnn module

+
+
+class mridc.collections.reconstruction.models.ccnn.CascadeNet(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Implementation of the Deep Cascade of Convolutional Neural Networks, as presented in Schlemper, J., Caballero, J., Hajnal, J. V., Price, A., & Rueckert, D.

+

References

+
+

Schlemper, J., Caballero, J., Hajnal, J. V., Price, A., & Rueckert, D., A Deep Cascade of Convolutional Neural Networks for MR Image Reconstruction. Information Processing in Medical Imaging (IPMI), 2017. Available at: https://arxiv.org/pdf/1703.00555.pdf

+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor) torch.Tensor[source]
+

Forward pass of the network.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – If self.accumulate_loss is True, returns a list of all intermediate estimates. +If False, returns the final estimate.

+
+
Return type
+

list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.cirim module

+
+
+class mridc.collections.reconstruction.models.cirim.CIRIM(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Implementation of the Cascades of Independently Recurrent Inference Machines, as presented in Karkalousos, D. et al.

+

References

+
+

Karkalousos, D. et al. (2021) ‘Assessment of Data Consistency through Cascades of Independently Recurrent Inference Machines for fast and robust accelerated MRI reconstruction’. Available at: https://arxiv.org/abs/2111.15498v1

+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor) Union[Generator, torch.Tensor][source]
+

Forward pass of the network.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – If self.accumulate_loss is True, returns a list of all intermediate estimates. +If False, returns the final estimate.

+
+
Return type
+

list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+process_intermediate_pred(pred, sensitivity_maps, target, do_coil_combination=False)[source]
+

Process the intermediate prediction.

+
+
Parameters
+
    +
  • pred (Intermediate prediction.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • target (Target data to crop to size.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • do_coil_combination (Whether to do coil combination.) – bool, default False

  • +
+
+
Returns
+

pred – Processed prediction.

+
+
Return type
+

torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+process_loss(target, pred, _loss_fn)[source]
+

Process the loss.

+
+
Parameters
+
    +
  • target (Target data.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • pred (Final prediction(s).) – list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or +torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • _loss_fn (Loss function.) – torch.nn.Module, default torch.nn.L1Loss()

  • +
+
+
Returns
+

loss – If self.accumulate_loss is True, returns an accumulative result of all intermediate losses.

+
+
Return type
+

torch.FloatTensor, shape [1]

+
+
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.crnn module

+
+
+class mridc.collections.reconstruction.models.crnn.CRNNet(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Implementation of the Convolutional Recurrent Neural Network, inspired by C. Qin, J. Schlemper, J. Caballero, A. N. Price, J. V. Hajnal and D. Rueckert.

+

References

+
+
    +
  1. Qin, J. Schlemper, J. Caballero, A. N. Price, J. V. Hajnal and D. Rueckert, “Convolutional Recurrent Neural Networks for Dynamic MR Image Reconstruction,” in IEEE Transactions on Medical Imaging, vol. 38, no. 1, pp. 280-290, Jan. 2019, doi: 10.1109/TMI.2018.2863670.

  2. +
+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor) Union[Generator, torch.Tensor][source]
+

Forward pass of the network.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – If self.accumulate_loss is True, returns a list of all intermediate estimates. +If False, returns the final estimate.

+
+
Return type
+

list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+process_intermediate_pred(pred, sensitivity_maps, target)[source]
+

Process the intermediate prediction.

+
+
Parameters
+
    +
  • pred (Intermediate prediction.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • target (Target data to crop to size.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – Processed prediction.

+
+
Return type
+

torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+process_loss(target, pred, _loss_fn)[source]
+

Process the loss.

+
+
Parameters
+
    +
  • target (Target data.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • pred (Final prediction(s).) – list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or +torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • _loss_fn (Loss function.) – torch.nn.Module, default torch.nn.L1Loss()

  • +
+
+
Returns
+

loss – If self.accumulate_loss is True, returns an accumulative result of all intermediate losses.

+
+
Return type
+

torch.FloatTensor, shape [1]

+
+
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.dunet module

+
+
+class mridc.collections.reconstruction.models.dunet.DUNet(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Implementation of the Down-Up NET, inspired by Hammernik, K, Schlemper, J, Qin, C, et al.

+

References

+
+

Hammernik, K, Schlemper, J, Qin, C, et al. Systematic evaluation of iterative deep neural networks for fast parallel MRI reconstruction with sensitivity-weighted coil combination. Magn Reson Med. 2021; 86: 1859– 1872. https://doi.org/10.1002/mrm.28827

+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor) torch.Tensor[source]
+

Forward pass of the network.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – If self.accumulate_loss is True, returns a list of all intermediate estimates. +If False, returns the final estimate.

+
+
Return type
+

list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.jointicnet module

+
+
+class mridc.collections.reconstruction.models.jointicnet.JointICNet(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Implementation of the Joint Deep Model-Based MR Image and Coil Sensitivity Reconstruction Network (Joint-ICNet), as presented in Jun, Yohan, et al.

+

References

+
+

Jun, Yohan, et al. “Joint Deep Model-Based MR Image and Coil Sensitivity Reconstruction Network (Joint-ICNet) for Fast MRI.” 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), IEEE, 2021, pp. 5266–75. DOI.org (Crossref), https://doi.org/10.1109/CVPR46437.2021.00523.

+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor) torch.Tensor[source]
+

Forward pass of the network.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – If self.accumulate_loss is True, returns a list of all intermediate estimates. +If False, returns the final estimate.

+
+
Return type
+

list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+
+update_C(idx, DC_sens, sensitivity_maps, image, y, mask) torch.Tensor[source]
+

Update the coil sensitivity maps.

+
+\[ \begin{align}\begin{aligned}C = (1 - 2 * \lambda_{k}^{C} * ni_{k}) * C_{k}\\C = 2 * \lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b))\\A(x_{k}) = M * F * (C * x_{k})\\C = 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^*\end{aligned}\end{align} \]
+
+
Parameters
+
    +
  • idx (int) – The current iteration index.

  • +
  • DC_sens (torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]) – The initial coil sensitivity maps.

  • +
  • sensitivity_maps (torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]) – The coil sensitivity maps.

  • +
  • image (torch.Tensor [batch_size, num_coils, num_rows, num_cols]) – The predicted image.

  • +
  • y (torch.Tensor [batch_size, num_coils, num_rows, num_cols]) – The subsampled k-space data.

  • +
  • mask (torch.Tensor [batch_size, 1, num_rows, num_cols]) – The subsampled mask.

  • +
+
+
Returns
+

sensitivity_maps – The updated coil sensitivity maps.

+
+
Return type
+

torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]

+
+
+
+ +
+
+update_X(idx, image, sensitivity_maps, y, mask)[source]
+

Update the image.

+
+\[ \begin{align}\begin{aligned}x_{k} = (1 - 2 * \lamdba_{{k}_{I}} * mi_{k} - 2 * \lamdba_{{k}_{F}} * mi_{k}) * x_{k}\\x_{k} = 2 * mi_{k} * (\lambda_{{k}_{I}} * D_I(x_{k}) + \lambda_{{k}_{F}} * F^-1(D_F(f)))\\A(x{k} - b) = M * F * (C * x{k}) - b\\x_{k} = 2 * mi_{k} * A^* * (A(x{k} - b))\end{aligned}\end{align} \]
+
+
Parameters
+
    +
  • idx (int) – The current iteration index.

  • +
  • image (torch.Tensor [batch_size, num_coils, num_rows, num_cols]) – The predicted image.

  • +
  • sensitivity_maps (torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols]) – The coil sensitivity maps.

  • +
  • y (torch.Tensor [batch_size, num_coils, num_rows, num_cols]) – The subsampled k-space data.

  • +
  • mask (torch.Tensor [batch_size, 1, num_rows, num_cols]) – The subsampled mask.

  • +
+
+
Returns
+

image – The updated image.

+
+
Return type
+

torch.Tensor [batch_size, num_coils, num_rows, num_cols]

+
+
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.kikinet module

+
+
+class mridc.collections.reconstruction.models.kikinet.KIKINet(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Based on KIKINet implementation [1]. Modified to work with multi-coil k-space data, as presented in Eo, Taejoon, et al.

+

References

+
+

Eo, Taejoon, et al. “KIKI-Net: Cross-Domain Convolutional Neural Networks for Reconstructing Undersampled Magnetic Resonance Images.” Magnetic Resonance in Medicine, vol. 80, no. 5, Nov. 2018, pp. 2188–201. PubMed, https://doi.org/10.1002/mrm.27201.

+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor) torch.Tensor[source]
+

Forward pass of the network.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – If self.accumulate_loss is True, returns a list of all intermediate estimates. +If False, returns the final estimate.

+
+
Return type
+

list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.lpd module

+
+
+class mridc.collections.reconstruction.models.lpd.LPDNet(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Implementation of the Learned Primal Dual network, inspired by Adler, Jonas, and Ozan Öktem.

+

References

+
+

Adler, Jonas, and Ozan Öktem. “Learned Primal-Dual Reconstruction.” IEEE Transactions on Medical Imaging, vol. 37, no. 6, June 2018, pp. 1322–32. arXiv.org, https://doi.org/10.1109/TMI.2018.2799231.

+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor) torch.Tensor[source]
+

Forward pass of the network.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – If self.accumulate_loss is True, returns a list of all intermediate estimates. +If False, returns the final estimate.

+
+
Return type
+

list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.multidomainnet module

+
+
+class mridc.collections.reconstruction.models.multidomainnet.MultiDomainNet(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Feature-level multi-domain module. Inspired by AIRS Medical submission to the FastMRI 2020 challenge.

+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor) torch.Tensor[source]
+

Forward pass of the network.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – If self.accumulate_loss is True, returns a list of all intermediate estimates. +If False, returns the final estimate.

+
+
Return type
+

list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.pics module

+
+
+class mridc.collections.reconstruction.models.pics.PICS(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Parallel-Imaging Compressed Sensing (PICS) reconstruction using the BART by Uecker, M. et al.

+

References

+
+

Uecker, M. et al. (2015) ‘Berkeley Advanced Reconstruction Toolbox’, Proc. Intl. Soc. Mag. Reson. Med., 23.

+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, target: Optional[torch.Tensor] = None) Union[list, Any][source]
+

Forward pass of PICS.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – Predicted data.

+
+
Return type
+

torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+static process_inputs(y, mask)[source]
+

Process the inputs to the method.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – list of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – list of torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
+
+
Returns
+

    +
  • y (Subsampled k-space data.) – randomly selected y

  • +
  • mask (Sampling mask.) – randomly selected mask

  • +
  • r (Random index.)

  • +
+

+
+
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+test_step(batch: Dict[float, torch.Tensor], batch_idx: int) Tuple[str, int, torch.Tensor][source]
+

Test step.

+
+
Parameters
+
    +
  • batch (Batch of data.) – Dict of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • batch_idx (Batch index.) – int

  • +
+
+
Returns
+

    +
  • name (Name of the volume.) – str

  • +
  • slice_num (Slice number.) – int

  • +
  • pred (Predicted data.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+

+
+
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.rvn module

+
+
+class mridc.collections.reconstruction.models.rvn.RecurrentVarNet(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Implementation of the Recurrent Variational Network implementation, as presented in Yiasemis, George, et al.

+

References

+
+

Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.

+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, **kwargs) torch.Tensor[source]
+

Forward pass of the network.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – If self.accumulate_loss is True, returns a list of all intermediate estimates. +If False, returns the final estimate.

+
+
Return type
+

list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.unet module

+
+
+class mridc.collections.reconstruction.models.unet.UNet(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Implementation of the UNet, as presented in O. Ronneberger, P. Fischer, and Thomas Brox.

+

References

+
+
    +
  1. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pages 234–241. Springer, 2015.

  2. +
+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor) torch.Tensor[source]
+

Forward pass of the network.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – If self.accumulate_loss is True, returns a list of all intermediate estimates. +If False, returns the final estimate.

+
+
Return type
+

list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.vn module

+
+
+class mridc.collections.reconstruction.models.vn.VarNet(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Implementation of the End-to-end Variational Network (VN), as presented in Sriram, A. et al.

+

References

+
+

Sriram, A. et al. (2020) ‘End-to-End Variational Networks for Accelerated MRI Reconstruction’. Available at: https://github.com/facebookresearch/fastMRI.

+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor) torch.Tensor[source]
+

Forward pass of the network.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – If self.accumulate_loss is True, returns a list of all intermediate estimates. +If False, returns the final estimate.

+
+
Return type
+

list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.vsnet module

+
+
+class mridc.collections.reconstruction.models.vsnet.VSNet(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Implementation of the Variable-Splitting Net, as presented in Duan, J. et al.

+

References

+
+

Duan, J. et al. (2019) ‘Vs-net: Variable splitting network for accelerated parallel MRI reconstruction’, Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture Notes in Bioinformatics), 11767 LNCS, pp. 713–722. doi: 10.1007/978-3-030-32251-9_78.

+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor) torch.Tensor[source]
+

Forward pass of the network.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – If self.accumulate_loss is True, returns a list of all intermediate estimates. +If False, returns the final estimate.

+
+
Return type
+

list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.xpdnet module

+
+
+class mridc.collections.reconstruction.models.xpdnet.XPDNet(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Implementation of the XPDNet, as presented in Ramzi, Zaccharie, et al.

+

References

+
+

Ramzi, Zaccharie, et al. “XPDNet for MRI Reconstruction: An Application to the 2020 FastMRI Challenge. ” ArXiv:2010.07290 [Physics, Stat], July 2021. arXiv.org, http://arxiv.org/abs/2010.07290.

+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor) torch.Tensor[source]
+

Forward pass of the network.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – If self.accumulate_loss is True, returns a list of all intermediate estimates. +If False, returns the final estimate.

+
+
Return type
+

list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.zf module

+
+
+class mridc.collections.reconstruction.models.zf.ZF(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel, abc.ABC

+

Zero-Filled reconstruction using either root-sum-of-squares (RSS) or SENSE (SENSitivity Encoding), as presented in Pruessmann KP, Weiger M, Scheidegger MB, Boesiger P.

+

References

+
+

Pruessmann KP, Weiger M, Scheidegger MB, Boesiger P. SENSE: Sensitivity encoding for fast MRI. Magn Reson Med 1999; 42:952-962.

+
+
+
+allow_zero_length_dataloader_with_multiple_devices: bool
+
+ +
+
+forward(y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, target: Optional[torch.Tensor] = None) Union[list, Any][source]
+

Forward pass of the zero-filled method.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • sensitivity_maps (Coil sensitivity maps.) – torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
  • init_pred (Initial prediction.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
  • target (Target data to compute the loss.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+
+
Returns
+

pred – Predicted data.

+
+
Return type
+

torch.Tensor, shape [batch_size, n_x, n_y, 2]

+
+
+
+ +
+
+mse_vals: Dict
+
+ +
+
+nmse_vals: Dict
+
+ +
+
+precision: int
+
+ +
+
+prepare_data_per_node: bool
+
+ +
+
+static process_inputs(y, mask)[source]
+

Process the inputs to the method.

+
+
Parameters
+
    +
  • y (Subsampled k-space data.) – list of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • mask (Sampling mask.) – list of torch.Tensor, shape [1, 1, n_x, n_y, 1]

  • +
+
+
Returns
+

    +
  • y (Subsampled k-space data.) – randomly selected y

  • +
  • mask (Sampling mask.) – randomly selected mask

  • +
  • r (Random index.)

  • +
+

+
+
+
+ +
+
+psnr_vals: Dict
+
+ +
+
+ssim_vals: Dict
+
+ +
+
+test_step(batch: Dict[float, torch.Tensor], batch_idx: int) Tuple[str, int, torch.Tensor][source]
+

Test step.

+
+
Parameters
+
    +
  • batch (Batch of data.) – Dict of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]

  • +
  • batch_idx (Batch index.) – int

  • +
+
+
Returns
+

    +
  • name (Name of the volume.) – str

  • +
  • slice_num (Slice number.) – int

  • +
  • pred (Predicted data.) – torch.Tensor, shape [batch_size, n_x, n_y, 2]

  • +
+

+
+
+
+ +
+
+trainer: Optional['pl.Trainer']
+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.multidomain.html b/docs/build/html/mridc.collections.reconstruction.models.multidomain.html new file mode 100644 index 00000000..486b10cd --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.multidomain.html @@ -0,0 +1,252 @@ + + + + + + + mridc.collections.reconstruction.models.multidomain package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.multidomain package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.multidomain.multidomain module

+
+
+class mridc.collections.reconstruction.models.multidomain.multidomain.MultiDomainConv2d(fft_type, in_channels, out_channels, **kwargs)[source]
+

Bases: torch.nn.modules.module.Module

+

Multi-domain convolution layer.

+
+
+forward(image)[source]
+

Forward method for the MultiDomainConv2d class.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.multidomain.multidomain.MultiDomainConvBlock(fft_type, in_channels: int, out_channels: int, dropout_probability: float)[source]
+

Bases: torch.nn.modules.module.Module

+

A multi-domain convolutional block that consists of two multi-domain convolution layers each followed by instance +normalization, LeakyReLU activation and dropout.

+
+
+forward(_input: torch.Tensor)[source]
+

Forward method for the MultiDomainConvBlock class.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.multidomain.multidomain.MultiDomainConvTranspose2d(fft_type, in_channels, out_channels, **kwargs)[source]
+

Bases: torch.nn.modules.module.Module

+

Multi-Domain convolutional transpose layer.

+
+
+forward(image)[source]
+

Forward method for the MultiDomainConvTranspose2d class.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.multidomain.multidomain.MultiDomainUnet2d(in_channels: int, out_channels: int, num_filters: int, num_pool_layers: int, dropout_probability: float, fft_type: str = 'orthogonal')[source]
+

Bases: torch.nn.modules.module.Module

+

Unet modification to be used with Multi-domain network as in AIRS Medical submission to the Fast MRI 2020 +challenge.

+
+
+forward(input_data: torch.Tensor)[source]
+

Forward pass of the u-net.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.multidomain.multidomain.StandardizationLayer(coil_dim=1, channel_dim=- 1)[source]
+

Bases: torch.nn.modules.module.Module

+

Multi-channel data standardization method. Inspired by AIRS model submission to the Fast MRI 2020 challenge. +Given individual coil images \(\{x_i\}_{i=1}^{N_c}\) and sensitivity coil maps \(\{S_i\}_{i=1}^{N_c}\) it returns

+
+\[[(x_{sense}, {x_{res}}_1), ..., (x_{sense}, {x_{res}}_{N_c})]\]
+

where

+

\({x_{res}}_i = xi - S_i X x_{sense}\) and

+

\(x_{sense} = \sum_{i=1}^{N_c} {S_i}^{*} X x_i\).

+
+
+forward(coil_images: torch.Tensor, sensitivity_map: torch.Tensor) torch.Tensor[source]
+

Forward pass.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.multidomain.multidomain.TransposeMultiDomainConvBlock(fft_type, in_channels: int, out_channels: int)[source]
+

Bases: torch.nn.modules.module.Module

+

A Transpose Convolutional Block that consists of one convolution transpose layers followed by instance +normalization and LeakyReLU activation.

+
+
+forward(input_data: torch.Tensor)[source]
+

Forward method for the TransposeMultiDomainConvBlock class.

+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.mwcnn.html b/docs/build/html/mridc.collections.reconstruction.models.mwcnn.html new file mode 100644 index 00000000..1c321130 --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.mwcnn.html @@ -0,0 +1,318 @@ + + + + + + + mridc.collections.reconstruction.models.mwcnn package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.mwcnn package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.mwcnn.mwcnn module

+
+
+class mridc.collections.reconstruction.models.mwcnn.mwcnn.ConvBlock(in_channels: int, out_channels: int, kernel_size: int, bias: bool = True, batchnorm: bool = False, activation: torch.nn.modules.module.Module = ReLU(inplace=True), scale: Optional[float] = 1.0)[source]
+

Bases: torch.nn.modules.module.Module

+

Convolution Block for MWCNN as implemented in Liu, Pengju, et al.

+

References

+
+

Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.

+
+
+
+forward(x: torch.Tensor) torch.Tensor[source]
+

Performs forward pass of ConvBlock.

+
+
Parameters
+

x (Input with shape (N, C, H, W).) –

+
+
Return type
+

Output with shape (N, C’, H’, W’).

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.mwcnn.mwcnn.DWT[source]
+

Bases: torch.nn.modules.module.Module

+

2D Discrete Wavelet Transform as implemented in Liu, Pengju, et al.

+

References

+
+

Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.

+
+
+
+static forward(x: torch.Tensor) torch.Tensor[source]
+

Computes DWT(x) given tensor x.

+
+
Parameters
+

x (Input tensor.) –

+
+
Return type
+

DWT of x.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.mwcnn.mwcnn.DilatedConvBlock(in_channels: int, dilations: Tuple[int, int], kernel_size: int, out_channels: Optional[int] = None, bias: bool = True, batchnorm: bool = False, activation: torch.nn.modules.module.Module = ReLU(inplace=True), scale: Optional[float] = 1.0)[source]
+

Bases: torch.nn.modules.module.Module

+

Double dilated Convolution Block fpr MWCNN as implemented in Liu, Pengju, et al.

+

References

+
+

Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.

+
+
+
+forward(x: torch.Tensor) torch.Tensor[source]
+

Performs forward pass of DilatedConvBlock.

+
+
Parameters
+

x (Input with shape (N, C, H, W).) –

+
+
Return type
+

Output with shape (N, C’, H’, W’).

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.mwcnn.mwcnn.IWT[source]
+

Bases: torch.nn.modules.module.Module

+

2D Inverse Wavelet Transform as implemented in Liu, Pengju, et al.

+

References

+
+

Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.

+
+
+
+forward(x: torch.Tensor) torch.Tensor[source]
+

Computes IWT(x) given tensor x.

+
+
Parameters
+

x (Input tensor.) –

+
+
Return type
+

IWT of x.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.mwcnn.mwcnn.MWCNN(input_channels: int, first_conv_hidden_channels: int, num_scales: int = 4, bias: bool = True, batchnorm: bool = False, activation: torch.nn.modules.module.Module = ReLU(inplace=True))[source]
+

Bases: torch.nn.modules.module.Module

+

Multi-level Wavelet CNN (MWCNN) implementation as implemented in Liu, Pengju, et al.

+

References

+
+

Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.

+
+
+
+static crop_to_shape(x, shape)[source]
+

Crop the input to the given shape.

+
+
Parameters
+
    +
  • x (Input tensor.) –

  • +
  • shape (Tuple of (height, width).) –

  • +
+
+
Return type
+

Cropped tensor.

+
+
+
+ +
+
+forward(input_tensor: torch.Tensor, res: bool = False) torch.Tensor[source]
+

Computes forward pass of MWCNN.

+
+
Parameters
+
    +
  • input_tensor (Input tensor.) – torch.tensor

  • +
  • res (If True, residual connection is applied to the output.) – bool, Default: False.

  • +
+
+
Return type
+

Output tensor.

+
+
+
+ +
+
+static pad(x)[source]
+

Pad the input with zeros.

+
+
Parameters
+

x (Input tensor.) –

+
+
Return type
+

Padded tensor.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.primaldual.html b/docs/build/html/mridc.collections.reconstruction.models.primaldual.html new file mode 100644 index 00000000..ebedeaa7 --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.primaldual.html @@ -0,0 +1,198 @@ + + + + + + + mridc.collections.reconstruction.models.primaldual package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.primaldual package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.primaldual.pd module

+
+
+class mridc.collections.reconstruction.models.primaldual.pd.DualNet(num_dual, **kwargs)[source]
+

Bases: torch.nn.modules.module.Module

+

Dual Network for Learned Primal Dual Network.

+
+
+static compute_model_per_coil(model, data)[source]
+

Computes model per coil.

+
+
Parameters
+
    +
  • model (Model to compute.) –

  • +
  • data (Multi-coil input.) –

  • +
+
+
Return type
+

Multi-coil output.

+
+
+
+ +
+
+forward(h, forward_f, g)[source]
+

Forward pass.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.primaldual.pd.PrimalNet(num_primal, **kwargs)[source]
+

Bases: torch.nn.modules.module.Module

+

Primal Network for Learned Primal Dual Network.

+
+
+forward(f, backward_h)[source]
+

Forward pass of primal network.

+
+
Parameters
+
    +
  • f (Forward function.) –

  • +
  • backward_h (Backward function.) –

  • +
+
+
Return type
+

Primal function.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.recurrentvarnet.html b/docs/build/html/mridc.collections.reconstruction.models.recurrentvarnet.html new file mode 100644 index 00000000..601f5838 --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.recurrentvarnet.html @@ -0,0 +1,241 @@ + + + + + + + mridc.collections.reconstruction.models.recurrentvarnet package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.recurrentvarnet package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.recurrentvarnet.conv2gru module

+
+
+class mridc.collections.reconstruction.models.recurrentvarnet.conv2gru.Conv2dGRU(in_channels: int, hidden_channels: int, out_channels: Optional[int] = None, num_layers: int = 2, gru_kernel_size=1, orthogonal_initialization: bool = True, instance_norm: bool = False, dense_connect: int = 0, replication_padding: bool = True)[source]
+

Bases: torch.nn.modules.module.Module

+

2D Convolutional GRU Network.

+
+
+forward(cell_input: torch.Tensor, previous_state: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]
+

Computes Conv2dGRU forward pass given tensors cell_input and previous_state.

+
+
Parameters
+
    +
  • cell_input (Reconstruction input) –

  • +
  • previous_state (Tensor of previous states.) –

  • +
+
+
Return type
+

Output and new states.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet module

+
+
+class mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet.RecurrentInit(in_channels: int, out_channels: int, channels: Tuple[int, ...], dilations: Tuple[int, ...], depth: int = 2, multiscale_depth: int = 1)[source]
+

Bases: torch.nn.modules.module.Module

+

Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in Yiasemis, George, et al. +The RSI module learns to initialize the recurrent hidden state \(h_0\), input of the first +RecurrentVarNetBlock of the RecurrentVarNet.

+

References

+
+

Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.

+
+
+
+forward(x: torch.Tensor) torch.Tensor[source]
+

Computes initialization for recurrent unit given input x.

+
+
Parameters
+

x (Initialization for RecurrentInit.) –

+
+
Return type
+

Initial recurrent hidden state from input x.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet.RecurrentVarNetBlock(in_channels: int = 2, hidden_channels: int = 64, num_layers: int = 4, fft_type: str = 'orthogonal')[source]
+

Bases: torch.nn.modules.module.Module

+

Recurrent Variational Network Block \(\mathcal{H}_{ heta_{t}}\) as presented in Yiasemis, George, et al.

+

References

+
+

Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.

+
+
+
+forward(current_kspace: torch.Tensor, masked_kspace: torch.Tensor, sampling_mask: torch.Tensor, sensitivity_map: torch.Tensor, hidden_state: Union[None, torch.Tensor], coil_dim: int = 1, complex_dim: int = - 1) Tuple[torch.Tensor, torch.Tensor][source]
+

Computes forward pass of RecurrentVarNetBlock.

+
+
Parameters
+
    +
  • current_kspace (Current k-space prediction.) – torch.Tensor, shape [batch_size, n_coil, height, width, 2]

  • +
  • masked_kspace (Subsampled k-space.) – torch.Tensor, shape [batch_size, n_coil, height, width, 2]

  • +
  • sampling_mask (Sampling mask.) – torch.Tensor, shape [batch_size, 1, height, width, 1]

  • +
  • sensitivity_map (Coil sensitivities.) – torch.Tensor, shape [batch_size, n_coil, height, width, 2]

  • +
  • hidden_state (ConvGRU hidden state.) – None or torch.Tensor, shape [batch_size, n_l, height, width, hidden_channels]

  • +
  • coil_dim (Coil dimension.) – int, Default: 1.

  • +
  • complex_dim (Complex dimension.) – int, Default: -1.

  • +
+
+
Returns
+

    +
  • new_kspace (New k-space prediction.) – torch.Tensor, shape [batch_size, n_coil, height, width, 2]

  • +
  • hidden_state (Next hidden state.) – list of torch.Tensor, shape [batch_size, hidden_channels, height, width, num_layers]

  • +
+

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.rim.html b/docs/build/html/mridc.collections.reconstruction.models.rim.html new file mode 100644 index 00000000..77372ac0 --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.rim.html @@ -0,0 +1,478 @@ + + + + + + + mridc.collections.reconstruction.models.rim package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.rim package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.rim.conv_layers module

+
+
+class mridc.collections.reconstruction.models.rim.conv_layers.ConvNonlinear(input_size, features, conv_dim, kernel_size, dilation, bias, nonlinear='relu')[source]
+

Bases: torch.nn.modules.module.Module

+

A convolutional layer with nonlinearity.

+
+
+check_forward_input(_input)[source]
+

Checks input for correct size and shape.

+
+ +
+
+static determine_conv_class(n_dim)[source]
+

Determines the convolutional layer class.

+
+ +
+
+extra_repr()[source]
+

Extra information about the layer.

+
+ +
+
+forward(_input)[source]
+

Forward pass of the convolutional layer.

+
+ +
+
+reset_parameters()[source]
+

Resets the parameters of the convolutional layer.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.rim.conv_layers.ConvRNNStack(convs, rnn)[source]
+

Bases: torch.nn.modules.module.Module

+

A stack of convolutional RNNs.

+
+
+forward(x, hidden)[source]
+
+
Parameters
+
    +
  • x ([batch_size, seq_len, input_size]) –

  • +
  • hidden ([num_layers * num_directions, batch_size, hidden_size) –

  • +
+
+
Returns
+

output

+
+
Return type
+

[batch_size, seq_len, hidden_size]

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.rim.rim_block module

+
+
+class mridc.collections.reconstruction.models.rim.rim_block.RIMBlock(recurrent_layer=None, conv_filters=None, conv_kernels=None, conv_dilations=None, conv_bias=None, recurrent_filters=None, recurrent_kernels=None, recurrent_dilations=None, recurrent_bias=None, depth: int = 2, time_steps: int = 8, conv_dim: int = 2, no_dc: bool = False, fft_type: str = 'orthogonal')[source]
+

Bases: torch.nn.modules.module.Module

+

RIMBlock is a block of Recurrent Inference Machines (RIMs).

+
+
+forward(pred: torch.Tensor, masked_kspace: torch.Tensor, sense: torch.Tensor, mask: torch.Tensor, eta: Optional[torch.Tensor] = None, hx: Optional[torch.Tensor] = None, sigma: float = 1.0, keep_eta: bool = False) Tuple[Any, Optional[Union[list, torch.Tensor]]][source]
+

Forward pass of the RIMBlock.

+
+
Parameters
+
    +
  • pred (Predicted k-space.) –

  • +
  • masked_kspace (Subsampled k-space.) –

  • +
  • sense (Coil sensitivity maps.) –

  • +
  • mask (Sample mask.) –

  • +
  • eta (Initial guess for the eta.) –

  • +
  • hx (Initial guess for the hidden state.) –

  • +
  • sigma (Noise level.) –

  • +
  • keep_eta (Whether to keep the eta.) –

  • +
+
+
Return type
+

Reconstructed image and hidden states.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.rim.rnn_cells module

+
+
+class mridc.collections.reconstruction.models.rim.rnn_cells.ConvGRUCell(input_size, hidden_size, conv_dim, kernel_size, dilation=1, bias=True)[source]
+

Bases: mridc.collections.reconstruction.models.rim.rnn_cells.ConvGRUCellBase

+

A Convolutional GRU cell.

+
+
+forward(_input, hx)[source]
+

Forward pass of the ConvGRUCell.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.rim.rnn_cells.ConvGRUCellBase(input_size, hidden_size, conv_dim, kernel_size, dilation, bias)[source]
+

Bases: torch.nn.modules.module.Module

+

Base class for Conv Gated Recurrent Unit (GRU) cells. +# TODO: add paper reference

+
+
+check_forward_hidden(_input, hx, hidden_label='')[source]
+

Check forward hidden.

+
+ +
+
+check_forward_input(_input)[source]
+

Check forward input.

+
+ +
+
+static determine_conv_class(n_dim)[source]
+

Determine the convolutional class to use.

+
+ +
+
+extra_repr()[source]
+

Extra information to be printed when printing the model.

+
+ +
+
+static orthotogonalize_weights(weights, chunks=1)[source]
+

Orthogonalize the weights of a convolutional layer.

+
+ +
+
+reset_parameters()[source]
+

Initialize parameters following the way proposed in the paper.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.rim.rnn_cells.ConvMGUCell(input_size, hidden_size, conv_dim, kernel_size, dilation=1, bias=True)[source]
+

Bases: mridc.collections.reconstruction.models.rim.rnn_cells.ConvMGUCellBase

+

Convolutional Minimal Gated Unit cell.

+
+
+forward(_input, hx)[source]
+

Forward the ConvMGUCell.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.rim.rnn_cells.ConvMGUCellBase(input_size, hidden_size, conv_dim, kernel_size, dilation, bias)[source]
+

Bases: torch.nn.modules.module.Module

+

A base class for a Convolutional Minimal Gated Unit cell. +# TODO: add paper reference

+
+
+check_forward_hidden(_input, hx, hidden_label='')[source]
+

Check the forward hidden.

+
+ +
+
+check_forward_input(_input)[source]
+

Check the forward input.

+
+ +
+
+static determine_conv_class(n_dim)[source]
+

Determine the convolutional class.

+
+ +
+
+extra_repr()[source]
+

Extra information about the ConvMGUCellBase.

+
+ +
+
+static orthotogonalize_weights(weights, chunks=1)[source]
+

Orthogonalize the weights.

+
+ +
+
+reset_parameters()[source]
+

Reset the parameters.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.rim.rnn_cells.IndRNNCell(input_size, hidden_size, conv_dim, kernel_size, dilation=1, bias=True)[source]
+

Bases: mridc.collections.reconstruction.models.rim.rnn_cells.IndRNNCellBase

+

Independently Recurrent Neural Network cell.

+
+
+forward(_input, hx)[source]
+

Forward propagate the RNN cell.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.rim.rnn_cells.IndRNNCellBase(input_size, hidden_size, conv_dim, kernel_size, dilation, bias)[source]
+

Bases: torch.nn.modules.module.Module

+

Base class for Independently RNN cells as presented in 1.

+

References

+
+
1
+

Li, S. et al. (2018) ‘Independently Recurrent Neural Network (IndRNN): Building A Longer and Deeper RNN’, Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, (1), pp. 5457–5466. doi: 10.1109/CVPR.2018.00572.

+
+
+
+
+check_forward_hidden(_input, hx, hidden_label='')[source]
+

Check forward hidden.

+
+ +
+
+check_forward_input(_input)[source]
+

Check forward input.

+
+ +
+
+static determine_conv_class(n_dim)[source]
+

Determine the convolutional class.

+
+ +
+
+extra_repr()[source]
+

Extra information about the module, used for printing.

+
+ +
+
+static orthotogonalize_weights(weights, chunks=1)[source]
+

Orthogonalize the weights.

+
+ +
+
+reset_parameters()[source]
+

Reset the parameters.

+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.rim.utils module

+
+
+mridc.collections.reconstruction.models.rim.utils.log_likelihood_gradient(eta: torch.Tensor, masked_kspace: torch.Tensor, sense: torch.Tensor, mask: torch.Tensor, sigma: float, fft_type: str = 'orthogonal') torch.Tensor[source]
+

Computes the gradient of the log-likelihood function.

+
+
Parameters
+
    +
  • eta (Initial guess for the reconstruction.) –

  • +
  • masked_kspace (Subsampled k-space data.) –

  • +
  • sense (Sensing matrix.) –

  • +
  • mask (Sampling mask.) –

  • +
  • sigma (Noise level.) –

  • +
  • fft_type (Type of FFT to use.) –

  • +
+
+
Return type
+

Gradient of the log-likelihood function.

+
+
+
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.sigmanet.html b/docs/build/html/mridc.collections.reconstruction.models.sigmanet.html new file mode 100644 index 00000000..ffe117dc --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.sigmanet.html @@ -0,0 +1,539 @@ + + + + + + + mridc.collections.reconstruction.models.sigmanet package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.sigmanet package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.sigmanet.dc_layers module

+
+
+class mridc.collections.reconstruction.models.sigmanet.dc_layers.ConjugateGradient(*args, **kwargs)[source]
+

Bases: torch.autograd.function.Function

+

Conjugate Gradient solver for the prox of the data term.

+
+
+static backward(ctx, grad_x)[source]
+

Backward pass of the conjugate gradient solver.

+
+
Parameters
+
    +
  • ctx (Context object.) –

  • +
  • grad_x (Gradient of the output image.) –

  • +
+
+
Returns
+

grad_z

+
+
Return type
+

Gradient of the input image.

+
+
+
+ +
+
+static complexDot(data1, data2)[source]
+

Complex dot product of two tensors.

+
+ +
+
+static forward(ctx, z, lambdaa, y, smaps, mask, tol, max_iter, fft_type)[source]
+

Forward pass of the conjugate gradient solver.

+
+
Parameters
+
    +
  • ctx (Context object.) –

  • +
  • z (Input image.) –

  • +
  • lambdaa (Regularization parameter.) –

  • +
  • y (Subsampled k-space data.) –

  • +
  • smaps (Coil sensitivity maps.) –

  • +
  • mask (Sampling mask.) –

  • +
  • tol (Tolerance for the stopping criterion.) –

  • +
  • max_iter (Maximum number of iterations.) –

  • +
  • fft_type (FFT type.) –

  • +
+
+
Returns
+

z

+
+
Return type
+

Output image.

+
+
+
+ +
+
+static solve(x0, M, tol, max_iter)[source]
+

Solve the linear system Mx=b using conjugate gradient.

+
+ +
+ +
+
+class mridc.collections.reconstruction.models.sigmanet.dc_layers.DCLayer(lambda_init=0.0, learnable=True, fft_type='orthogonal')[source]
+

Bases: torch.nn.modules.module.Module

+

Data Consistency layer from DC-CNN, apply for single coil mainly

+
+
+forward(x, y, mask)[source]
+

Forward pass of the data-consistency block.

+
+
Parameters
+
    +
  • x (Input image.) –

  • +
  • y (Subsampled k-space data.) –

  • +
  • mask (Sampling mask.) –

  • +
+
+
Return type
+

Output image.

+
+
+
+ +
+
+set_learnable(flag)[source]
+

Set the learnable flag of the parameters.

+
+
Parameters
+

flag (If True, the parameters of the model are learnable.) –

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.sigmanet.dc_layers.DataGDLayer(lambda_init, learnable=True, fft_type='orthogonal')[source]
+

Bases: torch.nn.modules.module.Module

+

DataLayer computing the gradient on the L2 dataterm.

+
+
+forward(x, y, smaps, mask)[source]
+
+
Parameters
+
    +
  • x (Input image.) –

  • +
  • y (Subsampled k-space data.) –

  • +
  • smaps (Coil sensitivity maps.) –

  • +
  • mask (Sampling mask.) –

  • +
+
+
Returns
+

data_loss

+
+
Return type
+

Data term loss.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.sigmanet.dc_layers.DataIDLayer(*args, **kwargs)[source]
+

Bases: torch.nn.modules.module.Module

+

Placeholder for the data layer.

+
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.sigmanet.dc_layers.DataProxCGLayer(lambda_init, tol=1e-06, iter=10, learnable=True, fft_type='orthogonal')[source]
+

Bases: torch.nn.modules.module.Module

+

Solving the prox wrt. dataterm using Conjugate Gradient as proposed by Aggarwal et al.

+
+
+forward(x, f, smaps, mask)[source]
+
+
Parameters
+
    +
  • x (Input image.) –

  • +
  • f (Subsampled k-space data.) –

  • +
  • smaps (Coil sensitivity maps.) –

  • +
  • mask (Sampling mask.) –

  • +
+
+
Returns
+

data_loss

+
+
Return type
+

Data term loss.

+
+
+
+ +
+
+set_learnable(flag)[source]
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.sigmanet.dc_layers.DataVSLayer(alpha_init, beta_init, learnable=True, fft_type='orthogonal')[source]
+

Bases: torch.nn.modules.module.Module

+

DataLayer using variable splitting formulation

+
+
+forward(x, y, smaps, mask)[source]
+

Forward pass of the data-consistency block.

+
+
Parameters
+
    +
  • x (Input image.) –

  • +
  • y (Subsampled k-space data.) –

  • +
  • smaps (Coil sensitivity maps.) –

  • +
  • mask (Sampling mask.) –

  • +
+
+
Return type
+

Output image.

+
+
+
+ +
+
+set_learnable(flag)[source]
+

Set the learnable flag of the parameters.

+
+
Parameters
+

flag (If True, the parameters of the model are learnable.) –

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

mridc.collections.reconstruction.models.sigmanet.sensitivity_net module

+
+
+class mridc.collections.reconstruction.models.sigmanet.sensitivity_net.ComplexInstanceNorm[source]
+

Bases: torch.nn.modules.module.Module

+

Motivated by ‘Deep Complex Networks’ (https://arxiv.org/pdf/1705.09792.pdf)

+
+
+complex_instance_norm(x, eps=1e-05)[source]
+

Operates on images x of size [nBatch, nSmaps, nFE, nPE, 2]

+
+ +
+
+complex_pseudocovariance(data)[source]
+

Data variable hast to be already mean-free! Operates on images x of size [nBatch, nSmaps, nFE, nPE, 2]

+
+ +
+
+forward(input)[source]
+

Operates on images x of size [nBatch, nSmaps, nFE, nPE, 2]

+
+ +
+
+normalize(x)[source]
+

Normalize the input x.

+
+ +
+
+set_normalization(input)[source]
+

Set the normalization parameters for a given input.

+
+ +
+
+training: bool
+
+ +
+
+unnormalize(x)[source]
+

Unnormalize the input x.

+
+ +
+ +
+
+class mridc.collections.reconstruction.models.sigmanet.sensitivity_net.ComplexNormWrapper(model)[source]
+

Bases: torch.nn.modules.module.Module

+

Wrapper for complex normalization.

+
+
+forward(input)[source]
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.sigmanet.sensitivity_net.SensitivityNetwork(num_iter, model, datalayer, shared_params=True, save_space=False, reset_cache=False)[source]
+

Bases: torch.nn.modules.module.Module

+

Sensitivity network with data term based on forward and adjoint containing the sensitivity maps

+
+
+copy_params(src_i, trg_j)[source]
+

copy i-th cascade net parameters to j-th cascade net parameters

+
+ +
+
+forward(x, y, smaps, mask)[source]
+
+
Parameters
+
    +
  • x (Input data.) –

  • +
  • y (Subsampled k-space data.) –

  • +
  • smaps (Coil sensitivity maps.) –

  • +
  • mask (Sampling mask.) –

  • +
+
+
Return type
+

Output data.

+
+
+
+ +
+
+forward_save_space(x, y, smaps, mask)[source]
+
+
Parameters
+
    +
  • x (Input data.) –

  • +
  • y (Subsampled k-space data.) –

  • +
  • smaps (Coil sensitivity maps.) –

  • +
  • mask (Sampling mask.) –

  • +
+
+
Return type
+

Output data.

+
+
+
+ +
+
+freeze(i)[source]
+

freeze parameter of cascade i

+
+ +
+
+freeze_all()[source]
+

freeze parameter of cascade i

+
+ +
+
+stage_training_init()[source]
+

set stage training flag to True

+
+ +
+
+stage_training_transition_i(copy=False)[source]
+

set stage training flag to True

+
+ +
+
+training: bool
+
+ +
+
+unfreeze(i)[source]
+

freeze parameter of cascade i

+
+ +
+
+unfreeze_all()[source]
+

freeze parameter of cascade i

+
+ +
+ +
+
+mridc.collections.reconstruction.models.sigmanet.sensitivity_net.matrix_invert(xx, xy, yx, yy)[source]
+

Invert a 2x2 matrix.

+
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.unet_base.html b/docs/build/html/mridc.collections.reconstruction.models.unet_base.html new file mode 100644 index 00000000..30a0b346 --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.unet_base.html @@ -0,0 +1,275 @@ + + + + + + + mridc.collections.reconstruction.models.unet_base package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.unet_base package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.unet_base.unet_block module

+
+
+class mridc.collections.reconstruction.models.unet_base.unet_block.ConvBlock(in_chans: int, out_chans: int, drop_prob: float)[source]
+

Bases: torch.nn.modules.module.Module

+

A Convolutional Block that consists of two convolution layers each followed by instance normalization, LeakyReLU +activation and dropout.

+
+
+forward(image: torch.Tensor) torch.Tensor[source]
+
+
Parameters
+

image (Input 4D tensor of shape (N, in_chans, H, W).) –

+
+
Return type
+

Output tensor of shape (N, out_chans, H, W).

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.unet_base.unet_block.NormUnet(chans: int, num_pools: int, in_chans: int = 2, out_chans: int = 2, drop_prob: float = 0.0, padding_size: int = 15, normalize: bool = True, norm_groups: int = 2)[source]
+

Bases: torch.nn.modules.module.Module

+

Normalized U-Net model.

+

This is the same as a regular U-Net, but with normalization applied to the input before the U-Net. +This keeps the values more numerically stable during training.

+
+
+static chan_complex_to_last_dim(x: torch.Tensor) torch.Tensor[source]
+

Convert the last dimension of the input to complex.

+
+ +
+
+static complex_to_chan_dim(x: torch.Tensor) torch.Tensor[source]
+

Convert the last dimension of the input to complex.

+
+ +
+
+forward(x: torch.Tensor) torch.Tensor[source]
+

Forward pass of the network.

+
+ +
+
+norm(x: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]
+

Normalize the input.

+
+ +
+
+pad(x: torch.Tensor) Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]][source]
+

Pad the input with zeros to make it square.

+
+ +
+
+training: bool
+
+ +
+
+unnorm(x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) torch.Tensor[source]
+

Unnormalize the input.

+
+ +
+
+static unpad(x: torch.Tensor, h_pad: List[int], w_pad: List[int], h_mult: int, w_mult: int) torch.Tensor[source]
+

Unpad the input.

+
+ +
+ +
+
+class mridc.collections.reconstruction.models.unet_base.unet_block.TransposeConvBlock(in_chans: int, out_chans: int)[source]
+

Bases: torch.nn.modules.module.Module

+

A Transpose Convolutional Block that consists of one convolution transpose layers followed by instance +normalization and LeakyReLU activation.

+
+
+forward(image: torch.Tensor) torch.Tensor[source]
+
+
Parameters
+

image (Input 4D tensor of shape (N, in_chans, H, W).) –

+
+
Return type
+

Output tensor of shape (N, out_chans, H*2, W*2).

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.unet_base.unet_block.Unet(in_chans: int, out_chans: int, chans: int = 32, num_pool_layers: int = 4, drop_prob: float = 0.0)[source]
+

Bases: torch.nn.modules.module.Module

+

PyTorch implementation of a U-Net model, as presented in 1.

+

References

+
+
1
+
    +
  1. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pages 234–241. Springer, 2015.

  2. +
+
+
+
+
+forward(image: torch.Tensor) torch.Tensor[source]
+
+
Parameters
+

image (Input 4D tensor of shape (N, in_chans, H, W).) –

+
+
Return type
+

Output tensor of shape (N, out_chans, H, W).

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.variablesplittingnet.html b/docs/build/html/mridc.collections.reconstruction.models.variablesplittingnet.html new file mode 100644 index 00000000..e1c80a47 --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.variablesplittingnet.html @@ -0,0 +1,248 @@ + + + + + + + mridc.collections.reconstruction.models.variablesplittingnet package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.variablesplittingnet package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block module

+
+
+class mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block.DataConsistencyLayer[source]
+

Bases: torch.nn.modules.module.Module

+

Data consistency layer for the VSNet. +This layer is used to ensure that the output of the VSNet is the same as the input.

+
+
+forward(pred_kspace, ref_kspace, mask)[source]
+

Forward pass of the data consistency layer.

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block.VSNetBlock(denoiser_block: torch.nn.modules.container.ModuleList, data_consistency_block: torch.nn.modules.container.ModuleList, weighted_average_block: torch.nn.modules.container.ModuleList, num_cascades: int = 8, fft_type: str = 'orthogonal')[source]
+

Bases: torch.nn.modules.module.Module

+

Model block for the Variable-Splitting Network inspired by 1.

+

References

+
+
1
+

Duan, J. et al. (2019) ‘Vs-net: Variable splitting network for accelerated parallel MRI reconstruction’, Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture Notes in Bioinformatics), 11767 LNCS, pp. 713–722. doi: 10.1007/978-3-030-32251-9_78.

+
+
+
+
+forward(kspace: torch.Tensor, sens_maps: torch.Tensor, mask: torch.Tensor) List[Union[torch.Tensor, Any]][source]
+
+
Parameters
+
    +
  • kspace (Reference k-space data.) –

  • +
  • sens_maps (Coil sensitivity maps.) –

  • +
  • mask (Mask to apply to the data.) –

  • +
+
+
Return type
+

Reconstructed image.

+
+
+
+ +
+
+sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) torch.Tensor[source]
+

Expand the sensitivity maps to the same size as the input.

+
+
Parameters
+
    +
  • x (Input data.) –

  • +
  • sens_maps (Coil Sensitivity maps.) –

  • +
+
+
Return type
+

SENSE reconstruction expanded to the same size as the input sens_maps.

+
+
+
+ +
+
+sens_reduce(x: torch.Tensor, sens_maps: torch.Tensor) torch.Tensor[source]
+

Reduce the sensitivity maps.

+
+
Parameters
+
    +
  • x (Input data.) –

  • +
  • sens_maps (Coil Sensitivity maps.) –

  • +
+
+
Return type
+

SENSE coil-combined reconstruction.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block.WeightedAverageTerm[source]
+

Bases: torch.nn.modules.module.Module

+

Weighted average term for the VSNet.

+
+
+forward(x, Sx)[source]
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.models.varnet.html b/docs/build/html/mridc.collections.reconstruction.models.varnet.html new file mode 100644 index 00000000..e5db1442 --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.models.varnet.html @@ -0,0 +1,199 @@ + + + + + + + mridc.collections.reconstruction.models.varnet package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.models.varnet package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.models.varnet.vn_block module

+
+
+class mridc.collections.reconstruction.models.varnet.vn_block.VarNetBlock(model: torch.nn.modules.module.Module, fft_type: str = 'orthogonal', no_dc: bool = False)[source]
+

Bases: torch.nn.modules.module.Module

+

Model block for end-to-end variational network.

+

This model applies a combination of soft data consistency with the input model as a regularizer. +A series of these blocks can be stacked to form the full variational network.

+
+
+forward(pred: torch.Tensor, ref_kspace: torch.Tensor, sens_maps: torch.Tensor, mask: torch.Tensor) torch.Tensor[source]
+
+
Parameters
+
    +
  • kspace (Reference k-space data.) –

  • +
  • sens_maps (Coil sensitivity maps.) –

  • +
  • mask (Mask to apply to the data.) –

  • +
+
+
Return type
+

Reconstructed image.

+
+
+
+ +
+
+sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) torch.Tensor[source]
+

Expand the sensitivity maps to the same size as the input.

+
+
Parameters
+
    +
  • x (Input data.) –

  • +
  • sens_maps (Coil Sensitivity maps.) –

  • +
+
+
Return type
+

SENSE reconstruction expanded to the same size as the input sens_maps.

+
+
+
+ +
+
+sens_reduce(x: torch.Tensor, sens_maps: torch.Tensor) torch.Tensor[source]
+

Reduce the sensitivity maps.

+
+
Parameters
+
    +
  • x (Input data.) –

  • +
  • sens_maps (Coil Sensitivity maps.) –

  • +
+
+
Return type
+

SENSE coil-combined reconstruction.

+
+
+
+ +
+
+training: bool
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.collections.reconstruction.parts.html b/docs/build/html/mridc.collections.reconstruction.parts.html new file mode 100644 index 00000000..842dc46d --- /dev/null +++ b/docs/build/html/mridc.collections.reconstruction.parts.html @@ -0,0 +1,281 @@ + + + + + + + mridc.collections.reconstruction.parts package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.collections.reconstruction.parts package

+
+

Submodules

+
+
+

mridc.collections.reconstruction.parts.transforms module

+
+
+class mridc.collections.reconstruction.parts.transforms.MRIDataTransforms(mask_func: Optional[List[mridc.collections.reconstruction.data.subsample.MaskFunc]] = None, shift_mask: bool = False, mask_center_scale: Optional[float] = 0.02, half_scan_percentage: float = 0.0, crop_size: Optional[Tuple[int, int]] = None, kspace_crop: bool = False, crop_before_masking: bool = True, kspace_zero_filling_size: Optional[Tuple] = None, normalize_inputs: bool = False, fft_type: str = 'orthogonal', use_seed: bool = True)[source]
+

Bases: object

+

MRI preprocessing data transforms.

+
+
+__call__(kspace: numpy.ndarray, sensitivity_map: numpy.ndarray, mask: numpy.ndarray, eta: numpy.ndarray, target: numpy.ndarray, attrs: Dict, fname: str, slice_idx: int) Tuple[Union[List[Union[torch.Tensor, Any]], torch.Tensor, Any], Union[torch.Tensor, None, Any], Union[List, Any], Union[torch.Tensor, None, Any], Union[torch.Tensor, Any], str, int, Union[List, torch.Tensor, Any]][source]
+

Apply the data transform.

+
+
Parameters
+
    +
  • kspace (The kspace.) –

  • +
  • sensitivity_map (The sensitivity map.) –

  • +
  • mask (The mask.) –

  • +
  • eta (The initial estimation.) –

  • +
  • target (The target.) –

  • +
  • attrs (The attributes.) –

  • +
  • fname (The file name.) –

  • +
  • slice_idx (The slice number.) –

  • +
+
+
Return type
+

The transformed data.

+
+
+
+ +
+ +
+
+

mridc.collections.reconstruction.parts.utils module

+
+
+mridc.collections.reconstruction.parts.utils.apply_mask(data: torch.Tensor, mask_func: mridc.collections.reconstruction.data.subsample.MaskFunc, seed: Optional[Union[int, Tuple[int, ...]]] = None, padding: Optional[Sequence[int]] = None, shift: bool = False, half_scan_percentage: Optional[float] = 0.0, center_scale: Optional[float] = 0.02) Tuple[Any, Any, Any][source]
+

Subsample given k-space by multiplying with a mask.

+
+
Parameters
+
    +
  • data (The input k-space data. This should have at least 3 dimensions, where dimensions -3 and -2 are the) – spatial dimensions, and the final dimension has size 2 (for complex values).

  • +
  • mask_func (A function that takes a shape (tuple of ints) and a random number seed and returns a mask.) –

  • +
  • seed (Seed for the random number generator.) –

  • +
  • padding (Padding value to apply for mask.) –

  • +
  • shift (Toggle to shift mask when subsampling. Applicable on 2D data.) –

  • +
  • half_scan_percentage (Percentage of kspace to be dropped.) –

  • +
  • center_scale (Scale of the center of the mask. Applicable on Gaussian masks.) –

  • +
+
+
Return type
+

Tuple of subsampled k-space, mask, and mask indices.

+
+
+
+ +
+
+mridc.collections.reconstruction.parts.utils.batched_mask_center(x: torch.Tensor, mask_from: torch.Tensor, mask_to: torch.Tensor, mask_type: str = '2D') torch.Tensor[source]
+

Initializes a mask with the center filled in. Can operate with different masks for each batch element.

+
+
Parameters
+
    +
  • x (The input real image or batch of real images.) –

  • +
  • mask_from (Part of center to start filling.) –

  • +
  • mask_to (Part of center to end filling.) –

  • +
  • mask_type (Type of mask to apply. Can be either "1D" or "2D".) –

  • +
+
+
Return type
+

A mask with the center filled.

+
+
+
+ +
+
+mridc.collections.reconstruction.parts.utils.center_crop(data: torch.Tensor, shape: Tuple[int, int]) torch.Tensor[source]
+

Apply a center crop to the input real image or batch of real images.

+
+
Parameters
+
    +
  • data (The input tensor to be center cropped. It should have at least 2 dimensions and the cropping is applied) – along the last two dimensions.

  • +
  • shape (The output shape. The shape should be smaller than the corresponding dimensions of data.) –

  • +
+
+
Return type
+

The center cropped image.

+
+
+
+ +
+
+mridc.collections.reconstruction.parts.utils.center_crop_to_smallest(x: Union[torch.Tensor, numpy.ndarray], y: Union[torch.Tensor, numpy.ndarray]) Tuple[Union[torch.Tensor, numpy.ndarray], Union[torch.Tensor, numpy.ndarray]][source]
+

Apply a center crop on the larger image to the size of the smaller.

+
+
The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at dim=-1 and y is smaller than x at dim=-2,

then the returned dimension will be a mixture of the two.

+
+
+
+
Parameters
+
    +
  • x (The first image.) –

  • +
  • y (The second image.) –

  • +
+
+
Return type
+

Tuple of tensors x and y, each cropped to the minimum size.

+
+
+
+ +
+
+mridc.collections.reconstruction.parts.utils.complex_center_crop(data: torch.Tensor, shape: Tuple[int, int]) torch.Tensor[source]
+

Apply a center crop to the input image or batch of complex images.

+
+
Parameters
+
    +
  • data (The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is) – applied along dimensions -3 and -2 and the last dimensions should have a size of 2.

  • +
  • shape (The output shape. The shape should be smaller than the corresponding dimensions of data.) –

  • +
+
+
Return type
+

The center cropped image.

+
+
+
+ +
+
+mridc.collections.reconstruction.parts.utils.mask_center(x: torch.Tensor, mask_from: Optional[int], mask_to: Optional[int], mask_type: str = '2D') torch.Tensor[source]
+

Apply a center crop to the input real image or batch of real images.

+
+
Parameters
+
    +
  • x (The input real image or batch of real images.) –

  • +
  • mask_from (Part of center to start filling.) –

  • +
  • mask_to (Part of center to end filling.) –

  • +
  • mask_type (Type of mask to apply. Can be either "1D" or "2D".) –

  • +
+
+
Return type
+

A mask with the center filled.

+
+
+
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.core.classes.html b/docs/build/html/mridc.core.classes.html new file mode 100644 index 00000000..a5983d7d --- /dev/null +++ b/docs/build/html/mridc.core.classes.html @@ -0,0 +1,1024 @@ + + + + + + + mridc.core.classes package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.core.classes package

+
+

Submodules

+
+
+

mridc.core.classes.common module

+
+
+class mridc.core.classes.common.FileIO[source]
+

Bases: abc.ABC

+

Base class for file IO.

+
+
+classmethod from_config_file(path2yaml_file: str)[source]
+

Instantiates an instance of mridc Model from YAML config file. Weights will be initialized randomly.

+
+
Parameters
+

path2yaml_file (path to yaml file with model configuration) –

+
+
Return type
+

Model instance.

+
+
+
+ +
+
+classmethod restore_from(restore_path: str, override_config_path: Optional[str] = None, map_location: Optional[torch.device] = None, strict: bool = True, return_config: bool = False, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None, save_restore_connector: Optional[mridc.core.connectors.save_restore_connector.SaveRestoreConnector] = None)[source]
+

Restores module/model with weights

+
+ +
+
+save_to(save_path: str)[source]
+

Saves module/model with weights

+
+ +
+
+to_config_file(path2yaml_file: str)[source]
+

Saves current instance’s configuration to YAML config file. Weights will not be saved.

+
+
Parameters
+

path2yaml_file (path2yaml_file: path to yaml file where model configuration will be saved.) –

+
+
+
+ +
+ +
+
+class mridc.core.classes.common.Model[source]
+

Bases: mridc.core.classes.common.Typing, mridc.core.classes.common.Serialization, mridc.core.classes.common.FileIO, abc.ABC

+

Abstract class offering interface which should be implemented by all mridc models.

+
+
+classmethod from_pretrained(model_name: str, refresh_cache: bool = False, override_config_path: Optional[str] = None, map_location: Optional[torch.device] = None, strict: bool = True, return_config: bool = False, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None, save_restore_connector: Optional[mridc.core.connectors.save_restore_connector.SaveRestoreConnector] = None)[source]
+

Instantiates an instance of mridc. Use restore_from() to instantiate from a local .mridc file.

+
+
Parameters
+
    +
  • model_name (String key which will be used to find the module.) –

  • +
  • refresh_cache (If set to True, then when fetching from cloud, this will re-fetch the file from cloud even if it) – is already found in a cache locally.

  • +
  • override_config_path (Path to a yaml config that will override the internal config file.) –

  • +
  • map_location (Optional torch.device() to map the instantiated model to a device. By default (None), it will) –

  • +
  • available (select a GPU if) –

  • +
  • otherwise. (falling back to CPU) –

  • +
  • strict (Passed to torch.load_state_dict. By default, True.) –

  • +
  • return_config (If set to true, will return just the underlying config of the restored model as an) –

  • +
  • model. (OmegaConf/DictConfig object without instantiating the) –

  • +
  • trainer (Optional Trainer objects to use for restoring the model.) –

  • +
  • save_restore_connector (Optional SaveRestoreConnector object to use for restoring the model.) –

  • +
+
+
Return type
+

A model instance of a particular model class or its underlying config (if return_config is set).

+
+
+
+ +
+
+classmethod get_available_model_names() List[str][source]
+

Returns the list of model names available. To get the complete model description use list_available_models().

+
+
Return type
+

A list of model names.

+
+
+
+ +
+
+classmethod list_available_models() Optional[mridc.core.classes.common.PretrainedModelInfo][source]
+

Should list all pre-trained models available. +Note: There is no check that requires model names and aliases to be unique. In the case of a collision, +whatever model (or alias) is listed first in the returned list will be instantiated.

+
+
Return type
+

A list of PretrainedModelInfo entries.

+
+
+
+ +
+ +
+
+class mridc.core.classes.common.PretrainedModelInfo(pretrained_model_name: str, description: str, location: str, class_: Optional[mridc.core.classes.common.Model] = None, aliases: Optional[List[str]] = None)[source]
+

Bases: object

+

Class to store information about a pretrained model.

+
+
+aliases: Optional[List[str]] = None
+
+ +
+
+class_: Optional[mridc.core.classes.common.Model] = None
+
+ +
+
+description: str
+
+ +
+
+location: str
+
+ +
+
+pretrained_model_name: str
+
+ +
+ +
+
+class mridc.core.classes.common.Serialization[source]
+

Bases: abc.ABC

+

Base class for serialization.

+
+
+classmethod from_config_dict(config: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Instantiates object using DictConfig-based configuration

+
+ +
+
+to_config_dict() omegaconf.dictconfig.DictConfig[source]
+

Returns object’s configuration to config dictionary

+
+ +
+ +
+
+class mridc.core.classes.common.Typing[source]
+

Bases: abc.ABC

+

An interface which endows module with neural types

+
+
+property input_types: Optional[Dict[str, mridc.core.neural_types.neural_type.NeuralType]]
+

Define these to enable input neural type checks

+
+ +
+
+property output_types: Optional[Dict[str, mridc.core.neural_types.neural_type.NeuralType]]
+

Define these to enable output neural type checks

+
+ +
+ +
+
+mridc.core.classes.common.is_typecheck_enabled()[source]
+

Getter method for typechecking state.

+
+ +
+
+class mridc.core.classes.common.typecheck(input_types: Optional[Union[mridc.core.classes.common.typecheck.TypeState, Dict[str, mridc.core.neural_types.neural_type.NeuralType]]] = TypeState.UNINITIALIZED, output_types: Optional[Union[mridc.core.classes.common.typecheck.TypeState, Dict[str, mridc.core.neural_types.neural_type.NeuralType]]] = TypeState.UNINITIALIZED, ignore_collections: bool = False)[source]
+

Bases: object

+

Decorator to check the type of the input arguments.

+
+
+class TypeState(value)[source]
+

Bases: enum.Enum

+

Placeholder to denote the default value of type information provided. +If the constructor of this decorator is used to override the class level type definition, this enum value +indicate that types will be overridden.

+
+
+UNINITIALIZED = 0
+
+ +
+ +
+
+static disable_checks()[source]
+

Temporarily disable type checks.

+
+ +
+
+static set_typecheck_enabled(enabled: bool = True)[source]
+

Set the global typecheck flag.

+
+ +
+ +
+
+

mridc.core.classes.dataset module

+
+
+class mridc.core.classes.dataset.Dataset[source]
+

Bases: torch.utils.data.dataset.Dataset, mridc.core.classes.common.Typing, mridc.core.classes.common.Serialization, abc.ABC

+

Dataset with output ports. Please Note: Subclasses of IterableDataset should not implement input_types.

+
+
+collate_fn(batch)[source]
+

This is the method that user pass as functor to DataLoader. +The method optionally performs neural type checking and add types to the outputs.

+

Please note, subclasses of Dataset should not implement input_types.

+

# Usage:

+
dataloader = torch.utils.data.DataLoader(
+        ....,
+        collate_fn=dataset.collate_fn,
+        ....
+)
+
+
+
+
Return type
+

Collated batch, with or without types.

+
+
+
+ +
+ +
+
+class mridc.core.classes.dataset.DatasetConfig(batch_size: int = 32, drop_last: bool = False, shuffle: bool = False, num_workers: Optional[int] = 0, pin_memory: bool = True)[source]
+

Bases: object

+

Dataset configuration.

+
+
+batch_size: int = 32
+
+ +
+
+drop_last: bool = False
+
+ +
+
+num_workers: Optional[int] = 0
+
+ +
+
+pin_memory: bool = True
+
+ +
+
+shuffle: bool = False
+
+ +
+ +
+
+class mridc.core.classes.dataset.IterableDataset[source]
+

Bases: torch.utils.data.dataset.IterableDataset, mridc.core.classes.common.Typing, mridc.core.classes.common.Serialization, abc.ABC

+

Iterable Dataset with output ports. +Please Note: Subclasses of IterableDataset should not implement input_types.

+
+
+collate_fn(batch)[source]
+

This is the method that user pass as functor to DataLoader. +The method optionally performs neural type checking and add types to the outputs.

+

# Usage:

+
dataloader = torch.utils.data.DataLoader(
+        ....,
+        collate_fn=dataset.collate_fn,
+        ....
+)
+
+
+
+
Return type
+

Collated batch, with or without types.

+
+
+
+ +
+ +
+
+

mridc.core.classes.export module

+
+
+class mridc.core.classes.export.ExportFormat(value)[source]
+

Bases: enum.Enum

+

Which format to use when exporting a Neural Module for deployment

+
+
+ONNX = (1,)
+
+ +
+
+TORCHSCRIPT = (2,)
+
+ +
+ +
+
+class mridc.core.classes.export.Exportable[source]
+

Bases: abc.ABC

+

This Interface should be implemented by particular classes derived from mridc.core.NeuralModule or +mridc.core.ModelPT. It gives these entities ability to be exported for deployment to formats such as ONNX.

+
+
+property disabled_deployment_input_names
+

Implement this method to return a set of input names disabled for export

+
+ +
+
+property disabled_deployment_output_names
+

Implement this method to return a set of output names disabled for export

+
+ +
+
+export(output: str, input_example=None, verbose=False, export_params=True, do_constant_folding=True, onnx_opset_version=None, try_script: bool = False, training=<TrainingMode.EVAL: 0>, check_trace: bool = False, use_dynamic_axes: bool = True, dynamic_axes=None, check_tolerance=0.01)[source]
+

Export the module to a file.

+
+
Parameters
+
    +
  • output (The output file path.) –

  • +
  • input_example (A dictionary of input names and values.) –

  • +
  • verbose (If True, print out the export process.) –

  • +
  • export_params (If True, export the parameters of the module.) –

  • +
  • do_constant_folding (If True, do constant folding.) –

  • +
  • onnx_opset_version (The ONNX opset version to use.) –

  • +
  • try_script (If True, try to export as TorchScript.) –

  • +
  • training (Training mode for the export.) –

  • +
  • check_trace (If True, check the trace of the exported model.) –

  • +
  • use_dynamic_axes (If True, use dynamic axes for the export.) –

  • +
  • dynamic_axes (A dictionary of input names and dynamic axes.) –

  • +
  • check_tolerance (The tolerance for the check_trace.) –

  • +
+
+
+
+ +
+
+property input_module
+
+ +
+
+property input_names
+

Implement this method to return a list of input names

+
+ +
+
+property output_module
+
+ +
+
+property output_names
+

Override this method to return a set of output names disabled for export

+
+ +
+
+property supported_export_formats
+

Implement this method to return a set of export formats supported. Default is all types.

+
+ +
+ +
+
+

mridc.core.classes.loss module

+
+
+class mridc.core.classes.loss.Loss(size_average=None, reduce=None, reduction: str = 'mean')[source]
+

Bases: torch.nn.modules.loss._Loss, mridc.core.classes.common.Typing, mridc.core.classes.common.Serialization

+

Inherit this class to implement custom loss.

+
+
+reduction: str
+
+ +
+ +
+
+

mridc.core.classes.modelPT module

+
+
+class mridc.core.classes.modelPT.ModelPT(cfg: omegaconf.dictconfig.DictConfig, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Bases: pytorch_lightning.core.lightning.LightningModule, mridc.core.classes.common.Model

+

Interface for Pytorch-lightning based mridc models

+
+
+classmethod __init_subclass__() None[source]
+

This method is called when a subclass is created.

+
+ +
+
+property cfg
+

Property that holds the finalized internal config of the model.

+
+

Note

+

Changes to this config are not reflected in the state of the model. +Please create a new model using an updated config to properly update the model.

+
+
+ +
+
+configure_optimizers()[source]
+

Configure optimizers and schedulers for training.

+
+ +
+
+classmethod extract_state_dict_from(restore_path: str, save_dir: str, split_by_module: bool = False, save_restore_connector: Optional[mridc.core.connectors.save_restore_connector.SaveRestoreConnector] = None)[source]
+

Extract the state dict(s) from a provided .mridc tarfile and save it to a directory.

+
+
Parameters
+
    +
  • restore_path (path to .mridc file from which state dict(s) should be extracted) –

  • +
  • save_dir (directory in which the saved state dict(s) should be stored) –

  • +
  • split_by_module (bool flag, which determines whether the output checkpoint should be for the entire Model, or) –

  • +
  • Model (the individual module's that comprise the) –

  • +
  • save_restore_connector (Can be overridden to add custom save and restore logic.) –

  • +
+
+
+

Example

+

To convert the .mridc tarfile into a single Model level PyTorch checkpoint

+
state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc',             './asr_ckpts')
+
+
+

To restore a model from a Model level checkpoint

+
model = mridc.collections.asr.models.EncDecCTCModel(cfg)  # or any other method of restoration
+model.load_state_dict(torch.load("./asr_ckpts/model_weights.ckpt"))
+
+
+

To convert the .mridc tarfile into multiple Module level PyTorch checkpoints

+
state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc',             './asr_ckpts', split_by_module=True)
+
+
+

To restore a module from a Module level checkpoint

+
model = mridc.collections.asr.models.EncDecCTCModel(cfg)  # or any other method of restoration
+# load the individual components
+model.preprocessor.load_state_dict(torch.load("./asr_ckpts/preprocessor.ckpt"))
+model.encoder.load_state_dict(torch.load("./asr_ckpts/encoder.ckpt"))
+model.decoder.load_state_dict(torch.load("./asr_ckpts/decoder.ckpt"))
+
+
+
+
Return type
+

The state dict that was loaded from the original .mridc checkpoint.

+
+
+
+ +
+
+get_test_dataloader_prefix(dataloader_idx: int = 0) str[source]
+

Get the name of one or more data loaders, which will be prepended to all logs.

+
+ +
+
+get_validation_dataloader_prefix(dataloader_idx: int = 0) str[source]
+

Get the name of one or more data loaders, which will be prepended to all logs.

+
+ +
+
+classmethod load_from_checkpoint(checkpoint_path: str, *args, map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, hparams_file: Optional[str] = None, strict: bool = True, **kwargs)[source]
+

Loads ModelPT from checkpoint, with some maintenance of restoration. +For documentation, please refer to LightningModule.load_from_checkpoint() documentation.

+
+ +
+
+load_part_of_state_dict(state_dict, include, exclude, load_from_string)[source]
+

Load part of the state dict.

+
+ +
+
+maybe_init_from_pretrained_checkpoint(cfg: omegaconf.omegaconf.OmegaConf, map_location: str = 'cpu')[source]
+

Initializes a given model with the parameters obtained via specific config arguments. The state dict of the provided model will be updated with strict=False setting to prevent requirement of exact model parameters matching.

+

Initializations

+

init_from_mridc_model: Str path to a .mridc model, which will be instantiated in order to extract the state dict.

+

init_from_pretrained_model: Str name of a pretrained model checkpoint (obtained via cloud). The model will be downloaded (or a cached copy will be used), instantiated and then its state dict will be extracted.

+

init_from_ptl_ckpt: Str name of a Pytorch Lightning checkpoint file. It will be loaded and the state dict will extract.

+
+
Parameters
+
    +
  • cfg (The config used to instantiate the model. It needs only contain one of the above keys.) –

  • +
  • map_location (str or torch.device() which represents where the intermediate state dict (from the pretrained model or checkpoint) will be loaded.) –

  • +
+
+
+
+ +
+
+static multi_test_epoch_end(outputs: Union[object, List[Dict[str, torch.Tensor]]], dataloader_idx: int = 0) None[source]
+

Adds support for multiple test datasets. Should be overridden by subclass, to obtain appropriate logs for each +of the dataloaders.

+
+
Parameters
+
    +
  • outputs (Same as that provided by LightningModule.validation_epoch_end() for a single dataloader.) –

  • +
  • dataloader_idx (int representing the index of the dataloader.) –

  • +
+
+
Returns
+

    +
  • A dictionary of values, optionally containing a sub-dict log, such that the values in the log will be

  • +
  • pre-pended by the dataloader prefix.

  • +
+

+
+
+
+ +
+
+static multi_validation_epoch_end(outputs: Optional[Union[object, List[Dict[str, torch.Tensor]]]], dataloader_idx: int = 0) None[source]
+
+
Adds support for multiple validation datasets. Should be overridden by subclass, to obtain appropriate logs for

each of the dataloaders.

+
+
+
+
Parameters
+
    +
  • outputs (Same as that provided by LightningModule.validation_epoch_end() for a single dataloader.) –

  • +
  • dataloader_idx (int representing the index of the dataloader.) –

  • +
+
+
Returns
+

    +
  • A dictionary of values, optionally containing a sub-dict log, such that the values in the log will be

  • +
  • pre-pended by the dataloader prefix.

  • +
+

+
+
+
+ +
+
+property num_weights
+

Utility property that returns the total number of parameters of the Model.

+
+ +
+
+prepare_test(trainer: pytorch_lightning.trainer.trainer.Trainer) bool[source]
+

Helper method to check whether the model can safely be tested on a dataset after training (or loading a +checkpoint).

+
trainer = Trainer()
+if model.prepare_test(trainer):
+    trainer.test(model)
+
+
+
+
Return type
+

Bool which declares the model safe to test. Provides warnings if it has to return False to guide the user.

+
+
+
+ +
+
+register_artifact(config_path: str, src: str, verify_src_exists: bool = True)[source]
+

Register model artifacts with this function. These artifacts (files) will be included inside .mridc file when +model.save_to(“model.mridc”) is called.

+
+
How it works:
    +
  1. It always returns existing absolute path which can be used during Model constructor call EXCEPTION: src is None or “” in which case nothing will be done and src will be returned

  2. +
  3. It will add (config_path, model_utils.ArtifactItem()) pair to self.artifacts

  4. +
+
+
+

If “src” is local existing path, then it will be returned in absolute path form. +elif “src” starts with “mridc_file:unique_artifact_name” .mridc will be untarred to a temporary folder location and an actual existing path will be returned else an error will be raised.

+

WARNING: use .register_artifact calls in your models’ constructors. +The returned path is not guaranteed to exist after you have exited your model’s constructor.

+
+
Parameters
+
    +
  • config_path (Artifact key. Usually corresponds to the model config.) –

  • +
  • src (Path to artifact.) –

  • +
  • verify_src_exists (If set to False, then the artifact is optional and register_artifact will return None even if src is not found. Defaults to True.) –

  • +
+
+
Return type
+

If src is not None or empty it always returns absolute path which is guaranteed to exist during model instance life.

+
+
+
+ +
+
+classmethod restore_from(restore_path: str, override_config_path: Optional[Union[omegaconf.omegaconf.OmegaConf, str]] = None, map_location: Optional[torch.device] = None, strict: bool = True, return_config: bool = False, save_restore_connector: Optional[mridc.core.connectors.save_restore_connector.SaveRestoreConnector] = None, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Restores model instance (weights and configuration) from .mridc file.

+
+
Parameters
+
    +
  • restore_path (path to .mridc file from which model should be instantiated override_config_path: path to a yaml config that will override the internal config file or an OmegaConf/DictConfig object representing the model config.) –

  • +
  • map_location (Optional torch.device() to map the instantiated model to a device. By default (None), it will select a GPU if available, falling back to CPU otherwise.) –

  • +
  • strict (Passed to load_state_dict. By default, True.) –

  • +
  • return_config (If set to true, will return just the underlying config of the restored model as an OmegaConf/DictConfig object without instantiating the model.) –

  • +
  • trainer (Optional, a pytorch lightning Trainer object that will be forwarded to the instantiated model's constructor.) –

  • +
  • save_restore_connector (Can be overridden to add custom save and restore logic.) –

  • +
+
+
+

Example

+
model = mridc.collections.asr.models.EncDecCTCModel.restore_from('asr.mridc')
+assert isinstance(model, mridc.collections.asr.models.EncDecCTCModel)
+
+
+
+
Return type
+

An instance of type cls or its underlying config (if return_config is set).

+
+
+
+ +
+
+save_to(save_path: str)[source]
+

Saves model instance (weights and configuration) into .mridc file. You can use “restore_from” method to fully +restore instance from .mridc file. .mridc file is an archive (tar.gz) with the following: +- model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model’s constructor +- model_wights.ckpt - model checkpoint

+
+
Parameters
+

saved. (Path to .mridc file where model instance should be) –

+
+
+
+ +
+
+set_trainer(trainer: pytorch_lightning.trainer.trainer.Trainer)[source]
+

Set an instance of Trainer object.

+
+ +
+
+set_world_size(trainer: pytorch_lightning.trainer.trainer.Trainer)[source]
+

Determines the world size from the PyTorch Lightning Trainer and then updates AppState.

+
+ +
+
+setup_multiple_test_data(test_data_config: Union[omegaconf.dictconfig.DictConfig, Dict])[source]
+

(Optionally) Setups data loader to be used in test, with support for multiple data loaders.

+
+ +
+
+setup_multiple_validation_data(val_data_config: Union[omegaconf.dictconfig.DictConfig, Dict])[source]
+

(Optionally) Setups data loader to be used in validation.

+
+ +
+
+setup_optimization(optim_config: Optional[Union[omegaconf.dictconfig.DictConfig, Dict]] = None)[source]
+

Prepares an optimizer from a string name and its optional config parameters.

+
+
Parameters
+

optim_config (A dictionary containing the following keys:) –

    +
  • lr: mandatory key for learning rate. Will raise ValueError if not provided.

  • +
  • optimizer: string name pointing to one of the available optimizers in the registry. If not provided, defaults to “adam”.

  • +
  • opt_args: Optional list of strings, in the format “arg_name=arg_value”. The list of “arg_value” will be parsed and a dictionary of optimizer kwargs will be built and supplied to instantiate the optimizer.

  • +
+

+
+
Return type
+

An instance of an optimizer.

+
+
+
+ +
+
+setup_optimizer_param_groups()[source]
+

Used to create param groups for the optimizer. As an example, this can be used to specify per-layer learning +rates:

+
optim.SGD([
+            {'params': model.base.parameters()},
+            {'params': model.classifier.parameters(), 'lr': 1e-3}
+            ], lr=1e-2, momentum=0.9)
+
+
+

See https://pytorch.org/docs/stable/optim.html for more information. By default, ModelPT will use +self.parameters(). Override this method to add custom param groups.

+
+ +
+
+setup_test_data(test_data_config: Union[omegaconf.dictconfig.DictConfig, Dict])[source]
+

(Optionally) Setups data loader to be used in test.

+
+ +
+
+abstract setup_training_data(train_data_config: Union[omegaconf.dictconfig.DictConfig, Dict])[source]
+

Setups data loader to be used in training.

+
+ +
+
+abstract setup_validation_data(val_data_config: Union[omegaconf.dictconfig.DictConfig, Dict])[source]
+

Setups data loader to be used in validation.

+
+ +
+
+teardown(stage: str)[source]
+

Called at the end of fit and test.

+
+ +
+
+test_dataloader()[source]
+

Return the test dataloader.

+
+ +
+
+test_epoch_end(outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]]) Optional[Dict[str, Dict[str, torch.Tensor]]][source]
+

Default DataLoader for Test set which automatically supports multiple data loaders +via multi_test_epoch_end. +If multi dataset support is not required, override this method entirely in base class. +In such a case, there is no need to implement multi_test_epoch_end either.

+
+

Note

+

If more than one data loader exists, and they all provide test_loss, +only the test_loss of the first data loader will be used by default. +This default can be changed by passing the special key _test_dl_idx: int +inside the test_ds config.

+
+
+
Parameters
+

outputs (Single or nested list of tensor outputs from one or more data loaders.) –

+
+
Returns
+

    +
  • A dictionary containing the union of all items from individual data_loaders, along with merged logs from all

  • +
  • data loaders.

  • +
+

+
+
+
+ +
+
+train_dataloader()[source]
+

Return the training dataloader.

+
+ +
+
+training: bool
+
+ +
+
+classmethod update_save_restore_connector(save_restore_connector)[source]
+

Update the save_restore_connector of the model.

+
+ +
+
+val_dataloader()[source]
+

Return the validation dataloader.

+
+ +
+
+validation_epoch_end(outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]]) Optional[Dict[str, Dict[str, torch.Tensor]]][source]
+

Default DataLoader for Validation set which automatically supports multiple data loaders +via multi_validation_epoch_end. +If multi dataset support is not required, override this method entirely in base class. +In such a case, there is no need to implement multi_validation_epoch_end either.

+
+

Note

+

If more than one data loader exists, and they all provide val_loss, +only the val_loss of the first data loader will be used by default. +This default can be changed by passing the special key val_dl_idx: int +inside the validation_ds config.

+
+
+
Parameters
+

outputs (Single or nested list of tensor outputs from one or more data loaders.) –

+
+
Returns
+

    +
  • A dictionary containing the union of all items from individual data_loaders, along with merged logs from all

  • +
  • data loaders.

  • +
+

+
+
+
+ +
+ +
+
+

mridc.core.classes.module module

+
+
+class mridc.core.classes.module.NeuralModule[source]
+

Bases: torch.nn.modules.module.Module, mridc.core.classes.common.Typing, mridc.core.classes.common.Serialization, mridc.core.classes.common.FileIO, abc.ABC

+

Abstract class offering interface shared between all PyTorch Neural Modules.

+
+
+as_frozen()[source]
+

Context manager which temporarily freezes a module, yields control and finally unfreezes the module.

+
+ +
+
+freeze() None[source]
+

Freeze all params for inference.

+
+ +
+
+static input_example(max_batch=None, max_dim=None)[source]
+

Override this method if random inputs won’t work

+
+
Parameters
+
    +
  • max_batch (Maximum batch size to generate) –

  • +
  • max_dim (Maximum dimension to generate) –

  • +
+
+
Return type
+

A tuple sample of valid input data.

+
+
+
+ +
+
+property num_weights
+

Utility property that returns the total number of parameters of NeuralModule.

+
+ +
+
+training: bool
+
+ +
+
+unfreeze() None[source]
+

Unfreeze all parameters for training.

+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.core.conf.html b/docs/build/html/mridc.core.conf.html new file mode 100644 index 00000000..327ccf1f --- /dev/null +++ b/docs/build/html/mridc.core.conf.html @@ -0,0 +1,1403 @@ + + + + + + + mridc.core.conf package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.core.conf package

+
+

Submodules

+
+
+

mridc.core.conf.base_config module

+
+
+class mridc.core.conf.base_config.Config(name: Optional[str] = None)[source]
+

Bases: object

+

Abstract mridc Configuration class.

+
+
+name: Optional[str] = None
+
+ +
+ +
+
+

mridc.core.conf.dataloader module

+
+
+class mridc.core.conf.dataloader.DataLoaderConfig(batch_size: int = '???', shuffle: bool = False, sampler: Optional[Any] = None, batch_sampler: Optional[Any] = None, num_workers: int = 0, collate_fn: Optional[Any] = None, pin_memory: bool = False, drop_last: bool = False, timeout: int = 0, worker_init_fn: Optional[Any] = None, multiprocessing_context: Optional[Any] = None)[source]
+

Bases: object

+

Configuration of PyTorch DataLoader.

+
+
..note:

For the details on the function/meanings of the arguments, please refer to: +https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

+
+
+
+
+batch_sampler: Optional[Any] = None
+
+ +
+
+batch_size: int = '???'
+
+ +
+
+collate_fn: Optional[Any] = None
+
+ +
+
+drop_last: bool = False
+
+ +
+
+multiprocessing_context: Optional[Any] = None
+
+ +
+
+num_workers: int = 0
+
+ +
+
+pin_memory: bool = False
+
+ +
+
+sampler: Optional[Any] = None
+
+ +
+
+shuffle: bool = False
+
+ +
+
+timeout: int = 0
+
+ +
+
+worker_init_fn: Optional[Any] = None
+
+ +
+ +
+
+

mridc.core.conf.hydra_runner module

+
+
+mridc.core.conf.hydra_runner.hydra_runner(config_path: Optional[str] = '.', config_name: Optional[str] = None, schema: Optional[Any] = None) Callable[[Callable[[Any], Any]], Any][source]
+

Decorator used for passing the Config paths to main function. +Optionally registers a schema used for validation/providing default values.

+
+
Parameters
+
    +
  • config_path (Path to the config file.) –

  • +
  • config_name (Name of the config file.) –

  • +
  • schema (Schema used for validation/providing default values.) –

  • +
+
+
Return type
+

A decorator that passes the config paths to the main function.

+
+
+
+ +
+
+

mridc.core.conf.modelPT module

+
+
+class mridc.core.conf.modelPT.HydraConfig(run: typing.Dict[str, typing.Any] = <factory>, job_logging: typing.Dict[str, typing.Any] = <factory>)[source]
+

Bases: object

+

Configuration for the hydra framework.

+
+
+job_logging: Dict[str, Any]
+
+ +
+
+run: Dict[str, Any]
+
+ +
+ +
+
+class mridc.core.conf.modelPT.MRIDCConfig(name: str = '???', model: mridc.core.conf.modelPT.ModelConfig = '???', trainer: mridc.core.conf.trainer.TrainerConfig = TrainerConfig(logger=False, checkpoint_callback=True, callbacks=None, default_root_dir=None, gradient_clip_val=0, process_position=0, num_nodes=1, gpus=None, auto_select_gpus=False, tpu_cores=None, log_gpu_memory=None, progress_bar_refresh_rate=1, enable_progress_bar=True, overfit_batches=0.0, track_grad_norm=- 1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, max_epochs=1000, min_epochs=1, max_steps=None, min_steps=None, limit_train_batches=1.0, limit_val_batches=1.0, limit_test_batches=1.0, val_check_interval=1.0, flush_logs_every_n_steps=100, log_every_n_steps=1, accelerator='gpu', sync_batchnorm=False, precision=32, weights_summary='full', weights_save_path=None, num_sanity_val_steps=2, resume_from_checkpoint=None, profiler=None, benchmark=False, deterministic=False, auto_lr_find=False, replace_sampler_ddp=True, detect_anomaly=False, terminate_on_nan=False, auto_scale_batch_size=False, prepare_data_per_node=True, amp_backend='native', amp_level=None, plugins=None, move_metrics_to_cpu=False, multiple_trainloader_mode='max_size_cycle', limit_predict_batches=1.0, stochastic_weight_avg=False, gradient_clip_algorithm='norm', max_time=None, reload_dataloaders_every_n_epochs=0, ipus=None, devices=None, strategy='ddp', enable_checkpointing=False, enable_model_summary=True), exp_manager: Optional[Any] = ExpManagerConfig(explicit_log_dir=None, exp_dir=None, name=None, version=None, use_datetime_version=True, resume_if_exists=False, resume_past_end=False, resume_ignore_no_checkpoint=False, create_tensorboard_logger=True, summary_writer_kwargs=None, create_wandb_logger=False, wandb_logger_kwargs=None, create_checkpoint_callback=True, checkpoint_callback_params=CallbackParams(filepath=None, dirpath=None, filename=None, monitor='val_loss', verbose=True, save_last=True, save_top_k=3, save_weights_only=False, mode='min', every_n_epochs=1, prefix=None, postfix='.mridc', save_best_model=False, always_save_mridc=False, save_mridc_on_train_end=True, model_parallel_size=None), files_to_copy=None, log_step_timing=True, step_timing_kwargs=StepTimingParams(reduction='mean', sync_cuda=False, buffer_size=1), log_local_rank_0_only=False, log_global_rank_0_only=False, model_parallel_size=None), hydra: mridc.core.conf.modelPT.HydraConfig = HydraConfig(run={'dir': '.'}, job_logging={'root': {'handlers': None}}))[source]
+

Bases: object

+

Configuration for the mridc framework.

+
+
+exp_manager: Optional[Any] = ExpManagerConfig(explicit_log_dir=None, exp_dir=None, name=None, version=None, use_datetime_version=True, resume_if_exists=False, resume_past_end=False, resume_ignore_no_checkpoint=False, create_tensorboard_logger=True, summary_writer_kwargs=None, create_wandb_logger=False, wandb_logger_kwargs=None, create_checkpoint_callback=True, checkpoint_callback_params=CallbackParams(filepath=None, dirpath=None, filename=None, monitor='val_loss', verbose=True, save_last=True, save_top_k=3, save_weights_only=False, mode='min', every_n_epochs=1, prefix=None, postfix='.mridc', save_best_model=False, always_save_mridc=False, save_mridc_on_train_end=True, model_parallel_size=None), files_to_copy=None, log_step_timing=True, step_timing_kwargs=StepTimingParams(reduction='mean', sync_cuda=False, buffer_size=1), log_local_rank_0_only=False, log_global_rank_0_only=False, model_parallel_size=None)
+
+ +
+
+hydra: mridc.core.conf.modelPT.HydraConfig = HydraConfig(run={'dir': '.'}, job_logging={'root': {'handlers': None}})
+
+ +
+
+model: mridc.core.conf.modelPT.ModelConfig = '???'
+
+ +
+
+name: str = '???'
+
+ +
+
+trainer: mridc.core.conf.trainer.TrainerConfig = TrainerConfig(logger=False, checkpoint_callback=True, callbacks=None, default_root_dir=None, gradient_clip_val=0, process_position=0, num_nodes=1, gpus=None, auto_select_gpus=False, tpu_cores=None, log_gpu_memory=None, progress_bar_refresh_rate=1, enable_progress_bar=True, overfit_batches=0.0, track_grad_norm=-1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, max_epochs=1000, min_epochs=1, max_steps=None, min_steps=None, limit_train_batches=1.0, limit_val_batches=1.0, limit_test_batches=1.0, val_check_interval=1.0, flush_logs_every_n_steps=100, log_every_n_steps=1, accelerator='gpu', sync_batchnorm=False, precision=32, weights_summary='full', weights_save_path=None, num_sanity_val_steps=2, resume_from_checkpoint=None, profiler=None, benchmark=False, deterministic=False, auto_lr_find=False, replace_sampler_ddp=True, detect_anomaly=False, terminate_on_nan=False, auto_scale_batch_size=False, prepare_data_per_node=True, amp_backend='native', amp_level=None, plugins=None, move_metrics_to_cpu=False, multiple_trainloader_mode='max_size_cycle', limit_predict_batches=1.0, stochastic_weight_avg=False, gradient_clip_algorithm='norm', max_time=None, reload_dataloaders_every_n_epochs=0, ipus=None, devices=None, strategy='ddp', enable_checkpointing=False, enable_model_summary=True)
+
+ +
+ +
+
+class mridc.core.conf.modelPT.ModelConfig(train_ds: Optional[mridc.core.classes.dataset.DatasetConfig] = None, validation_ds: Optional[mridc.core.classes.dataset.DatasetConfig] = None, test_ds: Optional[mridc.core.classes.dataset.DatasetConfig] = None, optim: Optional[mridc.core.conf.modelPT.OptimConfig] = None)[source]
+

Bases: object

+

Configuration for the model.

+
+
+optim: Optional[mridc.core.conf.modelPT.OptimConfig] = None
+
+ +
+
+test_ds: Optional[mridc.core.classes.dataset.DatasetConfig] = None
+
+ +
+
+train_ds: Optional[mridc.core.classes.dataset.DatasetConfig] = None
+
+ +
+
+validation_ds: Optional[mridc.core.classes.dataset.DatasetConfig] = None
+
+ +
+ +
+
+class mridc.core.conf.modelPT.ModelConfigBuilder(model_cfg: mridc.core.conf.modelPT.ModelConfig)[source]
+

Bases: object

+

Builder for the ModelConfig class.

+
+
+build() mridc.core.conf.modelPT.ModelConfig[source]
+

Validate config

+
+ +
+
+set_optim(cfg: mridc.core.conf.optimizers.OptimizerParams, sched_cfg: Optional[mridc.core.conf.schedulers.SchedulerParams] = None)[source]
+

Set the optimizer configuration.

+
+ +
+
+set_test_ds(cfg: Optional[mridc.core.classes.dataset.DatasetConfig] = None)[source]
+

Set the test dataset configuration.

+
+ +
+
+set_train_ds(cfg: Optional[mridc.core.classes.dataset.DatasetConfig] = None)[source]
+

Set the training dataset configuration.

+
+ +
+
+set_validation_ds(cfg: Optional[mridc.core.classes.dataset.DatasetConfig] = None)[source]
+

Set the validation dataset configuration.

+
+ +
+ +
+
+class mridc.core.conf.modelPT.OptimConfig(name: str = '???', sched: Optional[mridc.core.conf.modelPT.SchedConfig] = None)[source]
+

Bases: object

+

Configuration for the optimizer.

+
+
+name: str = '???'
+
+ +
+
+sched: Optional[mridc.core.conf.modelPT.SchedConfig] = None
+
+ +
+ +
+
+class mridc.core.conf.modelPT.SchedConfig(name: str = '???', min_lr: float = 0.0, last_epoch: int = - 1)[source]
+

Bases: object

+

Configuration for the scheduler.

+
+
+last_epoch: int = -1
+
+ +
+
+min_lr: float = 0.0
+
+ +
+
+name: str = '???'
+
+ +
+ +
+
+

mridc.core.conf.optimizers module

+
+
+class mridc.core.conf.optimizers.AdadeltaParams(lr: Optional[float] = '???', rho: float = 0.9, eps: float = 1e-06, weight_decay: float = 0)[source]
+

Bases: mridc.core.conf.optimizers.OptimizerParams

+

Default configuration for Adadelta optimizer.

+
+

Note

+

For the details on the function/meanings of the arguments, please refer to: +https://pytorch.org/docs/stable/optim.html#torch.optim.Adadelta

+
+
+
+eps: float = 1e-06
+
+ +
+
+rho: float = 0.9
+
+ +
+
+weight_decay: float = 0
+
+ +
+ +
+
+class mridc.core.conf.optimizers.AdagradParams(lr: Optional[float] = '???', lr_decay: float = 0, weight_decay: float = 0, initial_accumulator_value: float = 0, eps: float = 1e-10)[source]
+

Bases: mridc.core.conf.optimizers.OptimizerParams

+

Default configuration for Adagrad optimizer.

+
+

Note

+

For the details on the function/meanings of the arguments, please refer to: +https://pytorch.org/docs/stable/optim.html#torch.optim.Adagrad

+
+
+
+eps: float = 1e-10
+
+ +
+
+initial_accumulator_value: float = 0
+
+ +
+
+lr_decay: float = 0
+
+ +
+
+weight_decay: float = 0
+
+ +
+ +
+
+class mridc.core.conf.optimizers.AdamParams(lr: Optional[float] = '???', eps: float = 1e-08, weight_decay: float = 0, amsgrad: bool = False)[source]
+

Bases: mridc.core.conf.optimizers.OptimizerParams

+

Default configuration for Adam optimizer.

+
+

Note

+

For the details on the function/meanings of the arguments, please refer to: +https://pytorch.org/docs/stable/optim.html?highlight=adam#torch.optim.Adam

+
+
+
+amsgrad: bool = False
+
+ +
+
+eps: float = 1e-08
+
+ +
+
+weight_decay: float = 0
+
+ +
+ +
+
+class mridc.core.conf.optimizers.AdamWParams(lr: Optional[float] = '???', betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0, amsgrad: bool = False)[source]
+

Bases: mridc.core.conf.optimizers.OptimizerParams

+

Default configuration for AdamW optimizer.

+
+

Note

+

For the details on the function/meanings of the arguments, please refer to: +https://pytorch.org/docs/stable/optim.html#torch.optim.AdamW

+
+
+
+amsgrad: bool = False
+
+ +
+
+betas: Tuple[float, float] = (0.9, 0.999)
+
+ +
+
+eps: float = 1e-08
+
+ +
+
+weight_decay: float = 0
+
+ +
+ +
+
+class mridc.core.conf.optimizers.AdamaxParams(lr: Optional[float] = '???', betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0)[source]
+

Bases: mridc.core.conf.optimizers.OptimizerParams

+

Default configuration for Adamax optimizer.

+
+

Note

+

For the details on the function/meanings of the arguments, please refer to: +https://pytorch.org/docs/stable/optim.html#torch.optim.Adamax

+
+
+
+betas: Tuple[float, float] = (0.9, 0.999)
+
+ +
+
+eps: float = 1e-08
+
+ +
+
+weight_decay: float = 0
+
+ +
+ +
+
+class mridc.core.conf.optimizers.NovogradParams(lr: float = 0.001, betas: Tuple[float, float] = (0.95, 0.98), eps: float = 1e-08, weight_decay: float = 0, grad_averaging: bool = False, amsgrad: bool = False, luc: bool = False, luc_trust: float = 0.001, luc_eps: float = 1e-08)[source]
+

Bases: mridc.core.conf.optimizers.OptimizerParams

+

Configuration of the Novograd optimizer. It has been proposed in “Stochastic Gradient Methods with Layer-wise +Adaptive Moments for Training of Deep Networks” (https://arxiv.org/abs/1905.11286). The OptimizerParams is a Base +Optimizer params with no values. User can choose to explicitly override it via command line arguments.

+
+
+amsgrad: bool = False
+
+ +
+
+betas: Tuple[float, float] = (0.95, 0.98)
+
+ +
+
+eps: float = 1e-08
+
+ +
+
+grad_averaging: bool = False
+
+ +
+
+lr: float = 0.001
+
+ +
+
+luc: bool = False
+
+ +
+
+luc_eps: float = 1e-08
+
+ +
+
+luc_trust: float = 0.001
+
+ +
+
+weight_decay: float = 0
+
+ +
+ +
+
+class mridc.core.conf.optimizers.OptimizerParams(lr: Optional[float] = '???')[source]
+

Bases: object

+

Base Optimizer params with no values. User can chose it to explicitly override via command line arguments.

+
+
+lr: Optional[float] = '???'
+
+ +
+ +
+
+class mridc.core.conf.optimizers.RMSpropParams(lr: Optional[float] = '???', alpha: float = 0.99, eps: float = 1e-08, weight_decay: float = 0, momentum: float = 0, centered: bool = False)[source]
+

Bases: mridc.core.conf.optimizers.OptimizerParams

+

Default configuration for RMSprop optimizer.

+
+

Note

+

For the details on the function/meanings of the arguments, please refer to: +https://pytorch.org/docs/stable/optim.html#torch.optim.RMSprop

+
+
+
+alpha: float = 0.99
+
+ +
+
+centered: bool = False
+
+ +
+
+eps: float = 1e-08
+
+ +
+
+momentum: float = 0
+
+ +
+
+weight_decay: float = 0
+
+ +
+ +
+
+class mridc.core.conf.optimizers.RpropParams(lr: Optional[float] = '???', etas: Tuple[float, float] = (0.5, 1.2), step_sizes: Tuple[float, float] = (1e-06, 50))[source]
+

Bases: mridc.core.conf.optimizers.OptimizerParams

+

Default configuration for RpropParams optimizer.

+
+

Note

+

For the details on the function/meanings of the arguments, please refer to: +https://pytorch.org/docs/stable/optim.html#torch.optim.Rprop

+
+
+
+etas: Tuple[float, float] = (0.5, 1.2)
+
+ +
+
+step_sizes: Tuple[float, float] = (1e-06, 50)
+
+ +
+ +
+
+class mridc.core.conf.optimizers.SGDParams(lr: Optional[float] = '???', momentum: float = 0, dampening: float = 0, weight_decay: float = 0, nesterov: bool = False)[source]
+

Bases: mridc.core.conf.optimizers.OptimizerParams

+

Default configuration for Adam optimizer.

+
+

Note

+

For the details on the function/meanings of the arguments, please refer to: +https://pytorch.org/docs/stable/optim.html?highlight=sgd#torch.optim.SGD

+
+
+
+dampening: float = 0
+
+ +
+
+momentum: float = 0
+
+ +
+
+nesterov: bool = False
+
+ +
+
+weight_decay: float = 0
+
+ +
+ +
+
+mridc.core.conf.optimizers.get_optimizer_config(name: str, **kwargs: Optional[Dict[str, Any]]) Union[Dict[str, Optional[Dict[str, Any]]], functools.partial][source]
+

Convenience method to obtain a OptimizerParams class and partially instantiate it with optimizer kwargs.

+
+
Parameters
+
    +
  • name (Name of the OptimizerParams in the registry.) –

  • +
  • kwargs (Optional kwargs of the optimizer used during instantiation.) –

  • +
+
+
Return type
+

A partially instantiated OptimizerParams.

+
+
+
+ +
+
+mridc.core.conf.optimizers.register_optimizer_params(name: str, optimizer_params: mridc.core.conf.optimizers.OptimizerParams)[source]
+

Checks if the optimizer param name exists in the registry, and if it doesn’t, adds it. +This allows custom optimizer params to be added and called by name during instantiation.

+
+
Parameters
+
    +
  • name (Name of the optimizer. Will be used as key to retrieve the optimizer.) –

  • +
  • optimizer_params (Optimizer class) –

  • +
+
+
+
+ +
+
+

mridc.core.conf.schedulers module

+
+
+class mridc.core.conf.schedulers.CosineAnnealingParams(last_epoch: int = - 1, max_steps: int = 0, warmup_steps: Optional[float] = None, warmup_ratio: Optional[float] = None, constant_steps: Optional[float] = None, constant_ratio: Optional[float] = None, min_lr: float = 0.0)[source]
+

Bases: mridc.core.conf.schedulers.WarmupAnnealingHoldSchedulerParams

+

Cosine Annealing parameter config

+
+
+min_lr: float = 0.0
+
+ +
+ +
+
+class mridc.core.conf.schedulers.CyclicLRParams(last_epoch: int = - 1, base_lr: float = 0.001, max_lr: float = 0.1, step_size_up: int = 2000, step_size_down: Optional[int] = None, mode: str = 'triangular', gamma: float = 1.0, scale_mode: str = 'cycle', cycle_momentum: bool = True, base_momentum: float = 0.8, max_momentum: float = 0.9)[source]
+

Bases: mridc.core.conf.schedulers.SchedulerParams

+

Config for CyclicLR.

+
+
+base_lr: float = 0.001
+
+ +
+
+base_momentum: float = 0.8
+
+ +
+
+cycle_momentum: bool = True
+
+ +
+
+gamma: float = 1.0
+
+ +
+
+max_lr: float = 0.1
+
+ +
+
+max_momentum: float = 0.9
+
+ +
+
+mode: str = 'triangular'
+
+ +
+
+scale_mode: str = 'cycle'
+
+ +
+
+step_size_down: Optional[int] = None
+
+ +
+
+step_size_up: int = 2000
+
+ +
+ +
+
+class mridc.core.conf.schedulers.ExponentialLRParams(last_epoch: int = - 1, gamma: float = 0.9)[source]
+

Bases: mridc.core.conf.schedulers.SchedulerParams

+

Config for ExponentialLR.

+
+
+gamma: float = 0.9
+
+ +
+ +
+
+class mridc.core.conf.schedulers.InverseSquareRootAnnealingParams(last_epoch: int = - 1, max_steps: int = 0, warmup_steps: Optional[float] = None, warmup_ratio: Optional[float] = None)[source]
+

Bases: mridc.core.conf.schedulers.WarmupSchedulerParams

+

Inverse Square Root Annealing parameter config

+
+ +
+
+class mridc.core.conf.schedulers.NoamAnnealingParams(last_epoch: int = - 1, max_steps: int = 0, warmup_steps: Optional[float] = None, warmup_ratio: Optional[float] = None, min_lr: float = 0.0)[source]
+

Bases: mridc.core.conf.schedulers.WarmupSchedulerParams

+

Cosine Annealing parameter config

+
+
+min_lr: float = 0.0
+
+ +
+ +
+
+class mridc.core.conf.schedulers.PolynomialDecayAnnealingParams(last_epoch: int = - 1, max_steps: int = 0, warmup_steps: Optional[float] = None, warmup_ratio: Optional[float] = None, power: float = 1.0, cycle: bool = False)[source]
+

Bases: mridc.core.conf.schedulers.WarmupSchedulerParams

+

Polynomial Decay Annealing parameter config

+
+
+cycle: bool = False
+
+ +
+
+power: float = 1.0
+
+ +
+ +
+
+class mridc.core.conf.schedulers.PolynomialHoldDecayAnnealingParams(last_epoch: int = - 1, max_steps: int = 0, warmup_steps: Optional[float] = None, warmup_ratio: Optional[float] = None, power: float = 1.0, cycle: bool = False)[source]
+

Bases: mridc.core.conf.schedulers.WarmupSchedulerParams

+

Polynomial Hold Decay Annealing parameter config

+
+
+cycle: bool = False
+
+ +
+
+power: float = 1.0
+
+ +
+ +
+
+class mridc.core.conf.schedulers.ReduceLROnPlateauParams(mode: str = 'min', factor: float = 0.1, patience: int = 10, verbose: bool = False, threshold: float = 0.0001, threshold_mode: str = 'rel', cooldown: int = 0, min_lr: float = 0, eps: float = 1e-08)[source]
+

Bases: object

+

Config for ReduceLROnPlateau.

+
+
+cooldown: int = 0
+
+ +
+
+eps: float = 1e-08
+
+ +
+
+factor: float = 0.1
+
+ +
+
+min_lr: float = 0
+
+ +
+
+mode: str = 'min'
+
+ +
+
+patience: int = 10
+
+ +
+
+threshold: float = 0.0001
+
+ +
+
+threshold_mode: str = 'rel'
+
+ +
+
+verbose: bool = False
+
+ +
+ +
+
+class mridc.core.conf.schedulers.SchedulerParams(last_epoch: int = - 1)[source]
+

Bases: object

+

Base configuration for all schedulers.

+
+
+last_epoch: int = -1
+
+ +
+ +
+
+class mridc.core.conf.schedulers.SquareAnnealingParams(last_epoch: int = - 1, max_steps: int = 0, warmup_steps: Optional[float] = None, warmup_ratio: Optional[float] = None, min_lr: float = 1e-05)[source]
+

Bases: mridc.core.conf.schedulers.WarmupSchedulerParams

+

Square Annealing parameter config

+
+
+min_lr: float = 1e-05
+
+ +
+ +
+
+class mridc.core.conf.schedulers.SquareRootAnnealingParams(last_epoch: int = - 1, max_steps: int = 0, warmup_steps: Optional[float] = None, warmup_ratio: Optional[float] = None, min_lr: float = 0.0)[source]
+

Bases: mridc.core.conf.schedulers.WarmupSchedulerParams

+

Square Root Annealing parameter config

+
+
+min_lr: float = 0.0
+
+ +
+ +
+
+class mridc.core.conf.schedulers.SquareRootConstantSchedulerParams(last_epoch: int = - 1, constant_steps: Optional[float] = None, constant_ratio: Optional[float] = None)[source]
+

Bases: mridc.core.conf.schedulers.SchedulerParams

+

Base configuration for all schedulers. +It is not derived from Config as it is not a mridc object (and in particular it doesn’t need a name).

+
+
+constant_ratio: Optional[float] = None
+
+ +
+
+constant_steps: Optional[float] = None
+
+ +
+ +
+
+class mridc.core.conf.schedulers.StepLRParams(last_epoch: int = - 1, step_size: float = 0.1, gamma: float = 0.1)[source]
+

Bases: mridc.core.conf.schedulers.SchedulerParams

+

Config for StepLR.

+
+
+gamma: float = 0.1
+
+ +
+
+step_size: float = 0.1
+
+ +
+ +
+
+class mridc.core.conf.schedulers.WarmupAnnealingHoldSchedulerParams(last_epoch: int = - 1, max_steps: int = 0, warmup_steps: Optional[float] = None, warmup_ratio: Optional[float] = None, constant_steps: Optional[float] = None, constant_ratio: Optional[float] = None, min_lr: float = 0.0)[source]
+

Bases: mridc.core.conf.schedulers.WarmupSchedulerParams

+

Base configuration for all schedulers.

+
+
+constant_ratio: Optional[float] = None
+
+ +
+
+constant_steps: Optional[float] = None
+
+ +
+
+min_lr: float = 0.0
+
+ +
+ +
+
+class mridc.core.conf.schedulers.WarmupAnnealingParams(last_epoch: int = - 1, max_steps: int = 0, warmup_steps: Optional[float] = None, warmup_ratio: Optional[float] = None)[source]
+

Bases: mridc.core.conf.schedulers.WarmupSchedulerParams

+

Warmup Annealing parameter config

+
+
+warmup_ratio: Optional[float] = None
+
+ +
+ +
+
+class mridc.core.conf.schedulers.WarmupHoldSchedulerParams(last_epoch: int = - 1, max_steps: int = 0, warmup_steps: Optional[float] = None, warmup_ratio: Optional[float] = None, hold_steps: Optional[float] = None, hold_ratio: Optional[float] = None, min_lr: float = 0.0)[source]
+

Bases: mridc.core.conf.schedulers.WarmupSchedulerParams

+

Base configuration for all schedulers.

+
+
+hold_ratio: Optional[float] = None
+
+ +
+
+hold_steps: Optional[float] = None
+
+ +
+
+min_lr: float = 0.0
+
+ +
+ +
+
+class mridc.core.conf.schedulers.WarmupSchedulerParams(last_epoch: int = - 1, max_steps: int = 0, warmup_steps: Optional[float] = None, warmup_ratio: Optional[float] = None)[source]
+

Bases: mridc.core.conf.schedulers.SchedulerParams

+

Base configuration for all schedulers.

+
+
+max_steps: int = 0
+
+ +
+
+warmup_ratio: Optional[float] = None
+
+ +
+
+warmup_steps: Optional[float] = None
+
+ +
+ +
+
+mridc.core.conf.schedulers.get_scheduler_config(name: str, **kwargs: Optional[Dict[str, Any]]) functools.partial[source]
+

Convenience method to obtain a SchedulerParams class and partially instantiate it with optimizer kwargs.

+
+
Parameters
+
    +
  • name (Name of the SchedulerParams in the registry.) –

  • +
  • kwargs (Optional kwargs of the optimizer used during instantiation.) –

  • +
+
+
Return type
+

A partially instantiated SchedulerParams.

+
+
+
+ +
+
+mridc.core.conf.schedulers.register_scheduler_params(name: str, scheduler_params: mridc.core.conf.schedulers.SchedulerParams)[source]
+

Checks if the scheduler config name exists in the registry, and if it doesn’t, adds it. +This allows custom schedulers to be added and called by name during instantiation.

+
+
Parameters
+
    +
  • name (Name of the optimizer. Will be used as key to retrieve the optimizer.) –

  • +
  • scheduler_params (SchedulerParams class) –

  • +
+
+
+
+ +
+
+

mridc.core.conf.trainer module

+
+
+class mridc.core.conf.trainer.TrainerConfig(logger: Any = True, checkpoint_callback: Any = True, callbacks: Optional[Any] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, num_nodes: int = 1, gpus: Optional[Any] = None, auto_select_gpus: bool = False, tpu_cores: Optional[Any] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: int = 1, enable_progress_bar: bool = True, overfit_batches: Any = 0.0, track_grad_norm: Any = - 1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, accumulate_grad_batches: Any = 1, max_epochs: int = 1000, min_epochs: int = 1, max_steps: Optional[int] = None, min_steps: Optional[int] = None, limit_train_batches: Any = 1.0, limit_val_batches: Any = 1.0, limit_test_batches: Any = 1.0, val_check_interval: Any = 1.0, flush_logs_every_n_steps: int = 100, log_every_n_steps: int = 50, accelerator: Optional[str] = None, sync_batchnorm: bool = False, precision: Any = 32, weights_summary: Optional[str] = 'full', weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, resume_from_checkpoint: Optional[str] = None, profiler: Optional[Any] = None, benchmark: bool = False, deterministic: bool = False, auto_lr_find: Any = False, replace_sampler_ddp: bool = True, detect_anomaly: bool = False, terminate_on_nan: bool = False, auto_scale_batch_size: Any = False, prepare_data_per_node: bool = True, amp_backend: str = 'native', amp_level: Optional[str] = None, plugins: Optional[Any] = None, move_metrics_to_cpu: bool = False, multiple_trainloader_mode: str = 'max_size_cycle', limit_predict_batches: float = 1.0, stochastic_weight_avg: bool = False, gradient_clip_algorithm: str = 'norm', max_time: Optional[Any] = None, reload_dataloaders_every_n_epochs: int = 0, ipus: Optional[int] = None, devices: Optional[Any] = None, strategy: Optional[Any] = None, enable_checkpointing: bool = True, enable_model_summary: bool = True)[source]
+

Bases: object

+

TrainerConfig is a dataclass that holds all the hyperparameters for the training process.

+
+
+accelerator: Optional[str] = None
+
+ +
+
+accumulate_grad_batches: Any = 1
+
+ +
+
+amp_backend: str = 'native'
+
+ +
+
+amp_level: Optional[str] = None
+
+ +
+
+auto_lr_find: Any = False
+
+ +
+
+auto_scale_batch_size: Any = False
+
+ +
+
+auto_select_gpus: bool = False
+
+ +
+
+benchmark: bool = False
+
+ +
+
+callbacks: Optional[Any] = None
+
+ +
+
+check_val_every_n_epoch: int = 1
+
+ +
+
+checkpoint_callback: Any = True
+
+ +
+
+default_root_dir: Optional[str] = None
+
+ +
+
+detect_anomaly: bool = False
+
+ +
+
+deterministic: bool = False
+
+ +
+
+devices: Any = None
+
+ +
+
+enable_checkpointing: bool = True
+
+ +
+
+enable_model_summary: bool = True
+
+ +
+
+enable_progress_bar: bool = True
+
+ +
+
+fast_dev_run: bool = False
+
+ +
+
+flush_logs_every_n_steps: int = 100
+
+ +
+
+gpus: Optional[Any] = None
+
+ +
+
+gradient_clip_algorithm: str = 'norm'
+
+ +
+
+gradient_clip_val: float = 0
+
+ +
+
+ipus: Optional[int] = None
+
+ +
+
+limit_predict_batches: float = 1.0
+
+ +
+
+limit_test_batches: Any = 1.0
+
+ +
+
+limit_train_batches: Any = 1.0
+
+ +
+
+limit_val_batches: Any = 1.0
+
+ +
+
+log_every_n_steps: int = 50
+
+ +
+
+log_gpu_memory: Optional[str] = None
+
+ +
+
+logger: Any = True
+
+ +
+
+max_epochs: int = 1000
+
+ +
+
+max_steps: Optional[int] = None
+
+ +
+
+max_time: Optional[Any] = None
+
+ +
+
+min_epochs: int = 1
+
+ +
+
+min_steps: Optional[int] = None
+
+ +
+
+move_metrics_to_cpu: bool = False
+
+ +
+
+multiple_trainloader_mode: str = 'max_size_cycle'
+
+ +
+
+num_nodes: int = 1
+
+ +
+
+num_sanity_val_steps: int = 2
+
+ +
+
+overfit_batches: Any = 0.0
+
+ +
+
+plugins: Optional[Any] = None
+
+ +
+
+precision: Any = 32
+
+ +
+
+prepare_data_per_node: bool = True
+
+ +
+
+process_position: int = 0
+
+ +
+
+profiler: Optional[Any] = None
+
+ +
+
+progress_bar_refresh_rate: int = 1
+
+ +
+
+reload_dataloaders_every_n_epochs: int = 0
+
+ +
+
+replace_sampler_ddp: bool = True
+
+ +
+
+resume_from_checkpoint: Optional[str] = None
+
+ +
+
+stochastic_weight_avg: bool = False
+
+ +
+
+strategy: Any = None
+
+ +
+
+sync_batchnorm: bool = False
+
+ +
+
+terminate_on_nan: bool = False
+
+ +
+
+tpu_cores: Optional[Any] = None
+
+ +
+
+track_grad_norm: Any = -1
+
+ +
+
+val_check_interval: Any = 1.0
+
+ +
+
+weights_save_path: Optional[str] = None
+
+ +
+
+weights_summary: Optional[str] = 'full'
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.core.connectors.html b/docs/build/html/mridc.core.connectors.html new file mode 100644 index 00000000..8142589a --- /dev/null +++ b/docs/build/html/mridc.core.connectors.html @@ -0,0 +1,306 @@ + + + + + + + mridc.core.connectors package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.core.connectors package

+
+

Submodules

+
+
+

mridc.core.connectors.save_restore_connector module

+
+
+class mridc.core.connectors.save_restore_connector.SaveRestoreConnector[source]
+

Bases: object

+

This class is used to save and restore the model state.

+
+
+extract_state_dict_from(restore_path: str, save_dir: str, split_by_module: bool = False)[source]
+

Extract the state dict(s) from a provided .mridc tarfile and save it to a directory.

+
+
Parameters
+
    +
  • restore_path (path to .mridc file from which state dict(s) should be extracted) –

  • +
  • save_dir (directory in which the saved state dict(s) should be stored) –

  • +
  • split_by_module (bool flag, which determines whether the output checkpoint should be for the entire Model, or) –

  • +
  • Model. (the individual module's that comprise the) –

  • +
+
+
+

Example

+

To convert the .mridc tarfile into a single Model level PyTorch checkpoint +:: +state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from(‘asr.mridc’, +‘./asr_ckpts’) +To restore a model from a Model level checkpoint +:: +model = mridc.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration +model.load_state_dict(torch.load(“./asr_ckpts/model_weights.ckpt”)) +To convert the .mridc tarfile into multiple Module level PyTorch checkpoints +:: +state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from(‘asr.mridc’, +‘./asr_ckpts’, split_by_module=True). To restore a module from a Module level checkpoint +:: +model = mridc.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration +# load the individual components +model.preprocessor.load_state_dict(torch.load(“./asr_ckpts/preprocessor.ckpt”)) +model.encoder.load_state_dict(torch.load(“./asr_ckpts/encoder.ckpt”)) +model.decoder.load_state_dict(torch.load(“./asr_ckpts/decoder.ckpt”))

+
+
Return type
+

The state dict that was loaded from the original .mridc checkpoint.

+
+
+
+ +
+
+load_config_and_state_dict(calling_cls, restore_path: str, override_config_path: Optional[Union[omegaconf.omegaconf.OmegaConf, str]] = None, map_location: Optional[torch.device] = None, strict: bool = True, return_config: bool = False, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Restores model instance (weights and configuration) into .mridc file

+
+
Parameters
+
    +
  • calling_cls (Class of the model to be restored.) –

  • +
  • restore_path (path to .mridc file from which model should be instantiated) –

  • +
  • override_config_path (path to a yaml config that will override the internal config file or an) –

  • +
  • config. (OmegaConf/DictConfig object representing the model) –

  • +
  • map_location (Optional torch.device() to map the instantiated model to a device. By default (None), it will) –

  • +
  • available (select a GPU if) –

  • +
  • otherwise. (falling back to CPU) –

  • +
  • strict (Passed to load_state_dict. By default, True.) –

  • +
  • return_config (If set to true, will return just the underlying config of the restored model as an OmegaConf) –

  • +
  • model. (DictConfig object without instantiating the) –

  • +
  • trainer (Optional trainer object to be used for model parallelism.) –

  • +
+
+
+

Example

+

` +model = mridc.collections.asr.models.EncDecCTCModel.restore_from('asr.mridc') +assert isinstance(model, mridc.collections.asr.models.EncDecCTCModel) +`

+
+
Return type
+

An instance of type cls or its underlying config (if return_config is set).

+
+
+
+ +
+
+static load_instance_with_state_dict(instance, state_dict, strict)[source]
+

Loads the state dict into the instance.

+
+ +
+
+property model_config_yaml: str
+

This property is used to get the path to the model config yaml file.

+
+ +
+
+property model_weights_ckpt: str
+

This property is used to get the path to the model weights ckpt file.

+
+ +
+
+static register_artifact(model, config_path: str, src: str, verify_src_exists: bool = True)[source]
+

Register model artifacts with this function. These artifacts (files) will be included inside .mridc file +when model.save_to(“mymodel.mridc”) is called.

+

How it works: +1. It always returns existing absolute path which can be used during Model constructor call. EXCEPTION: src is +None or “” in which case nothing will be done and src will be returned +2. It will add (config_path, model_utils.ArtifactItem()) pair to self.artifacts. If “src” is local existing +path, then it will be returned in absolute path form. elif “src” starts with “mridc_file:unique_artifact_name”: +.mridc will be untarred to a temporary folder location and an actual existing path will be returned else an +error will be raised.

+

WARNING: use .register_artifact calls in your models’ constructors. +The returned path is not guaranteed to exist after you have exited your model’s constructor.

+
+
Parameters
+
    +
  • model (ModelPT object to register artifact for.) –

  • +
  • config_path (Artifact key. Usually corresponds to the model config.) –

  • +
  • src (Path to artifact.) –

  • +
  • verify_src_exists (If set to False, then the artifact is optional and register_artifact will return None) – even if src is not found. Defaults to True.

  • +
+
+
Returns
+

life.

+
+
Return type
+

If src is not None or empty it always returns absolute path which is guaranteed to exist during model instance

+
+
+
+ +
+
+restore_from(calling_cls, restore_path: str, override_config_path: Optional[Union[omegaconf.omegaconf.OmegaConf, str]] = None, map_location: Optional[torch.device] = None, strict: bool = True, return_config: bool = False, trainer: Optional[pytorch_lightning.trainer.trainer.Trainer] = None)[source]
+

Restores model instance (weights and configuration) into .mridc file

+
+
Parameters
+
    +
  • calling_cls (The class of the model to be restored.) –

  • +
  • restore_path (path to .mridc file from which model should be instantiated) –

  • +
  • override_config_path (path to a yaml config that will override the internal config file or an) –

  • +
  • config. (OmegaConf/DictConfig object representing the model) –

  • +
  • map_location (Optional torch.device() to map the instantiated model to a device. By default (None), it will) –

  • +
  • available (select a GPU if) –

  • +
  • otherwise. (falling back to CPU) –

  • +
  • strict (Passed to load_state_dict. By default, True.) –

  • +
  • return_config (If set to true, will return just the underlying config of the restored model as an) –

  • +
  • model. (OmegaConf/DictConfig object without instantiating the) –

  • +
  • trainer (Optional trainer object to be used for restoring the model.) –

  • +
+
+
Return type
+

An instance of type cls or its underlying config (if return_config is set).

+
+
+
+ +
+
+save_to(model, save_path: str)[source]
+

Saves model instance (weights and configuration) into .mridc file. +You can use “restore_from” method to fully restore instance from .mridc file. +.mridc file is an archive (tar.gz) with the following: +- model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model’s constructor +- model_wights.chpt - model checkpoint

+
+
Parameters
+
    +
  • model (ModelPT object to be saved.) –

  • +
  • save_path (Path to .mridc file where model instance should be saved) –

  • +
+
+
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.core.html b/docs/build/html/mridc.core.html new file mode 100644 index 00000000..8d06e988 --- /dev/null +++ b/docs/build/html/mridc.core.html @@ -0,0 +1,187 @@ + + + + + + + mridc.core package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+ + +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.core.neural_types.html b/docs/build/html/mridc.core.neural_types.html new file mode 100644 index 00000000..9347b594 --- /dev/null +++ b/docs/build/html/mridc.core.neural_types.html @@ -0,0 +1,638 @@ + + + + + + + mridc.core.neural_types package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.core.neural_types package

+
+

Submodules

+
+
+

mridc.core.neural_types.axes module

+
+
+class mridc.core.neural_types.axes.AxisKind(value)[source]
+

Bases: mridc.core.neural_types.axes.AxisKindAbstract

+

This Enum represents what does varying axis dimension mean. For example, does this dimension correspond to width, batch, time, etc. The “Dimension” and “Channel” kinds are the same and used to represent a general axis. “Any” axis will accept any axis kind fed to it.

+
+
+Any = 5
+
+ +
+
+Batch = 0
+
+ +
+
+Channel = 2
+
+ +
+
+Dimension = 2
+
+ +
+
+FlowGroup = 7
+
+ +
+
+Height = 4
+
+ +
+
+Sequence = 6
+
+ +
+
+Singleton = 8
+
+ +
+
+Time = 1
+
+ +
+
+Width = 3
+
+ +
+
+__repr__()[source]
+

Returns short string representation of the AxisKind

+
+ +
+
+__str__()[source]
+

Returns short string representation of the AxisKind

+
+ +
+
+static from_str(label)[source]
+

Returns AxisKind instance based on short string representation

+
+ +
+
+t_with_string(text)[source]
+

It checks if text is ‘t_<any string>’

+
+ +
+ +
+
+class mridc.core.neural_types.axes.AxisKindAbstract(value)[source]
+

Bases: enum.Enum

+

This is an abstract Enum to represents what does varying axis dimension mean. In practice, you will almost always +use AxisKind Enum. This Enum should be inherited by your OWN Enum if you aren’t satisfied with AxisKind. Then your +own Enum can be used instead of AxisKind.

+
+ +
+
+class mridc.core.neural_types.axes.AxisType(kind: mridc.core.neural_types.axes.AxisKindAbstract, size: Optional[int] = None, is_list=False)[source]
+

Bases: object

+

This class represents axis semantics and (optionally) it’s dimensionality

+
+
Parameters
+
    +
  • kind (what kind of axis it is? For example Batch, Height, etc.) – AxisKindAbstract

  • +
  • size (specify if the axis should have a fixed size. By default, it is set to None and you typically do not want to) –

  • +
  • Time. (set it for Batch and) – (int, optional)

  • +
  • is_list (whether this is a list or a tensor axis.) – (bool, default=False)

  • +
+
+
+
+
+__repr__()[source]
+

Returns short string representation of the AxisType

+
+ +
+ +
+
+

mridc.core.neural_types.comparison module

+
+
+class mridc.core.neural_types.comparison.NeuralTypeComparisonResult(value)[source]
+

Bases: enum.Enum

+

The result of comparing two neural type objects for compatibility. When comparing A.compare_to(B).

+
+
+CONTAINER_SIZE_MISMATCH = 5
+
+ +
+
+DIM_INCOMPATIBLE = 3
+
+ +
+
+GREATER = 2
+
+ +
+
+INCOMPATIBLE = 6
+
+ +
+
+LESS = 1
+
+ +
+
+SAME = 0
+
+ +
+
+SAME_TYPE_INCOMPATIBLE_PARAMS = 7
+
+ +
+
+TRANSPOSE_SAME = 4
+
+ +
+
+UNCHECKED = 8
+
+ +
+ +
+
+

mridc.core.neural_types.elements module

+
+
+class mridc.core.neural_types.elements.CategoricalValuesType[source]
+

Bases: mridc.core.neural_types.elements.PredictionsType

+

Element type to represent labels for categorical classification task

+
+ +
+
+class mridc.core.neural_types.elements.ChannelType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element to represent convolutional input/output channel.

+
+ +
+
+class mridc.core.neural_types.elements.ElementType[source]
+

Bases: abc.ABC

+

Abstract class defining semantics of the tensor elements. We are relying on Python for inheritance checking

+
+
+__repr__()[source]
+

Override this method to provide a human readable representation of the type

+
+ +
+
+__str__()[source]
+

Override this method to provide a human readable representation of the type

+
+ +
+
+compare(second) mridc.core.neural_types.comparison.NeuralTypeComparisonResult[source]
+

Override this method to provide a comparison between two types.

+
+ +
+
+property fields: Optional[Tuple]
+

This should be used to logically represent tuples/structures. For example, if you want to represent a bounding box (x, y, width, height) you can put a tuple with names (‘x’, y’, ‘w’, ‘h’) in here. Under the hood this should be converted to the last tensor dimension of fixed size = len(fields). When two types are compared their fields must match.

+
+ +
+
+property type_parameters: Dict
+

Override this property to parametrize your type. For example, you can specify ‘storage’ type such as float, +int, bool with ‘dtype’ keyword. Another example, is if you want to represent a signal with a particular +property (say, sample frequency), then you can put sample_freq->value in there. When two types are compared +their type_parameters must match.”

+
+ +
+ +
+
+class mridc.core.neural_types.elements.FloatType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element type representing a single float

+
+ +
+
+class mridc.core.neural_types.elements.ImageFeatureValue[source]
+

Bases: mridc.core.neural_types.elements.ImageValue

+

Type representing an element (single value) of a (image) feature maps.

+
+ +
+
+class mridc.core.neural_types.elements.ImageValue[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Type representing an element/value of a single image channel,

+
+ +
+
+class mridc.core.neural_types.elements.Index[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Type representing an element being an index of the sample.

+
+ +
+
+class mridc.core.neural_types.elements.IntType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element type representing a single integer

+
+ +
+
+class mridc.core.neural_types.elements.LabelsType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element type to represent labels of something. For example, labels of a dataset.

+
+ +
+
+class mridc.core.neural_types.elements.Length[source]
+

Bases: mridc.core.neural_types.elements.IntType

+

Type representing an element storing a “length” (e.g. length of a list).

+
+ +
+
+class mridc.core.neural_types.elements.LengthsType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element type representing lengths of something

+
+ +
+
+class mridc.core.neural_types.elements.LogDeterminantType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element for representing log determinants usually used in flow models

+
+ +
+
+class mridc.core.neural_types.elements.LogprobsType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element type to represent log-probabilities. For example, outputs of log softmax layers.

+
+ +
+
+class mridc.core.neural_types.elements.LossType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element type to represent outputs of Loss modules

+
+ +
+
+class mridc.core.neural_types.elements.MRISignal(freq: Optional[int] = None)[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element type to represent encoded representation returned by the mri model

+
+
Parameters
+

freq (sampling frequency of a signal. Note that two signals will only be the same if their freq is the same.) –

+
+
+
+
+property type_parameters
+

Returns the type parameters of the element type.

+
+ +
+ +
+
+class mridc.core.neural_types.elements.MaskType[source]
+

Bases: mridc.core.neural_types.elements.PredictionsType

+

Element type to represent a boolean mask

+
+ +
+
+class mridc.core.neural_types.elements.NormalDistributionLogVarianceType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element to represent the log variance of a normal distribution

+
+ +
+
+class mridc.core.neural_types.elements.NormalDistributionMeanType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element to represent the mean of a normal distribution

+
+ +
+
+class mridc.core.neural_types.elements.NormalDistributionSamplesType[source]
+

Bases: mridc.core.neural_types.elements.ProbabilityDistributionSamplesType

+

Element to represent tensors that meant to be sampled from a valid normal distribution

+
+ +
+
+class mridc.core.neural_types.elements.NormalizedImageValue[source]
+

Bases: mridc.core.neural_types.elements.ImageValue

+

Type representing an element/value of a single image channel normalized to <0-1> range.

+
+ +
+
+class mridc.core.neural_types.elements.PredictionsType[source]
+

Bases: mridc.core.neural_types.elements.LabelsType

+

Element type to represent some sort of predictions returned by model

+
+ +
+
+class mridc.core.neural_types.elements.ProbsType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element type to represent probabilities. For example, outputs of softmax layers.

+
+ +
+
+class mridc.core.neural_types.elements.ReconstructionTarget[source]
+

Bases: mridc.core.neural_types.elements.Target

+

Type representing an element being target value in the reconstruction task, i.e. identifier of a desired +class.

+
+ +
+
+class mridc.core.neural_types.elements.RecurrentsType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element type to represent recurrent layers

+
+ +
+
+class mridc.core.neural_types.elements.RegressionValuesType[source]
+

Bases: mridc.core.neural_types.elements.PredictionsType

+

Element type to represent labels for regression task

+
+ +
+
+class mridc.core.neural_types.elements.SequenceToSequenceAlignmentType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Class to represent the alignment from seq-to-seq attention outputs. Generally a mapping from encoder time steps +to decoder time steps.

+
+ +
+
+class mridc.core.neural_types.elements.StringLabel[source]
+

Bases: mridc.core.neural_types.elements.StringType

+

Type representing a label being a string with class name (e.g. the “hamster” class in CIFAR100).

+
+ +
+
+class mridc.core.neural_types.elements.StringType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Element type representing a single string

+
+ +
+
+class mridc.core.neural_types.elements.Target[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Type representing an element being a target value.

+
+ +
+
+class mridc.core.neural_types.elements.VoidType[source]
+

Bases: mridc.core.neural_types.elements.ElementType

+

Void-like type which is compatible with everything. It is a good practice to use this type only as necessary. +For example, when you need template-like functionality.

+
+
+compare(second: abc.ABCMeta) mridc.core.neural_types.comparison.NeuralTypeComparisonResult[source]
+

Void type is compatible with everything.

+
+ +
+ +
+
+

mridc.core.neural_types.neural_type module

+
+
+exception mridc.core.neural_types.neural_type.NeuralPortNameMismatchError(input_port_name)[source]
+

Bases: mridc.core.neural_types.neural_type.NeuralTypeError

+

Exception raised when neural module is called with incorrect port names.

+
+ +
+
+exception mridc.core.neural_types.neural_type.NeuralPortNmTensorMismatchError(class_name, port_name, first_type, second_type, type_compatibility)[source]
+

Bases: mridc.core.neural_types.neural_type.NeuralTypeError

+

Exception raised when a port is fed with a NmTensor of incompatible type.

+
+ +
+
+class mridc.core.neural_types.neural_type.NeuralType(axes: Optional[Tuple] = None, elements_type: mridc.core.neural_types.elements.ElementType = VoidType, optional=False)[source]
+

Bases: object

+
+
This is the main class which would represent neural type concept. It is used to represent the types of inputs and

outputs.

+
+
+
+
Parameters
+
    +
  • axes (a tuple of AxisTypes objects representing the semantics of what varying each axis means. You can use a short,) – string-based form here. For example: (‘B’, ‘C’, ‘H’, ‘W’) would correspond to an NCHW format frequently used in +computer vision. (‘B’, ‘T’, ‘D’) is frequently used for signal processing and means +[batch, time, dimension/channel].

  • +
  • elements_type (an instance of ElementType class representing the semantics of what is stored inside the tensor.) –

  • +
  • example (For) –

  • +
  • optional (By default, this is false. If set to True, it would mean that input to the port of this type can be) –

  • +
  • optional.

  • +
+
+
+
+
+__eq__(other)[source]
+

Checks if two NeuralTypes are equal.

+
+ +
+
+__repr__()[source]
+

Returns string representation of NeuralType.

+
+ +
+
+compare(second) mridc.core.neural_types.comparison.NeuralTypeComparisonResult[source]
+

Performs neural type comparison of self with second. When you chain two modules’ inputs/outputs via __call__ +method, this comparison will be called to ensure neural type compatibility.

+
+ +
+
+compare_and_raise_error(parent_type_name, port_name, second_object)[source]
+

Method compares definition of one type with another and raises an error if not compatible.

+
+ +
+ +
+
+exception mridc.core.neural_types.neural_type.NeuralTypeError[source]
+

Bases: Exception

+

Base class for neural type related exceptions.

+
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.core.optim.html b/docs/build/html/mridc.core.optim.html new file mode 100644 index 00000000..b5047d41 --- /dev/null +++ b/docs/build/html/mridc.core.optim.html @@ -0,0 +1,683 @@ + + + + + + + mridc.core.optim package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.core.optim package

+
+

Submodules

+
+
+

mridc.core.optim.adafactor module

+
+
+class mridc.core.optim.adafactor.Adafactor(params, lr=None, eps=(1e-30, 0.001), clip_threshold=1.0, decay_rate=- 0.8, beta1=None, weight_decay=0.0, scale_parameter=True, relative_step=True, warmup_init=False, min_step=0.01)[source]
+

Bases: torch.optim.optimizer.Optimizer

+

Implements Adafactor algorithm.

+

This implementation is based on: Adafactor: Adaptive Learning Rates with Sublinear Memory Cost +(see https://arxiv.org/abs/1804.04235) +Note that this optimizer internally adjusts the learning rate depending on the scale_parameter, relative_step +and warmup_init options. To use a manual (external) learning rate schedule you should set scale_parameter=False +and relative_step=False.

+
+
Parameters
+
    +
  • params (Iterable of parameters to optimize or dicts defining parameter groups.) – iterable

  • +
  • lr (External learning rate.) – float (optional), (default: None)

  • +
  • eps (Regularization constants for square gradient and parameter scale respectively.) – tuple (float, float), (default: (1e-30, 1e-3))

  • +
  • clip_threshold (Threshold of root-mean-square of final gradient update.) – float, (default: 1.0)

  • +
  • decay_rate (Coefficient used to compute running averages of square gradient.) – float, (default: -0.8)

  • +
  • beta1 (Coefficient used for computing running averages of gradient) – float, (default: None)

  • +
  • weight_decay (Weight decay (L2 penalty).) – float (optional), (default: 0)

  • +
  • scale_parameter (If True, learning rate is scaled by root-mean-square of parameter.) – bool (default: True)

  • +
  • relative_step (If True, time-dependent learning rate is computed instead of external learning rate.) – bool (default: True)

  • +
  • warmup_init (Time-dependent learning rate computation depends on whether warm-up initialization is being used.) – bool (default: False)

  • +
+
+
Return type
+

Adafactor Optimizer

+
+
+
+
+step(closure=None)[source]
+

Performs a single optimization step.

+
+
Parameters
+

closure (A closure that reevaluates the model and returns the loss.) – callable (optional)

+
+
+
+ +
+
+property supports_flat_params
+

Whether the optimizer supports flat parameters.

+
+ +
+
+property supports_memory_efficient_fp16
+

Whether optimizer supports memory efficient fp16

+
+ +
+ +
+
+

mridc.core.optim.lr_scheduler module

+
+
+class mridc.core.optim.lr_scheduler.CosineAnnealing(optimizer, *, max_steps, min_lr=0, last_epoch=- 1, **kwargs)[source]
+

Bases: mridc.core.optim.lr_scheduler.WarmupAnnealHoldPolicy

+

Anneal learning rate by cosine.

+
+ +
+
+class mridc.core.optim.lr_scheduler.InverseSquareRootAnnealing(optimizer, *, max_steps, last_epoch=- 1, min_lr=0.0, **kwargs)[source]
+

Bases: mridc.core.optim.lr_scheduler.WarmupPolicy

+

Inverse square root learning rate annealing.

+
+ +
+
+class mridc.core.optim.lr_scheduler.NoamAnnealing(optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=- 1)[source]
+

Bases: torch.optim.lr_scheduler._LRScheduler

+

Noam learning rate annealing.

+
+
+get_lr()[source]
+

Get learning rate at current step.

+
+ +
+ +
+
+class mridc.core.optim.lr_scheduler.PolynomialDecayAnnealing(optimizer, *, max_steps, min_lr=0.0, power=1.0, cycle=False, last_epoch=- 1, **kwargs)[source]
+

Bases: mridc.core.optim.lr_scheduler.WarmupPolicy

+

Polynomial decay learning rate annealing.

+
+ +
+
+class mridc.core.optim.lr_scheduler.PolynomialHoldDecayAnnealing(optimizer, *, max_steps, min_lr=0.0, power=1.0, cycle=False, last_epoch=- 1, **kwargs)[source]
+

Bases: mridc.core.optim.lr_scheduler.WarmupHoldPolicy

+

Polynomial decay learning rate annealing.

+
+ +
+
+class mridc.core.optim.lr_scheduler.SquareAnnealing(optimizer, *, max_steps, min_lr=1e-05, last_epoch=- 1, **kwargs)[source]
+

Bases: mridc.core.optim.lr_scheduler.WarmupPolicy

+

Anneal learning rate by square.

+
+ +
+
+class mridc.core.optim.lr_scheduler.SquareRootAnnealing(optimizer, *, max_steps, min_lr=0, last_epoch=- 1, **kwargs)[source]
+

Bases: mridc.core.optim.lr_scheduler.WarmupPolicy

+

Anneal learning rate by square root.

+
+ +
+
+class mridc.core.optim.lr_scheduler.SquareRootConstantPolicy(optimizer, *, constant_steps=None, constant_ratio=None, max_steps=None, min_lr=0.0, last_epoch=- 1)[source]
+

Bases: torch.optim.lr_scheduler._LRScheduler

+

Adds warmup kwargs and warmup logic to lr policy. All arguments should be passed as kwargs for clarity.

+
+
Parameters
+
    +
  • warmup_steps (Number of training steps in warmup stage) –

  • +
  • warmup_ratio (Ratio of warmup steps to total steps) –

  • +
  • max_steps (Total number of steps while training or None for infinite training) –

  • +
+
+
+
+
+get_lr()[source]
+

Get learning rate at current step.

+
+ +
+ +
+
+class mridc.core.optim.lr_scheduler.T5InverseSquareRootAnnealing(optimizer, *, max_steps, last_epoch=- 1, min_lr=0.0, **kwargs)[source]
+

Bases: mridc.core.optim.lr_scheduler.SquareRootConstantPolicy

+

Inverse square root learning rate annealing.

+
+ +
+
+class mridc.core.optim.lr_scheduler.WarmupAnnealHoldPolicy(optimizer, *, warmup_steps=None, warmup_ratio=None, constant_steps=None, constant_ratio=None, max_steps=None, min_lr=0.0, last_epoch=- 1)[source]
+

Bases: torch.optim.lr_scheduler._LRScheduler

+

Adds warmup kwargs and warmup logic to lr policy. All arguments should be passed as kwargs for clarity.

+
+
Parameters
+
    +
  • warmup_steps (Number of training steps in warmup stage) –

  • +
  • warmup_ratio (Ratio of warmup steps to total steps) –

  • +
  • max_steps (Total number of steps while training or None for infinite training) –

  • +
  • min_lr (Minimum lr to hold the learning rate after decay at.) –

  • +
  • constant_steps (Number of steps to keep lr constant at.) –

  • +
  • constant_ratio (Ratio of steps to keep lr constant.) –

  • +
+
+
+
+
+get_lr()[source]
+

Get learning rate at current step.

+
+ +
+ +
+
+class mridc.core.optim.lr_scheduler.WarmupAnnealing(optimizer, *, max_steps, last_epoch=- 1, min_lr=0.0, **kwargs)[source]
+

Bases: mridc.core.optim.lr_scheduler.WarmupPolicy

+

Warmup learning rate annealing.

+
+ +
+
+class mridc.core.optim.lr_scheduler.WarmupHoldPolicy(optimizer, *, warmup_steps=None, warmup_ratio=None, hold_steps=None, hold_ratio=None, max_steps=None, min_lr=0.0, last_epoch=- 1)[source]
+

Bases: mridc.core.optim.lr_scheduler.WarmupPolicy

+

Variant of WarmupPolicy which maintains high learning rate for a defined number of steps. All arguments should be +passed as kwargs for clarity,

+
+
Parameters
+
    +
  • warmup_steps (Number of training steps in warmup stage) –

  • +
  • warmup_ratio (Ratio of warmup steps to total steps) –

  • +
  • hold_steps (Number of training steps to hold the learning rate after warm up) –

  • +
  • hold_ratio (Ratio of hold steps to total steps) –

  • +
  • max_steps (Total number of steps while training or None for infinite training) –

  • +
  • Results

  • +
  • -------

  • +
  • steps (Learning rate is linearly increased from 0 to 1 over warmup) –

  • +
  • hold (then linearly decreased from 1 to 0 over) –

  • +
  • steps.

  • +
+
+
+
+
+get_lr()[source]
+

Get learning rate at current step.

+
+ +
+ +
+
+class mridc.core.optim.lr_scheduler.WarmupPolicy(optimizer, *, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=- 1)[source]
+

Bases: torch.optim.lr_scheduler._LRScheduler

+

Adds warmup kwargs and warmup logic to lr policy. All arguments should be passed as kwargs for clarity.

+
+
Parameters
+
    +
  • warmup_steps (Number of training steps in warmup stage.) –

  • +
  • warmup_ratio (Ratio of warmup steps to total steps.) –

  • +
  • max_steps (Total number of steps while training or None for infinite training.) –

  • +
+
+
Returns
+

lr

+
+
Return type
+

Learning rate for current step.

+
+
+
+
+get_lr()[source]
+

Get learning rate at current step.

+
+ +
+ +
+
+mridc.core.optim.lr_scheduler.compute_max_steps(max_epochs, accumulate_grad_batches, limit_train_batches, num_workers, num_samples, batch_size, drop_last)[source]
+

Compute effective max_steps from the provided parameters.

+
+ +
+
+mridc.core.optim.lr_scheduler.get_scheduler(name: str, **kwargs: Optional[Dict[str, Any]]) torch.optim.lr_scheduler._LRScheduler[source]
+

Convenience method to obtain an _LRScheduler class and partially instantiate it with optimizer kwargs.

+
+
Parameters
+
    +
  • name (Name of the scheduler in the registry.) –

  • +
  • kwargs (Optional kwargs of the scheduler used during instantiation.) –

  • +
+
+
Return type
+

A partially instantiated _LRScheduler

+
+
+
+ +
+
+mridc.core.optim.lr_scheduler.prepare_lr_scheduler(optimizer: torch.optim.optimizer.Optimizer, scheduler_config: Optional[Union[Dict[str, Any], omegaconf.dictconfig.DictConfig]], train_dataloader: Optional[torch.utils.data.dataloader.DataLoader] = None) Optional[Dict[str, Any]][source]
+

Constructs an LR Scheduler (optionally) for a given optimizer, based on a config with the following schema.

+
+
Parameters
+
    +
  • optimizer (The optimizer to use for the scheduler.) –

    name: <name of optimizer>

    +

    lr: <maximal learning rate>

    +

    # <additional optimizer arguments>

    +

    args:

    +
    +

    name: auto # special keyword, resolves to correct optimizer config for given optimizer name

    +

    # cls: mridc.core.config.optimizers.NovogradParams # explicit instantiation by class path

    +

    params: # optional override parameters for the optimizer config

    +
    +

    betas: [0.8, 0.5]

    +

    weight_decay: 0.001

    +
    +
    +

  • +
  • scheduler_config (The scheduler config.) –

    name: <name of scheduler>

    +

    iters_per_batch: null # computed at runtime; mandatory to have

    +

    max_steps: null # computed at runtime or explicitly set here; mandatory to have

    +

    # pytorch lightning args <mandatory>

    +

    monitor: val_loss

    +

    reduce_on_plateau: false

    +

    # <scheduler config override>

    +

    args:

    +
    +

    name: auto # special keyword, resolves to correct optimizer config for given optimizer name

    +

    # cls: mridc.core.config.schedulers.CosineAnnealingParams # explicit instantiation by class path

    +

    params: # optional override parameters for the optimizer config

    +
    +

    warmup_steps: null

    +

    warmup_ratio: null

    +

    min_lr: 0.0

    +

    last_epoch: -1

    +
    +
    +

  • +
  • train_dataloader (Optional requirement, must be passed if "iters_per_batch" is defined instead of "max_steps". Used to compute effective "max_steps".) –

  • +
+
+
Return type
+

A dictionary containing the LR Scheduler implementation if the config was successfully parsed along with other parameters required by Pytorch Lightning, otherwise None.

+
+
+
+ +
+
+mridc.core.optim.lr_scheduler.register_scheduler(name: str, scheduler: torch.optim.lr_scheduler._LRScheduler, scheduler_params: mridc.core.conf.schedulers.SchedulerParams)[source]
+

Checks if the scheduler name exists in the registry, and if it doesn’t, adds it. +This allows custom schedulers to be added and called by name during instantiation.

+
+
Parameters
+
    +
  • name (Name of the optimizer. Will be used as key to retrieve the optimizer.) –

  • +
  • scheduler (Scheduler class (inherits from _LRScheduler)) –

  • +
  • scheduler_params (The parameters as a dataclass of the scheduler) –

  • +
+
+
+
+ +
+
+

mridc.core.optim.novograd module

+
+
+class mridc.core.optim.novograd.Novograd(params, lr=0.001, betas=(0.95, 0.98), eps=1e-08, weight_decay=0, grad_averaging=False, amsgrad=False, luc=False, luc_trust=0.001, luc_eps=1e-08)[source]
+

Bases: torch.optim.optimizer.Optimizer

+

Implements Novograd algorithm. +It has been proposed in “Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep +Networks” (https://arxiv.org/abs/1905.11286).

+
+
Parameters
+
    +
  • params (Iterable of parameters to optimize or dicts defining parameter groups.) – iterable

  • +
  • lr (Learning rate.) – float, (default: 1e-3)

  • +
  • betas (Coefficients used for computing running averages of gradient and its square.) – (Tuple[float, float], optional) (default: (0.9, 0.999))

  • +
  • eps (Term added to the denominator to improve numerical stability.) – (float, optional), (default: 1e-8)

  • +
  • (float (weight_decay) –

  • +
  • optional) (weight decay (L2 penalty) (default: 0)) –

  • +
  • amsgrad (whether to use the AMSGrad variant of this algorithm from the paper "On the Convergence of Adam and) –

  • +
  • Beyond". – (boolean, optional), (default: False)

  • +
+
+
+
+
+step(closure=None)[source]
+

Performs a single optimization step.

+
+
Parameters
+

closure (A closure that reevaluates the model and returns the loss.) –

+
+
Returns
+

loss

+
+
Return type
+

Loss (if provided)

+
+
+
+ +
+ +
+
+

mridc.core.optim.optimizer_with_master_params module

+
+
+class mridc.core.optim.optimizer_with_master_params.GradBucket(numel)[source]
+

Bases: object

+

Persistent buffer for main gradients that remains allocated between training iterations.

+
+
+allreduce_buffer()[source]
+

Synchronous buffer data allreduce

+
+ +
+
+get(shape, start_index)[source]
+

Return a tensor with the input shape as a view into the 1-D data starting at start_index.

+
+ +
+
+zero()[source]
+

Reset the buffer to zero.

+
+ +
+ +
+
+class mridc.core.optim.optimizer_with_master_params.MainParamsOptimizerWrapper(optimizer, fp32_grad_accum=False, contiguous_grad_bucket=False, async_grad_allreduce=False)[source]
+

Bases: torch.optim.optimizer.Optimizer

+

Float16 optimizer wrapper for half precision (fp16 and bf16) data types. +This optimizer wrapper holds main parameters and gradients in fp32 to support +stable convergence.

+
+
Parameters
+
    +
  • optimizer (base optimizer such as Adam or SGD.) –

  • +
  • fp32_grad_accum (to enable the use of fp32 in gradient accumulation and allreduce.) –

  • +
  • contiguous_grad_bucket (to enable allocating the master gradients in the contiguous memory space to reduce memory) –

  • +
  • fragmentation.

  • +
  • async_grad_allreduce (enable asynchronous gradient allreduce that is executed along with the training step back prop.) –

  • +
+
+
+
+
+allreduce_main_grads()[source]
+

All reduce main grads.

+
+ +
+
+property async_master_grads_allreudce
+

Return whether to use async allreduce for master grads.

+
+ +
+
+copy_model_grads_to_main_grads()[source]
+

Copy model grads to main grads.

+
+ +
+
+property fp32_grad_accumulation
+

Return whether to accumulate gradients in fp32.

+
+ +
+
+get_parameters()[source]
+

Return the parameters of the optimizer.

+
+ +
+
+grad_sync()[source]
+

A context manager to disable gradient synchronizations across data-parallel ranks.

+
+ +
+
+load_state_dict(state_dict)[source]
+

Load the state of the optimizer.

+
+ +
+
+property param_groups
+

Promote param_groups, so it can be retrieved or set via “optimizer_instance.param_groups. +(for example, to adjust the learning rate)

+
+ +
+
+reload_model_params()[source]
+

Reload model params.

+
+ +
+
+property state
+

Promote state, so it can be retrieved or set via “optimizer_instance.state.

+
+ +
+
+state_dict()[source]
+

Return the state of the optimizer.

+
+ +
+
+step(**kwargs)[source]
+

Step the optimizer.

+
+ +
+
+zero_grad(set_to_none=True)[source]
+

We only need to zero the model related parameters, i.e., float16_groups & fp32_from_fp32_groups. We +additionally zero fp32_from_float16_groups as a memory optimization to reduce fragmentation; in the case of +set_to_none==True, the space used by this field can be safely deallocated at this point.

+
+ +
+ +
+
+

mridc.core.optim.optimizers module

+
+
+mridc.core.optim.optimizers.get_optimizer(name: str, **kwargs: Optional[Dict[str, Any]]) functools.partial[source]
+

Convenience method to obtain an Optimizer class and partially instantiate it with optimizer kwargs.

+
+
Parameters
+
    +
  • name (Name of the Optimizer in the registry.) –

  • +
  • kwargs (Optional kwargs of the optimizer used during instantiation.) –

  • +
+
+
Return type
+

A partially instantiated Optimizer.

+
+
+
+ +
+
+mridc.core.optim.optimizers.parse_optimizer_args(optimizer_name: str, optimizer_kwargs: Union[omegaconf.dictconfig.DictConfig, Dict[str, Any]]) Union[Dict[str, Any], omegaconf.dictconfig.DictConfig][source]
+

Parses a list of strings, of the format “key=value” or “key2=val1,val2,…” +into a dictionary of type {key=value, key2=[val1, val2], …} +This dictionary is then used to instantiate the chosen Optimizer.

+
+
Parameters
+
    +
  • optimizer_name (string name of the optimizer, used for auto resolution of params.) –

  • +
  • optimizer_kwargs (Either a list of strings in a specified format, or a dictionary. If a dictionary is provided, it) –

  • +
  • value (is assumed the dictionary is the final parsed) –

  • +
  • provided (and simply returned. If a list of strings is) –

  • +
  • each

  • +
  • dictionary. (item in the list is parsed into a new) –

  • +
+
+
Return type
+

A dictionary of the parsed arguments.

+
+
+
+ +
+
+mridc.core.optim.optimizers.register_optimizer(name: str, optimizer: torch.optim.optimizer.Optimizer, optimizer_params: mridc.core.conf.optimizers.OptimizerParams)[source]
+

Checks if the optimizer name exists in the registry, and if it doesn’t, adds it. +This allows custom optimizers to be added and called by name during instantiation.

+
+
Parameters
+
    +
  • name (Name of the optimizer. Will be used as key to retrieve the optimizer.) –

  • +
  • optimizer (Optimizer class.) –

  • +
  • optimizer_params (The parameters as a dataclass of the optimizer.) –

  • +
+
+
+
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.core.utils.html b/docs/build/html/mridc.core.utils.html new file mode 100644 index 00000000..8896c10f --- /dev/null +++ b/docs/build/html/mridc.core.utils.html @@ -0,0 +1,258 @@ + + + + + + + mridc.core.utils package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.core.utils package

+
+

Submodules

+
+
+

mridc.core.utils.neural_type_utils module

+
+
+mridc.core.utils.neural_type_utils.extract_dynamic_axes(name: str, ntype: mridc.core.neural_types.neural_type.NeuralType)[source]
+

This method will extract BATCH and TIME dimension ids from each provided input/output name argument.

+

For example, if module/model accepts argument named “input_signal” with type corresponding to [Batch, Time, Dim] +shape, then the returned result should contain “input_signal” -> [0, 1] because Batch and Time are dynamic axes +as they can change from call to call during inference.

+
+
Parameters
+
    +
  • name (Name of input or output parameter) –

  • +
  • ntype (Corresponding Neural Type) –

  • +
+
+
Return type
+

A dictionary with input/output name as key and a list of dynamic axes as value.

+
+
+
+ +
+
+mridc.core.utils.neural_type_utils.get_dynamic_axes(types, names)[source]
+

This method will return a dictionary with input/output names as keys and a list of dynamic axes as values.

+
+
Parameters
+
    +
  • types (The NeuralType of the module or model to be inspected.) –

  • +
  • names (A list of names that should be inspected.) –

  • +
+
+
Return type
+

A dictionary with input/output names as keys and a list of dynamic axes as values.

+
+
+
+ +
+
+mridc.core.utils.neural_type_utils.get_io_names(types, disabled_names)[source]
+

This method will return a list of input and output names for a given NeuralType.

+
+
Parameters
+
    +
  • types (The NeuralType of the module or model to be inspected.) –

  • +
  • disabled_names (A list of names that should be excluded from the result.) –

  • +
+
+
Return type
+

A list of input and output names.

+
+
+
+ +
+
+

mridc.core.utils.numba_utils module

+
+
+mridc.core.utils.numba_utils.is_numba_compat_strict() bool[source]
+

Returns strictness level of numba cuda compatibility checks. +If value is true, numba cuda compatibility matrix must be satisfied. +If value is false, only cuda availability is checked, not compatibility. +Numba Cuda may still compile and run without issues in such a case, or it may fail.

+
+ +
+
+mridc.core.utils.numba_utils.numba_cpu_is_supported(min_version: str) bool[source]
+

Tests if an appropriate version of numba is installed.

+
+
Parameters
+

min_version (The minimum version of numba that is required.) –

+
+
Return type
+

bool, whether numba CPU supported with this current installation or not.

+
+
+
+ +
+
+mridc.core.utils.numba_utils.numba_cuda_is_supported(min_version: str) bool[source]
+

Tests if an appropriate version of numba is installed, and if it is, +if cuda is supported properly within it.

+
+
Parameters
+

min_version (The minimum version of numba that is required.) –

+
+
Return type
+

Whether cuda is supported with this current installation or not.

+
+
+
+ +
+
+mridc.core.utils.numba_utils.set_numba_compat_strictness(strict: bool)[source]
+

Sets the strictness level of numba cuda compatibility checks. +If value is true, numba cuda compatibility matrix must be satisfied. +If value is false, only cuda availability is checked, not compatibility. +Numba Cuda may still compile and run without issues in such a case, or it may fail.

+
+
Parameters
+

strict (Whether to enforce strict compatibility checks or relax them.) –

+
+
+
+ +
+
+mridc.core.utils.numba_utils.skip_numba_cuda_test_if_unsupported(min_version: str)[source]
+

Helper method to skip pytest test case if numba cuda is not supported.

+
+
Parameters
+

min_version (The minimum version of numba that is required.) –

+
+
+
+ +
+
+mridc.core.utils.numba_utils.with_numba_compat_strictness(strict: bool)[source]
+

Context manager to temporarily set numba cuda compatibility strictness.

+
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.html b/docs/build/html/mridc.html new file mode 100644 index 00000000..a61dc0d9 --- /dev/null +++ b/docs/build/html/mridc.html @@ -0,0 +1,272 @@ + + + + + + + mridc package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc package

+
+

Subpackages

+
+ +
+
+
+

Submodules

+
+
+

mridc.constants module

+
+
+

mridc.launch module

+
+
+mridc.launch.main(cfg: omegaconf.dictconfig.DictConfig) None[source]
+

Main function for training and running a model

+
+
Parameters
+

cfg (Configuration (yaml) file.) – DictConfig

+
+
+
+ +
+
+

mridc.package_info module

+
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.utils.decorators.html b/docs/build/html/mridc.utils.decorators.html new file mode 100644 index 00000000..16ea15c6 --- /dev/null +++ b/docs/build/html/mridc.utils.decorators.html @@ -0,0 +1,189 @@ + + + + + + + mridc.utils.decorators package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.utils.decorators package

+
+

Submodules

+
+
+

mridc.utils.decorators.deprecated module

+
+
+mridc.utils.decorators.deprecated.deprecated(wrapped=None, version=None, explanation=None)[source]
+

This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted +when the function is used.

+
+
Parameters
+
    +
  • wrapped (The function to be decorated.) – function

  • +
  • version (The version of the package where the function was deprecated.) – str

  • +
  • explanation (The explanation of the deprecation.) – str

  • +
+
+
Return type
+

The decorated function.

+
+
+
+ +
+
+

mridc.utils.decorators.experimental module

+
+
+mridc.utils.decorators.experimental.experimental(cls)[source]
+

Decorator to mark a class as experimental.

+
+
Parameters
+

cls (The class to be decorated.) – class

+
+
Return type
+

The decorated class.

+
+
+
+ +
+
+

mridc.utils.decorators.port_docs module

+
+
+mridc.utils.decorators.port_docs.add_port_docs(wrapped=None, instance=None, value='')[source]
+

Adds port documentation to the wrapped function.

+
+
Parameters
+
    +
  • wrapped (The function to decorate.) – function

  • +
  • instance (The instance of the function.) – object

  • +
  • value (The value of the port.) – object

  • +
+
+
Return type
+

The decorated function.

+
+
+
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.utils.formaters.html b/docs/build/html/mridc.utils.formaters.html new file mode 100644 index 00000000..b294fc0f --- /dev/null +++ b/docs/build/html/mridc.utils.formaters.html @@ -0,0 +1,572 @@ + + + + + + + mridc.utils.formaters package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.utils.formaters package

+
+

Submodules

+
+
+

mridc.utils.formaters.base module

+
+
+class mridc.utils.formaters.base.BaseMRIDCFormatter(color=True, fmt=None, datefmt=None, colors=None)[source]
+

Bases: mridc.utils.formaters.base.BaseFormatter

+

Base formatter for MRIDC logs.

+
+
+DEFAULT_FORMAT = '%(color)s[MRIDC %(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s'
+
+ +
+ +
+
+class mridc.utils.formaters.base.DebugMRIDCFormatter(color=True, fmt=None, datefmt=None, colors=None)[source]
+

Bases: mridc.utils.formaters.base.BaseFormatter

+

Debug formatter for MRIDC logs.

+
+
+DEFAULT_FORMAT = '%(color)s[MRIDC %(levelname)1.1s %(asctime)s %(module)s:%(lineno)d rank:%(rank)s]%(end_color)s %(message)s'
+
+ +
+ +
+
+

mridc.utils.formaters.colors module

+
+
+class mridc.utils.formaters.colors.AnsiBack[source]
+

Bases: mridc.utils.formaters.colors.AnsiCodes

+

ANSI color codes for background text.

+
+
+BLACK = 40
+
+ +
+
+BLUE = 44
+
+ +
+
+CYAN = 46
+
+ +
+
+GREEN = 42
+
+ +
+
+LIGHTBLACK_EX = 100
+
+ +
+
+LIGHTBLUE_EX = 104
+
+ +
+
+LIGHTCYAN_EX = 106
+
+ +
+
+LIGHTGREEN_EX = 102
+
+ +
+
+LIGHTMAGENTA_EX = 105
+
+ +
+
+LIGHTRED_EX = 101
+
+ +
+
+LIGHTWHITE_EX = 107
+
+ +
+
+LIGHTYELLOW_EX = 103
+
+ +
+
+MAGENTA = 45
+
+ +
+
+RED = 41
+
+ +
+
+RESET = 49
+
+ +
+
+WHITE = 47
+
+ +
+
+YELLOW = 43
+
+ +
+ +
+
+class mridc.utils.formaters.colors.AnsiCodes[source]
+

Bases: object

+

ANSI color codes.

+
+ +
+
+class mridc.utils.formaters.colors.AnsiCursor[source]
+

Bases: object

+

ANSI cursor codes.

+
+
+static BACK(n=1)[source]
+

Move the cursor back n lines.

+
+
Parameters
+

n (Number of lines.) – int

+
+
Returns
+

str

+
+
Return type
+

String of characters.

+
+
+
+ +
+
+static DOWN(n=1)[source]
+

Move the cursor down n lines.

+
+
Parameters
+

n (Number of lines.) – int

+
+
Returns
+

str

+
+
Return type
+

String of characters.

+
+
+
+ +
+
+static FORWARD(n=1)[source]
+

Move the cursor forward n lines.

+
+
Parameters
+

n (Number of lines.) – int

+
+
Returns
+

str

+
+
Return type
+

String of characters.

+
+
+
+ +
+
+static POS(x=1, y=1)[source]
+

Move the cursor to the specified position.

+
+
Parameters
+
    +
  • x (X position.) – int

  • +
  • y (Y position.) – int

  • +
+
+
Returns
+

str

+
+
Return type
+

String of characters.

+
+
+
+ +
+
+static UP(n=1)[source]
+

Move the cursor up n lines.

+
+
Parameters
+

n (Number of lines.) – int

+
+
Returns
+

str

+
+
Return type
+

String of characters.

+
+
+
+ +
+ +
+
+class mridc.utils.formaters.colors.AnsiFore[source]
+

Bases: mridc.utils.formaters.colors.AnsiCodes

+

ANSI color codes for foreground text.

+
+
+BLACK = 30
+
+ +
+
+BLUE = 34
+
+ +
+
+CYAN = 36
+
+ +
+
+GREEN = 32
+
+ +
+
+LIGHTBLACK_EX = 90
+
+ +
+
+LIGHTBLUE_EX = 94
+
+ +
+
+LIGHTCYAN_EX = 96
+
+ +
+
+LIGHTGREEN_EX = 92
+
+ +
+
+LIGHTMAGENTA_EX = 95
+
+ +
+
+LIGHTRED_EX = 91
+
+ +
+
+LIGHTWHITE_EX = 97
+
+ +
+
+LIGHTYELLOW_EX = 93
+
+ +
+
+MAGENTA = 35
+
+ +
+
+RED = 31
+
+ +
+
+RESET = 39
+
+ +
+
+WHITE = 37
+
+ +
+
+YELLOW = 33
+
+ +
+ +
+
+class mridc.utils.formaters.colors.AnsiStyle[source]
+

Bases: mridc.utils.formaters.colors.AnsiCodes

+

ANSI color codes for text styles.

+
+
+BRIGHT = 1
+
+ +
+
+DIM = 2
+
+ +
+
+NORMAL = 22
+
+ +
+
+RESET_ALL = 0
+
+ +
+ +
+
+mridc.utils.formaters.colors.clear_line(mode=2)[source]
+

Clear terminal line.

+
+
Parameters
+

mode (Mode.) – int

+
+
Returns
+

str

+
+
Return type
+

String of characters.

+
+
+
+ +
+
+mridc.utils.formaters.colors.clear_screen(mode=2)[source]
+

Clear terminal screen.

+
+
Parameters
+

mode (Mode.) – int

+
+
Returns
+

str

+
+
Return type
+

String of characters.

+
+
+
+ +
+
+mridc.utils.formaters.colors.code_to_chars(code)[source]
+

Convert ANSI color code to string of characters.

+
+
Parameters
+

code (ANSI color code.) – int

+
+
Returns
+

str

+
+
Return type
+

String of characters.

+
+
+
+ +
+
+mridc.utils.formaters.colors.set_title(title)[source]
+

Set terminal title.

+
+
Parameters
+

title (Title.) – str

+
+
Returns
+

str

+
+
Return type
+

String of characters.

+
+
+
+ +
+
+

mridc.utils.formaters.utils module

+
+
+mridc.utils.formaters.utils.check_color_support()[source]
+
+
Returns
+

bool

+
+
Return type
+

True if the terminal supports color, False otherwise.

+
+
+
+ +
+
+mridc.utils.formaters.utils.to_unicode(value)[source]
+

Converts a string to unicode. If the string is already unicode, it is returned as is. If it is a byte string, it is +decoded using utf-8.

+
+
Parameters
+

value (The string to convert.) – str

+
+
Returns
+

str

+
+
Return type
+

The converted string.

+
+
+
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/mridc.utils.html b/docs/build/html/mridc.utils.html new file mode 100644 index 00000000..70232b36 --- /dev/null +++ b/docs/build/html/mridc.utils.html @@ -0,0 +1,1995 @@ + + + + + + + mridc.utils package — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

mridc.utils package

+
+

Subpackages

+ +
+
+

Submodules

+
+
+

mridc.utils.app_state module

+
+
+class mridc.utils.app_state.AppState(*args, **kwargs)[source]
+

Bases: object

+

A singleton class that holds the state of the application.

+
+
+property checkpoint_callback_params
+

Returns the version set by exp_manager.

+
+ +
+
+property checkpoint_name
+

Returns the name set by exp_manager.

+
+ +
+
+property create_checkpoint_callback
+

Returns the create_checkpoint_callback set by exp_manager.

+
+ +
+
+property data_parallel_group
+

Property returns the data parallel group.

+
+ +
+
+property data_parallel_rank
+

Property returns the data parallel rank.

+
+ +
+
+property data_parallel_size
+

Property returns the number of GPUs in each data parallel group.

+
+ +
+
+property device_id
+

Property returns the device_id.

+
+ +
+
+property exp_dir
+

Returns the exp_dir set by exp_manager.

+
+ +
+
+get_model_metadata_from_guid(guid) mridc.utils.app_state.ModelMetadataRegistry[source]
+

Returns the global model idx and restoration path.

+
+ +
+
+property global_rank
+

Property returns the global rank.

+
+ +
+
+property is_model_being_restored: bool
+

Returns whether a model is being restored.

+
+ +
+
+property local_rank
+

Property returns the local rank.

+
+ +
+
+property log_dir
+

Returns the log_dir set by exp_manager.

+
+ +
+
+property model_parallel_size
+

Property returns the number of GPUs in each model parallel group.

+
+ +
+
+property model_restore_path
+

Returns the model_restore_path set by exp_manager.

+
+ +
+
+property mridc_file_folder: str
+

Returns the mridc_file_folder set by exp_manager.

+
+ +
+
+property name
+

Returns the name set by exp_manager.

+
+ +
+
+property pipeline_model_parallel_group
+

Property returns the model parallel group.

+
+ +
+
+property pipeline_model_parallel_rank
+

Property returns the model parallel rank.

+
+ +
+
+property pipeline_model_parallel_size
+

Property returns the number of GPUs in each model parallel group.

+
+ +
+
+property pipeline_model_parallel_split_rank
+

Property returns the model parallel split rank.

+
+ +
+
+property random_seed
+

Property returns the random seed.

+
+ +
+
+register_model_guid(guid: str, restoration_path: Optional[str] = None)[source]
+

Maps a guid to its restore path (None or last absolute path).

+
+ +
+
+reset_model_guid_registry()[source]
+

Resets the model guid registry.

+
+ +
+
+property tensor_model_parallel_group
+

Property returns the model parallel group.

+
+ +
+
+property tensor_model_parallel_rank
+

Property returns the model parallel rank.

+
+ +
+
+property tensor_model_parallel_size
+

Property returns the number of GPUs in each model parallel group.

+
+ +
+
+property version
+

Returns the version set by exp_manager.

+
+ +
+
+property world_size
+

Property returns the total number of GPUs.

+
+ +
+ +
+
+class mridc.utils.app_state.ModelMetadataRegistry(guid: str, gidx: int, restoration_path: Optional[str] = None)[source]
+

Bases: object

+

A registry for model metadata.

+
+
+gidx: int
+
+ +
+
+guid: str
+
+ +
+
+restoration_path: Optional[str] = None
+
+ +
+ +
+
+

mridc.utils.arguments module

+
+
+mridc.utils.arguments.add_optimizer_args(parent_parser: argparse.ArgumentParser, optimizer: str = 'adam', default_lr: Optional[float] = None, default_opt_args: Optional[Union[Dict[str, Any], List[str]]] = None) argparse.ArgumentParser[source]
+

Extends existing argparse with default optimizer args.

+

# Example of adding optimizer args to command line: +python train_script.py … –optimizer “novograd” –lr 0.01 –opt_args betas=0.95,0.5 weight_decay=0.001

+
+
Parameters
+
    +
  • parent_parser (Custom CLI parser that will be extended.) – ArgumentParser

  • +
  • optimizer (Default optimizer required.) – str, default “adam”

  • +
  • default_lr (Default learning rate.) – float, default None

  • +
  • default_opt_args (Default optimizer arguments.) – Optional[Union[Dict[str, Any], List[str]]], default None

  • +
+
+
Returns
+

ArgumentParser

+
+
Return type
+

Parser extended by Optimizers arguments.

+
+
+
+ +
+
+mridc.utils.arguments.add_recon_args(parent_parser: argparse.ArgumentParser) argparse.ArgumentParser[source]
+

Extends existing argparse with default reconstruction args.

+
+
Parameters
+

parent_parser (Custom CLI parser that will be extended.) – ArgumentParser

+
+
Returns
+

ArgumentParser

+
+
Return type
+

Parser extended by Reconstruction arguments.

+
+
+
+ +
+
+mridc.utils.arguments.add_scheduler_args(parent_parser: argparse.ArgumentParser) argparse.ArgumentParser[source]
+

Extends existing argparse with default scheduler args.

+
+
Parameters
+

parent_parser (Custom CLI parser that will be extended.) – ArgumentParser

+
+
Returns
+

ArgumentParser

+
+
Return type
+

Parser extended by Schedulers arguments.

+
+
+
+ +
+
+

mridc.utils.cloud module

+
+
+mridc.utils.cloud.maybe_download_from_cloud(url, filename, subfolder=None, cache_dir=None, refresh_cache=False) str[source]
+

Download a file from a URL if it does not exist in the cache.

+
+
Parameters
+
    +
  • url (URL to download the file from.) – str

  • +
  • filename (What to download. The request will be issued to url/filename) – str

  • +
  • subfolder (Subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can be empty.) – str

  • +
  • cache_dir (A cache directory where to download. If not present, this function will attempt to create it.) – str, If None (default), then it will be $HOME/.cache/torch/mridc

  • +
  • refresh_cache (If True and cached file is present, it will delete it and re-fetch) – bool

  • +
+
+
Return type
+

If successful - absolute local path to the downloaded file else empty string.

+
+
+
+ +
+
+

mridc.utils.config_utils module

+
+
+mridc.utils.config_utils.assert_dataclass_signature_match(cls: class_type, datacls: dataclass, ignore_args: Optional[List[str]] = None, remap_args: Optional[Dict[str, str]] = None)[source]
+

Analyses the signature of a provided class and its respective data class, +asserting that the dataclass signature matches the class __init__ signature. +.. note:

+
This is not a value based check. This function only checks if all argument
+names exist on both class and dataclass and logs mismatches.
+
+
+
+
Parameters
+
    +
  • cls (Any class type - but not an instance of a class. Pass type(x) where x is an instance) – if class type is not easily available.

  • +
  • datacls (A corresponding dataclass for the above class.) –

  • +
  • ignore_args ((Optional) A list of string argument names which are forcibly ignored,) – even if mismatched in the signature. Useful when a dataclass is a superset of the +arguments of a class.

  • +
  • remap_args ((Optional) A dictionary, mapping an argument name that exists (in either the) – class or its dataclass), to another name. Useful when argument names are mismatched between +a class and its dataclass due to indirect instantiation via a helper method.

  • +
+
+
Returns
+

    +
  1. +
    A bool value which is True if the signatures matched exactly / after ignoring values.

    False otherwise.

    +
    +
    +
  2. +
  3. +
    A set of arguments names that exist in the class, but do not exist in the dataclass.

    If exact signature match occurs, this will be None instead.

    +
    +
    +
  4. +
  5. +
    A set of argument names that exist in the data class, but do not exist in the class itself.

    If exact signature match occurs, this will be None instead.

    +
    +
    +
  6. +
+

+
+
Return type
+

A tuple containing information about the analysis

+
+
+
+ +
+
+mridc.utils.config_utils.update_model_config(model_cls: mridc.core.conf.modelPT.MRIDCConfig, update_cfg: omegaconf.dictconfig.DictConfig, drop_missing_subconfigs: bool = True)[source]
+
+
Helper class that updates the default values of a ModelPT config class with the values in a DictConfig that mirrors the structure of the config class. Assumes the update_cfg is a DictConfig (either generated manually, via hydra or instantiated via yaml/model.cfg). This update_cfg is then used to override the default values preset inside the ModelPT config class. If drop_missing_subconfigs is set, the certain sub-configs of the ModelPT config class will be removed, if they are not found in the mirrored update_cfg. The following sub-configs are subject to potential removal:
    +
  • train_ds

  • +
  • validation_ds

  • +
  • test_ds

  • +
  • optim + nested sched

  • +
+
+
+
+
Parameters
+
    +
  • model_cls (A subclass of MRIDC, that details in entirety all the parameters that constitute the MRIDC Model.) –

  • +
  • update_cfg (A DictConfig that mirrors the structure of the MRIDCConfig data class. Used to update the default values of the config class.) –

  • +
  • drop_missing_subconfigs (Bool which determines whether to drop certain sub-configs from the MRIDCConfig class, if the corresponding sub-config is missing from update_cfg.) –

  • +
+
+
Return type
+

A DictConfig with updated values that can be used to instantiate the MRIDC Model along with supporting infrastructure.

+
+
+
+ +
+
+

mridc.utils.distributed module

+
+
+mridc.utils.distributed.initialize_distributed(args, backend='nccl')[source]
+

Initialize distributed training.

+
+
Parameters
+
    +
  • args (The arguments object.) –

  • +
  • backend (The backend to use.) – default: “nccl”

  • +
+
+
Returns
+

    +
  • local_rank (The local rank of the process.)

  • +
  • rank (The rank of the process.)

  • +
  • world_size (The number of processes.)

  • +
+

+
+
+
+ +
+
+

mridc.utils.env_var_parsing module

+
+
+exception mridc.utils.env_var_parsing.CoercionError(key, value, func)[source]
+

Bases: Exception

+

Custom error raised when a value cannot be coerced.

+
+ +
+
+exception mridc.utils.env_var_parsing.RequiredSettingMissingError(key)[source]
+

Bases: Exception

+

Custom error raised when a required env var is missing.

+
+ +
+
+mridc.utils.env_var_parsing.get_env(key, *default, **kwargs)[source]
+

Return env var. This is the parent function of all other get_foo functions, and is responsible for unpacking args/kwargs into the values that _get_env expects (it is the root function that actually interacts with environ).

+
+
Parameters
+
    +
  • key (string, the env var name to look up.) –

  • +
  • default ((optional) the value to use if the env var does not exist. If this value is not supplied, then the env var is considered to be required, and a RequiredSettingMissingError error will be raised if it does not exist.) –

  • +
  • kwargs – coerce: a func that may be supplied to coerce the value into something else. This is used by the default get_foo functions to cast strings to builtin types, but could be a function that returns a custom class.

  • +
+
+
Return type
+

The env var, coerced if required, and a default if supplied.

+
+
+
+ +
+
+mridc.utils.env_var_parsing.get_envbool(key, *default)[source]
+

Return env var cast as boolean.

+
+ +
+
+mridc.utils.env_var_parsing.get_envdate(key, *default)[source]
+

Return env var as a date.

+
+ +
+
+mridc.utils.env_var_parsing.get_envdatetime(key, *default)[source]
+

Return env var as a datetime.

+
+ +
+
+mridc.utils.env_var_parsing.get_envdecimal(key, *default)[source]
+

Return env var cast as Decimal.

+
+ +
+
+mridc.utils.env_var_parsing.get_envdict(key, *default)[source]
+

Return env var as a dict.

+
+ +
+
+mridc.utils.env_var_parsing.get_envfloat(key, *default)[source]
+

Return env var cast as float.

+
+ +
+
+mridc.utils.env_var_parsing.get_envint(key, *default)[source]
+

Return env var cast as integer.

+
+ +
+
+mridc.utils.env_var_parsing.get_envlist(key, *default, **kwargs)[source]
+

Return env var as a list.

+
+ +
+
+

mridc.utils.exceptions module

+
+
+class mridc.utils.exceptions.CheckInstall(*args, **kwargs)[source]
+

Bases: object

+

Class to check if a package is installed.

+
+ +
+
+exception mridc.utils.exceptions.LightningNotInstalledException(obj)[source]
+

Bases: mridc.utils.exceptions.MRIDCBaseException

+

Exception for when lightning is not installed

+
+ +
+
+exception mridc.utils.exceptions.MRIDCBaseException[source]
+

Bases: Exception

+

MRIDC Base Exception. All exceptions created in MRIDC should inherit from this class

+
+ +
+
+

mridc.utils.exp_manager module

+
+
+class mridc.utils.exp_manager.CallbackParams(filepath: Optional[str] = None, dirpath: Optional[str] = None, filename: Optional[str] = None, monitor: Optional[str] = 'val_loss', verbose: Optional[bool] = True, save_last: Optional[bool] = True, save_top_k: Optional[int] = 3, save_weights_only: Optional[bool] = False, mode: Optional[str] = 'min', every_n_epochs: Optional[int] = 1, prefix: Optional[str] = None, postfix: str = '.mridc', save_best_model: bool = False, always_save_mridc: bool = False, save_mridc_on_train_end: Optional[bool] = True, model_parallel_size: Optional[int] = None)[source]
+

Bases: object

+

Parameters for a callback

+
+
+always_save_mridc: bool = False
+
+ +
+
+dirpath: Optional[str] = None
+
+ +
+
+every_n_epochs: Optional[int] = 1
+
+ +
+
+filename: Optional[str] = None
+
+ +
+
+filepath: Optional[str] = None
+
+ +
+
+mode: Optional[str] = 'min'
+
+ +
+
+model_parallel_size: Optional[int] = None
+
+ +
+
+monitor: Optional[str] = 'val_loss'
+
+ +
+
+postfix: str = '.mridc'
+
+ +
+
+prefix: Optional[str] = None
+
+ +
+
+save_best_model: bool = False
+
+ +
+
+save_last: Optional[bool] = True
+
+ +
+
+save_mridc_on_train_end: Optional[bool] = True
+
+ +
+
+save_top_k: Optional[int] = 3
+
+ +
+
+save_weights_only: Optional[bool] = False
+
+ +
+
+verbose: Optional[bool] = True
+
+ +
+ +
+
+exception mridc.utils.exp_manager.CheckpointMisconfigurationError[source]
+

Bases: mridc.utils.exceptions.MRIDCBaseException

+

Raised when a mismatch between trainer.callbacks and exp_manager occurs

+
+ +
+
+class mridc.utils.exp_manager.ExpManagerConfig(explicit_log_dir: Optional[str] = None, exp_dir: Optional[str] = None, name: Optional[str] = None, version: Optional[str] = None, use_datetime_version: Optional[bool] = True, resume_if_exists: Optional[bool] = False, resume_past_end: Optional[bool] = False, resume_ignore_no_checkpoint: Optional[bool] = False, create_tensorboard_logger: Optional[bool] = True, summary_writer_kwargs: Optional[Dict[Any, Any]] = None, create_wandb_logger: Optional[bool] = False, wandb_logger_kwargs: Optional[Dict[Any, Any]] = None, create_checkpoint_callback: Optional[bool] = True, checkpoint_callback_params: Optional[mridc.utils.exp_manager.CallbackParams] = CallbackParams(filepath=None, dirpath=None, filename=None, monitor='val_loss', verbose=True, save_last=True, save_top_k=3, save_weights_only=False, mode='min', every_n_epochs=1, prefix=None, postfix='.mridc', save_best_model=False, always_save_mridc=False, save_mridc_on_train_end=True, model_parallel_size=None), files_to_copy: Optional[List[str]] = None, log_step_timing: Optional[bool] = True, step_timing_kwargs: Optional[mridc.utils.exp_manager.StepTimingParams] = StepTimingParams(reduction='mean', sync_cuda=False, buffer_size=1), log_local_rank_0_only: Optional[bool] = False, log_global_rank_0_only: Optional[bool] = False, model_parallel_size: Optional[int] = None)[source]
+

Bases: object

+

Configuration for the experiment manager.

+
+
+checkpoint_callback_params: Optional[mridc.utils.exp_manager.CallbackParams] = CallbackParams(filepath=None, dirpath=None, filename=None, monitor='val_loss', verbose=True, save_last=True, save_top_k=3, save_weights_only=False, mode='min', every_n_epochs=1, prefix=None, postfix='.mridc', save_best_model=False, always_save_mridc=False, save_mridc_on_train_end=True, model_parallel_size=None)
+
+ +
+
+create_checkpoint_callback: Optional[bool] = True
+
+ +
+
+create_tensorboard_logger: Optional[bool] = True
+
+ +
+
+create_wandb_logger: Optional[bool] = False
+
+ +
+
+exp_dir: Optional[str] = None
+
+ +
+
+explicit_log_dir: Optional[str] = None
+
+ +
+
+files_to_copy: Optional[List[str]] = None
+
+ +
+
+log_global_rank_0_only: Optional[bool] = False
+
+ +
+
+log_local_rank_0_only: Optional[bool] = False
+
+ +
+
+log_step_timing: Optional[bool] = True
+
+ +
+
+model_parallel_size: Optional[int] = None
+
+ +
+
+name: Optional[str] = None
+
+ +
+
+resume_if_exists: Optional[bool] = False
+
+ +
+
+resume_ignore_no_checkpoint: Optional[bool] = False
+
+ +
+
+resume_past_end: Optional[bool] = False
+
+ +
+
+step_timing_kwargs: Optional[mridc.utils.exp_manager.StepTimingParams] = StepTimingParams(reduction='mean', sync_cuda=False, buffer_size=1)
+
+ +
+
+summary_writer_kwargs: Optional[Dict[Any, Any]] = None
+
+ +
+
+use_datetime_version: Optional[bool] = True
+
+ +
+
+version: Optional[str] = None
+
+ +
+
+wandb_logger_kwargs: Optional[Dict[Any, Any]] = None
+
+ +
+ +
+
+class mridc.utils.exp_manager.LoggerList(_logger_iterable, mridc_name=None, mridc_version='')[source]
+

Bases: pytorch_lightning.loggers.base.LoggerCollection

+

A thin wrapper on Lightning’s LoggerCollection such that name and version are better aligned with exp_manager

+
+
+property name: str
+

The name of the experiment.

+
+ +
+
+property version: str
+

The version of the experiment. If the logger was created with a version, this will be the version.

+
+ +
+ +
+
+exception mridc.utils.exp_manager.LoggerMisconfigurationError(message)[source]
+

Bases: mridc.utils.exceptions.MRIDCBaseException

+

Raised when a mismatch between trainer.logger and exp_manager occurs

+
+ +
+
+class mridc.utils.exp_manager.MRIDCModelCheckpoint(always_save_mridc=False, save_mridc_on_train_end=True, save_best_model=False, postfix='.mridc', n_resume=False, model_parallel_size=None, **kwargs)[source]
+

Bases: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint

+

Light wrapper around Lightning’s ModelCheckpoint to force a saved checkpoint on train_end

+
+
+mridc_topk_check_previous_run()[source]
+

Check if there are previous runs with the same topk value.

+
+ +
+
+on_save_checkpoint(trainer, pl_module, checkpoint)[source]
+

Override the default on_save_checkpoint to save the best model if needed.

+
+
Parameters
+
    +
  • trainer (The trainer object.) –

  • +
  • pl_module (The PyTorch-Lightning module.) –

  • +
  • checkpoint (The checkpoint object.) –

  • +
+
+
+
+ +
+
+on_train_end(trainer, pl_module)[source]
+

This is called at the end of training.

+
+
Parameters
+
    +
  • trainer (The trainer object.) –

  • +
  • pl_module (The PyTorch-Lightning module.) –

  • +
+
+
+
+ +
+ +
+
+exception mridc.utils.exp_manager.NotFoundError[source]
+

Bases: mridc.utils.exceptions.MRIDCBaseException

+

Raised when a file or folder is not found

+
+ +
+
+class mridc.utils.exp_manager.StatelessTimer(duration: Optional[Union[str, datetime.timedelta, Dict[str, int]]] = None, interval: str = Interval.step, verbose: bool = True)[source]
+

Bases: pytorch_lightning.callbacks.timer.Timer

+

Extension of PTL timers to be per run.

+
+
+load_state_dict(state_dict: Dict[str, Any]) None[source]
+

Loads the state of the timer.

+
+ +
+
+state_dict() Dict[str, Any][source]
+

Saves the state of the timer.

+
+ +
+ +
+
+class mridc.utils.exp_manager.StepTimingParams(reduction: Optional[str] = 'mean', sync_cuda: Optional[bool] = False, buffer_size: Optional[int] = 1)[source]
+

Bases: object

+

Parameters for the step timing callback.

+
+
+buffer_size: Optional[int] = 1
+
+ +
+
+reduction: Optional[str] = 'mean'
+
+ +
+
+sync_cuda: Optional[bool] = False
+
+ +
+ +
+
+class mridc.utils.exp_manager.TimingCallback(timer_kwargs=None)[source]
+

Bases: pytorch_lightning.callbacks.base.Callback

+

Logs execution time of train/val/test steps

+
+
+on_after_backward(trainer, pl_module)[source]
+

Note: this is called after the optimizer step

+
+ +
+
+on_before_backward(trainer, pl_module, loss)[source]
+

Logs the time taken for backward pass

+
+ +
+
+on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]
+

Logs execution time of test steps

+
+ +
+
+on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]
+

Logs execution time of test steps

+
+ +
+
+on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, **kwargs)[source]
+

Logs the time taken by the training batch

+
+ +
+
+on_train_batch_start(trainer, pl_module, batch, batch_idx, **kwargs)[source]
+

Called at the beginning of each training batch

+
+ +
+
+on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]
+

Logs the time taken by the validation step

+
+ +
+
+on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]
+

Logs the time taken by the validation batch

+
+ +
+ +
+
+mridc.utils.exp_manager.check_explicit_log_dir(trainer: pytorch_lightning.trainer.trainer.Trainer, explicit_log_dir: List[Union[pathlib.Path, str]], exp_dir: str, name: str, version: str) Tuple[pathlib.Path, str, str, str][source]
+

Checks that the passed arguments are compatible with explicit_log_dir.

+
+
Parameters
+
    +
  • trainer (The trainer to check.) –

  • +
  • explicit_log_dir (The explicit log dir to check.) –

  • +
  • exp_dir (The experiment directory to check.) –

  • +
  • name (The experiment name to check.) –

  • +
  • version (The experiment version to check.) –

  • +
+
+
Return type
+

The log_dir, exp_dir, name, and version that should be used.

+
+
Raises
+

LoggerMisconfigurationError

+
+
+
+ +
+
+mridc.utils.exp_manager.check_resume(trainer: pytorch_lightning.trainer.trainer.Trainer, log_dir: str, resume_past_end: bool = False, resume_ignore_no_checkpoint: bool = False)[source]
+

Checks that resume=True was used correctly with the arguments pass to exp_manager. Sets +trainer._checkpoint_connector.resume_from_checkpoint_fit_path as necessary.

+
+
Parameters
+
    +
  • trainer (The trainer that is being used.) –

  • +
  • log_dir (The directory where the logs are being saved.) –

  • +
  • resume_past_end (Whether to resume from the end of the experiment.) –

  • +
  • resume_ignore_no_checkpoint (Whether to ignore if there is no checkpoint to resume from.) –

  • +
+
+
Returns
+

    +
  • NotFoundError (If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found.)

  • +
  • ValueError (If resume is True, and there were more than 1 checkpoint could found.)

  • +
+

+
+
+
+ +
+
+mridc.utils.exp_manager.check_slurm(trainer)[source]
+

Checks if the trainer is running on a slurm cluster. If so, it will check if the trainer is running on the master +node. If it is not, it will exit.

+
+
Parameters
+

trainer (The trainer to check.) –

+
+
Return type
+

True if the trainer is running on the master node, False otherwise.

+
+
+
+ +
+
+mridc.utils.exp_manager.configure_checkpointing(trainer: pytorch_lightning.trainer.trainer.Trainer, log_dir: pathlib.Path, name: str, resume: bool, params: omegaconf.dictconfig.DictConfig)[source]
+

Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint +callback or if trainer.weights_save_path was passed to Trainer.

+
+ +
+
+mridc.utils.exp_manager.configure_loggers(trainer: pytorch_lightning.trainer.trainer.Trainer, exp_dir: List[Union[pathlib.Path, str]], name: str, version: str, create_tensorboard_logger: bool, summary_writer_kwargs: dict, create_wandb_logger: bool, wandb_kwargs: dict)[source]
+

Creates TensorboardLogger and/or WandBLogger and attach them to trainer. Raises ValueError if summary_writer_kwargs +or wandb_kwargs are miss configured.

+
+
Parameters
+
    +
  • trainer (The trainer to attach the loggers to.) –

  • +
  • exp_dir (The experiment directory.) –

  • +
  • name (The name of the experiment.) –

  • +
  • version (The version of the experiment.) –

  • +
  • create_tensorboard_logger (Whether to create a TensorboardLogger.) –

  • +
  • summary_writer_kwargs (The kwargs to pass to the TensorboardLogger.) –

  • +
  • create_wandb_logger (Whether to create a Weights & Biases logger.) –

  • +
  • wandb_kwargs (The kwargs to pass to the Weights & Biases logger.) –

  • +
+
+
Returns
+

LoggerList

+
+
Return type
+

A list of loggers.

+
+
+
+ +
+
+mridc.utils.exp_manager.error_checks(trainer: pytorch_lightning.trainer.trainer.Trainer, cfg: Optional[Union[omegaconf.dictconfig.DictConfig, Dict]] = None)[source]
+
+
Checks that the passed trainer is compliant with MRIDC and exp_manager’s passed configuration. Checks that:
    +
  • Throws error when hydra has changed the working directory. This causes issues with lightning’s DDP

  • +
  • Throws error when trainer has loggers defined but create_tensorboard_logger or create_WandB_logger is True

  • +
  • Prints error messages when 1) run on multi-node and not Slurm, and 2) run on multi-gpu without DDP

  • +
+
+
+
+ +
+
+mridc.utils.exp_manager.exp_manager(trainer: pytorch_lightning.trainer.trainer.Trainer, cfg: Optional[Union[omegaconf.dictconfig.DictConfig, Dict]] = None) Optional[pathlib.Path][source]
+

exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will get exp_dir, name, and version from the logger. Otherwise, it will use the exp_dir and name arguments to create the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir.

+

The version can be a datetime string or an integer. Datetime version can be disabled if you use_datetime_version is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file for each process to log their output into.

+

exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need multiple consecutive jobs), you need to avoid creating the version folders. Therefore, from v1.0.0, when resume_if_exists is set to True, creating the version folders is ignored.

+
+
Parameters
+
    +
  • trainer (The lightning trainer object.) –

  • +
  • cfg (Can have the following keys:) –

      +
    • explicit_log_dir: Can be used to override exp_dir/name/version folder creation. Defaults to None, which will use exp_dir, name, and version to construct the logging directory.

    • +
    • exp_dir: The base directory to create the logging directory. Defaults to None, which logs to ./mridc_experiments.

    • +
    • name: The name of the experiment. Defaults to None which turns into “default” via name = name or “default”.

    • +
    • version: The version of the experiment. Defaults to None which uses either a datetime string or lightning’s TensorboardLogger system of using version_{int}.

    • +
    • use_datetime_version: Whether to use a datetime string for version. Defaults to True.

    • +
    • resume_if_exists: Whether this experiment is resuming from a previous run. If True, it sets trainer._checkpoint_connector.resume_from_checkpoint_fit_path so that the trainer should auto-resume. exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. From v1.0.0, when resume_if_exists is True, we would not create version folders to make it easier to find the log folder for next runs.

    • +
    • resume_past_end: exp_manager errors out if resume_if_exists is True and a checkpoint matching *end.ckpt indicating a previous training run fully completed. This behaviour can be disabled, in which case the *end.ckpt will be loaded by setting resume_past_end to True. Defaults to False.

    • +
    • resume_ignore_no_checkpoint: exp_manager errors out if resume_if_exists is True and no checkpoint could be found. This behaviour can be disabled, in which case exp_manager will print a message and continue without restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False.

    • +
    • create_tensorboard_logger: Whether to create a tensorboard logger and attach it to the pytorch lightning trainer. Defaults to True.

    • +
    • summary_writer_kwargs: A dictionary of kwargs that can be passed to lightning’s TensorboardLogger class. Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None.

    • +
    • create_wandb_logger: Whether to create a Weights and Biases logger and attach it to the pytorch lightning trainer. Defaults to False.

    • +
    • wandb_logger_kwargs: A dictionary of kwargs that can be passed to lightning’s WandBLogger class. Note that name and project are required parameters if create_wandb_logger is True. Defaults to None.

    • +
    • create_checkpoint_callback: Whether to create a ModelCheckpoint callback and attach it to the pytorch lightning trainer. The ModelCheckpoint saves the top 3 models with the best “val_loss”, the most recent checkpoint under *last.ckpt, and the final checkpoint after training completes under *end.ckpt. Defaults to True.

    • +
    • files_to_copy: A list of files to copy to the experiment logging directory. Defaults to None which copies no files.

    • +
    • log_local_rank_0_only: Whether to only create log files for local rank 0. Defaults to False. Set this to True if you are using DDP with many GPUs and do not want many log files in your exp dir.

    • +
    • log_global_rank_0_only: Whether to only create log files for global rank 0. Defaults to False. Set this to True if you are using DDP with many GPUs and do not want many log files in your exp dir.

    • +
    +

  • +
+
+
Return type
+

The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and version.

+
+
+
+ +
+
+mridc.utils.exp_manager.get_git_diff()[source]
+

Helper function that tries to get the git diff if running inside a git folder.

+
+
Returns
+

    +
  • Bool (Whether the git subprocess ran without error.)

  • +
  • String (git subprocess output or error message)

  • +
+

+
+
+
+ +
+
+mridc.utils.exp_manager.get_git_hash()[source]
+

Helper function that tries to get the commit hash if running inside a git folder.

+
+
Returns
+

    +
  • Bool (Whether the git subprocess ran without error.)

  • +
  • String (git subprocess output or error message)

  • +
+

+
+
+
+ +
+
+mridc.utils.exp_manager.get_log_dir(trainer: pytorch_lightning.trainer.trainer.Trainer, exp_dir: Optional[str] = None, name: Optional[str] = None, version: Optional[str] = None, explicit_log_dir: Optional[str] = None, use_datetime_version: bool = True, resume_if_exists: bool = False) Tuple[pathlib.Path, str, str, str][source]
+

Obtains the log_dir used for exp_manager.

+
+
Parameters
+
    +
  • trainer (The trainer to check.) –

  • +
  • exp_dir (The experiment directory to check.) –

  • +
  • name (The experiment name to check.) –

  • +
  • version (The experiment version to check.) –

  • +
  • explicit_log_dir (The explicit log dir to check.) –

  • +
  • use_datetime_version (Whether to use datetime versioning.) –

  • +
  • resume_if_exists (Whether to resume if the log_dir already exists.) –

  • +
+
+
Raises
+
    +
  • LoggerMisconfigurationError – If trainer is incompatible with arguments:

  • +
  • NotFoundError – If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found.:

  • +
  • ValueError – If resume is True, and there were more than 1 checkpoint could found.:

  • +
+
+
+
+ +
+
+

mridc.utils.export_utils module

+
+
+class mridc.utils.export_utils.CastToFloat(mod)[source]
+

Bases: torch.nn.modules.module.Module

+

Cast input to float

+
+
+forward(x)[source]
+

Forward pass

+
+ +
+
+training: bool
+
+ +
+ +
+
+class mridc.utils.export_utils.ExportFormat(value)[source]
+

Bases: enum.Enum

+

Which format to use when exporting a Neural Module for deployment

+
+
+ONNX = (1,)
+
+ +
+
+TORCHSCRIPT = (2,)
+
+ +
+ +
+
+mridc.utils.export_utils.augment_filename(output: str, prepend: str)[source]
+

Augment output filename with prepend

+
+ +
+
+mridc.utils.export_utils.forward_method(self)[source]
+

Forward method for export

+
+ +
+
+mridc.utils.export_utils.get_export_format(filename: str)[source]
+

Get export format from filename

+
+ +
+
+mridc.utils.export_utils.parse_input_example(input_example)[source]
+

Parse input example to onnxrt input format

+
+ +
+
+mridc.utils.export_utils.replace_for_export(model: torch.nn.modules.module.Module) torch.nn.modules.module.Module[source]
+

Top-level function to replace default set of modules in model +NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.

+
+
Parameters
+

model (Top-level model to replace modules in.) –

+
+
Return type
+

The model with replaced modules.

+
+
+
+ +
+
+mridc.utils.export_utils.replace_modules(model: torch.nn.modules.module.Module, expansions: Optional[Dict[str, Callable[[torch.nn.modules.module.Module], Optional[torch.nn.modules.module.Module]]]] = None) torch.nn.modules.module.Module[source]
+

Top-level function to replace modules in model, specified by class name with a desired replacement. +NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.

+
+
Parameters
+
    +
  • model (Top-level model to replace modules in.) –

  • +
  • expansions (A dictionary of module class names to functions to replace them with.) –

  • +
+
+
Return type
+

The model with replaced modules.

+
+
+
+ +
+
+mridc.utils.export_utils.simple_replace(BaseT: Type[torch.nn.modules.module.Module], DestT: Type[torch.nn.modules.module.Module]) Callable[[torch.nn.modules.module.Module], Optional[torch.nn.modules.module.Module]][source]
+

Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same attributes. +No weights are copied.

+
+
Parameters
+
    +
  • BaseT (The base type of the module.) –

  • +
  • DestT (The destination type of the module.) –

  • +
+
+
Return type
+

A function to replace BaseT with DestT.

+
+
+
+ +
+
+mridc.utils.export_utils.swap_modules(model: torch.nn.modules.module.Module, mapping: Dict[str, torch.nn.modules.module.Module])[source]
+

This function swaps nested modules as specified by “dot paths” in mod with a desired replacement. This allows +for swapping nested modules through arbitrary levels if children +NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.

+
+ +
+
+mridc.utils.export_utils.to_onnxrt_input(input_names, input_dict, input_list)[source]
+

Transforms input to onnxrt input format

+
+ +
+
+mridc.utils.export_utils.verify_runtime(output, input_list, input_dict, input_names, output_names, output_example, check_tolerance=0.01)[source]
+

Verify runtime output with onnxrt.

+
+
Parameters
+
    +
  • output (The output of the module.) –

  • +
  • input_list (The input list of the module.) –

  • +
  • input_dict (The input dict of the module.) –

  • +
  • input_names (The input names of the module.) –

  • +
  • output_names (The output names of the module.) –

  • +
  • output_example (The output example of the module.) –

  • +
  • check_tolerance (The tolerance for the check.) –

  • +
+
+
Return type
+

The runtime output.

+
+
+
+ +
+
+mridc.utils.export_utils.wrap_forward_method(self)[source]
+

Wraps the forward method of the module with a function that returns the output of the forward method

+
+ +
+
+mridc.utils.export_utils.wrap_module(BaseT: Type[torch.nn.modules.module.Module], DestT: Type[torch.nn.modules.module.Module]) Callable[[torch.nn.modules.module.Module], Optional[torch.nn.modules.module.Module]][source]
+

Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same attributes. +No weights are copied.

+
+
Parameters
+
    +
  • BaseT (The base type of the module.) –

  • +
  • DestT (The destination type of the module.) –

  • +
+
+
Return type
+

A function to replace BaseT with DestT.

+
+
+
+ +
+
+

mridc.utils.get_rank module

+
+
+mridc.utils.get_rank.is_global_rank_zero()[source]
+

Helper function to determine if the current process is global_rank 0 (the main process).

+
+ +
+
+

mridc.utils.lightning_logger_patch module

+
+
+mridc.utils.lightning_logger_patch.add_filehandlers_to_pl_logger(all_log_file, err_log_file)[source]
+

Adds two filehandlers to pytorch_lightning’s logger. Called in mridc.utils.exp_manager(). The first filehandler +logs all messages to all_log_file while the second filehandler logs all WARNING and higher messages to +err_log_file. If “memory_err” and “memory_all” exist in HANDLERS, then those buffers are flushed to err_log_file +and all_log_file respectively, and then closed.

+
+ +
+
+mridc.utils.lightning_logger_patch.add_memory_handlers_to_pl_logger()[source]
+

Adds two MemoryHandlers to pytorch_lightning’s logger. These two handlers are essentially message buffers. This +function is called in mridc.utils.__init__.py. These handlers are used in add_filehandlers_to_pl_logger to flush +buffered messages to files.

+
+ +
+
+

mridc.utils.metaclasses module

+
+
+class mridc.utils.metaclasses.Singleton[source]
+

Bases: type

+

Implementation of a generic, tread-safe singleton meta-class. Can be used as meta-class, i.e. will create.

+
+
+__call__(*args, **kwargs)[source]
+

Returns singleton instance. A thread safe implementation.

+
+ +
+ +
+
+

mridc.utils.model_utils module

+
+
+class mridc.utils.model_utils.ArtifactItem[source]
+

Bases: object

+

ArtifactItem is a dataclass that holds the information of an artifact.

+
+
+hashed_path: Optional[str] = None
+
+ +
+
+path: str
+
+ +
+
+path_type: mridc.utils.model_utils.ArtifactPathType
+
+ +
+ +
+
+class mridc.utils.model_utils.ArtifactPathType(value)[source]
+

Bases: enum.Enum

+

ArtifactPathType refers to the type of the path that the artifact is located at. +LOCAL_PATH: A user local filepath that exists on the file system. +TAR_PATH: A (generally flattened) filepath that exists inside of an archive (that may have its own full path).

+
+
+LOCAL_PATH = 0
+
+ +
+
+TAR_PATH = 1
+
+ +
+ +
+
+mridc.utils.model_utils.check_lib_version(lib_name: str, checked_version: str, operator) Tuple[Optional[bool], str][source]
+

Checks if a library is installed, and if it is, checks the operator(lib.__version__, checked_version) as a result. +This bool result along with a string analysis of result is returned. +If the library is not installed at all, then returns None instead, along with a string explaining +that the library is not installed

+
+
Parameters
+
    +
  • lib_name (lower case str name of the library that must be imported.) –

  • +
  • checked_version (semver string that is compared against lib.__version__.) –

  • +
  • operator (binary callable function func(a, b) -> bool; that compares lib.__version__ against version in some) –

  • +
  • boolean. (manner. Must return a) –

  • +
+
+
Returns
+

    +
  • Bool or None. Bool if the library could be imported, and the result of +operator(lib.__version__, checked_version) or False if __version__ is not implemented in lib. +None is passed if the library is not installed at all.

  • +
  • A string analysis of the check.

  • +
+

+
+
Return type
+

A tuple of results

+
+
+
+ +
+
+mridc.utils.model_utils.convert_model_config_to_dict_config(cfg: Union[omegaconf.dictconfig.DictConfig, mridc.core.conf.modelPT.MRIDCConfig]) omegaconf.dictconfig.DictConfig[source]
+

Converts its input into a standard DictConfig.

+
+
Possible input values are:
    +
  • DictConfig

  • +
  • A dataclass which is a subclass of MRIDCConfig

  • +
+
+
+
+
Parameters
+

cfg (A dict-like object.) –

+
+
Return type
+

The equivalent DictConfig.

+
+
+
+ +
+
+mridc.utils.model_utils.import_class_by_path(path: str)[source]
+

Recursive import of class by path string.

+
+ +
+
+mridc.utils.model_utils.inject_model_parallel_rank(filepath)[source]
+

Injects tensor/pipeline model parallel ranks into the filepath. Does nothing if not using model parallelism.

+
+ +
+
+mridc.utils.model_utils.maybe_update_config_version(cfg: omegaconf.dictconfig.DictConfig)[source]
+

Recursively convert Hydra 0.x configs to Hydra 1.x configs. +Changes include: +- cls -> _target_. +- params -> drop params and shift all arguments to parent. +- target -> _target_ cannot be performed due to ModelPT injecting target inside class.

+
+
Parameters
+

cfg (Any Hydra compatible DictConfig) –

+
+
Return type
+

An updated DictConfig that conforms to Hydra 1.x format.

+
+
+
+ +
+
+mridc.utils.model_utils.parse_dataset_as_name(name: str) str[source]
+

Constructs a valid prefix-name from a provided file path.

+
+
Parameters
+
    +
  • name (Path to some valid data/manifest file or a python object that will be used as a name for the data loader (via) –

  • +
  • cast). (str()) –

  • +
+
+
Return type
+

A valid prefix-name for the data loader.

+
+
+
+ +
+
+mridc.utils.model_utils.resolve_cache_dir() pathlib.Path[source]
+

Utility method to resolve a cache directory for MRIDC that can be overridden by an environment variable. +.. rubric:: Example

+

MRIDC_CACHE_DIR=”~/mridc_cache_dir/” python mridc_example_script.py

+
+
Returns
+

    +
  • A Path object, resolved to the absolute path of the cache directory. If no override is provided, uses an inbuilt

  • +
  • default which adapts to mridc versions strings.

  • +
+

+
+
+
+ +
+
+mridc.utils.model_utils.resolve_dataset_name_from_cfg(cfg: omegaconf.dictconfig.DictConfig) Union[str, int, enum.Enum, float, bool, None, Any][source]
+

Parses items of the provided sub-config to find the first potential key that resolves to an existing file or +directory.

+

# Fast-path Resolution +In order to handle cases where we need to resolve items that are not paths, a fastpath key can be provided as +defined in the global _VAL_TEST_FASTPATH_KEY.

+

This key can be used in two ways : +## _VAL_TEST_FASTPATH_KEY points to another key in the config +If this _VAL_TEST_FASTPATH_KEY points to another key in this config itself, then we assume we want to loop through +the values of that key. This allows for any key in the config to become a fastpath key.

+

Example

+

validation_ds:

+
splits: "val"
+...
+<_VAL_TEST_FASTPATH_KEY>: "splits"  <-- this points to the key name "splits"
+
+
+

Then we can write the following when overriding in hydra: +`python +python train_file.py ... model.validation_ds.splits=[val1, val2, dev1, dev2] ... +` +## _VAL_TEST_FASTPATH_KEY itself acts as the resolved key +If this _VAL_TEST_FASTPATH_KEY does not point to another key in the config, then it is assumed that the items of +this key itself are used for resolution.

+

Example

+

validation_ds:

+
<_VAL_TEST_FASTPATH_KEY>: "val"  <-- this points to the key name "splits"
+
+
+

Then we can write the following when overriding in hydra: +`python +python train_file.py ... model.validation_ds.<_VAL_TEST_FASTPATH_KEY>=[val1, val2, dev1, dev2] ... +` +# IMPORTANT NOTE: +It <can> potentially mismatch if there exist more than 2 valid paths, and the first path does not resolve the +path of the data file (but does resolve to some other valid path). To avoid this side effect, place the data path +as the first item on the config file.

+
+
Parameters
+

cfg (Sub-config of the config file.) –

+
+
Return type
+

A str representing the key of the config which hosts the filepath(s), or None in case path could not be resolved.

+
+
+
+ +
+
+mridc.utils.model_utils.resolve_subclass_pretrained_model_info(base_class) Union[List[mridc.core.classes.common.PretrainedModelInfo], Set[Any]][source]
+

Recursively traverses the inheritance graph of subclasses to extract all pretrained model info. +First constructs a set of unique pretrained model info by performing DFS over the inheritance graph. +All model info belonging to the same class is added together.

+
+
Parameters
+

base_class (The root class, whose subclass graph will be traversed.) –

+
+
Return type
+

A list of unique pretrained model infos belonging to all the inherited subclasses of this baseclass.

+
+
+
+ +
+
+mridc.utils.model_utils.resolve_validation_dataloaders(model: mridc.core.classes.modelPT.ModelPT)[source]
+

Helper method that operates on the ModelPT class to automatically support multiple dataloaders for the validation +set. It does so by first resolving the path to one/more data files via resolve_dataset_name_from_cfg(). +If this resolution fails, it assumes the data loader is prepared to manually support / not support multiple data +loaders and simply calls the appropriate setup method. +If resolution succeeds: +- Checks if provided path is to a single file or a list of files. +If a single file is provided, simply tags that file as such and loads it via the setup method. +If multiple files are provided: +- Inject a new manifest path at index “i” into the resolved key. +- Calls the appropriate setup method to set the data loader. +- Collects the initialized data loader in a list and preserves it. +- Once all data loaders are processed, assigns the list of loaded loaders to the ModelPT. +- Finally, assigns a list of unique names resolved from the file paths to the ModelPT.

+
+
Parameters
+

model (ModelPT subclass, which requires >=1 Validation Dataloaders to be setup.) –

+
+
+
+ +
+
+mridc.utils.model_utils.uninject_model_parallel_rank(filepath)[source]
+

Uninjects tensor/pipeline model parallel ranks from the filepath.

+
+ +
+
+mridc.utils.model_utils.unique_names_check(name_list: Optional[List[str]])[source]
+

Performs a uniqueness check on the name list resolved, so that it can warn users about non-unique keys.

+
+
Parameters
+

name_list (List of strings resolved for data loaders.) –

+
+
+
+ +
+
+mridc.utils.model_utils.wrap_training_step(wrapper=None, enabled=None, adapter=None, proxy=<class 'FunctionWrapper'>)[source]
+

Wraps the training step of the LightningModule.

+
+
Parameters
+
    +
  • wrapped (The wrapped function.) –

  • +
  • instance (The LightningModule instance.) –

  • +
  • args (The arguments passed to the wrapped function.) –

  • +
  • kwargs (The keyword arguments passed to the wrapped function.) –

  • +
+
+
Return type
+

The return value of the wrapped function.

+
+
+
+ +
+
+

mridc.utils.mridc_logging module

+
+
+class mridc.utils.mridc_logging.LogMode(value)[source]
+

Bases: enum.IntEnum

+

Enum for the different logging modes.

+
+
+EACH = 0
+
+ +
+
+ONCE = 1
+
+ +
+ +
+
+class mridc.utils.mridc_logging.Logger(*args, **kwargs)[source]
+

Bases: object

+

Singleton class for logging.

+
+
+CRITICAL = 50
+
+ +
+
+DEBUG = 10
+
+ +
+
+ERROR = 40
+
+ +
+
+INFO = 20
+
+ +
+
+NOTSET = 0
+
+ +
+
+WARNING = 30
+
+ +
+
+add_err_file_handler(log_file)[source]
+

Add a FileHandler to logger that logs all WARNING and higher messages to a file. If the logger had a +MemoryHandler at self._handlers[“memory_err”], those buffered messages are flushed to the new file, and the +MemoryHandler is closed.

+
+ +
+
+add_file_handler(log_file)[source]
+

Add a FileHandler to logger that logs all messages to a file. If the logger had a MemoryHandler at +self._handlers[“memory_all”], those buffered messages are flushed to the new file, and the MemoryHandler is +closed.

+
+ +
+
+add_stream_handlers(formatter=<class 'mridc.utils.formaters.base.BaseMRIDCFormatter'>)[source]
+

Add StreamHandler that log to stdout and stderr to the logger. INFO and lower logs are streamed to stdout +while WARNING and higher are streamed to stderr. If the MRIDC_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR environment +variable is set, all logs are sent to stderr instead.

+
+ +
+
+captureWarnings(capture)[source]
+

If capture is true, redirect all warnings to the logging package. +If capture is False, ensure that warnings are not redirected to logging but to their original destinations.

+
+ +
+
+critical(msg, *args, mode=LogMode.EACH, **kwargs) None[source]
+

Log ‘msg % args’ with severity ‘CRITICAL’. +To pass exception information, use the keyword argument exc_info with a true value, e.g. +logger.critical(“Houston, we have %s”, “major disaster”, exc_info=1)

+
+
Parameters
+
    +
  • msg (the message to log) –

  • +
  • *args (the arguments to the message) –

  • +
  • mode (the mode to log the message in) –

  • +
  • **kwargs (the keyword arguments to the message) –

  • +
+
+
+
+ +
+
+debug(msg, *args, mode=LogMode.EACH, **kwargs)[source]
+

Log ‘msg % args’ with severity ‘DEBUG’. +To pass exception information, use the keyword argument exc_info with a true value, e.g. +logger.debug(“Houston, we have %s”, “thorny problem”, exc_info=1)

+
+ +
+
+error(msg, *args, mode=LogMode.EACH, **kwargs)[source]
+

Log ‘msg % args’ with severity ‘ERROR’. +To pass exception information, use the keyword argument exc_info with a true value, e.g. +logger.error(“Houston, we have %s”, “major problem”, exc_info=1)

+
+ +
+
+getEffectiveLevel()[source]
+

Return how much logging output will be produced.

+
+ +
+
+get_verbosity()[source]
+

See getEffectiveLevel

+
+ +
+
+info(msg, *args, mode=LogMode.EACH, **kwargs)[source]
+

Log ‘msg % args’ with severity ‘INFO’. +To pass exception information, use the keyword argument exc_info with a true value, e.g. +logger.info(“Houston, we have %s”, “interesting problem”, exc_info=1)

+
+ +
+
+patch_stderr_handler(stream)[source]
+

Sends messages that should log to stderr to stream instead. Useful for unittests

+
+ +
+
+patch_stdout_handler(stream)[source]
+

Sends messages that should log to stdout to stream instead. Useful for unittests

+
+ +
+
+remove_stream_handlers()[source]
+

Removes StreamHandler that log to stdout and stderr from the logger.

+
+ +
+
+reset_stream_handler(formatter=<class 'mridc.utils.formaters.base.BaseMRIDCFormatter'>)[source]
+

Removes then adds stream handlers.

+
+ +
+
+setLevel(verbosity_level)[source]
+

Sets the threshold for what messages will be logged.

+
+ +
+
+set_verbosity(verbosity_level)[source]
+

See setLevel

+
+ +
+
+temp_verbosity(verbosity_level)[source]
+

Sets a temporary threshold for what messages will be logged.

+
+ +
+
+warning(msg, *args, mode=LogMode.EACH, **kwargs)[source]
+

Log ‘msg % args’ with severity ‘WARNING’. +To pass exception information, use the keyword argument exc_info with a true value, e.g. +logger.warning(“Houston, we have %s”, “bit of a problem”, exc_info=1)

+
+ +
+ +
+
+

mridc.utils.timers module

+
+
+class mridc.utils.timers.NamedTimer(reduction='mean', sync_cuda=False, buffer_size=- 1)[source]
+

Bases: object

+

A timer class that supports multiple named timers. +A named timer can be used multiple times, in which case the average dt will be returned. +A named timer cannot be started if it is already currently running. +Use case: measuring execution of multiple code blocks.

+
+
+active_timers()[source]
+

Return list of all active named timers

+
+ +
+
+property buffer_size
+

Returns the buffer size of the timer.

+
+ +
+
+export()[source]
+

Exports a dictionary with average/all dt per named timer

+
+ +
+
+get(name='')[source]
+

Returns the value of a named timer

+
+
Parameters
+

name (timer name to return) –

+
+
+
+ +
+
+reset(name=None)[source]
+

Resents all / specific timer

+
+
Parameters
+

name (Timer name to reset (if None all timers are reset)) –

+
+
+
+ +
+
+start(name='')[source]
+

Starts measuring a named timer.

+
+
Parameters
+

name (timer name to start) –

+
+
+
+ +
+
+stop(name='')[source]
+

Stops measuring a named timer.

+
+
Parameters
+

name (timer name to stop) –

+
+
+
+ +
+ +
+
+

Module contents

+
+
+ + +
+
+ +
+
+
+
+ + + + diff --git a/docs/build/html/objects.inv b/docs/build/html/objects.inv new file mode 100644 index 00000000..a5e2c38a Binary files /dev/null and b/docs/build/html/objects.inv differ diff --git a/docs/build/html/py-modindex.html b/docs/build/html/py-modindex.html new file mode 100644 index 00000000..b2fcf22c --- /dev/null +++ b/docs/build/html/py-modindex.html @@ -0,0 +1,800 @@ + + + + + + Python Module Index — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Python Module Index
  • +
  • +
  • +
+
+
+
+
+ + +

Python Module Index

+ +
+ m +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
 
+ m
+ mridc +
    + mridc.collections +
    + mridc.collections.common +
    + mridc.collections.common.callbacks +
    + mridc.collections.common.callbacks.callbacks +
    + mridc.collections.common.data +
    + mridc.collections.common.data.dataset +
    + mridc.collections.common.losses +
    + mridc.collections.common.losses.aggregator +
    + mridc.collections.common.losses.ssim +
    + mridc.collections.common.metrics +
    + mridc.collections.common.metrics.global_average_loss_metric +
    + mridc.collections.common.parts +
    + mridc.collections.common.parts.fft +
    + mridc.collections.common.parts.patch_utils +
    + mridc.collections.common.parts.ptl_overrides +
    + mridc.collections.common.parts.rnn_utils +
    + mridc.collections.common.parts.utils +
    + mridc.collections.reconstruction +
    + mridc.collections.reconstruction.data +
    + mridc.collections.reconstruction.data.mri_data +
    + mridc.collections.reconstruction.data.subsample +
    + mridc.collections.reconstruction.metrics +
    + mridc.collections.reconstruction.metrics.evaluate +
    + mridc.collections.reconstruction.models +
    + mridc.collections.reconstruction.models.base +
    + mridc.collections.reconstruction.models.cascadenet +
    + mridc.collections.reconstruction.models.cascadenet.ccnn_block +
    + mridc.collections.reconstruction.models.ccnn +
    + mridc.collections.reconstruction.models.cirim +
    + mridc.collections.reconstruction.models.conv +
    + mridc.collections.reconstruction.models.conv.conv2d +
    + mridc.collections.reconstruction.models.conv.gruconv2d +
    + mridc.collections.reconstruction.models.convrecnet +
    + mridc.collections.reconstruction.models.convrecnet.crnn_block +
    + mridc.collections.reconstruction.models.crnn +
    + mridc.collections.reconstruction.models.crossdomain +
    + mridc.collections.reconstruction.models.crossdomain.crossdomain +
    + mridc.collections.reconstruction.models.crossdomain.multicoil +
    + mridc.collections.reconstruction.models.didn +
    + mridc.collections.reconstruction.models.didn.didn +
    + mridc.collections.reconstruction.models.dunet +
    + mridc.collections.reconstruction.models.jointicnet +
    + mridc.collections.reconstruction.models.kikinet +
    + mridc.collections.reconstruction.models.lpd +
    + mridc.collections.reconstruction.models.multidomain +
    + mridc.collections.reconstruction.models.multidomain.multidomain +
    + mridc.collections.reconstruction.models.multidomainnet +
    + mridc.collections.reconstruction.models.mwcnn +
    + mridc.collections.reconstruction.models.mwcnn.mwcnn +
    + mridc.collections.reconstruction.models.pics +
    + mridc.collections.reconstruction.models.primaldual +
    + mridc.collections.reconstruction.models.primaldual.pd +
    + mridc.collections.reconstruction.models.recurrentvarnet +
    + mridc.collections.reconstruction.models.recurrentvarnet.conv2gru +
    + mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet +
    + mridc.collections.reconstruction.models.rim +
    + mridc.collections.reconstruction.models.rim.conv_layers +
    + mridc.collections.reconstruction.models.rim.rim_block +
    + mridc.collections.reconstruction.models.rim.rnn_cells +
    + mridc.collections.reconstruction.models.rim.utils +
    + mridc.collections.reconstruction.models.rvn +
    + mridc.collections.reconstruction.models.sigmanet +
    + mridc.collections.reconstruction.models.sigmanet.dc_layers +
    + mridc.collections.reconstruction.models.sigmanet.sensitivity_net +
    + mridc.collections.reconstruction.models.unet +
    + mridc.collections.reconstruction.models.unet_base +
    + mridc.collections.reconstruction.models.unet_base.unet_block +
    + mridc.collections.reconstruction.models.variablesplittingnet +
    + mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block +
    + mridc.collections.reconstruction.models.varnet +
    + mridc.collections.reconstruction.models.varnet.vn_block +
    + mridc.collections.reconstruction.models.vn +
    + mridc.collections.reconstruction.models.vsnet +
    + mridc.collections.reconstruction.models.xpdnet +
    + mridc.collections.reconstruction.models.zf +
    + mridc.collections.reconstruction.parts +
    + mridc.collections.reconstruction.parts.transforms +
    + mridc.collections.reconstruction.parts.utils +
    + mridc.constants +
    + mridc.core +
    + mridc.core.classes +
    + mridc.core.classes.common +
    + mridc.core.classes.dataset +
    + mridc.core.classes.export +
    + mridc.core.classes.loss +
    + mridc.core.classes.modelPT +
    + mridc.core.classes.module +
    + mridc.core.conf +
    + mridc.core.conf.base_config +
    + mridc.core.conf.dataloader +
    + mridc.core.conf.hydra_runner +
    + mridc.core.conf.modelPT +
    + mridc.core.conf.optimizers +
    + mridc.core.conf.schedulers +
    + mridc.core.conf.trainer +
    + mridc.core.connectors +
    + mridc.core.connectors.save_restore_connector +
    + mridc.core.neural_types +
    + mridc.core.neural_types.axes +
    + mridc.core.neural_types.comparison +
    + mridc.core.neural_types.elements +
    + mridc.core.neural_types.neural_type +
    + mridc.core.optim +
    + mridc.core.optim.adafactor +
    + mridc.core.optim.lr_scheduler +
    + mridc.core.optim.novograd +
    + mridc.core.optim.optimizer_with_master_params +
    + mridc.core.optim.optimizers +
    + mridc.core.utils +
    + mridc.core.utils.neural_type_utils +
    + mridc.core.utils.numba_utils +
    + mridc.launch +
    + mridc.package_info +
    + mridc.utils +
    + mridc.utils.app_state +
    + mridc.utils.arguments +
    + mridc.utils.cloud +
    + mridc.utils.config_utils +
    + mridc.utils.decorators +
    + mridc.utils.decorators.deprecated +
    + mridc.utils.decorators.experimental +
    + mridc.utils.decorators.port_docs +
    + mridc.utils.distributed +
    + mridc.utils.env_var_parsing +
    + mridc.utils.exceptions +
    + mridc.utils.exp_manager +
    + mridc.utils.export_utils +
    + mridc.utils.formaters +
    + mridc.utils.formaters.base +
    + mridc.utils.formaters.colors +
    + mridc.utils.formaters.utils +
    + mridc.utils.get_rank +
    + mridc.utils.lightning_logger_patch +
    + mridc.utils.metaclasses +
    + mridc.utils.model_utils +
    + mridc.utils.mridc_logging +
    + mridc.utils.timers +
+ + +
+
+
+ +
+ +
+

© Copyright 2022, Dimitrios Karkalousos.

+
+ + Built with Sphinx using a + theme + provided by Read the Docs. + + +
+
+
+
+
+ + + + diff --git a/docs/build/html/search.html b/docs/build/html/search.html new file mode 100644 index 00000000..1643a117 --- /dev/null +++ b/docs/build/html/search.html @@ -0,0 +1,115 @@ + + + + + + Search — mridc v.0.0.1 documentation + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • »
  • +
  • Search
  • +
  • +
  • +
+
+
+
+
+ + + + +
+ +
+ +
+
+
+ +
+ +
+

© Copyright 2022, Dimitrios Karkalousos.

+
+ + Built with Sphinx using a + theme + provided by Read the Docs. + + +
+
+
+
+
+ + + + + + + + + diff --git a/docs/build/html/searchindex.js b/docs/build/html/searchindex.js new file mode 100644 index 00000000..c85372df --- /dev/null +++ b/docs/build/html/searchindex.js @@ -0,0 +1 @@ +Search.setIndex({docnames:["index","modules","mridc","mridc.collections","mridc.collections.common","mridc.collections.common.callbacks","mridc.collections.common.data","mridc.collections.common.losses","mridc.collections.common.metrics","mridc.collections.common.parts","mridc.collections.reconstruction","mridc.collections.reconstruction.data","mridc.collections.reconstruction.metrics","mridc.collections.reconstruction.models","mridc.collections.reconstruction.models.cascadenet","mridc.collections.reconstruction.models.conv","mridc.collections.reconstruction.models.convrecnet","mridc.collections.reconstruction.models.crossdomain","mridc.collections.reconstruction.models.didn","mridc.collections.reconstruction.models.multidomain","mridc.collections.reconstruction.models.mwcnn","mridc.collections.reconstruction.models.primaldual","mridc.collections.reconstruction.models.recurrentvarnet","mridc.collections.reconstruction.models.rim","mridc.collections.reconstruction.models.sigmanet","mridc.collections.reconstruction.models.unet_base","mridc.collections.reconstruction.models.variablesplittingnet","mridc.collections.reconstruction.models.varnet","mridc.collections.reconstruction.parts","mridc.core","mridc.core.classes","mridc.core.conf","mridc.core.connectors","mridc.core.neural_types","mridc.core.optim","mridc.core.utils","mridc.utils","mridc.utils.decorators","mridc.utils.formaters"],envversion:{"sphinx.domains.c":2,"sphinx.domains.changeset":1,"sphinx.domains.citation":1,"sphinx.domains.cpp":5,"sphinx.domains.index":1,"sphinx.domains.javascript":2,"sphinx.domains.math":2,"sphinx.domains.python":3,"sphinx.domains.rst":2,"sphinx.domains.std":2,"sphinx.ext.viewcode":1,sphinx:56},filenames:["index.rst","modules.rst","mridc.rst","mridc.collections.rst","mridc.collections.common.rst","mridc.collections.common.callbacks.rst","mridc.collections.common.data.rst","mridc.collections.common.losses.rst","mridc.collections.common.metrics.rst","mridc.collections.common.parts.rst","mridc.collections.reconstruction.rst","mridc.collections.reconstruction.data.rst","mridc.collections.reconstruction.metrics.rst","mridc.collections.reconstruction.models.rst","mridc.collections.reconstruction.models.cascadenet.rst","mridc.collections.reconstruction.models.conv.rst","mridc.collections.reconstruction.models.convrecnet.rst","mridc.collections.reconstruction.models.crossdomain.rst","mridc.collections.reconstruction.models.didn.rst","mridc.collections.reconstruction.models.multidomain.rst","mridc.collections.reconstruction.models.mwcnn.rst","mridc.collections.reconstruction.models.primaldual.rst","mridc.collections.reconstruction.models.recurrentvarnet.rst","mridc.collections.reconstruction.models.rim.rst","mridc.collections.reconstruction.models.sigmanet.rst","mridc.collections.reconstruction.models.unet_base.rst","mridc.collections.reconstruction.models.variablesplittingnet.rst","mridc.collections.reconstruction.models.varnet.rst","mridc.collections.reconstruction.parts.rst","mridc.core.rst","mridc.core.classes.rst","mridc.core.conf.rst","mridc.core.connectors.rst","mridc.core.neural_types.rst","mridc.core.optim.rst","mridc.core.utils.rst","mridc.utils.rst","mridc.utils.decorators.rst","mridc.utils.formaters.rst"],objects:{"":[[2,0,0,"-","mridc"]],"mridc.collections":[[4,0,0,"-","common"],[10,0,0,"-","reconstruction"]],"mridc.collections.common":[[5,0,0,"-","callbacks"],[6,0,0,"-","data"],[7,0,0,"-","losses"],[8,0,0,"-","metrics"],[9,0,0,"-","parts"]],"mridc.collections.common.callbacks":[[5,0,0,"-","callbacks"]],"mridc.collections.common.callbacks.callbacks":[[5,1,1,"","LogEpochTimeCallback"]],"mridc.collections.common.callbacks.callbacks.LogEpochTimeCallback":[[5,2,1,"","on_train_epoch_end"],[5,2,1,"","on_train_epoch_start"]],"mridc.collections.common.data":[[6,0,0,"-","dataset"]],"mridc.collections.common.data.dataset":[[6,1,1,"","ConcatDataset"]],"mridc.collections.common.data.dataset.ConcatDataset":[[6,2,1,"","__iter__"],[6,2,1,"","__len__"],[6,2,1,"","get_iterable"],[6,2,1,"","random_generator"],[6,2,1,"","round_robin_generator"]],"mridc.collections.common.losses":[[7,0,0,"-","aggregator"],[7,0,0,"-","ssim"]],"mridc.collections.common.losses.aggregator":[[7,1,1,"","AggregatorLoss"]],"mridc.collections.common.losses.aggregator.AggregatorLoss":[[7,2,1,"","forward"],[7,3,1,"","input_types"],[7,3,1,"","output_types"],[7,4,1,"","reduction"]],"mridc.collections.common.losses.ssim":[[7,1,1,"","SSIMLoss"]],"mridc.collections.common.losses.ssim.SSIMLoss":[[7,2,1,"","forward"],[7,4,1,"","training"]],"mridc.collections.common.metrics":[[8,0,0,"-","global_average_loss_metric"]],"mridc.collections.common.metrics.global_average_loss_metric":[[8,1,1,"","GlobalAverageLossMetric"]],"mridc.collections.common.metrics.global_average_loss_metric.GlobalAverageLossMetric":[[8,2,1,"","compute"],[8,2,1,"","update"]],"mridc.collections.common.parts":[[9,0,0,"-","fft"],[9,0,0,"-","patch_utils"],[9,0,0,"-","ptl_overrides"],[9,0,0,"-","rnn_utils"],[9,0,0,"-","utils"]],"mridc.collections.common.parts.fft":[[9,5,1,"","fft2c"],[9,5,1,"","ifft2c"]],"mridc.collections.common.parts.ptl_overrides":[[9,1,1,"","MRIDCNativeMixedPrecisionPlugin"]],"mridc.collections.common.parts.rnn_utils":[[9,5,1,"","rnn_weights_init"]],"mridc.collections.common.parts.utils":[[9,5,1,"","check_stacked_complex"],[9,5,1,"","coil_combination"],[9,5,1,"","complex_abs"],[9,5,1,"","complex_abs_sq"],[9,5,1,"","complex_conj"],[9,5,1,"","complex_mul"],[9,5,1,"","rss"],[9,5,1,"","rss_complex"],[9,5,1,"","save_reconstructions"],[9,5,1,"","sense"],[9,5,1,"","tensor_to_complex_np"],[9,5,1,"","to_tensor"]],"mridc.collections.reconstruction":[[11,0,0,"-","data"],[12,0,0,"-","metrics"],[13,0,0,"-","models"],[28,0,0,"-","parts"]],"mridc.collections.reconstruction.data":[[11,0,0,"-","mri_data"],[11,0,0,"-","subsample"]],"mridc.collections.reconstruction.data.mri_data":[[11,1,1,"","FastMRICombinedSliceDataset"],[11,1,1,"","FastMRISliceDataset"],[11,5,1,"","et_query"]],"mridc.collections.reconstruction.data.subsample":[[11,1,1,"","EquispacedMaskFunc"],[11,1,1,"","Gaussian1DMaskFunc"],[11,1,1,"","Gaussian2DMaskFunc"],[11,1,1,"","MaskFunc"],[11,1,1,"","Poisson2DMaskFunc"],[11,1,1,"","RandomMaskFunc"],[11,5,1,"","create_mask_for_mask_type"],[11,5,1,"","temp_seed"]],"mridc.collections.reconstruction.data.subsample.EquispacedMaskFunc":[[11,2,1,"","__call__"]],"mridc.collections.reconstruction.data.subsample.Gaussian1DMaskFunc":[[11,2,1,"","__call__"],[11,2,1,"","gaussian_coordinates"],[11,2,1,"","gaussian_kernel"],[11,2,1,"","gaussian_kspace"]],"mridc.collections.reconstruction.data.subsample.Gaussian2DMaskFunc":[[11,2,1,"","__call__"],[11,2,1,"","gaussian_coordinates"],[11,2,1,"","gaussian_kernel"],[11,2,1,"","gaussian_kspace"]],"mridc.collections.reconstruction.data.subsample.MaskFunc":[[11,2,1,"","__call__"],[11,2,1,"","choose_acceleration"]],"mridc.collections.reconstruction.data.subsample.Poisson2DMaskFunc":[[11,2,1,"","__call__"],[11,2,1,"","centered_circle"],[11,2,1,"","poisson_disc2d"]],"mridc.collections.reconstruction.data.subsample.RandomMaskFunc":[[11,2,1,"","__call__"]],"mridc.collections.reconstruction.metrics":[[12,0,0,"-","evaluate"]],"mridc.collections.reconstruction.metrics.evaluate":[[12,1,1,"","Metrics"],[12,5,1,"","evaluate"],[12,5,1,"","mse"],[12,5,1,"","nmse"],[12,5,1,"","psnr"],[12,5,1,"","ssim"]],"mridc.collections.reconstruction.metrics.evaluate.Metrics":[[12,2,1,"","__repr__"],[12,2,1,"","means"],[12,2,1,"","push"],[12,2,1,"","stddevs"]],"mridc.collections.reconstruction.models":[[13,0,0,"-","base"],[14,0,0,"-","cascadenet"],[13,0,0,"-","ccnn"],[13,0,0,"-","cirim"],[15,0,0,"-","conv"],[16,0,0,"-","convrecnet"],[13,0,0,"-","crnn"],[17,0,0,"-","crossdomain"],[18,0,0,"-","didn"],[13,0,0,"-","dunet"],[13,0,0,"-","jointicnet"],[13,0,0,"-","kikinet"],[13,0,0,"-","lpd"],[19,0,0,"-","multidomain"],[13,0,0,"-","multidomainnet"],[20,0,0,"-","mwcnn"],[13,0,0,"-","pics"],[21,0,0,"-","primaldual"],[22,0,0,"-","recurrentvarnet"],[23,0,0,"-","rim"],[13,0,0,"-","rvn"],[24,0,0,"-","sigmanet"],[13,0,0,"-","unet"],[25,0,0,"-","unet_base"],[26,0,0,"-","variablesplittingnet"],[27,0,0,"-","varnet"],[13,0,0,"-","vn"],[13,0,0,"-","vsnet"],[13,0,0,"-","xpdnet"],[13,0,0,"-","zf"]],"mridc.collections.reconstruction.models.base":[[13,1,1,"","BaseMRIReconstructionModel"],[13,1,1,"","BaseSensitivityModel"]],"mridc.collections.reconstruction.models.base.BaseMRIReconstructionModel":[[13,2,1,"","log_image"],[13,2,1,"","process_inputs"],[13,2,1,"","process_loss"],[13,2,1,"","setup_test_data"],[13,2,1,"","setup_training_data"],[13,2,1,"","setup_validation_data"],[13,2,1,"","test_epoch_end"],[13,2,1,"","test_step"],[13,4,1,"","training"],[13,2,1,"","training_step"],[13,2,1,"","validation_epoch_end"],[13,2,1,"","validation_step"]],"mridc.collections.reconstruction.models.base.BaseSensitivityModel":[[13,2,1,"","batch_chans_to_chan_dim"],[13,2,1,"","chans_to_batch_dim"],[13,2,1,"","divide_root_sum_of_squares"],[13,2,1,"","forward"],[13,2,1,"","get_pad_and_num_low_freqs"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.cascadenet":[[14,0,0,"-","ccnn_block"]],"mridc.collections.reconstruction.models.cascadenet.ccnn_block":[[14,1,1,"","CascadeNetBlock"]],"mridc.collections.reconstruction.models.cascadenet.ccnn_block.CascadeNetBlock":[[14,2,1,"","forward"],[14,2,1,"","sens_expand"],[14,2,1,"","sens_reduce"],[14,4,1,"","training"]],"mridc.collections.reconstruction.models.ccnn":[[13,1,1,"","CascadeNet"]],"mridc.collections.reconstruction.models.ccnn.CascadeNet":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.cirim":[[13,1,1,"","CIRIM"]],"mridc.collections.reconstruction.models.cirim.CIRIM":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,2,1,"","process_intermediate_pred"],[13,2,1,"","process_loss"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.conv":[[15,0,0,"-","conv2d"],[15,0,0,"-","gruconv2d"]],"mridc.collections.reconstruction.models.conv.conv2d":[[15,1,1,"","Conv2d"]],"mridc.collections.reconstruction.models.conv.conv2d.Conv2d":[[15,2,1,"","forward"],[15,4,1,"","training"]],"mridc.collections.reconstruction.models.conv.gruconv2d":[[15,1,1,"","GRUConv2d"]],"mridc.collections.reconstruction.models.conv.gruconv2d.GRUConv2d":[[15,2,1,"","forward"],[15,4,1,"","training"]],"mridc.collections.reconstruction.models.convrecnet":[[16,0,0,"-","crnn_block"]],"mridc.collections.reconstruction.models.convrecnet.crnn_block":[[16,1,1,"","DataConsistencyLayer"],[16,1,1,"","RecurrentConvolutionalNetBlock"]],"mridc.collections.reconstruction.models.convrecnet.crnn_block.DataConsistencyLayer":[[16,2,1,"","forward"],[16,4,1,"","training"]],"mridc.collections.reconstruction.models.convrecnet.crnn_block.RecurrentConvolutionalNetBlock":[[16,2,1,"","forward"],[16,2,1,"","sens_expand"],[16,2,1,"","sens_reduce"],[16,4,1,"","training"]],"mridc.collections.reconstruction.models.crnn":[[13,1,1,"","CRNNet"]],"mridc.collections.reconstruction.models.crnn.CRNNet":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,2,1,"","process_intermediate_pred"],[13,2,1,"","process_loss"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.crossdomain":[[17,0,0,"-","crossdomain"],[17,0,0,"-","multicoil"]],"mridc.collections.reconstruction.models.crossdomain.crossdomain":[[17,1,1,"","CrossDomainNetwork"]],"mridc.collections.reconstruction.models.crossdomain.crossdomain.CrossDomainNetwork":[[17,2,1,"","forward"],[17,2,1,"","image_correction"],[17,2,1,"","kspace_correction"],[17,4,1,"","training"]],"mridc.collections.reconstruction.models.crossdomain.multicoil":[[17,1,1,"","MultiCoil"]],"mridc.collections.reconstruction.models.crossdomain.multicoil.MultiCoil":[[17,2,1,"","forward"],[17,4,1,"","training"]],"mridc.collections.reconstruction.models.didn":[[18,0,0,"-","didn"]],"mridc.collections.reconstruction.models.didn.didn":[[18,1,1,"","DIDN"],[18,1,1,"","DUB"],[18,1,1,"","ReconBlock"],[18,1,1,"","Subpixel"]],"mridc.collections.reconstruction.models.didn.didn.DIDN":[[18,2,1,"","crop_to_shape"],[18,2,1,"","forward"],[18,4,1,"","training"]],"mridc.collections.reconstruction.models.didn.didn.DUB":[[18,2,1,"","crop_to_shape"],[18,2,1,"","forward"],[18,2,1,"","pad"],[18,4,1,"","training"]],"mridc.collections.reconstruction.models.didn.didn.ReconBlock":[[18,2,1,"","forward"],[18,4,1,"","training"]],"mridc.collections.reconstruction.models.didn.didn.Subpixel":[[18,2,1,"","forward"],[18,4,1,"","training"]],"mridc.collections.reconstruction.models.dunet":[[13,1,1,"","DUNet"]],"mridc.collections.reconstruction.models.dunet.DUNet":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.jointicnet":[[13,1,1,"","JointICNet"]],"mridc.collections.reconstruction.models.jointicnet.JointICNet":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,4,1,"","trainer"],[13,4,1,"","training"],[13,2,1,"","update_C"],[13,2,1,"","update_X"]],"mridc.collections.reconstruction.models.kikinet":[[13,1,1,"","KIKINet"]],"mridc.collections.reconstruction.models.kikinet.KIKINet":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.lpd":[[13,1,1,"","LPDNet"]],"mridc.collections.reconstruction.models.lpd.LPDNet":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.multidomain":[[19,0,0,"-","multidomain"]],"mridc.collections.reconstruction.models.multidomain.multidomain":[[19,1,1,"","MultiDomainConv2d"],[19,1,1,"","MultiDomainConvBlock"],[19,1,1,"","MultiDomainConvTranspose2d"],[19,1,1,"","MultiDomainUnet2d"],[19,1,1,"","StandardizationLayer"],[19,1,1,"","TransposeMultiDomainConvBlock"]],"mridc.collections.reconstruction.models.multidomain.multidomain.MultiDomainConv2d":[[19,2,1,"","forward"],[19,4,1,"","training"]],"mridc.collections.reconstruction.models.multidomain.multidomain.MultiDomainConvBlock":[[19,2,1,"","forward"],[19,4,1,"","training"]],"mridc.collections.reconstruction.models.multidomain.multidomain.MultiDomainConvTranspose2d":[[19,2,1,"","forward"],[19,4,1,"","training"]],"mridc.collections.reconstruction.models.multidomain.multidomain.MultiDomainUnet2d":[[19,2,1,"","forward"],[19,4,1,"","training"]],"mridc.collections.reconstruction.models.multidomain.multidomain.StandardizationLayer":[[19,2,1,"","forward"],[19,4,1,"","training"]],"mridc.collections.reconstruction.models.multidomain.multidomain.TransposeMultiDomainConvBlock":[[19,2,1,"","forward"],[19,4,1,"","training"]],"mridc.collections.reconstruction.models.multidomainnet":[[13,1,1,"","MultiDomainNet"]],"mridc.collections.reconstruction.models.multidomainnet.MultiDomainNet":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.mwcnn":[[20,0,0,"-","mwcnn"]],"mridc.collections.reconstruction.models.mwcnn.mwcnn":[[20,1,1,"","ConvBlock"],[20,1,1,"","DWT"],[20,1,1,"","DilatedConvBlock"],[20,1,1,"","IWT"],[20,1,1,"","MWCNN"]],"mridc.collections.reconstruction.models.mwcnn.mwcnn.ConvBlock":[[20,2,1,"","forward"],[20,4,1,"","training"]],"mridc.collections.reconstruction.models.mwcnn.mwcnn.DWT":[[20,2,1,"","forward"],[20,4,1,"","training"]],"mridc.collections.reconstruction.models.mwcnn.mwcnn.DilatedConvBlock":[[20,2,1,"","forward"],[20,4,1,"","training"]],"mridc.collections.reconstruction.models.mwcnn.mwcnn.IWT":[[20,2,1,"","forward"],[20,4,1,"","training"]],"mridc.collections.reconstruction.models.mwcnn.mwcnn.MWCNN":[[20,2,1,"","crop_to_shape"],[20,2,1,"","forward"],[20,2,1,"","pad"],[20,4,1,"","training"]],"mridc.collections.reconstruction.models.pics":[[13,1,1,"","PICS"]],"mridc.collections.reconstruction.models.pics.PICS":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,2,1,"","process_inputs"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,2,1,"","test_step"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.primaldual":[[21,0,0,"-","pd"]],"mridc.collections.reconstruction.models.primaldual.pd":[[21,1,1,"","DualNet"],[21,1,1,"","PrimalNet"]],"mridc.collections.reconstruction.models.primaldual.pd.DualNet":[[21,2,1,"","compute_model_per_coil"],[21,2,1,"","forward"],[21,4,1,"","training"]],"mridc.collections.reconstruction.models.primaldual.pd.PrimalNet":[[21,2,1,"","forward"],[21,4,1,"","training"]],"mridc.collections.reconstruction.models.recurrentvarnet":[[22,0,0,"-","conv2gru"],[22,0,0,"-","recurentvarnet"]],"mridc.collections.reconstruction.models.recurrentvarnet.conv2gru":[[22,1,1,"","Conv2dGRU"]],"mridc.collections.reconstruction.models.recurrentvarnet.conv2gru.Conv2dGRU":[[22,2,1,"","forward"],[22,4,1,"","training"]],"mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet":[[22,1,1,"","RecurrentInit"],[22,1,1,"","RecurrentVarNetBlock"]],"mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet.RecurrentInit":[[22,2,1,"","forward"],[22,4,1,"","training"]],"mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet.RecurrentVarNetBlock":[[22,2,1,"","forward"],[22,4,1,"","training"]],"mridc.collections.reconstruction.models.rim":[[23,0,0,"-","conv_layers"],[23,0,0,"-","rim_block"],[23,0,0,"-","rnn_cells"],[23,0,0,"-","utils"]],"mridc.collections.reconstruction.models.rim.conv_layers":[[23,1,1,"","ConvNonlinear"],[23,1,1,"","ConvRNNStack"]],"mridc.collections.reconstruction.models.rim.conv_layers.ConvNonlinear":[[23,2,1,"","check_forward_input"],[23,2,1,"","determine_conv_class"],[23,2,1,"","extra_repr"],[23,2,1,"","forward"],[23,2,1,"","reset_parameters"],[23,4,1,"","training"]],"mridc.collections.reconstruction.models.rim.conv_layers.ConvRNNStack":[[23,2,1,"","forward"],[23,4,1,"","training"]],"mridc.collections.reconstruction.models.rim.rim_block":[[23,1,1,"","RIMBlock"]],"mridc.collections.reconstruction.models.rim.rim_block.RIMBlock":[[23,2,1,"","forward"],[23,4,1,"","training"]],"mridc.collections.reconstruction.models.rim.rnn_cells":[[23,1,1,"","ConvGRUCell"],[23,1,1,"","ConvGRUCellBase"],[23,1,1,"","ConvMGUCell"],[23,1,1,"","ConvMGUCellBase"],[23,1,1,"","IndRNNCell"],[23,1,1,"","IndRNNCellBase"]],"mridc.collections.reconstruction.models.rim.rnn_cells.ConvGRUCell":[[23,2,1,"","forward"],[23,4,1,"","training"]],"mridc.collections.reconstruction.models.rim.rnn_cells.ConvGRUCellBase":[[23,2,1,"","check_forward_hidden"],[23,2,1,"","check_forward_input"],[23,2,1,"","determine_conv_class"],[23,2,1,"","extra_repr"],[23,2,1,"","orthotogonalize_weights"],[23,2,1,"","reset_parameters"],[23,4,1,"","training"]],"mridc.collections.reconstruction.models.rim.rnn_cells.ConvMGUCell":[[23,2,1,"","forward"],[23,4,1,"","training"]],"mridc.collections.reconstruction.models.rim.rnn_cells.ConvMGUCellBase":[[23,2,1,"","check_forward_hidden"],[23,2,1,"","check_forward_input"],[23,2,1,"","determine_conv_class"],[23,2,1,"","extra_repr"],[23,2,1,"","orthotogonalize_weights"],[23,2,1,"","reset_parameters"],[23,4,1,"","training"]],"mridc.collections.reconstruction.models.rim.rnn_cells.IndRNNCell":[[23,2,1,"","forward"],[23,4,1,"","training"]],"mridc.collections.reconstruction.models.rim.rnn_cells.IndRNNCellBase":[[23,2,1,"","check_forward_hidden"],[23,2,1,"","check_forward_input"],[23,2,1,"","determine_conv_class"],[23,2,1,"","extra_repr"],[23,2,1,"","orthotogonalize_weights"],[23,2,1,"","reset_parameters"],[23,4,1,"","training"]],"mridc.collections.reconstruction.models.rim.utils":[[23,5,1,"","log_likelihood_gradient"]],"mridc.collections.reconstruction.models.rvn":[[13,1,1,"","RecurrentVarNet"]],"mridc.collections.reconstruction.models.rvn.RecurrentVarNet":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.sigmanet":[[24,0,0,"-","dc_layers"],[24,0,0,"-","sensitivity_net"]],"mridc.collections.reconstruction.models.sigmanet.dc_layers":[[24,1,1,"","ConjugateGradient"],[24,1,1,"","DCLayer"],[24,1,1,"","DataGDLayer"],[24,1,1,"","DataIDLayer"],[24,1,1,"","DataProxCGLayer"],[24,1,1,"","DataVSLayer"]],"mridc.collections.reconstruction.models.sigmanet.dc_layers.ConjugateGradient":[[24,2,1,"","backward"],[24,2,1,"","complexDot"],[24,2,1,"","forward"],[24,2,1,"","solve"]],"mridc.collections.reconstruction.models.sigmanet.dc_layers.DCLayer":[[24,2,1,"","forward"],[24,2,1,"","set_learnable"],[24,4,1,"","training"]],"mridc.collections.reconstruction.models.sigmanet.dc_layers.DataGDLayer":[[24,2,1,"","forward"],[24,4,1,"","training"]],"mridc.collections.reconstruction.models.sigmanet.dc_layers.DataIDLayer":[[24,4,1,"","training"]],"mridc.collections.reconstruction.models.sigmanet.dc_layers.DataProxCGLayer":[[24,2,1,"","forward"],[24,2,1,"","set_learnable"],[24,4,1,"","training"]],"mridc.collections.reconstruction.models.sigmanet.dc_layers.DataVSLayer":[[24,2,1,"","forward"],[24,2,1,"","set_learnable"],[24,4,1,"","training"]],"mridc.collections.reconstruction.models.sigmanet.sensitivity_net":[[24,1,1,"","ComplexInstanceNorm"],[24,1,1,"","ComplexNormWrapper"],[24,1,1,"","SensitivityNetwork"],[24,5,1,"","matrix_invert"]],"mridc.collections.reconstruction.models.sigmanet.sensitivity_net.ComplexInstanceNorm":[[24,2,1,"","complex_instance_norm"],[24,2,1,"","complex_pseudocovariance"],[24,2,1,"","forward"],[24,2,1,"","normalize"],[24,2,1,"","set_normalization"],[24,4,1,"","training"],[24,2,1,"","unnormalize"]],"mridc.collections.reconstruction.models.sigmanet.sensitivity_net.ComplexNormWrapper":[[24,2,1,"","forward"],[24,4,1,"","training"]],"mridc.collections.reconstruction.models.sigmanet.sensitivity_net.SensitivityNetwork":[[24,2,1,"","copy_params"],[24,2,1,"","forward"],[24,2,1,"","forward_save_space"],[24,2,1,"","freeze"],[24,2,1,"","freeze_all"],[24,2,1,"","stage_training_init"],[24,2,1,"","stage_training_transition_i"],[24,4,1,"","training"],[24,2,1,"","unfreeze"],[24,2,1,"","unfreeze_all"]],"mridc.collections.reconstruction.models.unet":[[13,1,1,"","UNet"]],"mridc.collections.reconstruction.models.unet.UNet":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.unet_base":[[25,0,0,"-","unet_block"]],"mridc.collections.reconstruction.models.unet_base.unet_block":[[25,1,1,"","ConvBlock"],[25,1,1,"","NormUnet"],[25,1,1,"","TransposeConvBlock"],[25,1,1,"","Unet"]],"mridc.collections.reconstruction.models.unet_base.unet_block.ConvBlock":[[25,2,1,"","forward"],[25,4,1,"","training"]],"mridc.collections.reconstruction.models.unet_base.unet_block.NormUnet":[[25,2,1,"","chan_complex_to_last_dim"],[25,2,1,"","complex_to_chan_dim"],[25,2,1,"","forward"],[25,2,1,"","norm"],[25,2,1,"","pad"],[25,4,1,"","training"],[25,2,1,"","unnorm"],[25,2,1,"","unpad"]],"mridc.collections.reconstruction.models.unet_base.unet_block.TransposeConvBlock":[[25,2,1,"","forward"],[25,4,1,"","training"]],"mridc.collections.reconstruction.models.unet_base.unet_block.Unet":[[25,2,1,"","forward"],[25,4,1,"","training"]],"mridc.collections.reconstruction.models.variablesplittingnet":[[26,0,0,"-","vsnet_block"]],"mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block":[[26,1,1,"","DataConsistencyLayer"],[26,1,1,"","VSNetBlock"],[26,1,1,"","WeightedAverageTerm"]],"mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block.DataConsistencyLayer":[[26,2,1,"","forward"],[26,4,1,"","training"]],"mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block.VSNetBlock":[[26,2,1,"","forward"],[26,2,1,"","sens_expand"],[26,2,1,"","sens_reduce"],[26,4,1,"","training"]],"mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block.WeightedAverageTerm":[[26,2,1,"","forward"],[26,4,1,"","training"]],"mridc.collections.reconstruction.models.varnet":[[27,0,0,"-","vn_block"]],"mridc.collections.reconstruction.models.varnet.vn_block":[[27,1,1,"","VarNetBlock"]],"mridc.collections.reconstruction.models.varnet.vn_block.VarNetBlock":[[27,2,1,"","forward"],[27,2,1,"","sens_expand"],[27,2,1,"","sens_reduce"],[27,4,1,"","training"]],"mridc.collections.reconstruction.models.vn":[[13,1,1,"","VarNet"]],"mridc.collections.reconstruction.models.vn.VarNet":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.vsnet":[[13,1,1,"","VSNet"]],"mridc.collections.reconstruction.models.vsnet.VSNet":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.xpdnet":[[13,1,1,"","XPDNet"]],"mridc.collections.reconstruction.models.xpdnet.XPDNet":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.models.zf":[[13,1,1,"","ZF"]],"mridc.collections.reconstruction.models.zf.ZF":[[13,4,1,"","allow_zero_length_dataloader_with_multiple_devices"],[13,2,1,"","forward"],[13,4,1,"","mse_vals"],[13,4,1,"","nmse_vals"],[13,4,1,"","precision"],[13,4,1,"","prepare_data_per_node"],[13,2,1,"","process_inputs"],[13,4,1,"","psnr_vals"],[13,4,1,"","ssim_vals"],[13,2,1,"","test_step"],[13,4,1,"","trainer"],[13,4,1,"","training"]],"mridc.collections.reconstruction.parts":[[28,0,0,"-","transforms"],[28,0,0,"-","utils"]],"mridc.collections.reconstruction.parts.transforms":[[28,1,1,"","MRIDataTransforms"]],"mridc.collections.reconstruction.parts.transforms.MRIDataTransforms":[[28,2,1,"","__call__"]],"mridc.collections.reconstruction.parts.utils":[[28,5,1,"","apply_mask"],[28,5,1,"","batched_mask_center"],[28,5,1,"","center_crop"],[28,5,1,"","center_crop_to_smallest"],[28,5,1,"","complex_center_crop"],[28,5,1,"","mask_center"]],"mridc.core":[[30,0,0,"-","classes"],[31,0,0,"-","conf"],[32,0,0,"-","connectors"],[33,0,0,"-","neural_types"],[34,0,0,"-","optim"],[35,0,0,"-","utils"]],"mridc.core.classes":[[30,0,0,"-","common"],[30,0,0,"-","dataset"],[30,0,0,"-","export"],[30,0,0,"-","loss"],[30,0,0,"-","modelPT"],[30,0,0,"-","module"]],"mridc.core.classes.common":[[30,1,1,"","FileIO"],[30,1,1,"","Model"],[30,1,1,"","PretrainedModelInfo"],[30,1,1,"","Serialization"],[30,1,1,"","Typing"],[30,5,1,"","is_typecheck_enabled"],[30,1,1,"","typecheck"]],"mridc.core.classes.common.FileIO":[[30,2,1,"","from_config_file"],[30,2,1,"","restore_from"],[30,2,1,"","save_to"],[30,2,1,"","to_config_file"]],"mridc.core.classes.common.Model":[[30,2,1,"","from_pretrained"],[30,2,1,"","get_available_model_names"],[30,2,1,"","list_available_models"]],"mridc.core.classes.common.PretrainedModelInfo":[[30,4,1,"","aliases"],[30,4,1,"","class_"],[30,4,1,"","description"],[30,4,1,"","location"],[30,4,1,"","pretrained_model_name"]],"mridc.core.classes.common.Serialization":[[30,2,1,"","from_config_dict"],[30,2,1,"","to_config_dict"]],"mridc.core.classes.common.Typing":[[30,3,1,"","input_types"],[30,3,1,"","output_types"]],"mridc.core.classes.common.typecheck":[[30,1,1,"","TypeState"],[30,2,1,"","disable_checks"],[30,2,1,"","set_typecheck_enabled"]],"mridc.core.classes.common.typecheck.TypeState":[[30,4,1,"","UNINITIALIZED"]],"mridc.core.classes.dataset":[[30,1,1,"","Dataset"],[30,1,1,"","DatasetConfig"],[30,1,1,"","IterableDataset"]],"mridc.core.classes.dataset.Dataset":[[30,2,1,"","collate_fn"]],"mridc.core.classes.dataset.DatasetConfig":[[30,4,1,"","batch_size"],[30,4,1,"","drop_last"],[30,4,1,"","num_workers"],[30,4,1,"","pin_memory"],[30,4,1,"","shuffle"]],"mridc.core.classes.dataset.IterableDataset":[[30,2,1,"","collate_fn"]],"mridc.core.classes.export":[[30,1,1,"","ExportFormat"],[30,1,1,"","Exportable"]],"mridc.core.classes.export.ExportFormat":[[30,4,1,"","ONNX"],[30,4,1,"","TORCHSCRIPT"]],"mridc.core.classes.export.Exportable":[[30,3,1,"","disabled_deployment_input_names"],[30,3,1,"","disabled_deployment_output_names"],[30,2,1,"","export"],[30,3,1,"","input_module"],[30,3,1,"","input_names"],[30,3,1,"","output_module"],[30,3,1,"","output_names"],[30,3,1,"","supported_export_formats"]],"mridc.core.classes.loss":[[30,1,1,"","Loss"]],"mridc.core.classes.loss.Loss":[[30,4,1,"","reduction"]],"mridc.core.classes.modelPT":[[30,1,1,"","ModelPT"]],"mridc.core.classes.modelPT.ModelPT":[[30,2,1,"","__init_subclass__"],[30,3,1,"","cfg"],[30,2,1,"","configure_optimizers"],[30,2,1,"","extract_state_dict_from"],[30,2,1,"","get_test_dataloader_prefix"],[30,2,1,"","get_validation_dataloader_prefix"],[30,2,1,"","load_from_checkpoint"],[30,2,1,"","load_part_of_state_dict"],[30,2,1,"","maybe_init_from_pretrained_checkpoint"],[30,2,1,"","multi_test_epoch_end"],[30,2,1,"","multi_validation_epoch_end"],[30,3,1,"","num_weights"],[30,2,1,"","prepare_test"],[30,2,1,"","register_artifact"],[30,2,1,"","restore_from"],[30,2,1,"","save_to"],[30,2,1,"","set_trainer"],[30,2,1,"","set_world_size"],[30,2,1,"","setup_multiple_test_data"],[30,2,1,"","setup_multiple_validation_data"],[30,2,1,"","setup_optimization"],[30,2,1,"","setup_optimizer_param_groups"],[30,2,1,"","setup_test_data"],[30,2,1,"","setup_training_data"],[30,2,1,"","setup_validation_data"],[30,2,1,"","teardown"],[30,2,1,"","test_dataloader"],[30,2,1,"","test_epoch_end"],[30,2,1,"","train_dataloader"],[30,4,1,"","training"],[30,2,1,"","update_save_restore_connector"],[30,2,1,"","val_dataloader"],[30,2,1,"","validation_epoch_end"]],"mridc.core.classes.module":[[30,1,1,"","NeuralModule"]],"mridc.core.classes.module.NeuralModule":[[30,2,1,"","as_frozen"],[30,2,1,"","freeze"],[30,2,1,"","input_example"],[30,3,1,"","num_weights"],[30,4,1,"","training"],[30,2,1,"","unfreeze"]],"mridc.core.conf":[[31,0,0,"-","base_config"],[31,0,0,"-","dataloader"],[31,0,0,"-","hydra_runner"],[31,0,0,"-","modelPT"],[31,0,0,"-","optimizers"],[31,0,0,"-","schedulers"],[31,0,0,"-","trainer"]],"mridc.core.conf.base_config":[[31,1,1,"","Config"]],"mridc.core.conf.base_config.Config":[[31,4,1,"","name"]],"mridc.core.conf.dataloader":[[31,1,1,"","DataLoaderConfig"]],"mridc.core.conf.dataloader.DataLoaderConfig":[[31,4,1,"","batch_sampler"],[31,4,1,"","batch_size"],[31,4,1,"","collate_fn"],[31,4,1,"","drop_last"],[31,4,1,"","multiprocessing_context"],[31,4,1,"","num_workers"],[31,4,1,"","pin_memory"],[31,4,1,"","sampler"],[31,4,1,"","shuffle"],[31,4,1,"","timeout"],[31,4,1,"","worker_init_fn"]],"mridc.core.conf.hydra_runner":[[31,5,1,"","hydra_runner"]],"mridc.core.conf.modelPT":[[31,1,1,"","HydraConfig"],[31,1,1,"","MRIDCConfig"],[31,1,1,"","ModelConfig"],[31,1,1,"","ModelConfigBuilder"],[31,1,1,"","OptimConfig"],[31,1,1,"","SchedConfig"]],"mridc.core.conf.modelPT.HydraConfig":[[31,4,1,"","job_logging"],[31,4,1,"","run"]],"mridc.core.conf.modelPT.MRIDCConfig":[[31,4,1,"","exp_manager"],[31,4,1,"","hydra"],[31,4,1,"","model"],[31,4,1,"","name"],[31,4,1,"","trainer"]],"mridc.core.conf.modelPT.ModelConfig":[[31,4,1,"","optim"],[31,4,1,"","test_ds"],[31,4,1,"","train_ds"],[31,4,1,"","validation_ds"]],"mridc.core.conf.modelPT.ModelConfigBuilder":[[31,2,1,"","build"],[31,2,1,"","set_optim"],[31,2,1,"","set_test_ds"],[31,2,1,"","set_train_ds"],[31,2,1,"","set_validation_ds"]],"mridc.core.conf.modelPT.OptimConfig":[[31,4,1,"","name"],[31,4,1,"","sched"]],"mridc.core.conf.modelPT.SchedConfig":[[31,4,1,"","last_epoch"],[31,4,1,"","min_lr"],[31,4,1,"","name"]],"mridc.core.conf.optimizers":[[31,1,1,"","AdadeltaParams"],[31,1,1,"","AdagradParams"],[31,1,1,"","AdamParams"],[31,1,1,"","AdamWParams"],[31,1,1,"","AdamaxParams"],[31,1,1,"","NovogradParams"],[31,1,1,"","OptimizerParams"],[31,1,1,"","RMSpropParams"],[31,1,1,"","RpropParams"],[31,1,1,"","SGDParams"],[31,5,1,"","get_optimizer_config"],[31,5,1,"","register_optimizer_params"]],"mridc.core.conf.optimizers.AdadeltaParams":[[31,4,1,"","eps"],[31,4,1,"","rho"],[31,4,1,"","weight_decay"]],"mridc.core.conf.optimizers.AdagradParams":[[31,4,1,"","eps"],[31,4,1,"","initial_accumulator_value"],[31,4,1,"","lr_decay"],[31,4,1,"","weight_decay"]],"mridc.core.conf.optimizers.AdamParams":[[31,4,1,"","amsgrad"],[31,4,1,"","eps"],[31,4,1,"","weight_decay"]],"mridc.core.conf.optimizers.AdamWParams":[[31,4,1,"","amsgrad"],[31,4,1,"","betas"],[31,4,1,"","eps"],[31,4,1,"","weight_decay"]],"mridc.core.conf.optimizers.AdamaxParams":[[31,4,1,"","betas"],[31,4,1,"","eps"],[31,4,1,"","weight_decay"]],"mridc.core.conf.optimizers.NovogradParams":[[31,4,1,"","amsgrad"],[31,4,1,"","betas"],[31,4,1,"","eps"],[31,4,1,"","grad_averaging"],[31,4,1,"","lr"],[31,4,1,"","luc"],[31,4,1,"","luc_eps"],[31,4,1,"","luc_trust"],[31,4,1,"","weight_decay"]],"mridc.core.conf.optimizers.OptimizerParams":[[31,4,1,"","lr"]],"mridc.core.conf.optimizers.RMSpropParams":[[31,4,1,"","alpha"],[31,4,1,"","centered"],[31,4,1,"","eps"],[31,4,1,"","momentum"],[31,4,1,"","weight_decay"]],"mridc.core.conf.optimizers.RpropParams":[[31,4,1,"","etas"],[31,4,1,"","step_sizes"]],"mridc.core.conf.optimizers.SGDParams":[[31,4,1,"","dampening"],[31,4,1,"","momentum"],[31,4,1,"","nesterov"],[31,4,1,"","weight_decay"]],"mridc.core.conf.schedulers":[[31,1,1,"","CosineAnnealingParams"],[31,1,1,"","CyclicLRParams"],[31,1,1,"","ExponentialLRParams"],[31,1,1,"","InverseSquareRootAnnealingParams"],[31,1,1,"","NoamAnnealingParams"],[31,1,1,"","PolynomialDecayAnnealingParams"],[31,1,1,"","PolynomialHoldDecayAnnealingParams"],[31,1,1,"","ReduceLROnPlateauParams"],[31,1,1,"","SchedulerParams"],[31,1,1,"","SquareAnnealingParams"],[31,1,1,"","SquareRootAnnealingParams"],[31,1,1,"","SquareRootConstantSchedulerParams"],[31,1,1,"","StepLRParams"],[31,1,1,"","WarmupAnnealingHoldSchedulerParams"],[31,1,1,"","WarmupAnnealingParams"],[31,1,1,"","WarmupHoldSchedulerParams"],[31,1,1,"","WarmupSchedulerParams"],[31,5,1,"","get_scheduler_config"],[31,5,1,"","register_scheduler_params"]],"mridc.core.conf.schedulers.CosineAnnealingParams":[[31,4,1,"","min_lr"]],"mridc.core.conf.schedulers.CyclicLRParams":[[31,4,1,"","base_lr"],[31,4,1,"","base_momentum"],[31,4,1,"","cycle_momentum"],[31,4,1,"","gamma"],[31,4,1,"","max_lr"],[31,4,1,"","max_momentum"],[31,4,1,"","mode"],[31,4,1,"","scale_mode"],[31,4,1,"","step_size_down"],[31,4,1,"","step_size_up"]],"mridc.core.conf.schedulers.ExponentialLRParams":[[31,4,1,"","gamma"]],"mridc.core.conf.schedulers.NoamAnnealingParams":[[31,4,1,"","min_lr"]],"mridc.core.conf.schedulers.PolynomialDecayAnnealingParams":[[31,4,1,"","cycle"],[31,4,1,"","power"]],"mridc.core.conf.schedulers.PolynomialHoldDecayAnnealingParams":[[31,4,1,"","cycle"],[31,4,1,"","power"]],"mridc.core.conf.schedulers.ReduceLROnPlateauParams":[[31,4,1,"","cooldown"],[31,4,1,"","eps"],[31,4,1,"","factor"],[31,4,1,"","min_lr"],[31,4,1,"","mode"],[31,4,1,"","patience"],[31,4,1,"","threshold"],[31,4,1,"","threshold_mode"],[31,4,1,"","verbose"]],"mridc.core.conf.schedulers.SchedulerParams":[[31,4,1,"","last_epoch"]],"mridc.core.conf.schedulers.SquareAnnealingParams":[[31,4,1,"","min_lr"]],"mridc.core.conf.schedulers.SquareRootAnnealingParams":[[31,4,1,"","min_lr"]],"mridc.core.conf.schedulers.SquareRootConstantSchedulerParams":[[31,4,1,"","constant_ratio"],[31,4,1,"","constant_steps"]],"mridc.core.conf.schedulers.StepLRParams":[[31,4,1,"","gamma"],[31,4,1,"","step_size"]],"mridc.core.conf.schedulers.WarmupAnnealingHoldSchedulerParams":[[31,4,1,"","constant_ratio"],[31,4,1,"","constant_steps"],[31,4,1,"","min_lr"]],"mridc.core.conf.schedulers.WarmupAnnealingParams":[[31,4,1,"","warmup_ratio"]],"mridc.core.conf.schedulers.WarmupHoldSchedulerParams":[[31,4,1,"","hold_ratio"],[31,4,1,"","hold_steps"],[31,4,1,"","min_lr"]],"mridc.core.conf.schedulers.WarmupSchedulerParams":[[31,4,1,"","max_steps"],[31,4,1,"","warmup_ratio"],[31,4,1,"","warmup_steps"]],"mridc.core.conf.trainer":[[31,1,1,"","TrainerConfig"]],"mridc.core.conf.trainer.TrainerConfig":[[31,4,1,"","accelerator"],[31,4,1,"","accumulate_grad_batches"],[31,4,1,"","amp_backend"],[31,4,1,"","amp_level"],[31,4,1,"","auto_lr_find"],[31,4,1,"","auto_scale_batch_size"],[31,4,1,"","auto_select_gpus"],[31,4,1,"","benchmark"],[31,4,1,"","callbacks"],[31,4,1,"","check_val_every_n_epoch"],[31,4,1,"","checkpoint_callback"],[31,4,1,"","default_root_dir"],[31,4,1,"","detect_anomaly"],[31,4,1,"","deterministic"],[31,4,1,"","devices"],[31,4,1,"","enable_checkpointing"],[31,4,1,"","enable_model_summary"],[31,4,1,"","enable_progress_bar"],[31,4,1,"","fast_dev_run"],[31,4,1,"","flush_logs_every_n_steps"],[31,4,1,"","gpus"],[31,4,1,"","gradient_clip_algorithm"],[31,4,1,"","gradient_clip_val"],[31,4,1,"","ipus"],[31,4,1,"","limit_predict_batches"],[31,4,1,"","limit_test_batches"],[31,4,1,"","limit_train_batches"],[31,4,1,"","limit_val_batches"],[31,4,1,"","log_every_n_steps"],[31,4,1,"","log_gpu_memory"],[31,4,1,"","logger"],[31,4,1,"","max_epochs"],[31,4,1,"","max_steps"],[31,4,1,"","max_time"],[31,4,1,"","min_epochs"],[31,4,1,"","min_steps"],[31,4,1,"","move_metrics_to_cpu"],[31,4,1,"","multiple_trainloader_mode"],[31,4,1,"","num_nodes"],[31,4,1,"","num_sanity_val_steps"],[31,4,1,"","overfit_batches"],[31,4,1,"","plugins"],[31,4,1,"","precision"],[31,4,1,"","prepare_data_per_node"],[31,4,1,"","process_position"],[31,4,1,"","profiler"],[31,4,1,"","progress_bar_refresh_rate"],[31,4,1,"","reload_dataloaders_every_n_epochs"],[31,4,1,"","replace_sampler_ddp"],[31,4,1,"","resume_from_checkpoint"],[31,4,1,"","stochastic_weight_avg"],[31,4,1,"","strategy"],[31,4,1,"","sync_batchnorm"],[31,4,1,"","terminate_on_nan"],[31,4,1,"","tpu_cores"],[31,4,1,"","track_grad_norm"],[31,4,1,"","val_check_interval"],[31,4,1,"","weights_save_path"],[31,4,1,"","weights_summary"]],"mridc.core.connectors":[[32,0,0,"-","save_restore_connector"]],"mridc.core.connectors.save_restore_connector":[[32,1,1,"","SaveRestoreConnector"]],"mridc.core.connectors.save_restore_connector.SaveRestoreConnector":[[32,2,1,"","extract_state_dict_from"],[32,2,1,"","load_config_and_state_dict"],[32,2,1,"","load_instance_with_state_dict"],[32,3,1,"","model_config_yaml"],[32,3,1,"","model_weights_ckpt"],[32,2,1,"","register_artifact"],[32,2,1,"","restore_from"],[32,2,1,"","save_to"]],"mridc.core.neural_types":[[33,0,0,"-","axes"],[33,0,0,"-","comparison"],[33,0,0,"-","elements"],[33,0,0,"-","neural_type"]],"mridc.core.neural_types.axes":[[33,1,1,"","AxisKind"],[33,1,1,"","AxisKindAbstract"],[33,1,1,"","AxisType"]],"mridc.core.neural_types.axes.AxisKind":[[33,4,1,"","Any"],[33,4,1,"","Batch"],[33,4,1,"","Channel"],[33,4,1,"","Dimension"],[33,4,1,"","FlowGroup"],[33,4,1,"","Height"],[33,4,1,"","Sequence"],[33,4,1,"","Singleton"],[33,4,1,"","Time"],[33,4,1,"","Width"],[33,2,1,"","__repr__"],[33,2,1,"","__str__"],[33,2,1,"","from_str"],[33,2,1,"","t_with_string"]],"mridc.core.neural_types.axes.AxisType":[[33,2,1,"","__repr__"]],"mridc.core.neural_types.comparison":[[33,1,1,"","NeuralTypeComparisonResult"]],"mridc.core.neural_types.comparison.NeuralTypeComparisonResult":[[33,4,1,"","CONTAINER_SIZE_MISMATCH"],[33,4,1,"","DIM_INCOMPATIBLE"],[33,4,1,"","GREATER"],[33,4,1,"","INCOMPATIBLE"],[33,4,1,"","LESS"],[33,4,1,"","SAME"],[33,4,1,"","SAME_TYPE_INCOMPATIBLE_PARAMS"],[33,4,1,"","TRANSPOSE_SAME"],[33,4,1,"","UNCHECKED"]],"mridc.core.neural_types.elements":[[33,1,1,"","CategoricalValuesType"],[33,1,1,"","ChannelType"],[33,1,1,"","ElementType"],[33,1,1,"","FloatType"],[33,1,1,"","ImageFeatureValue"],[33,1,1,"","ImageValue"],[33,1,1,"","Index"],[33,1,1,"","IntType"],[33,1,1,"","LabelsType"],[33,1,1,"","Length"],[33,1,1,"","LengthsType"],[33,1,1,"","LogDeterminantType"],[33,1,1,"","LogprobsType"],[33,1,1,"","LossType"],[33,1,1,"","MRISignal"],[33,1,1,"","MaskType"],[33,1,1,"","NormalDistributionLogVarianceType"],[33,1,1,"","NormalDistributionMeanType"],[33,1,1,"","NormalDistributionSamplesType"],[33,1,1,"","NormalizedImageValue"],[33,1,1,"","PredictionsType"],[33,1,1,"","ProbsType"],[33,1,1,"","ReconstructionTarget"],[33,1,1,"","RecurrentsType"],[33,1,1,"","RegressionValuesType"],[33,1,1,"","SequenceToSequenceAlignmentType"],[33,1,1,"","StringLabel"],[33,1,1,"","StringType"],[33,1,1,"","Target"],[33,1,1,"","VoidType"]],"mridc.core.neural_types.elements.ElementType":[[33,2,1,"","__repr__"],[33,2,1,"","__str__"],[33,2,1,"","compare"],[33,3,1,"","fields"],[33,3,1,"","type_parameters"]],"mridc.core.neural_types.elements.MRISignal":[[33,3,1,"","type_parameters"]],"mridc.core.neural_types.elements.VoidType":[[33,2,1,"","compare"]],"mridc.core.neural_types.neural_type":[[33,6,1,"","NeuralPortNameMismatchError"],[33,6,1,"","NeuralPortNmTensorMismatchError"],[33,1,1,"","NeuralType"],[33,6,1,"","NeuralTypeError"]],"mridc.core.neural_types.neural_type.NeuralType":[[33,2,1,"","__eq__"],[33,2,1,"","__repr__"],[33,2,1,"","compare"],[33,2,1,"","compare_and_raise_error"]],"mridc.core.optim":[[34,0,0,"-","adafactor"],[34,0,0,"-","lr_scheduler"],[34,0,0,"-","novograd"],[34,0,0,"-","optimizer_with_master_params"],[34,0,0,"-","optimizers"]],"mridc.core.optim.adafactor":[[34,1,1,"","Adafactor"]],"mridc.core.optim.adafactor.Adafactor":[[34,2,1,"","step"],[34,3,1,"","supports_flat_params"],[34,3,1,"","supports_memory_efficient_fp16"]],"mridc.core.optim.lr_scheduler":[[34,1,1,"","CosineAnnealing"],[34,1,1,"","InverseSquareRootAnnealing"],[34,1,1,"","NoamAnnealing"],[34,1,1,"","PolynomialDecayAnnealing"],[34,1,1,"","PolynomialHoldDecayAnnealing"],[34,1,1,"","SquareAnnealing"],[34,1,1,"","SquareRootAnnealing"],[34,1,1,"","SquareRootConstantPolicy"],[34,1,1,"","T5InverseSquareRootAnnealing"],[34,1,1,"","WarmupAnnealHoldPolicy"],[34,1,1,"","WarmupAnnealing"],[34,1,1,"","WarmupHoldPolicy"],[34,1,1,"","WarmupPolicy"],[34,5,1,"","compute_max_steps"],[34,5,1,"","get_scheduler"],[34,5,1,"","prepare_lr_scheduler"],[34,5,1,"","register_scheduler"]],"mridc.core.optim.lr_scheduler.NoamAnnealing":[[34,2,1,"","get_lr"]],"mridc.core.optim.lr_scheduler.SquareRootConstantPolicy":[[34,2,1,"","get_lr"]],"mridc.core.optim.lr_scheduler.WarmupAnnealHoldPolicy":[[34,2,1,"","get_lr"]],"mridc.core.optim.lr_scheduler.WarmupHoldPolicy":[[34,2,1,"","get_lr"]],"mridc.core.optim.lr_scheduler.WarmupPolicy":[[34,2,1,"","get_lr"]],"mridc.core.optim.novograd":[[34,1,1,"","Novograd"]],"mridc.core.optim.novograd.Novograd":[[34,2,1,"","step"]],"mridc.core.optim.optimizer_with_master_params":[[34,1,1,"","GradBucket"],[34,1,1,"","MainParamsOptimizerWrapper"]],"mridc.core.optim.optimizer_with_master_params.GradBucket":[[34,2,1,"","allreduce_buffer"],[34,2,1,"","get"],[34,2,1,"","zero"]],"mridc.core.optim.optimizer_with_master_params.MainParamsOptimizerWrapper":[[34,2,1,"","allreduce_main_grads"],[34,3,1,"","async_master_grads_allreudce"],[34,2,1,"","copy_model_grads_to_main_grads"],[34,3,1,"","fp32_grad_accumulation"],[34,2,1,"","get_parameters"],[34,2,1,"","grad_sync"],[34,2,1,"","load_state_dict"],[34,3,1,"","param_groups"],[34,2,1,"","reload_model_params"],[34,3,1,"","state"],[34,2,1,"","state_dict"],[34,2,1,"","step"],[34,2,1,"","zero_grad"]],"mridc.core.optim.optimizers":[[34,5,1,"","get_optimizer"],[34,5,1,"","parse_optimizer_args"],[34,5,1,"","register_optimizer"]],"mridc.core.utils":[[35,0,0,"-","neural_type_utils"],[35,0,0,"-","numba_utils"]],"mridc.core.utils.neural_type_utils":[[35,5,1,"","extract_dynamic_axes"],[35,5,1,"","get_dynamic_axes"],[35,5,1,"","get_io_names"]],"mridc.core.utils.numba_utils":[[35,5,1,"","is_numba_compat_strict"],[35,5,1,"","numba_cpu_is_supported"],[35,5,1,"","numba_cuda_is_supported"],[35,5,1,"","set_numba_compat_strictness"],[35,5,1,"","skip_numba_cuda_test_if_unsupported"],[35,5,1,"","with_numba_compat_strictness"]],"mridc.launch":[[2,5,1,"","main"]],"mridc.utils":[[36,0,0,"-","app_state"],[36,0,0,"-","arguments"],[36,0,0,"-","cloud"],[36,0,0,"-","config_utils"],[37,0,0,"-","decorators"],[36,0,0,"-","distributed"],[36,0,0,"-","env_var_parsing"],[36,0,0,"-","exceptions"],[36,0,0,"-","exp_manager"],[36,0,0,"-","export_utils"],[38,0,0,"-","formaters"],[36,0,0,"-","get_rank"],[36,0,0,"-","lightning_logger_patch"],[36,0,0,"-","metaclasses"],[36,0,0,"-","model_utils"],[36,0,0,"-","mridc_logging"],[36,0,0,"-","timers"]],"mridc.utils.app_state":[[36,1,1,"","AppState"],[36,1,1,"","ModelMetadataRegistry"]],"mridc.utils.app_state.AppState":[[36,3,1,"","checkpoint_callback_params"],[36,3,1,"","checkpoint_name"],[36,3,1,"","create_checkpoint_callback"],[36,3,1,"","data_parallel_group"],[36,3,1,"","data_parallel_rank"],[36,3,1,"","data_parallel_size"],[36,3,1,"","device_id"],[36,3,1,"","exp_dir"],[36,2,1,"","get_model_metadata_from_guid"],[36,3,1,"","global_rank"],[36,3,1,"","is_model_being_restored"],[36,3,1,"","local_rank"],[36,3,1,"","log_dir"],[36,3,1,"","model_parallel_size"],[36,3,1,"","model_restore_path"],[36,3,1,"","mridc_file_folder"],[36,3,1,"","name"],[36,3,1,"","pipeline_model_parallel_group"],[36,3,1,"","pipeline_model_parallel_rank"],[36,3,1,"","pipeline_model_parallel_size"],[36,3,1,"","pipeline_model_parallel_split_rank"],[36,3,1,"","random_seed"],[36,2,1,"","register_model_guid"],[36,2,1,"","reset_model_guid_registry"],[36,3,1,"","tensor_model_parallel_group"],[36,3,1,"","tensor_model_parallel_rank"],[36,3,1,"","tensor_model_parallel_size"],[36,3,1,"","version"],[36,3,1,"","world_size"]],"mridc.utils.app_state.ModelMetadataRegistry":[[36,4,1,"","gidx"],[36,4,1,"","guid"],[36,4,1,"","restoration_path"]],"mridc.utils.arguments":[[36,5,1,"","add_optimizer_args"],[36,5,1,"","add_recon_args"],[36,5,1,"","add_scheduler_args"]],"mridc.utils.cloud":[[36,5,1,"","maybe_download_from_cloud"]],"mridc.utils.config_utils":[[36,5,1,"","assert_dataclass_signature_match"],[36,5,1,"","update_model_config"]],"mridc.utils.decorators":[[37,0,0,"-","deprecated"],[37,0,0,"-","experimental"],[37,0,0,"-","port_docs"]],"mridc.utils.decorators.deprecated":[[37,5,1,"","deprecated"]],"mridc.utils.decorators.experimental":[[37,5,1,"","experimental"]],"mridc.utils.decorators.port_docs":[[37,5,1,"","add_port_docs"]],"mridc.utils.distributed":[[36,5,1,"","initialize_distributed"]],"mridc.utils.env_var_parsing":[[36,6,1,"","CoercionError"],[36,6,1,"","RequiredSettingMissingError"],[36,5,1,"","get_env"],[36,5,1,"","get_envbool"],[36,5,1,"","get_envdate"],[36,5,1,"","get_envdatetime"],[36,5,1,"","get_envdecimal"],[36,5,1,"","get_envdict"],[36,5,1,"","get_envfloat"],[36,5,1,"","get_envint"],[36,5,1,"","get_envlist"]],"mridc.utils.exceptions":[[36,1,1,"","CheckInstall"],[36,6,1,"","LightningNotInstalledException"],[36,6,1,"","MRIDCBaseException"]],"mridc.utils.exp_manager":[[36,1,1,"","CallbackParams"],[36,6,1,"","CheckpointMisconfigurationError"],[36,1,1,"","ExpManagerConfig"],[36,1,1,"","LoggerList"],[36,6,1,"","LoggerMisconfigurationError"],[36,1,1,"","MRIDCModelCheckpoint"],[36,6,1,"","NotFoundError"],[36,1,1,"","StatelessTimer"],[36,1,1,"","StepTimingParams"],[36,1,1,"","TimingCallback"],[36,5,1,"","check_explicit_log_dir"],[36,5,1,"","check_resume"],[36,5,1,"","check_slurm"],[36,5,1,"","configure_checkpointing"],[36,5,1,"","configure_loggers"],[36,5,1,"","error_checks"],[36,5,1,"","exp_manager"],[36,5,1,"","get_git_diff"],[36,5,1,"","get_git_hash"],[36,5,1,"","get_log_dir"]],"mridc.utils.exp_manager.CallbackParams":[[36,4,1,"","always_save_mridc"],[36,4,1,"","dirpath"],[36,4,1,"","every_n_epochs"],[36,4,1,"","filename"],[36,4,1,"","filepath"],[36,4,1,"","mode"],[36,4,1,"","model_parallel_size"],[36,4,1,"","monitor"],[36,4,1,"","postfix"],[36,4,1,"","prefix"],[36,4,1,"","save_best_model"],[36,4,1,"","save_last"],[36,4,1,"","save_mridc_on_train_end"],[36,4,1,"","save_top_k"],[36,4,1,"","save_weights_only"],[36,4,1,"","verbose"]],"mridc.utils.exp_manager.ExpManagerConfig":[[36,4,1,"","checkpoint_callback_params"],[36,4,1,"","create_checkpoint_callback"],[36,4,1,"","create_tensorboard_logger"],[36,4,1,"","create_wandb_logger"],[36,4,1,"","exp_dir"],[36,4,1,"","explicit_log_dir"],[36,4,1,"","files_to_copy"],[36,4,1,"","log_global_rank_0_only"],[36,4,1,"","log_local_rank_0_only"],[36,4,1,"","log_step_timing"],[36,4,1,"","model_parallel_size"],[36,4,1,"","name"],[36,4,1,"","resume_if_exists"],[36,4,1,"","resume_ignore_no_checkpoint"],[36,4,1,"","resume_past_end"],[36,4,1,"","step_timing_kwargs"],[36,4,1,"","summary_writer_kwargs"],[36,4,1,"","use_datetime_version"],[36,4,1,"","version"],[36,4,1,"","wandb_logger_kwargs"]],"mridc.utils.exp_manager.LoggerList":[[36,3,1,"","name"],[36,3,1,"","version"]],"mridc.utils.exp_manager.MRIDCModelCheckpoint":[[36,2,1,"","mridc_topk_check_previous_run"],[36,2,1,"","on_save_checkpoint"],[36,2,1,"","on_train_end"]],"mridc.utils.exp_manager.StatelessTimer":[[36,2,1,"","load_state_dict"],[36,2,1,"","state_dict"]],"mridc.utils.exp_manager.StepTimingParams":[[36,4,1,"","buffer_size"],[36,4,1,"","reduction"],[36,4,1,"","sync_cuda"]],"mridc.utils.exp_manager.TimingCallback":[[36,2,1,"","on_after_backward"],[36,2,1,"","on_before_backward"],[36,2,1,"","on_test_batch_end"],[36,2,1,"","on_test_batch_start"],[36,2,1,"","on_train_batch_end"],[36,2,1,"","on_train_batch_start"],[36,2,1,"","on_validation_batch_end"],[36,2,1,"","on_validation_batch_start"]],"mridc.utils.export_utils":[[36,1,1,"","CastToFloat"],[36,1,1,"","ExportFormat"],[36,5,1,"","augment_filename"],[36,5,1,"","forward_method"],[36,5,1,"","get_export_format"],[36,5,1,"","parse_input_example"],[36,5,1,"","replace_for_export"],[36,5,1,"","replace_modules"],[36,5,1,"","simple_replace"],[36,5,1,"","swap_modules"],[36,5,1,"","to_onnxrt_input"],[36,5,1,"","verify_runtime"],[36,5,1,"","wrap_forward_method"],[36,5,1,"","wrap_module"]],"mridc.utils.export_utils.CastToFloat":[[36,2,1,"","forward"],[36,4,1,"","training"]],"mridc.utils.export_utils.ExportFormat":[[36,4,1,"","ONNX"],[36,4,1,"","TORCHSCRIPT"]],"mridc.utils.formaters":[[38,0,0,"-","base"],[38,0,0,"-","colors"],[38,0,0,"-","utils"]],"mridc.utils.formaters.base":[[38,1,1,"","BaseMRIDCFormatter"],[38,1,1,"","DebugMRIDCFormatter"]],"mridc.utils.formaters.base.BaseMRIDCFormatter":[[38,4,1,"","DEFAULT_FORMAT"]],"mridc.utils.formaters.base.DebugMRIDCFormatter":[[38,4,1,"","DEFAULT_FORMAT"]],"mridc.utils.formaters.colors":[[38,1,1,"","AnsiBack"],[38,1,1,"","AnsiCodes"],[38,1,1,"","AnsiCursor"],[38,1,1,"","AnsiFore"],[38,1,1,"","AnsiStyle"],[38,5,1,"","clear_line"],[38,5,1,"","clear_screen"],[38,5,1,"","code_to_chars"],[38,5,1,"","set_title"]],"mridc.utils.formaters.colors.AnsiBack":[[38,4,1,"","BLACK"],[38,4,1,"","BLUE"],[38,4,1,"","CYAN"],[38,4,1,"","GREEN"],[38,4,1,"","LIGHTBLACK_EX"],[38,4,1,"","LIGHTBLUE_EX"],[38,4,1,"","LIGHTCYAN_EX"],[38,4,1,"","LIGHTGREEN_EX"],[38,4,1,"","LIGHTMAGENTA_EX"],[38,4,1,"","LIGHTRED_EX"],[38,4,1,"","LIGHTWHITE_EX"],[38,4,1,"","LIGHTYELLOW_EX"],[38,4,1,"","MAGENTA"],[38,4,1,"","RED"],[38,4,1,"","RESET"],[38,4,1,"","WHITE"],[38,4,1,"","YELLOW"]],"mridc.utils.formaters.colors.AnsiCursor":[[38,2,1,"","BACK"],[38,2,1,"","DOWN"],[38,2,1,"","FORWARD"],[38,2,1,"","POS"],[38,2,1,"","UP"]],"mridc.utils.formaters.colors.AnsiFore":[[38,4,1,"","BLACK"],[38,4,1,"","BLUE"],[38,4,1,"","CYAN"],[38,4,1,"","GREEN"],[38,4,1,"","LIGHTBLACK_EX"],[38,4,1,"","LIGHTBLUE_EX"],[38,4,1,"","LIGHTCYAN_EX"],[38,4,1,"","LIGHTGREEN_EX"],[38,4,1,"","LIGHTMAGENTA_EX"],[38,4,1,"","LIGHTRED_EX"],[38,4,1,"","LIGHTWHITE_EX"],[38,4,1,"","LIGHTYELLOW_EX"],[38,4,1,"","MAGENTA"],[38,4,1,"","RED"],[38,4,1,"","RESET"],[38,4,1,"","WHITE"],[38,4,1,"","YELLOW"]],"mridc.utils.formaters.colors.AnsiStyle":[[38,4,1,"","BRIGHT"],[38,4,1,"","DIM"],[38,4,1,"","NORMAL"],[38,4,1,"","RESET_ALL"]],"mridc.utils.formaters.utils":[[38,5,1,"","check_color_support"],[38,5,1,"","to_unicode"]],"mridc.utils.get_rank":[[36,5,1,"","is_global_rank_zero"]],"mridc.utils.lightning_logger_patch":[[36,5,1,"","add_filehandlers_to_pl_logger"],[36,5,1,"","add_memory_handlers_to_pl_logger"]],"mridc.utils.metaclasses":[[36,1,1,"","Singleton"]],"mridc.utils.metaclasses.Singleton":[[36,2,1,"","__call__"]],"mridc.utils.model_utils":[[36,1,1,"","ArtifactItem"],[36,1,1,"","ArtifactPathType"],[36,5,1,"","check_lib_version"],[36,5,1,"","convert_model_config_to_dict_config"],[36,5,1,"","import_class_by_path"],[36,5,1,"","inject_model_parallel_rank"],[36,5,1,"","maybe_update_config_version"],[36,5,1,"","parse_dataset_as_name"],[36,5,1,"","resolve_cache_dir"],[36,5,1,"","resolve_dataset_name_from_cfg"],[36,5,1,"","resolve_subclass_pretrained_model_info"],[36,5,1,"","resolve_validation_dataloaders"],[36,5,1,"","uninject_model_parallel_rank"],[36,5,1,"","unique_names_check"],[36,5,1,"","wrap_training_step"]],"mridc.utils.model_utils.ArtifactItem":[[36,4,1,"","hashed_path"],[36,4,1,"","path"],[36,4,1,"","path_type"]],"mridc.utils.model_utils.ArtifactPathType":[[36,4,1,"","LOCAL_PATH"],[36,4,1,"","TAR_PATH"]],"mridc.utils.mridc_logging":[[36,1,1,"","LogMode"],[36,1,1,"","Logger"]],"mridc.utils.mridc_logging.LogMode":[[36,4,1,"","EACH"],[36,4,1,"","ONCE"]],"mridc.utils.mridc_logging.Logger":[[36,4,1,"","CRITICAL"],[36,4,1,"","DEBUG"],[36,4,1,"","ERROR"],[36,4,1,"","INFO"],[36,4,1,"","NOTSET"],[36,4,1,"","WARNING"],[36,2,1,"","add_err_file_handler"],[36,2,1,"","add_file_handler"],[36,2,1,"","add_stream_handlers"],[36,2,1,"","captureWarnings"],[36,2,1,"","critical"],[36,2,1,"","debug"],[36,2,1,"","error"],[36,2,1,"","getEffectiveLevel"],[36,2,1,"","get_verbosity"],[36,2,1,"","info"],[36,2,1,"","patch_stderr_handler"],[36,2,1,"","patch_stdout_handler"],[36,2,1,"","remove_stream_handlers"],[36,2,1,"","reset_stream_handler"],[36,2,1,"","setLevel"],[36,2,1,"","set_verbosity"],[36,2,1,"","temp_verbosity"],[36,2,1,"","warning"]],"mridc.utils.timers":[[36,1,1,"","NamedTimer"]],"mridc.utils.timers.NamedTimer":[[36,2,1,"","active_timers"],[36,3,1,"","buffer_size"],[36,2,1,"","export"],[36,2,1,"","get"],[36,2,1,"","reset"],[36,2,1,"","start"],[36,2,1,"","stop"]],mridc:[[3,0,0,"-","collections"],[2,0,0,"-","constants"],[29,0,0,"-","core"],[2,0,0,"-","launch"],[2,0,0,"-","package_info"],[36,0,0,"-","utils"]]},objnames:{"0":["py","module","Python module"],"1":["py","class","Python class"],"2":["py","method","Python method"],"3":["py","property","Python property"],"4":["py","attribute","Python attribute"],"5":["py","function","Python function"],"6":["py","exception","Python exception"]},objtypes:{"0":"py:module","1":"py:class","2":"py:method","3":"py:property","4":"py:attribute","5":"py:function","6":"py:exception"},terms:{"0":[6,7,9,11,13,18,20,22,23,24,25,28,30,31,33,34,35,36,38],"0001":31,"001":[31,34,36],"00523":13,"00555":13,"00572":23,"01":[7,30,34,36],"02":[9,11,28],"03":7,"030":[13,26],"04":11,"04235":34,"05":[24,31,34],"06":[24,31],"07071":20,"07290":13,"08":[11,31,34],"09639":[13,22],"09792":24,"1":[0,6,9,13,14,15,16,17,18,19,20,22,23,25,26,28,30,31,32,33,34,35,36,38],"10":[0,13,15,16,23,24,26,31,36],"100":[31,38],"1000":[9,31],"1002":13,"1007":[13,26],"101":38,"102":38,"103":38,"104":38,"105":38,"106":38,"107":38,"11":0,"1109":[13,15,16,23],"11286":[31,34],"11767":[13,26],"12":0,"128":18,"13":0,"1322":13,"14":0,"15":[0,13,25],"15498v1":13,"16":0,"17":0,"1703":13,"1705":24,"1804":34,"1805":20,"1859":13,"1872":13,"1905":[31,34],"1999":[9,13],"1d":[11,28],"1e":[24,30,31,34],"1s":38,"2":[0,7,9,13,14,17,22,23,24,25,28,30,31,32,33,36,38],"20":36,"2000":31,"201":13,"2010":13,"2015":[13,25],"2017":13,"2018":[13,15,16,20,23],"2019":[13,15,16,26],"2020":[13,19],"2021":[0,13,22],"21":11,"2111":[13,22],"2188":13,"22":38,"23":13,"234":[13,25],"241":[13,25],"27201":13,"2799231":13,"280":[13,15,16],"2863670":[13,15,16],"28827":13,"290":[13,15,16],"2d":[11,13,15,20,22,28],"2x2":24,"3":[0,9,11,13,15,26,28,30,31,33,34,36],"30":[34,36,38],"31":38,"32":[13,25,30,31,38],"32251":[13,26],"33":38,"34":38,"35":38,"36":38,"37":[13,38],"38":[13,15,16],"39":38,"4":[0,11,13,20,22,25,33],"40":[36,38],"41":38,"42":[9,13,38],"4294967296":9,"43":38,"44":38,"45":38,"46":38,"47":38,"49":38,"4d":25,"5":[0,13,31,33,34,36],"50":[11,31,36],"5266":13,"54":11,"5457":23,"5466":23,"6":[0,13,18,33],"64":22,"7":[0,7,33],"713":[13,26],"722":[13,26],"75":13,"8":[0,11,13,23,26,31,33,34,38],"80":13,"86":13,"9":[0,11,18,30,31,34],"90":38,"91":38,"92":38,"93":38,"94":38,"95":[31,34,36,38],"952":[9,13],"96":38,"962":[9,13],"97":38,"978":[13,26],"98":[31,34],"99":31,"999":[31,34],"9_78":[13,26],"\u00f6ktem":13,"abstract":[30,31,33],"boolean":[11,33,34,36],"byte":38,"case":[11,30,32,34,35,36],"class":[2,5,6,7,8,9,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,31,32,33,34,36,37,38],"default":[6,8,9,11,13,18,20,22,30,31,32,33,34,36],"do":[13,30,33,36],"enum":[30,33,36],"export":[2,29,36],"final":[9,13,28,30,34,36],"float":[6,7,8,9,11,12,13,19,20,23,25,28,31,33,34,36],"function":[2,9,11,13,21,23,24,26,28,30,31,32,33,36,37],"import":[9,36],"int":[6,7,9,11,13,16,17,18,19,20,22,23,25,26,28,30,31,33,36,38],"long":5,"new":[12,22,30,34,36],"null":34,"public":[0,11],"return":[6,7,8,9,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,30,31,32,33,34,35,36,37,38],"short":33,"static":[6,13,18,20,21,23,24,25,30,32,33,38],"super":18,"throw":36,"true":[6,8,9,13,15,17,20,22,23,24,25,28,30,31,32,33,34,35,36,38],"try":30,"var":36,"void":33,"while":[24,26,34,36],A:[0,6,8,9,11,12,13,14,15,16,19,22,23,25,27,28,30,31,33,34,35,36],As:[11,30],By:[30,32,33],For:[0,9,11,13,30,31,33,35],If:[8,11,13,15,17,20,24,28,30,32,33,34,35,36,38],In:[13,25,30,33,36],It:[8,11,28,30,31,32,33,34,36,37],No:36,On:34,The:[0,8,9,11,12,13,22,28,30,31,32,33,34,35,36,37,38],Then:[33,36],There:30,These:[30,32,36],To:[30,32,34,36],Will:[30,31,34],_1:19,_:[13,19,22],__call__:[11,28,33,36],__eq__:33,__init__:[8,11,36],__init_subclass__:30,__iter__:6,__len__:6,__repr__:[12,33],__str__:33,__version__:36,_checkpoint_connector:36,_get_env:36,_handler:36,_i:19,_input:[19,23],_logger_iter:36,_loss:30,_loss_fn:13,_lrschedul:34,_target_:36,_test_dl_idx:30,_val_test_fastpath_kei:36,ab:[13,20,22,31,34],abc:[6,13,30,33],abcmeta:33,abil:30,about:[23,30,36],abov:[30,36],absolut:[9,30,32,36],acc:[12,13],acceler:[0,11,12,13,22,26,31],accept:[6,33,35],access:0,accommod:11,accord:[11,17],accumul:[8,13,34],accumulate_grad_batch:[31,34],accumulate_loss:13,accur:11,across:[8,34],act:[11,36],activ:[0,15,18,19,20,25,36],active_tim:36,actual:[30,32,36],ad:[31,34,36],adadelta:31,adadeltaparam:31,adafactor:[2,29],adagrad:31,adagradparam:31,adam:[30,31,34,36],adamax:31,adamaxparam:31,adamparam:31,adamw:31,adamwparam:31,adapt:[31,34,36],add:[23,30,31,32,34,36,37],add_err_file_handl:36,add_file_handl:36,add_filehandlers_to_pl_logg:36,add_memory_handlers_to_pl_logg:36,add_optimizer_arg:36,add_port_doc:37,add_recon_arg:36,add_scheduler_arg:36,add_stream_handl:36,addit:34,addition:[34,36],adjoint:24,adjust:34,adler:13,advanc:13,after:[15,30,32,34,36],afterward:[11,24,26],against:36,aggarw:24,aggreg:[3,4,13],aggregatorloss:7,ai:0,aiayn:9,air:[13,19],al:[0,13,18,20,22,23,24,26],algorithm:34,alia:30,alias:30,align:[33,36],all:[9,13,24,26,30,31,34,36],all_log_fil:36,alloc:34,allow:[31,34,36],allow_zero_length_dataloader_with_multiple_devic:13,allreduc:34,allreduce_buff:34,allreduce_main_grad:34,almost:33,along:[9,11,28,30,34,36],alpha:31,alpha_init:24,alreadi:[24,30,36,38],also:36,although:[24,26],alwai:[30,32,33],always_save_mridc:[31,36],amp_backend:31,amp_level:31,amsgrad:[31,34],an:[6,8,9,11,13,30,32,33,34,35,36],anaconda3:11,analys:36,analysi:[0,36],ani:[6,13,16,23,26,28,30,31,32,33,34,36],anneal:[31,34],anoth:[33,36],ansi:38,ansiback:38,ansicod:38,ansicursor:38,ansifor:38,ansistyl:38,app_stat:[1,2],appli:[0,9,13,14,15,16,20,22,24,25,26,27,28],applic:[13,28,36],apply_mask:28,approach:11,appropri:[9,11,30,35,36],appstat:[30,36],ar:[0,8,9,11,12,24,28,30,33,35,36],arbitrari:36,archiv:[30,32,36],aren:33,arg:[24,30,34,36],arg_nam:30,arg_valu:30,argpars:36,argument:[1,2,6,8,12,30,31,32,34,35],argumentpars:36,argv:36,around:36,arrai:9,artifact:[30,32,36],artifactitem:[30,32,36],artifactpathtyp:36,artifici:[13,26],arxiv:[13,20,22,24,31,34],as_frozen:30,asctim:38,aspect:11,asr:[30,32],asr_ckpt:[30,32],assert:[30,32,36],assert_dataclass_signature_match:36,assess:[0,13],assign:36,assist:[13,25],assum:[9,34,36],async:34,async_grad_allreduc:34,async_master_grads_allreudc:34,asynchron:34,attach:36,attempt:36,attent:33,attr:28,attribut:[11,28,36],augment:36,augment_filenam:36,author:0,auto:[34,36],auto_lr_find:31,auto_scale_batch_s:31,auto_select_gpu:31,autocalibr:11,autograd:24,automat:[30,36],avail:[11,13,30,32,35,36],averag:[8,26,34,36],avoid:36,ax:[2,11,29,30,35],axi:33,axiskind:33,axiskindabstract:33,axistyp:33,b:[13,24,33,36],back:[30,32,34,38],backend:[8,36],background:[12,38],backward:[21,24,36],backward_h:21,bart:13,base:[0,2,3,5,6,7,8,9,10,11,12,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,30,31,32,33,34,36],base_class:36,base_config:[2,29],base_lr:31,base_momentum:31,baseclass:36,baseformatt:38,basemridcformatt:[36,38],basemrireconstructionmodel:13,basesensitivitymodel:13,baset:36,batch:[9,12,13,15,17,28,30,33,35,36],batch_chans_to_chan_dim:13,batch_idx:[13,36],batch_sampl:31,batch_siz:[13,14,17,22,23,30,31,34],batched_mask_cent:28,batchnorm:[15,20],becaus:35,becom:36,been:[11,31,34],befor:[8,25],begin:36,behaviour:36,being:[9,33,34,36,37],belong:36,benchmark:31,berkelei:13,bert:9,best:[0,36],beta1:34,beta:[31,34,36],beta_init:24,better:36,between:[30,33,34,36],beyond:34,bf16:34,bia:[20,23],bias:36,binari:36,bioinformat:[13,26],biomed:[13,25],bit:36,black:38,block:[14,16,18,19,20,22,23,24,25,26,27,36],block_idx:17,blue:38,boesig:[9,13],bool:[6,7,11,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,30,31,32,33,34,35,36,38],both:[9,17,36],bound:33,box:33,bright:38,brox:[13,25],buffer:[34,36],buffer_s:[31,36],build:[23,31],builder:31,built:30,builtin:36,button:0,c:[13,17,20,33],c_:13,caan:0,caballero:[13,15,16],cach:[30,36],cache_dir:36,calcul:11,call:[5,8,11,13,24,26,30,31,32,33,34,35,36],callabl:[11,30,31,34,36],callback:[3,4,31,36],callbackparam:[31,36],calling_cl:32,can:[9,14,27,28,30,31,32,33,34,35,36,37],cannot:[8,11,36],captur:36,capturewarn:36,care:[24,26],cascad:[0,13,15,24],cascadenet:[10,13],cascadenetblock:14,cast:36,casttofloat:36,categor:33,categoricalvaluestyp:33,caus:36,ccnn:[0,3,10],ccnn_block:[10,13],cd:0,cell:23,cell_input:22,center:[9,11,28,31],center_crop:28,center_crop_to_smallest:28,center_fract:11,center_scal:28,centered_circl:11,certain:36,cfg:[2,13,30,31,32,36],chain:33,challeng:[11,13,19],chan:[13,25],chan_complex_to_last_dim:25,chang:[9,30,35,36],channel:[13,18,19,22,33],channel_dim:[18,19],channeltyp:33,chans_to_batch_dim:13,charact:38,check:[9,23,30,31,33,34,35,36],check_color_support:38,check_explicit_log_dir:36,check_forward_hidden:23,check_forward_input:23,check_lib_vers:36,check_resum:36,check_slurm:36,check_stacked_complex:9,check_toler:[30,36],check_trac:30,check_val_every_n_epoch:31,checked_vers:36,checkinstal:36,checkpoint:[30,32,36],checkpoint_callback:31,checkpoint_callback_param:[31,36],checkpoint_nam:36,checkpoint_path:30,checkpointmisconfigurationerror:36,children:36,choos:[6,11,31],choose_acceler:11,chose:31,chosen:[11,34],chpt:32,chunk:23,cifar100:33,circl:11,cirim:[0,3,10],cite:0,ckpt:[30,32,36],cl:[30,32,34,36,37],clariti:34,class_:30,class_nam:33,class_typ:36,classif:33,classifi:30,classmethod:30,clear:38,clear_lin:38,clear_screen:38,cli:[12,36],clip_threshold:34,clone:0,close:36,closur:34,cloud:[1,2,30],cluster:36,cnn:[20,24],code:[36,38],code_to_char:38,coeffici:[7,34],coerc:36,coercionerror:36,coil:[0,9,13,17,19,21,22,23,24,26,27],coil_combin:9,coil_dim:[17,19,22],coil_imag:19,coil_to_batch:17,collat:30,collate_fn:[30,31],collect:[0,1,2,30,32,36],collis:30,color:[2,36],column:11,com:[0,11,13],combin:[9,11,13,14,26,27],come:0,command:[31,36],commit:36,common:[2,3,29,36],compar:[33,36],compare_and_raise_error:33,compare_to:33,comparison:[2,29],compat:[33,35,36],compil:35,complet:[30,36],complex:[9,13,22,24,25,28],complex_ab:9,complex_abs_sq:9,complex_center_crop:28,complex_conj:9,complex_dim:22,complex_instance_norm:24,complex_mul:9,complex_pseudocovari:24,complex_to_chan_dim:25,complexdot:24,complexinstancenorm:24,complexnormwrapp:24,compliant:36,compon:[30,32],compress:[0,13],compris:[30,32],comput:[7,8,9,12,13,17,18,20,21,22,23,24,25,26,33,34],compute_max_step:34,compute_model_per_coil:21,compute_on_step:8,concatdataset:6,concaten:36,concept:33,conda:0,conf:[2,29,34,36],confer:[13,23,25],config:[30,31,32,34,36],config_nam:31,config_path:[30,31,32],config_util:[1,2],configur:[2,13,30,31,32,36],configure_checkpoint:36,configure_logg:36,configure_optim:30,conform:36,conjug:[9,24],conjugategradi:24,connect:20,connector:[2,29,30],consecut:36,consid:36,consider:11,consist:[13,14,16,19,24,25,26,27],constant:[0,1,30,34],constant_ratio:[31,34],constant_step:[31,34],constitut:36,construct:[34,36],constructor:[30,32],contain:[8,9,11,24,26,30,34,35,36],container_size_mismatch:33,content:[0,1],context:[11,24,30,34,35],contigu:34,contiguous_grad_bucket:34,continu:36,contribut:0,control:30,conv2d:[10,13],conv2dgru:22,conv2gru:[10,13],conv:[10,13,23],conv_bia:23,conv_dil:23,conv_dim:23,conv_filt:23,conv_kernel:23,conv_lay:[10,13],convblock:[20,25],conveni:[31,34],converg:34,convert:[9,11,13,25,30,32,33,36,38],convert_model_config_to_dict_config:36,convgru:22,convgrucel:23,convgrucellbas:23,convmgucel:23,convmgucellbas:23,convnonlinear:23,convolut:[0,13,14,15,16,18,19,20,22,23,25,33],convrecnet:[10,13],convrnnstack:23,cooldown:31,coordin:11,copi:[24,30,34,36],copy_model_grads_to_main_grad:34,copy_param:24,core:[0,1,2,7,13,36],correct:[9,17,23,34],correctli:36,correspond:[9,11,18,28,30,32,33,35,36],cosin:[31,34],cosineann:34,cosineannealingparam:[31,34],cost:34,could:36,cpu:[30,32,35],creat:[0,11,30,36],create_checkpoint_callback:[31,36],create_mask_for_mask_typ:11,create_tensorboard_logg:[31,36],create_wandb_logg:[31,36],creation:36,criterion:24,critic:36,crnn:[0,3,10,16],crnn_block:[10,13],crnnet:13,crop:[13,18,20,28],crop_before_mask:28,crop_siz:[13,28],crop_to_shap:[18,20],cross:13,crossdomain:[10,13],crossdomainnetwork:17,crossref:13,cs:[0,20],ctx:24,cuda:35,current:[6,13,22,30,34,35,36],current_kspac:22,cursor:38,custom:[30,31,34,36],cvf:13,cvpr46437:13,cvpr:[13,23],cyan:38,cycl:[31,34],cycle_momentum:31,cycliclr:31,cycliclrparam:31,d:[0,13,15,16,33,34,38],d_:13,d_f:13,d_i:13,d_model:34,dampen:31,data1:24,data2:24,data:[3,4,7,9,10,13,14,16,17,19,21,23,24,26,27,28,30,31,34,36],data_consistency_block:26,data_load:30,data_loss:24,data_parallel_group:36,data_parallel_rank:36,data_parallel_s:36,data_rang:7,datacl:36,dataclass:[31,34,36],dataconsistencylay:[16,26],datagdlay:24,dataidlay:24,datalay:24,dataload:[2,13,29,30,34,36],dataloader_idx:[30,36],dataloaderconfig:31,dataproxcglay:24,dataset:[2,3,4,11,29,31,33],dataset_cach:11,dataset_cache_fil:11,datasetconfig:[30,31],dataterm:24,datavslay:24,date:36,datefmt:38,datetim:36,dc:24,dc_layer:[10,13],dc_sen:13,dclayer:24,ddp:[31,36],dealloc:34,debug:[36,38],debugmridcformatt:38,decai:[31,34],decay_r:34,decim:36,declar:30,decod:[30,32,33,38],decor:[2,30,31,36],decreas:34,deep:[0,13,18,22,24,31,34],deeper:23,default_format:38,default_lr:36,default_opt_arg:36,default_root_dir:31,defin:[11,24,26,30,33,34,36],definit:[7,30,33],delet:36,denoiser_block:26,denomin:34,denot:30,dense_connect:22,densiti:11,depend:34,deploy:[30,36],deprec:[2,36],depth:[22,23],deriv:[30,31],descript:30,deseri:[30,32],desir:[11,33,36],destin:36,destt:36,detail:[31,36],detect_anomali:31,determin:[23,30,32,33,36],determine_conv_class:23,determinist:31,dev1:36,dev2:36,develop:0,deviat:[9,12],devic:[30,31,32],device_id:36,df:36,dict:[9,12,13,28,30,31,32,33,34,36],dictconfig:[2,13,30,32,34,36],dictionari:[9,13,30,34,35,36],didn:[10,13],diff:36,differ:[9,28,36],dilat:[20,22,23],dilatedconvblock:20,dim:[9,28,35,38],dim_incompat:33,dimens:[9,11,13,17,18,22,25,28,30,33,35],dimension:[8,9,33],dimitri:11,dimitrio:0,dir:[31,36],direct:0,directori:[9,30,32,36],dirpath:[31,36],disabl:[30,34,36],disable_check:30,disabled_deployment_input_nam:30,disabled_deployment_output_nam:30,disabled_nam:35,disast:36,disc:11,discret:20,dist_sync_on_step:8,distribut:[1,2,8,11,33],divid:13,divide_root_sum_of_squar:13,do_coil_combin:13,do_constant_fold:30,doc:[30,31],document:[11,30,37],doe:[8,33,36],doesn:[31,34],doi:[13,15,16,23,26],domain:[13,17,19],domain_sequ:17,done:[30,32],dot:[24,36],doubl:20,down:[0,13,18,38],download:[30,36],draw:6,drawn:11,drop:[28,36],drop_last:[30,31,34],drop_missing_subconfig:36,drop_prob:[13,25],dropout:[19,25],dropout_prob:19,dt:36,dtype:33,dual:[0,13,21],dualnet:21,duan:[13,26],dub:18,due:36,dunet:[0,3,10],durat:36,dure:[25,30,31,32,34,35],dwt:20,dynam:[13,15,16,30,35],dynamic_ax:30,e2evn:0,e:[33,34,36],each:[5,8,11,12,13,15,17,19,25,28,30,33,34,35,36],easier:36,easili:36,effect:[34,36],effici:34,either:[8,13,28,30,34,36],element:[2,6,11,28,29],elements_typ:33,elementtyp:33,elif:[30,32],ellips:11,els:[30,32,36],emit:37,empti:[30,32,36],enabl:[30,34,36],enable_checkpoint:31,enable_model_summari:31,enable_progress_bar:31,encdecctcmodel:[30,32],encod:[9,13,30,32,33],end:[0,5,12,13,27,28,30,36],end_color:38,endow:30,enforc:35,ensur:[11,16,26,33,36],entir:[8,30,32],entireti:36,entiti:30,entri:30,env:36,env_var_pars:[1,2],environ:[0,36],eo:13,ep:[24,31,34],epoch:[5,13],equal:[11,33],equispac:11,equispacedmaskfunc:11,equival:36,err_log_fil:36,error:[12,30,32,33,36],error_check:36,essenti:36,estim:[13,28],et:[0,13,18,20,22,23,24,26],et_queri:11,eta:[23,28,31],etc:33,eval:30,evalu:[3,10,13],even:[30,32,36],everi:[24,26],every_n_epoch:[31,36],everyth:33,exact:[30,36],exactli:36,exampl:[8,11,13,30,32,33,34,35,36],exc_info:36,except:[1,2,30,32,33],exclud:[30,35],execut:[34,36],exist:[30,31,32,34,36],exit:[30,32,36],exp:36,exp_dir:[31,36],exp_manag:[1,2,31],expand:[14,16,26,27],expans:36,expect:[11,36],experi:36,experiment:[2,36],explain:36,explan:37,explicit:[34,36],explicit_log_dir:[31,36],explicitli:[31,34],expmanagerconfig:[31,36],exponentiallr:31,exponentiallrparam:31,export_param:30,export_util:[1,2],exportformat:[30,36],extend:36,extens:36,extern:34,extra:23,extra_repr:23,extract:[30,32,35,36],extract_dynamic_ax:35,extract_state_dict_from:[30,32],f:[13,21,24],facebookresearch:[11,13],factor:[11,12,13,31],factori:31,fail:[35,36],fall:[30,32],fals:[8,11,13,14,15,16,17,18,20,22,23,24,27,28,30,31,32,33,34,35,36,38],fashion:6,fast:[0,9,13,19,36],fast_dev_run:31,fastmri:[0,11,13],fastmricombinedslicedataset:11,fastmrislicedataset:11,fastpath:36,feasibl:0,featur:[13,18,23,33,36],fed:33,fetch:[30,36],fft2c:9,fft:[3,4,23,24],fft_dim:9,fft_normal:9,fft_type:[9,13,14,16,17,19,22,23,24,26,27,28],field:[33,34],file:[2,9,13,28,30,31,32,36],filehandl:36,fileio:30,filenam:[9,13,31,36],filepath:[31,36],files_to_copi:[31,36],fill:[0,13,28],find:[30,36],first:[7,22,28,30,36],first_conv_hidden_channel:20,first_typ:33,fischer:[13,25],fit:30,fix:33,flag:[24,30,32],flat:34,flatten:36,float16:34,float16_group:34,floattensor:13,floattyp:33,flow:33,flowgroup:33,flush:36,flush_logs_every_n_step:31,fmt:38,fname:[13,28],focus:0,fold:[11,30],folder:[30,32,36],follow:[0,15,18,19,23,25,30,32,34,36],forc:36,forcibl:36,foreground:38,form:[14,27,30,32,33],format:[2,30,32,33,34,36],formatt:[36,38],former:[24,26],formul:24,forward:[7,8,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,30,36,38],forward_f:21,forward_method:36,forward_save_spac:24,found:[30,32,36],fourier:9,fp16:34,fp32:34,fp32_from_float16_group:34,fp32_from_fp32_group:34,fp32_grad_accum:34,fp32_grad_accumul:34,fpr:20,fraction:11,fragment:34,framework:[0,31],free:24,freez:[24,30],freeze_al:24,freq:33,frequenc:[11,13,33],frequent:33,from:[6,11,13,22,24,30,31,32,33,34,35,36],from_config_dict:30,from_config_fil:30,from_pretrain:30,from_str:33,full:[11,14,27,31,36],fulli:[0,11,30,32,36],func:36,functionwrapp:36,functool:[31,34],functor:30,fwhm:11,g:[21,33,36],gamma:31,gate:23,gaussian1dmaskfunc:11,gaussian2dmaskfunc:11,gaussian:[11,28],gaussian_coordin:11,gaussian_kernel:11,gaussian_kspac:11,gener:[6,11,13,28,30,33,36],georg:[13,22],get:[13,30,32,34,36],get_available_model_nam:30,get_dynamic_ax:35,get_env:36,get_envbool:36,get_envd:36,get_envdatetim:36,get_envdecim:36,get_envdict:36,get_envfloat:36,get_envint:36,get_envlist:36,get_export_format:36,get_foo:36,get_git_diff:36,get_git_hash:36,get_io_nam:35,get_iter:6,get_log_dir:36,get_lr:34,get_model_metadata_from_guid:36,get_optim:34,get_optimizer_config:31,get_pad_and_num_low_freq:13,get_paramet:34,get_rank:[1,2],get_schedul:34,get_scheduler_config:31,get_test_dataloader_prefix:30,get_validation_dataloader_prefix:30,get_verbos:36,geteffectivelevel:36,getter:30,gidx:36,git:[0,36],github:[0,11,13],give:[11,30],given:[11,12,19,20,22,24,28,30,34,35],global:[30,36],global_average_loss_metr:[3,4],global_rank:[6,36],globalaveragelossmetr:8,good:33,gpu:[0,30,31,32,36],grad:34,grad_averag:[31,34],grad_sync:34,grad_x:24,grad_z:24,gradbucket:34,gradient:[8,23,24,31,34],gradient_clip_algorithm:31,gradient_clip_v:31,graph:36,grappa:11,greater:33,green:38,group:[8,30,34,36],growth_interv:9,gru:[15,22,23],gru_kernel_s:22,gruconv2d:[10,13],gt:12,guarante:[30,32],guess:23,guid:[30,36],gz:[30,32],h5:[9,13],h:[17,18,20,21,22,25,33],h_0:22,h_mult:25,h_pad:25,ha:[8,9,11,28,30,31,34,36],had:36,hajnal:[13,15,16],half:[11,34],half_scan_percentag:[11,28],hammernik:13,hamster:33,handl:36,handler:[31,36],hash:36,hashed_path:36,hast:24,have:[11,28,30,32,33,34,36],height:[14,17,18,20,22,33],helper:[30,35,36],here:[0,11,33,34],heta_:22,hidden:[15,22,23],hidden_channel:[15,18,22],hidden_label:23,hidden_s:23,hidden_st:22,high:[0,34],higher:36,highlight:31,hold:[30,31,34,36],hold_ratio:[31,34],hold_step:[31,34],home:[11,36],hood:33,hook:[24,26],host:36,houston:36,how:[5,30,32,36],hparams_fil:30,html:[30,31],http:[0,11,13,20,22,24,30,31,34],human:33,hx:[15,23],hydra:[31,36],hydra_runn:[2,29],hydraconfig:31,hyperparamet:31,i:[13,17,19,24,33,34,36],icnet:[0,13],id:35,identifi:33,idx:[13,36],ieee:[13,15,16,23],ifft2c:9,ifft:[9,13],ignor:[24,26,36],ignore_arg:36,ignore_collect:30,imag:[9,11,12,13,14,15,16,17,19,20,23,24,25,26,27,28,33],image_buff:17,image_buffer_s:17,image_correct:17,image_model_list:17,imagefeaturevalu:33,imagevalu:33,imaginari:9,implement:[0,13,15,18,20,25,30,34,36],import_class_by_path:36,improv:34,in_chan:[13,25],in_channel:[15,18,19,20,22],inbuilt:36,includ:[13,26,30,32,36],incompat:[33,36],incorrect:33,increas:34,independ:[0,13,23],index:[12,13,30,33,36],indic:[6,28,30,36],indirect:36,individu:[6,17,19,30,32],indrnn:23,indrnncel:23,indrnncellbas:23,infer:[0,13,23,30,35],infinit:34,info:36,inform:[13,23,30,36],infrastructur:36,inherit:[30,33,34,36],init_from_mridc_model:30,init_from_pretrained_model:30,init_from_ptl_ckpt:30,init_pr:13,init_scal:9,initi:[9,13,15,22,23,28,30,34,36],initial_accumulator_valu:31,initialize_distribut:36,inject:36,inject_model_parallel_rank:36,inplac:20,input:[7,9,11,13,14,15,16,17,18,20,21,22,23,24,25,26,27,28,30,33,34,35,36],input_channel:20,input_data:[18,19],input_dict:36,input_exampl:[30,36],input_list:36,input_modul:30,input_nam:[30,36],input_port_nam:33,input_s:23,input_sign:35,input_tensor:20,input_typ:[7,30],insid:[30,32,33,36],inspect:35,inspir:[13,15,16,19,26],instal:[35,36],instanc:[19,24,25,26,30,32,33,36,37],instance_norm:22,instanti:[30,31,32,34,36],instead:[24,26,33,34,36],instruct:8,integ:[8,33,36],intellig:[13,26],intenum:36,interact:36,interest:36,interfac:30,intermedi:[13,30],intern:[13,25,30,32,34],interv:36,intervent:[13,25],intl:13,inttyp:33,invers:[9,13,20,22,31,34],inversesquarerootann:34,inversesquarerootannealingparam:31,invert:24,io:30,ipmi:13,ipu:31,irim:0,is_global_rank_zero:36,is_list:33,is_model_being_restor:36,is_numba_compat_strict:35,is_typecheck_en:30,isinst:[30,32],ismrm:11,ismrmrd:11,issu:[11,35,36],item:[30,34,36],iter:[6,13,18,24,30,34],iterabledataset:[6,30],iters_per_batch:34,its:[30,32,34,36],itself:36,iwt:20,j:[13,15,16,24,26],jan:[13,15,16],job:36,job_log:31,joint:[0,13],jointicnet:[3,10],jona:13,juli:13,jun:13,june:13,just:[30,32],k1:7,k2:7,k:[11,13,14,16,17,22,23,24,26,27,28],karkalouso:[0,13],keep:[13,23,25,34],keep_eta:23,kei:[12,13,30,31,32,34,35,36],kernel:11,kernel_s:[18,20,23],key2:34,keyword:[33,34,36],kiki:[0,13,17],kikinet:[3,10],kind:33,knee:0,kp:[9,13],kspace:[13,26,27,28],kspace_buff:17,kspace_buffer_s:17,kspace_correct:17,kspace_crop:28,kspace_model_list:17,kspace_zero_filling_s:28,kwarg:[6,7,13,17,19,21,24,30,31,34,36],l1loss:13,l2:[24,34],label:33,labelstyp:33,lambda_:13,lambda_init:24,lambdaa:24,lamdba_:13,larger:28,last:[9,11,13,25,28,33,36],last_epoch:[31,34],latest:0,latter:[24,26],launch:[0,1],layer:[9,15,16,18,19,23,24,25,26,30,31,33,34],leaderboard:9,leakyrelu:[19,25],learn:[0,13,21,22,30,34,36],learnabl:24,least:[9,11,28],lectur:[13,26],len:33,length:33,lengthstyp:33,less:33,level:[13,20,23,30,32,35,36],levelnam:38,li:23,lib:[11,36],lib_nam:36,librari:36,life:[30,32],light:36,lightblack_ex:38,lightblue_ex:38,lightcyan_ex:38,lightgreen_ex:38,lightmagenta_ex:38,lightn:[0,5,8,30,34,36],lightning_logger_patch:[1,2],lightningmodul:[30,36],lightningnotinstalledexcept:36,lightred_ex:38,lightwhite_ex:38,lightyellow_ex:38,like:[9,33,36],likelihood:23,limit_predict_batch:31,limit_test_batch:31,limit_train_batch:[31,34],limit_val_batch:31,line:[31,36,38],linear:[9,24],linearli:34,lineno:38,list:[6,7,9,11,13,16,22,23,25,26,28,30,33,34,35,36],list_available_model:30,liu:20,lnc:[13,26],load:[11,30,32,34,36],load_config_and_state_dict:32,load_from_checkpoint:30,load_from_str:30,load_instance_with_state_dict:32,load_part_of_state_dict:30,load_state_dict:[30,32,34,36],loader:[30,36],local:[30,32,36],local_path:36,local_rank:36,locat:[30,32,36],log:[5,13,23,30,33,36,38],log_dir:36,log_every_n_step:31,log_fil:36,log_global_rank_0_onli:[31,36],log_gpu_memori:31,log_imag:13,log_likelihood_gradi:23,log_local_rank_0_onli:[31,36],log_step_tim:[31,36],logdeterminanttyp:33,logepochtimecallback:5,logger:[31,36],loggercollect:36,loggerlist:36,loggermisconfigurationerror:36,logic:[30,33,34],logmod:36,logprobstyp:33,longer:23,look:36,loop:36,loss:[2,3,4,8,13,24,29,33,34,36],loss_sum:8,losstyp:33,low:[11,13,18],lower:36,lpd:[3,10],lpdnet:[0,13],lr:[30,31,34,36],lr_decai:31,lr_schedul:[2,29],luc:[31,34],luc_ep:[31,34],luc_trust:[31,34],m:[9,13,24],machin:[0,13,23],made:11,mag:13,magenta:38,magn:[9,13],magnet:13,magnitud:13,mai:[11,20,35,36],main:[2,31,33,34,36],mainli:24,mainparamsoptimizerwrapp:34,maintain:[12,34],mainten:30,major:36,make:[13,17,25,36],manag:[11,30,34,35,36],mandatori:[30,34],mani:36,manifest:36,manner:36,manual:[34,36],map:[6,9,13,14,16,17,19,23,24,26,27,28,30,32,33,36],map_loc:[30,32],mark:37,mask:[11,12,13,14,16,17,22,23,24,26,27,28,33],mask_background:12,mask_cent:[13,28],mask_center_scal:28,mask_from:28,mask_func:28,mask_root:11,mask_to:28,mask_typ:[13,28],mask_type_str:11,masked_kspac:[13,17,22,23],maskfunc:[11,28],masktyp:33,master:[34,36],match:[11,30,33,36],mathcal:22,matrix:[23,24,35],matrix_invert:24,matthan:0,max_batch:30,max_dim:30,max_epoch:[31,34],max_it:24,max_lr:31,max_momentum:31,max_size_cycl:31,max_step:[31,34],max_tim:31,max_valu:13,maxim:34,maximum:[11,13,24,30],maxval:12,maybe_download_from_cloud:36,maybe_init_from_pretrained_checkpoint:30,maybe_update_config_vers:36,mb:[9,13],mean:[8,12,24,25,30,31,33,34,36],meant:33,measur:[8,36],med:[9,13],medic:[13,15,16,19,25],medicin:13,memori:34,memory_al:36,memory_err:36,memoryhandl:36,merg:[7,30],messag:[36,38],meta:36,metaclass:[1,2],metadata:36,method:[0,8,9,12,13,19,30,31,32,33,34,35,36],metric:[3,4,10,13],metric_func:12,mi_:13,might:9,min:[31,36],min_epoch:31,min_lr:[31,34],min_step:[31,34],min_vers:35,minim:23,minimum:[28,34,35],mirror:36,misc:0,mismatch:36,miss:36,mix:[0,9],mixtur:28,mod:36,mode:[30,31,36,38],model:[0,2,3,9,10,30,31,32,33,34,35,36],model_cfg:31,model_checkpoint:36,model_cl:36,model_config:[30,32],model_config_yaml:32,model_nam:30,model_or_experiment_nam:36,model_parallel_s:[31,36],model_restore_path:36,model_util:[1,2,30,32],model_weight:[30,32],model_weights_ckpt:32,model_wight:[30,32],modelcheckpoint:36,modelconfig:31,modelconfigbuild:31,modelmetadataregistri:36,modelpt:[2,13,29,32,36],modif:[11,19],modifi:13,modul:[0,1],modulelist:26,moment:[31,34],momentum:[30,31],monitor:[31,34,36],more:[25,30,36],most:36,motiv:24,move:[17,36,38],move_metrics_to_cpu:31,mr:[0,13,15,16],mri:[0,9,13,19,22,26,28,33],mri_data:[3,10],mridatatransform:28,mridc_cache_dir:36,mridc_env_varname_redirect_logs_to_stderr:36,mridc_example_script:36,mridc_experi:36,mridc_fil:[30,32],mridc_file_fold:36,mridc_log:[1,2],mridc_nam:36,mridc_topk_check_previous_run:36,mridc_vers:36,mridcbaseexcept:36,mridcconfig:[31,36],mridcmodelcheckpoint:36,mridcnativemixedprecisionplugin:9,mrireconstruct:13,mrisign:33,mrm:13,mse:12,mse_val:13,msg:36,much:36,multi:[0,13,17,19,20,21,30,36],multi_test_epoch_end:30,multi_validation_epoch_end:30,multichannel:13,multicoil:[10,11,13],multidomain:[10,13],multidomainconv2d:19,multidomainconvblock:19,multidomainconvtranspose2d:19,multidomainnet:[0,3,10],multidomainunet2d:19,multipl:[6,8,9,11,30,32,36],multiple_trainloader_mod:31,multipli:[9,28],multiprocessing_context:31,multiscale_depth:22,must:[33,34,35,36],mwcnn:[10,13],mx:24,mymodel:32,n:[0,11,13,15,16,17,20,25,38],n_c:19,n_coil:[13,14,17,22],n_conv:15,n_dim:23,n_l:22,n_low_freq:11,n_resum:36,n_x:13,n_y:13,name:[11,12,13,28,30,31,33,34,35,36],name_list:36,namedtim:36,namespac:11,nativ:[9,31],native_amp:9,nativemixedprecisionplugin:9,nbatch:24,nccl:36,nchw:33,ndarrai:[9,11,12,28],nearest:11,necessari:[33,36],need:[9,11,24,26,30,31,33,34,36],nemo:0,nest:[30,36],nesterov:31,net:[0,13,19,24,25,26],network:[0,13,14,15,16,18,19,21,22,23,24,25,26,27,31,34],neural:[0,13,14,15,16,18,23,30,33,35,36],neural_typ:[2,29,30,35],neural_type_util:[2,29],neuralmodul:30,neuralportnamemismatcherror:33,neuralportnmtensormismatcherror:33,neuraltyp:[30,33,35],neuraltypecomparisonresult:33,neuraltypeerror:33,new_kspac:22,next:[22,36],nfe:24,ni_:13,nmse:12,nmse_val:13,nmtensor:33,nn:[7,9,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,30,36],no_dc:[14,16,23,27],no_param:12,noam:34,noamann:34,noamannealingparam:31,node:[0,36],nois:[12,23],non:[6,36],none:[2,6,7,8,9,11,12,13,15,17,20,22,23,28,30,31,32,33,34,36,37,38],nonetheless:11,nonlinear:23,norm:[25,31],norm_group:25,normal:[9,12,13,15,19,24,25,33,38],normaldistributionlogvariancetyp:33,normaldistributionmeantyp:33,normaldistributionsamplestyp:33,normalize_imag:17,normalize_input:28,normalizedimagevalu:33,normunet:25,note:[11,13,26,30,31,33,34,36],notfounderror:36,noth:[30,32,36],notset:36,nov:[13,22],novograd:[2,29,31,36],novogradparam:[31,34],npe:24,nsmap:24,ntype:35,num_cascad:26,num_coil:13,num_col:[11,13],num_conv:18,num_convs_recon:18,num_direct:23,num_dual:21,num_dub:18,num_filt:19,num_input:7,num_it:24,num_iter:16,num_lay:[22,23],num_low_frequ:13,num_measur:8,num_nod:31,num_paramet:15,num_pool:[13,25],num_pool_lay:[19,25],num_prim:21,num_row:13,num_sampl:34,num_sanity_val_step:31,num_scal:20,num_sens_map:13,num_weight:30,num_work:[30,31,34],numba:35,numba_cpu_is_support:35,numba_cuda_is_support:35,numba_util:[2,29],number:[6,7,8,11,12,13,15,24,28,30,34,36,38],numel:34,numer:[25,34],numpi:[9,11,12,28],o:13,obj:36,object:[11,12,24,28,30,31,32,33,34,36,37,38],obtain:[30,31,34,36],occur:36,odd:18,off:11,offer:30,omegaconf:[2,13,30,32,34,36],on_after_backward:36,on_before_backward:36,on_save_checkpoint:36,on_test_batch_end:36,on_test_batch_start:36,on_train_batch_end:36,on_train_batch_start:36,on_train_end:36,on_train_epoch_end:5,on_train_epoch_start:5,on_validation_batch_end:36,on_validation_batch_start:36,onc:36,one:[7,11,19,24,25,26,30,33,36],onli:[6,8,30,33,34,35,36],onnx:[30,36],onnx_opset_vers:30,onnxrt:36,oper:[24,28,36],opset:30,opt_arg:[30,36],optim:[2,8,29,30,36],optim_config:30,optimconfig:31,optimis:17,optimizer_inst:34,optimizer_kwarg:34,optimizer_nam:34,optimizer_param:[31,34],optimizer_with_master_param:[2,29],optimizerparam:[31,34],option:[6,7,11,12,13,15,17,20,22,23,28,30,31,32,33,34,36],order:[30,36],org:[11,13,20,22,24,30,31,34],origin:[13,30,32,36],ortho:9,orthogon:[9,13,14,16,17,19,22,23,24,26,27,28],orthogonal_initi:22,orthotogonalize_weight:23,os:11,other:[9,11,30,32,33,34,36],otherwis:[9,17,30,32,34,36,38],out:[11,30,36],out_chan:[13,25],out_channel:[15,18,19,20,22],out_dir:9,output:[7,9,12,13,15,16,17,18,20,21,22,23,24,25,26,28,30,32,33,35,36],output_exampl:36,output_modul:30,output_nam:[30,36],output_path:12,output_typ:[7,30],outsid:11,over:[6,28,34,36],overfit_batch:31,overrid:[30,31,32,33,34,36],overridden:[24,26,30,36],override_config_path:[30,32],own:[33,36],ozan:13,p:[9,13,25],packag:[0,1],package_info:[0,1],pad:[13,18,20,25,28],padding_s:[13,25],page:[13,25],pair:[30,32],paper:[9,23,34],paradigm:36,parallel:[13,26,32,34,36],param:[30,31,34,36],param_group:34,paramet:[2,6,7,8,9,11,12,13,14,15,16,17,18,20,21,22,23,24,25,26,27,28,30,31,32,33,34,35,36,37,38],parametr:33,parent:36,parent_pars:36,parent_type_nam:33,pars:[30,34,36],parse_dataset_as_nam:36,parse_input_exampl:36,parse_optimizer_arg:34,parser:36,part:[3,4,10,30],partial:[31,34],particular:[30,31,33],partit:6,pass:[13,14,15,16,17,19,20,21,22,23,24,25,26,30,31,32,34,36],patch_stderr_handl:36,patch_stdout_handl:36,patch_util:[3,4],path2yaml_fil:30,path:[9,11,12,30,31,32,34,36],path_typ:36,pathlib:[9,11,36],pathlik:11,patienc:31,pattern:[11,13,23],pd:[10,13],pdf:[13,24],peak:12,penalti:34,pend:30,pengju:20,per:[21,30,36],percentag:[11,28],perform:[0,13,15,17,20,24,26,30,33,34,36],persist:34,physic:[13,22],pic:[3,10],pick:11,pin_memori:[30,31],pipelin:36,pipeline_model_parallel_group:36,pipeline_model_parallel_rank:36,pipeline_model_parallel_s:36,pipeline_model_parallel_split_rank:36,pl:13,pl_modul:[5,36],place:36,placehold:[24,30],pleas:[0,30,31],plugin:[9,31],po:38,point:[11,30,34,36],poisson2dmaskfunc:11,poisson:11,poisson_disc2d:11,polici:34,polynomi:[31,34],polynomialdecayann:34,polynomialdecayannealingparam:31,polynomialholddecayann:34,polynomialholddecayannealingparam:31,port:[7,30,33,37],port_doc:[2,36],port_nam:33,posit:38,possibl:[11,36],postfix:[31,36],potenti:36,power:[31,34],pp:[13,15,16,23,26],practic:33,pre:30,precis:[0,9,13,31,34],pred:[12,13,14,23,27],pred_kspac:[16,26],predefin:11,predict:[13,14,22,23,33],predictionstyp:33,prefix:[30,31,36],prelu:[15,18],prepar:[30,36],prepare_data_per_nod:[13,31],prepare_lr_schedul:34,prepare_test:30,prepend:[30,36],preprocess:28,preprocessor:[30,32],present:[13,22,23,25,36],preserv:[11,36],preset:36,pretrain:[30,36],pretrained_model_nam:30,pretrainedmodelinfo:[30,36],prevent:30,previou:[22,36],previous_st:22,price:[13,15,16],primal:[0,13,21],primaldu:[10,13],primalnet:21,primarili:0,print:[23,30,36],prob:11,probabilitydistributionsamplestyp:33,probabl:[6,11,33],problem:[13,22,36],probstyp:33,proc:13,proceed:23,process:[6,8,13,30,31,33,36],process_group:8,process_input:13,process_intermediate_pr:13,process_loss:13,process_posit:31,produc:36,product:24,profil:31,progress_bar_refresh_r:31,project:36,promot:34,prop:34,propag:23,properli:[30,35],properti:[7,30,32,33,34,36],proport:11,propos:[9,23,24,31,34],provid:[0,30,31,32,33,34,35,36],prox:24,proxi:36,pruessmann:[9,13],psnr:12,psnr_val:13,ptl:36,ptl_overrid:[3,4],pubm:13,purpos:11,push:12,put:33,py:[11,36],pytest:35,python3:11,python:[0,33,36],pytorch:[0,5,8,9,25,30,31,32,34,36],pytorch_lightn:[5,9,13,30,32,36],qin:[13,15,16],qlist:11,qualiti:0,queri:11,r:[11,13],radiu:11,rais:[30,32,33,36],ramzi:13,ran:36,random:[6,11,13,28,30,36],random_gener:6,random_se:36,randomli:[13,30],randommaskfunc:11,rang:[7,11,33],rank:[6,34,36,38],rate:[11,30,34,36],ratio:[12,34],re:[19,20,30,36],reach:11,readabl:33,real:[9,28],recent:36,recip:[24,26],recognit:[13,23],recommend:0,recon:12,reconblock:18,reconstruct:[2,3,9,33,36],reconstruction_kei:12,reconstructiontarget:33,recurentvarnet:[10,13],recurr:[0,13,14,15,16,22,23,33],recurrent_bia:23,recurrent_dil:23,recurrent_filt:23,recurrent_kernel:23,recurrent_lay:23,recurrentconvolutionalnetblock:16,recurrentinit:22,recurrentstyp:33,recurrentvarnet:[10,13],recurrentvarnetblock:22,recurs:36,red:38,redirect:36,reduc:[14,16,26,27,30,34],reduce_on_plateau:34,reducelronplateau:31,reducelronplateauparam:31,reduct:[7,30,31,36],reevalu:34,ref_kspac:[14,16,26,27],refer:[9,13,14,15,16,18,20,22,23,25,26,27,30,31,36],reflect:30,refresh_cach:[30,36],region:11,regist:[24,26,30,31,32],register_artifact:[30,32],register_model_guid:36,register_optim:34,register_optimizer_param:31,register_schedul:34,register_scheduler_param:31,registri:[30,31,34,36],regress:33,regressionvaluestyp:33,regular:[14,24,25,27,34],reinstal:0,rel:31,relat:[33,34],relative_step:34,relax:35,reli:33,reload:34,reload_dataloaders_every_n_epoch:31,reload_model_param:34,relu:[15,20,23],remain:[11,34],remap_arg:36,remov:36,remove_stream_handl:36,repeatedli:36,replac:36,replace_for_export:36,replace_modul:36,replace_sampler_ddp:31,replication_pad:22,repo:0,repositori:0,repres:[30,32,33,36],represent:[11,12,33],request:36,requir:[11,30,34,35,36],requiredsettingmissingerror:36,resent:36,reset:[11,23,34,36,38],reset_al:38,reset_cach:24,reset_model_guid_registri:36,reset_paramet:23,reset_stream_handl:36,residu:20,resolut:[18,34,36],resolv:[34,36],resolve_cache_dir:36,resolve_dataset_name_from_cfg:36,resolve_subclass_pretrained_model_info:36,resolve_validation_dataload:36,reson:[9,13],respect:[34,36],respons:36,restor:[20,30,32,36],restoration_path:36,restore_from:[30,32],restore_path:[30,32],result:[8,13,33,34,35,36,37],resum:36,resume_from_checkpoint:31,resume_from_checkpoint_fit_path:36,resume_if_exist:[31,36],resume_ignore_no_checkpoint:[31,36],resume_past_end:[31,36],retain:11,retriev:[31,34],return_config:[30,32],rho:31,rim:[0,10,13],rim_block:[10,13],rimblock:23,rmsprop:31,rmspropparam:31,rng:11,rnn:[9,23],rnn_cell:[10,13],rnn_util:[3,4],rnn_weights_init:9,robin:6,robust:[0,13],ronneberg:[13,25],root:[9,11,13,31,34,36],round:[6,11],round_robin_gener:6,rprop:31,rpropparam:31,rsi:22,rss:[9,13],rss_complex:9,rubric:36,rueckert:[13,15,16],run:[2,8,12,24,26,31,34,35,36],run_:36,runtim:[34,36],rvn:[0,3,10],s:[13,23,30,32,33,36,38],s_i:19,safe:[30,34,36],sai:33,same:[11,14,16,25,26,27,30,33,36],same_type_incompatible_param:33,sampl:[0,6,11,13,17,22,23,24,30,33],sample_freq:33,sample_r:11,sampler:31,sampling_mask:[17,22],sampling_prob:6,sampling_techniqu:6,satisfi:[33,35],save:[9,13,30,32,36],save_best_model:[31,36],save_dir:[30,32],save_last:[31,36],save_mridc_on_train_end:[31,36],save_path:[30,32],save_reconstruct:9,save_restore_connector:[2,29,30],save_spac:24,save_to:[30,32],save_top_k:[31,36],save_weights_onli:[31,36],saverestoreconnector:[30,32],scale:[11,18,20,28,34],scale_mod:31,scale_paramet:34,sched:[31,36],sched_cfg:31,schedconfig:31,schedul:[2,29,30,34,36],scheduler_config:34,scheduler_param:[31,34],schedulerparam:[31,34],scheidegg:[9,13],schema:[31,34],schlemper:[13,15,16],scienc:[13,26],screen:38,second:[5,7,11,28,33,36],second_object:33,second_typ:33,see:[8,30,34,36],seed:[11,28,36],segment:[13,25],select:[8,11,13,30,32],self:[13,30,32,33,36],semant:33,semver:36,send:36,sens:[0,9,13,14,16,19,23,26,27],sens_expand:[14,16,26,27],sens_map:[14,16,26,27],sens_reduc:[14,16,26,27],sense_root:11,sensit:[0,9,13,14,16,17,19,22,23,24,26,27,28],sensitivity_map:[9,13,17,19,22,28],sensitivity_net:[10,13],sensitivitynetwork:24,sent:36,seq:33,seq_len:23,sequenc:[11,28,33],sequencetosequencealignmenttyp:33,seri:[14,27],serial:30,set:[8,11,15,17,24,30,31,32,33,34,35,36,38],set_learn:24,set_norm:24,set_numba_compat_strict:35,set_optim:31,set_test_d:31,set_titl:38,set_to_non:34,set_train:30,set_train_d:31,set_typecheck_en:30,set_validation_d:31,set_verbos:36,set_world_s:30,setlevel:36,setup:[13,30,36],setup_multiple_test_data:30,setup_multiple_validation_data:30,setup_optim:30,setup_optimizer_param_group:30,setup_test_data:[13,30],setup_training_data:[13,30],setup_validation_data:[13,30],sever:[7,36],sgd:[30,31,34],sgdparam:31,sh:0,shape:[11,13,14,17,18,20,22,23,25,28,34,35],share:30,shared_param:24,shift:[28,36],shift_mask:28,should:[9,11,24,26,28,30,32,33,34,35,36],shuffl:[6,30,31],side:36,sigma:23,sigmanet:[10,13],signal:[12,33],signatur:36,silent:[24,26],similar:12,simpl:[5,15],simple_replac:36,simpli:[34,36],sinc:[24,26],singl:[11,24,30,32,33,34,36],singleton:[33,36],site:11,size:[9,13,14,16,23,24,26,27,28,30,33,36],size_averag:30,skip:35,skip_connect:18,skip_numba_cuda_test_if_unsupport:35,slice:[11,12,13,28],slice_end:12,slice_idx:[13,28],slice_num:13,slice_start:12,slurm:36,smaller:28,smap:24,so:[8,34,36],soc:13,societi:23,soft:[14,27],softmax:33,solv:24,solver:[13,22,24],some:[30,33,36],someth:[33,36],songhyun:18,soon:0,sort:33,sourc:[2,5,6,7,8,9,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,30,31,32,33,34,35,36,37,38],space:[11,13,14,16,17,22,23,24,26,27,28,34],spatial:[9,28],special:[30,34],specif:[30,36],specifi:[6,8,9,18,30,33,34,36,38],split:[0,13,24,26,36],split_by_modul:[30,32],springer:[13,25],squar:[9,12,13,25,31,34],squareann:34,squareannealingparam:31,squarerootann:34,squarerootannealingparam:31,squarerootconstantpolici:34,squarerootconstantschedulerparam:31,src:[30,32],src_i:24,sriram:13,ssim:[3,4,12],ssim_val:13,ssimloss:7,stabil:34,stabl:[0,25,30,31,34],stack:[9,14,23,27],stage:[24,30,34],stage_training_init:24,stage_training_transition_i:24,standard:[9,11,12,19,36],standardizationlay:19,start:[5,12,28,30,32,34,36],start_index:34,stat:13,state:[8,11,15,22,23,30,32,34,36],state_dict:[30,32,34,36],statelesstim:36,statist:12,std:25,std_init_rang:9,stddev:12,stderr:36,stdout:36,step:[8,13,33,34,36],step_siz:31,step_size_down:31,step_size_up:31,step_timing_kwarg:[31,36],steplr:31,steplrparam:31,steptimingparam:[31,36],still:35,stochast:[31,34],stochastic_weight_avg:31,stop:[24,36],storag:33,store:[9,30,32,33,36],str:[6,7,9,11,13,14,16,17,19,22,23,26,27,28,30,31,32,34,35,36,37,38],strategi:31,stream:36,streamhandl:36,strict:[30,32,35],string:[11,30,33,34,36,38],stringlabel:33,stringtyp:33,structur:[12,33,36],style:[6,38],sub:[11,30,36],subclass:[24,26,30,36],subfold:36,subject:36,sublinear:34,submiss:[9,13,19],submodul:[0,1,3,4,10,29],subpackag:[0,1],subpixel:18,subprocess:36,subsampl:[3,10,13,17,22,23,24,28],subseri:[13,26],subset:11,succe:36,success:36,successfulli:34,sum:[7,8,9,13],sum_:19,summary_writer_kwarg:[31,36],superset:36,suppli:[30,36],support:[6,30,34,35,36,38],supported_export_format:30,supports_flat_param:34,supports_memory_efficient_fp16:34,sure:36,swap:36,swap_modul:36,sx:26,sy:36,sync_batchnorm:31,sync_cuda:[31,36],synchron:[8,34],system:[24,36],systemat:13,t5inversesquarerootann:34,t:[13,22,30,31,33,34],t_:33,t_with_str:33,taejoon:13,tag:36,take:[5,11,18,24,26,28],take_avg_loss:8,taken:[28,36],tar:[30,32],tar_path:36,tarfil:[30,32],target:[12,13,28,33,36],task:[0,13,22,33],teardown:30,techniqu:6,temp_se:11,temp_verbos:36,templat:33,temporari:[30,32,36],temporarili:[11,30,35],tenor:17,tensor:[7,8,9,11,13,14,15,16,17,18,19,20,22,23,24,25,26,27,28,30,33,34,36],tensor_model_parallel_group:36,tensor_model_parallel_rank:36,tensor_model_parallel_s:36,tensor_to_complex_np:9,tensorboard:36,tensorboardlogg:36,term:[24,26,34],termin:38,terminate_on_nan:31,test:[13,30,31,35,36],test_d:[30,31,36],test_data:13,test_data_config:[13,30],test_dataload:30,test_epoch_end:[13,30],test_loss:30,test_step:13,text:[33,38],th:24,than:[28,30,36],thank:0,thei:[9,30,35,36],them:[6,24,26,35,36],therefor:36,thi:[0,8,9,11,13,14,16,17,24,25,26,27,28,30,31,32,33,34,35,36,37],thin:36,thoma:[13,25],thorni:36,those:36,thread:36,threshold:[31,34,36],threshold_mod:31,through:[0,13,36],time:[11,33,34,35,36],time_step:23,timedelta:36,timeout:31,timer:[1,2],timer_kwarg:36,timingcallback:36,titl:[0,38],tmi:[13,15,16],to_config_dict:30,to_config_fil:30,to_onnxrt_input:36,to_tensor:9,to_unicod:38,todo:[9,23],togeth:36,toggl:28,tol:24,toler:[24,30,36],tool:0,toolbox:[0,13],top:36,topk:36,torch:[6,7,8,9,11,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,30,31,32,34,36],torchmetr:8,torchscript:[30,36],total:[6,30,34,36],tpu_cor:31,trace:30,track_grad_norm:31,train:[0,2,7,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,30,31,34,36],train_d:[31,36],train_data:13,train_data_config:[13,30],train_dataload:[30,34],train_end:36,train_fil:36,train_script:36,trainer:[2,5,13,29,30,32,36],trainerconfig:31,training_step:13,trainingmod:30,transact:[13,15,16],transform:[3,9,10,11,20,36],transpos:[19,25],transpose_sam:33,transposeconvblock:25,transposemultidomainconvblock:19,travers:36,tread:36,tree:11,trg_j:24,tri:36,triangular:31,try_script:30,tupl:[11,13,20,22,23,25,28,30,31,33,34,36],turn:36,two:[9,19,24,25,28,33,36],type:[7,9,11,12,13,14,15,16,17,18,20,21,22,23,24,25,26,27,28,30,31,32,33,34,35,36,37,38],type_compat:33,type_paramet:33,typecheck:30,typest:30,typic:33,u:[13,19,25],uecker:13,uncheck:33,under:[33,36],underli:[30,32],undersampl:13,unet:[0,3,10,19,25],unet_bas:[10,13],unet_block:[10,13],unfreez:[24,30],unfreeze_al:24,unicod:38,uniformli:11,uniniti:30,uninject:36,uninject_model_parallel_rank:36,union:[9,11,13,16,22,23,26,28,30,31,32,34,36],uniqu:[30,36],unique_artifact_nam:[30,32],unique_names_check:36,unit:[22,23],unittest:36,unnorm:[24,25],unpack:36,unpad:25,untar:[30,32],up:[0,11,13,18,34,36,38],updat:[8,13,30,34,36],update_c:13,update_cfg:36,update_model_config:36,update_save_restore_connector:30,update_x:13,upscale_factor:18,url:[0,36],us:[0,6,8,9,11,13,16,19,23,24,26,30,31,32,33,34,36,37,38],usag:[8,30],use_dataset_cach:11,use_datetime_vers:[31,36],use_dynamic_ax:30,use_se:28,user:[30,31,36],usual:[30,32,33,36],utf:38,util:[0,1,2,3,4,6,10,11,13,29,30,31,34],v1:36,v:[13,15,16],val1:[34,36],val2:[34,36],val:36,val_check_interv:31,val_data:13,val_data_config:[13,30],val_dataload:30,val_dl_idx:30,val_loss:[30,31,34,36],valid:[13,30,31,33,36],validation_d:[30,31,36],validation_epoch_end:[13,30],validation_step:13,valu:[6,8,9,11,12,13,25,28,30,31,33,34,35,36,37,38],valueerror:[30,36],vari:33,variabl:[0,11,13,24,26,36],variablesplittingnet:[10,13],varianc:33,variant:34,variat:[0,13,14,22,27],varnet:[10,13],varnetblock:27,verbos:[30,31,36],verbosity_level:36,verifi:36,verify_runtim:36,verify_src_exist:[30,32],version:[0,9,30,31,35,36,37],version_:36,via:[30,31,33,34,36],view:34,vision:[13,23,33],vn:[3,10],vn_block:[10,13],voidtyp:33,vol:[13,15,16],volum:13,volume_sample_r:11,vs:[13,26],vsnet:[0,3,10,26],vsnet_block:[10,13],vsnetblock:26,w:[17,18,20,25,33],w_mult:25,w_pad:25,wa:[9,30,32,34,36,37],wai:[9,23,36],wandb_kwarg:36,wandb_logger_kwarg:[31,36],wandblogg:36,want:[0,33,36],warm:34,warmup:[31,34],warmup_init:34,warmup_ratio:[31,34],warmup_step:[31,34],warmupann:34,warmupannealholdpolici:34,warmupannealingholdschedulerparam:31,warmupannealingparam:31,warmupholdpolici:34,warmupholdschedulerparam:31,warmuppolici:34,warmupschedulerparam:31,warn:[30,32,36,37],wavelet:20,wdika:0,we:[33,34,36],weiger:[9,13],weight:[7,9,13,23,26,30,32,34,36],weight_decai:[31,34,36],weighted_average_block:26,weightedaverageterm:26,weights_save_path:[31,36],weights_summari:31,well:0,were:36,what:[33,36],whatev:30,when:[6,23,28,30,32,33,36,37],where:[9,12,19,28,30,32,36,37],whether:[6,13,23,30,32,33,34,35,36],which:[6,8,9,11,30,32,33,34,36,37],white:38,whose:36,width:[11,14,17,18,20,22,33],win_siz:7,wise:[31,34],with_numba_compat_strict:35,within:[11,24,26,35,36],without:[30,32,35,36],won:30,work:[6,13,30,32,36],worker:6,worker_init_fn:31,world:[8,30],world_siz:[6,36],would:[33,36],wrap:[36,37],wrap_forward_method:36,wrap_modul:36,wrap_training_step:36,wrapper:[24,34,36],write:[9,36],wrt:24,www:11,x0:24,x:[7,9,13,14,15,16,17,18,19,20,22,23,24,25,26,27,28,33,36,38],x_:[13,19],x_i:19,xavier:9,xi:19,xml:11,xpdnet:[0,3,10],xx:24,xy:24,y:[7,9,13,24,28,33,38],yaml:[2,11,30,32,36],year:0,yellow:38,yiasemi:[13,22],yield:30,yohan:13,you:[0,30,32,33,34,36],your:[30,32,33,36],yu:18,yx:24,yy:24,z:24,zacchari:13,zero:[0,8,13,20,25,34],zero_grad:34,zf:[0,3,10]},titles:["Welcome to mridc\u2019s documentation!","mridc","mridc package","mridc.collections package","mridc.collections.common package","mridc.collections.common.callbacks package","mridc.collections.common.data package","mridc.collections.common.losses package","mridc.collections.common.metrics package","mridc.collections.common.parts package","mridc.collections.reconstruction package","mridc.collections.reconstruction.data package","mridc.collections.reconstruction.metrics package","mridc.collections.reconstruction.models package","mridc.collections.reconstruction.models.cascadenet package","mridc.collections.reconstruction.models.conv package","mridc.collections.reconstruction.models.convrecnet package","mridc.collections.reconstruction.models.crossdomain package","mridc.collections.reconstruction.models.didn package","mridc.collections.reconstruction.models.multidomain package","mridc.collections.reconstruction.models.mwcnn package","mridc.collections.reconstruction.models.primaldual package","mridc.collections.reconstruction.models.recurrentvarnet package","mridc.collections.reconstruction.models.rim package","mridc.collections.reconstruction.models.sigmanet package","mridc.collections.reconstruction.models.unet_base package","mridc.collections.reconstruction.models.variablesplittingnet package","mridc.collections.reconstruction.models.varnet package","mridc.collections.reconstruction.parts package","mridc.core package","mridc.core.classes package","mridc.core.conf package","mridc.core.connectors package","mridc.core.neural_types package","mridc.core.optim package","mridc.core.utils package","mridc.utils package","mridc.utils.decorators package","mridc.utils.formaters package"],titleterms:{"class":30,"export":30,acknowledg:0,adafactor:34,aggreg:7,api:0,app_stat:36,argument:36,ax:33,base:[13,38],base_config:31,callback:5,cascadenet:14,ccnn:13,ccnn_block:14,cirim:13,citat:0,cloud:36,collect:[3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28],color:38,common:[4,5,6,7,8,9,30],comparison:33,conf:31,config_util:36,connector:32,consist:0,constant:2,content:[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38],conv2d:15,conv2gru:22,conv:15,conv_lay:23,convrecnet:16,core:[29,30,31,32,33,34,35],crnn:13,crnn_block:16,crossdomain:17,data:[0,6,11],dataload:31,dataset:[0,6,30],dc_layer:24,decor:37,deprec:37,didn:18,distribut:36,document:0,dunet:13,element:33,env_var_pars:36,evalu:12,except:36,exp_manag:36,experiment:37,export_util:36,fft:9,format:38,from:0,get_rank:36,global_average_loss_metr:8,gruconv2d:15,hydra_runn:31,imag:0,instal:0,introduct:0,jointicnet:13,kikinet:13,launch:2,licens:0,lightning_logger_patch:36,loss:[7,30],lpd:13,lr_schedul:34,magnet:0,metaclass:36,metric:[8,12],model:[13,14,15,16,17,18,19,20,21,22,23,24,25,26,27],model_util:36,modelpt:[30,31],modul:[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38],mri_data:11,mridc:[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38],mridc_log:36,multicoil:17,multidomain:19,multidomainnet:13,mwcnn:20,neural_typ:33,neural_type_util:35,novograd:34,numba_util:35,optim:[31,34],optimizer_with_master_param:34,packag:[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38],package_info:2,paper:0,part:[9,28],patch_util:9,pd:21,pic:13,pip:0,port_doc:37,primaldu:21,ptl_overrid:9,reconstruct:[0,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28],recurentvarnet:22,recurrentvarnet:22,reson:0,rim:23,rim_block:23,rnn_cell:23,rnn_util:9,rvn:13,s:0,save_restore_connector:32,schedul:31,segment:0,sensitivity_net:24,sigmanet:24,sourc:0,ssim:7,submodul:[2,5,6,7,8,9,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,30,31,32,33,34,35,36,37,38],subpackag:[2,3,4,10,13,29,36],subsampl:11,timer:36,trainer:31,transform:28,unet:13,unet_bas:25,unet_block:25,util:[9,23,28,35,36,37,38],variablesplittingnet:26,varnet:27,vn:13,vn_block:27,vsnet:13,vsnet_block:26,welcom:0,xpdnet:13,zf:13}}) diff --git a/docs/make.bat b/docs/make.bat index dc1312ab..747ffb7b 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -1,35 +1,35 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=source -set BUILDDIR=build - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.https://www.sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "" goto help - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/conf.py b/docs/source/conf.py index 661a0264..238a68c8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -37,8 +37,28 @@ extensions = [ "myst_parser", "sphinx.ext.napoleon", + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.viewcode", + "sphinx.ext.mathjax", + "sphinx.ext.autosummary", ] +# Napoleon settings +napoleon_google_docstring = True +napoleon_numpy_docstring = True +napoleon_include_init_with_doc = False +napoleon_include_private_with_doc = False +napoleon_include_special_with_doc = True +napoleon_use_admonition_for_examples = False +napoleon_use_admonition_for_notes = False +napoleon_use_admonition_for_references = False +napoleon_use_ivar = False +napoleon_use_param = True +napoleon_use_rtype = True +napoleon_preprocess_types = False +napoleon_type_aliases = None +napoleon_attr_annotations = True # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] diff --git a/docs/source/index.rst b/docs/source/index.rst index c0a6b855..e25c6b0e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,20 +1,15 @@ .. mridc documentation master file, created by - sphinx-quickstart on Mon May 23 12:30:03 2022. + sphinx-quickstart on Wed May 25 16:45:16 2022. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -MRI Data Consistency Documentation -================================== -.. toctree:: - :maxdepth: 2 - :caption: Getting started: - +Welcome to mridc's documentation! +================================= .. include:: ../../README.md :parser: myst_parser.sphinx_ -Indices and tables -================== +.. toctree:: + :maxdepth: 4 + :caption: API Documentation: -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` + modules.rst diff --git a/docs/source/modules.rst b/docs/source/modules.rst new file mode 100644 index 00000000..af286e6a --- /dev/null +++ b/docs/source/modules.rst @@ -0,0 +1,7 @@ +mridc +===== + +.. toctree:: + :maxdepth: 4 + + mridc diff --git a/docs/source/mridc.collections.common.callbacks.rst b/docs/source/mridc.collections.common.callbacks.rst new file mode 100644 index 00000000..c3aaa45d --- /dev/null +++ b/docs/source/mridc.collections.common.callbacks.rst @@ -0,0 +1,21 @@ +mridc.collections.common.callbacks package +========================================== + +Submodules +---------- + +mridc.collections.common.callbacks.callbacks module +--------------------------------------------------- + +.. automodule:: mridc.collections.common.callbacks.callbacks + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.common.callbacks + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.common.data.rst b/docs/source/mridc.collections.common.data.rst new file mode 100644 index 00000000..dbd9e857 --- /dev/null +++ b/docs/source/mridc.collections.common.data.rst @@ -0,0 +1,21 @@ +mridc.collections.common.data package +===================================== + +Submodules +---------- + +mridc.collections.common.data.dataset module +-------------------------------------------- + +.. automodule:: mridc.collections.common.data.dataset + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.common.data + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.common.losses.rst b/docs/source/mridc.collections.common.losses.rst new file mode 100644 index 00000000..3b4e880a --- /dev/null +++ b/docs/source/mridc.collections.common.losses.rst @@ -0,0 +1,29 @@ +mridc.collections.common.losses package +======================================= + +Submodules +---------- + +mridc.collections.common.losses.aggregator module +------------------------------------------------- + +.. automodule:: mridc.collections.common.losses.aggregator + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.common.losses.ssim module +------------------------------------------- + +.. automodule:: mridc.collections.common.losses.ssim + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.common.losses + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.common.metrics.rst b/docs/source/mridc.collections.common.metrics.rst new file mode 100644 index 00000000..06dd9f87 --- /dev/null +++ b/docs/source/mridc.collections.common.metrics.rst @@ -0,0 +1,21 @@ +mridc.collections.common.metrics package +======================================== + +Submodules +---------- + +mridc.collections.common.metrics.global\_average\_loss\_metric module +--------------------------------------------------------------------- + +.. automodule:: mridc.collections.common.metrics.global_average_loss_metric + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.common.metrics + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.common.parts.rst b/docs/source/mridc.collections.common.parts.rst new file mode 100644 index 00000000..b06319e8 --- /dev/null +++ b/docs/source/mridc.collections.common.parts.rst @@ -0,0 +1,53 @@ +mridc.collections.common.parts package +====================================== + +Submodules +---------- + +mridc.collections.common.parts.fft module +----------------------------------------- + +.. automodule:: mridc.collections.common.parts.fft + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.common.parts.patch\_utils module +-------------------------------------------------- + +.. automodule:: mridc.collections.common.parts.patch_utils + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.common.parts.ptl\_overrides module +---------------------------------------------------- + +.. automodule:: mridc.collections.common.parts.ptl_overrides + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.common.parts.rnn\_utils module +------------------------------------------------ + +.. automodule:: mridc.collections.common.parts.rnn_utils + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.common.parts.utils module +------------------------------------------- + +.. automodule:: mridc.collections.common.parts.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.common.parts + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.common.rst b/docs/source/mridc.collections.common.rst new file mode 100644 index 00000000..c0a988df --- /dev/null +++ b/docs/source/mridc.collections.common.rst @@ -0,0 +1,22 @@ +mridc.collections.common package +================================ + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mridc.collections.common.callbacks + mridc.collections.common.data + mridc.collections.common.losses + mridc.collections.common.metrics + mridc.collections.common.parts + +Module contents +--------------- + +.. automodule:: mridc.collections.common + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.data.rst b/docs/source/mridc.collections.reconstruction.data.rst new file mode 100644 index 00000000..e2e984f6 --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.data.rst @@ -0,0 +1,29 @@ +mridc.collections.reconstruction.data package +============================================= + +Submodules +---------- + +mridc.collections.reconstruction.data.mri\_data module +------------------------------------------------------ + +.. automodule:: mridc.collections.reconstruction.data.mri_data + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.data.subsample module +------------------------------------------------------ + +.. automodule:: mridc.collections.reconstruction.data.subsample + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.data + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.metrics.rst b/docs/source/mridc.collections.reconstruction.metrics.rst new file mode 100644 index 00000000..06da8eab --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.metrics.rst @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.metrics package +================================================ + +Submodules +---------- + +mridc.collections.reconstruction.metrics.evaluate module +-------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.metrics.evaluate + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.metrics + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.models.cascadenet.rst b/docs/source/mridc.collections.reconstruction.models.cascadenet.rst new file mode 100644 index 00000000..1406af85 --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.models.cascadenet.rst @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.cascadenet package +========================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.cascadenet.ccnn\_block module +--------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.cascadenet.ccnn_block + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.cascadenet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.models.conv.rst b/docs/source/mridc.collections.reconstruction.models.conv.rst new file mode 100644 index 00000000..4ac025fe --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.models.conv.rst @@ -0,0 +1,29 @@ +mridc.collections.reconstruction.models.conv package +==================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.conv.conv2d module +---------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.conv.conv2d + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.conv.gruconv2d module +------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.conv.gruconv2d + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.conv + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.models.convrecnet.rst b/docs/source/mridc.collections.reconstruction.models.convrecnet.rst new file mode 100644 index 00000000..1038005e --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.models.convrecnet.rst @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.convrecnet package +========================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.convrecnet.crnn\_block module +--------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.convrecnet.crnn_block + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.convrecnet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.models.crossdomain.rst b/docs/source/mridc.collections.reconstruction.models.crossdomain.rst new file mode 100644 index 00000000..b4341827 --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.models.crossdomain.rst @@ -0,0 +1,29 @@ +mridc.collections.reconstruction.models.crossdomain package +=========================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.crossdomain.crossdomain module +---------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.crossdomain.crossdomain + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.crossdomain.multicoil module +-------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.crossdomain.multicoil + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.crossdomain + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.models.didn.rst b/docs/source/mridc.collections.reconstruction.models.didn.rst new file mode 100644 index 00000000..365e6b2e --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.models.didn.rst @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.didn package +==================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.didn.didn module +-------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.didn.didn + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.didn + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.models.multidomain.rst b/docs/source/mridc.collections.reconstruction.models.multidomain.rst new file mode 100644 index 00000000..dab55d8e --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.models.multidomain.rst @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.multidomain package +=========================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.multidomain.multidomain module +---------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.multidomain.multidomain + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.multidomain + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.models.mwcnn.rst b/docs/source/mridc.collections.reconstruction.models.mwcnn.rst new file mode 100644 index 00000000..2fa986b3 --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.models.mwcnn.rst @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.mwcnn package +===================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.mwcnn.mwcnn module +---------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.mwcnn.mwcnn + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.mwcnn + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.models.primaldual.rst b/docs/source/mridc.collections.reconstruction.models.primaldual.rst new file mode 100644 index 00000000..72ce9d35 --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.models.primaldual.rst @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.primaldual package +========================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.primaldual.pd module +------------------------------------------------------------ + +.. automodule:: mridc.collections.reconstruction.models.primaldual.pd + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.primaldual + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.models.recurrentvarnet.rst b/docs/source/mridc.collections.reconstruction.models.recurrentvarnet.rst new file mode 100644 index 00000000..d82f2f7c --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.models.recurrentvarnet.rst @@ -0,0 +1,29 @@ +mridc.collections.reconstruction.models.recurrentvarnet package +=============================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.recurrentvarnet.conv2gru module +----------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.recurrentvarnet.conv2gru + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet module +----------------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.recurrentvarnet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.models.rim.rst b/docs/source/mridc.collections.reconstruction.models.rim.rst new file mode 100644 index 00000000..58a6db10 --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.models.rim.rst @@ -0,0 +1,45 @@ +mridc.collections.reconstruction.models.rim package +=================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.rim.conv\_layers module +--------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.rim.conv_layers + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.rim.rim\_block module +------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.rim.rim_block + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.rim.rnn\_cells module +------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.rim.rnn_cells + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.models.rim.utils module +-------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.rim.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.rim + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.models.unet_base.rst b/docs/source/mridc.collections.reconstruction.models.unet_base.rst new file mode 100644 index 00000000..deacd5a8 --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.models.unet_base.rst @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.unet\_base package +========================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.unet\_base.unet\_block module +--------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.unet_base.unet_block + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.unet_base + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.models.variablesplittingnet.rst b/docs/source/mridc.collections.reconstruction.models.variablesplittingnet.rst new file mode 100644 index 00000000..8f982eee --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.models.variablesplittingnet.rst @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.variablesplittingnet package +==================================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.variablesplittingnet.vsnet\_block module +-------------------------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.variablesplittingnet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.models.varnet.rst b/docs/source/mridc.collections.reconstruction.models.varnet.rst new file mode 100644 index 00000000..2210165f --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.models.varnet.rst @@ -0,0 +1,21 @@ +mridc.collections.reconstruction.models.varnet package +====================================================== + +Submodules +---------- + +mridc.collections.reconstruction.models.varnet.vn\_block module +--------------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.models.varnet.vn_block + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.models.varnet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.parts.rst b/docs/source/mridc.collections.reconstruction.parts.rst new file mode 100644 index 00000000..1aa734e4 --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.parts.rst @@ -0,0 +1,29 @@ +mridc.collections.reconstruction.parts package +============================================== + +Submodules +---------- + +mridc.collections.reconstruction.parts.transforms module +-------------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.parts.transforms + :members: + :undoc-members: + :show-inheritance: + +mridc.collections.reconstruction.parts.utils module +--------------------------------------------------- + +.. automodule:: mridc.collections.reconstruction.parts.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction.parts + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.reconstruction.rst b/docs/source/mridc.collections.reconstruction.rst new file mode 100644 index 00000000..f24c3c4b --- /dev/null +++ b/docs/source/mridc.collections.reconstruction.rst @@ -0,0 +1,21 @@ +mridc.collections.reconstruction package +======================================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mridc.collections.reconstruction.data + mridc.collections.reconstruction.metrics + mridc.collections.reconstruction.models + mridc.collections.reconstruction.parts + +Module contents +--------------- + +.. automodule:: mridc.collections.reconstruction + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.collections.rst b/docs/source/mridc.collections.rst new file mode 100644 index 00000000..362d583a --- /dev/null +++ b/docs/source/mridc.collections.rst @@ -0,0 +1,19 @@ +mridc.collections package +========================= + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mridc.collections.common + mridc.collections.reconstruction + +Module contents +--------------- + +.. automodule:: mridc.collections + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.core.classes.rst b/docs/source/mridc.core.classes.rst new file mode 100644 index 00000000..9ffff854 --- /dev/null +++ b/docs/source/mridc.core.classes.rst @@ -0,0 +1,61 @@ +mridc.core.classes package +========================== + +Submodules +---------- + +mridc.core.classes.common module +-------------------------------- + +.. automodule:: mridc.core.classes.common + :members: + :undoc-members: + :show-inheritance: + +mridc.core.classes.dataset module +--------------------------------- + +.. automodule:: mridc.core.classes.dataset + :members: + :undoc-members: + :show-inheritance: + +mridc.core.classes.export module +-------------------------------- + +.. automodule:: mridc.core.classes.export + :members: + :undoc-members: + :show-inheritance: + +mridc.core.classes.loss module +------------------------------ + +.. automodule:: mridc.core.classes.loss + :members: + :undoc-members: + :show-inheritance: + +mridc.core.classes.modelPT module +--------------------------------- + +.. automodule:: mridc.core.classes.modelPT + :members: + :undoc-members: + :show-inheritance: + +mridc.core.classes.module module +-------------------------------- + +.. automodule:: mridc.core.classes.module + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.core.classes + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.core.conf.rst b/docs/source/mridc.core.conf.rst new file mode 100644 index 00000000..efac091d --- /dev/null +++ b/docs/source/mridc.core.conf.rst @@ -0,0 +1,69 @@ +mridc.core.conf package +======================= + +Submodules +---------- + +mridc.core.conf.base\_config module +----------------------------------- + +.. automodule:: mridc.core.conf.base_config + :members: + :undoc-members: + :show-inheritance: + +mridc.core.conf.dataloader module +--------------------------------- + +.. automodule:: mridc.core.conf.dataloader + :members: + :undoc-members: + :show-inheritance: + +mridc.core.conf.hydra\_runner module +------------------------------------ + +.. automodule:: mridc.core.conf.hydra_runner + :members: + :undoc-members: + :show-inheritance: + +mridc.core.conf.modelPT module +------------------------------ + +.. automodule:: mridc.core.conf.modelPT + :members: + :undoc-members: + :show-inheritance: + +mridc.core.conf.optimizers module +--------------------------------- + +.. automodule:: mridc.core.conf.optimizers + :members: + :undoc-members: + :show-inheritance: + +mridc.core.conf.schedulers module +--------------------------------- + +.. automodule:: mridc.core.conf.schedulers + :members: + :undoc-members: + :show-inheritance: + +mridc.core.conf.trainer module +------------------------------ + +.. automodule:: mridc.core.conf.trainer + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.core.conf + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.core.connectors.rst b/docs/source/mridc.core.connectors.rst new file mode 100644 index 00000000..e7bfd68b --- /dev/null +++ b/docs/source/mridc.core.connectors.rst @@ -0,0 +1,21 @@ +mridc.core.connectors package +============================= + +Submodules +---------- + +mridc.core.connectors.save\_restore\_connector module +----------------------------------------------------- + +.. automodule:: mridc.core.connectors.save_restore_connector + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.core.connectors + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.core.neural_types.rst b/docs/source/mridc.core.neural_types.rst new file mode 100644 index 00000000..43deaebd --- /dev/null +++ b/docs/source/mridc.core.neural_types.rst @@ -0,0 +1,45 @@ +mridc.core.neural\_types package +================================ + +Submodules +---------- + +mridc.core.neural\_types.axes module +------------------------------------ + +.. automodule:: mridc.core.neural_types.axes + :members: + :undoc-members: + :show-inheritance: + +mridc.core.neural\_types.comparison module +------------------------------------------ + +.. automodule:: mridc.core.neural_types.comparison + :members: + :undoc-members: + :show-inheritance: + +mridc.core.neural\_types.elements module +---------------------------------------- + +.. automodule:: mridc.core.neural_types.elements + :members: + :undoc-members: + :show-inheritance: + +mridc.core.neural\_types.neural\_type module +-------------------------------------------- + +.. automodule:: mridc.core.neural_types.neural_type + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.core.neural_types + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.core.optim.rst b/docs/source/mridc.core.optim.rst new file mode 100644 index 00000000..bb352f50 --- /dev/null +++ b/docs/source/mridc.core.optim.rst @@ -0,0 +1,53 @@ +mridc.core.optim package +======================== + +Submodules +---------- + +mridc.core.optim.adafactor module +--------------------------------- + +.. automodule:: mridc.core.optim.adafactor + :members: + :undoc-members: + :show-inheritance: + +mridc.core.optim.lr\_scheduler module +------------------------------------- + +.. automodule:: mridc.core.optim.lr_scheduler + :members: + :undoc-members: + :show-inheritance: + +mridc.core.optim.novograd module +-------------------------------- + +.. automodule:: mridc.core.optim.novograd + :members: + :undoc-members: + :show-inheritance: + +mridc.core.optim.optimizer\_with\_master\_params module +------------------------------------------------------- + +.. automodule:: mridc.core.optim.optimizer_with_master_params + :members: + :undoc-members: + :show-inheritance: + +mridc.core.optim.optimizers module +---------------------------------- + +.. automodule:: mridc.core.optim.optimizers + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.core.optim + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.core.rst b/docs/source/mridc.core.rst new file mode 100644 index 00000000..bfbcf8c0 --- /dev/null +++ b/docs/source/mridc.core.rst @@ -0,0 +1,23 @@ +mridc.core package +================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mridc.core.classes + mridc.core.conf + mridc.core.connectors + mridc.core.neural_types + mridc.core.optim + mridc.core.utils + +Module contents +--------------- + +.. automodule:: mridc.core + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.core.utils.rst b/docs/source/mridc.core.utils.rst new file mode 100644 index 00000000..a87d1362 --- /dev/null +++ b/docs/source/mridc.core.utils.rst @@ -0,0 +1,29 @@ +mridc.core.utils package +======================== + +Submodules +---------- + +mridc.core.utils.neural\_type\_utils module +------------------------------------------- + +.. automodule:: mridc.core.utils.neural_type_utils + :members: + :undoc-members: + :show-inheritance: + +mridc.core.utils.numba\_utils module +------------------------------------ + +.. automodule:: mridc.core.utils.numba_utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.core.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.rst b/docs/source/mridc.rst new file mode 100644 index 00000000..09088364 --- /dev/null +++ b/docs/source/mridc.rst @@ -0,0 +1,47 @@ +mridc package +============= + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mridc.collections + mridc.core + mridc.utils + +Submodules +---------- + +mridc.constants module +---------------------- + +.. automodule:: mridc.constants + :members: + :undoc-members: + :show-inheritance: + +mridc.launch module +------------------- + +.. automodule:: mridc.launch + :members: + :undoc-members: + :show-inheritance: + +mridc.package\_info module +-------------------------- + +.. automodule:: mridc.package_info + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.utils.decorators.rst b/docs/source/mridc.utils.decorators.rst new file mode 100644 index 00000000..5b55330b --- /dev/null +++ b/docs/source/mridc.utils.decorators.rst @@ -0,0 +1,37 @@ +mridc.utils.decorators package +============================== + +Submodules +---------- + +mridc.utils.decorators.deprecated module +---------------------------------------- + +.. automodule:: mridc.utils.decorators.deprecated + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.decorators.experimental module +------------------------------------------ + +.. automodule:: mridc.utils.decorators.experimental + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.decorators.port\_docs module +---------------------------------------- + +.. automodule:: mridc.utils.decorators.port_docs + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.utils.decorators + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.utils.formaters.rst b/docs/source/mridc.utils.formaters.rst new file mode 100644 index 00000000..ef8fed13 --- /dev/null +++ b/docs/source/mridc.utils.formaters.rst @@ -0,0 +1,37 @@ +mridc.utils.formaters package +============================= + +Submodules +---------- + +mridc.utils.formaters.base module +--------------------------------- + +.. automodule:: mridc.utils.formaters.base + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.formaters.colors module +----------------------------------- + +.. automodule:: mridc.utils.formaters.colors + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.formaters.utils module +---------------------------------- + +.. automodule:: mridc.utils.formaters.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.utils.formaters + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/mridc.utils.rst b/docs/source/mridc.utils.rst new file mode 100644 index 00000000..dea4ead2 --- /dev/null +++ b/docs/source/mridc.utils.rst @@ -0,0 +1,142 @@ +mridc.utils package +=================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mridc.utils.decorators + mridc.utils.formaters + +Submodules +---------- + +mridc.utils.app\_state module +----------------------------- + +.. automodule:: mridc.utils.app_state + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.arguments module +---------------------------- + +.. automodule:: mridc.utils.arguments + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.cloud module +------------------------ + +.. automodule:: mridc.utils.cloud + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.config\_utils module +-------------------------------- + +.. automodule:: mridc.utils.config_utils + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.distributed module +------------------------------ + +.. automodule:: mridc.utils.distributed + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.env\_var\_parsing module +------------------------------------ + +.. automodule:: mridc.utils.env_var_parsing + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.exceptions module +----------------------------- + +.. automodule:: mridc.utils.exceptions + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.exp\_manager module +------------------------------- + +.. automodule:: mridc.utils.exp_manager + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.export\_utils module +-------------------------------- + +.. automodule:: mridc.utils.export_utils + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.get\_rank module +---------------------------- + +.. automodule:: mridc.utils.get_rank + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.lightning\_logger\_patch module +------------------------------------------- + +.. automodule:: mridc.utils.lightning_logger_patch + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.metaclasses module +------------------------------ + +.. automodule:: mridc.utils.metaclasses + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.model\_utils module +------------------------------- + +.. automodule:: mridc.utils.model_utils + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.mridc\_logging module +--------------------------------- + +.. automodule:: mridc.utils.mridc_logging + :members: + :undoc-members: + :show-inheritance: + +mridc.utils.timers module +------------------------- + +.. automodule:: mridc.utils.timers + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mridc.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/mridc/collections/common/metrics/global_average_loss_metric.py b/mridc/collections/common/metrics/global_average_loss_metric.py index c351916b..54b9617c 100644 --- a/mridc/collections/common/metrics/global_average_loss_metric.py +++ b/mridc/collections/common/metrics/global_average_loss_metric.py @@ -12,22 +12,24 @@ class GlobalAverageLossMetric(Metric): """ - This class is for averaging loss across multiple processes if a distributed backend is used. True average is - computed not running average. It does not accumulate gradients so the averaged loss cannot be used for + This class is for averaging loss across multiple processes if a distributed backend is used. True average is \ + computed not running average. It does not accumulate gradients so the averaged loss cannot be used for \ optimization. - If ``take_avg_loss`` is ``True``, the :meth:`update` method ``loss`` argument has to be a mean loss. If - ``take_avg_loss`` is ``False`` then the :meth:`update` method ``loss`` argument has to be a sum of losses. - See :doc:`PyTorch Lightning Metrics` for the metric usage instruction. + + .. note:: + If ``take_avg_loss`` is ``True``, the :meth:`update` method ``loss`` argument has to be a mean loss. If \ + ``take_avg_loss`` is ``False`` then the :meth:`update` method ``loss`` argument has to be a sum of losses. \ + See PyTorch Lightning Metrics for the metric usage instruction. Parameters ---------- - compute_on_step: The method :meth:`forward` only calls ``update()`` and returns ``None`` if this is set to + compute_on_step: The method :meth:`forward` only calls ``update()`` and returns ``None`` if this is set to \ ``False``. Default: ``True`` - dist_sync_on_step: Synchronize metric state across processes at each method :meth:`forward` call before returning - the value at the step - process_group: Specify the process group on which synchronization is called. default: ``None`` (which selects the - entire world) - take_avg_loss: If ``True`` values of :meth:`update` method ``loss`` argument has to be a mean loss. If ``False`` + dist_sync_on_step: Synchronize metric state across processes at each method :meth:`forward` call before \ + returning the value at the step + process_group: Specify the process group on which synchronization is called. default: ``None`` (which selects \ + the entire world) + take_avg_loss: If ``True`` values of :meth:`update` method ``loss`` argument has to be a mean loss. If ``False`` \ values of :meth:`update` method ``loss`` argument has to be a sum of losses. default: ``True`` """ @@ -45,9 +47,9 @@ def update(self, loss, num_measurements): Parameters ---------- - loss: A float zero dimensional ``torch.Tensor`` which is either sum or average of losses for processed + loss: A float zero dimensional ``torch.Tensor`` which is either sum or average of losses for processed \ examples. See ``take_avg_loss`` parameter of :meth:`__init__`. - num_measurements: An integer zero dimensional ``torch.Tensor`` which contains a number of loss measurements. + num_measurements: An integer zero dimensional ``torch.Tensor`` which contains a number of loss measurements. \ The sum or mean of the results of these measurements are in the ``loss`` parameter. """ if self.take_avg_loss: diff --git a/mridc/collections/common/parts/utils.py b/mridc/collections/common/parts/utils.py index 55a1d924..c56564e9 100644 --- a/mridc/collections/common/parts/utils.py +++ b/mridc/collections/common/parts/utils.py @@ -200,7 +200,7 @@ def rss_complex(data: torch.Tensor, dim: int = 0) -> torch.Tensor: def sense(data: torch.Tensor, sensitivity_maps: torch.Tensor, dim: int = 0) -> torch.Tensor: """ - The SENSitivity Encoding (SENSE) transform [1]. + The SENSitivity Encoding (SENSE) transform [1]_. References ---------- diff --git a/mridc/collections/reconstruction/data/subsample.py b/mridc/collections/reconstruction/data/subsample.py index 328afa00..52ea84d3 100644 --- a/mridc/collections/reconstruction/data/subsample.py +++ b/mridc/collections/reconstruction/data/subsample.py @@ -46,10 +46,10 @@ def __init__(self, center_fractions: Sequence[float], accelerations: Sequence[in Parameters ---------- - center_fractions: Fraction of low-frequency columns to be retained. If multiple values are provided, then - one of these numbers is chosen uniformly each time. For 2D setting this value corresponds to setting - the Full-Width-Half-Maximum. - accelerations: Amount of under-sampling. This should have the same length as center_fractions. If multiple + center_fractions: Fraction of low-frequency columns to be retained. If multiple values are provided, then \ + one of these numbers is chosen uniformly each time. For 2D setting this value corresponds to setting the \ + Full-Width-Half-Maximum. + accelerations: Amount of under-sampling. This should have the same length as center_fractions. If multiple \ values are provided, then one of these is chosen uniformly each time. """ if len(center_fractions) != len(accelerations): @@ -66,6 +66,19 @@ def __call__( half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02, ) -> Tuple[torch.Tensor, int]: + """ + + Parameters + ---------- + shape: Shape of the input tensor. + seed: Seed for the random number generator. + half_scan_percentage: Percentage of the low-frequency columns to be retained. + scale: Scale of the mask. + + Returns + ------- + A tuple of the mask and the number of low-frequency columns retained. + """ raise NotImplementedError def choose_acceleration(self): @@ -81,22 +94,18 @@ class RandomMaskFunc(MaskFunc): """ RandomMaskFunc creates a sub-sampling mask of a given shape. - The mask selects a subset of columns from the input k-space data. If the - k-space data has N columns, the mask picks out: - 1. N_low_freqs = (N * center_fraction) columns in the center - corresponding to low-frequencies. - 2. The other columns are selected uniformly at random with a - probability equal to: prob = (N / acceleration - N_low_freqs) / - (N - N_low_freqs). This ensures that the expected number of columns - selected is equal to (N / acceleration). - - It is possible to use multiple center_fractions and accelerations, in which - case one possible (center_fraction, acceleration) is chosen uniformly at - random each time the RandomMaskFunc object is called. - - For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], - then there is a 50% probability that 4-fold acceleration with 8% center - fraction is selected and a 50% probability that 8-fold acceleration with 4% + The mask selects a subset of columns from the input k-space data. If the k-space data has N columns, the mask \ + picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a probability equal to: \ + prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). This ensures that the expected number of \ + columns selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which case one possible (center_fraction, \ + acceleration) is chosen uniformly at random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there is a 50% probability that \ + 4-fold acceleration with 8% center fraction is selected and a 50% probability that 8-fold acceleration with 4% \ center fraction is selected. """ @@ -110,10 +119,10 @@ def __call__( """ Parameters ---------- - shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn - along the second last dimension. - seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time - for the same shape. The random state is reset afterwards. + shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ + along the second last dimension. + seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time \ + for the same shape. The random state is reset afterwards. half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. scale: Optional; Defines the scale of the center of the mask. @@ -147,23 +156,19 @@ class EquispacedMaskFunc(MaskFunc): """ EquispacedMaskFunc creates a sub-sampling mask of a given shape. - The mask selects a subset of columns from the input k-space data. If the - k-space data has N columns, the mask picks out: - 1. N_low_freqs = (N * center_fraction) columns in the center - corresponding to low-frequencies. - 2. The other columns are selected with equal spacing at a proportion - that reaches the desired acceleration rate taking into consideration - the number of low frequencies. This ensures that the expected number - of columns selected is equal to (N / acceleration) - - It is possible to use multiple center_fractions and accelerations, in which - case one possible (center_fraction, acceleration) is chosen uniformly at - random each time the EquispacedMaskFunc object is called. - - Note that this function may not give equispaced samples (documented in - https://github.com/facebookresearch/fastMRI/issues/54), which will require - modifications to standard GRAPPA approaches. Nonetheless, this aspect of - the function has been preserved to match the public multicoil data. + The mask selects a subset of columns from the input k-space data. If the k-space data has N columns, the mask \ + picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to low-frequencies. + 2. The other columns are selected with equal spacing at a proportion that reaches the desired acceleration \ + rate taking into consideration the number of low frequencies. This ensures that the expected number of \ + columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which case one possible (center_fraction, \ + acceleration) is chosen uniformly at random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in \ + https://github.com/facebookresearch/fastMRI/issues/54), which will require modifications to standard GRAPPA \ + approaches. Nonetheless, this aspect of the function has been preserved to match the public multicoil data. """ def __call__( @@ -176,9 +181,9 @@ def __call__( """ Parameters ---------- - shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn + shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ along the second last dimension. - seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time for + seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time for \ the same shape. The random state is reset afterwards. half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. scale: Optional; Defines the scale of the center of the mask. @@ -220,8 +225,8 @@ class Gaussian1DMaskFunc(MaskFunc): """ Creates a 1D sub-sampling mask of a given shape. - For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of which - the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled + For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of \ + which the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled \ according to a Gaussian distribution. The center fractions here act as Full-Width at Half-Maximum (FWHM) values. @@ -237,13 +242,13 @@ def __call__( """ Parameters ---------- - shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn - along the second last dimension. - seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time - for the same shape. The random state is reset afterwards. + shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ + along the second last dimension. + seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time \ + for the same shape. The random state is reset afterwards. half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. - scale: For autocalibration purposes, data points near the k-space center will be fully sampled within an - ellipse of which the half-axes will set to the set scale % of the fully sampled region + scale: For autocalibration purposes, data points near the k-space center will be fully sampled within an \ + ellipse of which the half-axes will set to the set scale % of the fully sampled region Returns ------- @@ -307,8 +312,8 @@ class Gaussian2DMaskFunc(MaskFunc): """ Creates a 2D sub-sampling mask of a given shape. - For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of which - the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled + For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of \ + which the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled \ according to a Gaussian distribution. The center fractions here act as Full-Width at Half-Maximum (FWHM) values. @@ -324,13 +329,13 @@ def __call__( """ Parameters ---------- - shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn - along the second last dimension. - seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time - for the same shape. The random state is reset afterwards. + shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ + along the second last dimension. + seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time for \ + the same shape. The random state is reset afterwards. half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. - scale: For autocalibration purposes, data points near the k-space center will be fully sampled within an - ellipse of which the half-axes will set to the set scale % of the fully sampled region + scale: For autocalibration purposes, data points near the k-space center will be fully sampled within an \ + ellipse of which the half-axes will set to the set scale % of the fully sampled region Returns ------- @@ -390,14 +395,14 @@ class Poisson2DMaskFunc(MaskFunc): """ Creates a 2D sub-sampling mask of a given shape. - For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of which - the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled - according to a (variable density) Poisson distribution. + For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of \ + which the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled \ + according to a (variable density) Poisson distribution. - For a given acceleration factor to be accurate, the scale for the fully sampled center should remain at the default - 0.02. A predefined list is used to convert the acceleration factor to the appropriate r parameter needed for the - variable density calculation. This list has been made to accommodate acceleration factors of 4 up to 21, rounding - off to the nearest one available. As such, acceleration factors outside this range cannot be used. + For a given acceleration factor to be accurate, the scale for the fully sampled center should remain at the \ + default 0.02. A predefined list is used to convert the acceleration factor to the appropriate r parameter needed \ + for the variable density calculation. This list has been made to accommodate acceleration factors of 4 up to 21, \ + rounding off to the nearest one available. As such, acceleration factors outside this range cannot be used. """ def __call__( @@ -410,13 +415,13 @@ def __call__( """ Parameters ---------- - shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn - along the second last dimension. - seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time - for the same shape. The random state is reset afterwards. + shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ + along the second last dimension. + seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time \ + for the same shape. The random state is reset afterwards. half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. - scale: For autocalibration purposes, data points near the k-space center will be fully sampled within an - ellipse of which the half-axes will set to the set scale % of the fully sampled region + scale: For autocalibration purposes, data points near the k-space center will be fully sampled within an \ + ellipse of which the half-axes will set to the set scale % of the fully sampled region Returns ------- diff --git a/mridc/collections/reconstruction/models/base.py b/mridc/collections/reconstruction/models/base.py index 3212b053..6e6cf263 100644 --- a/mridc/collections/reconstruction/models/base.py +++ b/mridc/collections/reconstruction/models/base.py @@ -10,15 +10,18 @@ import h5py import numpy as np import torch +import wandb from omegaconf import DictConfig from pytorch_lightning import Trainer from torch import nn from torch.utils.data import DataLoader +from torchmetrics.metric import Metric from mridc.collections.common.parts.fft import ifft2c from mridc.collections.common.parts.utils import rss_complex from mridc.collections.reconstruction.data.mri_data import FastMRISliceDataset from mridc.collections.reconstruction.data.subsample import create_mask_for_mask_type +from mridc.collections.reconstruction.metrics.evaluate import mse, nmse, psnr, ssim from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet from mridc.collections.reconstruction.parts.transforms import MRIDataTransforms from mridc.collections.reconstruction.parts.utils import batched_mask_center @@ -28,6 +31,26 @@ __all__ = ["BaseMRIReconstructionModel", "BaseSensitivityModel"] +class DistributedMetricSum(Metric): + """ + A metric that sums the values of a metric across all workers. + Taken from: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/pl_modules/mri_module.py + """ + + def __init__(self, dist_sync_on_step=True): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.add_state("quantity", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, batch: torch.Tensor): # type: ignore + """Update the metric with a batch of data.""" + self.quantity += batch + + def compute(self): + """Compute the metric value.""" + return self.quantity + + class BaseMRIReconstructionModel(ModelPT, ABC): """Base class of all MRIReconstruction models.""" @@ -36,7 +59,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable self.world_size = 1 if trainer is not None: - self.world_size = trainer.num_nodes * trainer.num_gpus + self.world_size = trainer.num_nodes * trainer.num_devices cfg = convert_model_config_to_dict_config(cfg) cfg = maybe_update_config_version(cfg) @@ -44,6 +67,18 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # init superclass super().__init__(cfg=cfg, trainer=trainer) + self.MSE = DistributedMetricSum() + self.NMSE = DistributedMetricSum() + self.SSIM = DistributedMetricSum() + self.PSNR = DistributedMetricSum() + self.TotExamples = DistributedMetricSum() + + # Set evaluation metrics dictionaries + self.mse_vals: Dict = defaultdict(dict) + self.nmse_vals: Dict = defaultdict(dict) + self.ssim_vals: Dict = defaultdict(dict) + self.psnr_vals: Dict = defaultdict(dict) + # skipcq: PYL-R0201 def process_loss(self, target, pred, _loss_fn): """ @@ -85,7 +120,7 @@ def loss_fn(x, y): return loss_fn(target, pred) @staticmethod - def process_inputs(y, mask): + def process_inputs(y, mask, init_pred): """ Processes the inputs to the method. @@ -95,6 +130,8 @@ def process_inputs(y, mask): list of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. list of torch.Tensor, shape [1, 1, n_x, n_y, 1] + init_pred: Initial prediction. + list of torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- @@ -102,6 +139,8 @@ def process_inputs(y, mask): randomly selected y mask: Sampling mask. randomly selected mask + init_pred: Initial prediction. + randomly selected init_pred r: Random index. """ if isinstance(y, list): @@ -110,7 +149,7 @@ def process_inputs(y, mask): mask = mask[r] else: r = 0 - return y, mask, r + return y, mask, init_pred, r def training_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]: """ @@ -153,7 +192,7 @@ def training_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Dic dict, shape [1] """ y, sensitivity_maps, mask, init_pred, target, _, _, acc = batch - y, mask, r = self.process_inputs(y, mask) + y, mask, init_pred, r = self.process_inputs(y, mask, init_pred) preds = self.forward(y, sensitivity_maps, mask, init_pred, target) if self.accumulate_estimates: @@ -179,9 +218,7 @@ def validation_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> D Parameters ---------- - batch: Batch of data. - Dict[str, torch.Tensor], with keys, - + batch: Batch of data. Dict[str, torch.Tensor], with keys, 'y': subsampled kspace, torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] 'sensitivity_maps': sensitivity_maps, @@ -214,7 +251,7 @@ def validation_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> D dict, shape [1] """ y, sensitivity_maps, mask, init_pred, target, fname, slice_num, _ = batch - y, mask, _ = self.process_inputs(y, mask) + y, mask, init_pred, _ = self.process_inputs(y, mask, init_pred) preds = self.forward(y, sensitivity_maps, mask, init_pred, target) if self.accumulate_estimates: @@ -245,6 +282,17 @@ def validation_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> D self.log_image(f"{key}/reconstruction", output) self.log_image(f"{key}/error", error) + target = target.numpy() # type: ignore + output = output.numpy() # type: ignore + self.mse_vals[fname][slice_num] = torch.tensor(mse(target, output)).view(1) + self.nmse_vals[fname][slice_num] = torch.tensor(nmse(target, output)).view(1) + self.ssim_vals[fname][slice_num] = torch.tensor(ssim(target, output, maxval=output.max() - output.min())).view( + 1 + ) + self.psnr_vals[fname][slice_num] = torch.tensor(psnr(target, output, maxval=output.max() - output.min())).view( + 1 + ) + return {"val_loss": val_loss} def test_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Tuple[str, int, torch.Tensor]: @@ -253,9 +301,7 @@ def test_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Tuple[s Parameters ---------- - batch: Batch of data. - Dict[str, torch.Tensor], with keys, - + batch: Batch of data. Dict[str, torch.Tensor], with keys, 'y': subsampled kspace, torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] 'sensitivity_maps': sensitivity_maps, @@ -289,7 +335,7 @@ def test_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Tuple[s torch.Tensor, shape [batch_size, n_x, n_y, 2] """ y, sensitivity_maps, mask, init_pred, target, fname, slice_num, _ = batch - y, mask, _ = self.process_inputs(y, mask) + y, mask, init_pred, _ = self.process_inputs(y, mask, init_pred) preds = self.forward(y, sensitivity_maps, mask, init_pred, target) if self.accumulate_estimates: @@ -334,9 +380,11 @@ def log_image(self, name, image): str image: Image to log. torch.Tensor, shape [batch_size, n_x, n_y, 2] - """ - self.logger.experiment.add_image(name, image, global_step=self.global_step) + if "wandb" in self.logger.__module__.lower(): + self.logger.experiment.log({name: wandb.Image(image.numpy())}) + else: + self.logger.experiment.add_image(name, image, global_step=self.global_step) def validation_epoch_end(self, outputs): """ @@ -354,6 +402,48 @@ def validation_epoch_end(self, outputs): """ self.log("val_loss", torch.stack([x["val_loss"] for x in outputs]).mean()) + # Log metrics. + # Taken from: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/pl_modules/mri_module.py + mse_vals = defaultdict(dict) + nmse_vals = defaultdict(dict) + ssim_vals = defaultdict(dict) + psnr_vals = defaultdict(dict) + + for k in self.mse_vals.keys(): + mse_vals[k].update(self.mse_vals[k]) + for k in self.nmse_vals.keys(): + nmse_vals[k].update(self.nmse_vals[k]) + for k in self.ssim_vals.keys(): + ssim_vals[k].update(self.ssim_vals[k]) + for k in self.psnr_vals.keys(): + psnr_vals[k].update(self.psnr_vals[k]) + + # apply means across image volumes + metrics = {"MSE": 0, "NMSE": 0, "SSIM": 0, "PSNR": 0} + local_examples = 0 + for fname in mse_vals: + local_examples += 1 + metrics["MSE"] = metrics["MSE"] + torch.mean(torch.cat([v.view(-1) for _, v in mse_vals[fname].items()])) + metrics["NMSE"] = metrics["NMSE"] + torch.mean( + torch.cat([v.view(-1) for _, v in nmse_vals[fname].items()]) + ) + metrics["SSIM"] = metrics["SSIM"] + torch.mean( + torch.cat([v.view(-1) for _, v in ssim_vals[fname].items()]) + ) + metrics["PSNR"] = metrics["PSNR"] + torch.mean( + torch.cat([v.view(-1) for _, v in psnr_vals[fname].items()]) + ) + + # reduce across ddp via sum + metrics["MSE"] = self.MSE(metrics["MSE"]) + metrics["NMSE"] = self.NMSE(metrics["NMSE"]) + metrics["SSIM"] = self.SSIM(metrics["SSIM"]) + metrics["PSNR"] = self.PSNR(metrics["PSNR"]) + + tot_examples = self.TotExamples(torch.tensor(local_examples)) + for metric, value in metrics.items(): + self.log(f"{metric}", value / tot_examples) + def test_epoch_end(self, outputs): """ Called at the end of test epoch to aggregate outputs. @@ -463,7 +553,6 @@ def _setup_dataloader_from_config(cfg: DictConfig) -> DataLoader: if len(accelerations) > 2 else [create_mask_for_mask_type(mask_type, center_fractions, accelerations)] ) - else: mask_func = None # type: ignore mask_center_scale = 0.02 @@ -492,7 +581,7 @@ def _setup_dataloader_from_config(cfg: DictConfig) -> DataLoader: return torch.utils.data.DataLoader( dataset=dataset, - batch_size=1, + batch_size=cfg.get("batch_size"), sampler=sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), @@ -503,7 +592,6 @@ def _setup_dataloader_from_config(cfg: DictConfig) -> DataLoader: class BaseSensitivityModel(nn.Module, ABC): """ Model for learning sensitivity estimation from k-space data. - This model applies an IFFT to multichannel k-space data and then a U-Net to the coil images to estimate coil sensitivities. """ diff --git a/mridc/collections/reconstruction/models/ccnn.py b/mridc/collections/reconstruction/models/ccnn.py index f05b34ce..886531d3 100644 --- a/mridc/collections/reconstruction/models/ccnn.py +++ b/mridc/collections/reconstruction/models/ccnn.py @@ -22,12 +22,18 @@ class CascadeNet(BaseMRIReconstructionModel, ABC): """ - Implementation of the Deep Cascade of Convolutional Neural Networks, as presented in [1]. + Implementation of the Deep Cascade of Convolutional Neural Networks, as presented in Schlemper, J., \ + Caballero, J., Hajnal, J. V., Price, A., & Rueckert, D. References ---------- - .. [1] Schlemper, J., Caballero, J., Hajnal, J. V., Price, A., & Rueckert, D., A Deep Cascade of Convolutional Neural Networks for MR Image Reconstruction. Information Processing in Medical Imaging (IPMI), 2017. Available at: https://arxiv.org/pdf/1703.00555.pdf + .. + + Schlemper, J., Caballero, J., Hajnal, J. V., Price, A., & Rueckert, D., A Deep Cascade of Convolutional \ + Neural Networks for MR Image Reconstruction. Information Processing in Medical Imaging (IPMI), 2017. \ + Available at: https://arxiv.org/pdf/1703.00555.pdf + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): diff --git a/mridc/collections/reconstruction/models/cirim.py b/mridc/collections/reconstruction/models/cirim.py index e2ad5510..e9e0aa88 100644 --- a/mridc/collections/reconstruction/models/cirim.py +++ b/mridc/collections/reconstruction/models/cirim.py @@ -25,11 +25,18 @@ class CIRIM(BaseMRIReconstructionModel, ABC): """ - Implementation of the Cascades of Independently Recurrent Inference Machines, as presented in [1]. + Implementation of the Cascades of Independently Recurrent Inference Machines, as presented in \ + Karkalousos, D. et al. References ---------- - .. [1] Karkalousos, D. et al. (2021) ‘Assessment of Data Consistency through Cascades of Independently Recurrent Inference Machines for fast and robust accelerated MRI reconstruction’. Available at: https://arxiv.org/abs/2111.15498v1 + + .. + + Karkalousos, D. et al. (2021) ‘Assessment of Data Consistency through Cascades of Independently Recurrent \ + Inference Machines for fast and robust accelerated MRI reconstruction’. Available at: \ + https://arxiv.org/abs/2111.15498v1 + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): diff --git a/mridc/collections/reconstruction/models/crnn.py b/mridc/collections/reconstruction/models/crnn.py index e7ee6a44..10458f9e 100644 --- a/mridc/collections/reconstruction/models/crnn.py +++ b/mridc/collections/reconstruction/models/crnn.py @@ -24,11 +24,18 @@ class CRNNet(BaseMRIReconstructionModel, ABC): """ - Implementation of the Convolutional Recurrent Neural Network, inspired by [1]. + Implementation of the Convolutional Recurrent Neural Network, inspired by C. Qin, J. Schlemper, J. Caballero, \ + A. N. Price, J. V. Hajnal and D. Rueckert. References ---------- - .. [1] C. Qin, J. Schlemper, J. Caballero, A. N. Price, J. V. Hajnal and D. Rueckert, "Convolutional Recurrent Neural Networks for Dynamic MR Image Reconstruction," in IEEE Transactions on Medical Imaging, vol. 38, no. 1, pp. 280-290, Jan. 2019, doi: 10.1109/TMI.2018.2863670. + + .. + + C. Qin, J. Schlemper, J. Caballero, A. N. Price, J. V. Hajnal and D. Rueckert, "Convolutional Recurrent \ + Neural Networks for Dynamic MR Image Reconstruction," in IEEE Transactions on Medical Imaging, vol. 38, \ + no. 1, pp. 280-290, Jan. 2019, doi: 10.1109/TMI.2018.2863670. + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): @@ -108,7 +115,7 @@ def forward( """ sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps pred = self.crnn(y, sensitivity_maps, mask) - yield [self.process_intermediate_eta(x, sensitivity_maps, target) for x in pred] + yield [self.process_intermediate_pred(x, sensitivity_maps, target) for x in pred] def process_intermediate_pred(self, pred, sensitivity_maps, target): """ diff --git a/mridc/collections/reconstruction/models/didn/didn.py b/mridc/collections/reconstruction/models/didn/didn.py index 1f733330..aab38130 100644 --- a/mridc/collections/reconstruction/models/didn/didn.py +++ b/mridc/collections/reconstruction/models/didn/didn.py @@ -11,11 +11,17 @@ class Subpixel(nn.Module): """ - Subpixel convolution layer for up-scaling of low resolution features at super-resolution as implemented in [1]. + Subpixel convolution layer for up-scaling of low resolution features at super-resolution as implemented in \ + Yu, Songhyun, et al. References ---------- - .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. + + .. + Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer \ + Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, \ + https://doi.org/10.1109/CVPRW.2019.00262. + """ def __init__(self, in_channels, out_channels, upscale_factor, kernel_size, padding=0): @@ -43,11 +49,16 @@ def forward(self, x): class ReconBlock(nn.Module): """ - Reconstruction Block of DIDN model as implemented in [1]. + Reconstruction Block of DIDN model as implemented in Yu, Songhyun, et al. References ---------- - .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. + + .. + Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer \ + Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, \ + https://doi.org/10.1109/CVPRW.2019.00262. + """ def __init__(self, in_channels, num_convs): @@ -91,11 +102,16 @@ def forward(self, input_data): class DUB(nn.Module): """ - Down-up block (DUB) for DIDN model as implemented in [1]. + Down-up block (DUB) for DIDN model as implemented in Yu, Songhyun, et al. References ---------- - .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. + + .. + Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer \ + Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, \ + https://doi.org/10.1109/CVPRW.2019.00262. + """ def __init__( @@ -221,11 +237,16 @@ def forward(self, x): class DIDN(nn.Module): """ - Deep Iterative Down-up convolutional Neural network (DIDN) implementation as in [1]. + Deep Iterative Down-up convolutional Neural network (DIDN) implementation as in Yu, Songhyun, et al. References ---------- - .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. + + .. + Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer \ + Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, \ + https://doi.org/10.1109/CVPRW.2019.00262. + """ def __init__( diff --git a/mridc/collections/reconstruction/models/dunet.py b/mridc/collections/reconstruction/models/dunet.py index 35ceb008..ab83eece 100644 --- a/mridc/collections/reconstruction/models/dunet.py +++ b/mridc/collections/reconstruction/models/dunet.py @@ -29,11 +29,17 @@ class DUNet(BaseMRIReconstructionModel, ABC): """ - Implementation of the Down-Up NET, inspired by [1]. + Implementation of the Down-Up NET, inspired by Hammernik, K, Schlemper, J, Qin, C, et al. References ---------- - .. [1] Hammernik, K, Schlemper, J, Qin, C, et al. Systematic evaluation of iterative deep neural networks for fast parallel MRI reconstruction with sensitivity-weighted coil combination. Magn Reson Med. 2021; 86: 1859– 1872. https://doi.org/10.1002/mrm.28827 + + .. + + Hammernik, K, Schlemper, J, Qin, C, et al. Systematic evaluation of iterative deep neural networks for fast \ + parallel MRI reconstruction with sensitivity-weighted coil combination. Magn Reson Med. 2021; 86: 1859– 1872. \ + https://doi.org/10.1002/mrm.28827 + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): diff --git a/mridc/collections/reconstruction/models/jointicnet.py b/mridc/collections/reconstruction/models/jointicnet.py index 362bae0c..e218c29f 100644 --- a/mridc/collections/reconstruction/models/jointicnet.py +++ b/mridc/collections/reconstruction/models/jointicnet.py @@ -21,12 +21,18 @@ class JointICNet(BaseMRIReconstructionModel, ABC): """ - Implementation of the Joint Deep Model-Based MR Image and Coil Sensitivity Reconstruction Network (Joint-ICNet), as - presented in [1]. + Implementation of the Joint Deep Model-Based MR Image and Coil Sensitivity Reconstruction Network (Joint-ICNet), \ + as presented in Jun, Yohan, et al. References ---------- - .. [1] Jun, Yohan, et al. “Joint Deep Model-Based MR Image and Coil Sensitivity Reconstruction Network (Joint-ICNet) for Fast MRI.” 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), IEEE, 2021, pp. 5266–75. DOI.org (Crossref), https://doi.org/10.1109/CVPR46437.2021.00523. + + .. + + Jun, Yohan, et al. “Joint Deep Model-Based MR Image and Coil Sensitivity Reconstruction Network (Joint-ICNet) \ + for Fast MRI.” 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), IEEE, 2021, pp. \ + 5266–75. DOI.org (Crossref), https://doi.org/10.1109/CVPR46437.2021.00523. + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): diff --git a/mridc/collections/reconstruction/models/kikinet.py b/mridc/collections/reconstruction/models/kikinet.py index c0a50a4b..63cead8d 100644 --- a/mridc/collections/reconstruction/models/kikinet.py +++ b/mridc/collections/reconstruction/models/kikinet.py @@ -25,11 +25,18 @@ class KIKINet(BaseMRIReconstructionModel, ABC): """ - Based on KIKINet implementation [1]. Modified to work with multi-coil k-space data. + Based on KIKINet implementation [1]. Modified to work with multi-coil k-space data, as presented in Eo, Taejoon, \ + et al. References ---------- - .. [1] Eo, Taejoon, et al. “KIKI-Net: Cross-Domain Convolutional Neural Networks for Reconstructing Undersampled Magnetic Resonance Images.” Magnetic Resonance in Medicine, vol. 80, no. 5, Nov. 2018, pp. 2188–201. PubMed, https://doi.org/10.1002/mrm.27201. + + .. + + Eo, Taejoon, et al. “KIKI-Net: Cross-Domain Convolutional Neural Networks for Reconstructing Undersampled \ + Magnetic Resonance Images.” Magnetic Resonance in Medicine, vol. 80, no. 5, Nov. 2018, pp. 2188–201. PubMed, \ + https://doi.org/10.1002/mrm.27201. + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): diff --git a/mridc/collections/reconstruction/models/lpd.py b/mridc/collections/reconstruction/models/lpd.py index c4c624d2..1a354abd 100644 --- a/mridc/collections/reconstruction/models/lpd.py +++ b/mridc/collections/reconstruction/models/lpd.py @@ -25,11 +25,16 @@ class LPDNet(BaseMRIReconstructionModel, ABC): """ - Implementation of the Learned Primal Dual network, inspired by [1]. + Implementation of the Learned Primal Dual network, inspired by Adler, Jonas, and Ozan Öktem. References ---------- - .. [1] Adler, Jonas, and Ozan Öktem. “Learned Primal-Dual Reconstruction.” IEEE Transactions on Medical Imaging, vol. 37, no. 6, June 2018, pp. 1322–32. arXiv.org, https://doi.org/10.1109/TMI.2018.2799231. + + .. + + Adler, Jonas, and Ozan Öktem. “Learned Primal-Dual Reconstruction.” IEEE Transactions on Medical Imaging, \ + vol. 37, no. 6, June 2018, pp. 1322–32. arXiv.org, https://doi.org/10.1109/TMI.2018.2799231. + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): diff --git a/mridc/collections/reconstruction/models/multidomain/multidomain.py b/mridc/collections/reconstruction/models/multidomain/multidomain.py index 5e07e4e7..bbdaab9b 100644 --- a/mridc/collections/reconstruction/models/multidomain/multidomain.py +++ b/mridc/collections/reconstruction/models/multidomain/multidomain.py @@ -162,13 +162,18 @@ def __repr__(self): class StandardizationLayer(nn.Module): """ Multi-channel data standardization method. Inspired by AIRS model submission to the Fast MRI 2020 challenge. - Given individual coil images :math:`\{x_i\}_{i=1}^{N_c}` and sensitivity coil maps :math:`\{S_i\}_{i=1}^{N_c}` - + Given individual coil images :math:`\{x_i\}_{i=1}^{N_c}` and sensitivity coil maps :math:`\{S_i\}_{i=1}^{N_c}` \ it returns + .. math:: - [(x_{\text{sense}}, {x_{\text{res}}}_1), ..., (x_{\text{sense}}, {x_{\text{res}}}_{N_c})] - where :math:`{x_{\text{res}}}_i = xi - S_i \times x_{\text{sense}}` and - :math:`x_{\text{sense}} = \sum_{i=1}^{N_c} {S_i}^{*} \times x_i`. + + [(x_{sense}, {x_{res}}_1), ..., (x_{sense}, {x_{res}}_{N_c})] + + where + + :math:`{x_{res}}_i = xi - S_i X x_{sense}` and + + :math:`x_{sense} = \sum_{i=1}^{N_c} {S_i}^{*} X x_i`. """ def __init__(self, coil_dim=1, channel_dim=-1): diff --git a/mridc/collections/reconstruction/models/mwcnn/mwcnn.py b/mridc/collections/reconstruction/models/mwcnn/mwcnn.py index e3705bc8..37e6f7bb 100644 --- a/mridc/collections/reconstruction/models/mwcnn/mwcnn.py +++ b/mridc/collections/reconstruction/models/mwcnn/mwcnn.py @@ -14,11 +14,16 @@ class DWT(nn.Module): """ - 2D Discrete Wavelet Transform as implemented in [1]. + 2D Discrete Wavelet Transform as implemented in Liu, Pengju, et al. References ---------- - .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071. + + .. + + Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \ + arXiv.org, http://arxiv.org/abs/1805.07071. + """ def __init__(self): @@ -55,11 +60,16 @@ def forward(x: torch.Tensor) -> torch.Tensor: class IWT(nn.Module): """ - 2D Inverse Wavelet Transform as implemented in [1]. + 2D Inverse Wavelet Transform as implemented in Liu, Pengju, et al. References ---------- - .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071. + + .. + + Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \ + arXiv.org, http://arxiv.org/abs/1805.07071. + """ def __init__(self): @@ -100,11 +110,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ConvBlock(nn.Module): """ - Convolution Block for MWCNN as implemented in [1]. + Convolution Block for MWCNN as implemented in Liu, Pengju, et al. References ---------- - .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071. + + .. + + Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \ + arXiv.org, http://arxiv.org/abs/1805.07071. + """ def __init__( @@ -173,11 +188,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DilatedConvBlock(nn.Module): """ - Double dilated Convolution Block fpr MWCNN as implemented in [1]. + Double dilated Convolution Block fpr MWCNN as implemented in Liu, Pengju, et al. References ---------- - .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071. + + .. + + Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \ + arXiv.org, http://arxiv.org/abs/1805.07071. + """ def __init__( @@ -264,11 +284,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MWCNN(nn.Module): """ - Multi-level Wavelet CNN (MWCNN) implementation as implemented in [1]. + Multi-level Wavelet CNN (MWCNN) implementation as implemented in Liu, Pengju, et al. References ---------- - .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071. + + .. + + Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \ + arXiv.org, http://arxiv.org/abs/1805.07071. + """ def __init__( diff --git a/mridc/collections/reconstruction/models/pics.py b/mridc/collections/reconstruction/models/pics.py index 4a74981a..88d77a52 100644 --- a/mridc/collections/reconstruction/models/pics.py +++ b/mridc/collections/reconstruction/models/pics.py @@ -19,12 +19,15 @@ class PICS(BaseMRIReconstructionModel, ABC): """ - Parallel-Imaging Compressed Sensing (PICS) reconstruction using the BART [1]. + Parallel-Imaging Compressed Sensing (PICS) reconstruction using the BART by Uecker, M. et al. References ---------- - .. [1] Uecker, M. et al. (2015) ‘Berkeley Advanced Reconstruction Toolbox’, Proc. Intl. Soc. Mag. Reson. Med., 23. + .. + + Uecker, M. et al. (2015) ‘Berkeley Advanced Reconstruction Toolbox’, Proc. Intl. Soc. Mag. Reson. Med., 23. + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): diff --git a/mridc/collections/reconstruction/models/recurrentvarnet/recurentvarnet.py b/mridc/collections/reconstruction/models/recurrentvarnet/recurentvarnet.py index e3a0fdfa..8bc16596 100644 --- a/mridc/collections/reconstruction/models/recurrentvarnet/recurentvarnet.py +++ b/mridc/collections/reconstruction/models/recurrentvarnet/recurentvarnet.py @@ -18,13 +18,19 @@ class RecurrentInit(nn.Module): """ - Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in [1]. + Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in Yiasemis, George, et al. The RSI module learns to initialize the recurrent hidden state :math:`h_0`, input of the first RecurrentVarNetBlock of the RecurrentVarNet. References ---------- - .. [1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639. + + .. + + Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to \ + the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, \ + http://arxiv.org/abs/2111.09639. + """ def __init__( @@ -102,11 +108,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class RecurrentVarNetBlock(nn.Module): """ - Recurrent Variational Network Block :math:`\mathcal{H}_{\theta_{t}}` as presented in [1]. + Recurrent Variational Network Block :math:`\mathcal{H}_{\theta_{t}}` as presented in Yiasemis, George, et al. + References ---------- - .. [1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639. + + .. + + Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to \ + the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, \ + http://arxiv.org/abs/2111.09639. + """ def __init__( diff --git a/mridc/collections/reconstruction/models/rim/rnn_cells.py b/mridc/collections/reconstruction/models/rim/rnn_cells.py index 877dee9a..12394ef2 100644 --- a/mridc/collections/reconstruction/models/rim/rnn_cells.py +++ b/mridc/collections/reconstruction/models/rim/rnn_cells.py @@ -253,7 +253,7 @@ def forward(self, _input, hx): class IndRNNCellBase(nn.Module): """ - Base class for Independently RNN cells as presented in [1]. + Base class for Independently RNN cells as presented in [1]_. References ---------- diff --git a/mridc/collections/reconstruction/models/rvn.py b/mridc/collections/reconstruction/models/rvn.py index b81bbd47..f5e8ffa5 100644 --- a/mridc/collections/reconstruction/models/rvn.py +++ b/mridc/collections/reconstruction/models/rvn.py @@ -24,11 +24,17 @@ class RecurrentVarNet(BaseMRIReconstructionModel, ABC): """ - Implementation of the Recurrent Variational Network implementation, as presented in [1]. + Implementation of the Recurrent Variational Network implementation, as presented in Yiasemis, George, et al. References ---------- - .. [1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639. + + .. + + Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to \ + the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, \ + http://arxiv.org/abs/2111.09639. + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): diff --git a/mridc/collections/reconstruction/models/unet.py b/mridc/collections/reconstruction/models/unet.py index 1f3631ba..63eff762 100644 --- a/mridc/collections/reconstruction/models/unet.py +++ b/mridc/collections/reconstruction/models/unet.py @@ -21,12 +21,16 @@ class UNet(BaseMRIReconstructionModel, ABC): """ - Implementation of the UNet, as presented in [1]. + Implementation of the UNet, as presented in O. Ronneberger, P. Fischer, and Thomas Brox. References ---------- + .. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation. \ + In International Conference on Medical image computing and computer-assisted intervention, pages 234–241. \ + Springer, 2015. - .. [1] O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pages 234–241. Springer, 2015. """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): diff --git a/mridc/collections/reconstruction/models/unet_base/unet_block.py b/mridc/collections/reconstruction/models/unet_base/unet_block.py index 67694596..43fa486f 100644 --- a/mridc/collections/reconstruction/models/unet_base/unet_block.py +++ b/mridc/collections/reconstruction/models/unet_base/unet_block.py @@ -138,7 +138,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Unet(torch.nn.Module): """ - PyTorch implementation of a U-Net model, as presented in [1]. + PyTorch implementation of a U-Net model, as presented in [1]_. References ---------- diff --git a/mridc/collections/reconstruction/models/variablesplittingnet/vsnet_block.py b/mridc/collections/reconstruction/models/variablesplittingnet/vsnet_block.py index 0563cb18..6f375493 100644 --- a/mridc/collections/reconstruction/models/variablesplittingnet/vsnet_block.py +++ b/mridc/collections/reconstruction/models/variablesplittingnet/vsnet_block.py @@ -38,7 +38,7 @@ def forward(self, x, Sx): class VSNetBlock(torch.nn.Module): """ - Model block for the Variable-Splitting Network inspired by [1]. + Model block for the Variable-Splitting Network inspired by [1]_. References ---------- diff --git a/mridc/collections/reconstruction/models/vn.py b/mridc/collections/reconstruction/models/vn.py index 9c652a97..b89a8cdc 100644 --- a/mridc/collections/reconstruction/models/vn.py +++ b/mridc/collections/reconstruction/models/vn.py @@ -22,12 +22,16 @@ class VarNet(BaseMRIReconstructionModel, ABC): """ - Implementation of the End-to-end Variational Network (VN), as presented in [1]. + Implementation of the End-to-end Variational Network (VN), as presented in Sriram, A. et al. References ---------- - .. [1] Sriram, A. et al. (2020) ‘End-to-End Variational Networks for Accelerated MRI Reconstruction’. Available at: https://github.com/facebookresearch/fastMRI. + .. + + Sriram, A. et al. (2020) ‘End-to-End Variational Networks for Accelerated MRI Reconstruction’. Available \ + at: https://github.com/facebookresearch/fastMRI. + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): diff --git a/mridc/collections/reconstruction/models/vsnet.py b/mridc/collections/reconstruction/models/vsnet.py index 870d92b1..8059ea55 100644 --- a/mridc/collections/reconstruction/models/vsnet.py +++ b/mridc/collections/reconstruction/models/vsnet.py @@ -28,11 +28,17 @@ class VSNet(BaseMRIReconstructionModel, ABC): """ - Implementation of the Variable-Splitting Net, as presented in [1]. + Implementation of the Variable-Splitting Net, as presented in Duan, J. et al. References ---------- - .. [1] Duan, J. et al. (2019) ‘Vs-net: Variable splitting network for accelerated parallel MRI reconstruction’, Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture Notes in Bioinformatics), 11767 LNCS, pp. 713–722. doi: 10.1007/978-3-030-32251-9_78. + + .. + + Duan, J. et al. (2019) ‘Vs-net: Variable splitting network for accelerated parallel MRI reconstruction’, \ + Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture \ + Notes in Bioinformatics), 11767 LNCS, pp. 713–722. doi: 10.1007/978-3-030-32251-9_78. + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): diff --git a/mridc/collections/reconstruction/models/xpdnet.py b/mridc/collections/reconstruction/models/xpdnet.py index fed6d426..f9da07a5 100644 --- a/mridc/collections/reconstruction/models/xpdnet.py +++ b/mridc/collections/reconstruction/models/xpdnet.py @@ -24,11 +24,16 @@ class XPDNet(BaseMRIReconstructionModel, ABC): """ - Implementation of the XPDNet, as presented in [1]. + Implementation of the XPDNet, as presented in Ramzi, Zaccharie, et al. References ---------- - .. [1] Ramzi, Zaccharie, et al. “XPDNet for MRI Reconstruction: An Application to the 2020 FastMRI Challenge.” ArXiv:2010.07290 [Physics, Stat], July 2021. arXiv.org, http://arxiv.org/abs/2010.07290. + + .. + + Ramzi, Zaccharie, et al. “XPDNet for MRI Reconstruction: An Application to the 2020 FastMRI Challenge. \ + ” ArXiv:2010.07290 [Physics, Stat], July 2021. arXiv.org, http://arxiv.org/abs/2010.07290. + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): diff --git a/mridc/collections/reconstruction/models/zf.py b/mridc/collections/reconstruction/models/zf.py index bc38aa45..cf1f5439 100644 --- a/mridc/collections/reconstruction/models/zf.py +++ b/mridc/collections/reconstruction/models/zf.py @@ -20,12 +20,17 @@ class ZF(BaseMRIReconstructionModel, ABC): """ - Zero-Filled reconstruction using either root-sum-of-squares (RSS) or SENSE (SENSitivity Encoding) [1]. + Zero-Filled reconstruction using either root-sum-of-squares (RSS) or SENSE (SENSitivity Encoding), as presented \ + in Pruessmann KP, Weiger M, Scheidegger MB, Boesiger P. References ---------- - .. [1] Pruessmann KP, Weiger M, Scheidegger MB, Boesiger P. SENSE: Sensitivity encoding for fast MRI. Magn Reson Med 1999; 42:952-962. + .. + + Pruessmann KP, Weiger M, Scheidegger MB, Boesiger P. SENSE: Sensitivity encoding for fast MRI. Magn Reson \ + Med 1999; 42:952-962. + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): diff --git a/mridc/core/classes/dataset.py b/mridc/core/classes/dataset.py index 61c21d62..0bcc34e6 100644 --- a/mridc/core/classes/dataset.py +++ b/mridc/core/classes/dataset.py @@ -31,12 +31,16 @@ def collate_fn(self, batch): The method optionally performs neural type checking and add types to the outputs. Please note, subclasses of Dataset should not implement `input_types`. + # Usage: - dataloader = torch.utils.data.DataLoader( - ...., - collate_fn=dataset.collate_fn, - .... - ) + + .. code-block:: + + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) Returns ------- @@ -70,11 +74,14 @@ def collate_fn(self, batch): The method optionally performs neural type checking and add types to the outputs. # Usage: - dataloader = torch.utils.data.DataLoader( - ...., - collate_fn=dataset.collate_fn, - .... - ) + + .. code-block:: + + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) Returns ------- diff --git a/mridc/core/classes/export.py b/mridc/core/classes/export.py index 86584752..e4bc6706 100644 --- a/mridc/core/classes/export.py +++ b/mridc/core/classes/export.py @@ -3,7 +3,6 @@ # Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/classes/exportable.py -import os from abc import ABC import torch diff --git a/mridc/core/classes/modelPT.py b/mridc/core/classes/modelPT.py index 4df8ea1f..841495b5 100644 --- a/mridc/core/classes/modelPT.py +++ b/mridc/core/classes/modelPT.py @@ -55,12 +55,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): Parameters ---------- - cfg: configuration object. - The cfg object should have (optionally) the following sub-configs: - * train_ds - to instantiate training dataset - * validation_ds - to instantiate validation dataset - * test_ds - to instantiate testing dataset - * optim - to instantiate optimizer with learning rate scheduler + cfg: configuration object. The cfg object should have (optionally) the following sub-configs: + - train_ds - to instantiate training dataset + - validation_ds - to instantiate validation dataset + - test_ds - to instantiate testing dataset + - optim - to instantiate optimizer with learning rate scheduler trainer: Pytorch Lightning Trainer instance """ if trainer is not None and not isinstance(trainer, Trainer): @@ -152,16 +151,19 @@ def __init_subclass__(cls) -> None: cls._save_restore_connector = SaveRestoreConnector() def register_artifact(self, config_path: str, src: str, verify_src_exists: bool = True): - """Register model artifacts with this function. These artifacts (files) will be included inside .mridc file - when model.save_to("model.mridc") is called. + """ + Register model artifacts with this function. These artifacts (files) will be included inside .mridc file when + model.save_to("model.mridc") is called. + How it works: - 1. It always returns existing absolute path which can be used during Model constructor call - EXCEPTION: src is None or "" in which case nothing will be done and src will be returned - 2. It will add (config_path, model_utils.ArtifactItem()) pair to self.artifacts + 1. It always returns existing absolute path which can be used during Model constructor call EXCEPTION: \ + src is None or "" in which case nothing will be done and src will be returned + 2. It will add (config_path, model_utils.ArtifactItem()) pair to self.artifacts + If "src" is local existing path, then it will be returned in absolute path form. - elif "src" starts with "mridc_file:unique_artifact_name": - .mridc will be untarred to a temporary folder location and an actual existing path will be returned - else an error will be raised. + elif "src" starts with "mridc_file:unique_artifact_name" .mridc will be untarred to a temporary folder \ + location and an actual existing path will be returned else an error will be raised. + WARNING: use .register_artifact calls in your models' constructors. The returned path is not guaranteed to exist after you have exited your model's constructor. @@ -169,13 +171,13 @@ def register_artifact(self, config_path: str, src: str, verify_src_exists: bool ---------- config_path: Artifact key. Usually corresponds to the model config. src: Path to artifact. - verify_src_exists: If set to False, then the artifact is optional and register_artifact will return None even - if src is not found. Defaults to True. + verify_src_exists: If set to False, then the artifact is optional and register_artifact will return None \ + even if src is not found. Defaults to True. Returns ------- - If src is not None or empty it always returns absolute path which is guaranteed to exist during model instance - life. + If src is not None or empty it always returns absolute path which is guaranteed to exist during model \ + instance life. """ if src is None or not src: return src @@ -198,13 +200,13 @@ def save_to(self, save_path: str): """ Saves model instance (weights and configuration) into .mridc file. You can use "restore_from" method to fully restore instance from .mridc file. .mridc file is an archive (tar.gz) with the following: - model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for - model's constructor - model_wights.ckpt - model checkpoint + - model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for \ + model's constructor + - model_wights.ckpt - model checkpoint Parameters ---------- - save_path: Path to .mridc file where model instance should be saved. + Path to .mridc file where model instance should be saved. """ def maybe_make_save_dir(_path: "Path"): @@ -244,24 +246,27 @@ def restore_from( # type: ignore Parameters ---------- - restore_path: path to .mridc file from which model should be instantiated - override_config_path: path to a yaml config that will override the internal config file or an - OmegaConf/DictConfig object representing the model config. - map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will + restore_path: path to .mridc file from which model should be instantiated override_config_path: path to a \ + yaml config that will override the internal config file or an OmegaConf/DictConfig object representing the \ + model config. + map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will \ select a GPU if available, falling back to CPU otherwise. strict: Passed to load_state_dict. By default, True. - return_config: If set to true, will return just the underlying config of the restored model as an + return_config: If set to true, will return just the underlying config of the restored model as an \ OmegaConf/DictConfig object without instantiating the model. - trainer: Optional, a pytorch lightning Trainer object that will be forwarded to the instantiated model's + trainer: Optional, a pytorch lightning Trainer object that will be forwarded to the instantiated model's \ constructor. save_restore_connector: Can be overridden to add custom save and restore logic. Example ------- - ``` + + .. code-block:: + model = mridc.collections.asr.models.EncDecCTCModel.restore_from('asr.mridc') assert isinstance(model, mridc.collections.asr.models.EncDecCTCModel) - ``` + + Returns ------- An instance of type cls or its underlying config (if return_config is set). @@ -376,16 +381,15 @@ def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = N Parameters ---------- optim_config: A dictionary containing the following keys: - * "lr": mandatory key for learning rate. Will raise ValueError if not provided. - * "optimizer": string name pointing to one of the available optimizers in the registry. \ - If not provided, defaults to "adam". - * "opt_args": Optional list of strings, in the format "arg_name=arg_value". \ - The list of "arg_value" will be parsed and a dictionary of optimizer kwargs \ - will be built and supplied to instantiate the optimizer. + - lr: mandatory key for learning rate. Will raise ValueError if not provided. + - optimizer: string name pointing to one of the available optimizers in the registry. If not provided, \ + defaults to "adam". + - opt_args: Optional list of strings, in the format "arg_name=arg_value". The list of "arg_value" will \ + be parsed and a dictionary of optimizer kwargs will be built and supplied to instantiate the optimizer. Returns ------- - optimizer: An instance of an optimizer. + An instance of an optimizer. """ if self._optimizer_param_groups is None: self.setup_optimizer_param_groups() @@ -428,17 +432,17 @@ def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = N optim_config["sched"]["t_accumulate_grad_batches"] = self._trainer.accumulate_grad_batches optim_config["sched"]["t_limit_train_batches"] = self._trainer.limit_train_batches if self._trainer.accelerator is None: - optim_config["sched"]["t_num_workers"] = self._trainer.num_gpus or 1 + optim_config["sched"]["t_num_workers"] = self._trainer.num_devices or 1 elif self._trainer.accelerator == "ddp_cpu": - optim_config["sched"]["t_num_workers"] = self._trainer.num_processes * self._trainer.num_nodes + optim_config["sched"]["t_num_workers"] = self._trainer.num_devices * self._trainer.num_nodes elif self._trainer.accelerator == "ddp": - optim_config["sched"]["t_num_workers"] = self._trainer.num_gpus * self._trainer.num_nodes + optim_config["sched"]["t_num_workers"] = self._trainer.num_devices * self._trainer.num_nodes else: logging.warning( f"The lightning trainer received accelerator: {self._trainer.accelerator}. We " "recommend to use 'ddp' instead." ) - optim_config["sched"]["t_num_workers"] = self._trainer.num_gpus * self._trainer.num_nodes + optim_config["sched"]["t_num_workers"] = self._trainer.num_devices * self._trainer.num_nodes else: optim_config["sched"]["max_steps"] = self._trainer.max_steps @@ -539,15 +543,18 @@ def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = N def setup_optimizer_param_groups(self): """ - Used to create param groups for the optimizer. - As an example, this can be used to specify per-layer learning rates: - optim.SGD([ - {'params': model.base.parameters()}, - {'params': model.classifier.parameters(), 'lr': 1e-3} - ], lr=1e-2, momentum=0.9) - See https://pytorch.org/docs/stable/optim.html for more information. - By default, ModelPT will use self.parameters(). - Override this method to add custom param groups. + Used to create param groups for the optimizer. As an example, this can be used to specify per-layer learning + rates: + + .. code-block:: + + optim.SGD([ + {'params': model.base.parameters()}, + {'params': model.classifier.parameters(), 'lr': 1e-3} + ], lr=1e-2, momentum=0.9) + + See https://pytorch.org/docs/stable/optim.html for more information. By default, ModelPT will use + self.parameters(). Override this method to add custom param groups. """ param_groups = None if hasattr(self, "parameters"): @@ -583,6 +590,7 @@ def validation_epoch_end( via `multi_validation_epoch_end`. If multi dataset support is not required, override this method entirely in base class. In such a case, there is no need to implement `multi_validation_epoch_end` either. + .. note:: If more than one data loader exists, and they all provide `val_loss`, only the `val_loss` of the first data loader will be used by default. @@ -678,6 +686,7 @@ def test_epoch_end( via `multi_test_epoch_end`. If multi dataset support is not required, override this method entirely in base class. In such a case, there is no need to implement `multi_test_epoch_end` either. + .. note:: If more than one data loader exists, and they all provide `test_loss`, only the `test_loss` of the first data loader will be used by default. @@ -793,8 +802,8 @@ def multi_validation_epoch_end( @staticmethod def multi_test_epoch_end(outputs: Union[object, List[Dict[str, torch.Tensor]]], dataloader_idx: int = 0) -> None: """ - Adds support for multiple test datasets. Should be overridden by subclass, - to obtain appropriate logs for each of the dataloaders. + Adds support for multiple test datasets. Should be overridden by subclass, to obtain appropriate logs for each + of the dataloaders. Parameters ---------- @@ -850,23 +859,25 @@ def load_part_of_state_dict(self, state_dict, include, exclude, load_from_string @rank_zero_only def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: str = "cpu"): """ - Initializes a given model with the parameters obtained via specific config arguments. - The state dict of the provided model will be updated with `strict=False` setting to prevent - requirement of exact model parameters matching. + Initializes a given model with the parameters obtained via specific config arguments. The state dict of the \ + provided model will be updated with `strict=False` setting to prevent requirement of exact model parameters \ + matching. Initializations - --------------- - init_from_mridc_model: Str path to a .mridc model, which will be instantiated in order to extract the state + + init_from_mridc_model: Str path to a .mridc model, which will be instantiated in order to extract the state \ dict. - init_from_pretrained_model: Str name of a pretrained model checkpoint (obtained via cloud). The model will be - downloaded (or a cached copy will be used), instantiated and then its state dict will be extracted. - init_from_ptl_ckpt: Str name of a Pytorch Lightning checkpoint file. It will be loaded and the state dict will - extract. + + init_from_pretrained_model: Str name of a pretrained model checkpoint (obtained via cloud). The model will \ + be downloaded (or a cached copy will be used), instantiated and then its state dict will be extracted. + + init_from_ptl_ckpt: Str name of a Pytorch Lightning checkpoint file. It will be loaded and the state dict \ + will extract. Parameters ---------- cfg: The config used to instantiate the model. It needs only contain one of the above keys. - map_location: str or torch.device() which represents where the intermediate state dict (from the pretrained + map_location: str or torch.device() which represents where the intermediate state dict (from the pretrained \ model or checkpoint) will be loaded. """ args = ["init_from_mridc_model", "init_from_pretrained_model", "init_from_ptl_ckpt"] @@ -1012,29 +1023,40 @@ def extract_state_dict_from( save_dir: directory in which the saved state dict(s) should be stored split_by_module: bool flag, which determines whether the output checkpoint should be for the entire Model, or the individual module's that comprise the Model - save_restore_connector (SaveRestoreConnector): Can be overridden to add custom save and restore logic. + save_restore_connector: Can be overridden to add custom save and restore logic. Example ------- To convert the .mridc tarfile into a single Model level PyTorch checkpoint - :: - state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc', - './asr_ckpts') + + .. code-block:: + + state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc', \ + './asr_ckpts') + To restore a model from a Model level checkpoint - :: - model = mridc.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration - model.load_state_dict(torch.load("./asr_ckpts/model_weights.ckpt")) + + .. code-block:: + + model = mridc.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration + model.load_state_dict(torch.load("./asr_ckpts/model_weights.ckpt")) + To convert the .mridc tarfile into multiple Module level PyTorch checkpoints - :: - state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc', - './asr_ckpts', split_by_module=True) + + .. code-block:: + + state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc', \ + './asr_ckpts', split_by_module=True) + To restore a module from a Module level checkpoint - :: - model = mridc.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration - # load the individual components - model.preprocessor.load_state_dict(torch.load("./asr_ckpts/preprocessor.ckpt")) - model.encoder.load_state_dict(torch.load("./asr_ckpts/encoder.ckpt")) - model.decoder.load_state_dict(torch.load("./asr_ckpts/decoder.ckpt")) + + .. code-block:: + + model = mridc.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration + # load the individual components + model.preprocessor.load_state_dict(torch.load("./asr_ckpts/preprocessor.ckpt")) + model.encoder.load_state_dict(torch.load("./asr_ckpts/encoder.ckpt")) + model.decoder.load_state_dict(torch.load("./asr_ckpts/decoder.ckpt")) Returns ------- @@ -1051,9 +1073,11 @@ def extract_state_dict_from( def prepare_test(self, trainer: "Trainer") -> bool: """ - Helper method to check whether the model can safely be tested - on a dataset after training (or loading a checkpoint). - :: + Helper method to check whether the model can safely be tested on a dataset after training (or loading a + checkpoint). + + .. code-block:: + trainer = Trainer() if model.prepare_test(trainer): trainer.test(model) @@ -1066,7 +1090,7 @@ def prepare_test(self, trainer: "Trainer") -> bool: logging.info("No `test_ds` config found within the manifest.") return False - if trainer is not None and trainer.num_gpus > 1: + if trainer is not None and trainer.num_devices > 1: # Replace ddp multi-gpu until PTL has a fix DDP_WARN = """\n\nDuring testing, it is currently advisable to construct a new Trainer " "with single GPU and no DDP to obtain accurate results. @@ -1093,8 +1117,8 @@ def set_world_size(self, trainer: Trainer): # Update AppState with world information from trainer if isinstance(trainer, Trainer): app_state = AppState() - if self._trainer.num_gpus and self._trainer.num_nodes: # type: ignore - app_state.world_size = self._trainer.num_gpus * self._trainer.num_nodes # type: ignore + if self._trainer.num_devices and self._trainer.num_nodes: # type: ignore + app_state.world_size = self._trainer.num_devices * self._trainer.num_nodes # type: ignore else: logging.warning("World size can only be set by PyTorch Lightning Trainer.") @@ -1138,7 +1162,8 @@ def num_weights(self): def cfg(self): """ Property that holds the finalized internal config of the model. - Note: + + .. note:: Changes to this config are not reflected in the state of the model. Please create a new model using an updated config to properly update the model. """ @@ -1148,7 +1173,8 @@ def cfg(self): def cfg(self, cfg): """ Property that holds the finalized internal config of the model. - Note: + + .. note:: Changes to this config are not reflected in the state of the model. Please create a new model using an updated config to properly update the model. """ diff --git a/mridc/core/conf/dataloader.py b/mridc/core/conf/dataloader.py index 015b412a..bd8e20f7 100644 --- a/mridc/core/conf/dataloader.py +++ b/mridc/core/conf/dataloader.py @@ -15,6 +15,7 @@ class DataLoaderConfig: """ Configuration of PyTorch DataLoader. + ..note: For the details on the function/meanings of the arguments, please refer to: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader diff --git a/mridc/core/conf/optimizers.py b/mridc/core/conf/optimizers.py index fc8568d8..798ee625 100644 --- a/mridc/core/conf/optimizers.py +++ b/mridc/core/conf/optimizers.py @@ -36,7 +36,8 @@ class OptimizerParams: class SGDParams(OptimizerParams): """ Default configuration for Adam optimizer. - ..note: + + .. note:: For the details on the function/meanings of the arguments, please refer to: https://pytorch.org/docs/stable/optim.html?highlight=sgd#torch.optim.SGD """ @@ -51,7 +52,8 @@ class SGDParams(OptimizerParams): class AdamParams(OptimizerParams): """ Default configuration for Adam optimizer. - ..note: + + .. note:: For the details on the function/meanings of the arguments, please refer to: https://pytorch.org/docs/stable/optim.html?highlight=adam#torch.optim.Adam """ @@ -66,7 +68,8 @@ class AdamParams(OptimizerParams): class AdamWParams(OptimizerParams): """ Default configuration for AdamW optimizer. - ..note: + + .. note:: For the details on the function/meanings of the arguments, please refer to: https://pytorch.org/docs/stable/optim.html#torch.optim.AdamW """ @@ -81,7 +84,8 @@ class AdamWParams(OptimizerParams): class AdadeltaParams(OptimizerParams): """ Default configuration for Adadelta optimizer. - ..note: + + .. note:: For the details on the function/meanings of the arguments, please refer to: https://pytorch.org/docs/stable/optim.html#torch.optim.Adadelta """ @@ -95,7 +99,8 @@ class AdadeltaParams(OptimizerParams): class AdamaxParams(OptimizerParams): """ Default configuration for Adamax optimizer. - ..note: + + .. note:: For the details on the function/meanings of the arguments, please refer to: https://pytorch.org/docs/stable/optim.html#torch.optim.Adamax """ @@ -109,7 +114,8 @@ class AdamaxParams(OptimizerParams): class AdagradParams(OptimizerParams): """ Default configuration for Adagrad optimizer. - ..note: + + .. note:: For the details on the function/meanings of the arguments, please refer to: https://pytorch.org/docs/stable/optim.html#torch.optim.Adagrad """ @@ -124,7 +130,8 @@ class AdagradParams(OptimizerParams): class RMSpropParams(OptimizerParams): """ Default configuration for RMSprop optimizer. - ..note: + + .. note:: For the details on the function/meanings of the arguments, please refer to: https://pytorch.org/docs/stable/optim.html#torch.optim.RMSprop """ @@ -140,7 +147,8 @@ class RMSpropParams(OptimizerParams): class RpropParams(OptimizerParams): """ Default configuration for RpropParams optimizer. - ..note: + + .. note:: For the details on the function/meanings of the arguments, please refer to: https://pytorch.org/docs/stable/optim.html#torch.optim.Rprop """ @@ -152,18 +160,9 @@ class RpropParams(OptimizerParams): @dataclass class NovogradParams(OptimizerParams): """ - Configuration of the Novograd optimizer. It has been proposed in "Stochastic Gradient Methods with Layer-wise - Adaptive Moments for Training of Deep Networks" (https://arxiv.org/abs/1905.11286). - The OptimizerParams is a Base Optimizer params with no values. - User can choose it to explicitly override via command line arguments. - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper "On the Convergence of Adam and Beyond" + Configuration of the Novograd optimizer. It has been proposed in "Stochastic Gradient Methods with Layer-wise + Adaptive Moments for Training of Deep Networks" (https://arxiv.org/abs/1905.11286). The OptimizerParams is a Base + Optimizer params with no values. User can choose to explicitly override it via command line arguments. """ betas: Tuple[float, float] = (0.95, 0.98) @@ -171,6 +170,7 @@ class NovogradParams(OptimizerParams): weight_decay: float = 0 grad_averaging: bool = False amsgrad: bool = False + lr: float = 1e-3 luc: bool = False luc_trust: float = 1e-3 luc_eps: float = 1e-8 diff --git a/mridc/core/connectors/save_restore_connector.py b/mridc/core/connectors/save_restore_connector.py index 49956155..0d679bda 100644 --- a/mridc/core/connectors/save_restore_connector.py +++ b/mridc/core/connectors/save_restore_connector.py @@ -33,9 +33,9 @@ def save_to(self, model, save_path: str): Saves model instance (weights and configuration) into .mridc file. You can use "restore_from" method to fully restore instance from .mridc file. .mridc file is an archive (tar.gz) with the following: - model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for - model's constructor - model_wights.chpt - model checkpoint + - model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for \ + model's constructor + - model_wights.chpt - model checkpoint Parameters ---------- diff --git a/mridc/core/neural_types/axes.py b/mridc/core/neural_types/axes.py index c5706f1c..7cbc6fb5 100644 --- a/mridc/core/neural_types/axes.py +++ b/mridc/core/neural_types/axes.py @@ -19,9 +19,9 @@ class AxisKindAbstract(Enum): class AxisKind(AxisKindAbstract): """ - This Enum represents what does varying axis dimension mean. For example, does this dimension correspond to width, - batch, time, etc. The "Dimension" and "Channel" kinds are the same and used to represent a general axis. "Any" axis - will accept any axis kind fed to it. + This Enum represents what does varying axis dimension mean. For example, does this dimension correspond to width, \ + batch, time, etc. The "Dimension" and "Channel" kinds are the same and used to represent a general axis. "Any" \ + axis will accept any axis kind fed to it. """ # TODO (wdika): change names of the enums diff --git a/mridc/core/neural_types/elements.py b/mridc/core/neural_types/elements.py index dbdf371a..e674f741 100644 --- a/mridc/core/neural_types/elements.py +++ b/mridc/core/neural_types/elements.py @@ -66,10 +66,10 @@ def type_parameters(self) -> Dict: @property def fields(self) -> Optional[Tuple]: """ - This should be used to logically represent tuples/structures. For example, if you want to represent a bounding - box (x, y, width, height) you can put a tuple with names ('x', y', 'w', 'h') in here. Under the hood this - should be converted to the last tensor dimension of fixed size = len(fields). When two types are compared their - fields must match. + This should be used to logically represent tuples/structures. For example, if you want to represent a \ + bounding box (x, y, width, height) you can put a tuple with names ('x', y', 'w', 'h') in here. Under the \ + hood this should be converted to the last tensor dimension of fixed size = len(fields). When two types are \ + compared their fields must match. """ return None diff --git a/mridc/core/optim/lr_scheduler.py b/mridc/core/optim/lr_scheduler.py index d4ff903c..51154527 100644 --- a/mridc/core/optim/lr_scheduler.py +++ b/mridc/core/optim/lr_scheduler.py @@ -705,37 +705,62 @@ def prepare_lr_scheduler( Parameters ---------- optimizer: The optimizer to use for the scheduler. - name: - lr: - # - args: - name: auto # special keyword, resolves to correct optimizer config for given optimizer name - # cls: mridc.core.config.optimizers.NovogradParams # explicit instantiation by class path - params: # optional override parameters for the optimizer config - betas: [0.8, 0.5] - weight_decay: 0.001 + name: + + lr: + + # + + args: + + name: auto # special keyword, resolves to correct optimizer config for given optimizer name + + # cls: mridc.core.config.optimizers.NovogradParams # explicit instantiation by class path + + params: # optional override parameters for the optimizer config + + betas: [0.8, 0.5] + + weight_decay: 0.001 + scheduler_config: The scheduler config. + name: + iters_per_batch: null # computed at runtime; mandatory to have + max_steps: null # computed at runtime or explicitly set here; mandatory to have + # pytorch lightning args + monitor: val_loss + reduce_on_plateau: false + # + args: - name: auto # special keyword, resolves to correct optimizer config for given optimizer name - # cls: mridc.core.config.schedulers.CosineAnnealingParams # explicit instantiation by class path - params: # optional override parameters for the optimizer config - warmup_steps: null - warmup_ratio: null - min_lr: 0.0 - last_epoch: -1 - train_dataloader: Optional requirement, must be passed if "iters_per_batch" is defined instead of "max_steps". + + name: auto # special keyword, resolves to correct optimizer config for given optimizer name + + # cls: mridc.core.config.schedulers.CosineAnnealingParams # explicit instantiation by class path + + params: # optional override parameters for the optimizer config + + warmup_steps: null + + warmup_ratio: null + + min_lr: 0.0 + + last_epoch: -1 + + train_dataloader: Optional requirement, must be passed if "iters_per_batch" is defined instead of "max_steps". \ Used to compute effective "max_steps". Returns ------- - A dictionary containing the LR Scheduler implementation if the config was successfully parsed along with other + A dictionary containing the LR Scheduler implementation if the config was successfully parsed along with other \ parameters required by Pytorch Lightning, otherwise None. """ if scheduler_config is not None: diff --git a/mridc/launch.py b/mridc/launch.py index 918ccd1d..c08a072c 100644 --- a/mridc/launch.py +++ b/mridc/launch.py @@ -86,6 +86,8 @@ def main(cfg: DictConfig) -> None: model.load_state_dict(torch.load(checkpoint)["state_dict"]) if cfg.get("mode", None) == "train": + logging.info("Validating") + trainer.validate(model) logging.info("Training") trainer.fit(model) else: diff --git a/mridc/package_info.py b/mridc/package_info.py index a3e4308c..75d46895 100644 --- a/mridc/package_info.py +++ b/mridc/package_info.py @@ -1,8 +1,8 @@ # encoding: utf-8 MAJOR = 0 -MINOR = 0 -PATCH = 1 +MINOR = 1 +PATCH = 0 PRE_RELEASE = "" # Use the following formatting: (major, minor, patch, pre-release) diff --git a/mridc/utils/config_utils.py b/mridc/utils/config_utils.py index 7f54161f..76a2c298 100644 --- a/mridc/utils/config_utils.py +++ b/mridc/utils/config_utils.py @@ -20,28 +20,28 @@ def update_model_config(model_cls: MRIDCConfig, update_cfg: "DictConfig", drop_missing_subconfigs: bool = True): """ - Helper class that updates the default values of a ModelPT config class with the values in a DictConfig that mirrors - the structure of the config class. Assumes the `update_cfg` is a DictConfig (either generated manually, via hydra - or instantiated via yaml/model.cfg). This update_cfg is then used to override the default values preset inside the - ModelPT config class. If `drop_missing_subconfigs` is set, the certain sub-configs of the ModelPT config class - will be removed, if they are not found in the mirrored `update_cfg`. The following sub-configs are subject to - potential removal: + Helper class that updates the default values of a ModelPT config class with the values in a DictConfig that \ + mirrors the structure of the config class. Assumes the `update_cfg` is a DictConfig (either generated manually, \ + via hydra or instantiated via yaml/model.cfg). This update_cfg is then used to override the default values \ + preset inside the ModelPT config class. If `drop_missing_subconfigs` is set, the certain sub-configs of the \ + ModelPT config class will be removed, if they are not found in the mirrored `update_cfg`. The following \ + sub-configs are subject to potential removal: - `train_ds` - `validation_ds` - `test_ds` - - `optim` + nested `sched`. + - `optim` + nested sched Parameters ---------- model_cls: A subclass of MRIDC, that details in entirety all the parameters that constitute the MRIDC Model. - update_cfg: A DictConfig that mirrors the structure of the MRIDCConfig data class. Used to update the default + update_cfg: A DictConfig that mirrors the structure of the MRIDCConfig data class. Used to update the default \ values of the config class. - drop_missing_subconfigs: Bool which determines whether to drop certain sub-configs from the MRIDCConfig class, if - the corresponding sub-config is missing from `update_cfg`. + drop_missing_subconfigs: Bool which determines whether to drop certain sub-configs from the MRIDCConfig class, \ + if the corresponding sub-config is missing from `update_cfg`. Returns ------- - A DictConfig with updated values that can be used to instantiate the MRIDC Model along with supporting + A DictConfig with updated values that can be used to instantiate the MRIDC Model along with supporting \ infrastructure. """ if not _HAS_HYDRA: diff --git a/mridc/utils/env_var_parsing.py b/mridc/utils/env_var_parsing.py index 97aec0a5..efd160bf 100644 --- a/mridc/utils/env_var_parsing.py +++ b/mridc/utils/env_var_parsing.py @@ -42,13 +42,13 @@ def __init__(self, key): def _get_env(key, default=None, coerce=lambda x: x, required=False): """ - Return env var coerced into a type other than string. This function extends the standard os.getenv function to + Return env var coerced into a type other than string. This function extends the standard os.getenv function to \ enable the coercion of values into data types other than string (all env vars are strings by default). Parameters ---------- key: The name of the env var to retrieve. - default: The default value to return if the env var is not set. NB the default value is **not** coerced, and is + default: The default value to return if the env var is not set. NB the default value is **not** coerced, and is \ assumed to be of the correct type. coerce: A function that takes a string and returns a value of the desired type. required: If True, raises a RequiredSettingMissingError if the env var is not set. @@ -118,23 +118,21 @@ def _date(value): def get_env(key, *default, **kwargs): """ - Return env var. - This is the parent function of all other get_foo functions, - and is responsible for unpacking args/kwargs into the values - that _get_env expects (it is the root function that actually - interacts with environ). - Args: - key: string, the env var name to look up. - default: (optional) the value to use if the env var does not - exist. If this value is not supplied, then the env var is - considered to be required, and a RequiredSettingMissingError - error will be raised if it does not exist. - Kwargs: - coerce: a func that may be supplied to coerce the value into - something else. This is used by the default get_foo functions - to cast strings to builtin types, but could be a function that - returns a custom class. - Returns the env var, coerced if required, and a default if supplied. + Return env var. This is the parent function of all other get_foo functions, and is responsible for unpacking \ + args/kwargs into the values that _get_env expects (it is the root function that actually interacts with environ). + + Parameters + ---------- + key: string, the env var name to look up. + default: (optional) the value to use if the env var does not exist. If this value is not supplied, then the \ + env var is considered to be required, and a RequiredSettingMissingError error will be raised if it does not exist. + kwargs: + coerce: a func that may be supplied to coerce the value into something else. This is used by the default \ + get_foo functions to cast strings to builtin types, but could be a function that returns a custom class. + + Returns + ------- + The env var, coerced if required, and a default if supplied. """ if len(default) not in (0, 1): raise AssertionError("Too many args supplied.") diff --git a/mridc/utils/exp_manager.py b/mridc/utils/exp_manager.py index 95a087e4..dfaca02f 100644 --- a/mridc/utils/exp_manager.py +++ b/mridc/utils/exp_manager.py @@ -9,7 +9,7 @@ import time from copy import deepcopy from dataclasses import dataclass -from datetime import timedelta + from pathlib import Path from shutil import copy, move from typing import Any, Dict, List, Optional, Tuple, Union @@ -20,7 +20,7 @@ from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback, ModelCheckpoint -from pytorch_lightning.callbacks.timer import Interval, Timer +from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.loggers import LoggerCollection as _LoggerCollection, TensorBoardLogger, WandbLogger from pytorch_lightning.strategies.ddp import DDPStrategy @@ -175,68 +175,66 @@ def on_after_backward(self, trainer, pl_module): def exp_manager(trainer: Trainer, cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]: """ - exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning paradigm - of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will get exp_dir, - name, and version from the logger. Otherwise, it will use the exp_dir and name arguments to create the logging - directory. exp_manager also allows for explicit folder creation via explicit_log_dir. - - The version can be a datetime string or an integer. Datetime version can be disabled if you use_datetime_version is - set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch lightning. - It copies sys.argv, and git information if available to the logging directory. It creates a log file for each - process to log their output into. - - exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from - the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need - multiple consecutive jobs), you need to avoid creating the version folders. Therefore, from v1.0.0, when + exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning \ + paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will \ + get exp_dir, name, and version from the logger. Otherwise, it will use the exp_dir and name arguments to create \ + the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir. + + The version can be a datetime string or an integer. Datetime version can be disabled if you use_datetime_version \ + is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch \ + lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file \ + for each process to log their output into. + + exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from \ + the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need \ + multiple consecutive jobs), you need to avoid creating the version folders. Therefore, from v1.0.0, when \ resume_if_exists is set to True, creating the version folders is ignored. Parameters ---------- - trainer: The lightning trainer. + trainer: The lightning trainer object. cfg: Can have the following keys: - - explicit_log_dir (str, Path): Can be used to override exp_dir/name/version folder creation. Defaults to - None, which will use exp_dir, name, and version to construct the logging directory. - - exp_dir (str, Path): The base directory to create the logging directory. Defaults to None, which logs to - ./mridc_experiments. - - name (str): The name of the experiment. Defaults to None which turns into "default" via name = name or - "default". - - version (str): The version of the experiment. Defaults to None which uses either a datetime string or - lightning's TensorboardLogger system of using version_{int}. - - use_datetime_version (bool): Whether to use a datetime string for version. Defaults to True. - - resume_if_exists (bool): Whether this experiment is resuming from a previous run. If True, it sets - trainer._checkpoint_connector.resume_from_checkpoint_fit_path so that the trainer should auto-resume. - exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. From v1.0.0, when - resume_if_exists is True, we would not create version folders to make it easier to find the log folder - for next runs. - - resume_past_end (bool): exp_manager errors out if resume_if_exists is True and a checkpoint matching - *end.ckpt indicating a previous training run fully completed. This behaviour can be disabled, in which - case the *end.ckpt will be loaded by setting resume_past_end to True. Defaults to False. - - resume_ignore_no_checkpoint (bool): exp_manager errors out if resume_if_exists is True and no checkpoint - could be found. This behaviour can be disabled, in which case exp_manager will print a message and - continue without restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False. - - create_tensorboard_logger (bool): Whether to create a tensorboard logger and attach it to the pytorch - lightning trainer. Defaults to True. - - summary_writer_kwargs (dict): A dictionary of kwargs that can be passed to lightning's TensorboardLogger - class. Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None. - - create_wandb_logger (bool): Whether to create a Weights and Biases logger and attach it to the pytorch - lightning trainer. Defaults to False. - - wandb_logger_kwargs (dict): A dictionary of kwargs that can be passed to lightning's WandBLogger - class. Note that name and project are required parameters if create_wandb_logger is True. - Defaults to None. - - create_checkpoint_callback (bool): Whether to create a ModelCheckpoint callback and attach it to the - pytorch lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the - most recent checkpoint under *last.ckpt, and the final checkpoint after training completes under - *end.ckpt. Defaults to True. - - files_to_copy (list): A list of files to copy to the experiment logging directory. Defaults to None which - copies no files. - - log_local_rank_0_only (bool): Whether to only create log files for local rank 0. Defaults to False. - Set this to True if you are using DDP with many GPUs and do not want many log files in your exp dir. - - log_global_rank_0_only (bool): Whether to only create log files for global rank 0. Defaults to False. - Set this to True if you are using DDP with many GPUs and do not want many log files in your exp dir. + - explicit_log_dir: Can be used to override exp_dir/name/version folder creation. Defaults to None, which \ + will use exp_dir, name, and version to construct the logging directory. + - exp_dir: The base directory to create the logging directory. Defaults to None, which logs to \ + ./mridc_experiments. + - name: The name of the experiment. Defaults to None which turns into "default" via name = name or "default". + - version: The version of the experiment. Defaults to None which uses either a datetime string or lightning's \ + TensorboardLogger system of using version_{int}. + - use_datetime_version: Whether to use a datetime string for version. Defaults to True. + - resume_if_exists: Whether this experiment is resuming from a previous run. If True, it sets \ + trainer._checkpoint_connector.resume_from_checkpoint_fit_path so that the trainer should auto-resume. \ + exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. From v1.0.0, when \ + resume_if_exists is True, we would not create version folders to make it easier to find the log folder for \ + next runs. + - resume_past_end: exp_manager errors out if resume_if_exists is True and a checkpoint matching \*end.ckpt \ + indicating a previous training run fully completed. This behaviour can be disabled, in which case the \ + \*end.ckpt will be loaded by setting resume_past_end to True. Defaults to False. + - resume_ignore_no_checkpoint: exp_manager errors out if resume_if_exists is True and no checkpoint could be \ + found. This behaviour can be disabled, in which case exp_manager will print a message and continue without \ + restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False. + - create_tensorboard_logger: Whether to create a tensorboard logger and attach it to the pytorch lightning \ + trainer. Defaults to True. + - summary_writer_kwargs: A dictionary of kwargs that can be passed to lightning's TensorboardLogger class. \ + Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None. + - create_wandb_logger: Whether to create a Weights and Biases logger and attach it to the pytorch lightning \ + trainer. Defaults to False. + - wandb_logger_kwargs: A dictionary of kwargs that can be passed to lightning's WandBLogger class. Note that \ + name and project are required parameters if create_wandb_logger is True. Defaults to None. + - create_checkpoint_callback: Whether to create a ModelCheckpoint callback and attach it to the pytorch \ + lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most recent \ + checkpoint under \*last.ckpt, and the final checkpoint after training completes under \*end.ckpt. \ + Defaults to True. + - files_to_copy: A list of files to copy to the experiment logging directory. Defaults to None which copies \ + no files. + - log_local_rank_0_only: Whether to only create log files for local rank 0. Defaults to False. Set this to \ + True if you are using DDP with many GPUs and do not want many log files in your exp dir. + - log_global_rank_0_only: Whether to only create log files for global rank 0. Defaults to False. Set this to \ + True if you are using DDP with many GPUs and do not want many log files in your exp dir. + Returns ------- - log_dir: The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and - version. + The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and version. """ # Add rank information to logger # Note: trainer.global_rank and trainer.is_global_zero are not set until trainer.fit, so have to hack around it diff --git a/mridc/utils/model_utils.py b/mridc/utils/model_utils.py index 90ed07b1..5df86447 100644 --- a/mridc/utils/model_utils.py +++ b/mridc/utils/model_utils.py @@ -75,16 +75,20 @@ def resolve_dataset_name_from_cfg(cfg: "DictConfig") -> Union[Union[str, int, En directory. # Fast-path Resolution - In order to handle cases where we need to resolve items that are not paths, a fastpath - key can be provided as defined in the global `_VAL_TEST_FASTPATH_KEY`. + In order to handle cases where we need to resolve items that are not paths, a fastpath key can be provided as + defined in the global `_VAL_TEST_FASTPATH_KEY`. This key can be used in two ways : - ## _VAL_TEST_FASTPATH_KEY points to another key in the config + ## _VAL_TEST_FASTPATH_KEY points to another key in the config If this _VAL_TEST_FASTPATH_KEY points to another key in this config itself, then we assume we want to loop through the values of that key. This allows for any key in the config to become a fastpath key. - Example: + Example + ------- validation_ds: + + .. code-block:: + splits: "val" ... <_VAL_TEST_FASTPATH_KEY>: "splits" <-- this points to the key name "splits" @@ -97,18 +101,22 @@ def resolve_dataset_name_from_cfg(cfg: "DictConfig") -> Union[Union[str, int, En If this _VAL_TEST_FASTPATH_KEY does not point to another key in the config, then it is assumed that the items of this key itself are used for resolution. - Example: + Example + ------- validation_ds: - ... + + .. code-block:: + <_VAL_TEST_FASTPATH_KEY>: "val" <-- this points to the key name "splits" + Then we can write the following when overriding in hydra: ```python python train_file.py ... model.validation_ds.<_VAL_TEST_FASTPATH_KEY>=[val1, val2, dev1, dev2] ... ``` # IMPORTANT NOTE: It potentially mismatch if there exist more than 2 valid paths, and the first path does *not* resolve the - path of the data file (but does resolve to some other valid path). - To avoid this side effect, place the data path as the first item on the config file. + path of the data file (but does resolve to some other valid path). To avoid this side effect, place the data path + as the first item on the config file. Parameters ---------- @@ -213,19 +221,18 @@ def unique_names_check(name_list: Optional[List[str]]): def resolve_validation_dataloaders(model: ModelPT): """ Helper method that operates on the ModelPT class to automatically support multiple dataloaders for the validation - set. - It does so by first resolving the path to one/more data files via `resolve_dataset_name_from_cfg()`. - If this resolution fails, it assumes the data loader is prepared to manually support / not support - multiple data loaders and simply calls the appropriate setup method. + set. It does so by first resolving the path to one/more data files via `resolve_dataset_name_from_cfg()`. + If this resolution fails, it assumes the data loader is prepared to manually support / not support multiple data + loaders and simply calls the appropriate setup method. If resolution succeeds: - Checks if provided path is to a single file or a list of files. - If a single file is provided, simply tags that file as such and loads it via the setup method. - If multiple files are provided: - Inject a new manifest path at index "i" into the resolved key. - Calls the appropriate setup method to set the data loader. - Collects the initialized data loader in a list and preserves it. - Once all data loaders are processed, assigns the list of loaded loaders to the ModelPT. - Finally, assigns a list of unique names resolved from the file paths to the ModelPT. + - Checks if provided path is to a single file or a list of files. + If a single file is provided, simply tags that file as such and loads it via the setup method. + If multiple files are provided: + - Inject a new manifest path at index "i" into the resolved key. + - Calls the appropriate setup method to set the data loader. + - Collects the initialized data loader in a list and preserves it. + - Once all data loaders are processed, assigns the list of loaded loaders to the ModelPT. + - Finally, assigns a list of unique names resolved from the file paths to the ModelPT. Parameters ---------- @@ -376,9 +383,10 @@ def wrap_training_step(wrapped, instance: LightningModule, args, kwargs): def convert_model_config_to_dict_config(cfg: Union[DictConfig, MRIDCConfig]) -> DictConfig: """ Converts its input into a standard DictConfig. + Possible input values are: - - DictConfig - - A dataclass which is a subclass of MRIDCConfig + - DictConfig + - A dataclass which is a subclass of MRIDCConfig Parameters ---------- @@ -386,7 +394,7 @@ def convert_model_config_to_dict_config(cfg: Union[DictConfig, MRIDCConfig]) -> Returns ------- - The equivalent DictConfig + The equivalent DictConfig. """ if not _HAS_HYDRA: logging.error("This function requires Hydra/OmegaConf and it was not installed.") @@ -431,9 +439,9 @@ def maybe_update_config_version(cfg: "DictConfig"): """ Recursively convert Hydra 0.x configs to Hydra 1.x configs. Changes include: - - `cls` -> `_target_`. - - `params` -> drop params and shift all arguments to parent. - - `target` -> `_target_` cannot be performed due to ModelPT injecting `target` inside class. + - `cls` -> `_target_`. + - `params` -> drop params and shift all arguments to parent. + - `target` -> `_target_` cannot be performed due to ModelPT injecting `target` inside class. Parameters ---------- diff --git a/requirements/requirements.txt b/requirements/requirements.txt index ec676407..ed8ad2f8 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -19,5 +19,6 @@ sphinxcontrib-apidoc>=0.3.0 torch>=1.8.0 torchmetrics~=0.6.1 tqdm>=4.62.3 +wandb>=0.12.16 wget>=1.20.1 wrapt>=1.13.3 diff --git a/setup.py b/setup.py index 87fed1bc..93e31295 100644 --- a/setup.py +++ b/setup.py @@ -151,7 +151,7 @@ def finalize_options(self): # 5 - Production/Stable # 6 - Mature # 7 - Inactive - "Development Status :: 3 - Alpha", + "Development Status :: 5 - Production/Stable", # Indicate who your project is intended for "Intended Audience :: Developers", "Intended Audience :: Science/Research", @@ -160,10 +160,7 @@ def finalize_options(self): "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Mathematics", "Topic :: Scientific/Engineering :: Physics", - "Topic :: Scientific/Engineering :: Image Reconstruction", "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Scientific/Engineering :: Magnetic Resonance Imaging", - "Topic :: Scientific/Engineering :: Medical Imaging", "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Utilities", diff --git a/tests/hydra/__init__.py b/tests/hydra/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/hydra/config.yaml b/tests/hydra/config.yaml deleted file mode 100644 index c76072c3..00000000 --- a/tests/hydra/config.yaml +++ /dev/null @@ -1 +0,0 @@ -dataset_name: fastmri diff --git a/tests/hydra/config_ivalid.yaml b/tests/hydra/config_ivalid.yaml deleted file mode 100644 index 63e693a0..00000000 --- a/tests/hydra/config_ivalid.yaml +++ /dev/null @@ -1,2 +0,0 @@ -dataset_name: invalid_dataset -password: not_a_password diff --git a/tests/hydra/test_hydra_runner.py b/tests/hydra/test_hydra_runner.py deleted file mode 100644 index 318395e8..00000000 --- a/tests/hydra/test_hydra_runner.py +++ /dev/null @@ -1,51 +0,0 @@ -# encoding: utf-8 -__author__ = "Dimitrios Karkalousos" - -# Taken and adapted from: https://github.com/wdika/NeMo/blob/main/tests/hydra/test_hydra_runner.py - -import subprocess -import sys -from os import path - -import pytest - - -class TestHydraRunner: - """Test the hydra runner.""" - - @pytest.mark.integration - def test_no_config(self): - """Test app without config - fields missing causes error.""" - # Create system call. - call = "python test/hydra/tmp_launch.py" - - with pytest.raises(subprocess.CalledProcessError): - # Run the call as subprocess. - subprocess.check_call(call, shell=True, stdout=sys.stdout, stderr=sys.stdout) - - @pytest.mark.integration - def test_config1(self): - """Test injection of valid config.""" - # Create system call. - call = "python test/hydra/tmp_launch.py --config-name config.yaml" - - with pytest.raises(subprocess.CalledProcessError): - # Run the call as subprocess. - subprocess.check_call(call, shell=True, stdout=sys.stdout, stderr=sys.stdout) - - # Make sure that .hydra dir is not present. - if path.exists(".hydra"): - raise AssertionError - # Make sure that default hydra log file is not present. - if path.exists("tmp_launch.log"): - raise AssertionError - - @pytest.mark.integration - def test_config1_invalid(self): - """Test injection of invalid config.""" - # Create system call. - call = "python test/hydra/tmp_launch.py --config-name config_invalid.yaml" - - with pytest.raises(subprocess.CalledProcessError): - # Run the call as subprocess. - subprocess.check_call(call, shell=True, stdout=sys.stdout, stderr=sys.stdout) diff --git a/tests/hydra/tmp_launch.py b/tests/hydra/tmp_launch.py deleted file mode 100644 index da400fb4..00000000 --- a/tests/hydra/tmp_launch.py +++ /dev/null @@ -1,28 +0,0 @@ -from dataclasses import dataclass - -from omegaconf import MISSING, OmegaConf - -from mridc.core.conf.hydra_runner import hydra_runner - - -@dataclass -class DefaultConfig: - """ - This is structured config for this application. - It provides the schema used for validation of user-written spec file - as well as default values of the selected parameters. - """ - - # Dataset. Available options: [imdb, sst2] - dataset_name: str = MISSING - - -@hydra_runner(config_name="DefaultConfig", schema=DefaultConfig) -def tmp_launch(cfg): - """Launch the application.""" - print(OmegaConf.to_yaml(cfg)) - _ = cfg.dataset_name - - -if __name__ == "__main__": - tmp_launch()