Skip to content

Commit

Permalink
Add support for weighted LoRAs in the LoRA merge script.
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanJDick committed May 28, 2024
1 parent c83f7cf commit 33882ca
Showing 1 changed file with 21 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def merge_lora_into_sd_model(
base_model: str,
base_model_variant: str | None,
base_model_type: PipelineVersionEnum,
lora_models: list[str],
lora_models: list[tuple[str, float]],
output: str,
save_dtype: str,
):
Expand All @@ -113,16 +113,18 @@ def merge_lora_into_sd_model(
else:
raise ValueError(f"Unexpected pipeline type: {type(pipeline)}")

for lora_model_path in lora_models:
for lora_model_path, lora_model_weight in lora_models:
lora_model = LoRAModelRaw.from_checkpoint(
file_path=lora_model_path,
device=pipeline.device,
dtype=save_dtype,
base_model=to_invokeai_base_model_type(base_model_type),
)
for model, lora_prefix in zip(models, lora_prefixes, strict=True):
# TODO(ryand): Parameterize the weight.
apply_lora_model_to_base_model(base_model=model, lora=lora_model, lora_weight=1.0, prefix=lora_prefix)
apply_lora_model_to_base_model(
base_model=model, lora=lora_model, lora_weight=lora_model_weight, prefix=lora_prefix
)
logger.info(f"Applied LoRA model '{lora_model_path}' with weight {lora_model_weight}.")

output_path = Path(output)
output_path.mkdir(parents=True)
Expand All @@ -132,6 +134,17 @@ def merge_lora_into_sd_model(
logger.info(f"Saved merged model to '{output_path}'.")


def parse_lora_model_arg(lora_model_arg: str) -> tuple[str, float]:
"""Parse a --lora-model argument into a tuple of the model path and weight."""
parts = lora_model_arg.split(":")
if len(parts) == 1:
return parts[0], 1.0
elif len(parts) == 2:
return parts[0], float(parts[1])
else:
raise ValueError(f"Unexpected format for --lora-model arg: '{lora_model_arg}'.")


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -157,7 +170,9 @@ def main():
"--lora-model",
type=str,
nargs="+",
help="The path(s) to one or more LoRA models to merge into the base model.",
help="The path(s) to one or more LoRA models to merge into the base model. Model weights can be appended to "
"the path, separated by a colon (':'). E.g. 'path/to/lora_model:0.5'. The weight is optional and defaults to "
"1.0.",
required=True,
)
parser.add_argument(
Expand All @@ -182,7 +197,7 @@ def main():
base_model=args.base_model,
base_model_variant=args.base_model_variant,
base_model_type=PipelineVersionEnum(args.base_model_type),
lora_models=args.lora_model,
lora_models=[parse_lora_model_arg(arg) for arg in args.lora_model],
output=args.output,
save_dtype=args.save_dtype,
)
Expand Down

0 comments on commit 33882ca

Please sign in to comment.