diff --git a/src/create_quantiles.py b/src/create_quantiles.py
index 181cf62..5241b32 100644
--- a/src/create_quantiles.py
+++ b/src/create_quantiles.py
@@ -29,13 +29,15 @@ def main(cfg: DictConfig):
)
# Extract the xarray dataset and denormalize it
- xr_ds = dataset.xr_data.map(denorm).sel(time=slice(START_YEAR, END_YEAR))
+ xr_ds = dataset.xr_data.map(denorm).sel(time=slice(START_YEAR, END_YEAR)).compute()
+ breakpoint()
+ # Group by each day of the year
+ groups = xr_ds.groupby("time.dayofyear")
- # Compute the quantiles
- quantiles = xr_ds.load().quantile(cfg.quantile, dim="time").drop_vars("quantile")
+ quantiles = groups.quantile(q=[0.9, 0.95, 0.99, 0.999], dim="time")
# Save the quantiles
- save_name = f"{cfg.var}_{int(cfg.quantile * 100)}.nc"
+ save_name = f"{cfg.var}_quantiles.nc"
save_path = os.path.join(cfg.paths.quantile_dir, cfg.esm, save_name)
# Delete the file if it already exists (avoids permission denied errors)
diff --git a/src/custom_diffusers/configuration_utils.py b/src/custom_diffusers/configuration_utils.py
new file mode 100644
index 0000000..6ed837f
--- /dev/null
+++ b/src/custom_diffusers/configuration_utils.py
@@ -0,0 +1,702 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" ConfigMixin base class and utilities."""
+import dataclasses
+import functools
+import importlib
+import inspect
+import json
+import os
+import re
+from collections import OrderedDict
+from pathlib import PosixPath
+from typing import Any, Dict, Tuple, Union
+
+import numpy as np
+from huggingface_hub import create_repo, hf_hub_download
+from huggingface_hub.utils import (
+ EntryNotFoundError,
+ RepositoryNotFoundError,
+ RevisionNotFoundError,
+ validate_hf_hub_args,
+)
+from requests import HTTPError
+
+from . import __version__
+from .utils import (
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
+ deprecate,
+ extract_commit_hash,
+ http_user_agent,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__)
+
+_re_configuration_file = re.compile(r"config\.(.*)\.json")
+
+
+class FrozenDict(OrderedDict):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ for key, value in self.items():
+ setattr(self, key, value)
+
+ self.__frozen = True
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __setattr__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setitem__(name, value)
+
+
+class ConfigMixin:
+ r"""
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
+ saving classes that inherit from [`ConfigMixin`].
+
+ Class attributes:
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
+ overridden by subclass).
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
+ subclass).
+ """
+
+ config_name = None
+ ignore_for_config = []
+ has_compatibles = False
+
+ _deprecated_kwargs = []
+
+ def register_to_config(self, **kwargs):
+ if self.config_name is None:
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
+ # Special case for `kwargs` used in deprecation warning added to schedulers
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
+ # or solve in a more general way.
+ kwargs.pop("kwargs", None)
+
+ if not hasattr(self, "_internal_dict"):
+ internal_dict = kwargs
+ else:
+ previous_dict = dict(self._internal_dict)
+ internal_dict = {**self._internal_dict, **kwargs}
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
+
+ self._internal_dict = FrozenDict(internal_dict)
+
+ def __getattr__(self, name: str) -> Any:
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
+
+ This function is mostly copied from PyTorch's __getattr__ overwrite:
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
+ """
+
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
+ is_attribute = name in self.__dict__
+
+ if is_in_config and not is_attribute:
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
+ return self._internal_dict[name]
+
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
+
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
+ [`~ConfigMixin.from_config`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we save using the predefined names, we can load using `from_config`
+ output_config_file = os.path.join(save_directory, self.config_name)
+
+ self.to_json_file(output_config_file)
+ logger.info(f"Configuration saved in {output_config_file}")
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ private = kwargs.pop("private", False)
+ create_pr = kwargs.pop("create_pr", False)
+ token = kwargs.pop("token", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
+
+ self._upload_folder(
+ save_directory,
+ repo_id,
+ token=token,
+ commit_message=commit_message,
+ create_pr=create_pr,
+ )
+
+ @classmethod
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
+ r"""
+ Instantiate a Python class from a config dictionary.
+
+ Parameters:
+ config (`Dict[str, Any]`):
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
+ files of compatible classes.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it is loaded) and initiate the Python class.
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
+ overwrite the same named arguments in `config`.
+
+ Returns:
+ [`ModelMixin`] or [`SchedulerMixin`]:
+ A model or scheduler object instantiated from a config dictionary.
+
+ Examples:
+
+ ```python
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
+
+ >>> # Download scheduler from huggingface.co and cache.
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
+
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
+
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
+ ```
+ """
+ # <===== TO BE REMOVED WITH DEPRECATION
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
+ if "pretrained_model_name_or_path" in kwargs:
+ config = kwargs.pop("pretrained_model_name_or_path")
+
+ if config is None:
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
+ # ======>
+
+ if not isinstance(config, dict):
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
+ if "Scheduler" in cls.__name__:
+ deprecation_message += (
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
+ " be removed in v1.0.0."
+ )
+ elif "Model" in cls.__name__:
+ deprecation_message += (
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
+ " instead. This functionality will be removed in v1.0.0."
+ )
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
+
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
+
+ # Allow dtype to be specified on initialization
+ if "dtype" in unused_kwargs:
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
+
+ # add possible deprecated kwargs
+ for deprecated_kwarg in cls._deprecated_kwargs:
+ if deprecated_kwarg in unused_kwargs:
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
+
+ # Return model and optionally state and/or unused_kwargs
+ model = cls(**init_dict)
+
+ # make sure to also save config parameters that might be used for compatible classes
+ # update _class_name
+ if "_class_name" in hidden_dict:
+ hidden_dict["_class_name"] = cls.__name__
+
+ model.register_to_config(**hidden_dict)
+
+ # add hidden kwargs of compatible classes to unused_kwargs
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
+
+ if return_unused_kwargs:
+ return (model, unused_kwargs)
+ else:
+ return model
+
+ @classmethod
+ def get_config_dict(cls, *args, **kwargs):
+ deprecation_message = (
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
+ " removed in version v1.0.0"
+ )
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
+ return cls.load_config(*args, **kwargs)
+
+ @classmethod
+ @validate_hf_hub_args
+ def load_config(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ return_unused_kwargs=False,
+ return_commit_hash=False,
+ **kwargs,
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ r"""
+ Load a model or scheduler configuration.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
+ [`~ConfigMixin.save_config`].
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
+ Whether unused keyword arguments of the config are returned.
+ return_commit_hash (`bool`, *optional*, defaults to `False):
+ Whether the `commit_hash` of the loaded configuration are returned.
+
+ Returns:
+ `dict`:
+ A dictionary of all the parameters stored in a JSON configuration file.
+
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ _ = kwargs.pop("mirror", None)
+ subfolder = kwargs.pop("subfolder", None)
+ user_agent = kwargs.pop("user_agent", {})
+
+ user_agent = {**user_agent, "file_type": "config"}
+ user_agent = http_user_agent(user_agent)
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ if cls.config_name is None:
+ raise ValueError(
+ "`self.config_name` is not defined. Note that one should not load a config from "
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
+ )
+
+ if os.path.isfile(pretrained_model_name_or_path):
+ config_file = pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
+ # Load from a PyTorch checkpoint
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ ):
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ config_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=cls.config_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
+ " token having permission to this repo with `token` or log in with `huggingface-cli login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
+ " this model name. Check the model page at"
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
+ " run the library in offline mode at"
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a {cls.config_name} file"
+ )
+
+ try:
+ # Load config dict
+ config_dict = cls._dict_from_json_file(config_file)
+
+ commit_hash = extract_commit_hash(config_file)
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
+
+ if not (return_unused_kwargs or return_commit_hash):
+ return config_dict
+
+ outputs = (config_dict,)
+
+ if return_unused_kwargs:
+ outputs += (kwargs,)
+
+ if return_commit_hash:
+ outputs += (commit_hash,)
+
+ return outputs
+
+ @staticmethod
+ def _get_init_keys(cls):
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
+
+ @classmethod
+ def extract_init_dict(cls, config_dict, **kwargs):
+ # Skip keys that were not present in the original config, so default __init__ values were used
+ used_defaults = config_dict.get("_use_default_values", [])
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
+
+ # 0. Copy origin config dict
+ original_dict = dict(config_dict.items())
+
+ # 1. Retrieve expected config attributes from __init__ signature
+ expected_keys = cls._get_init_keys(cls)
+ expected_keys.remove("self")
+ # remove general kwargs if present in dict
+ if "kwargs" in expected_keys:
+ expected_keys.remove("kwargs")
+ # remove flax internal keys
+ if hasattr(cls, "_flax_internal_args"):
+ for arg in cls._flax_internal_args:
+ expected_keys.remove(arg)
+
+ # 2. Remove attributes that cannot be expected from expected config attributes
+ # remove keys to be ignored
+ if len(cls.ignore_for_config) > 0:
+ expected_keys = expected_keys - set(cls.ignore_for_config)
+
+ # load diffusers library to import compatible and original scheduler
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
+
+ if cls.has_compatibles:
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
+ else:
+ compatible_classes = []
+
+ expected_keys_comp_cls = set()
+ for c in compatible_classes:
+ expected_keys_c = cls._get_init_keys(c)
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
+
+ # remove attributes from orig class that cannot be expected
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
+ if (
+ isinstance(orig_cls_name, str)
+ and orig_cls_name != cls.__name__
+ and hasattr(diffusers_library, orig_cls_name)
+ ):
+ orig_cls = getattr(diffusers_library, orig_cls_name)
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
+ elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
+ raise ValueError(
+ "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
+ )
+
+ # remove private attributes
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
+
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
+ init_dict = {}
+ for key in expected_keys:
+ # if config param is passed to kwarg and is present in config dict
+ # it should overwrite existing config dict key
+ if key in kwargs and key in config_dict:
+ config_dict[key] = kwargs.pop(key)
+
+ if key in kwargs:
+ # overwrite key
+ init_dict[key] = kwargs.pop(key)
+ elif key in config_dict:
+ # use value from config dict
+ init_dict[key] = config_dict.pop(key)
+
+ # 4. Give nice warning if unexpected values have been passed
+ if len(config_dict) > 0:
+ logger.warning(
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
+ "but are not expected and will be ignored. Please verify your "
+ f"{cls.config_name} configuration file."
+ )
+
+ # 5. Give nice info if config attributes are initialized to default because they have not been passed
+ passed_keys = set(init_dict.keys())
+ if len(expected_keys - passed_keys) > 0:
+ logger.info(
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
+ )
+
+ # 6. Define unused keyword arguments
+ unused_kwargs = {**config_dict, **kwargs}
+
+ # 7. Define "hidden" config parameters that were saved for compatible classes
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
+
+ return init_dict, unused_kwargs, hidden_config_dict
+
+ @classmethod
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ return json.loads(text)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @property
+ def config(self) -> Dict[str, Any]:
+ """
+ Returns the config of the class as a frozen dictionary
+
+ Returns:
+ `Dict[str, Any]`: Config of the class.
+ """
+ return self._internal_dict
+
+ def to_json_string(self) -> str:
+ """
+ Serializes the configuration instance to a JSON string.
+
+ Returns:
+ `str`:
+ String containing all the attributes that make up the configuration instance in JSON format.
+ """
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
+ config_dict["_class_name"] = self.__class__.__name__
+ config_dict["_diffusers_version"] = __version__
+
+ def to_json_saveable(value):
+ if isinstance(value, np.ndarray):
+ value = value.tolist()
+ elif isinstance(value, PosixPath):
+ value = str(value)
+ return value
+
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
+ # Don't save "_ignore_files" or "_use_default_values"
+ config_dict.pop("_ignore_files", None)
+ config_dict.pop("_use_default_values", None)
+
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save the configuration instance's parameters to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file to save a configuration instance's parameters.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+
+def register_to_config(init):
+ r"""
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
+
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
+ """
+
+ @functools.wraps(init)
+ def inner_init(self, *args, **kwargs):
+ # Ignore private kwargs in the init.
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
+ if not isinstance(self, ConfigMixin):
+ raise RuntimeError(
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
+ "not inherit from `ConfigMixin`."
+ )
+
+ ignore = getattr(self, "ignore_for_config", [])
+ # Get positional arguments aligned with kwargs
+ new_kwargs = {}
+ signature = inspect.signature(init)
+ parameters = {
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
+ }
+ for arg, name in zip(args, parameters.keys()):
+ new_kwargs[name] = arg
+
+ # Then add all kwargs
+ new_kwargs.update(
+ {
+ k: init_kwargs.get(k, default)
+ for k, default in parameters.items()
+ if k not in ignore and k not in new_kwargs
+ }
+ )
+
+ # Take note of the parameters that were not present in the loaded config
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
+
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
+ getattr(self, "register_to_config")(**new_kwargs)
+ init(self, *args, **init_kwargs)
+
+ return inner_init
+
+
+def flax_register_to_config(cls):
+ original_init = cls.__init__
+
+ @functools.wraps(original_init)
+ def init(self, *args, **kwargs):
+ if not isinstance(self, ConfigMixin):
+ raise RuntimeError(
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
+ "not inherit from `ConfigMixin`."
+ )
+
+ # Ignore private kwargs in the init. Retrieve all passed attributes
+ init_kwargs = dict(kwargs.items())
+
+ # Retrieve default values
+ fields = dataclasses.fields(self)
+ default_kwargs = {}
+ for field in fields:
+ # ignore flax specific attributes
+ if field.name in self._flax_internal_args:
+ continue
+ if type(field.default) == dataclasses._MISSING_TYPE:
+ default_kwargs[field.name] = None
+ else:
+ default_kwargs[field.name] = getattr(self, field.name)
+
+ # Make sure init_kwargs override default kwargs
+ new_kwargs = {**default_kwargs, **init_kwargs}
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
+ if "dtype" in new_kwargs:
+ new_kwargs.pop("dtype")
+
+ # Get positional arguments aligned with kwargs
+ for i, arg in enumerate(args):
+ name = fields[i].name
+ new_kwargs[name] = arg
+
+ # Take note of the parameters that were not present in the loaded config
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
+
+ getattr(self, "register_to_config")(**new_kwargs)
+ original_init(self, *args, **kwargs)
+
+ cls.__init__ = init
+ return cls
\ No newline at end of file
diff --git a/src/custom_diffusers/dpmsolver_multistep.py b/src/custom_diffusers/dpmsolver_multistep.py
new file mode 100644
index 0000000..da8b149
--- /dev/null
+++ b/src/custom_diffusers/dpmsolver_multistep.py
@@ -0,0 +1,1160 @@
+# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import deprecate
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.schedulers.scheduling_utils import (
+ KarrasDiffusionSchedulers,
+ SchedulerMixin,
+ SchedulerOutput,
+)
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ solver_order (`int`, defaults to 2):
+ The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
+ sampling, and `solver_order=3` for unconditional sampling.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
+ `algorithm_type="dpmsolver++"`.
+ algorithm_type (`str`, defaults to `dpmsolver++`):
+ Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
+ `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
+ paper, and the `dpmsolver++` type implements the algorithms in the
+ [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
+ `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
+ solver_type (`str`, defaults to `midpoint`):
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
+ lower_order_final (`bool`, defaults to `True`):
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
+ euler_at_final (`bool`, defaults to `False`):
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
+ steps, but sometimes may result in blurring.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
+ the sigmas are determined according to a sequence of noise levels {σi}.
+ use_lu_lambdas (`bool`, *optional*, defaults to `False`):
+ Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
+ the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
+ `lambda(t)`.
+ final_sigmas_type (`str`, defaults to `"zero"`):
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
+ is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ lambda_min_clipped (`float`, defaults to `-inf`):
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
+ cosine (`squaredcos_cap_v2`) noise schedule.
+ variance_type (`str`, *optional*):
+ Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
+ contains the predicted Gaussian variance.
+ timestep_spacing (`str`, defaults to `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ solver_order: int = 2,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ algorithm_type: str = "dpmsolver++",
+ solver_type: str = "midpoint",
+ lower_order_final: bool = True,
+ euler_at_final: bool = False,
+ use_karras_sigmas: Optional[bool] = False,
+ use_lu_lambdas: Optional[bool] = False,
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
+ lambda_min_clipped: float = -float("inf"),
+ variance_type: Optional[str] = None,
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ rescale_betas_zero_snr: bool = False,
+ ):
+ if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
+ deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
+ deprecate(
+ "algorithm_types dpmsolver and sde-dpmsolver",
+ "1.0.0",
+ deprecation_message,
+ )
+
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(
+ beta_start, beta_end, num_train_timesteps, dtype=torch.float32
+ )
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(
+ beta_start**0.5,
+ beta_end**0.5,
+ num_train_timesteps,
+ dtype=torch.float32,
+ )
+ ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(
+ f"{beta_schedule} does is not implemented for {self.__class__}"
+ )
+
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ if rescale_betas_zero_snr:
+ # Close to 0 without being 0 so first sigma is not inf
+ # FP16 smallest positive subnormal works well here
+ self.alphas_cumprod[-1] = 2**-24
+
+ # Currently we only support VP-type noise schedule
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # settings for DPM-Solver
+ if algorithm_type not in [
+ "dpmsolver",
+ "dpmsolver++",
+ "sde-dpmsolver",
+ "sde-dpmsolver++",
+ ]:
+ if algorithm_type == "deis":
+ self.register_to_config(algorithm_type="dpmsolver++")
+ else:
+ raise NotImplementedError(
+ f"{algorithm_type} does is not implemented for {self.__class__}"
+ )
+
+ if solver_type not in ["midpoint", "heun"]:
+ if solver_type in ["logrho", "bh1", "bh2"]:
+ self.register_to_config(solver_type="midpoint")
+ else:
+ raise NotImplementedError(
+ f"{solver_type} does is not implemented for {self.__class__}"
+ )
+
+ if (
+ algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]
+ and final_sigmas_type == "zero"
+ ):
+ raise ValueError(
+ f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
+ )
+
+ # setable values
+ self.num_inference_steps = None
+ timesteps = np.linspace(
+ 0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32
+ )[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
+ self.model_outputs = [None] * solver_order
+ self.lower_order_nums = 0
+ self._step_index = None
+ self._begin_index = None
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increae 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ def set_timesteps(
+ self, num_inference_steps: int = None, device: Union[str, torch.device] = None
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ # Clipping the minimum of all lambda(t) for numerical stability.
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
+ clipped_idx = torch.searchsorted(
+ torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped
+ )
+ last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, last_timestep - 1, num_inference_steps + 1)
+ .round()[::-1][:-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = last_timestep // (num_inference_steps + 1)
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (
+ (np.arange(0, num_inference_steps + 1) * step_ratio)
+ .round()[::-1][:-1]
+ .copy()
+ .astype(np.int64)
+ )
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (
+ np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
+ )
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ log_sigmas = np.log(sigmas)
+
+ if self.config.use_karras_sigmas:
+ sigmas = np.flip(sigmas).copy()
+ sigmas = self._convert_to_karras(
+ in_sigmas=sigmas, num_inference_steps=num_inference_steps
+ )
+ timesteps = np.array(
+ [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
+ ).round()
+ elif self.config.use_lu_lambdas:
+ lambdas = np.flip(log_sigmas.copy())
+ lambdas = self._convert_to_lu(
+ in_lambdas=lambdas, num_inference_steps=num_inference_steps
+ )
+ sigmas = np.exp(lambdas)
+ timesteps = np.array(
+ [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
+ ).round()
+ else:
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
+
+ self.sigmas = torch.from_numpy(sigmas)
+ self.timesteps = torch.from_numpy(timesteps).to(
+ device=device, dtype=torch.int64
+ )
+
+ self.num_inference_steps = len(timesteps)
+
+ self.model_outputs = [
+ None,
+ ] * self.config.solver_order
+ self.lower_order_nums = 0
+
+ # add an index counter for schedulers that allow duplicated timesteps
+ self._step_index = None
+ self._begin_index = None
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, *remaining_dims = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = (
+ sample.float()
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = (
+ torch.clamp(sample, -s, s) / s
+ ) # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
+ sample = sample.to(dtype)
+
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
+ def _sigma_to_t(self, sigma, log_sigmas):
+ # get log sigma
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
+
+ # get distribution
+ dists = log_sigma - log_sigmas[:, np.newaxis]
+
+ # get sigmas range
+ low_idx = (
+ np.cumsum((dists >= 0), axis=0)
+ .argmax(axis=0)
+ .clip(max=log_sigmas.shape[0] - 2)
+ )
+ high_idx = low_idx + 1
+
+ low = log_sigmas[low_idx]
+ high = log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = np.clip(w, 0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.reshape(sigma.shape)
+ return t
+
+ def _sigma_to_alpha_sigma_t(self, sigma):
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
+ sigma_t = sigma * alpha_t
+
+ return alpha_t, sigma_t
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
+ def _convert_to_karras(
+ self, in_sigmas: torch.FloatTensor, num_inference_steps
+ ) -> torch.FloatTensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+
+ rho = 7.0 # 7.0 is the value used in the paper
+ ramp = np.linspace(0, 1, num_inference_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ def _convert_to_lu(
+ self, in_lambdas: torch.FloatTensor, num_inference_steps
+ ) -> torch.FloatTensor:
+ """Constructs the noise schedule of Lu et al. (2022)."""
+
+ lambda_min: float = in_lambdas[-1].item()
+ lambda_max: float = in_lambdas[0].item()
+
+ rho = 1.0 # 1.0 is the value used in the paper
+ ramp = np.linspace(0, 1, num_inference_steps)
+ min_inv_rho = lambda_min ** (1 / rho)
+ max_inv_rho = lambda_max ** (1 / rho)
+ lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return lambdas
+
+ def convert_model_output(
+ self,
+ model_output: torch.FloatTensor,
+ *args,
+ sample: torch.FloatTensor = None,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ """
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
+ integral of the data prediction model.
+
+
<xarray.DataArray 'tas' (time: 140, lat: 96, lon: 96)>\n", + "[1290240 values with dtype=float32]\n", + "Coordinates:\n", + " * time (time) object 2100-01-01 12:00:00 ... 2100-01-28 12:00:00\n", + " * lat (lat) float64 -90.0 -88.11 -86.21 -84.32 ... 84.32 86.21 88.11 90.0\n", + " * lon (lon) float64 0.0 3.75 7.5 11.25 15.0 ... 345.0 348.8 352.5 356.2
<xarray.DataArray 'tas' (lat: 96, lon: 96)>\n", + "[9216 values with dtype=float32]\n", + "Coordinates:\n", + " time object 2080-01-01 12:00:00\n", + " * lat (lat) float64 -90.0 -88.11 -86.21 -84.32 ... 84.32 86.21 88.11 90.0\n", + " * lon (lon) float64 0.0 3.75 7.5 11.25 15.0 ... 345.0 348.8 352.5 356.2