From 33882ca19a6e7ad7e6c22592d27723c076b691eb Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 28 May 2024 15:11:35 -0400 Subject: [PATCH] Add support for weighted LoRAs in the LoRA merge script. --- .../lora_merge/merge_lora_into_sd_model.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/invoke_training/scripts/_experimental/lora_merge/merge_lora_into_sd_model.py b/src/invoke_training/scripts/_experimental/lora_merge/merge_lora_into_sd_model.py index 4c3a4b28..d1cc27c7 100644 --- a/src/invoke_training/scripts/_experimental/lora_merge/merge_lora_into_sd_model.py +++ b/src/invoke_training/scripts/_experimental/lora_merge/merge_lora_into_sd_model.py @@ -89,7 +89,7 @@ def merge_lora_into_sd_model( base_model: str, base_model_variant: str | None, base_model_type: PipelineVersionEnum, - lora_models: list[str], + lora_models: list[tuple[str, float]], output: str, save_dtype: str, ): @@ -113,7 +113,7 @@ def merge_lora_into_sd_model( else: raise ValueError(f"Unexpected pipeline type: {type(pipeline)}") - for lora_model_path in lora_models: + for lora_model_path, lora_model_weight in lora_models: lora_model = LoRAModelRaw.from_checkpoint( file_path=lora_model_path, device=pipeline.device, @@ -121,8 +121,10 @@ def merge_lora_into_sd_model( base_model=to_invokeai_base_model_type(base_model_type), ) for model, lora_prefix in zip(models, lora_prefixes, strict=True): - # TODO(ryand): Parameterize the weight. - apply_lora_model_to_base_model(base_model=model, lora=lora_model, lora_weight=1.0, prefix=lora_prefix) + apply_lora_model_to_base_model( + base_model=model, lora=lora_model, lora_weight=lora_model_weight, prefix=lora_prefix + ) + logger.info(f"Applied LoRA model '{lora_model_path}' with weight {lora_model_weight}.") output_path = Path(output) output_path.mkdir(parents=True) @@ -132,6 +134,17 @@ def merge_lora_into_sd_model( logger.info(f"Saved merged model to '{output_path}'.") +def parse_lora_model_arg(lora_model_arg: str) -> tuple[str, float]: + """Parse a --lora-model argument into a tuple of the model path and weight.""" + parts = lora_model_arg.split(":") + if len(parts) == 1: + return parts[0], 1.0 + elif len(parts) == 2: + return parts[0], float(parts[1]) + else: + raise ValueError(f"Unexpected format for --lora-model arg: '{lora_model_arg}'.") + + def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -157,7 +170,9 @@ def main(): "--lora-model", type=str, nargs="+", - help="The path(s) to one or more LoRA models to merge into the base model.", + help="The path(s) to one or more LoRA models to merge into the base model. Model weights can be appended to " + "the path, separated by a colon (':'). E.g. 'path/to/lora_model:0.5'. The weight is optional and defaults to " + "1.0.", required=True, ) parser.add_argument( @@ -182,7 +197,7 @@ def main(): base_model=args.base_model, base_model_variant=args.base_model_variant, base_model_type=PipelineVersionEnum(args.base_model_type), - lora_models=args.lora_model, + lora_models=[parse_lora_model_arg(arg) for arg in args.lora_model], output=args.output, save_dtype=args.save_dtype, )