diff --git a/src/create_quantiles.py b/src/create_quantiles.py index 181cf62..5241b32 100644 --- a/src/create_quantiles.py +++ b/src/create_quantiles.py @@ -29,13 +29,15 @@ def main(cfg: DictConfig): ) # Extract the xarray dataset and denormalize it - xr_ds = dataset.xr_data.map(denorm).sel(time=slice(START_YEAR, END_YEAR)) + xr_ds = dataset.xr_data.map(denorm).sel(time=slice(START_YEAR, END_YEAR)).compute() + breakpoint() + # Group by each day of the year + groups = xr_ds.groupby("time.dayofyear") - # Compute the quantiles - quantiles = xr_ds.load().quantile(cfg.quantile, dim="time").drop_vars("quantile") + quantiles = groups.quantile(q=[0.9, 0.95, 0.99, 0.999], dim="time") # Save the quantiles - save_name = f"{cfg.var}_{int(cfg.quantile * 100)}.nc" + save_name = f"{cfg.var}_quantiles.nc" save_path = os.path.join(cfg.paths.quantile_dir, cfg.esm, save_name) # Delete the file if it already exists (avoids permission denied errors) diff --git a/src/custom_diffusers/configuration_utils.py b/src/custom_diffusers/configuration_utils.py new file mode 100644 index 0000000..6ed837f --- /dev/null +++ b/src/custom_diffusers/configuration_utils.py @@ -0,0 +1,702 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" ConfigMixin base class and utilities.""" +import dataclasses +import functools +import importlib +import inspect +import json +import os +import re +from collections import OrderedDict +from pathlib import PosixPath +from typing import Any, Dict, Tuple, Union + +import numpy as np +from huggingface_hub import create_repo, hf_hub_download +from huggingface_hub.utils import ( + EntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, + validate_hf_hub_args, +) +from requests import HTTPError + +from . import __version__ +from .utils import ( + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + deprecate, + extract_commit_hash, + http_user_agent, + logging, +) + + +logger = logging.get_logger(__name__) + +_re_configuration_file = re.compile(r"config\.(.*)\.json") + + +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setitem__(name, value) + + +class ConfigMixin: + r""" + Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also + provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and + saving classes that inherit from [`ConfigMixin`]. + + Class attributes: + - **config_name** (`str`) -- A filename under which the config should stored when calling + [`~ConfigMixin.save_config`] (should be overridden by parent class). + - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be + overridden by subclass). + - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). + - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function + should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by + subclass). + """ + + config_name = None + ignore_for_config = [] + has_compatibles = False + + _deprecated_kwargs = [] + + def register_to_config(self, **kwargs): + if self.config_name is None: + raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") + # Special case for `kwargs` used in deprecation warning added to schedulers + # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, + # or solve in a more general way. + kwargs.pop("kwargs", None) + + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + previous_dict = dict(self._internal_dict) + internal_dict = {**self._internal_dict, **kwargs} + logger.debug(f"Updating config from {previous_dict} to {internal_dict}") + + self._internal_dict = FrozenDict(internal_dict) + + def __getattr__(self, name: str) -> Any: + """The only reason we overwrite `getattr` here is to gracefully deprecate accessing + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 + + This function is mostly copied from PyTorch's __getattr__ overwrite: + https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + """ + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) + return self._internal_dict[name] + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the + [`~ConfigMixin.from_config`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file is saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_config` + output_config_file = os.path.join(save_directory, self.config_name) + + self.to_json_file(output_config_file) + logger.info(f"Configuration saved in {output_config_file}") + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + + @classmethod + def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): + r""" + Instantiate a Python class from a config dictionary. + + Parameters: + config (`Dict[str, Any]`): + A config dictionary from which the Python class is instantiated. Make sure to only load configuration + files of compatible classes. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it is loaded) and initiate the Python class. + `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually + overwrite the same named arguments in `config`. + + Returns: + [`ModelMixin`] or [`SchedulerMixin`]: + A model or scheduler object instantiated from a config dictionary. + + Examples: + + ```python + >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler + + >>> # Download scheduler from huggingface.co and cache. + >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") + + >>> # Instantiate DDIM scheduler class with same config as DDPM + >>> scheduler = DDIMScheduler.from_config(scheduler.config) + + >>> # Instantiate PNDM scheduler class with same config as DDPM + >>> scheduler = PNDMScheduler.from_config(scheduler.config) + ``` + """ + # <===== TO BE REMOVED WITH DEPRECATION + # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated + if "pretrained_model_name_or_path" in kwargs: + config = kwargs.pop("pretrained_model_name_or_path") + + if config is None: + raise ValueError("Please make sure to provide a config as the first positional argument.") + # ======> + + if not isinstance(config, dict): + deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." + if "Scheduler" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead." + " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will" + " be removed in v1.0.0." + ) + elif "Model" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a model, please use {cls}.load_config(...) followed by" + f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary" + " instead. This functionality will be removed in v1.0.0." + ) + deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False) + config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs) + + init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) + + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + + # add possible deprecated kwargs + for deprecated_kwarg in cls._deprecated_kwargs: + if deprecated_kwarg in unused_kwargs: + init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) + + # Return model and optionally state and/or unused_kwargs + model = cls(**init_dict) + + # make sure to also save config parameters that might be used for compatible classes + # update _class_name + if "_class_name" in hidden_dict: + hidden_dict["_class_name"] = cls.__name__ + + model.register_to_config(**hidden_dict) + + # add hidden kwargs of compatible classes to unused_kwargs + unused_kwargs = {**unused_kwargs, **hidden_dict} + + if return_unused_kwargs: + return (model, unused_kwargs) + else: + return model + + @classmethod + def get_config_dict(cls, *args, **kwargs): + deprecation_message = ( + f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be" + " removed in version v1.0.0" + ) + deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) + return cls.load_config(*args, **kwargs) + + @classmethod + @validate_hf_hub_args + def load_config( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + return_unused_kwargs=False, + return_commit_hash=False, + **kwargs, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + r""" + Load a model or scheduler configuration. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with + [`~ConfigMixin.save_config`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + return_unused_kwargs (`bool`, *optional*, defaults to `False): + Whether unused keyword arguments of the config are returned. + return_commit_hash (`bool`, *optional*, defaults to `False): + Whether the `commit_hash` of the loaded configuration are returned. + + Returns: + `dict`: + A dictionary of all the parameters stored in a JSON configuration file. + + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + _ = kwargs.pop("mirror", None) + subfolder = kwargs.pop("subfolder", None) + user_agent = kwargs.pop("user_agent", {}) + + user_agent = {**user_agent, "file_type": "config"} + user_agent = http_user_agent(user_agent) + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + if cls.config_name is None: + raise ValueError( + "`self.config_name` is not defined. Note that one should not load a config from " + "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" + ) + + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): + # Load from a PyTorch checkpoint + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + ): + config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + else: + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + config_file = hf_hub_download( + pretrained_model_name_or_path, + filename=cls.config_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" + " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" + " token having permission to this repo with `token` or log in with `huggingface-cli login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" + " this model name. Check the model page at" + f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." + ) + except HTTPError as err: + raise EnvironmentError( + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" + " run the library in offline mode at" + " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a {cls.config_name} file" + ) + + try: + # Load config dict + config_dict = cls._dict_from_json_file(config_file) + + commit_hash = extract_commit_hash(config_file) + except (json.JSONDecodeError, UnicodeDecodeError): + raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") + + if not (return_unused_kwargs or return_commit_hash): + return config_dict + + outputs = (config_dict,) + + if return_unused_kwargs: + outputs += (kwargs,) + + if return_commit_hash: + outputs += (commit_hash,) + + return outputs + + @staticmethod + def _get_init_keys(cls): + return set(dict(inspect.signature(cls.__init__).parameters).keys()) + + @classmethod + def extract_init_dict(cls, config_dict, **kwargs): + # Skip keys that were not present in the original config, so default __init__ values were used + used_defaults = config_dict.get("_use_default_values", []) + config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"} + + # 0. Copy origin config dict + original_dict = dict(config_dict.items()) + + # 1. Retrieve expected config attributes from __init__ signature + expected_keys = cls._get_init_keys(cls) + expected_keys.remove("self") + # remove general kwargs if present in dict + if "kwargs" in expected_keys: + expected_keys.remove("kwargs") + # remove flax internal keys + if hasattr(cls, "_flax_internal_args"): + for arg in cls._flax_internal_args: + expected_keys.remove(arg) + + # 2. Remove attributes that cannot be expected from expected config attributes + # remove keys to be ignored + if len(cls.ignore_for_config) > 0: + expected_keys = expected_keys - set(cls.ignore_for_config) + + # load diffusers library to import compatible and original scheduler + diffusers_library = importlib.import_module(__name__.split(".")[0]) + + if cls.has_compatibles: + compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)] + else: + compatible_classes = [] + + expected_keys_comp_cls = set() + for c in compatible_classes: + expected_keys_c = cls._get_init_keys(c) + expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) + expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) + config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls} + + # remove attributes from orig class that cannot be expected + orig_cls_name = config_dict.pop("_class_name", cls.__name__) + if ( + isinstance(orig_cls_name, str) + and orig_cls_name != cls.__name__ + and hasattr(diffusers_library, orig_cls_name) + ): + orig_cls = getattr(diffusers_library, orig_cls_name) + unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys + config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig} + elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)): + raise ValueError( + "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)." + ) + + # remove private attributes + config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + + # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments + init_dict = {} + for key in expected_keys: + # if config param is passed to kwarg and is present in config dict + # it should overwrite existing config dict key + if key in kwargs and key in config_dict: + config_dict[key] = kwargs.pop(key) + + if key in kwargs: + # overwrite key + init_dict[key] = kwargs.pop(key) + elif key in config_dict: + # use value from config dict + init_dict[key] = config_dict.pop(key) + + # 4. Give nice warning if unexpected values have been passed + if len(config_dict) > 0: + logger.warning( + f"The config attributes {config_dict} were passed to {cls.__name__}, " + "but are not expected and will be ignored. Please verify your " + f"{cls.config_name} configuration file." + ) + + # 5. Give nice info if config attributes are initialized to default because they have not been passed + passed_keys = set(init_dict.keys()) + if len(expected_keys - passed_keys) > 0: + logger.info( + f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." + ) + + # 6. Define unused keyword arguments + unused_kwargs = {**config_dict, **kwargs} + + # 7. Define "hidden" config parameters that were saved for compatible classes + hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} + + return init_dict, unused_kwargs, hidden_config_dict + + @classmethod + def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @property + def config(self) -> Dict[str, Any]: + """ + Returns the config of the class as a frozen dictionary + + Returns: + `Dict[str, Any]`: Config of the class. + """ + return self._internal_dict + + def to_json_string(self) -> str: + """ + Serializes the configuration instance to a JSON string. + + Returns: + `str`: + String containing all the attributes that make up the configuration instance in JSON format. + """ + config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} + config_dict["_class_name"] = self.__class__.__name__ + config_dict["_diffusers_version"] = __version__ + + def to_json_saveable(value): + if isinstance(value, np.ndarray): + value = value.tolist() + elif isinstance(value, PosixPath): + value = str(value) + return value + + config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} + # Don't save "_ignore_files" or "_use_default_values" + config_dict.pop("_ignore_files", None) + config_dict.pop("_use_default_values", None) + + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save the configuration instance's parameters to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file to save a configuration instance's parameters. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + +def register_to_config(init): + r""" + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) + + new_kwargs = {**config_init_kwargs, **new_kwargs} + getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) + + return inner_init + + +def flax_register_to_config(cls): + original_init = cls.__init__ + + @functools.wraps(original_init) + def init(self, *args, **kwargs): + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + # Ignore private kwargs in the init. Retrieve all passed attributes + init_kwargs = dict(kwargs.items()) + + # Retrieve default values + fields = dataclasses.fields(self) + default_kwargs = {} + for field in fields: + # ignore flax specific attributes + if field.name in self._flax_internal_args: + continue + if type(field.default) == dataclasses._MISSING_TYPE: + default_kwargs[field.name] = None + else: + default_kwargs[field.name] = getattr(self, field.name) + + # Make sure init_kwargs override default kwargs + new_kwargs = {**default_kwargs, **init_kwargs} + # dtype should be part of `init_kwargs`, but not `new_kwargs` + if "dtype" in new_kwargs: + new_kwargs.pop("dtype") + + # Get positional arguments aligned with kwargs + for i, arg in enumerate(args): + name = fields[i].name + new_kwargs[name] = arg + + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) + + getattr(self, "register_to_config")(**new_kwargs) + original_init(self, *args, **kwargs) + + cls.__init__ = init + return cls \ No newline at end of file diff --git a/src/custom_diffusers/dpmsolver_multistep.py b/src/custom_diffusers/dpmsolver_multistep.py new file mode 100644 index 0000000..da8b149 --- /dev/null +++ b/src/custom_diffusers/dpmsolver_multistep.py @@ -0,0 +1,1160 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import deprecate +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_lu_lambdas (`bool`, *optional*, defaults to `False`): + Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during + the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of + `lambda(t)`. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma + is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + use_karras_sigmas: Optional[bool] = False, + use_lu_lambdas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate( + "algorithm_types dpmsolver and sde-dpmsolver", + "1.0.0", + deprecation_message, + ) + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace( + beta_start, beta_end, num_train_timesteps, dtype=torch.float32 + ) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError( + f"{beta_schedule} does is not implemented for {self.__class__}" + ) + + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in [ + "dpmsolver", + "dpmsolver++", + "sde-dpmsolver", + "sde-dpmsolver++", + ]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError( + f"{algorithm_type} does is not implemented for {self.__class__}" + ) + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError( + f"{solver_type} does is not implemented for {self.__class__}" + ) + + if ( + algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] + and final_sigmas_type == "zero" + ): + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace( + 0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32 + )[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, num_inference_steps: int = None, device: Union[str, torch.device] = None + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted( + torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped + ) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + + if self.config.use_karras_sigmas: + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + timesteps = np.array( + [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] + ).round() + elif self.config.use_lu_lambdas: + lambdas = np.flip(log_sigmas.copy()) + lambdas = self._convert_to_lu( + in_lambdas=lambdas, num_inference_steps=num_inference_steps + ) + sigmas = np.exp(lambdas) + timesteps = np.array( + [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] + ).round() + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64 + ) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = ( + sample.float() + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = ( + torch.clamp(sample, -s, s) / s + ) # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = ( + np.cumsum((dists >= 0), axis=0) + .argmax(axis=0) + .clip(max=log_sigmas.shape[0] - 2) + ) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras( + self, in_sigmas: torch.FloatTensor, num_inference_steps + ) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def _convert_to_lu( + self, in_lambdas: torch.FloatTensor, num_inference_steps + ) -> torch.FloatTensor: + """Constructs the noise schedule of Lu et al. (2022).""" + + lambda_min: float = in_lambdas[-1].item() + lambda_max: float = in_lambdas[0].item() + + rho = 1.0 # 1.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = lambda_min ** (1 / rho) + max_inv_rho = lambda_max ** (1 / rho) + lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return lambdas + + def convert_model_output( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + *args, + sample: torch.FloatTensor = None, + noise: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - ( + alpha_t * (torch.exp(-h) - 1.0) + ) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - ( + sigma_t * (torch.exp(h) - 1.0) + ) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + noise: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + *args, + sample: torch.FloatTensor = None, + **kwargs, + ) -> torch.FloatTensor: + """ + One step for the third-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.FloatTensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def get_velocity( + self, + sample: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.FloatTensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) + and self.config.lower_order_final + and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if ( + self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] + and variance_noise is None + ): + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=torch.float32, + ) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to(device=model_output.device, dtype=torch.float32) + else: + noise = None + + if ( + self.config.solver_order == 1 + or self.lower_order_nums < 1 + or lower_order_final + ): + prev_sample = self.dpm_solver_first_order_update( + model_output, sample=sample, noise=noise + ) + elif ( + self.config.solver_order == 2 + or self.lower_order_nums < 2 + or lower_order_second + ): + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, sample=sample, noise=noise + ) + else: + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, sample=sample + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input( + self, sample: torch.FloatTensor, *args, **kwargs + ) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype + ) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32 + ) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) for t in timesteps + ] + else: + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/data/climate_dataset.py b/src/data/climate_dataset.py index 3adfc25..01ed7ae 100644 --- a/src/data/climate_dataset.py +++ b/src/data/climate_dataset.py @@ -8,6 +8,7 @@ import xarray as xr from torch.utils.data import Dataset, DataLoader from accelerate import Accelerator +import dask # Constants for the minimum and maximum of our datasets MIN_MAX_CONSTANTS = {"tas": (-85.0, 60.0), "pr": (0.0, 6.0)} @@ -17,12 +18,12 @@ # Normalization and Inverse Normalization functions NORM_FN = { - "tas": lambda x: x / 20, - "pr": lambda x: np.log(1 + x), + "tas": lambda x: (x - 4.5) / 21.0, + "pr": lambda x: np.cbrt(x), } DENORM_FN = { - "tas": lambda x: x * 20, - "pr": lambda x: np.exp(x) - 1, + "tas": lambda x: x * 21.0 + 4.5, + "pr": lambda x: x**3, } # These functions transform the range of the data to [-1, 1] @@ -75,9 +76,11 @@ def __init__( data_dir: str, scenario: str, vars: list[str], + spatial_resolution=None, ): self.seq_len = seq_len self.realizations = realizations + self.spatial_resolution = spatial_resolution self.data_dir = os.path.join(data_dir, esm, scenario) @@ -98,6 +101,7 @@ def estimate_num_batches(self, batch_size: int) -> int: def load_data(self, realization: str): """Loads the data from the specified paths and returns it as an xarray Dataset.""" + realization_dir = os.path.join(self.data_dir, realization, "*.nc") # Open up the dataset and make sure it's sorted by time @@ -108,6 +112,11 @@ def load_data(self, realization: str): # Apply preprocessing and normalization self.xr_data = dataset.map(preprocess).map(normalize) + + if self.spatial_resolution is not None: + with dask.config.set(**{'array.slicing.split_large_chunks' : False}): + self.xr_data = self.xr_data.coarsen(lon=3, lat=2).mean() + self.tensor_data = self.convert_xarray_to_tensor(self.xr_data) def convert_xarray_to_tensor(self, ds: xr.Dataset) -> torch.Tensor: diff --git a/src/generate.py b/src/generate.py index e693c58..03fe77a 100644 --- a/src/generate.py +++ b/src/generate.py @@ -14,6 +14,7 @@ from diffusers import DDPMScheduler import xarray as xr from tqdm import tqdm +import pandas as pd # Local imports from data.climate_dataset import ClimateDataset @@ -26,6 +27,17 @@ realization_dict = {"gen": "r2", "val": "r2", "test": "r1"} +def get_starting_index(directory: str) -> int: + """Goes through a directory of files named "member_i.nc" and returns the next available index.""" + files = os.listdir(directory) + indices = [ + int(file.split("_")[1].split(".")[0]) + for file in files + if file.startswith("member") + ] + return max(indices) + 1 if indices else 0 + + def create_batches( xr_ds: xr.Dataset, dataset: ClimateDataset, @@ -92,6 +104,11 @@ def main(config: DictConfig) -> None: assert config.load_path, "Must specify a load path" assert os.path.isfile(config.load_path), "Invalid load path" + # Make sure num samples is 1 if gen mode is not gen + assert ( + config.samples_per == 1 or config.gen_mode == "gen" + ), "Number of samples must be 1 for val and test" + # Initialize all necessary objects accelerator = Accelerator(**config.accelerator, even_batches=False) @@ -101,7 +118,8 @@ def main(config: DictConfig) -> None: scenario=config.scenario, data_dir=config.paths.data_dir, realizations=[realization_dict[config.gen_mode]], - vars=[config.variable], + vars=config.variables, + spatial_resolution=config.spatial_resolution ) scheduler: DDPMScheduler = instantiate(config.scheduler) scheduler.set_timesteps(config.sample_steps) @@ -110,8 +128,10 @@ def main(config: DictConfig) -> None: # Load the model from the checkpoint chkpt: Checkpoint = torch.load(config.load_path, map_location="cpu") model = chkpt["EMA"].eval() + model = model.to(accelerator.device) else: model = None + # Grab the Xarray dataset from the dataset object xr_ds = dataset.xr_data.load() @@ -128,43 +148,56 @@ def main(config: DictConfig) -> None: # Prepare the model and dataloader for distributed training model, dataloader = accelerator.prepare(model, dataloader) - gen_samples = [] - - for tensor_batch, coords in tqdm( - dataloader, disable=not accelerator.is_main_process - ): - if model is not None: - gen_months = generate_samples( - tensor_batch, - scheduler=scheduler, - sample_steps=config.sample_steps, - model=model, - disable=not accelerator.is_main_process, + for i in tqdm(range(config.samples_per)): + gen_samples = [] + + for tensor_batch, coords in tqdm( + dataloader, disable=not accelerator.is_main_process + ): + tensor_batch = tensor_batch.to(accelerator.device) + if model is not None: + gen_months = generate_samples( + tensor_batch, + scheduler=scheduler, + sample_steps=config.sample_steps, + model=model, + disable=True, + ) + else: + gen_months = tensor_batch + + for i in range(len(gen_months)): + gen_samples.append( + dataset.convert_tensor_to_xarray(gen_months[i], coords=coords[i]) + ) + + gen_samples = accelerator.gather_for_metrics(gen_samples) + gen_samples = xr.concat(gen_samples, "time").drop_vars("height").sortby("time") + + if accelerator.is_main_process: + + # If we are generating multiple samples, create a directory for them + save_name = f"{config.gen_mode}_{config.save_name + '_' if config.save_name is not None else ''}{'_'.join(config.variables)}_{config.start_year}-{config.end_year}.nc" + save_path = os.path.join( + config.paths.save_dir, config.esm, config.scenario, save_name ) - else: - gen_months = tensor_batch + if config.gen_mode == "gen" and config.samples_per > 1: + save_dir = save_path.strip(".nc") + if not os.path.isdir(save_dir): + os.mkdir(save_dir) - for i in range(len(gen_months)): - gen_samples.append( - dataset.convert_tensor_to_xarray(gen_months[i], coords=coords[i]) - ) + mem_index = get_starting_index(save_dir) + save_path = os.path.join(save_dir, f"member_{mem_index}.nc") + + else: + # Delete the file if it already exists (avoids permission denied errors) + if os.path.isfile(save_path): + os.remove(save_path) + + # Save the generated samples + gen_samples.to_netcdf(save_path) - gen_samples = accelerator.gather_for_metrics(gen_samples) - gen_samples = xr.concat(gen_samples, "time").drop_vars("height").sortby("time") - - if accelerator.is_main_process: - # Construct the save path - save_name = f"{config.gen_mode}_{config.save_name + '_' if config.save_name is not None else ''}{config.variable}_{config.start_year}-{config.end_year}.nc" - save_path = os.path.join( - config.paths.save_dir, config.esm, config.scenario, save_name - ) - # Delete the file if it already exists (avoids permission denied errors) - if os.path.isfile(save_path): - os.remove(save_path) - # Save the generated samples - gen_samples.to_netcdf(save_path) - - os.chmod(save_path, 0o770) + os.chmod(save_path, 0o770) if __name__ == "__main__": diff --git a/src/models/rotary_embedding.py b/src/models/rotary_embedding.py new file mode 100644 index 0000000..4503afc --- /dev/null +++ b/src/models/rotary_embedding.py @@ -0,0 +1,283 @@ +from math import pi, log + +import torch +from torch.nn import Module, ModuleList +from torch.cuda.amp import autocast +from torch import nn, einsum, broadcast_tensors, Tensor + +from einops import rearrange, repeat + +from beartype import beartype +from beartype.typing import Literal, Union, Optional + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# broadcat, as tortoise-tts was using it + +def broadcat(tensors, dim = -1): + broadcasted_tensors = broadcast_tensors(*tensors) + return torch.cat(broadcasted_tensors, dim = dim) + +# rotary embedding helper functions + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') + +@autocast(enabled = False) +def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2): + if t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:].to(t) + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + return torch.cat((t_left, t, t_right), dim = -1) + +# learned rotation helpers + +def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None): + if exists(freq_ranges): + rotations = einsum('..., f -> ... f', rotations, freq_ranges) + rotations = rearrange(rotations, '... r f -> ... (r f)') + + rotations = repeat(rotations, '... n -> ... (n r)', r = 2) + return apply_rotary_emb(rotations, t, start_index = start_index) + +# classes + +class RotaryEmbedding(Module): + @beartype + def __init__( + self, + dim, + custom_freqs: Optional[Tensor] = None, + freqs_for: Union[ + Literal['lang'], + Literal['pixel'], + Literal['constant'] + ] = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + learned_freq = False, + use_xpos = False, + xpos_scale_base = 512, + interpolate_factor = 1., + theta_rescale_factor = 1., + seq_before_head_dim = False, + cache_if_possible = True + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + + self.cache_if_possible = cache_if_possible + + self.tmp_store('cached_freqs', None) + self.tmp_store('cached_scales', None) + + self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) + + self.learned_freq = learned_freq + + # dummy for device + + self.tmp_store('dummy', torch.tensor(0)) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1. + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + if not use_xpos: + self.tmp_store('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + self.tmp_store('scale', scale) + + @property + def device(self): + return self.dummy.device + + def tmp_store(self, key, value): + self.register_buffer(key, value, persistent = False) + + def get_seq_pos(self, seq_len, device, dtype, offset = 0): + return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor + + def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, freq_seq_len = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + if exists(freq_seq_len): + assert freq_seq_len >= seq_len + seq_len = freq_seq_len + + freqs = self.forward(self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset), seq_len = seq_len, offset = offset) + freqs = freqs.to(t.device) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + + return apply_rotary_emb(freqs, t, seq_dim = seq_dim) + + def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0): + seq_dim = default(seq_dim, self.default_seq_dim) + + q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] + assert q_len <= k_len + rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, freq_seq_len = k_len) + rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def rotate_queries_and_keys(self, q, k, seq_dim = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, dtype = dtype, device = device) + + freqs = self.forward(seq, seq_len = seq_len) + scale = self.get_scale(seq, seq_len = seq_len).to(dtype) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + scale = rearrange(scale, 'n d -> n 1 d') + + rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + @beartype + def get_scale( + self, + t: Tensor, + seq_len: Optional[int] = None, + offset = 0 + ): + assert self.use_xpos + + should_cache = ( + self.cache_if_possible and + exists(seq_len) + ) + + if ( + should_cache and \ + exists(self.cached_scales) and \ + (seq_len + offset) <= self.cached_scales.shape[0] + ): + return self.cached_scales[offset:(offset + seq_len)] + + scale = 1. + if self.use_xpos: + power = (t - len(t) // 2) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + if should_cache: + self.tmp_store('cached_scales', scale) + + return scale + + def get_axial_freqs(self, *dims): + Colon = slice(None) + all_freqs = [] + + for ind, dim in enumerate(dims): + if self.freqs_for == 'pixel': + pos = torch.linspace(-1, 1, steps = dim, device = self.device) + else: + pos = torch.arange(dim, device = self.device) + + freqs = self.forward(pos, seq_len = dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + all_freqs = broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim = -1) + + @autocast(enabled = False) + def forward( + self, + t: Tensor, + seq_len = None, + offset = 0 + ): + should_cache = ( + self.cache_if_possible and \ + not self.learned_freq and \ + exists(seq_len) and \ + self.freqs_for != 'pixel' + ) + + if ( + should_cache and \ + exists(self.cached_freqs) and \ + (offset + seq_len) <= self.cached_freqs.shape[0] + ): + return self.cached_freqs[offset:(offset + seq_len)].detach() + + freqs = self.freqs + + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + + if should_cache: + self.tmp_store('cached_freqs', freqs.detach()) + + return freqs \ No newline at end of file diff --git a/src/models/video_net.py b/src/models/video_net.py index 35160a7..73e01d2 100644 --- a/src/models/video_net.py +++ b/src/models/video_net.py @@ -5,10 +5,12 @@ import torch.nn as nn from einops_exts import rearrange_many from einops import rearrange -from rotary_embedding_torch import RotaryEmbedding + from torch.utils import checkpoint as torch_checkpoint +from models.rotary_embedding import RotaryEmbedding + def checkpoint(fn, *args, enabled=False): if enabled: @@ -415,7 +417,6 @@ def _forward( q = q * self.scale # rotate positions into queries and keys for time attention - if exists(self.rotary_emb): q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) @@ -566,6 +567,7 @@ def __init__( attn_heads=8, attn_dim_head=32, use_sparse_linear_attn=True, + use_mid_attn=False, init_kernel_size=7, resnet_groups=8, use_checkpoint=False, @@ -596,7 +598,7 @@ def __init__( (1, init_kernel_size, init_kernel_size), padding=(0, init_padding, init_padding), ) - + rotary_emb = RotaryEmbedding(min(32, attn_dim_head)) # If we are using temporal attn over convolution if use_temp_attn: # Define positional encodings and a temporal attention constructor @@ -604,8 +606,6 @@ def __init__( heads=attn_heads, max_distance=32 ) - # Create rotary embeddings for positional information - rotary_emb = RotaryEmbedding(32) # Create temporal attention operation only just frames def temporal_op(dim): @@ -626,6 +626,19 @@ def temporal_op(dim): def temporal_op(dim): return TemporalCNN(dim, kernel_size=3, use_checkpoint=use_checkpoint) + # Define positional encodings and a temporal attention constructor + self.time_rel_pos_bias = RelativePositionBias( + heads=attn_heads, max_distance=32 + ) + def temporal_attn(dim): + return EinopsToAndFrom( + "b c f h w", + "b (h w) f c", + Attention( + dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb + ), + ) + # Initial input temporal operation self.input_temp_op = Residual(PreNorm(model_dim, temporal_op(model_dim))) @@ -663,6 +676,7 @@ def temporal_op(dim): # Constructing down blocks of U-Net for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) + has_attn = ind >= (num_resolutions - 3) # Downblock: 2 Residual Blocks, 1 spatial linear attention, 1 temporal attention, 1 downsample (at all but last levels) self.downs.append( @@ -670,19 +684,21 @@ def temporal_op(dim): [ block_klass_cond(dim_in, dim_out), block_klass_cond(dim_out, dim_out), - Residual( - PreNorm( - dim_out, - SpatialLinearAttention( + ( + Residual( + PreNorm( dim_out, - heads=attn_heads, - use_checkpoint=use_checkpoint, - ), + SpatialLinearAttention( + dim_out, + heads=attn_heads, + use_checkpoint=use_checkpoint, + ), + ) ) - ) - if use_sparse_linear_attn - else nn.Identity(), - Residual(PreNorm(dim_out, temporal_op(dim_out))), + if use_sparse_linear_attn or has_attn + else nn.Identity() + ), + Residual(PreNorm(dim_out, temporal_op(dim_out) if not has_attn else temporal_attn(dim_out))), Downsample(dim_out) if not is_last else nn.Identity(), ] ) @@ -694,7 +710,7 @@ def temporal_op(dim): self.mid_block1 = block_klass_cond(mid_dim, mid_dim) # Only do spatial attn on middle layer if we are using spatial attn - if use_sparse_linear_attn: + if use_mid_attn: spatial_attn = EinopsToAndFrom( "b c f h w", "b f (h w) c", @@ -705,12 +721,14 @@ def temporal_op(dim): else: self.mid_spatial_attn = nn.Identity() - self.mid_temporal_attn = Residual(PreNorm(mid_dim, temporal_op(mid_dim))) + self.mid_temporal_attn = Residual(PreNorm(mid_dim, temporal_attn(mid_dim))) self.mid_block2 = block_klass_cond(mid_dim, mid_dim) # Construct Up Blocks for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind >= (num_resolutions - 1) + has_attn = ind in [0, 1, 2] # Up Block: 2 Residual blocks, 1 spatial attention, 1 temporal attention, 1 upsampling layer self.ups.append( @@ -720,19 +738,21 @@ def temporal_op(dim): dim_out * 2, dim_in ), # dim_out * 2 to account for incoming residual connection block_klass_cond(dim_in, dim_in), - Residual( - PreNorm( - dim_in, - SpatialLinearAttention( + ( + Residual( + PreNorm( dim_in, - heads=attn_heads, - use_checkpoint=use_checkpoint, - ), + SpatialLinearAttention( + dim_in, + heads=attn_heads, + use_checkpoint=use_checkpoint, + ), + ) ) - ) - if use_sparse_linear_attn - else nn.Identity(), - Residual(PreNorm(dim_in, temporal_op(dim_in))), + if use_sparse_linear_attn or has_attn + else nn.Identity() + ), + Residual(PreNorm(dim_in, temporal_op(dim_in) if not has_attn else temporal_attn(dim_in))), Upsample(dim_in) if not is_last else nn.Identity(), ] ) @@ -768,7 +788,8 @@ def forward( timesteps = timesteps[None].to(x.device) # If we are using attn for temporal representation - if self.use_temp_attn: + + if exists(self.time_rel_pos_bias): # Create keyword arguments for attention operations time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device) focus_present_mask = default( diff --git a/src/prepare_data.py b/src/prepare_data.py index ad32d83..5c412bd 100644 --- a/src/prepare_data.py +++ b/src/prepare_data.py @@ -82,7 +82,7 @@ def collect_var_data(path_list: list[str], base_dir: str) -> xr.Dataset: all_data = [] for path in path_list: all_data.append(xr.open_dataset(os.path.join(base_dir, path))) - return xr.concat(all_data, dim="time").sortby("time") + return xr.concat(all_data, dim="time").sortby("time").drop("time_bnds") @hydra.main(version_base=None, config_path="../configs", config_name="prepare_data") @@ -108,18 +108,19 @@ def main(cfg: DictConfig): cfg.scenario, ) - start_year = data["start_year"] - end_year = data["end_year"] + start_year = cfg.start_year + end_year = cfg.end_year # Iterate through each realization in our JSON file for realization, realization_data in data["realizations"].items(): + if realization in ["r1", "r2"]: + print(realization) + continue # Merge the two variables together - dataset = xr.merge( - [ - collect_var_data(path_list, load_dir) - for path_list in realization_data.values() - ] - ) + datasets = [collect_var_data(path_list, load_dir) for path_list in reversed(realization_data.values())] + datasets[1] = datasets[1].assign_coords({"time" : datasets[0].time}) + + dataset = xr.merge(datasets, join="right", compat="override") dataset = process_dataset(dataset, start_year, end_year) print(f"Finished processing realization {realization}") diff --git a/src/test_gen.ipynb b/src/test_gen.ipynb new file mode 100644 index 0000000..4d84608 --- /dev/null +++ b/src/test_gen.ipynb @@ -0,0 +1,1679 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import torch\n", + "import xarray as xr\n", + "\n", + "BASE_DIR = \"/research/hutchinson/data/ml_climate/saved_samples/\"\n", + "CHKPT_PATH = \"../checkpoints/ipsl_tas_rcp85_2.pt\"\n", + "VAR = \"tas\"\n", + "ESM = \"IPSL\"\n", + "SCENARIO = \"rcp85\"\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'tas' (time: 140, lat: 96, lon: 96)>\n",
+       "[1290240 values with dtype=float32]\n",
+       "Coordinates:\n",
+       "  * time     (time) object 2100-01-01 12:00:00 ... 2100-01-28 12:00:00\n",
+       "  * lat      (lat) float64 -90.0 -88.11 -86.21 -84.32 ... 84.32 86.21 88.11 90.0\n",
+       "  * lon      (lon) float64 0.0 3.75 7.5 11.25 15.0 ... 345.0 348.8 352.5 356.2
" + ], + "text/plain": [ + "\n", + "[1290240 values with dtype=float32]\n", + "Coordinates:\n", + " * time (time) object 2100-01-01 12:00:00 ... 2100-01-28 12:00:00\n", + " * lat (lat) float64 -90.0 -88.11 -86.21 -84.32 ... 84.32 86.21 88.11 90.0\n", + " * lon (lon) float64 0.0 3.75 7.5 11.25 15.0 ... 345.0 348.8 352.5 356.2" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "val_file = f\"gen_multi_samples_{VAR}_2100-2100.nc\"\n", + "\n", + "# Open up the datasets\n", + "val_set = xr.open_dataset(os.path.join(BASE_DIR, ESM, SCENARIO, val_file))[VAR]\n", + "\n", + "val_set" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'tas' (lat: 96, lon: 96)>\n",
+       "[9216 values with dtype=float32]\n",
+       "Coordinates:\n",
+       "    time     object 2080-01-01 12:00:00\n",
+       "  * lat      (lat) float64 -90.0 -88.11 -86.21 -84.32 ... 84.32 86.21 88.11 90.0\n",
+       "  * lon      (lon) float64 0.0 3.75 7.5 11.25 15.0 ... 345.0 348.8 352.5 356.2
" + ], + "text/plain": [ + "\n", + "[9216 values with dtype=float32]\n", + "Coordinates:\n", + " time object 2080-01-01 12:00:00\n", + " * lat (lat) float64 -90.0 -88.11 -86.21 -84.32 ... 84.32 86.21 88.11 90.0\n", + " * lon (lon) float64 0.0 3.75 7.5 11.25 15.0 ... 345.0 348.8 352.5 356.2" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cond_map = val_set.isel(time=0)\n", + "cond_map" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "UNetModel3D(\n", + " (input_conv): Conv3d(2, 64, kernel_size=(1, 7, 7), stride=(1, 1, 1), padding=(0, 3, 3))\n", + " (time_rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 8)\n", + " )\n", + " (input_temp_op): Residual(\n", + " (fn): PreNorm(\n", + " (fn): TemporalCNN(\n", + " (temporal_conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (time_mlp): Sequential(\n", + " (0): SinusoidalPosEmb()\n", + " (1): Linear(in_features=64, out_features=256, bias=True)\n", + " (2): SiLU()\n", + " (3): Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + " (downs): ModuleList(\n", + " (0): ModuleList(\n", + " (0-1): 2 x ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Identity()\n", + " (3): Residual(\n", + " (fn): PreNorm(\n", + " (fn): TemporalCNN(\n", + " (temporal_conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (4): Conv3d(64, 64, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1))\n", + " )\n", + " (1): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(64, 128, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(128, 128, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(128, 128, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(128, 128, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): PreNorm(\n", + " (fn): SpatialLinearAttention(\n", + " (to_qkv): Conv2d(128, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (to_out): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (3): Residual(\n", + " (fn): PreNorm(\n", + " (fn): EinopsToAndFrom(\n", + " (fn): Attention(\n", + " (rotary_emb): RotaryEmbedding()\n", + " (to_qkv): Linear(in_features=128, out_features=768, bias=False)\n", + " (to_out): Linear(in_features=256, out_features=128, bias=False)\n", + " )\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (4): Conv3d(128, 128, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1))\n", + " )\n", + " (2): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=384, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(128, 192, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 192, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(192, 192, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 192, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Conv3d(128, 192, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=384, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(192, 192, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 192, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(192, 192, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 192, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): PreNorm(\n", + " (fn): SpatialLinearAttention(\n", + " (to_qkv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (to_out): Conv2d(256, 192, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (3): Residual(\n", + " (fn): PreNorm(\n", + " (fn): EinopsToAndFrom(\n", + " (fn): Attention(\n", + " (rotary_emb): RotaryEmbedding()\n", + " (to_qkv): Linear(in_features=192, out_features=768, bias=False)\n", + " (to_out): Linear(in_features=256, out_features=192, bias=False)\n", + " )\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (4): Conv3d(192, 192, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1))\n", + " )\n", + " (3): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=512, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(192, 256, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(256, 256, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Conv3d(192, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=512, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(256, 256, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(256, 256, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): PreNorm(\n", + " (fn): SpatialLinearAttention(\n", + " (to_qkv): Conv2d(256, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (to_out): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (3): Residual(\n", + " (fn): PreNorm(\n", + " (fn): EinopsToAndFrom(\n", + " (fn): Attention(\n", + " (rotary_emb): RotaryEmbedding()\n", + " (to_qkv): Linear(in_features=256, out_features=768, bias=False)\n", + " (to_out): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (4): Identity()\n", + " )\n", + " )\n", + " (ups): ModuleList(\n", + " (0): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=384, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(512, 192, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 192, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(192, 192, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 192, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Conv3d(512, 192, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=384, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(192, 192, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 192, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(192, 192, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 192, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): PreNorm(\n", + " (fn): SpatialLinearAttention(\n", + " (to_qkv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (to_out): Conv2d(256, 192, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (3): Residual(\n", + " (fn): PreNorm(\n", + " (fn): EinopsToAndFrom(\n", + " (fn): Attention(\n", + " (rotary_emb): RotaryEmbedding()\n", + " (to_qkv): Linear(in_features=192, out_features=768, bias=False)\n", + " (to_out): Linear(in_features=256, out_features=192, bias=False)\n", + " )\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (4): ConvTranspose3d(192, 192, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1))\n", + " )\n", + " (1): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(384, 128, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(128, 128, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Conv3d(384, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(128, 128, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(128, 128, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): PreNorm(\n", + " (fn): SpatialLinearAttention(\n", + " (to_qkv): Conv2d(128, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (to_out): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (3): Residual(\n", + " (fn): PreNorm(\n", + " (fn): EinopsToAndFrom(\n", + " (fn): Attention(\n", + " (rotary_emb): RotaryEmbedding()\n", + " (to_qkv): Linear(in_features=128, out_features=768, bias=False)\n", + " (to_out): Linear(in_features=256, out_features=128, bias=False)\n", + " )\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (4): ConvTranspose3d(128, 128, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1))\n", + " )\n", + " (2): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(256, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Conv3d(256, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): PreNorm(\n", + " (fn): SpatialLinearAttention(\n", + " (to_qkv): Conv2d(64, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (to_out): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (3): Residual(\n", + " (fn): PreNorm(\n", + " (fn): EinopsToAndFrom(\n", + " (fn): Attention(\n", + " (rotary_emb): RotaryEmbedding()\n", + " (to_qkv): Linear(in_features=64, out_features=768, bias=False)\n", + " (to_out): Linear(in_features=256, out_features=64, bias=False)\n", + " )\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (4): ConvTranspose3d(64, 64, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1))\n", + " )\n", + " (3): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(128, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Conv3d(128, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Identity()\n", + " (3): Residual(\n", + " (fn): PreNorm(\n", + " (fn): TemporalCNN(\n", + " (temporal_conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (4): Identity()\n", + " )\n", + " )\n", + " (mid_block1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=512, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(256, 256, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(256, 256, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (mid_spatial_attn): Residual(\n", + " (fn): PreNorm(\n", + " (fn): EinopsToAndFrom(\n", + " (fn): Attention(\n", + " (to_qkv): Linear(in_features=256, out_features=768, bias=False)\n", + " (to_out): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (mid_temporal_attn): Residual(\n", + " (fn): PreNorm(\n", + " (fn): EinopsToAndFrom(\n", + " (fn): Attention(\n", + " (rotary_emb): RotaryEmbedding()\n", + " (to_qkv): Linear(in_features=256, out_features=768, bias=False)\n", + " (to_out): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " )\n", + " (norm): LayerNorm()\n", + " )\n", + " )\n", + " (mid_block2): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): SiLU()\n", + " (1): Linear(in_features=256, out_features=512, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (proj): Conv3d(256, 256, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(256, 256, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (out_conv): Sequential(\n", + " (0): ResnetBlock(\n", + " (block1): Block(\n", + " (proj): Conv3d(128, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (block2): Block(\n", + " (proj): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))\n", + " (norm): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (act): SiLU()\n", + " )\n", + " (res_conv): Conv3d(128, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " )\n", + " (1): Conv3d(64, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " )\n", + ")" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chkpt = torch.load(CHKPT_PATH, map_location=\"cpu\")\n", + "model = chkpt[\"EMA\"].to(0)\n", + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "diffesm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/train.py b/src/train.py index 8e39341..fb8ded0 100644 --- a/src/train.py +++ b/src/train.py @@ -36,6 +36,7 @@ def main(cfg: DictConfig) -> None: logger.info(f"Instantiating model <{cfg.model._target_}>") model: UNetModel3D = instantiate(cfg.model) + logger.info(str(model)) logger.info(f"Instantiating scheduler <{cfg.scheduler._target_}>") scheduler: DDPMScheduler = instantiate(cfg.scheduler) diff --git a/src/trainers/unet_trainer.py b/src/trainers/unet_trainer.py index c0e7fa9..b237ace 100644 --- a/src/trainers/unet_trainer.py +++ b/src/trainers/unet_trainer.py @@ -1,4 +1,5 @@ import os +import random from typing import Any, Callable import torch @@ -20,6 +21,23 @@ from custom_diffusers.continuous_ddpm import ContinuousDDPM +def calc_mse_loss(model_output, target): + """Manually calculate mse loss""" + spatial_loss = (model_output - target) ** 2 + + # Weight the equator more heavily than the poles + latitude = torch.linspace( + -80, 80, spatial_loss.shape[-2], device=spatial_loss.device + ) + latitude_rad = torch.deg2rad(latitude) + latitude_weight = torch.cos(latitude_rad) + + # Weight the loss + lat_weighted_loss = (spatial_loss * latitude_weight).mean() + + return lat_weighted_loss + + class UNetTrainer: """Trainer class for 2D diffusion models.""" @@ -176,6 +194,18 @@ def train(self): progress_bar.close() + def get_original_sample(self, noisy_sample, model_output, timesteps): + + alpha_prod_t = self.scheduler.alphas_cumprod[timesteps].view(-1, 1, 1, 1, 1) + beta_prod_t = 1 - alpha_prod_t + + pred_original_sample = (alpha_prod_t**0.5) * noisy_sample - (beta_prod_t**0.5) * model_output + + return pred_original_sample + + + + def get_loss(self, batch): clean_samples = batch.to(self.weight_dtype) cond_map = reduce(clean_samples, "b v t h w -> b v 1 h w", "mean").repeat( @@ -217,7 +247,22 @@ def get_loss(self, batch): raise NotImplementedError("Only epsilon and v_prediction supported") # Calculate loss and update gradients - loss = mse_loss(model_output.float(), target.float()) + mse_loss = calc_mse_loss(model_output, target) + + # Calculate the avg conditional loss + pred_original_sample = self.get_original_sample(noisy_samples, model_output, timesteps) + + # Get the mean of both the clean and the predicted original sample + clean_mean = clean_samples.mean(dim=-3) + pred_mean = pred_original_sample.mean(dim=-3) + + cond_loss = ((clean_mean - pred_mean) ** 2).mean() + + # Calculate the loss + loss = mse_loss + cond_loss * self.cond_loss_scaling + + + # Scale the loss by cosine-weighted latitude self.accelerator.backward(loss) if self.accelerator.sync_gradients: @@ -253,7 +298,7 @@ def sample(self) -> None: self.ema_model.eval() # Grab a random sample from validation set - batch = next(iter(self.val_loader.generate()))[0:1] + batch = random.choice(self.val_set).unsqueeze(0).to(self.accelerator.device) clean_samples = batch.to(self.weight_dtype)