Skip to content

Commit

Permalink
Default to saving LoRA checkpoints in kohya format.
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanJDick committed Jan 26, 2024
1 parent 79d7c8a commit 9834cca
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class LoraAndTiTrainingConfig(BasePipelineConfig):
"""The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.
"""

lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya"
"""The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`."""

# Helpful discussion for understanding how this works at inference time:
# https://github.com/huggingface/diffusers/pull/3144#discussion_r1172413509
num_vectors: int = 1
Expand Down
3 changes: 3 additions & 0 deletions src/invoke_training/config/pipelines/finetune_lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class LoRATrainingConfig(BasePipelineConfig):
generation time with the resultant LoRA model.
"""

lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya"
"""The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`."""

train_unet: bool = True
"""Whether to add LoRA layers to the UNet model and train it.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tempfile
import time
from pathlib import Path
from typing import Literal

import peft
import torch
Expand All @@ -32,6 +33,7 @@
TEXT_ENCODER_TARGET_MODULES,
UNET_TARGET_MODULES,
load_sd_peft_checkpoint,
save_sd_kohya_checkpoint,
save_sd_peft_checkpoint,
)
from invoke_training.training._shared.stable_diffusion.model_loading_utils import load_models_sd
Expand All @@ -46,14 +48,20 @@ def _save_sd_lora_checkpoint(
text_encoder: peft.PeftModel | None,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
lora_checkpoint_format: Literal["invoke_peft", "kohya"],
):
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(idx)

save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
if lora_checkpoint_format == "invoke_peft":
save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
elif lora_checkpoint_format == "kohya":
save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
else:
raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.")


def train_forward_dpo( # noqa: C901
Expand Down Expand Up @@ -421,14 +429,14 @@ def prep_peft_model(model, lr: float | None = None):
epoch_checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint_epoch",
extension=".safetensors",
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
max_checkpoints=config.max_checkpoints,
)

step_checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint_step",
extension=".safetensors",
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
max_checkpoints=config.max_checkpoints,
)

Expand Down Expand Up @@ -513,6 +521,7 @@ def prep_peft_model(model, lr: float | None = None):
text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None,
logger=logger,
checkpoint_tracker=step_checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)

logs = {
Expand All @@ -534,6 +543,7 @@ def prep_peft_model(model, lr: float | None = None):
text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None,
logger=logger,
checkpoint_tracker=epoch_checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)

# Generate validation images every n epochs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,66 @@ def _convert_peft_state_dict_to_kohya_state_dict(
return kohya_ss_state_dict


def _convert_peft_models_to_kohya_state_dict(
kohya_prefixes: list[str], models: list[peft.PeftModel]
) -> dict[str, torch.Tensor]:
kohya_state_dict = {}
default_adapter_name = "default"

for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True):
lora_config = peft_model.peft_config[default_adapter_name]
assert isinstance(lora_config, peft.LoraConfig)

peft_state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name)

kohya_state_dict.update(
_convert_peft_state_dict_to_kohya_state_dict(
lora_config=lora_config,
peft_state_dict=peft_state_dict,
prefix=kohya_prefix,
dtype=torch.float32,
)
)

return kohya_state_dict


def save_sd_kohya_checkpoint(checkpoint_path: Path, unet: peft.PeftModel | None, text_encoder: peft.PeftModel | None):
kohya_prefixes = []
models = []
for kohya_prefix, peft_model in zip([SD_KOHYA_UNET_KEY, SD_KOHYA_TEXT_ENCODER_KEY], [unet, text_encoder]):
if peft_model is not None:
kohya_prefixes.append(kohya_prefix)
models.append(peft_model)

kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models)

checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
save_state_dict(kohya_state_dict, checkpoint_path)


def save_sdxl_kohya_checkpoint(
checkpoint_path: Path,
unet: peft.PeftModel | None,
text_encoder_1: peft.PeftModel | None,
text_encoder_2: peft.PeftModel | None,
):
kohya_prefixes = []
models = []
for kohya_prefix, peft_model in zip(
[SDXL_KOHYA_UNET_KEY, SDXL_KOHYA_TEXT_ENCODER_1_KEY, SDXL_KOHYA_TEXT_ENCODER_2_KEY],
[unet, text_encoder_1, text_encoder_2],
):
if peft_model is not None:
kohya_prefixes.append(kohya_prefix)
models.append(peft_model)

kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models)

checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
save_state_dict(kohya_state_dict, checkpoint_path)


def convert_sd_peft_checkpoint_to_kohya_state_dict(
in_checkpoint_dir: Path,
out_checkpoint_file: Path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile
import time
from pathlib import Path
from typing import Optional, Union
from typing import Literal, Optional, Union

import peft
import torch
Expand Down Expand Up @@ -39,6 +39,7 @@
from invoke_training.training._shared.stable_diffusion.lora_checkpoint_utils import (
TEXT_ENCODER_TARGET_MODULES,
UNET_TARGET_MODULES,
save_sd_kohya_checkpoint,
save_sd_peft_checkpoint,
)
from invoke_training.training._shared.stable_diffusion.model_loading_utils import load_models_sd
Expand All @@ -52,14 +53,20 @@ def _save_sd_lora_checkpoint(
text_encoder: peft.PeftModel | None,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
lora_checkpoint_format: Literal["invoke_peft", "kohya"],
):
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(idx)

save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
if lora_checkpoint_format == "invoke_peft":
save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
elif lora_checkpoint_format == "kohya":
save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
else:
raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.")


def _build_data_loader(
Expand Down Expand Up @@ -437,12 +444,14 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
base_dir=ckpt_dir,
prefix="checkpoint_epoch",
max_checkpoints=config.max_checkpoints,
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
)

step_checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint_step",
max_checkpoints=config.max_checkpoints,
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
)

# Train!
Expand Down Expand Up @@ -524,6 +533,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
text_encoder=accelerator.unwrap_model(text_encoder) if config.train_text_encoder else None,
logger=logger,
checkpoint_tracker=step_checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)

logs = {
Expand All @@ -545,6 +555,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
text_encoder=accelerator.unwrap_model(text_encoder) if config.train_text_encoder else None,
logger=logger,
checkpoint_tracker=epoch_checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)

# Generate validation images every n epochs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import time
from pathlib import Path
from typing import Literal

import peft
import torch
Expand Down Expand Up @@ -32,6 +33,7 @@
from invoke_training.training._shared.stable_diffusion.lora_checkpoint_utils import (
TEXT_ENCODER_TARGET_MODULES,
UNET_TARGET_MODULES,
save_sdxl_kohya_checkpoint,
save_sdxl_peft_checkpoint,
)
from invoke_training.training._shared.stable_diffusion.model_loading_utils import (
Expand All @@ -54,19 +56,30 @@ def _save_sdxl_lora_and_ti_checkpoint(
accelerator: Accelerator,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
lora_checkpoint_format: Literal["invoke_peft", "kohya"],
):
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(idx)

save_sdxl_peft_checkpoint(
Path(save_path),
unet=unet if config.train_unet else None,
text_encoder_1=text_encoder_1 if config.train_text_encoder else None,
text_encoder_2=text_encoder_2 if config.train_text_encoder else None,
)
if lora_checkpoint_format == "invoke_peft":
save_sdxl_peft_checkpoint(
Path(save_path),
unet=unet if config.train_unet else None,
text_encoder_1=text_encoder_1 if config.train_text_encoder else None,
text_encoder_2=text_encoder_2 if config.train_text_encoder else None,
)
elif lora_checkpoint_format == "kohya":
save_sdxl_kohya_checkpoint(
Path(save_path) / "lora.safetensors",
unet=unet if config.train_unet else None,
text_encoder_1=text_encoder_1 if config.train_text_encoder else None,
text_encoder_2=text_encoder_2 if config.train_text_encoder else None,
)
else:
raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.")

if config.train_ti:
ti_checkpoint_path = Path(save_path) / "embeddings.safetensors"
Expand Down Expand Up @@ -487,6 +500,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float) -> peft.P
accelerator=accelerator,
logger=logger,
checkpoint_tracker=step_checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)

logs = {
Expand All @@ -512,6 +526,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float) -> peft.P
accelerator=accelerator,
logger=logger,
checkpoint_tracker=epoch_checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)
accelerator.wait_for_everyone()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile
import time
from pathlib import Path
from typing import Optional, Union
from typing import Literal, Optional, Union

import peft
import torch
Expand Down Expand Up @@ -43,6 +43,7 @@
from invoke_training.training._shared.stable_diffusion.lora_checkpoint_utils import (
TEXT_ENCODER_TARGET_MODULES,
UNET_TARGET_MODULES,
save_sdxl_kohya_checkpoint,
save_sdxl_peft_checkpoint,
)
from invoke_training.training._shared.stable_diffusion.model_loading_utils import (
Expand All @@ -60,14 +61,24 @@ def _save_sdxl_lora_checkpoint(
text_encoder_2: peft.PeftModel | None,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
lora_checkpoint_format: Literal["invoke_peft", "kohya"],
):
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(idx)

save_sdxl_peft_checkpoint(Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2)
if lora_checkpoint_format == "invoke_peft":
save_sdxl_peft_checkpoint(
Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2
)
elif lora_checkpoint_format == "kohya":
save_sdxl_kohya_checkpoint(
Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2
)
else:
raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.")


def _build_data_loader(
Expand Down Expand Up @@ -528,12 +539,14 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
base_dir=ckpt_dir,
prefix="checkpoint_epoch",
max_checkpoints=config.max_checkpoints,
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
)

step_checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint_step",
max_checkpoints=config.max_checkpoints,
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
)

# Train!
Expand Down Expand Up @@ -620,6 +633,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
text_encoder_2=text_encoder_2,
logger=logger,
checkpoint_tracker=step_checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)

logs = {
Expand All @@ -641,6 +655,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
text_encoder_2=text_encoder_2,
logger=logger,
checkpoint_tracker=epoch_checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)
accelerator.wait_for_everyone()

Expand Down

0 comments on commit 9834cca

Please sign in to comment.