From 3447b09435e72e856a4b29b436c2fbe61159a42f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 3 Jan 2024 16:30:24 +0000 Subject: [PATCH] 6676 port diffusion schedulers (#7332) Towards #6676 . ### Description This adds some base classes for DDPM noise schedulers + three scheduler types. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/networks.rst | 20 ++ monai/networks/schedulers/__init__.py | 17 ++ monai/networks/schedulers/ddim.py | 284 ++++++++++++++++++++++ monai/networks/schedulers/ddpm.py | 243 +++++++++++++++++++ monai/networks/schedulers/pndm.py | 316 +++++++++++++++++++++++++ monai/networks/schedulers/scheduler.py | 203 ++++++++++++++++ monai/utils/misc.py | 4 +- tests/test_scheduler_ddim.py | 83 +++++++ tests/test_scheduler_ddpm.py | 104 ++++++++ tests/test_scheduler_pndm.py | 108 +++++++++ 10 files changed, 1380 insertions(+), 2 deletions(-) create mode 100644 monai/networks/schedulers/__init__.py create mode 100644 monai/networks/schedulers/ddim.py create mode 100644 monai/networks/schedulers/ddpm.py create mode 100644 monai/networks/schedulers/pndm.py create mode 100644 monai/networks/schedulers/scheduler.py create mode 100644 tests/test_scheduler_ddim.py create mode 100644 tests/test_scheduler_ddpm.py create mode 100644 tests/test_scheduler_pndm.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 79d5ef822e..f9375f1e97 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -778,6 +778,26 @@ Nets .. autoclass:: MultiScalePatchDiscriminator :members: +Diffusion Schedulers +-------------------- +.. autoclass:: monai.networks.schedulers.Scheduler + :members: + +`DDPM Scheduler` +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.schedulers.DDPMScheduler + :members: + +`DDIM Scheduler` +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.schedulers.DDIMScheduler + :members: + +`PNDM Scheduler` +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.schedulers.PNDMScheduler + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/networks/schedulers/__init__.py b/monai/networks/schedulers/__init__.py new file mode 100644 index 0000000000..29e9020d65 --- /dev/null +++ b/monai/networks/schedulers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .ddim import DDIMScheduler +from .ddpm import DDPMScheduler +from .pndm import PNDMScheduler +from .scheduler import NoiseSchedules, Scheduler diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py new file mode 100644 index 0000000000..ec47ff8dc6 --- /dev/null +++ b/monai/networks/schedulers/ddim.py @@ -0,0 +1,284 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class DDIMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + + +class DDIMScheduler(Scheduler): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion + Implicit Models" https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one. + For the final step there is no previous alpha. When this option is `True` the previous alpha product is + fixed to `1`, otherwise it uses the value of alpha at step 0. + steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type: member of DDPMPredictionType + schedule_args: arguments to pass to the schedule function + + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = DDIMPredictionType.EPSILON, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if prediction_type not in DDIMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType") + + self.prediction_type = prediction_type + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64)) + + self.clip_sample = clip_sample + self.steps_offset = steps_offset + + # default the number of inference timesteps to the number of train steps + self.num_inference_steps: int + self.set_timesteps(self.num_train_timesteps) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.steps_offset + + def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance: torch.Tensor = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.0, + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + eta: weight of noise for added noise in diffusion step. + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. + generator: random number generator. + + Returns: + pred_prev_sample: Predicted previous sample + pred_original_sample: Predicted original sample + """ + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - model_output -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5) + pred_epsilon = model_output + elif self.prediction_type == DDIMPredictionType.SAMPLE: + pred_original_sample = model_output + pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance**0.5 + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon + + # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + + if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 + device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu") + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample, pred_original_sample + + def reversed_step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + pred_original_sample: Predicted original sample + """ + # See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf + + # Notation ( -> + # - model_output -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_post_sample -> "x_t+1" + + # 1. get previous step value (=t+1) + prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas at timestep t+1 + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.prediction_type == DDIMPredictionType.SAMPLE: + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + + # 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + return pred_post_sample, pred_original_sample diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py new file mode 100644 index 0000000000..a5173a1b65 --- /dev/null +++ b/monai/networks/schedulers/ddpm.py @@ -0,0 +1,243 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class DDPMVarianceType(StrEnum): + """ + Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise + to the denoised sample. + """ + + FIXED_SMALL = "fixed_small" + FIXED_LARGE = "fixed_large" + LEARNED = "learned" + LEARNED_RANGE = "learned_range" + + +class DDPMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + + +class DDPMScheduler(Scheduler): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models" + https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + variance_type: member of DDPMVarianceType + clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. + prediction_type: member of DDPMPredictionType + schedule_args: arguments to pass to the schedule function + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + variance_type: str = DDPMVarianceType.FIXED_SMALL, + clip_sample: bool = True, + prediction_type: str = DDPMPredictionType.EPSILON, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if variance_type not in DDPMVarianceType.__members__.values(): + raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") + + if prediction_type not in DDPMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`") + + self.clip_sample = clip_sample + self.variance_type = variance_type + self.prediction_type = prediction_type + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: + """ + Compute the mean of the posterior at timestep t. + + Args: + timestep: current timestep. + x0: the noise-free input. + x_t: the input noised to timestep t. + + Returns: + Returns the mean + """ + # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), + # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_t = self.alphas[timestep] + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) + x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + + mean: torch.Tensor = x_0_coefficient * x_0 + x_t_coefficient * x_t + + return mean + + def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor: + """ + Compute the variance of the posterior at timestep t. + + Args: + timestep: current timestep. + predicted_variance: variance predicted by the model. + + Returns: + Returns the variance + """ + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance: torch.Tensor = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] + # hacks - were probably added for training stability + if self.variance_type == DDPMVarianceType.FIXED_SMALL: + variance = torch.clamp(variance, min=1e-20) + elif self.variance_type == DDPMVarianceType.FIXED_LARGE: + variance = self.betas[timestep] + elif self.variance_type == DDPMVarianceType.LEARNED and predicted_variance is not None: + return predicted_variance + elif self.variance_type == DDPMVarianceType.LEARNED_RANGE and predicted_variance is not None: + min_log = variance + max_log = self.betas[timestep] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + generator: random number generator. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.prediction_type == DDPMPredictionType.EPSILON: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.prediction_type == DDPMPredictionType.SAMPLE: + pred_original_sample = model_output + elif self.prediction_type == DDPMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + + # 3. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t + current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if timestep > 0: + noise = torch.randn( + model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator + ).to(model_output.device) + variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample, pred_original_sample diff --git a/monai/networks/schedulers/pndm.py b/monai/networks/schedulers/pndm.py new file mode 100644 index 0000000000..c0728bbdff --- /dev/null +++ b/monai/networks/schedulers/pndm.py @@ -0,0 +1,316 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class PNDMPredictionType(StrEnum): + """ + Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + V_PREDICTION = "v_prediction" + + +class PNDMScheduler(Scheduler): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al., + "Pseudo Numerical Methods for Diffusion Models on Manifolds" https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + skip_prk_steps: + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms step. + set_alpha_to_one: + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + prediction_type: member of DDPMPredictionType + steps_offset: + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + schedule_args: arguments to pass to the schedule function + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + prediction_type: str = PNDMPredictionType.EPSILON, + steps_offset: int = 0, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if prediction_type not in PNDMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType") + + self.prediction_type = prediction_type + + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + self.skip_prk_steps = skip_prk_steps + self.steps_offset = steps_offset + + # running values + self.cur_model_output = torch.Tensor() + self.counter = 0 + self.cur_sample = torch.Tensor() + self.ets: list = [] + + # default the number of inference timesteps to the number of train steps + self.set_timesteps(num_train_timesteps) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64) + self._timesteps += self.steps_offset + + if self.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + self.prk_timesteps = np.array([]) + self.plms_timesteps = self._timesteps[::-1] + + else: + prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( + np.array([0, self.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + ) + self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() + self.plms_timesteps = self._timesteps[:-3][ + ::-1 + ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy + + timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + # update num_inference_steps - necessary if we use prk steps + self.num_inference_steps = len(self.timesteps) + + self.ets = [] + self.counter = 0 + + def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> tuple[torch.Tensor, Any]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + Returns: + pred_prev_sample: Predicted previous sample + """ + # return a tuple for consistency with samplers that return (previous pred, original sample pred) + + if self.counter < len(self.prk_timesteps) and not self.skip_prk_steps: + return self.step_prk(model_output=model_output, timestep=timestep, sample=sample), None + else: + return self.step_plms(model_output=model_output, timestep=timestep, sample=sample), None + + def step_prk(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> torch.Tensor: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = 0 if self.counter % 2 else self.num_train_timesteps // self.num_inference_steps // 2 + prev_timestep = timestep - diff_to_prev + timestep = self.prk_timesteps[self.counter // 4 * 4] + + if self.counter % 4 == 0: + self.cur_model_output = 1 / 6 * model_output + self.ets.append(model_output) + self.cur_sample = sample + elif (self.counter - 1) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 2) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 3) % 4 == 0: + model_output = self.cur_model_output + 1 / 6 * model_output + self.cur_model_output = torch.Tensor() + + # cur_sample should not be an empty torch.Tensor() + cur_sample = self.cur_sample if self.cur_sample.numel() != 0 else sample + + prev_sample: torch.Tensor = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + self.counter += 1 + + return prev_sample + + def step_plms(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> Any: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if not self.skip_prk_steps and len(self.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + ) + + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + + if self.counter != 1: + self.ets = self.ets[-3:] + self.ets.append(model_output) + else: + prev_timestep = timestep + timestep = timestep + self.num_train_timesteps // self.num_inference_steps + + if len(self.ets) == 1 and self.counter == 0: + model_output = model_output + self.cur_sample = sample + elif len(self.ets) == 1 and self.counter == 1: + model_output = (model_output + self.ets[-1]) / 2 + sample = self.cur_sample + self.cur_sample = torch.Tensor() + elif len(self.ets) == 2: + model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) + self.counter += 1 + + return prev_sample + + def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if self.prediction_type == PNDMPredictionType.V_PREDICTION: + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample diff --git a/monai/networks/schedulers/scheduler.py b/monai/networks/schedulers/scheduler.py new file mode 100644 index 0000000000..17bb526abc --- /dev/null +++ b/monai/networks/schedulers/scheduler.py @@ -0,0 +1,203 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +from __future__ import annotations + +import torch +import torch.nn as nn + +from monai.utils import ComponentStore, unsqueeze_right + +NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") + + +@NoiseSchedules.add_def("linear_beta", "Linear beta schedule") +def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + +@NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule") +def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Scaled linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + +@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule") +def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6): + """ + Sigmoid beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + sig_range: pos/neg range of sigmoid input, default 6 + + Returns: + betas: beta schedule tensor + """ + betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + + +@NoiseSchedules.add_def("cosine", "Cosine schedule") +def _cosine_beta(num_train_timesteps: int, s: float = 8e-3): + """ + Cosine noise schedule, see https://arxiv.org/abs/2102.09672 + + Args: + num_train_timesteps: number of timesteps + s: smoothing factor, default 8e-3 (see referenced paper) + + Returns: + (betas, alphas, alpha_cumprod) values + """ + x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1) + alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod /= alphas_cumprod[0].item() + alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999) + betas = 1.0 - alphas + return betas, alphas, alphas_cumprod[:-1] + + +class Scheduler(nn.Module): + """ + Base class for other schedulers based on a noise schedule function. + + This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here + the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`, + which is the name of a component in NoiseSchedules. These components must all be callables which return either + the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions + can be provided by using the NoiseSchedules.add_def, for example: + + .. code-block:: python + + from monai.networks.schedulers import NoiseSchedules, DDPMScheduler + + @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function") + def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule") + + All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of + timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through + the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules + to get a listing of stored objects with their docstring descriptions. + + Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule + type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended + to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are + still used for some schedules but these are provided as keyword arguments now. + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, + a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple + schedule_args: arguments to pass to the schedule function + """ + + def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None: + super().__init__() + schedule_args["num_train_timesteps"] = num_train_timesteps + noise_sched = NoiseSchedules[schedule](**schedule_args) + + # set betas, alphas, alphas_cumprod based off return value from noise function + if isinstance(noise_sched, tuple): + self.betas, self.alphas, self.alphas_cumprod = noise_sched + else: + self.betas = noise_sched + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + self.num_train_timesteps = num_train_timesteps + self.one = torch.tensor(1.0) + + # settable values + self.num_inference_steps: int | None = None + self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples: original samples + noise: noise to add to samples + timesteps: timesteps tensor indicating the timestep to be computed for each sample. + + Returns: + noisy_samples: sample with added noise + """ + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_cumprod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim) + sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( + (1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim + ) + + noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) + sqrt_one_minus_alpha_prod = unsqueeze_right((1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/monai/utils/misc.py b/monai/utils/misc.py index d6ff370f69..4f2501a7ee 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -890,11 +890,11 @@ def is_sqrt(num: Sequence[int] | int) -> bool: return ensure_tuple(ret) == num -def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: +def unsqueeze_right(arr: torch.Tensor, ndim: int) -> torch.Tensor: """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(...,) + (None,) * (ndim - arr.ndim)] -def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: +def unsqueeze_left(arr: torch.Tensor, ndim: int) -> torch.Tensor: """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(None,) * (ndim - arr.ndim)] diff --git a/tests/test_scheduler_ddim.py b/tests/test_scheduler_ddim.py new file mode 100644 index 0000000000..1a8f8cab67 --- /dev/null +++ b/tests/test_scheduler_ddim.py @@ -0,0 +1,83 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import DDIMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-0.9579, -0.6457], [0.4684, -0.9694]]]])] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(num_inference_steps=100) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(num_inference_steps=100) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1].shape, expected_shape) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + def test_set_timesteps(self): + scheduler = DDIMScheduler(num_train_timesteps=1000) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = DDIMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scheduler_ddpm.py b/tests/test_scheduler_ddpm.py new file mode 100644 index 0000000000..f0447aded2 --- /dev/null +++ b/tests/test_scheduler_ddpm.py @@ -0,0 +1,104 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import DDPMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + for variance_type in ["fixed_small", "fixed_large"]: + TEST_2D_CASE.append( + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)] + ) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + for variance_type in ["fixed_small", "fixed_large"]: + TEST_3D_CASE.append( + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)] + ) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-1.0153, -0.3218], [0.8454, -0.7870]]]])] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1].shape, expected_shape) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = DDPMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + @parameterized.expand(TEST_CASES) + def test_get_velocity_shape(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + sample = torch.randn(input_shape) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long() + velocity = scheduler.get_velocity(sample=sample, noise=sample, timesteps=timesteps) + self.assertEqual(velocity.shape, expected_shape) + + def test_step_learned(self): + for variance_type in ["learned", "learned_range"]: + scheduler = DDPMScheduler(variance_type=variance_type) + model_output = torch.randn(2, 6, 16, 16) + sample = torch.randn(2, 3, 16, 16) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, sample.shape) + self.assertEqual(output_step[1].shape, sample.shape) + + def test_set_timesteps(self): + scheduler = DDPMScheduler(num_train_timesteps=1000) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = DDPMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py new file mode 100644 index 0000000000..69e5e403f5 --- /dev/null +++ b/tests/test_scheduler_pndm.py @@ -0,0 +1,108 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import PNDMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [ + {"schedule": "linear_beta"}, + (1, 1, 2, 2), + torch.Tensor([[[[-2123055.2500, -459014.2812], [2863438.0000, -1263401.7500]]]]), + ] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = PNDMScheduler(**input_param) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(600) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1], None) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + @parameterized.expand(TEST_FULl_LOOP) + def test_timestep_two_loops(self, input_param, input_shape, expected_output): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + torch.manual_seed(42) + model_output2 = torch.randn(input_shape) + sample2 = torch.randn(input_shape) + scheduler.set_timesteps(50) + for t in range(50): + sample2, _ = scheduler.step(model_output=model_output2, timestep=t, sample=sample2) + assert_allclose(sample, sample2, rtol=1e-3, atol=1e-3) + + def test_set_timesteps(self): + scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=True) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_prk(self): + scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=False) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 109) + self.assertEqual(len(scheduler.timesteps), 109) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = PNDMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main()