Skip to content

Commit

Permalink
Revert "aider: Added all potential options for Merging and Extracting…
Browse files Browse the repository at this point in the history
… with progressively disclosed UI elements."

This reverts commit ce0046d.
  • Loading branch information
hipsterusername committed Jun 9, 2024
1 parent 87a05e8 commit 6e95027
Showing 1 changed file with 9 additions and 24 deletions.
33 changes: 9 additions & 24 deletions src/invoke_training/ui/pages/model_merge_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,57 +15,42 @@ def __init__(self):
self._app = app

def _create_merge_tab(self):
with gr.Row():
gr.Markdown("## Merge LoRA into SD Model")
with gr.Row():
gr.Markdown("## Merge LoRA into SD Model")
with gr.Row():
base_model = gr.Textbox(label="Base Model Path")
base_model_variant = gr.Textbox(label="Base Model Variant (Optional)")
base_model_type = gr.Dropdown(choices=["SD", "SDXL"], label="Base Model Type")
lora_models = gr.Textbox(label="LoRA Models (comma-separated paths with optional weights, e.g., 'path1::0.5,path2')")
lora_model = gr.Textbox(label="LoRA Model Path")
output_path = gr.Textbox(label="Output Path")
save_dtype = gr.Dropdown(choices=["float32", "float16", "bfloat16"], label="Save Dtype")
merge_button = gr.Button("Merge")

merge_button.click(
fn=self._merge_lora_into_sd_model,
inputs=[base_model, base_model_variant, base_model_type, lora_models, output_path, save_dtype],
inputs=[base_model, lora_model, output_path],
outputs=[]
)

def _create_extract_tab(self):
with gr.Row():
gr.Markdown("## Extract LoRA from Checkpoint")
with gr.Row():
gr.Markdown("## Extract LoRA from Checkpoint")
with gr.Row():
model_type = gr.Dropdown(choices=["sd1", "sdxl"], label="Model Type")
model_orig = gr.Textbox(label="Original Model Path")
model_tuned = gr.Textbox(label="Tuned Model Path")
save_to = gr.Textbox(label="Save To Path")
load_precision = gr.Dropdown(choices=["fp32", "fp16", "bf16"], label="Load Precision")
save_precision = gr.Dropdown(choices=["fp32", "fp16", "bf16"], label="Save Precision")
device = gr.Dropdown(choices=["cuda", "cpu"], label="Device")
lora_rank = gr.Slider(minimum=1, maximum=128, step=1, label="LoRA Rank")
clamp_quantile = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Clamp Quantile")
extract_button = gr.Button("Extract")

extract_button.click(
fn=self._extract_lora_from_checkpoint,
inputs=[model_type, model_orig, model_tuned, save_to, load_precision, save_precision, device, lora_rank, clamp_quantile],
inputs=[model_orig, model_tuned, save_to],
outputs=[]
)

def _merge_lora_into_sd_model(self, base_model, base_model_variant, base_model_type, lora_models, output_path, save_dtype):
lora_models_list = [tuple(lm.split("::")) if "::" in lm else (lm, 1.0) for lm in lora_models.split(",")]
lora_models_list = [(path, float(weight)) for path, weight in lora_models_list]
# Call the actual merge function here
print(f"Merging LoRA models {lora_models_list} into base model {base_model} with variant {base_model_variant} and type {base_model_type}, saving to {output_path} with dtype {save_dtype}")
def _merge_lora_into_sd_model(self, base_model, lora_model, output_path):
# Placeholder function for merging LoRA into SD model
print(f"Merging LoRA model {lora_model} into base model {base_model} and saving to {output_path}")

def _extract_lora_from_checkpoint(self, model_type, model_orig, model_tuned, save_to, load_precision, save_precision, device, lora_rank, clamp_quantile):
# Call the actual extraction function here
print(f"Extracting LoRA from {model_tuned} using original model {model_orig} with type {model_type}, saving to {save_to} with load precision {load_precision}, save precision {save_precision}, on device {device}, with rank {lora_rank} and clamp quantile {clamp_quantile}")
def _extract_lora_from_checkpoint(self, model_orig, model_tuned, save_to):
# Placeholder function for extracting LoRA from checkpoint
print(f"Extracting LoRA from {model_tuned} using original model {model_orig} and saving to {save_to}")

def app(self):
return self._app

0 comments on commit 6e95027

Please sign in to comment.