diff --git a/src/invoke_training/config/pipelines/finetune_lora_and_ti_config.py b/src/invoke_training/config/pipelines/finetune_lora_and_ti_config.py index 1f085446..fc98c99b 100644 --- a/src/invoke_training/config/pipelines/finetune_lora_and_ti_config.py +++ b/src/invoke_training/config/pipelines/finetune_lora_and_ti_config.py @@ -23,6 +23,9 @@ class LoraAndTiTrainingConfig(BasePipelineConfig): """The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name. """ + lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya" + """The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.""" + # Helpful discussion for understanding how this works at inference time: # https://github.com/huggingface/diffusers/pull/3144#discussion_r1172413509 num_vectors: int = 1 diff --git a/src/invoke_training/config/pipelines/finetune_lora_config.py b/src/invoke_training/config/pipelines/finetune_lora_config.py index c9678f41..7cf36aad 100644 --- a/src/invoke_training/config/pipelines/finetune_lora_config.py +++ b/src/invoke_training/config/pipelines/finetune_lora_config.py @@ -44,6 +44,9 @@ class LoRATrainingConfig(BasePipelineConfig): generation time with the resultant LoRA model. """ + lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya" + """The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.""" + train_unet: bool = True """Whether to add LoRA layers to the UNet model and train it. """ diff --git a/src/invoke_training/training/_experimental/dpo/diffusion_dpo_lora_sd.py b/src/invoke_training/training/_experimental/dpo/diffusion_dpo_lora_sd.py index 67a72df9..887e8191 100644 --- a/src/invoke_training/training/_experimental/dpo/diffusion_dpo_lora_sd.py +++ b/src/invoke_training/training/_experimental/dpo/diffusion_dpo_lora_sd.py @@ -7,6 +7,7 @@ import tempfile import time from pathlib import Path +from typing import Literal import peft import torch @@ -32,6 +33,7 @@ TEXT_ENCODER_TARGET_MODULES, UNET_TARGET_MODULES, load_sd_peft_checkpoint, + save_sd_kohya_checkpoint, save_sd_peft_checkpoint, ) from invoke_training.training._shared.stable_diffusion.model_loading_utils import load_models_sd @@ -46,6 +48,7 @@ def _save_sd_lora_checkpoint( text_encoder: peft.PeftModel | None, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, + lora_checkpoint_format: Literal["invoke_peft", "kohya"], ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) @@ -53,7 +56,12 @@ def _save_sd_lora_checkpoint( logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(idx) - save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) + if lora_checkpoint_format == "invoke_peft": + save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) + elif lora_checkpoint_format == "kohya": + save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) + else: + raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.") def train_forward_dpo( # noqa: C901 @@ -421,14 +429,14 @@ def prep_peft_model(model, lr: float | None = None): epoch_checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint_epoch", - extension=".safetensors", + extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, max_checkpoints=config.max_checkpoints, ) step_checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint_step", - extension=".safetensors", + extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, max_checkpoints=config.max_checkpoints, ) @@ -513,6 +521,7 @@ def prep_peft_model(model, lr: float | None = None): text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None, logger=logger, checkpoint_tracker=step_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) logs = { @@ -534,6 +543,7 @@ def prep_peft_model(model, lr: float | None = None): text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None, logger=logger, checkpoint_tracker=epoch_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) # Generate validation images every n epochs. diff --git a/src/invoke_training/training/_shared/stable_diffusion/lora_checkpoint_utils.py b/src/invoke_training/training/_shared/stable_diffusion/lora_checkpoint_utils.py index d7639699..98a5e35c 100644 --- a/src/invoke_training/training/_shared/stable_diffusion/lora_checkpoint_utils.py +++ b/src/invoke_training/training/_shared/stable_diffusion/lora_checkpoint_utils.py @@ -191,6 +191,66 @@ def _convert_peft_state_dict_to_kohya_state_dict( return kohya_ss_state_dict +def _convert_peft_models_to_kohya_state_dict( + kohya_prefixes: list[str], models: list[peft.PeftModel] +) -> dict[str, torch.Tensor]: + kohya_state_dict = {} + default_adapter_name = "default" + + for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True): + lora_config = peft_model.peft_config[default_adapter_name] + assert isinstance(lora_config, peft.LoraConfig) + + peft_state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name) + + kohya_state_dict.update( + _convert_peft_state_dict_to_kohya_state_dict( + lora_config=lora_config, + peft_state_dict=peft_state_dict, + prefix=kohya_prefix, + dtype=torch.float32, + ) + ) + + return kohya_state_dict + + +def save_sd_kohya_checkpoint(checkpoint_path: Path, unet: peft.PeftModel | None, text_encoder: peft.PeftModel | None): + kohya_prefixes = [] + models = [] + for kohya_prefix, peft_model in zip([SD_KOHYA_UNET_KEY, SD_KOHYA_TEXT_ENCODER_KEY], [unet, text_encoder]): + if peft_model is not None: + kohya_prefixes.append(kohya_prefix) + models.append(peft_model) + + kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models) + + checkpoint_path.parent.mkdir(parents=True, exist_ok=True) + save_state_dict(kohya_state_dict, checkpoint_path) + + +def save_sdxl_kohya_checkpoint( + checkpoint_path: Path, + unet: peft.PeftModel | None, + text_encoder_1: peft.PeftModel | None, + text_encoder_2: peft.PeftModel | None, +): + kohya_prefixes = [] + models = [] + for kohya_prefix, peft_model in zip( + [SDXL_KOHYA_UNET_KEY, SDXL_KOHYA_TEXT_ENCODER_1_KEY, SDXL_KOHYA_TEXT_ENCODER_2_KEY], + [unet, text_encoder_1, text_encoder_2], + ): + if peft_model is not None: + kohya_prefixes.append(kohya_prefix) + models.append(peft_model) + + kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models) + + checkpoint_path.parent.mkdir(parents=True, exist_ok=True) + save_state_dict(kohya_state_dict, checkpoint_path) + + def convert_sd_peft_checkpoint_to_kohya_state_dict( in_checkpoint_dir: Path, out_checkpoint_file: Path, diff --git a/src/invoke_training/training/pipelines/stable_diffusion/finetune_lora_sd.py b/src/invoke_training/training/pipelines/stable_diffusion/finetune_lora_sd.py index 6080040f..4ca82632 100644 --- a/src/invoke_training/training/pipelines/stable_diffusion/finetune_lora_sd.py +++ b/src/invoke_training/training/pipelines/stable_diffusion/finetune_lora_sd.py @@ -6,7 +6,7 @@ import tempfile import time from pathlib import Path -from typing import Optional, Union +from typing import Literal, Optional, Union import peft import torch @@ -39,6 +39,7 @@ from invoke_training.training._shared.stable_diffusion.lora_checkpoint_utils import ( TEXT_ENCODER_TARGET_MODULES, UNET_TARGET_MODULES, + save_sd_kohya_checkpoint, save_sd_peft_checkpoint, ) from invoke_training.training._shared.stable_diffusion.model_loading_utils import load_models_sd @@ -52,6 +53,7 @@ def _save_sd_lora_checkpoint( text_encoder: peft.PeftModel | None, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, + lora_checkpoint_format: Literal["invoke_peft", "kohya"], ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) @@ -59,7 +61,12 @@ def _save_sd_lora_checkpoint( logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(idx) - save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) + if lora_checkpoint_format == "invoke_peft": + save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) + elif lora_checkpoint_format == "kohya": + save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) + else: + raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.") def _build_data_loader( @@ -437,12 +444,14 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N base_dir=ckpt_dir, prefix="checkpoint_epoch", max_checkpoints=config.max_checkpoints, + extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, ) step_checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint_step", max_checkpoints=config.max_checkpoints, + extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, ) # Train! @@ -524,6 +533,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N text_encoder=accelerator.unwrap_model(text_encoder) if config.train_text_encoder else None, logger=logger, checkpoint_tracker=step_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) logs = { @@ -545,6 +555,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N text_encoder=accelerator.unwrap_model(text_encoder) if config.train_text_encoder else None, logger=logger, checkpoint_tracker=epoch_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) # Generate validation images every n epochs. diff --git a/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_and_ti_sdxl.py b/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_and_ti_sdxl.py index e38a4b8f..88a1b3fd 100644 --- a/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_and_ti_sdxl.py +++ b/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_and_ti_sdxl.py @@ -5,6 +5,7 @@ import os import time from pathlib import Path +from typing import Literal import peft import torch @@ -32,6 +33,7 @@ from invoke_training.training._shared.stable_diffusion.lora_checkpoint_utils import ( TEXT_ENCODER_TARGET_MODULES, UNET_TARGET_MODULES, + save_sdxl_kohya_checkpoint, save_sdxl_peft_checkpoint, ) from invoke_training.training._shared.stable_diffusion.model_loading_utils import ( @@ -54,6 +56,7 @@ def _save_sdxl_lora_and_ti_checkpoint( accelerator: Accelerator, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, + lora_checkpoint_format: Literal["invoke_peft", "kohya"], ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) @@ -61,12 +64,22 @@ def _save_sdxl_lora_and_ti_checkpoint( logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(idx) - save_sdxl_peft_checkpoint( - Path(save_path), - unet=unet if config.train_unet else None, - text_encoder_1=text_encoder_1 if config.train_text_encoder else None, - text_encoder_2=text_encoder_2 if config.train_text_encoder else None, - ) + if lora_checkpoint_format == "invoke_peft": + save_sdxl_peft_checkpoint( + Path(save_path), + unet=unet if config.train_unet else None, + text_encoder_1=text_encoder_1 if config.train_text_encoder else None, + text_encoder_2=text_encoder_2 if config.train_text_encoder else None, + ) + elif lora_checkpoint_format == "kohya": + save_sdxl_kohya_checkpoint( + Path(save_path) / "lora.safetensors", + unet=unet if config.train_unet else None, + text_encoder_1=text_encoder_1 if config.train_text_encoder else None, + text_encoder_2=text_encoder_2 if config.train_text_encoder else None, + ) + else: + raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.") if config.train_ti: ti_checkpoint_path = Path(save_path) / "embeddings.safetensors" @@ -487,6 +500,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float) -> peft.P accelerator=accelerator, logger=logger, checkpoint_tracker=step_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) logs = { @@ -512,6 +526,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float) -> peft.P accelerator=accelerator, logger=logger, checkpoint_tracker=epoch_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) accelerator.wait_for_everyone() diff --git a/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_sdxl.py b/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_sdxl.py index 02dca74e..95afdf71 100644 --- a/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_sdxl.py +++ b/src/invoke_training/training/pipelines/stable_diffusion_xl/finetune_lora_sdxl.py @@ -6,7 +6,7 @@ import tempfile import time from pathlib import Path -from typing import Optional, Union +from typing import Literal, Optional, Union import peft import torch @@ -43,6 +43,7 @@ from invoke_training.training._shared.stable_diffusion.lora_checkpoint_utils import ( TEXT_ENCODER_TARGET_MODULES, UNET_TARGET_MODULES, + save_sdxl_kohya_checkpoint, save_sdxl_peft_checkpoint, ) from invoke_training.training._shared.stable_diffusion.model_loading_utils import ( @@ -60,6 +61,7 @@ def _save_sdxl_lora_checkpoint( text_encoder_2: peft.PeftModel | None, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, + lora_checkpoint_format: Literal["invoke_peft", "kohya"], ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) @@ -67,7 +69,16 @@ def _save_sdxl_lora_checkpoint( logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(idx) - save_sdxl_peft_checkpoint(Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2) + if lora_checkpoint_format == "invoke_peft": + save_sdxl_peft_checkpoint( + Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2 + ) + elif lora_checkpoint_format == "kohya": + save_sdxl_kohya_checkpoint( + Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2 + ) + else: + raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.") def _build_data_loader( @@ -528,12 +539,14 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N base_dir=ckpt_dir, prefix="checkpoint_epoch", max_checkpoints=config.max_checkpoints, + extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, ) step_checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint_step", max_checkpoints=config.max_checkpoints, + extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, ) # Train! @@ -620,6 +633,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N text_encoder_2=text_encoder_2, logger=logger, checkpoint_tracker=step_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) logs = { @@ -641,6 +655,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N text_encoder_2=text_encoder_2, logger=logger, checkpoint_tracker=epoch_checkpoint_tracker, + lora_checkpoint_format=config.lora_checkpoint_format, ) accelerator.wait_for_everyone()