Skip to content

Commit

Permalink
Merge pull request #80 from invoke-ai/save-kohya
Browse files Browse the repository at this point in the history
Default to saving LoRA checkpoints in Kohya format
  • Loading branch information
RyanJDick authored Jan 26, 2024
2 parents 062988e + 49e11db commit 17332fc
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 41 deletions.
14 changes: 2 additions & 12 deletions docs/get_started/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,15 @@ Monitor the training process with Tensorboard by running `tensorboard --logdir o
![Screenshot of the Tensorboard UI showing validation images.](../images/tensorboard_val_images_screenshot.png)
*Validation images in the Tensorboard UI.*

### 5. Select a checkpoint
### 5. Invokeai
Select a checkpoint based on the quality of the generated images. In this short training run, there are only 3 checkpoints to choose from. As an example, we'll use the **Epoch 2** checkpoint.

Internally, `invoke-training` stores the LoRA checkpoints in [PEFT format](https://huggingface.co/docs/peft/v0.7.1/en/package_reference/peft_model#peft.PeftModel.save_pretrained). We will convert the selected checkpoint to 'Kohya' format, because it has more widespread support across various UIs:
```bash
# Note: You will have to replace the timestamp in the checkpoint path.
python src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py \
--src-ckpt-dir output/finetune_lora_sd_pokemon/1691088769.5694647/checkpoint_epoch-00000002 \
--dst-ckpt-file output/finetune_lora_sd_pokemon/1691088769.5694647/checkpoint_epoch-00000002_kohya.safetensors
```

### 5. InvokeAI

If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.

Copy your selected LoRA checkpoint into your `${INVOKEAI_ROOT}/autoimport/lora` directory. For example:
```bash
# Note: You will have to replace the timestamp in the checkpoint path.
cp output/finetune_lora_sd_pokemon/1691088769.5694647/checkpoint_epoch-00000002_kohya.safetensors ${INVOKEAI_ROOT}/autoimport/lora/pokemon_epoch-00000002.safetensors
cp output/1691088769.5694647/checkpoint_epoch-00000002.safetensors ${INVOKEAI_ROOT}/autoimport/lora/pokemon_epoch-00000002.safetensors
```

You can now use your trained Pokemon LoRA in the InvokeAI UI! 🎉
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class LoraAndTiTrainingConfig(BasePipelineConfig):
"""The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.
"""

lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya"
"""The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`."""

# Helpful discussion for understanding how this works at inference time:
# https://github.com/huggingface/diffusers/pull/3144#discussion_r1172413509
num_vectors: int = 1
Expand Down
3 changes: 3 additions & 0 deletions src/invoke_training/config/pipelines/finetune_lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class LoRATrainingConfig(BasePipelineConfig):
generation time with the resultant LoRA model.
"""

lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya"
"""The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`."""

train_unet: bool = True
"""Whether to add LoRA layers to the UNet model and train it.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tempfile
import time
from pathlib import Path
from typing import Literal

import peft
import torch
Expand All @@ -32,6 +33,7 @@
TEXT_ENCODER_TARGET_MODULES,
UNET_TARGET_MODULES,
load_sd_peft_checkpoint,
save_sd_kohya_checkpoint,
save_sd_peft_checkpoint,
)
from invoke_training.training._shared.stable_diffusion.model_loading_utils import load_models_sd
Expand All @@ -46,14 +48,20 @@ def _save_sd_lora_checkpoint(
text_encoder: peft.PeftModel | None,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
lora_checkpoint_format: Literal["invoke_peft", "kohya"],
):
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(idx)

save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
if lora_checkpoint_format == "invoke_peft":
save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
elif lora_checkpoint_format == "kohya":
save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
else:
raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.")


def train_forward_dpo( # noqa: C901
Expand Down Expand Up @@ -421,14 +429,14 @@ def prep_peft_model(model, lr: float | None = None):
epoch_checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint_epoch",
extension=".safetensors",
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
max_checkpoints=config.max_checkpoints,
)

step_checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint_step",
extension=".safetensors",
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
max_checkpoints=config.max_checkpoints,
)

Expand Down Expand Up @@ -513,6 +521,7 @@ def prep_peft_model(model, lr: float | None = None):
text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None,
logger=logger,
checkpoint_tracker=step_checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)

logs = {
Expand All @@ -534,6 +543,7 @@ def prep_peft_model(model, lr: float | None = None):
text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None,
logger=logger,
checkpoint_tracker=epoch_checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)

# Generate validation images every n epochs.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path

import peft
Expand Down Expand Up @@ -36,6 +37,24 @@
SDXL_PEFT_TEXT_ENCODER_1_KEY = "text_encoder_1"
SDXL_PEFT_TEXT_ENCODER_2_KEY = "text_encoder_2"

SD_KOHYA_UNET_KEY = "lora_unet"
SD_KOHYA_TEXT_ENCODER_KEY = "lora_te"

SDXL_KOHYA_UNET_KEY = "lora_unet"
SDXL_KOHYA_TEXT_ENCODER_1_KEY = "lora_te1"
SDXL_KOHYA_TEXT_ENCODER_2_KEY = "lora_te2"

SD_PEFT_TO_KOHYA_KEYS = {
SD_PEFT_UNET_KEY: SD_KOHYA_UNET_KEY,
SD_PEFT_TEXT_ENCODER_KEY: SD_KOHYA_TEXT_ENCODER_KEY,
}

SDXL_PEFT_TO_KOHYA_KEYS = {
SDXL_PEFT_UNET_KEY: SDXL_KOHYA_UNET_KEY,
SDXL_PEFT_TEXT_ENCODER_1_KEY: SDXL_KOHYA_TEXT_ENCODER_1_KEY,
SDXL_PEFT_TEXT_ENCODER_2_KEY: SDXL_KOHYA_TEXT_ENCODER_2_KEY,
}


def save_multi_model_peft_checkpoint(checkpoint_dir: Path | str, models: dict[str, peft.PeftModel]):
"""Save a dict of PeftModels to a checkpoint directory.
Expand Down Expand Up @@ -172,28 +191,100 @@ def _convert_peft_state_dict_to_kohya_state_dict(
return kohya_ss_state_dict


def _convert_peft_models_to_kohya_state_dict(
kohya_prefixes: list[str], models: list[peft.PeftModel]
) -> dict[str, torch.Tensor]:
kohya_state_dict = {}
default_adapter_name = "default"

for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True):
lora_config = peft_model.peft_config[default_adapter_name]
assert isinstance(lora_config, peft.LoraConfig)

peft_state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name)

kohya_state_dict.update(
_convert_peft_state_dict_to_kohya_state_dict(
lora_config=lora_config,
peft_state_dict=peft_state_dict,
prefix=kohya_prefix,
dtype=torch.float32,
)
)

return kohya_state_dict


def save_sd_kohya_checkpoint(checkpoint_path: Path, unet: peft.PeftModel | None, text_encoder: peft.PeftModel | None):
kohya_prefixes = []
models = []
for kohya_prefix, peft_model in zip([SD_KOHYA_UNET_KEY, SD_KOHYA_TEXT_ENCODER_KEY], [unet, text_encoder]):
if peft_model is not None:
kohya_prefixes.append(kohya_prefix)
models.append(peft_model)

kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models)

checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
save_state_dict(kohya_state_dict, checkpoint_path)


def save_sdxl_kohya_checkpoint(
checkpoint_path: Path,
unet: peft.PeftModel | None,
text_encoder_1: peft.PeftModel | None,
text_encoder_2: peft.PeftModel | None,
):
kohya_prefixes = []
models = []
for kohya_prefix, peft_model in zip(
[SDXL_KOHYA_UNET_KEY, SDXL_KOHYA_TEXT_ENCODER_1_KEY, SDXL_KOHYA_TEXT_ENCODER_2_KEY],
[unet, text_encoder_1, text_encoder_2],
):
if peft_model is not None:
kohya_prefixes.append(kohya_prefix)
models.append(peft_model)

kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models)

checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
save_state_dict(kohya_state_dict, checkpoint_path)


def convert_sd_peft_checkpoint_to_kohya_state_dict(
in_checkpoint_dir: Path,
out_checkpoint_file: Path,
dtype: torch.dtype = torch.float32,
) -> dict[str, torch.Tensor]:
"""Convert SD v1 PEFT models to a Kohya-format LoRA state dict."""
"""Convert SD v1 or SDXL PEFT models to a Kohya-format LoRA state dict."""
# Get the immediate subdirectories of the checkpoint directory. We assume that each subdirectory is a PEFT model.
peft_model_dirs = os.listdir(in_checkpoint_dir)
peft_model_dirs = [in_checkpoint_dir / d for d in peft_model_dirs] # Convert to Path objects.
peft_model_dirs = [d for d in peft_model_dirs if d.is_dir()] # Filter out non-directories.

if len(peft_model_dirs) == 0:
raise ValueError(f"No checkpoint files found in directory '{in_checkpoint_dir}'.")

kohya_state_dict = {}
for kohya_prefix, peft_model_key in [("lora_unet", SD_PEFT_UNET_KEY), ("lora_te", SD_PEFT_TEXT_ENCODER_KEY)]:
peft_model_dir = in_checkpoint_dir / peft_model_key

if peft_model_dir.exists():
# Note: This logic to load the LoraConfig and weights directly is based on how it is done here:
# https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/peft_model.py#L672-L689
# This may need to be updated in the future to support other adapter types (LoKr, LoHa, etc.).
# Also, I could see this interface breaking in the future.
lora_config = peft.LoraConfig.from_pretrained(peft_model_dir)
lora_weights = peft.utils.load_peft_weights(peft_model_dir, device="cpu")

kohya_state_dict.update(
_convert_peft_state_dict_to_kohya_state_dict(
lora_config=lora_config, peft_state_dict=lora_weights, prefix=kohya_prefix, dtype=dtype
)
for peft_model_dir in peft_model_dirs:
if peft_model_dir.name in SD_PEFT_TO_KOHYA_KEYS:
kohya_prefix = SD_PEFT_TO_KOHYA_KEYS[peft_model_dir.name]
elif peft_model_dir.name in SDXL_PEFT_TO_KOHYA_KEYS:
kohya_prefix = SDXL_PEFT_TO_KOHYA_KEYS[peft_model_dir.name]
else:
raise ValueError(f"Unrecognized checkpoint directory: '{peft_model_dir}'.")

# Note: This logic to load the LoraConfig and weights directly is based on how it is done here:
# https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/peft_model.py#L672-L689
# This may need to be updated in the future to support other adapter types (LoKr, LoHa, etc.).
# Also, I could see this interface breaking in the future.
lora_config = peft.LoraConfig.from_pretrained(peft_model_dir)
lora_weights = peft.utils.load_peft_weights(peft_model_dir, device="cpu")

kohya_state_dict.update(
_convert_peft_state_dict_to_kohya_state_dict(
lora_config=lora_config, peft_state_dict=lora_weights, prefix=kohya_prefix, dtype=dtype
)
)

save_state_dict(kohya_state_dict, out_checkpoint_file)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile
import time
from pathlib import Path
from typing import Optional, Union
from typing import Literal, Optional, Union

import peft
import torch
Expand Down Expand Up @@ -39,6 +39,7 @@
from invoke_training.training._shared.stable_diffusion.lora_checkpoint_utils import (
TEXT_ENCODER_TARGET_MODULES,
UNET_TARGET_MODULES,
save_sd_kohya_checkpoint,
save_sd_peft_checkpoint,
)
from invoke_training.training._shared.stable_diffusion.model_loading_utils import load_models_sd
Expand All @@ -52,14 +53,20 @@ def _save_sd_lora_checkpoint(
text_encoder: peft.PeftModel | None,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
lora_checkpoint_format: Literal["invoke_peft", "kohya"],
):
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(idx)

save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
if lora_checkpoint_format == "invoke_peft":
save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
elif lora_checkpoint_format == "kohya":
save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
else:
raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.")


def _build_data_loader(
Expand Down Expand Up @@ -437,12 +444,14 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
base_dir=ckpt_dir,
prefix="checkpoint_epoch",
max_checkpoints=config.max_checkpoints,
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
)

step_checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint_step",
max_checkpoints=config.max_checkpoints,
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
)

# Train!
Expand Down Expand Up @@ -524,6 +533,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
text_encoder=accelerator.unwrap_model(text_encoder) if config.train_text_encoder else None,
logger=logger,
checkpoint_tracker=step_checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)

logs = {
Expand All @@ -545,6 +555,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
text_encoder=accelerator.unwrap_model(text_encoder) if config.train_text_encoder else None,
logger=logger,
checkpoint_tracker=epoch_checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)

# Generate validation images every n epochs.
Expand Down
Loading

0 comments on commit 17332fc

Please sign in to comment.