diff --git a/README.md b/README.md index 24f17f6d..330a71f2 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,10 @@ A library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion, etc.) that can be used in [InvokeAI](https://github.com/invoke-ai/InvokeAI). +> [!WARNING] +> `invoke-training` is still under active development, and breaking changes are likely. Full backwards compatibility will not be guranteed until v1.0.0. +> In the meantime, I recommend pinning to a specific commit hash. + ## Documentation https://invoke-ai.github.io/invoke-training/ diff --git a/configs/finetune_lora_sd_pokemon_1x8gb_example.yaml b/configs/finetune_lora_sd_pokemon_1x8gb_example.yaml index 59ca0cc8..0f835525 100644 --- a/configs/finetune_lora_sd_pokemon_1x8gb_example.yaml +++ b/configs/finetune_lora_sd_pokemon_1x8gb_example.yaml @@ -10,7 +10,7 @@ type: FINETUNE_LORA_SD seed: 1 output: - base_output_dir: output/ + base_output_dir: output/finetune_lora_sd_pokemon/ optimizer: learning_rate: 1.0 @@ -28,12 +28,13 @@ data_loader: dataset_name: lambdalabs/pokemon-blip-captions image_transforms: resolution: 512 + dataloader_num_workers: 4 # General model: runwayml/stable-diffusion-v1-5 gradient_accumulation_steps: 1 mixed_precision: fp16 -xformers: True +xformers: False gradient_checkpointing: True # Dataset size is 833. Set max_train_steps to train for 2 epochs. # ceil(833 / 4) * 3 diff --git a/configs/finetune_lora_sdxl_pokemon_1x24gb_example.yaml b/configs/finetune_lora_sdxl_pokemon_1x24gb_example.yaml index bd7e4771..5f930080 100644 --- a/configs/finetune_lora_sdxl_pokemon_1x24gb_example.yaml +++ b/configs/finetune_lora_sdxl_pokemon_1x24gb_example.yaml @@ -9,7 +9,7 @@ type: FINETUNE_LORA_SDXL seed: 1 output: - base_output_dir: output/ + base_output_dir: output/finetune_lora_sdxl_pokemon/ optimizer: learning_rate: 1.0 @@ -33,7 +33,7 @@ model: stabilityai/stable-diffusion-xl-base-1.0 vae_model: madebyollin/sdxl-vae-fp16-fix gradient_accumulation_steps: 1 mixed_precision: fp16 -xformers: True +xformers: False gradient_checkpointing: True # Dataset size is 833. Set max_train_steps to train for 2 epochs. # ceil(833 / 6) * 2 diff --git a/configs/finetune_lora_sdxl_pokemon_1x8gb_example.yaml b/configs/finetune_lora_sdxl_pokemon_1x8gb_example.yaml index c6525eba..4b8b6548 100644 --- a/configs/finetune_lora_sdxl_pokemon_1x8gb_example.yaml +++ b/configs/finetune_lora_sdxl_pokemon_1x8gb_example.yaml @@ -10,7 +10,7 @@ type: FINETUNE_LORA_SDXL seed: 1 output: - base_output_dir: output/ + base_output_dir: output/finetune_lora_sdxl_pokemon/ optimizer: learning_rate: 1.0 @@ -37,7 +37,7 @@ cache_text_encoder_outputs: True enable_cpu_offload_during_validation: True gradient_accumulation_steps: 4 mixed_precision: fp16 -xformers: True +xformers: False gradient_checkpointing: True # Dataset size is 833. Set max_train_steps to train for 2 epochs. # ceil(833 / 4) * 3 diff --git a/docs/get_started/quick_start.md b/docs/get_started/quick_start.md index fb2eb08f..b14ccc49 100644 --- a/docs/get_started/quick_start.md +++ b/docs/get_started/quick_start.md @@ -27,15 +27,25 @@ Monitor the training process with Tensorboard by running `tensorboard --logdir o ![Screenshot of the Tensorboard UI showing validation images.](../images/tensorboard_val_images_screenshot.png) *Validation images in the Tensorboard UI.* -### 5. InvokeAI +### 5. Select a checkpoint Select a checkpoint based on the quality of the generated images. In this short training run, there are only 3 checkpoints to choose from. As an example, we'll use the **Epoch 2** checkpoint. +Internally, `invoke-training` stores the LoRA checkpoints in [PEFT format](https://huggingface.co/docs/peft/v0.7.1/en/package_reference/peft_model#peft.PeftModel.save_pretrained). We will convert the selected checkpoint to 'Kohya' format, because it has more widespread support across various UIs: +```bash +# Note: You will have to replace the timestamp in the checkpoint path. +python src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py \ + --src-ckpt-dir output/finetune_lora_sd_pokemon/1691088769.5694647/checkpoint_epoch-00000002 \ + --dst-ckpt-file output/finetune_lora_sd_pokemon/1691088769.5694647/checkpoint_epoch-00000002_kohya.safetensors +``` + +### 5. InvokeAI + If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation. Copy your selected LoRA checkpoint into your `${INVOKEAI_ROOT}/autoimport/lora` directory. For example: ```bash # Note: You will have to replace the timestamp in the checkpoint path. -cp output/1691088769.5694647/checkpoint_epoch-00000002.safetensors ${INVOKEAI_ROOT}/autoimport/lora/pokemon_epoch-00000002.safetensors +cp output/finetune_lora_sd_pokemon/1691088769.5694647/checkpoint_epoch-00000002_kohya.safetensors ${INVOKEAI_ROOT}/autoimport/lora/pokemon_epoch-00000002.safetensors ``` You can now use your trained Pokemon LoRA in the InvokeAI UI! 🎉 diff --git a/pyproject.toml b/pyproject.toml index 24dd5503..645332d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,11 +15,12 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "accelerate~=0.21.0", + "accelerate~=0.25.0", "datasets~=2.14.3", - "diffusers~=0.24.0", + "diffusers~=0.25.0", "numpy", "omegaconf", + "peft~=0.7.0", "Pillow", "prodigyopt", "pydantic", @@ -29,7 +30,7 @@ dependencies = [ "torch>=2.1.2", "torchvision", "tqdm", - "transformers~=4.35.0", + "transformers~=4.36.0", "xformers>=0.0.23", ] diff --git a/src/invoke_training/core/__init__.py b/src/invoke_training/core/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/invoke_training/core/lora/__init__.py b/src/invoke_training/core/lora/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/invoke_training/core/lora/injection/__init__.py b/src/invoke_training/core/lora/injection/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/invoke_training/core/lora/injection/lora_layer_collection.py b/src/invoke_training/core/lora/injection/lora_layer_collection.py deleted file mode 100644 index 7d470169..00000000 --- a/src/invoke_training/core/lora/injection/lora_layer_collection.py +++ /dev/null @@ -1,43 +0,0 @@ -import typing - -import torch - -from invoke_training.core.lora.layers import BaseLoRALayer - - -class LoRALayerCollection(torch.nn.Module): - """A collection of LoRA layers (with names). Typically used to perform operations on a group of LoRA layers during - training. - """ - - def __init__(self): - super().__init__() - - # A torch.nn.ModuleDict may seem like a more natural choice here, but it does not allow keys that contain '.' - # characters. Using a standard python dict is also inconvenient, because it would be ignored by torch.nn.Module - # methods such as `.parameters()` and `.train()`. - self._layers = torch.nn.ModuleList() - self._names = [] - - def add_layer(self, layer: BaseLoRALayer, name: str): - self._layers.append(layer) - self._names.append(name) - - def __len__(self): - return len(self._layers) - - def get_lora_state_dict(self) -> typing.Dict[str, torch.Tensor]: - """A custom alternative to .state_dict() that uses the layer names provided to add_layer(...) as key - prefixes. - """ - state_dict: typing.Dict[str, torch.Tensor] = {} - - for name, layer in zip(self._names, self._layers): - layer_state_dict = layer.state_dict() - for key, state in layer_state_dict.items(): - full_key = name + "." + key - if full_key in state_dict: - raise RuntimeError(f"Multiple state elements map to the same key: '{full_key}'.") - state_dict[full_key] = state - - return state_dict diff --git a/src/invoke_training/core/lora/injection/stable_diffusion.py b/src/invoke_training/core/lora/injection/stable_diffusion.py deleted file mode 100644 index c202aceb..00000000 --- a/src/invoke_training/core/lora/injection/stable_diffusion.py +++ /dev/null @@ -1,108 +0,0 @@ -import typing - -import torch -from diffusers.models import Transformer2DModel, UNet2DConditionModel -from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear -from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D -from transformers import CLIPTextModel -from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention - -from invoke_training.core.lora.injection.lora_layer_collection import LoRALayerCollection -from invoke_training.core.lora.injection.utils import inject_lora_layers -from invoke_training.core.lora.layers import LoRAConv2dLayer, LoRALinearLayer - - -def inject_lora_into_unet( - unet: UNet2DConditionModel, include_non_attention_blocks: bool = False, lora_rank_dim: int = 4 -) -> LoRALayerCollection: - """Inject LoRA layers into a Stable Diffusion UNet model. - - Args: - unet (UNet2DConditionModel): The UNet model to inject LoRA layers into. - include_non_attention_blocks (bool, optional): Whether to inject LoRA layers into the linear/conv layers of the - non-attention blocks (`ResnetBlock2D`, `Downsample2D`, `Upsample2D`). Defaults to False. - lora_rank_dim (int, optional): The LoRA layer rank dimension. - Returns: - LoRALayerCollection: The LoRA layers that were added to the UNet. - """ - include_descendants_of = {Transformer2DModel} - if include_non_attention_blocks: - include_descendants_of.update({ResnetBlock2D, Downsample2D, Upsample2D}) - - lora_layers = inject_lora_layers( - module=unet, - lora_map={ - torch.nn.Linear: LoRALinearLayer, - LoRACompatibleLinear: LoRALinearLayer, - torch.nn.Conv2d: LoRAConv2dLayer, - LoRACompatibleConv: LoRAConv2dLayer, - }, - include_descendants_of=include_descendants_of, - exclude_descendants_of=None, - prefix="lora_unet", - dtype=torch.float32, - lora_rank_dim=lora_rank_dim, - ) - - return lora_layers - - -def inject_lora_into_clip_text_encoder(text_encoder: CLIPTextModel, prefix: str = "lora_te", lora_rank_dim: int = 4): - lora_layers = inject_lora_layers( - module=text_encoder, - lora_map={ - torch.nn.Linear: LoRALinearLayer, - torch.nn.Conv2d: LoRAConv2dLayer, - }, - include_descendants_of={CLIPAttention, CLIPMLP}, - exclude_descendants_of=None, - prefix=prefix, - dtype=torch.float32, - lora_rank_dim=lora_rank_dim, - ) - - return lora_layers - - -def convert_lora_state_dict_to_kohya_format( - state_dict: typing.Dict[str, torch.Tensor], -) -> typing.Dict[str, torch.Tensor]: - """Convert a Stable Diffusion LoRA state_dict from internal invoke-training format to kohya_ss format. - - Args: - state_dict (typing.Dict[str, torch.Tensor]): LoRA layer state_dict in invoke-training format. - - Raises: - ValueError: If state_dict contains unexpected keys. - RuntimeError: If two input keys map to the same output kohya_ss key. - - Returns: - typing.Dict[str, torch.Tensor]: LoRA layer state_dict in kohya_ss format. - """ - new_state_dict = {} - - # The following logic converts state_dict keys from the internal invoke-training format to kohya_ss format. - # Example conversion: - # from: 'lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.weight' - # to: 'lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight' - for key, val in state_dict.items(): - if key.endswith("._up.weight"): - key_start = key.removesuffix("._up.weight") - key_end = ".lora_up.weight" - elif key.endswith("._down.weight"): - key_start = key.removesuffix("._down.weight") - key_end = ".lora_down.weight" - elif key.endswith(".alpha"): - key_start = key.removesuffix(".alpha") - key_end = ".alpha" - else: - raise ValueError(f"Unexpected key in state_dict: '{key}'.") - - new_key = key_start.replace(".", "_") + key_end - - if new_key in new_state_dict: - raise RuntimeError("Multiple input keys map to the same kohya_ss key: '{new_key}'.") - - new_state_dict[new_key] = val - - return new_state_dict diff --git a/src/invoke_training/core/lora/injection/utils.py b/src/invoke_training/core/lora/injection/utils.py deleted file mode 100644 index 52807bb0..00000000 --- a/src/invoke_training/core/lora/injection/utils.py +++ /dev/null @@ -1,133 +0,0 @@ -import typing - -import torch - -from invoke_training.core.lora.injection.lora_layer_collection import LoRALayerCollection -from invoke_training.core.lora.layers import BaseLoRALayer -from invoke_training.core.lora.lora_block import LoRABlock - - -def find_modules( - module: torch.nn.Module, - targets: typing.Set[typing.Type[torch.nn.Module]], - include_descendants_of: typing.Optional[typing.Set[typing.Type[torch.nn.Module]]] = None, - exclude_descendants_of: typing.Optional[typing.Set[typing.Type[torch.nn.Module]]] = None, - memo: typing.Optional[typing.Set[torch.nn.Module]] = None, - prefix: str = "", - parent: typing.Optional[torch.nn.Module] = None, -) -> typing.Iterator[typing.Tuple[str, torch.nn.Module, torch.nn.Module]]: - """Find sub-modules of 'module' that satisfy the search criteria. - Args: - module (torch.nn.Module): The base module whose sub-modules will be searched. - targets (typing.Set[typing.Type[torch.nn.Module]]): The set of module types to search for. - include_descendants_of (typing.Set[typing.Type[torch.nn.Module]], optional): If set, then only - descendants of these types (and their subclasses) will be searched. 'exclude_descendants_of' takes - precedence over 'include_descendants_of'. - exclude_descendants_of (typing.Set[typing.Type[torch.nn.Module]], optional): If set, then the - descendants of these types (and their subclasses) will be ignored in the search. 'exclude_descendants_of' - takes precedence over 'include_descendants_of'. - memo (typing.Set[torch.nn.Module], optional): A memo to store the set of modules already visited in the search. - memo is typically only set in recursive calls of this function. - prefix (str, optional): A prefix that will be added to the module name. - parent (torch.nn.Module, optional): The parent of 'module'. This is used for tracking the parent in recursive - calls to this function so that it can be returned along with the module. - Yields: - typing.Tuple[str, torch.nn.Module, torch.nn.Module]: A tuple (name, parent, module) that matches the search - criteria. - """ - - if memo is None: - memo = set() - - if module in memo: - # We've already visited this module in the search. - return - - memo.add(module) - - # If we have hit an excluded module type, do not search any further. - # Note that this takes precedence over include_descendants_of. - if exclude_descendants_of is not None and isinstance(module, tuple(exclude_descendants_of)): - return - - # If the include_descendants_of requirement is already satisfied, and this module matches a target class, then all - # of the search criteria are satisfied, so yield it. - if include_descendants_of is None and isinstance(module, tuple(targets)): - yield prefix, parent, module - - # The include_descendants_of requirement is NOT YET satisfied. Check if this module satisfies it. - updated_include_descendants_of = include_descendants_of - if include_descendants_of is not None and isinstance(module, tuple(include_descendants_of)): - # Drop the include_descendants_of requirement if this module satisfied it. - updated_include_descendants_of = None - - # Recursively search the child modules. - for child_name, child_module in module.named_children(): - submodule_prefix = prefix + ("." if prefix else "") + child_name - yield from find_modules( - module=child_module, - targets=targets, - include_descendants_of=updated_include_descendants_of, - exclude_descendants_of=exclude_descendants_of, - memo=memo, - prefix=submodule_prefix, - parent=module, - ) - - -def inject_lora_layers( - module: torch.nn.Module, - lora_map: typing.Dict[type[torch.nn.Module], type[BaseLoRALayer]], - include_descendants_of: typing.Optional[typing.Set[typing.Type[torch.nn.Module]]] = None, - exclude_descendants_of: typing.Optional[typing.Set[typing.Type[torch.nn.Module]]] = None, - prefix: str = "", - dtype: torch.dtype = None, - lora_rank_dim: int = 4, -) -> LoRALayerCollection: - """Iterates over all of the modules in 'module' and if they are present in 'replace_map' then replaces them with the - mapped LoRA layer type. - Args: - module (torch.nn.Module): The original module that will be monkeypatched. - lora_map (typing.Dict[type[torch.nn.Module], type[torch.nn.Module]]): A mapping from module types that should - have LoRA layers added to the type of LoRA layers that should be used. - Example: - ``` - lora_map = {torch.nn.Linear: LoRALinearLayer} - ``` - include_descendants_of (typing.Set[typing.Type[torch.nn.Module]], optional): Forwarded to find_modules(...). - exclude_descendants_of (typing.Set[typing.Type[torch.nn.Module]], optional): Forwarded to find_modules(...). - prefix (str, optional): A prefix that will be added to the names of all of the LoRA layers. - dtype (torch.dtype, optional): The dtype to construct the new layer with. - lora_rank_dim (int, optional): The rank dimension to use for the injected LoRA layers. - Returns: - LoRALayerCollection: A ModuleDict of all of the LoRA layers that were injected into the module. - """ - lora_layers = LoRALayerCollection() - - for name, parent, module in find_modules( - module=module, - targets=lora_map.keys(), - include_descendants_of=include_descendants_of, - exclude_descendants_of=exclude_descendants_of, - prefix=prefix, - ): - # Lookup the LoRA class to use. - lora_layer_cls = lora_map[type(module)] - - # Initialize the LoRA layer with the correct dimensions. - lora_layer = lora_layer_cls.from_layer(module, rank=lora_rank_dim, dtype=dtype) - - # Join the LoRA layer and the original layer in a block. - lora_block = LoRABlock(original_module=module, lora_layer=lora_layer) - - # Monkey-patch the parent module with the new LoRA block. - child_field_name = name.split(".")[-1] - setattr( - parent, - child_field_name, - lora_block, - ) - - lora_layers.add_layer(lora_layer, name) - - return lora_layers diff --git a/src/invoke_training/core/lora/layers/__init__.py b/src/invoke_training/core/lora/layers/__init__.py deleted file mode 100644 index 7af93c10..00000000 --- a/src/invoke_training/core/lora/layers/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .base_lora_layer import BaseLoRALayer # noqa: F401 -from .lora_conv_layer import ( # noqa: F401 - LoRAConv1dLayer, - LoRAConv2dLayer, - LoRAConv3dLayer, -) -from .lora_linear_layer import LoRALinearLayer # noqa: F401 diff --git a/src/invoke_training/core/lora/layers/base_lora_layer.py b/src/invoke_training/core/lora/layers/base_lora_layer.py deleted file mode 100644 index d45a6672..00000000 --- a/src/invoke_training/core/lora/layers/base_lora_layer.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch - - -class BaseLoRALayer(torch.nn.Module): - """An interface that is implemented by all LoRA layers.""" - - @classmethod - def from_layer( - cls, - layer: torch.nn.Module, - device: torch.device = None, - dtype: torch.dtype = None, - **kwargs, - ): - """Initialize a LoRA layer with dimensions that are compatible with 'layer'. - Args: - layer (torch.nn.Module): The existing layer whose in/out dimensions will be matched. - device (torch.device, optional): The device to construct the new layer on. - dtype (torch.dtype, optional): The dtype to construct the new layer with. - Raises: - TypeError: If layer has an unsupported type. - Returns: - cls: The new LoRA layer. - """ - raise NotImplementedError("from_layer(...) is not yet implemented.") diff --git a/src/invoke_training/core/lora/layers/lora_conv_layer.py b/src/invoke_training/core/lora/layers/lora_conv_layer.py deleted file mode 100644 index f8c8763f..00000000 --- a/src/invoke_training/core/lora/layers/lora_conv_layer.py +++ /dev/null @@ -1,131 +0,0 @@ -import math -import typing - -import torch - -from invoke_training.core.lora.layers import BaseLoRALayer - - -class LoRAConvLayer(BaseLoRALayer): - """An implementation of a conv LoRA layer based on the paper 'LoRA: Low-Rank Adaptation of Large Language Models'. - (https://arxiv.org/pdf/2106.09685.pdf) - """ - - @property - def conv_module(self): - """The conv module to be set by child classes. One of torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d.""" - raise NotImplementedError( - "LoRAConvLayer cannot be used directly. Use LoRAConv1dLayer, LoRAConv2dLayer, or LoRAConv3dLayer instead." - ) - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: typing.Union[int, tuple[int]] = 1, - stride: typing.Union[int, tuple[int]] = 1, - padding: typing.Union[str, int, tuple[int]] = 0, - rank: int = 4, - alpha: float = 1.0, - device: torch.device = None, - dtype: torch.dtype = None, - ): - """Initialize a LoRAConvLayer. - Args: - in_channels (int): The number of channels expected on inputs to this layer. - out_channels (int): The number of channels on outputs from this layer. - kernel_size: The kernel_size of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs. - stride: The stride of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs. - padding: The padding of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs. - rank (int, optional): The internal rank of the layer. See the paper for details. - alpha (float, optional): A scaling factor that enables tuning the rank without having to adjust the learning - rate. The recommendation from the paper is to set alpha equal to the first rank that you try and then do - not tune it further. See the paper for more details. - device (torch.device, optional): Device where weights will be initialized. - dtype (torch.dtype, optional): Weight dtype. - Raises: - ValueError: If the rank is greater than either in_channels or out_channels. - """ - super().__init__() - - if rank > min(in_channels, out_channels): - raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_channels, out_channels)}") - - self._down = self.conv_module( - in_channels, - rank, - kernel_size=kernel_size, - stride=stride, - padding=padding, - bias=False, - device=device, - dtype=dtype, - ) - self._up = self.conv_module(rank, out_channels, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype) - - # Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict. - self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype)) - - self._rank = rank - - self.reset_parameters() - - def reset_parameters(self): - # This initialization is based on Microsoft's implementation: - # https://github.com/microsoft/LoRA/blob/998cfe4d351f4d6b4a47f0921dec2397aa0b9dfe/loralib/layers.py#L279 - torch.nn.init.kaiming_uniform_(self._down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self._up.weight) - - @classmethod - def from_layer( - cls, - layer: torch.nn.Module, - rank: int = 4, - alpha: float = 1.0, - device: torch.device = None, - dtype: torch.dtype = None, - ): - """Initialize a LoRAConvLayer with dimensions that are compatible with `layer`. - Args: - layer (torch.nn.Module): The existing layer whose in/out dimensions will be matched. - rank, alpha, device, dtype: These args are forwarded to __init__(...). If device or dtype are None, they - will be inferred from `layer`. - Raises: - TypeError: If `layer` has an unsupported type. - Returns: - LoRAConvLayer: The new LoRAConvLayer. - """ - if isinstance(layer, cls.conv_module): - return cls( - in_channels=layer.in_channels, - out_channels=layer.out_channels, - kernel_size=layer.kernel_size, - stride=layer.stride, - padding=layer.padding, - rank=rank, - alpha=alpha, - device=layer.weight.device if device is None else device, - dtype=layer.weight.dtype if dtype is None else dtype, - ) - else: - raise TypeError(f"'{__class__.__name__}' cannot be initialized from a layer of type '{type(layer)}'.") - - def forward(self, input: torch.Tensor): - down_hidden = self._down(input) - up_hidden = self._up(down_hidden) - - up_hidden *= self.alpha / self._rank - - return up_hidden - - -class LoRAConv1dLayer(LoRAConvLayer): - conv_module = torch.nn.Conv1d - - -class LoRAConv2dLayer(LoRAConvLayer): - conv_module = torch.nn.Conv2d - - -class LoRAConv3dLayer(LoRAConvLayer): - conv_module = torch.nn.Conv3d diff --git a/src/invoke_training/core/lora/layers/lora_linear_layer.py b/src/invoke_training/core/lora/layers/lora_linear_layer.py deleted file mode 100644 index 55b65a84..00000000 --- a/src/invoke_training/core/lora/layers/lora_linear_layer.py +++ /dev/null @@ -1,93 +0,0 @@ -import math - -import torch - -from invoke_training.core.lora.layers import BaseLoRALayer - - -class LoRALinearLayer(BaseLoRALayer): - """An implementation of a linear LoRA layer based on the paper 'LoRA: Low-Rank Adaptation of Large Language Models'. - (https://arxiv.org/pdf/2106.09685.pdf) - """ - - def __init__( - self, - in_features: int, - out_features: int, - rank: int = 4, - alpha: float = 1.0, - device: torch.device = None, - dtype: torch.dtype = None, - ): - """Initialize a LoRALinearLayer. - Args: - in_features (int): Inputs to this layer will be expected to have shape (..., in_features). - out_features (int): This layer will produce outputs with shape (..., out_features). - rank (int, optional): The internal rank of the layer. See the paper for details. - alpha (float, optional): A scaling factor that enables tuning the rank without having to adjust the learning - rate. The recommendation from the paper is to set alpha equal to the first rank that you try and then do - not tune it further. See the paper for more details. - device (torch.device, optional): Device where weights will be initialized. - dtype (torch.dtype, optional): Weight dtype. - Raises: - ValueError: If the rank is greater than either in_features or out_features. - """ - super().__init__() - - if rank > min(in_features, out_features): - raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_features, out_features)}") - - self._down = torch.nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) - self._up = torch.nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) - - # Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict. - self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype)) - - self._rank = rank - - self.reset_parameters() - - def reset_parameters(self): - # This initialization is based on Microsoft's implementation: - # https://github.com/microsoft/LoRA/blob/998cfe4d351f4d6b4a47f0921dec2397aa0b9dfe/loralib/layers.py#L123 - torch.nn.init.kaiming_uniform_(self._down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self._up.weight) - - @classmethod - def from_layer( - cls, - layer: torch.nn.Linear, - rank: int = 4, - alpha: float = 1.0, - device: torch.device = None, - dtype: torch.dtype = None, - ): - """Initialize a LoRALinearLayer with dimensions that are compatible with 'layer'. - Args: - layer (torch.nn.Linear): The existing layer whose in/out dimensions will be matched. - rank, alpha, device, dtype: These args are forwarded to __init__(...). If device or dtype are None, they - will be inferred from 'layer'. - Raises: - TypeError: If 'layer' has an unsupported type. - Returns: - LoRALinearLayer: The new LoRALinearLayer. - """ - if isinstance(layer, torch.nn.Linear): - return cls( - layer.in_features, - layer.out_features, - rank, - alpha, - layer.weight.device if device is None else device, - layer.weight.dtype if dtype is None else dtype, - ) - else: - raise TypeError(f"'{__class__.__name__}' cannot be initialized from a layer of type '{type(layer)}'.") - - def forward(self, input: torch.Tensor): - down_hidden = self._down(input) - up_hidden = self._up(down_hidden) - - up_hidden *= self.alpha / self._rank - - return up_hidden diff --git a/src/invoke_training/core/lora/lora_block.py b/src/invoke_training/core/lora/lora_block.py deleted file mode 100644 index fdb18df4..00000000 --- a/src/invoke_training/core/lora/lora_block.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch - - -class LoRABlock(torch.nn.Module): - """A wrapper block that combines the outputs of an 'original' module and a parallel 'LoRA' layer.""" - - def __init__(self, original_module: torch.nn.Module, lora_layer: torch.nn.Module, lora_multiplier: float = 1.0): - """Initialize a LoRABlock. - Args: - original_module (torch.nn.Module): The original module. - lora_layer (torch.nn.Module): The LoRA layer. - lora_multiplier (float, optional): A multiplier applied to the LoRA layer output before adding it to the - original module output. Defaults to 1.0. - """ - super().__init__() - - self.original_module = original_module - self.lora_layer = lora_layer - self.lora_multiplier = lora_multiplier - - def forward(self, input, *args, **kwargs): - return self.original_module(input, *args, **kwargs) + self.lora_multiplier * self.lora_layer(input) diff --git a/src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py b/src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py new file mode 100644 index 00000000..19046144 --- /dev/null +++ b/src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py @@ -0,0 +1,55 @@ +import argparse +from pathlib import Path + +import torch + +from invoke_training.training.shared.stable_diffusion.lora_checkpoint_utils import ( + convert_sd_peft_checkpoint_to_kohya_state_dict, +) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Convert a Stable Diffusion LoRA checkpoint to kohya format.") + parser.add_argument( + "--src-ckpt-dir", + type=str, + required=True, + help="Path to the source checkpoint directory.", + ) + parser.add_argument( + "--dst-ckpt-file", + type=str, + required=True, + help="Path to the destination Kohya checkpoint file.", + ) + parser.add_argument( + "--dtype", + type=str, + default="fp16", + help="The precision to save the kohya state dict in. One of ['fp16', 'fp32'].", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + in_checkpoint_dir = Path(args.src_ckpt_dir) + out_checkpoint_file = Path(args.dst_ckpt_file) + + if args.dtype == "fp32": + dtype = torch.float32 + elif args.dtype == "fp16": + dtype = torch.float16 + else: + raise ValueError(f"Unsupported --dtype = '{args.dtype}'.") + + convert_sd_peft_checkpoint_to_kohya_state_dict( + in_checkpoint_dir=in_checkpoint_dir, out_checkpoint_file=out_checkpoint_file, dtype=dtype + ) + + print(f"Saved kohya checkpoint to '{out_checkpoint_file}'.") + + +if __name__ == "__main__": + main() diff --git a/src/invoke_training/training/pipelines/stable_diffusion/finetune_lora_sd.py b/src/invoke_training/training/pipelines/stable_diffusion/finetune_lora_sd.py index 70104083..f4614e4b 100644 --- a/src/invoke_training/training/pipelines/stable_diffusion/finetune_lora_sd.py +++ b/src/invoke_training/training/pipelines/stable_diffusion/finetune_lora_sd.py @@ -1,3 +1,4 @@ +import itertools import json import logging import math @@ -7,6 +8,7 @@ from typing import Optional, Union import numpy as np +import peft import torch import torch.utils.data from accelerate import Accelerator @@ -23,10 +25,6 @@ DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig, ) -from invoke_training.core.lora.injection.stable_diffusion import ( - inject_lora_into_clip_text_encoder, - inject_lora_into_unet, -) from invoke_training.training.shared.accelerator.accelerator_utils import ( get_mixed_precision_dtype, initialize_accelerator, @@ -42,10 +40,34 @@ ) from invoke_training.training.shared.data.transforms.tensor_disk_cache import TensorDiskCache from invoke_training.training.shared.optimizer.optimizer_utils import initialize_optimizer -from invoke_training.training.shared.stable_diffusion.lora_checkpoint_utils import save_lora_checkpoint +from invoke_training.training.shared.stable_diffusion.lora_checkpoint_utils import ( + save_sd_lora_checkpoint, +) from invoke_training.training.shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline from invoke_training.training.shared.stable_diffusion.tokenize_captions import tokenize_captions +# Copied from https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/stable_diffusion/train_dreambooth.py#L49C1-L65C87 +# TODO(ryand): Is this the set of modules that we want to use? +# UNET_TARGET_MODULES = [ +# "to_q", +# "to_k", +# "to_v", +# "proj", +# "proj_in", +# "proj_out", +# "conv", +# "conv1", +# "conv2", +# "conv_shortcut", +# "to_out.0", +# "time_emb_proj", +# "ff.net.2", +# ] +# TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"] +# Module lists copied from diffusers training script: +UNET_TARGET_MODULES = ["to_k", "to_q", "to_v", "to_out.0"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "out_proj"] + def load_models( config: FinetuneLoRASDConfig, @@ -409,6 +431,8 @@ def run_training(config: FinetuneLoRASDConfig): # noqa: C901 if config.xformers: import xformers # noqa: F401 + # TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers + # will fail because Q, K, V have different dtypes. unet.enable_xformers_memory_efficient_attention() vae.enable_xformers_memory_efficient_attention() @@ -467,26 +491,56 @@ def run_training(config: FinetuneLoRASDConfig): # noqa: C901 unet.to(accelerator.device, dtype=weight_dtype) - lora_layers = torch.nn.ModuleDict() + # Add LoRA layers to the models being trained. + trainable_param_groups = [] + all_trainable_models: list[peft.PeftModel] = [] + + def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = None) -> peft.PeftModel: + peft_model = peft.get_peft_model(model, lora_config) + peft_model.print_trainable_parameters() + + # Populate `trainable_param_groups`, to be passed to the optimizer. + # Note: PeftModel.parameters() returns only the trainable LoRA params. + param_group = {"params": list(filter(lambda p: p.requires_grad, peft_model.parameters()))} + if lr is not None: + param_group["lr"] = lr + trainable_param_groups.append(param_group) + + # Populate all_trainable_models. + all_trainable_models.append(peft_model) + + peft_model.train() + + return peft_model + + # Add LoRA layers to the model. trainable_param_groups = [] if config.train_unet: - lora_layers["unet"] = inject_lora_into_unet( - unet, config.train_unet_non_attention_blocks, lora_rank_dim=config.lora_rank_dim + unet_lora_config = peft.LoraConfig( + r=config.lora_rank_dim, + # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred? + lora_alpha=1.0, + target_modules=UNET_TARGET_MODULES, ) - unet_param_group = {"params": lora_layers["unet"].parameters()} - if config.unet_learning_rate is not None: - unet_param_group["lr"] = config.unet_learning_rate - trainable_param_groups.append(unet_param_group) + unet = inject_lora_layers(unet, unet_lora_config, lr=config.unet_learning_rate) + if config.train_text_encoder: - lora_layers["text_encoder"] = inject_lora_into_clip_text_encoder( - text_encoder, lora_rank_dim=config.lora_rank_dim + text_encoder_lora_config = peft.LoraConfig( + r=config.lora_rank_dim, + lora_alpha=1.0, + # init_lora_weights="gaussian", + target_modules=TEXT_ENCODER_TARGET_MODULES, ) - text_encoder_param_group = {"params": lora_layers["text_encoder"].parameters()} - if config.text_encoder_learning_rate is not None: - text_encoder_param_group["lr"] = config.text_encoder_learning_rate - trainable_param_groups.append(text_encoder_param_group) + text_encoder = inject_lora_layers(text_encoder, text_encoder_lora_config, lr=config.text_encoder_learning_rate) + + # Make sure all trainable params are in float32. + for trainable_model in all_trainable_models: + for param in trainable_model.parameters(): + if param.requires_grad: + param.data = param.to(torch.float32) if config.gradient_checkpointing: + # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained. unet.enable_gradient_checkpointing() # unet must be in train() mode for gradient checkpointing to take effect. # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does @@ -494,13 +548,17 @@ def run_training(config: FinetuneLoRASDConfig): # noqa: C901 unet.train() if config.train_text_encoder: text_encoder.gradient_checkpointing_enable() - # text_encoder must be in train() mode for gradient checkpointing to take effect. - # At the time of writing, the text_encoder dropout probabilities default to 0, so putting the text_encoder - # in train mode does not change its forward behavior. + + # The text encoder must be in train() mode for gradient checkpointing to take effect. This should + # already be the case, since we are training the text_encoder, but we do it explicitly to make it clear + # that this is required. + # At the time of writing, the text encoder dropout probabilities default to 0, so putting the text + # encoders in train mode does not change their forward behavior. text_encoder.train() - # Set requires_grad = True on the first parameters of the text encoder. Without this, the text encoder LoRA - # would have 0 gradients, and so would not get trained. + # Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder + # LoRA weights would have 0 gradients, and so would not get trained. Note that the set of + # trainable_param_groups has already been populated - the embeddings will not be trained. text_encoder.text_model.embeddings.requires_grad_(True) optimizer = initialize_optimizer(config.optimizer, trainable_param_groups) @@ -529,21 +587,19 @@ def run_training(config: FinetuneLoRASDConfig): # noqa: C901 prepared_result: tuple[ UNet2DConditionModel, CLIPTextModel, - torch.nn.ModuleDict, torch.optim.Optimizer, torch.utils.data.DataLoader, torch.optim.lr_scheduler.LRScheduler, ] = accelerator.prepare( unet, text_encoder, - lora_layers, optimizer, data_loader, lr_scheduler, # Disable automatic device placement for text_encoder if the text encoder outputs were cached. - device_placement=[True, not config.cache_text_encoder_outputs, True, True, True, True], + device_placement=[True, not config.cache_text_encoder_outputs, True, True, True], ) - unet, text_encoder, lora_layers, optimizer, data_loader, lr_scheduler = prepared_result + unet, text_encoder, optimizer, data_loader, lr_scheduler = prepared_result # Calculate the number of epochs and total training steps. A "step" represents a single weight update operation # (i.e. takes into account gradient accumulation steps). @@ -560,14 +616,12 @@ def run_training(config: FinetuneLoRASDConfig): # noqa: C901 epoch_checkpoint_tracker = CheckpointTracker( base_dir=out_dir, prefix="checkpoint_epoch", - extension=f".{config.output.save_model_as}", max_checkpoints=config.max_checkpoints, ) step_checkpoint_tracker = CheckpointTracker( base_dir=out_dir, prefix="checkpoint_step", - extension=f".{config.output.save_model_as}", max_checkpoints=config.max_checkpoints, ) @@ -592,11 +646,9 @@ def run_training(config: FinetuneLoRASDConfig): # noqa: C901 progress_bar.set_description("Steps") for epoch in range(first_epoch, num_train_epochs): - lora_layers.train() - train_loss = 0.0 for data_batch in data_loader: - with accelerator.accumulate(lora_layers): + with accelerator.accumulate(unet, text_encoder): loss = train_forward( config, data_batch, @@ -616,7 +668,7 @@ def run_training(config: FinetuneLoRASDConfig): # noqa: C901 # Backpropagate. accelerator.backward(loss) if accelerator.sync_gradients and config.max_grad_norm is not None: - params_to_clip = lora_layers.parameters() + params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models]) accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -646,7 +698,13 @@ def run_training(config: FinetuneLoRASDConfig): # noqa: C901 if config.save_every_n_steps is not None and (global_step + 1) % config.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: - save_lora_checkpoint(global_step + 1, lora_layers, logger, step_checkpoint_tracker) + save_sd_lora_checkpoint( + idx=global_step + 1, + unet=accelerator.unwrap_model(unet) if config.train_unet else None, + text_encoder=accelerator.unwrap_model(text_encoder) if config.train_text_encoder else None, + logger=logger, + checkpoint_tracker=step_checkpoint_tracker, + ) logs = { "step_loss": loss.detach().item(), @@ -660,8 +718,14 @@ def run_training(config: FinetuneLoRASDConfig): # noqa: C901 # Save a checkpoint every n epochs. if config.save_every_n_epochs is not None and (epoch + 1) % config.save_every_n_epochs == 0: if accelerator.is_main_process: - save_lora_checkpoint(epoch + 1, lora_layers, logger, epoch_checkpoint_tracker) accelerator.wait_for_everyone() + save_sd_lora_checkpoint( + idx=epoch + 1, + unet=accelerator.unwrap_model(unet) if config.train_unet else None, + text_encoder=accelerator.unwrap_model(text_encoder) if config.train_text_encoder else None, + logger=logger, + checkpoint_tracker=epoch_checkpoint_tracker, + ) # Generate validation images every n epochs. if len(config.validation_prompts) > 0 and (epoch + 1) % config.validate_every_n_epochs == 0: diff --git a/src/invoke_training/training/pipelines/stable_diffusion/textual_inversion_sd.py b/src/invoke_training/training/pipelines/stable_diffusion/textual_inversion_sd.py index affb4679..ef4be3b6 100644 --- a/src/invoke_training/training/pipelines/stable_diffusion/textual_inversion_sd.py +++ b/src/invoke_training/training/pipelines/stable_diffusion/textual_inversion_sd.py @@ -238,11 +238,12 @@ def run_training(config: TextualInversionSDConfig): # noqa: C901 text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) if config.gradient_checkpointing: + # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained. + unet.enable_gradient_checkpointing() # unet must be in train() mode for gradient checkpointing to take effect. # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does # not change its forward behavior. unet.train() - unet.enable_gradient_checkpointing() # The text_encoder will be put in .train() mode later, so we don't need to worry about that here. # Note: There are some weird interactions gradient checkpointing and requires_grad_() when training a diff --git a/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_sdxl.py b/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_sdxl.py index fa1bda73..68de82bd 100644 --- a/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_sdxl.py +++ b/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_sdxl.py @@ -1,3 +1,4 @@ +import itertools import json import logging import math @@ -7,6 +8,7 @@ from typing import Optional, Union import numpy as np +import peft import torch import torch.utils.data from accelerate import Accelerator @@ -16,17 +18,13 @@ from diffusers.optimization import get_scheduler from torch.utils.data import DataLoader from tqdm.auto import tqdm -from transformers import CLIPPreTrainedModel, CLIPTextModel, PretrainedConfig, PreTrainedTokenizer +from transformers import CLIPPreTrainedModel, CLIPTextModel, PreTrainedTokenizer from invoke_training.config.pipelines.finetune_lora_config import FinetuneLoRASDXLConfig from invoke_training.config.shared.data.data_loader_config import ( DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig, ) -from invoke_training.core.lora.injection.stable_diffusion import ( - inject_lora_into_clip_text_encoder, - inject_lora_into_unet, -) from invoke_training.training.pipelines.stable_diffusion.finetune_lora_sd import cache_vae_outputs from invoke_training.training.shared.accelerator.accelerator_utils import ( get_mixed_precision_dtype, @@ -42,42 +40,33 @@ ) from invoke_training.training.shared.data.transforms.tensor_disk_cache import TensorDiskCache from invoke_training.training.shared.optimizer.optimizer_utils import initialize_optimizer -from invoke_training.training.shared.stable_diffusion.lora_checkpoint_utils import save_lora_checkpoint +from invoke_training.training.shared.stable_diffusion.lora_checkpoint_utils import ( + save_sdxl_lora_checkpoint, +) from invoke_training.training.shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline from invoke_training.training.shared.stable_diffusion.tokenize_captions import tokenize_captions - -def _import_model_class_for_model(pretrained_model_name_or_path: str, subfolder: str = "", revision: str = "main"): - """Lookup the model class in a diffusers model config, import the class, and return it. This function is useful when - loading models that could be one of many possible classes. - - Args: - pretrained_model_name_or_path (str): The diffusers model name/path. - subfolder (str, optional): The model subfolder. - revision (str, optional): The diffusers model revision. - - - Raises: - ValueError: If the detected model class is not recognize. - - Returns: - type: The model class. - """ - text_encoder_config = PretrainedConfig.from_pretrained( - pretrained_model_name_or_path, subfolder=subfolder, revision=revision - ) - model_class = text_encoder_config.architectures[0] - - if model_class == "CLIPTextModel": - from transformers import CLIPTextModel - - return CLIPTextModel - elif model_class == "CLIPTextModelWithProjection": - from transformers import CLIPTextModelWithProjection - - return CLIPTextModelWithProjection - else: - raise ValueError(f"{model_class} is not supported.") +# Copied from https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/stable_diffusion/train_dreambooth.py#L49C1-L65C87 +# TODO(ryand): Confirm that this is the set of modules that we want to use. +# UNET_TARGET_MODULES = [ +# "to_q", +# "to_k", +# "to_v", +# "proj", +# "proj_in", +# "proj_out", +# "conv", +# "conv1", +# "conv2", +# "conv_shortcut", +# "to_out.0", +# "time_emb_proj", +# "ff.net.2", +# ] +# TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"] +# Module lists copied from diffusers training script: +UNET_TARGET_MODULES = ["to_k", "to_q", "to_v", "to_out.0"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "out_proj"] def load_models( @@ -480,6 +469,8 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 if config.xformers: import xformers # noqa: F401 + # TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers + # will fail because Q, K, V have different dtypes. unet.enable_xformers_memory_efficient_attention() vae.enable_xformers_memory_efficient_attention() @@ -544,30 +535,59 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 unet.to(accelerator.device, dtype=weight_dtype) - lora_layers = torch.nn.ModuleDict() + # Add LoRA layers to the models being trained. trainable_param_groups = [] + all_trainable_models: list[peft.PeftModel] = [] + + def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = None) -> peft.PeftModel: + peft_model = peft.get_peft_model(model, lora_config) + peft_model.print_trainable_parameters() + + # Populate `trainable_param_groups`, to be passed to the optimizer. + # Note: PeftModel.parameters() returns only the trainable LoRA params. + param_group = {"params": list(filter(lambda p: p.requires_grad, peft_model.parameters()))} + if lr is not None: + param_group["lr"] = lr + trainable_param_groups.append(param_group) + + # Populate all_trainable_models. + all_trainable_models.append(peft_model) + + peft_model.train() + + return peft_model + if config.train_unet: - lora_layers["unet"] = inject_lora_into_unet( - unet, config.train_unet_non_attention_blocks, lora_rank_dim=config.lora_rank_dim + unet_lora_config = peft.LoraConfig( + r=config.lora_rank_dim, + # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred? + lora_alpha=1.0, + target_modules=UNET_TARGET_MODULES, ) - unet_param_group = {"params": lora_layers["unet"].parameters()} - if config.unet_learning_rate is not None: - unet_param_group["lr"] = config.unet_learning_rate - trainable_param_groups.append(unet_param_group) + unet = inject_lora_layers(unet, unet_lora_config, lr=config.unet_learning_rate) + if config.train_text_encoder: - for te_model, key, lora_prefix in [ - (text_encoder_1, "text_encoder_1", "lora_te1"), - (text_encoder_2, "text_encoder_2", "lora_te2"), - ]: - lora_layers[key] = inject_lora_into_clip_text_encoder( - te_model, lora_prefix, lora_rank_dim=config.lora_rank_dim - ) - text_encoder_param_group = {"params": lora_layers[key].parameters()} - if config.text_encoder_learning_rate is not None: - text_encoder_param_group["lr"] = config.text_encoder_learning_rate - trainable_param_groups.append(text_encoder_param_group) + text_encoder_lora_config = peft.LoraConfig( + r=config.lora_rank_dim, + lora_alpha=1.0, + # init_lora_weights="gaussian", + target_modules=TEXT_ENCODER_TARGET_MODULES, + ) + text_encoder_1 = inject_lora_layers( + text_encoder_1, text_encoder_lora_config, lr=config.text_encoder_learning_rate + ) + text_encoder_2 = inject_lora_layers( + text_encoder_2, text_encoder_lora_config, lr=config.text_encoder_learning_rate + ) + + # Make sure all trainable params are in float32. + for trainable_model in all_trainable_models: + for param in trainable_model.parameters(): + if param.requires_grad: + param.data = param.to(torch.float32) if config.gradient_checkpointing: + # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained. unet.enable_gradient_checkpointing() # unet must be in train() mode for gradient checkpointing to take effect. # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does @@ -577,13 +597,16 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 for te in [text_encoder_1, text_encoder_2]: te.gradient_checkpointing_enable() - # The text encoders must be in train() mode for gradient checkpointing to take effect. + # The text encoders must be in train() mode for gradient checkpointing to take effect. This should + # already be the case, since we are training the text_encoders, be we do it explicitly to make it clear + # that this is required. # At the time of writing, the text encoder dropout probabilities default to 0, so putting the text # encoders in train mode does not change their forward behavior. te.train() # Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder - # LoRA weights would have 0 gradients, and so would not get trained. + # LoRA weights would have 0 gradients, and so would not get trained. Note that the set of + # trainable_param_groups has already been populated - the embeddings will not be trained. te.text_model.embeddings.requires_grad_(True) optimizer = initialize_optimizer(config.optimizer, trainable_param_groups) @@ -609,9 +632,8 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 prepared_result: tuple[ UNet2DConditionModel, - CLIPPreTrainedModel, - CLIPPreTrainedModel, - torch.nn.ModuleDict, + peft.PeftModel | CLIPTextModel, + peft.PeftModel | CLIPTextModel, torch.optim.Optimizer, torch.utils.data.DataLoader, torch.optim.lr_scheduler.LRScheduler, @@ -619,7 +641,6 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 unet, text_encoder_1, text_encoder_2, - lora_layers, optimizer, data_loader, lr_scheduler, @@ -631,10 +652,9 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 True, True, True, - True, ], ) - unet, text_encoder_1, text_encoder_2, lora_layers, optimizer, data_loader, lr_scheduler = prepared_result + unet, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result # Calculate the number of epochs and total training steps. A "step" represents a single weight update operation # (i.e. takes into account gradient accumulation steps). @@ -651,14 +671,12 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 epoch_checkpoint_tracker = CheckpointTracker( base_dir=out_dir, prefix="checkpoint_epoch", - extension=f".{config.output.save_model_as}", max_checkpoints=config.max_checkpoints, ) step_checkpoint_tracker = CheckpointTracker( base_dir=out_dir, prefix="checkpoint_step", - extension=f".{config.output.save_model_as}", max_checkpoints=config.max_checkpoints, ) @@ -683,11 +701,9 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 progress_bar.set_description("Steps") for epoch in range(first_epoch, num_train_epochs): - lora_layers.train() - train_loss = 0.0 for data_batch in data_loader: - with accelerator.accumulate(lora_layers): + with accelerator.accumulate(unet, text_encoder_1, text_encoder_2): loss = train_forward( accelerator, data_batch, @@ -711,7 +727,7 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 # Backpropagate. accelerator.backward(loss) if accelerator.sync_gradients and config.max_grad_norm is not None: - params_to_clip = lora_layers.parameters() + params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models]) accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -741,7 +757,14 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 if config.save_every_n_steps is not None and (global_step + 1) % config.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: - save_lora_checkpoint(global_step + 1, lora_layers, logger, step_checkpoint_tracker) + save_sdxl_lora_checkpoint( + idx=global_step + 1, + unet=unet, + text_encoder_1=text_encoder_1, + text_encoder_2=text_encoder_2, + logger=logger, + checkpoint_tracker=step_checkpoint_tracker, + ) logs = { "step_loss": loss.detach().item(), @@ -755,7 +778,14 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 # Save a checkpoint every n epochs. if config.save_every_n_epochs is not None and (epoch + 1) % config.save_every_n_epochs == 0: if accelerator.is_main_process: - save_lora_checkpoint(epoch + 1, lora_layers, logger, epoch_checkpoint_tracker) + save_sdxl_lora_checkpoint( + idx=epoch + 1, + unet=unet, + text_encoder_1=text_encoder_1, + text_encoder_2=text_encoder_2, + logger=logger, + checkpoint_tracker=epoch_checkpoint_tracker, + ) accelerator.wait_for_everyone() # Generate validation images every n epochs. diff --git a/src/invoke_training/training/pipelines/stable_diffusion_xl/textual_inversion_sdxl.py b/src/invoke_training/training/pipelines/stable_diffusion_xl/textual_inversion_sdxl.py index d39b3eb2..5c1ff416 100644 --- a/src/invoke_training/training/pipelines/stable_diffusion_xl/textual_inversion_sdxl.py +++ b/src/invoke_training/training/pipelines/stable_diffusion_xl/textual_inversion_sdxl.py @@ -176,6 +176,7 @@ def run_training(config: TextualInversionSDXLConfig): # noqa: C901 text_encoder_2.text_model.embeddings.token_embedding.requires_grad_(True) if config.gradient_checkpointing: + # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained. unet.enable_gradient_checkpointing() # unet must be in train() mode for gradient checkpointing to take effect. # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does diff --git a/src/invoke_training/training/shared/stable_diffusion/lora_checkpoint_utils.py b/src/invoke_training/training/shared/stable_diffusion/lora_checkpoint_utils.py index 23224887..8b9ee4bb 100644 --- a/src/invoke_training/training/shared/stable_diffusion/lora_checkpoint_utils.py +++ b/src/invoke_training/training/shared/stable_diffusion/lora_checkpoint_utils.py @@ -1,41 +1,198 @@ import logging +from pathlib import Path +import peft import torch +from diffusers import UNet2DConditionModel +from transformers import CLIPTextModel -from invoke_training.core.lora.injection.stable_diffusion import ( - convert_lora_state_dict_to_kohya_format, -) from invoke_training.training.shared.checkpoints.checkpoint_tracker import CheckpointTracker from invoke_training.training.shared.checkpoints.serialization import save_state_dict +SD_PEFT_UNET_KEY = "unet" +SD_PEFT_TEXT_ENCODER_KEY = "text_encoder" -def save_lora_checkpoint( +SDXL_PEFT_UNET_KEY = "unet" +SDXL_PEFT_TEXT_ENCODER_1_KEY = "text_encoder_1" +SDXL_PEFT_TEXT_ENCODER_2_KEY = "text_encoder_2" + + +def save_multi_model_peft_checkpoint(checkpoint_dir: Path, models: dict[str, peft.PeftModel]): + """Save a dict of PeftModels to a checkpoint directory. + + The `models` dict keys are used as the subdirectories for each individual model. + + `load_multi_model_peft_checkpoint(...)` can be used to load the resultant checkpoint. + """ + for model_key, peft_model in models.items(): + assert isinstance(peft_model, peft.PeftModel) + peft_model.save_pretrained(str(checkpoint_dir / model_key)) + + +def load_multi_model_peft_checkpoint( + checkpoint_dir: Path, + models: dict[str, torch.nn.Module], + is_trainable: bool = False, + raise_if_subdir_missing: bool = True, +) -> dict[str, torch.nn.Module]: + """Load a multi-model PEFT checkpoint that was saved with `save_multi_model_peft_checkpoint(...)`.""" + assert checkpoint_dir.exists() + + out_models = {} + for model_key, model in models: + dir_path: Path = checkpoint_dir / model_key + if dir_path.exists(): + out_models[model_key] = peft.PeftModel.from_pretrained(model, dir_path, is_trainable=is_trainable) + else: + if raise_if_subdir_missing: + raise ValueError(f"'{dir_path}' does not exist.") + else: + # Pass through the model unchanged. + out_models[model_key] = model + + return out_models + + +def save_sd_peft_checkpoint(checkpoint_dir: Path, unet: peft.PeftModel | None, text_encoder: peft.PeftModel | None): + models = {} + if unet is not None: + models[SD_PEFT_UNET_KEY] = unet + if text_encoder is not None: + models[SD_PEFT_TEXT_ENCODER_KEY] = text_encoder + + save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models) + + +def load_sd_peft_checkpoint( + checkpoint_dir: Path, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, is_trainable: bool = False +): + models = load_multi_model_peft_checkpoint( + checkpoint_dir=checkpoint_dir, + models={SD_PEFT_UNET_KEY: unet, SD_PEFT_TEXT_ENCODER_KEY: text_encoder}, + is_trainable=is_trainable, + raise_if_subdir_missing=False, + ) + + return models[SD_PEFT_UNET_KEY], models[SD_PEFT_TEXT_ENCODER_KEY] + + +def save_sdxl_peft_checkpoint( + checkpoint_dir: Path, + unet: peft.PeftModel | None, + text_encoder_1: peft.PeftModel | None, + text_encoder_2: peft.PeftModel | None, +): + models = {} + if unet is not None: + models[SDXL_PEFT_UNET_KEY] = unet + if text_encoder_1 is not None: + models[SDXL_PEFT_TEXT_ENCODER_1_KEY] = text_encoder_1 + if text_encoder_2 is not None: + models[SDXL_PEFT_TEXT_ENCODER_2_KEY] = text_encoder_2 + + save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models) + + +def load_sdxl_peft_checkpoint( + checkpoint_dir: Path, + unet: UNet2DConditionModel, + text_encoder_1: CLIPTextModel, + text_encoder_2: CLIPTextModel, + is_trainable: bool = False, +): + models = load_multi_model_peft_checkpoint( + checkpoint_dir=checkpoint_dir, + models={ + SDXL_PEFT_UNET_KEY: unet, + SDXL_PEFT_TEXT_ENCODER_1_KEY: text_encoder_1, + SDXL_PEFT_TEXT_ENCODER_2_KEY: text_encoder_2, + }, + is_trainable=is_trainable, + raise_if_subdir_missing=False, + ) + + return models[SDXL_PEFT_UNET_KEY], models[SDXL_PEFT_TEXT_ENCODER_1_KEY], models[SDXL_PEFT_TEXT_ENCODER_2_KEY] + + +def save_sd_lora_checkpoint( idx: int, - lora_layers: torch.nn.ModuleDict, + unet: peft.PeftModel | None, + text_encoder: peft.PeftModel | None, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, ): - """Save a LoRA checkpoint. Old checkpoints are deleted if necessary to respect the config.max_checkpoints config. - - Args: - idx (int): The checkpoint index (typically step count or epoch). - lora_layers (torch.nn.ModuleDict): The LoRA layers to save in a ModuleDict mapping keys to - `LoRALayerCollection`s. - logger (logging.Logger): Logger. - checkpoint_tracker (CheckpointTracker): The checkpoint tracker. - """ # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) if num_pruned > 0: logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(idx) - state_dict = {} - for model_lora_layers in lora_layers.values(): - model_state_dict = model_lora_layers.get_lora_state_dict() - model_kohya_state_dict = convert_lora_state_dict_to_kohya_format(model_state_dict) - state_dict.update(model_kohya_state_dict) + save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) + + +def save_sdxl_lora_checkpoint( + idx: int, + unet: peft.PeftModel | None, + text_encoder_1: peft.PeftModel | None, + text_encoder_2: peft.PeftModel | None, + logger: logging.Logger, + checkpoint_tracker: CheckpointTracker, +): + # Prune checkpoints and get new checkpoint path. + num_pruned = checkpoint_tracker.prune(1) + if num_pruned > 0: + logger.info(f"Pruned {num_pruned} checkpoint(s).") + save_path = checkpoint_tracker.get_path(idx) + + save_sdxl_peft_checkpoint(Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2) + + +# This implementation is based on +# https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/lora_dreambooth/convert_peft_sd_lora_to_kohya_ss.py#L20 +def _convert_peft_state_dict_to_kohya_state_dict( + lora_config: peft.LoraConfig, + peft_state_dict: dict[str, torch.Tensor], + prefix: str, + dtype: torch.dtype, +) -> dict[str, torch.Tensor]: + kohya_ss_state_dict = {} + for peft_key, weight in peft_state_dict.items(): + kohya_key = peft_key.replace("base_model.model", prefix) + kohya_key = kohya_key.replace("lora_A", "lora_down") + kohya_key = kohya_key.replace("lora_B", "lora_up") + kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) + kohya_ss_state_dict[kohya_key] = weight.to(dtype) + + # Set alpha parameter + if "lora_down" in kohya_key: + alpha_key = f'{kohya_key.split(".")[0]}.alpha' + kohya_ss_state_dict[alpha_key] = torch.tensor(lora_config.lora_alpha).to(dtype) + + return kohya_ss_state_dict + + +def convert_sd_peft_checkpoint_to_kohya_state_dict( + in_checkpoint_dir: Path, + out_checkpoint_file: Path, + dtype: torch.dtype = torch.float32, +) -> dict[str, torch.Tensor]: + """Convert SD v1 PEFT models to a Kohya-format LoRA state dict.""" + kohya_state_dict = {} + for kohya_prefix, peft_model_key in [("lora_unet", SD_PEFT_UNET_KEY), ("lora_te", SD_PEFT_TEXT_ENCODER_KEY)]: + peft_model_dir = in_checkpoint_dir / peft_model_key + + if peft_model_dir.exists(): + # Note: This logic to load the LoraConfig and weights directly is based on how it is done here: + # https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/peft_model.py#L672-L689 + # This may need to be updated in the future to support other adapter types (LoKr, LoHa, etc.). + # Also, I could see this interface breaking in the future. + lora_config = peft.LoraConfig.from_pretrained(peft_model_dir) + lora_weights = peft.utils.load_peft_weights(peft_model_dir, device="cpu") + + kohya_state_dict.update( + _convert_peft_state_dict_to_kohya_state_dict( + lora_config=lora_config, peft_state_dict=lora_weights, prefix=kohya_prefix, dtype=dtype + ) + ) - save_state_dict(state_dict, save_path) - # accelerator.save_state(save_path) - logger.info(f"Saved state to '{save_path}'.") + save_state_dict(kohya_state_dict, out_checkpoint_file) diff --git a/tests/invoke_training/core/lora/__init__.py b/tests/invoke_training/core/lora/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/invoke_training/core/lora/injection/__init__.py b/tests/invoke_training/core/lora/injection/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/invoke_training/core/lora/injection/test_lora_layer_collection.py b/tests/invoke_training/core/lora/injection/test_lora_layer_collection.py deleted file mode 100644 index 8f97fe4d..00000000 --- a/tests/invoke_training/core/lora/injection/test_lora_layer_collection.py +++ /dev/null @@ -1,37 +0,0 @@ -import pytest - -from invoke_training.core.lora.injection.lora_layer_collection import LoRALayerCollection -from invoke_training.core.lora.layers import LoRALinearLayer - - -def test_lora_layer_collection_state_dict(): - """Test the behavior of LoRALayerCollection.get_lora_state_dict().""" - lora_layers = LoRALayerCollection() - - lora_layers.add_layer(LoRALinearLayer(8, 16), "lora_layer_1") - lora_layers.add_layer(LoRALinearLayer(16, 32), "lora_layer_2") - - state_dict = lora_layers.get_lora_state_dict() - - expected_state_keys = { - "lora_layer_1._down.weight", - "lora_layer_1._up.weight", - "lora_layer_1.alpha", - "lora_layer_2._down.weight", - "lora_layer_2._up.weight", - "lora_layer_2.alpha", - } - assert set(state_dict.keys()) == expected_state_keys - - -def test_lora_layer_collection_state_dict_conflicting_keys(): - """Test that LoRALayerCollection.get_lora_state_dict() raises an exception if state Tensors have conflicting - keys. - """ - lora_layers = LoRALayerCollection() - - lora_layers.add_layer(LoRALinearLayer(8, 16), "lora_layer_1") - lora_layers.add_layer(LoRALinearLayer(16, 32), "lora_layer_1") # Insert same layer type with same key. - - with pytest.raises(RuntimeError): - _ = lora_layers.get_lora_state_dict() diff --git a/tests/invoke_training/core/lora/injection/test_stable_diffusion.py b/tests/invoke_training/core/lora/injection/test_stable_diffusion.py deleted file mode 100644 index 506df81d..00000000 --- a/tests/invoke_training/core/lora/injection/test_stable_diffusion.py +++ /dev/null @@ -1,183 +0,0 @@ -import pytest -import torch -from diffusers.models import UNet2DConditionModel -from transformers import CLIPTextModel - -from invoke_training.core.lora.injection.stable_diffusion import ( - convert_lora_state_dict_to_kohya_format, - inject_lora_into_clip_text_encoder, - inject_lora_into_unet, -) - - -@pytest.mark.loads_model -@pytest.mark.parametrize( - ["model_name", "expected_num_layers"], - [ - ("runwayml/stable-diffusion-v1-5", 192), - ("stabilityai/stable-diffusion-xl-base-1.0", 722), - ], -) -def test_inject_lora_into_unet_smoke(model_name: str, expected_num_layers: int): - """Smoke test of inject_lora_into_unet(...).""" - unet = UNet2DConditionModel.from_pretrained( - model_name, - subfolder="unet", - variant="fp16", - local_files_only=True, - ) - lora_layers = inject_lora_into_unet(unet) - - # These assertions are based on a manual check of the injected layers and comparison against the behaviour of - # kohya_ss. They are included here to force another manual review after any future breaking change. - assert len(lora_layers) == expected_num_layers - for layer_name in lora_layers._names: - assert layer_name.endswith( - ("to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2", ".proj_in", ".proj_out") - ) - - -@pytest.mark.loads_model -@pytest.mark.parametrize( - ["model_name", "expected_num_layers"], - [ - ("runwayml/stable-diffusion-v1-5", 278), - ("stabilityai/stable-diffusion-xl-base-1.0", 788), - ], -) -def test_inject_lora_into_unet_non_attention_layers_smoke(model_name: str, expected_num_layers: int): - """Smoke test of inject_lora_into_unet(..., include_non_attention_blocks=True).""" - unet = UNet2DConditionModel.from_pretrained( - model_name, - subfolder="unet", - variant="fp16", - local_files_only=True, - ) - lora_layers = inject_lora_into_unet(unet, include_non_attention_blocks=True) - - # These assertions are based on a manual check of the injected layers and comparison against the behaviour of - # kohya_ss. They are included here to force another manual review after any future breaking change. - assert len(lora_layers) == expected_num_layers - for layer_name in lora_layers._names: - assert layer_name.endswith( - ( - "to_q", - "to_k", - "to_v", - "to_out.0", - "ff.net.0.proj", - "ff.net.2", - ".proj_in", - ".proj_out", - ".conv1", - ".conv2", - ".time_emb_proj", - ".conv", - ".conv_shortcut", - ) - ) - - -@pytest.mark.loads_model -@pytest.mark.parametrize( - ["model_name", "text_encoder_name", "expected_num_layers"], - [ - ("stabilityai/stable-diffusion-xl-base-1.0", "text_encoder", 72), - ("stabilityai/stable-diffusion-xl-base-1.0", "text_encoder_2", 192), - ("runwayml/stable-diffusion-v1-5", "text_encoder", 72), - ], -) -def test_inject_lora_into_clip_text_encoder_smoke(model_name, text_encoder_name, expected_num_layers): - """Smoke test of inject_lora_into_clip_text_encoder(...).""" - text_encoder = CLIPTextModel.from_pretrained( - model_name, - subfolder=text_encoder_name, - variant="fp16", - local_files_only=True, - ) - - lora_layers = inject_lora_into_clip_text_encoder(text_encoder) - - # These assertions are based on a manual check of the injected layers and comparison against the behaviour of - # kohya_ss. They are included here to force another manual review after any future breaking change. - assert len(lora_layers) == expected_num_layers - for layer_name in lora_layers._names: - assert layer_name.endswith(("mlp.fc1", "mlp.fc2", "k_proj", "out_proj", "q_proj", "v_proj")) - - -@pytest.mark.loads_model -@pytest.mark.loads_model -@pytest.mark.parametrize( - ["model_name", "expected_num_layers"], - [ - ("runwayml/stable-diffusion-v1-5", 192), - ("stabilityai/stable-diffusion-xl-base-1.0", 722), - ], -) -def test_convert_lora_state_dict_to_kohya_format_smoke(model_name: str, expected_num_layers: int): - """Smoke test of convert_lora_state_dict_to_kohya_format(...) with full SD 1.5 model.""" - unet = UNet2DConditionModel.from_pretrained( - model_name, - subfolder="unet", - variant="fp16", - local_files_only=True, - ) - lora_layers = inject_lora_into_unet(unet) - lora_state_dict = lora_layers.get_lora_state_dict() - kohya_state_dict = convert_lora_state_dict_to_kohya_format(lora_state_dict) - - # These assertions are based on a manual check of the injected layers and comparison against the behaviour of - # kohya_ss. They are included here to force another manual review after any future breaking change. - assert len(kohya_state_dict) == expected_num_layers * 3 - for key in kohya_state_dict.keys(): - assert key.startswith("lora_unet_") - assert key.endswith((".lora_down.weight", ".lora_up.weight", ".alpha")) - - -def test_convert_lora_state_dict_to_kohya_format(): - """Basic test of convert_lora_state_dict_to_kohya_format(...).""" - down_weight = torch.Tensor(4, 2) - up_weight = torch.Tensor(2, 4) - alpha = torch.Tensor([1.0]) - in_state_dict = { - "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.weight": down_weight, - "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._up.weight": up_weight, - "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.alpha": alpha, - } - - out_state_dict = convert_lora_state_dict_to_kohya_format(in_state_dict) - - expected_out_state_dict = { - "lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight": down_weight, - "lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_up.weight": up_weight, - "lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.alpha": alpha, - } - - assert out_state_dict == expected_out_state_dict - - -def test_convert_lora_state_dict_to_kohya_format_unexpected_key(): - """Test that convert_lora_state_dict_to_kohya_format(...) raises an exception if it receives an unexpected - key. - """ - in_state_dict = { - "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.unexpected": torch.Tensor(4, 2), - } - - with pytest.raises(ValueError): - _ = convert_lora_state_dict_to_kohya_format(in_state_dict) - - -def test_convert_lora_state_dict_to_kohya_format_conflicting_keys(): - """Test that convert_lora_state_dict_to_kohya_format(...) raises an exception if multiple keys map to the same - output key. - """ - # Note: There are differences in the '.' and '_' characters of these keys, but they both map to the same output - # kohya_ss keys. - in_state_dict = { - "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.weight": torch.Tensor(4, 2), - "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1_to_q._down.weight": torch.Tensor(4, 2), - } - - with pytest.raises(RuntimeError): - _ = convert_lora_state_dict_to_kohya_format(in_state_dict) diff --git a/tests/invoke_training/core/lora/injection/test_utils.py b/tests/invoke_training/core/lora/injection/test_utils.py deleted file mode 100644 index 8f1f4703..00000000 --- a/tests/invoke_training/core/lora/injection/test_utils.py +++ /dev/null @@ -1,202 +0,0 @@ -import torch - -from invoke_training.core.lora.injection.utils import find_modules, inject_lora_layers -from invoke_training.core.lora.layers import LoRALinearLayer -from invoke_training.core.lora.lora_block import LoRABlock - - -def test_find_modules_simple(): - """Test find_modules(...) behaviour on a simple ModuleDict structure.""" - # Construct mock module. - linear1 = torch.nn.Linear(4, 8) - linear2 = torch.nn.Linear(8, 16) - conv1 = torch.nn.Conv2d(16, 32, 5) - module = torch.nn.ModuleDict( - { - "linear1": linear1, - "linear2": linear2, - "conv1": conv1, - } - ) - - result = list(find_modules(module, {torch.nn.Linear})) - - # Validate result. - assert len(result) == 2 - result_by_name = {n: (n, p, m) for (n, p, m) in result} - - assert result_by_name["linear1"][0] == "linear1" - assert result_by_name["linear1"][1] == module - assert result_by_name["linear1"][2] == linear1 - - assert result_by_name["linear2"][0] == "linear2" - assert result_by_name["linear2"][1] == module - assert result_by_name["linear2"][2] == linear2 - - -def test_find_modules_nested(): - """Test find_modules(...) behaviour when target modules are nested.""" - # Construct mock module. - linear1 = torch.nn.Linear(4, 8) - linear2 = torch.nn.Linear(8, 16) - conv1 = torch.nn.Conv2d(16, 32, 5) - dict1 = torch.nn.ModuleDict({"linear2": linear2}) - module = torch.nn.ModuleDict( - { - "linear1": linear1, - "dict1": dict1, - "conv1": conv1, - } - ) - - result = list(find_modules(module, {torch.nn.Linear, torch.nn.ModuleDict})) - - # Validate result. - assert len(result) == 4 - result_by_name = {n: (n, p, m) for (n, p, m) in result} - - assert result_by_name[""][0] == "" - assert result_by_name[""][1] is None - assert result_by_name[""][2] == module - - assert result_by_name["dict1"][0] == "dict1" - assert result_by_name["dict1"][1] == module - assert result_by_name["dict1"][2] == dict1 - - assert result_by_name["linear1"][0] == "linear1" - assert result_by_name["linear1"][1] == module - assert result_by_name["linear1"][2] == linear1 - - assert result_by_name["dict1.linear2"][0] == "dict1.linear2" - assert result_by_name["dict1.linear2"][1] == dict1 - assert result_by_name["dict1.linear2"][2] == linear2 - - -def test_find_modules_include_descendants_of(): - """Test include_descendants_of parameter to find_modules(...).""" - # Construct mock module. - linear1 = torch.nn.Linear(4, 8) - linear2 = torch.nn.Linear(8, 16) - conv1 = torch.nn.Conv2d(16, 32, 5) - list1 = torch.nn.ModuleList([conv1, linear2]) - module = torch.nn.ModuleDict( - { - "linear1": linear1, - "list1": list1, - } - ) - - # linear1 should be ignored, because it is not a descendant of a ModuleList. - result = list(find_modules(module, {torch.nn.Linear}, include_descendants_of={torch.nn.ModuleList})) - - # Validate result. - assert len(result) == 1 - result_by_name = {n: (n, p, m) for (n, p, m) in result} - - assert result_by_name["list1.1"][0] == "list1.1" - assert result_by_name["list1.1"][1] == list1 - assert result_by_name["list1.1"][2] == linear2 - - -def test_find_modules_exclude_descendants_of(): - """Test exclude_descendants_of parameter to find_modules(...).""" - # Construct mock module. - linear1 = torch.nn.Linear(4, 8) - linear2 = torch.nn.Linear(8, 16) - conv1 = torch.nn.Conv2d(16, 32, 5) - list1 = torch.nn.ModuleList( - [conv1, linear2] - ) # linear2 should be ignored, because it is a descendant of a ModuleList. - module = torch.nn.ModuleDict( - { - "linear1": linear1, - "list1": list1, - } - ) - - result = list(find_modules(module, {torch.nn.Linear}, exclude_descendants_of={torch.nn.ModuleList})) - - # Validate result. - assert len(result) == 1 - result_by_name = {n: (n, p, m) for (n, p, m) in result} - - assert result_by_name["linear1"][0] == "linear1" - assert result_by_name["linear1"][1] == module - assert result_by_name["linear1"][2] == linear1 - - -def test_find_modules_exclude_precedence_over_include(): - """Test that exclude_descendants_of takes precedence over include_descendants_of find_modules(...).""" - # Construct mock module. - linear1 = torch.nn.Linear(4, 8) - list1 = torch.nn.ModuleList([linear1]) - module = torch.nn.ModuleDict({"list1": list1}) - - # Test that when a descendant is excluded, exclude_descendants_of takes precedence over - # include_descendants_of. - result = list( - find_modules( - module, - {torch.nn.Linear}, - include_descendants_of={torch.nn.ModuleDict}, - exclude_descendants_of={torch.nn.ModuleList}, - ) - ) - assert len(result) == 0 - - # Test that when an ancestor is excluded, exclude_descendants_of takes precedence over - # include_descendants_of. - result = list( - find_modules( - module, - {torch.nn.Linear}, - include_descendants_of={torch.nn.ModuleList}, - exclude_descendants_of={torch.nn.ModuleDict}, - ) - ) - assert len(result) == 0 - - -def test_find_modules_duplicate(): - """Test that duplicate modules are only returned once.""" - # Construct mock module. Include linear1 twice. - linear1 = torch.nn.Linear(4, 8) - list1 = torch.nn.ModuleList([linear1]) - module = torch.nn.ModuleDict( - { - "linear1": linear1, - "list1": list1, - } - ) - - result = list(find_modules(module, {torch.nn.Linear}, exclude_descendants_of={torch.nn.ModuleList})) - - # Validate result. - assert len(result) == 1 - result_by_name = {n: (n, p, m) for (n, p, m) in result} - - assert result_by_name["linear1"][0] == "linear1" - assert result_by_name["linear1"][1] == module - assert result_by_name["linear1"][2] == linear1 - - -def test_inject_lora_layers(): - # Construct mock module. - linear1 = torch.nn.Linear(4, 8) - conv1 = torch.nn.Conv2d(16, 32, 5) - module = torch.nn.ModuleDict( - { - "linear1": linear1, - "conv1": conv1, - } - ) - - lora_layers = inject_lora_layers(module, {torch.nn.Linear: LoRALinearLayer}, prefix="lora_unet") - - assert len(lora_layers) == 1 - assert all([k.startswith("lora_unet") for k in lora_layers.get_lora_state_dict()]) - - assert isinstance(module["linear1"], LoRABlock) - assert module["linear1"].original_module == linear1 - assert module["linear1"].lora_layer._down.in_features == linear1.in_features - assert module["linear1"].lora_layer._up.out_features == linear1.out_features diff --git a/tests/invoke_training/core/lora/layers/__init__.py b/tests/invoke_training/core/lora/layers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/invoke_training/core/lora/layers/test_lora_conv_layer.py b/tests/invoke_training/core/lora/layers/test_lora_conv_layer.py deleted file mode 100644 index 25032a9c..00000000 --- a/tests/invoke_training/core/lora/layers/test_lora_conv_layer.py +++ /dev/null @@ -1,189 +0,0 @@ -import typing - -import pytest -import torch - -from invoke_training.core.lora.layers import ( - LoRAConv1dLayer, - LoRAConv2dLayer, - LoRAConv3dLayer, -) -from invoke_training.core.lora.layers.lora_conv_layer import LoRAConvLayer - - -def test_lora_conv_layer_initialize_base_class(): - """Test that attempting to directly initialize a LoRAConvLayer raise a NotImplementedError.""" - with pytest.raises(NotImplementedError): - _ = LoRAConvLayer(4, 8) - - -@pytest.mark.parametrize( - ["lora_conv_cls", "conv_dims"], [(LoRAConv1dLayer, 1), (LoRAConv2dLayer, 2), (LoRAConv3dLayer, 3)] -) -class TestLoRAConvLayers: - """Test class for applying tests to each of the LoRAConv*Layer classes.""" - - def test_lora_conv_layer_output_dim(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int): - """Test that LoRAConv*Layer produces an output with the expected dimensions.""" - batch_size = 10 - in_channels = 8 - out_channels = 16 - layer = lora_conv_cls(in_channels, out_channels) - - in_shape = (batch_size, in_channels) + (5,) * conv_dims - x = torch.rand(in_shape) - with torch.no_grad(): - y = layer(x) - - expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims - assert y.shape == expected_out_shape - - def test_lora_conv_layer_invalid_input_dim(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int): - """Test that LoRAConv*Layer raises an exception if it receives an input with invalid dimensions.""" - batch_size = 10 - in_channels = 8 - out_channels = 16 - layer = lora_conv_cls(in_channels, out_channels) - - in_shape = (batch_size, in_channels + 1) + (5,) * conv_dims # Bad input dimension. - x = torch.rand(in_shape) - with pytest.raises(RuntimeError): - _ = layer(x) - - def test_lora_conv_layer_zero_after_init(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int): - """Test that a newly-initialized LoRAConv*Layer produces all zeros before it is trained.""" - batch_size = 10 - in_channels = 8 - out_channels = 16 - layer = lora_conv_cls(in_channels, out_channels) - - in_shape = (batch_size, in_channels) + (5,) * conv_dims - x = torch.rand(in_shape) - with torch.no_grad(): - y = layer(x) - - assert not torch.allclose(x, torch.Tensor([0.0]), rtol=0.0) # The random input was non-zero. - assert torch.allclose(y, torch.Tensor([0.0]), rtol=0.0) # The untrained outputs are zero. - - def test_lora_conv_layer_from_layer(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int): - """Test that a LoRAConv*Layer can be initialized correctly from a torch.nn.Conv* layer.""" - batch_size = 10 - in_channels = 8 - out_channels = 16 - original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3, padding="same") - - lora_layer = lora_conv_cls.from_layer(original_layer) - - in_shape = (batch_size, in_channels) + (5,) * conv_dims - x = torch.rand(in_shape) - with torch.no_grad(): - y = lora_layer(x) - - expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims - assert y.shape == expected_out_shape - - def test_lora_conv_layer_from_layer_kernel_and_stride( - self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int - ): - """Test that a LoRAConv*Layer is initialized with the correct kernel_size, stride, and padding when initialized - from a torch.nn.Conv* layer.""" - batch_size = 10 - in_channels = 8 - out_channels = 16 - original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3, stride=2, padding="valid") - - lora_layer = lora_conv_cls.from_layer(original_layer) - - # Check the internal layer config. - assert lora_layer._down.kernel_size == original_layer.kernel_size - assert lora_layer._down.stride == original_layer.stride - assert lora_layer._down.padding == original_layer.padding - - in_shape = (batch_size, in_channels) + (6,) * conv_dims - x = torch.rand(in_shape) - with torch.no_grad(): - y = lora_layer(x) - - # The combination of kernel_size, stride, and padding should reduce the dimensions to this output shape: - expected_out_shape = (batch_size, out_channels) + (2,) * conv_dims - assert y.shape == expected_out_shape - - @pytest.mark.cuda - @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) - def test_lora_conv_layer_from_layer_inherit_device_and_dtype( - self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int, dtype: torch.dtype - ): - """Test that when a LoRAConv*Layer is initialized with from_layer(...), it correctly inherits the device and - dtype. - """ - batch_size = 10 - in_channels = 8 - out_channels = 16 - original_layer = lora_conv_cls.conv_module( - in_channels, out_channels, kernel_size=3, padding="same", device=torch.device("cuda"), dtype=dtype - ) - - lora_layer = lora_conv_cls.from_layer(original_layer) - - in_shape = (batch_size, in_channels) + (5,) * conv_dims - x = torch.rand(in_shape, device=torch.device("cuda"), dtype=dtype) - with torch.no_grad(): - y = lora_layer(x) - - expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims - assert y.shape == expected_out_shape - # Assert that lora_layer's internal layers have correct device and dtype. - assert lora_layer._down.weight.device == original_layer.weight.device - assert lora_layer._down.weight.dtype == original_layer.weight.dtype - assert lora_layer._up.weight.device == original_layer.weight.device - assert lora_layer._up.weight.dtype == original_layer.weight.dtype - - @pytest.mark.cuda - @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) - def test_lora_conv_layer_from_layer_override_device_and_dtype( - self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int, dtype: torch.dtype - ): - """Test that when a LoRAConv*Layer is initialized with from_layer(...), the device and dtype can be - overriden.""" - batch_size = 10 - in_channels = 8 - out_channels = 16 - # Original layer has dtype float32 on CPU. - original_layer = lora_conv_cls.conv_module( - in_channels, out_channels, kernel_size=3, padding="same", dtype=torch.float32 - ) - - target_device = torch.device("cuda:0") - lora_layer = lora_conv_cls.from_layer(original_layer, device=target_device, dtype=dtype) - - in_shape = (batch_size, in_channels) + (5,) * conv_dims - x = torch.rand(in_shape, device=torch.device("cuda"), dtype=dtype) - with torch.no_grad(): - y = lora_layer(x) - - expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims - assert y.shape == expected_out_shape - # Assert that lora_layer's internal layers have correct device and dtype. - assert lora_layer._down.weight.device == target_device - assert lora_layer._down.weight.dtype == dtype - assert lora_layer._up.weight.device == target_device - assert lora_layer._up.weight.dtype == dtype - - def test_lora_conv_layer_state_dict_roundtrip(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int): - original_layer = lora_conv_cls(8, 16) - - state_dict = original_layer.state_dict() - - roundtrip_layer = lora_conv_cls(8, 16, alpha=2.0) - - # Prior to loading the state_dict, the roundtrip_layer is different than the original_layer. - # (We don't compare the _up layer, because it is initialized to zeros so should match already.) - assert not torch.allclose(roundtrip_layer._down.weight, original_layer._down.weight) - assert not torch.allclose(roundtrip_layer.alpha, original_layer.alpha) - - roundtrip_layer.load_state_dict(state_dict) - - # After loading the state_dict the roundtrip_layer and original_layer match. - assert torch.allclose(roundtrip_layer._down.weight, original_layer._down.weight) - assert torch.allclose(roundtrip_layer._up.weight, original_layer._up.weight) - assert torch.allclose(roundtrip_layer.alpha, original_layer.alpha) diff --git a/tests/invoke_training/core/lora/layers/test_lora_linear_layer.py b/tests/invoke_training/core/lora/layers/test_lora_linear_layer.py deleted file mode 100644 index bdedd46b..00000000 --- a/tests/invoke_training/core/lora/layers/test_lora_linear_layer.py +++ /dev/null @@ -1,122 +0,0 @@ -import pytest -import torch - -from invoke_training.core.lora.layers import LoRALinearLayer - - -def test_lora_linear_layer_output_dim(): - """Test LoRALinearLayer produces an output with the expected dimension.""" - batch_size = 10 - in_features = 8 - out_features = 16 - layer = LoRALinearLayer(in_features, out_features, 2) - - x = torch.rand((batch_size, in_features)) - with torch.no_grad(): - y = layer(x) - - assert y.shape == (batch_size, out_features) - - -def test_lora_linear_layer_invalid_input_dim(): - """Test that LoRALinearLayer throws an exception if it receives an input with invalid dimensions.""" - in_features = 8 - out_features = 16 - layer = LoRALinearLayer(in_features, out_features, 2) - - x = torch.rand((10, in_features + 1)) # Bad input dimension. - - with pytest.raises(RuntimeError): - _ = layer(x) - - -def test_lora_linear_layer_zero_after_init(): - """Test that a newly-initialized LoRALinearLayer produces all zeros before it is trained.""" - batch_size = 10 - in_features = 8 - out_features = 16 - layer = LoRALinearLayer(in_features, out_features, 2) - - x = torch.rand((batch_size, in_features)) - with torch.no_grad(): - y = layer(x) - - assert not torch.allclose(x, torch.Tensor([0.0]), rtol=0.0) # The random input was non-zero. - assert torch.allclose(y, torch.Tensor([0.0]), rtol=0.0) # The untrained outputs are zero. - - -def test_lora_linear_layer_from_layer(): - """Test that a LoRALinearLayer can be initialized correctly from a torch.nn.Linear layer.""" - batch_size = 10 - in_features = 4 - out_features = 16 - original_layer = torch.nn.Linear(in_features, out_features) - - lora_layer: LoRALinearLayer = LoRALinearLayer.from_layer(original_layer) - - x = torch.rand((batch_size, in_features)) - with torch.no_grad(): - y = lora_layer(x) - - assert y.shape == (batch_size, out_features) - - -@pytest.mark.cuda -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -def test_lora_linear_layer_from_layer_inherit_device_and_dtype(dtype): - """Test that when a LoRALinearLayer is initialized with from_layer(...), it correctly inherits the device and - dtype. - """ - batch_size = 10 - in_features = 4 - out_features = 16 - original_layer = torch.nn.Linear(in_features, out_features, device=torch.device("cuda"), dtype=dtype) - - lora_layer: LoRALinearLayer = LoRALinearLayer.from_layer(original_layer) - - x = torch.rand((batch_size, in_features), device=torch.device("cuda"), dtype=dtype) - with torch.no_grad(): - y = lora_layer(x) - - assert y.shape == (batch_size, out_features) - # Assert that lora_layer's internal layers have correct device and dtype. - assert lora_layer._down.weight.device == original_layer.weight.device - assert lora_layer._down.weight.dtype == original_layer.weight.dtype - assert lora_layer._up.weight.device == original_layer.weight.device - assert lora_layer._up.weight.dtype == original_layer.weight.dtype - - -@pytest.mark.cuda -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -def test_lora_linear_layer_from_layer_override_device_and_dtype(dtype): - """Test that when a LoRALinearLayer is initialized with from_layer(...), the device and dtype can be overriden.""" - # Original layer has dtype float32 on CPU. - original_layer = torch.nn.Linear(4, 16, dtype=torch.float32) - - target_device = torch.device("cuda:0") - lora_layer: LoRALinearLayer = LoRALinearLayer.from_layer(original_layer, device=target_device, dtype=dtype) - - # Assert that lora_layer's internal layers have correct device and dtype. - assert lora_layer._down.weight.device == target_device - assert lora_layer._down.weight.dtype == dtype - assert lora_layer._up.weight.device == target_device - assert lora_layer._up.weight.dtype == dtype - - -def test_lora_linear_layer_state_dict_roundtrip(): - original_layer = LoRALinearLayer(4, 8) - - state_dict = original_layer.state_dict() - - roundtrip_layer = LoRALinearLayer(4, 8, alpha=2.0) - - # Prior to loading the state_dict, the roundtrip_layer is different than the original_layer. - assert not torch.allclose(roundtrip_layer._down.weight, original_layer._down.weight) - assert not torch.allclose(roundtrip_layer.alpha, original_layer.alpha) - - roundtrip_layer.load_state_dict(state_dict) - - # After loading the state_dict the roundtrip_layer and original_layer match. - assert torch.allclose(roundtrip_layer._down.weight, original_layer._down.weight) - assert torch.allclose(roundtrip_layer._up.weight, original_layer._up.weight) - assert torch.allclose(roundtrip_layer.alpha, original_layer.alpha) diff --git a/tests/invoke_training/core/lora/test_lora_block.py b/tests/invoke_training/core/lora/test_lora_block.py deleted file mode 100644 index ce9737da..00000000 --- a/tests/invoke_training/core/lora/test_lora_block.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch - -from invoke_training.core.lora.lora_block import LoRABlock - - -def test_lora_block_multiplier(): - """A basic test that the lora_multitplier param is being applied correctly.""" - original = torch.nn.Linear(1, 1, bias=False) - original.weight = torch.nn.Parameter(torch.Tensor([[1]])) - - lora = torch.nn.Linear(1, 1, bias=False) - lora.weight = torch.nn.Parameter(torch.Tensor([[2]])) - - layer = LoRABlock(original, lora, lora_multiplier=2) - - with torch.no_grad(): - y = layer(torch.Tensor([[1]])) - - # expected: y = (1 * in) + 2 * (2 * in) = 5 - torch.testing.assert_close(y, torch.Tensor([[5]]))