Skip to content

Commit

Permalink
Refactor extract_lora_from_checkpoint.py to support loading from both…
Browse files Browse the repository at this point in the history
… diffusers and checkpoint files, and to support unet, te1, and te2 submodels.
  • Loading branch information
RyanJDick committed Jun 3, 2024
1 parent 31d71c0 commit 55c9c3a
Showing 1 changed file with 228 additions and 65 deletions.
293 changes: 228 additions & 65 deletions src/invoke_training/model_merge/scripts/extract_lora_from_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,92 @@
# https://raw.githubusercontent.com/kohya-ss/sd-scripts/bfb352bc433326a77aca3124248331eb60c49e8c/networks/extract_lora_from_models.py
# That script was originally based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py


import argparse
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import peft
import torch
from diffusers import UNet2DConditionModel
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection

from invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str
from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (
UNET_TARGET_MODULES,
save_sdxl_kohya_checkpoint,
)
from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline
from invoke_training.model_merge.extract_lora import (
PEFT_BASE_LAYER_PREFIX,
extract_lora_from_diffs,
get_patched_base_weights_from_peft_model,
get_state_dict_diff,
)
from invoke_training.model_merge.utils.parse_model_arg import parse_model_arg


@dataclass
class StableDiffusionModel:
"""A helper class to store the submodels of a SD model that we are interested in for LoRA extraction."""

unet: UNet2DConditionModel | None = None
# TODO(ryand): Figure out the actual type of these text encoders.
text_encoder: CLIPTextModel | None = None
text_encoder_2: CLIPTextModelWithProjection | None = None

def all_none(self) -> bool:
return self.unet is None and self.text_encoder is None and self.text_encoder_2 is None


def load_model(
logger: logging.Logger,
model_name_or_path: str,
model_type: PipelineVersionEnum,
variant: str | None,
dtype: torch.dtype,
) -> StableDiffusionModel:
pipeline: StableDiffusionPipeline | StableDiffusionXLPipeline | None = None
try:
pipeline = load_pipeline(
logger=logger,
model_name_or_path=model_name_or_path,
pipeline_version=model_type,
torch_dtype=dtype,
variant=variant,
)
except Exception as e:
logger.info(f"Failed to load full SD/SDXL model '{model_name_or_path}': {e}.")

if pipeline is not None:
return StableDiffusionModel(
unet=pipeline.unet, text_encoder=pipeline.text_encoder, text_encoder_2=pipeline.text_encoder_2
)

# Failed to load a full pipeline. Try to load the submodels.
logger.info("Attempting to load the available submodels.")
sd_model = StableDiffusionModel()
base_model_path = Path(model_name_or_path)
for submodel_name, submodel_class in [
("unet", UNet2DConditionModel),
("text_encoder", CLIPTextModel),
("text_encoder_2", CLIPTextModelWithProjection),
]:
try:
# TODO(ryand): Add variant fallbacks?
submodel = submodel_class.from_pretrained(
base_model_path / submodel_name, variant=variant, torch_dtype=dtype, local_files_only=True
)
setattr(sd_model, submodel_name, submodel)
except Exception:
logger.info(f"Failed to load '{submodel_name}' from '{model_name_or_path}'.")

if sd_model.all_none():
raise RuntimeError(f"Failed to load any submodels from '{model_name_or_path}'.")

return sd_model


def str_to_device(device_str: Literal["cuda", "cpu"]) -> torch.device:
Expand All @@ -35,59 +99,40 @@ def str_to_device(device_str: Literal["cuda", "cpu"]) -> torch.device:
raise ValueError(f"Unexpected device: {device_str}")


def load_sdxl_unet(model_path: str) -> UNet2DConditionModel:
variants_to_try = [None, "fp16"]
unet = None
for variant in variants_to_try:
try:
unet = UNet2DConditionModel.from_pretrained(model_path, variant=variant, local_files_only=True)
except OSError as e:
if "no file named" in str(e):
# Ok. We'll try a different variant.
pass
else:
raise
if unet is None:
raise RuntimeError(f"Failed to load UNet from '{model_path}'.")
return unet
# TODO(ryand): Delete this after integrating the variant fallback logic.
# def load_sdxl_unet(model_path: str) -> UNet2DConditionModel:
# variants_to_try = [None, "fp16"]
# unet = None
# for variant in variants_to_try:
# try:
# unet = UNet2DConditionModel.from_pretrained(model_path, variant=variant, local_files_only=True)
# except OSError as e:
# if "no file named" in str(e):
# # Ok. We'll try a different variant.
# pass
# else:
# raise
# if unet is None:
# raise RuntimeError(f"Failed to load UNet from '{model_path}'.")
# return unet


def state_dict_to_device(state_dict: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]:
return {k: v.to(device=device) for k, v in state_dict.items()}


@torch.no_grad()
def extract_lora(
def extract_lora_from_submodel(
logger: logging.Logger,
model_type: Literal["sd1", "sdxl"],
model_orig_path: str,
model_tuned_path: str,
save_to: str,
load_precision: Literal["float32", "float16", "bfloat16"],
save_precision: Literal["float32", "float16", "bfloat16"],
device: Literal["cuda", "cpu"],
model_orig: torch.nn.Module,
model_tuned: torch.nn.Module,
device: torch.device,
out_dtype: torch.dtype,
lora_rank: int,
clamp_quantile=0.99,
):
load_dtype = get_dtype_from_str(load_precision)
save_dtype = get_dtype_from_str(save_precision)
device = str_to_device(device)

# Load models.
if model_type == "sd1":
raise NotImplementedError("SD1 support is not yet implemented.")
elif model_type == "sdxl":
logger.info(f"Loading original SDXL model: '{model_orig_path}'.")
unet_orig = load_sdxl_unet(model_orig_path)
logger.info(f"Loading tuned SDXL model: '{model_tuned_path}'.")
unet_tuned = load_sdxl_unet(model_tuned_path)

if load_dtype is not None:
unet_orig = unet_orig.to(load_dtype)
unet_tuned = unet_tuned.to(load_dtype)
else:
raise ValueError(f"Unexpected model type: '{model_type}'.")

clamp_quantile: float = 0.99,
) -> peft.PeftModel:
"""Extract LoRA weights from the diff between model_orig and model_tuned. Returns a new model_orig, wrapped in a
PeftModel, with the LoRA weights applied.
"""
# Apply LoRA to the UNet.
# The only reason we do this is to get the module names for the weights that we'll extract. We don't actually use
# the LoRA weights initialized here.
Expand All @@ -98,65 +143,177 @@ def extract_lora(
lora_alpha=lora_rank,
target_modules=UNET_TARGET_MODULES,
)
unet_tuned = peft.get_peft_model(unet_tuned, unet_lora_config)
unet_orig = peft.get_peft_model(unet_orig, unet_lora_config)
model_tuned = peft.get_peft_model(model_tuned, unet_lora_config)
model_orig = peft.get_peft_model(model_orig, unet_lora_config)

unet_tuned_base_weights = get_patched_base_weights_from_peft_model(unet_tuned)
unet_orig_base_weights = get_patched_base_weights_from_peft_model(unet_orig)
base_weights_tuned = get_patched_base_weights_from_peft_model(model_tuned)
base_weights_orig = get_patched_base_weights_from_peft_model(model_orig)

diffs = get_state_dict_diff(unet_tuned_base_weights, unet_orig_base_weights)
diffs = get_state_dict_diff(base_weights_tuned, base_weights_orig)

# Clear tuned UNet to save memory.
# Clear tuned model to save memory.
# TODO(ryand): We also need to clear the state_dicts. Move the diff extraction to a separate function so that memory
# cleanup is handled by scoping.
del unet_tuned
del model_tuned

# Apply SVD (Singluar Value Decomposition) to the diffs.
# We just use the device for this calculation, since it's slow, then we move the results back to the CPU.
logger.info("Calculating LoRA weights with SVD.")
diffs = state_dict_to_device(diffs, device)
lora_weights = extract_lora_from_diffs(
diffs=diffs, rank=lora_rank, clamp_quantile=clamp_quantile, out_dtype=save_dtype
diffs=diffs, rank=lora_rank, clamp_quantile=clamp_quantile, out_dtype=out_dtype
)

# Prepare state dict for LoRA.
lora_state_dict = {}
for module_name, (lora_up, lora_down) in lora_weights.items():
lora_state_dict[PEFT_BASE_LAYER_PREFIX + module_name + ".lora_A.default.weight"] = lora_down
lora_state_dict[PEFT_BASE_LAYER_PREFIX + module_name + ".lora_B.default.weight"] = lora_up
# TODO(ryand): Double-check that this isn't needed with peft.
# The alpha value is set once globally in the PEFT model, so no need to set it for each module.
# lora_state_dict[peft_base_layer_suffix + module_name + ".alpha"] = torch.tensor(down_weight.size()[0])

lora_state_dict = state_dict_to_device(lora_state_dict, torch.device("cpu"))

# Load the state_dict into the LoRA model.
unet_orig.load_state_dict(lora_state_dict, strict=False, assign=True)
model_orig.load_state_dict(lora_state_dict, strict=False, assign=True)

return model_orig


@torch.no_grad()
def extract_lora(
logger: logging.Logger,
model_type: PipelineVersionEnum,
orig_model_name_or_path: str,
orig_model_variant: str | None,
tuned_model_name_or_path: str,
tuned_model_variant: str | None,
save_to: str,
load_precision: Literal["float32", "float16", "bfloat16"],
save_precision: Literal["float32", "float16", "bfloat16"],
device: Literal["cuda", "cpu"],
lora_rank: int,
clamp_quantile=0.99,
):
load_dtype = get_dtype_from_str(load_precision)
save_dtype = get_dtype_from_str(save_precision)
device = str_to_device(device)

# Load models.
# if model_type == "sd1":
# raise NotImplementedError("SD1 support is not yet implemented.")
# elif model_type == "sdxl":
# logger.info(f"Loading original SDXL model: '{model_orig_path}'.")
# unet_orig = load_sdxl_unet(model_orig_path)
# logger.info(f"Loading tuned SDXL model: '{model_tuned_path}'.")
# unet_tuned = load_sdxl_unet(model_tuned_path)

# if load_dtype is not None:
# unet_orig = unet_orig.to(load_dtype)
# unet_tuned = unet_tuned.to(load_dtype)
# else:
# raise ValueError(f"Unexpected model type: '{model_type}'.")

orig_model = load_model(
logger=logger,
model_name_or_path=orig_model_name_or_path,
model_type=model_type,
dtype=load_dtype,
variant=orig_model_variant,
)
tuned_model = load_model(
logger=logger,
model_name_or_path=tuned_model_name_or_path,
model_type=model_type,
dtype=load_dtype,
variant=tuned_model_variant,
)

# TODO(ryand): Consolidate these calls to extract_lora_from_submodel.
unet_orig_with_lora = None
if orig_model.unet is not None and tuned_model.unet is not None:
logger.info("Extracting LoRA from UNet.")
unet_orig_with_lora = extract_lora_from_submodel(
logger=logger,
model_orig=orig_model.unet,
model_tuned=tuned_model.unet,
device=device,
out_dtype=save_dtype,
lora_rank=lora_rank,
clamp_quantile=clamp_quantile,
)

text_encoder_orig_with_lora = None
if orig_model.text_encoder is not None and tuned_model.text_encoder is not None:
logger.info("Extracting LoRA from text encoder.")
text_encoder_orig_with_lora = extract_lora_from_submodel(
logger=logger,
model_orig=orig_model.text_encoder,
model_tuned=tuned_model.text_encoder,
device=device,
out_dtype=save_dtype,
lora_rank=lora_rank,
clamp_quantile=clamp_quantile,
)

text_encoder_2_orig_with_lora = None
if orig_model.text_encoder_2 is not None and tuned_model.text_encoder_2 is not None:
logger.info("Extracting LoRA from text encoder 2.")
text_encoder_2_orig_with_lora = extract_lora_from_submodel(
logger=logger,
model_orig=orig_model.text_encoder_2,
model_tuned=tuned_model.text_encoder_2,
device=device,
out_dtype=save_dtype,
lora_rank=lora_rank,
clamp_quantile=clamp_quantile,
)

# Save the LoRA weights.
save_to_path = Path(save_to)
assert save_to_path.suffix == ".safetensors"
if save_to_path.exists():
raise FileExistsError(f"Destination file already exists: '{save_to}'.")
save_to_path.parent.mkdir(parents=True, exist_ok=True)
save_sdxl_kohya_checkpoint(save_to_path, unet=unet_orig, text_encoder_1=None, text_encoder_2=None)
save_sdxl_kohya_checkpoint(
save_to_path,
unet=unet_orig_with_lora,
text_encoder_1=text_encoder_orig_with_lora,
text_encoder_2=text_encoder_2_orig_with_lora,
)

logger.info(f"Saved LoRA weights to: {save_to_path}")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-type", type=str, required=True, choices=["sd1", "sdxl"], help="The base model type.")

parser.add_argument(
"--model-type",
type=str,
choices=["SD", "SDXL"],
help="The type of the models to merge ['SD', 'SDXL'].",
)
parser.add_argument(
"--model-orig",
type=str,
required=True,
help="Path to the original model. (This should be a unet directory in diffusers format.)",
help="Path or HF Hub name of the original model. The model must be in one of the following formats: "
"1) a single checkpoint file (e.g. '.safetensors') containing all submodels, "
"2) a model in diffusers format containing all submodels, "
"or 3) a model in diffusers format containing a subset of the submodels (e.g. only a UNet)."
"An HF variant can optionally be appended to the model name after a double-colon delimiter ('::')."
"E.g. '--model-orig runwayml/stable-diffusion-v1-5::fp16'",
)
parser.add_argument(
"--model-tuned",
type=str,
required=True,
help="Path to the tuned model. (This should be a unet directory in diffusers format.)",
help="Path or HF Hub name of the tuned model. The model must be in one of the following formats: "
"1) a single checkpoint file (e.g. '.safetensors') containing all submodels, "
"2) a model in diffusers format containing all submodels, "
"or 3) a model in diffusers format containing a subset of the submodels (e.g. only a UNet)."
"An HF variant can optionally be appended to the model name after a double-colon delimiter ('::')."
"E.g. '--model-orig runwayml/stable-diffusion-v1-5::fp16'",
)
parser.add_argument(
"--save-to",
Expand Down Expand Up @@ -190,11 +347,17 @@ def main():
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

orig_model_name_or_path, orig_model_variant = parse_model_arg(args.model_orig)
tuned_model_name_or_path, tuned_model_variant = parse_model_arg(args.model_tuned)

extract_lora(
logger=logger,
model_type=args.model_type,
model_orig_path=args.model_orig,
model_tuned_path=args.model_tuned,
model_type=PipelineVersionEnum(args.model_type),
orig_model_name_or_path=orig_model_name_or_path,
orig_model_variant=orig_model_variant,
tuned_model_name_or_path=tuned_model_name_or_path,
tuned_model_variant=tuned_model_variant,
save_to=args.save_to,
load_precision=args.load_precision,
save_precision=args.save_precision,
Expand Down

0 comments on commit 55c9c3a

Please sign in to comment.