Skip to content

Commit

Permalink
Get LoRA merge script working with InvokeAI LoRA loading.
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanJDick committed May 28, 2024
1 parent dbc13ae commit c83f7cf
Showing 1 changed file with 87 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import argparse
import argparse # noqa: I001
import logging
from pathlib import Path
from typing import Literal

import torch
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline

# fmt: off
# HACK(ryand): Import order matters, because invokeai contains circular imports.
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
# fmt: on

from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline

Expand All @@ -20,6 +28,62 @@ def str_to_dtype(dtype_str: Literal["float32", "float16", "bfloat16"]):
raise ValueError(f"Unexpected dtype: {dtype_str}")


def to_invokeai_base_model_type(base_model_type: PipelineVersionEnum):
if base_model_type == PipelineVersionEnum.SD:
return BaseModelType.StableDiffusion1
elif base_model_type == PipelineVersionEnum.SDXL:
return BaseModelType.StableDiffusionXL
else:
raise ValueError(f"Unexpected base_model_type: {base_model_type}")


@torch.no_grad()
def apply_lora_model_to_base_model(
base_model: torch.nn.Module,
lora: LoRAModelRaw,
lora_weight: float,
prefix: str,
):
"""Apply a LoRAModelRaw model to a base model.
This implementation is based on:
https://github.com/invoke-ai/InvokeAI/blob/df91d1b8497e95c9520fb2f46522384220429011/invokeai/backend/model_patcher.py#L105
This function is simplified relative to the original implementation, because it does not need to support unpatching.
Args:
base_model (torch.nn.Module): The base model to patch.
loras (list[tuple[LoRAModelRaw, float]]): The LoRA models to apply, with their associated weights.
prefix (str): The prefix of the LoRA layers to apply to this base_model.
"""
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue

module_key, module = ModelPatcher._resolve_lora_key(base_model, layer_key, prefix)

# All of the LoRA weight calculations will be done on the same device as the module weight.
device = module.weight.device
dtype = module.weight.dtype

layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0

# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu"))

if module.weight.shape != layer_weight.shape:
assert hasattr(layer_weight, "reshape")
layer_weight = layer_weight.reshape(module.weight.shape)

module.weight += layer_weight.to(dtype=dtype)


@torch.no_grad()
def merge_lora_into_sd_model(
logger: logging.Logger,
base_model: str,
Expand All @@ -29,7 +93,7 @@ def merge_lora_into_sd_model(
output: str,
save_dtype: str,
):
pipeline = load_pipeline(
pipeline: StableDiffusionXLPipeline | StableDiffusionPipeline = load_pipeline(
logger=logger, model_name_or_path=base_model, pipeline_version=base_model_type, variant=base_model_variant
)
save_dtype = str_to_dtype(save_dtype)
Expand All @@ -38,16 +102,27 @@ def merge_lora_into_sd_model(

pipeline.to(save_dtype)

lora_adapter_names = []
for i, lora_model in enumerate(lora_models):
lora_adapter_name = f"lora_{i}"
pipeline.load_lora_weights(lora_model, adapter_name=lora_adapter_name)
lora_adapter_names.append(lora_adapter_name)

logger.info(f"Loaded {len(lora_models)} LoRA models.")

pipeline.set_adapters(adapter_names=lora_adapter_names, adapter_weights=[1.0] * len(lora_adapter_names))
pipeline.fuse_lora()
models: list[torch.nn.Module] = []
lora_prefixes: list[str] = []
if isinstance(pipeline, StableDiffusionPipeline):
models = [pipeline.unet, pipeline.text_encoder]
lora_prefixes = ["lora_unet_", "lora_te_"]
elif isinstance(pipeline, StableDiffusionXLPipeline):
models = [pipeline.unet, pipeline.text_encoder, pipeline.text_encoder_2]
lora_prefixes = ["lora_unet_", "lora_te1_", "lora_te2_"]
else:
raise ValueError(f"Unexpected pipeline type: {type(pipeline)}")

for lora_model_path in lora_models:
lora_model = LoRAModelRaw.from_checkpoint(
file_path=lora_model_path,
device=pipeline.device,
dtype=save_dtype,
base_model=to_invokeai_base_model_type(base_model_type),
)
for model, lora_prefix in zip(models, lora_prefixes, strict=True):
# TODO(ryand): Parameterize the weight.
apply_lora_model_to_base_model(base_model=model, lora=lora_model, lora_weight=1.0, prefix=lora_prefix)

output_path = Path(output)
output_path.mkdir(parents=True)
Expand Down

0 comments on commit c83f7cf

Please sign in to comment.