diff --git a/docs/get_started/quick_start.md b/docs/get_started/quick_start.md index b14ccc49..2b65419e 100644 --- a/docs/get_started/quick_start.md +++ b/docs/get_started/quick_start.md @@ -27,25 +27,15 @@ Monitor the training process with Tensorboard by running `tensorboard --logdir o ![Screenshot of the Tensorboard UI showing validation images.](../images/tensorboard_val_images_screenshot.png) *Validation images in the Tensorboard UI.* -### 5. Select a checkpoint +### 5. Invokeai Select a checkpoint based on the quality of the generated images. In this short training run, there are only 3 checkpoints to choose from. As an example, we'll use the **Epoch 2** checkpoint. -Internally, `invoke-training` stores the LoRA checkpoints in [PEFT format](https://huggingface.co/docs/peft/v0.7.1/en/package_reference/peft_model#peft.PeftModel.save_pretrained). We will convert the selected checkpoint to 'Kohya' format, because it has more widespread support across various UIs: -```bash -# Note: You will have to replace the timestamp in the checkpoint path. -python src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py \ - --src-ckpt-dir output/finetune_lora_sd_pokemon/1691088769.5694647/checkpoint_epoch-00000002 \ - --dst-ckpt-file output/finetune_lora_sd_pokemon/1691088769.5694647/checkpoint_epoch-00000002_kohya.safetensors -``` - -### 5. InvokeAI - If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation. Copy your selected LoRA checkpoint into your `${INVOKEAI_ROOT}/autoimport/lora` directory. For example: ```bash # Note: You will have to replace the timestamp in the checkpoint path. -cp output/finetune_lora_sd_pokemon/1691088769.5694647/checkpoint_epoch-00000002_kohya.safetensors ${INVOKEAI_ROOT}/autoimport/lora/pokemon_epoch-00000002.safetensors +cp output/1691088769.5694647/checkpoint_epoch-00000002.safetensors ${INVOKEAI_ROOT}/autoimport/lora/pokemon_epoch-00000002.safetensors ``` You can now use your trained Pokemon LoRA in the InvokeAI UI! 🎉 diff --git a/src/invoke_training/config/pipelines/finetune_lora_and_ti_config.py b/src/invoke_training/config/pipelines/finetune_lora_and_ti_config.py index 1f085446..fc98c99b 100644 --- a/src/invoke_training/config/pipelines/finetune_lora_and_ti_config.py +++ b/src/invoke_training/config/pipelines/finetune_lora_and_ti_config.py @@ -23,6 +23,9 @@ class LoraAndTiTrainingConfig(BasePipelineConfig): """The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name. """ + lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya" + """The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.""" + # Helpful discussion for understanding how this works at inference time: # https://github.com/huggingface/diffusers/pull/3144#discussion_r1172413509 num_vectors: int = 1 diff --git a/src/invoke_training/config/pipelines/finetune_lora_config.py b/src/invoke_training/config/pipelines/finetune_lora_config.py index c9678f41..7cf36aad 100644 --- a/src/invoke_training/config/pipelines/finetune_lora_config.py +++ b/src/invoke_training/config/pipelines/finetune_lora_config.py @@ -44,6 +44,9 @@ class LoRATrainingConfig(BasePipelineConfig): generation time with the resultant LoRA model. """ + lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya" + """The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.""" + train_unet: bool = True """Whether to add LoRA layers to the UNet model and train it. """ diff --git a/src/invoke_training/training/_experimental/dpo/diffusion_dpo_lora_sd.py b/src/invoke_training/training/_experimental/dpo/diffusion_dpo_lora_sd.py index 67a72df9..887e8191 100644 --- a/src/invoke_training/training/_experimental/dpo/diffusion_dpo_lora_sd.py +++ b/src/invoke_training/training/_experimental/dpo/diffusion_dpo_lora_sd.py @@ -7,6 +7,7 @@ import tempfile import time from pathlib import Path +from typing import Literal import peft import torch @@ -32,6 +33,7 @@ TEXT_ENCODER_TARGET_MODULES, UNET_TARGET_MODULES, load_sd_peft_checkpoint, + save_sd_kohya_checkpoint, save_sd_peft_checkpoint, ) from invoke_training.training._shared.stable_diffusion.model_loading_utils import load_models_sd @@ -46,6 +48,7 @@ def _save_sd_lora_checkpoint( text_encoder: peft.PeftModel | None, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, + lora_checkpoint_format: Literal["invoke_peft", "kohya"], ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) @@ -53,7 +56,12 @@ def _save_sd_lora_checkpoint( logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(idx) - save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) + if lora_checkpoint_format == "invoke_peft": + save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) + elif lora_checkpoint_format == "kohya": + save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) + else: + raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.") def train_forward_dpo( # noqa: C901 @@ -421,14 +429,14 @@ def prep_peft_model(model, lr: float | None = None): epoch_checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint_epoch", - extension=".safetensors", + extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, max_checkpoints=config.max_checkpoints, ) step_checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint_step", - extension=".safetensors", + extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, max_checkpoints=config.max_checkpoints, ) @@ -513,6 +521,7 @@ def prep_peft_model(model, lr: float | None = None): text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None, logger=logger, checkpoint_tracker=step_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) logs = { @@ -534,6 +543,7 @@ def prep_peft_model(model, lr: float | None = None): text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None, logger=logger, checkpoint_tracker=epoch_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) # Generate validation images every n epochs. diff --git a/src/invoke_training/training/_shared/stable_diffusion/lora_checkpoint_utils.py b/src/invoke_training/training/_shared/stable_diffusion/lora_checkpoint_utils.py index 065176c9..98a5e35c 100644 --- a/src/invoke_training/training/_shared/stable_diffusion/lora_checkpoint_utils.py +++ b/src/invoke_training/training/_shared/stable_diffusion/lora_checkpoint_utils.py @@ -1,3 +1,4 @@ +import os from pathlib import Path import peft @@ -36,6 +37,24 @@ SDXL_PEFT_TEXT_ENCODER_1_KEY = "text_encoder_1" SDXL_PEFT_TEXT_ENCODER_2_KEY = "text_encoder_2" +SD_KOHYA_UNET_KEY = "lora_unet" +SD_KOHYA_TEXT_ENCODER_KEY = "lora_te" + +SDXL_KOHYA_UNET_KEY = "lora_unet" +SDXL_KOHYA_TEXT_ENCODER_1_KEY = "lora_te1" +SDXL_KOHYA_TEXT_ENCODER_2_KEY = "lora_te2" + +SD_PEFT_TO_KOHYA_KEYS = { + SD_PEFT_UNET_KEY: SD_KOHYA_UNET_KEY, + SD_PEFT_TEXT_ENCODER_KEY: SD_KOHYA_TEXT_ENCODER_KEY, +} + +SDXL_PEFT_TO_KOHYA_KEYS = { + SDXL_PEFT_UNET_KEY: SDXL_KOHYA_UNET_KEY, + SDXL_PEFT_TEXT_ENCODER_1_KEY: SDXL_KOHYA_TEXT_ENCODER_1_KEY, + SDXL_PEFT_TEXT_ENCODER_2_KEY: SDXL_KOHYA_TEXT_ENCODER_2_KEY, +} + def save_multi_model_peft_checkpoint(checkpoint_dir: Path | str, models: dict[str, peft.PeftModel]): """Save a dict of PeftModels to a checkpoint directory. @@ -172,28 +191,100 @@ def _convert_peft_state_dict_to_kohya_state_dict( return kohya_ss_state_dict +def _convert_peft_models_to_kohya_state_dict( + kohya_prefixes: list[str], models: list[peft.PeftModel] +) -> dict[str, torch.Tensor]: + kohya_state_dict = {} + default_adapter_name = "default" + + for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True): + lora_config = peft_model.peft_config[default_adapter_name] + assert isinstance(lora_config, peft.LoraConfig) + + peft_state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name) + + kohya_state_dict.update( + _convert_peft_state_dict_to_kohya_state_dict( + lora_config=lora_config, + peft_state_dict=peft_state_dict, + prefix=kohya_prefix, + dtype=torch.float32, + ) + ) + + return kohya_state_dict + + +def save_sd_kohya_checkpoint(checkpoint_path: Path, unet: peft.PeftModel | None, text_encoder: peft.PeftModel | None): + kohya_prefixes = [] + models = [] + for kohya_prefix, peft_model in zip([SD_KOHYA_UNET_KEY, SD_KOHYA_TEXT_ENCODER_KEY], [unet, text_encoder]): + if peft_model is not None: + kohya_prefixes.append(kohya_prefix) + models.append(peft_model) + + kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models) + + checkpoint_path.parent.mkdir(parents=True, exist_ok=True) + save_state_dict(kohya_state_dict, checkpoint_path) + + +def save_sdxl_kohya_checkpoint( + checkpoint_path: Path, + unet: peft.PeftModel | None, + text_encoder_1: peft.PeftModel | None, + text_encoder_2: peft.PeftModel | None, +): + kohya_prefixes = [] + models = [] + for kohya_prefix, peft_model in zip( + [SDXL_KOHYA_UNET_KEY, SDXL_KOHYA_TEXT_ENCODER_1_KEY, SDXL_KOHYA_TEXT_ENCODER_2_KEY], + [unet, text_encoder_1, text_encoder_2], + ): + if peft_model is not None: + kohya_prefixes.append(kohya_prefix) + models.append(peft_model) + + kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models) + + checkpoint_path.parent.mkdir(parents=True, exist_ok=True) + save_state_dict(kohya_state_dict, checkpoint_path) + + def convert_sd_peft_checkpoint_to_kohya_state_dict( in_checkpoint_dir: Path, out_checkpoint_file: Path, dtype: torch.dtype = torch.float32, ) -> dict[str, torch.Tensor]: - """Convert SD v1 PEFT models to a Kohya-format LoRA state dict.""" + """Convert SD v1 or SDXL PEFT models to a Kohya-format LoRA state dict.""" + # Get the immediate subdirectories of the checkpoint directory. We assume that each subdirectory is a PEFT model. + peft_model_dirs = os.listdir(in_checkpoint_dir) + peft_model_dirs = [in_checkpoint_dir / d for d in peft_model_dirs] # Convert to Path objects. + peft_model_dirs = [d for d in peft_model_dirs if d.is_dir()] # Filter out non-directories. + + if len(peft_model_dirs) == 0: + raise ValueError(f"No checkpoint files found in directory '{in_checkpoint_dir}'.") + kohya_state_dict = {} - for kohya_prefix, peft_model_key in [("lora_unet", SD_PEFT_UNET_KEY), ("lora_te", SD_PEFT_TEXT_ENCODER_KEY)]: - peft_model_dir = in_checkpoint_dir / peft_model_key - - if peft_model_dir.exists(): - # Note: This logic to load the LoraConfig and weights directly is based on how it is done here: - # https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/peft_model.py#L672-L689 - # This may need to be updated in the future to support other adapter types (LoKr, LoHa, etc.). - # Also, I could see this interface breaking in the future. - lora_config = peft.LoraConfig.from_pretrained(peft_model_dir) - lora_weights = peft.utils.load_peft_weights(peft_model_dir, device="cpu") - - kohya_state_dict.update( - _convert_peft_state_dict_to_kohya_state_dict( - lora_config=lora_config, peft_state_dict=lora_weights, prefix=kohya_prefix, dtype=dtype - ) + for peft_model_dir in peft_model_dirs: + if peft_model_dir.name in SD_PEFT_TO_KOHYA_KEYS: + kohya_prefix = SD_PEFT_TO_KOHYA_KEYS[peft_model_dir.name] + elif peft_model_dir.name in SDXL_PEFT_TO_KOHYA_KEYS: + kohya_prefix = SDXL_PEFT_TO_KOHYA_KEYS[peft_model_dir.name] + else: + raise ValueError(f"Unrecognized checkpoint directory: '{peft_model_dir}'.") + + # Note: This logic to load the LoraConfig and weights directly is based on how it is done here: + # https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/peft_model.py#L672-L689 + # This may need to be updated in the future to support other adapter types (LoKr, LoHa, etc.). + # Also, I could see this interface breaking in the future. + lora_config = peft.LoraConfig.from_pretrained(peft_model_dir) + lora_weights = peft.utils.load_peft_weights(peft_model_dir, device="cpu") + + kohya_state_dict.update( + _convert_peft_state_dict_to_kohya_state_dict( + lora_config=lora_config, peft_state_dict=lora_weights, prefix=kohya_prefix, dtype=dtype ) + ) save_state_dict(kohya_state_dict, out_checkpoint_file) diff --git a/src/invoke_training/training/pipelines/stable_diffusion/finetune_lora_sd.py b/src/invoke_training/training/pipelines/stable_diffusion/finetune_lora_sd.py index 6080040f..4ca82632 100644 --- a/src/invoke_training/training/pipelines/stable_diffusion/finetune_lora_sd.py +++ b/src/invoke_training/training/pipelines/stable_diffusion/finetune_lora_sd.py @@ -6,7 +6,7 @@ import tempfile import time from pathlib import Path -from typing import Optional, Union +from typing import Literal, Optional, Union import peft import torch @@ -39,6 +39,7 @@ from invoke_training.training._shared.stable_diffusion.lora_checkpoint_utils import ( TEXT_ENCODER_TARGET_MODULES, UNET_TARGET_MODULES, + save_sd_kohya_checkpoint, save_sd_peft_checkpoint, ) from invoke_training.training._shared.stable_diffusion.model_loading_utils import load_models_sd @@ -52,6 +53,7 @@ def _save_sd_lora_checkpoint( text_encoder: peft.PeftModel | None, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, + lora_checkpoint_format: Literal["invoke_peft", "kohya"], ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) @@ -59,7 +61,12 @@ def _save_sd_lora_checkpoint( logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(idx) - save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) + if lora_checkpoint_format == "invoke_peft": + save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) + elif lora_checkpoint_format == "kohya": + save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) + else: + raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.") def _build_data_loader( @@ -437,12 +444,14 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N base_dir=ckpt_dir, prefix="checkpoint_epoch", max_checkpoints=config.max_checkpoints, + extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, ) step_checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint_step", max_checkpoints=config.max_checkpoints, + extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, ) # Train! @@ -524,6 +533,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N text_encoder=accelerator.unwrap_model(text_encoder) if config.train_text_encoder else None, logger=logger, checkpoint_tracker=step_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) logs = { @@ -545,6 +555,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N text_encoder=accelerator.unwrap_model(text_encoder) if config.train_text_encoder else None, logger=logger, checkpoint_tracker=epoch_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) # Generate validation images every n epochs. diff --git a/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_and_ti_sdxl.py b/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_and_ti_sdxl.py index e38a4b8f..88a1b3fd 100644 --- a/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_and_ti_sdxl.py +++ b/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_and_ti_sdxl.py @@ -5,6 +5,7 @@ import os import time from pathlib import Path +from typing import Literal import peft import torch @@ -32,6 +33,7 @@ from invoke_training.training._shared.stable_diffusion.lora_checkpoint_utils import ( TEXT_ENCODER_TARGET_MODULES, UNET_TARGET_MODULES, + save_sdxl_kohya_checkpoint, save_sdxl_peft_checkpoint, ) from invoke_training.training._shared.stable_diffusion.model_loading_utils import ( @@ -54,6 +56,7 @@ def _save_sdxl_lora_and_ti_checkpoint( accelerator: Accelerator, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, + lora_checkpoint_format: Literal["invoke_peft", "kohya"], ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) @@ -61,12 +64,22 @@ def _save_sdxl_lora_and_ti_checkpoint( logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(idx) - save_sdxl_peft_checkpoint( - Path(save_path), - unet=unet if config.train_unet else None, - text_encoder_1=text_encoder_1 if config.train_text_encoder else None, - text_encoder_2=text_encoder_2 if config.train_text_encoder else None, - ) + if lora_checkpoint_format == "invoke_peft": + save_sdxl_peft_checkpoint( + Path(save_path), + unet=unet if config.train_unet else None, + text_encoder_1=text_encoder_1 if config.train_text_encoder else None, + text_encoder_2=text_encoder_2 if config.train_text_encoder else None, + ) + elif lora_checkpoint_format == "kohya": + save_sdxl_kohya_checkpoint( + Path(save_path) / "lora.safetensors", + unet=unet if config.train_unet else None, + text_encoder_1=text_encoder_1 if config.train_text_encoder else None, + text_encoder_2=text_encoder_2 if config.train_text_encoder else None, + ) + else: + raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.") if config.train_ti: ti_checkpoint_path = Path(save_path) / "embeddings.safetensors" @@ -487,6 +500,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float) -> peft.P accelerator=accelerator, logger=logger, checkpoint_tracker=step_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) logs = { @@ -512,6 +526,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float) -> peft.P accelerator=accelerator, logger=logger, checkpoint_tracker=epoch_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) accelerator.wait_for_everyone() diff --git a/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_sdxl.py b/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_sdxl.py index 02dca74e..95afdf71 100644 --- a/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_sdxl.py +++ b/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_sdxl.py @@ -6,7 +6,7 @@ import tempfile import time from pathlib import Path -from typing import Optional, Union +from typing import Literal, Optional, Union import peft import torch @@ -43,6 +43,7 @@ from invoke_training.training._shared.stable_diffusion.lora_checkpoint_utils import ( TEXT_ENCODER_TARGET_MODULES, UNET_TARGET_MODULES, + save_sdxl_kohya_checkpoint, save_sdxl_peft_checkpoint, ) from invoke_training.training._shared.stable_diffusion.model_loading_utils import ( @@ -60,6 +61,7 @@ def _save_sdxl_lora_checkpoint( text_encoder_2: peft.PeftModel | None, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, + lora_checkpoint_format: Literal["invoke_peft", "kohya"], ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) @@ -67,7 +69,16 @@ def _save_sdxl_lora_checkpoint( logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(idx) - save_sdxl_peft_checkpoint(Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2) + if lora_checkpoint_format == "invoke_peft": + save_sdxl_peft_checkpoint( + Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2 + ) + elif lora_checkpoint_format == "kohya": + save_sdxl_kohya_checkpoint( + Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2 + ) + else: + raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.") def _build_data_loader( @@ -528,12 +539,14 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N base_dir=ckpt_dir, prefix="checkpoint_epoch", max_checkpoints=config.max_checkpoints, + extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, ) step_checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint_step", max_checkpoints=config.max_checkpoints, + extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, ) # Train! @@ -620,6 +633,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N text_encoder_2=text_encoder_2, logger=logger, checkpoint_tracker=step_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) logs = { @@ -641,6 +655,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N text_encoder_2=text_encoder_2, logger=logger, checkpoint_tracker=epoch_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) accelerator.wait_for_everyone() diff --git a/tests/invoke_training/training/_shared/stable_diffusion/test_lora_checkpoint_utils.py b/tests/invoke_training/training/_shared/stable_diffusion/test_lora_checkpoint_utils.py new file mode 100644 index 00000000..ff1a902c --- /dev/null +++ b/tests/invoke_training/training/_shared/stable_diffusion/test_lora_checkpoint_utils.py @@ -0,0 +1,24 @@ +from pathlib import Path + +import pytest + +from invoke_training.training._shared.stable_diffusion.lora_checkpoint_utils import ( + convert_sd_peft_checkpoint_to_kohya_state_dict, +) + + +def test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_empty_directory(tmp_path: Path): + with pytest.raises(ValueError, match="No checkpoint files found in directory"): + convert_sd_peft_checkpoint_to_kohya_state_dict( + in_checkpoint_dir=tmp_path, out_checkpoint_file=tmp_path / "out.safetensors" + ) + + +def test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_unexpected_subdirectory(tmp_path: Path): + subdirectory = tmp_path / "subdir" + subdirectory.mkdir() + + with pytest.raises(ValueError, match=f"Unrecognized checkpoint directory: '{subdirectory}'."): + convert_sd_peft_checkpoint_to_kohya_state_dict( + in_checkpoint_dir=tmp_path, out_checkpoint_file=tmp_path / "out.safetensors" + )