From c83f7cf86999d9a3693c62bd87ace5c750e96c8f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 28 May 2024 14:41:54 -0400 Subject: [PATCH] Get LoRA merge script working with InvokeAI LoRA loading. --- .../lora_merge/merge_lora_into_sd_model.py | 99 ++++++++++++++++--- 1 file changed, 87 insertions(+), 12 deletions(-) diff --git a/src/invoke_training/scripts/_experimental/lora_merge/merge_lora_into_sd_model.py b/src/invoke_training/scripts/_experimental/lora_merge/merge_lora_into_sd_model.py index ab6d777c..4c3a4b28 100644 --- a/src/invoke_training/scripts/_experimental/lora_merge/merge_lora_into_sd_model.py +++ b/src/invoke_training/scripts/_experimental/lora_merge/merge_lora_into_sd_model.py @@ -1,9 +1,17 @@ -import argparse +import argparse # noqa: I001 import logging from pathlib import Path from typing import Literal import torch +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline + +# fmt: off +# HACK(ryand): Import order matters, because invokeai contains circular imports. +from invokeai.backend.model_manager import BaseModelType +from invokeai.backend.lora import LoRAModelRaw +from invokeai.backend.model_patcher import ModelPatcher +# fmt: on from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline @@ -20,6 +28,62 @@ def str_to_dtype(dtype_str: Literal["float32", "float16", "bfloat16"]): raise ValueError(f"Unexpected dtype: {dtype_str}") +def to_invokeai_base_model_type(base_model_type: PipelineVersionEnum): + if base_model_type == PipelineVersionEnum.SD: + return BaseModelType.StableDiffusion1 + elif base_model_type == PipelineVersionEnum.SDXL: + return BaseModelType.StableDiffusionXL + else: + raise ValueError(f"Unexpected base_model_type: {base_model_type}") + + +@torch.no_grad() +def apply_lora_model_to_base_model( + base_model: torch.nn.Module, + lora: LoRAModelRaw, + lora_weight: float, + prefix: str, +): + """Apply a LoRAModelRaw model to a base model. + + This implementation is based on: + https://github.com/invoke-ai/InvokeAI/blob/df91d1b8497e95c9520fb2f46522384220429011/invokeai/backend/model_patcher.py#L105 + + This function is simplified relative to the original implementation, because it does not need to support unpatching. + + Args: + base_model (torch.nn.Module): The base model to patch. + loras (list[tuple[LoRAModelRaw, float]]): The LoRA models to apply, with their associated weights. + prefix (str): The prefix of the LoRA layers to apply to this base_model. + """ + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + module_key, module = ModelPatcher._resolve_lora_key(base_model, layer_key, prefix) + + # All of the LoRA weight calculations will be done on the same device as the module weight. + device = module.weight.device + dtype = module.weight.dtype + + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) + layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) + layer.to(device=torch.device("cpu")) + + if module.weight.shape != layer_weight.shape: + assert hasattr(layer_weight, "reshape") + layer_weight = layer_weight.reshape(module.weight.shape) + + module.weight += layer_weight.to(dtype=dtype) + + +@torch.no_grad() def merge_lora_into_sd_model( logger: logging.Logger, base_model: str, @@ -29,7 +93,7 @@ def merge_lora_into_sd_model( output: str, save_dtype: str, ): - pipeline = load_pipeline( + pipeline: StableDiffusionXLPipeline | StableDiffusionPipeline = load_pipeline( logger=logger, model_name_or_path=base_model, pipeline_version=base_model_type, variant=base_model_variant ) save_dtype = str_to_dtype(save_dtype) @@ -38,16 +102,27 @@ def merge_lora_into_sd_model( pipeline.to(save_dtype) - lora_adapter_names = [] - for i, lora_model in enumerate(lora_models): - lora_adapter_name = f"lora_{i}" - pipeline.load_lora_weights(lora_model, adapter_name=lora_adapter_name) - lora_adapter_names.append(lora_adapter_name) - - logger.info(f"Loaded {len(lora_models)} LoRA models.") - - pipeline.set_adapters(adapter_names=lora_adapter_names, adapter_weights=[1.0] * len(lora_adapter_names)) - pipeline.fuse_lora() + models: list[torch.nn.Module] = [] + lora_prefixes: list[str] = [] + if isinstance(pipeline, StableDiffusionPipeline): + models = [pipeline.unet, pipeline.text_encoder] + lora_prefixes = ["lora_unet_", "lora_te_"] + elif isinstance(pipeline, StableDiffusionXLPipeline): + models = [pipeline.unet, pipeline.text_encoder, pipeline.text_encoder_2] + lora_prefixes = ["lora_unet_", "lora_te1_", "lora_te2_"] + else: + raise ValueError(f"Unexpected pipeline type: {type(pipeline)}") + + for lora_model_path in lora_models: + lora_model = LoRAModelRaw.from_checkpoint( + file_path=lora_model_path, + device=pipeline.device, + dtype=save_dtype, + base_model=to_invokeai_base_model_type(base_model_type), + ) + for model, lora_prefix in zip(models, lora_prefixes, strict=True): + # TODO(ryand): Parameterize the weight. + apply_lora_model_to_base_model(base_model=model, lora=lora_model, lora_weight=1.0, prefix=lora_prefix) output_path = Path(output) output_path.mkdir(parents=True)