From 6e950277c97bd6b9439d5dc749ad4ce84a1a7ed0 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Sun, 9 Jun 2024 07:26:06 -0700 Subject: [PATCH] Revert "aider: Added all potential options for Merging and Extracting with progressively disclosed UI elements." This reverts commit ce0046d15fbe160ec852c603f82c75508efea913. --- .../ui/pages/model_merge_page.py | 33 +++++-------------- 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/src/invoke_training/ui/pages/model_merge_page.py b/src/invoke_training/ui/pages/model_merge_page.py index 35ca3105..c7596574 100644 --- a/src/invoke_training/ui/pages/model_merge_page.py +++ b/src/invoke_training/ui/pages/model_merge_page.py @@ -15,22 +15,17 @@ def __init__(self): self._app = app def _create_merge_tab(self): - with gr.Row(): - gr.Markdown("## Merge LoRA into SD Model") with gr.Row(): gr.Markdown("## Merge LoRA into SD Model") with gr.Row(): base_model = gr.Textbox(label="Base Model Path") - base_model_variant = gr.Textbox(label="Base Model Variant (Optional)") - base_model_type = gr.Dropdown(choices=["SD", "SDXL"], label="Base Model Type") - lora_models = gr.Textbox(label="LoRA Models (comma-separated paths with optional weights, e.g., 'path1::0.5,path2')") + lora_model = gr.Textbox(label="LoRA Model Path") output_path = gr.Textbox(label="Output Path") - save_dtype = gr.Dropdown(choices=["float32", "float16", "bfloat16"], label="Save Dtype") merge_button = gr.Button("Merge") merge_button.click( fn=self._merge_lora_into_sd_model, - inputs=[base_model, base_model_variant, base_model_type, lora_models, output_path, save_dtype], + inputs=[base_model, lora_model, output_path], outputs=[] ) @@ -38,34 +33,24 @@ def _create_extract_tab(self): with gr.Row(): gr.Markdown("## Extract LoRA from Checkpoint") with gr.Row(): - gr.Markdown("## Extract LoRA from Checkpoint") - with gr.Row(): - model_type = gr.Dropdown(choices=["sd1", "sdxl"], label="Model Type") model_orig = gr.Textbox(label="Original Model Path") model_tuned = gr.Textbox(label="Tuned Model Path") save_to = gr.Textbox(label="Save To Path") - load_precision = gr.Dropdown(choices=["fp32", "fp16", "bf16"], label="Load Precision") - save_precision = gr.Dropdown(choices=["fp32", "fp16", "bf16"], label="Save Precision") - device = gr.Dropdown(choices=["cuda", "cpu"], label="Device") - lora_rank = gr.Slider(minimum=1, maximum=128, step=1, label="LoRA Rank") - clamp_quantile = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Clamp Quantile") extract_button = gr.Button("Extract") extract_button.click( fn=self._extract_lora_from_checkpoint, - inputs=[model_type, model_orig, model_tuned, save_to, load_precision, save_precision, device, lora_rank, clamp_quantile], + inputs=[model_orig, model_tuned, save_to], outputs=[] ) - def _merge_lora_into_sd_model(self, base_model, base_model_variant, base_model_type, lora_models, output_path, save_dtype): - lora_models_list = [tuple(lm.split("::")) if "::" in lm else (lm, 1.0) for lm in lora_models.split(",")] - lora_models_list = [(path, float(weight)) for path, weight in lora_models_list] - # Call the actual merge function here - print(f"Merging LoRA models {lora_models_list} into base model {base_model} with variant {base_model_variant} and type {base_model_type}, saving to {output_path} with dtype {save_dtype}") + def _merge_lora_into_sd_model(self, base_model, lora_model, output_path): + # Placeholder function for merging LoRA into SD model + print(f"Merging LoRA model {lora_model} into base model {base_model} and saving to {output_path}") - def _extract_lora_from_checkpoint(self, model_type, model_orig, model_tuned, save_to, load_precision, save_precision, device, lora_rank, clamp_quantile): - # Call the actual extraction function here - print(f"Extracting LoRA from {model_tuned} using original model {model_orig} with type {model_type}, saving to {save_to} with load precision {load_precision}, save precision {save_precision}, on device {device}, with rank {lora_rank} and clamp quantile {clamp_quantile}") + def _extract_lora_from_checkpoint(self, model_orig, model_tuned, save_to): + # Placeholder function for extracting LoRA from checkpoint + print(f"Extracting LoRA from {model_tuned} using original model {model_orig} and saving to {save_to}") def app(self): return self._app