Skip to content

Commit

Permalink
Merge pull request #127 from invoke-ai/SDXL-Finetune-Training-Addition
Browse files Browse the repository at this point in the history
initial addition of Finetune tab
  • Loading branch information
RyanJDick authored May 21, 2024
2 parents 3519d73 + cd85bb2 commit 24a3bc5
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self):
self.enabled = gr.Checkbox(label="Use Aspect Ratio Bucketing", interactive=True)
self.target_resolution = gr.Number(label="target_resolution", interactive=True, precision=0)
self.start_dim = gr.Number(label="start_dimension", interactive=True, precision=0)
self.end_dim = gr.Number(label="end_imension", interactive=True, precision=0)
self.end_dim = gr.Number(label="end_dimension", interactive=True, precision=0)
self.divisible_by = gr.Number(label="divisible_by", interactive=True, precision=0)

def update_ui_components_with_config_data(
Expand Down
20 changes: 14 additions & 6 deletions src/invoke_training/ui/config_groups/optimizer_config_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
class AdamOptimizerConfigGroup(UIConfigElement):
def __init__(self):
with gr.Tab("Core"):
self.learning_rate = gr.Number(
label="Learning Rate",
info="Initial learning rate to use (after the potential warmup period). Note that in some training "
"pipelines this can be overriden for a specific group of params.",
interactive=True,
)
with gr.Row():
self.learning_rate = gr.Number(
label="Learning Rate",
info="Initial learning rate to use (after the potential warmup period). Note that in some training "
"pipelines this can be overriden for a specific group of params.",
interactive=True,
)
self.use_8bit = gr.Checkbox(
label="Use 8-bit",
info="Use 8-bit Adam optimizer to reduce VRAM requirements. (Requires bitsandbytes.)",
interactive=True,
)
with gr.Tab("Advanced"):
with gr.Row():
self.beta1 = gr.Number(label="beta1", interactive=True)
Expand All @@ -32,6 +38,7 @@ def update_ui_components_with_config_data(self, config: AdamOptimizerConfig) ->
self.beta2: config.beta2,
self.weight_decay: config.weight_decay,
self.epsilon: config.epsilon,
self.use_8bit: config.use_8bit,
}

def update_config_with_ui_component_data(
Expand All @@ -45,6 +52,7 @@ def update_config_with_ui_component_data(
beta2=ui_data.pop(self.beta2),
weight_decay=ui_data.pop(self.weight_decay),
epsilon=ui_data.pop(self.epsilon),
use_8bit=ui_data.pop(self.use_8bit),
)


Expand Down
248 changes: 248 additions & 0 deletions src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import typing

import gradio as gr

from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig
from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup
from invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import (
ImageCaptionSDDataLoaderConfigGroup,
)
from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
from invoke_training.ui.utils.prompts import (
convert_pos_neg_prompts_to_ui_prompts,
convert_ui_prompts_to_pos_neg_prompts,
)
from invoke_training.ui.utils.utils import get_typing_literal_options


class SdxlFinetuneConfigGroup(UIConfigElement):
def __init__(self):
"""The SDXL_FINETUNE configs."""

gr.Markdown("## Basic Configs")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Base Model"):
self.model = gr.Textbox(
label="Model",
info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in "
"diffusers or checkpoint format).",
type="text",
interactive=True,
)
self.hf_variant = gr.Textbox(
label="Variant",
info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a"
" HF Hub model name.",
type="text",
interactive=True,
)
self.vae_model = gr.Textbox(
label="VAE Model",
info="(optional) If set, this overrides the base model's default VAE model.",
type="text",
interactive=True,
)
with gr.Column(scale=3):
with gr.Tab("Training Outputs"):
self.base_pipeline_config_group = BasePipelineConfigGroup()
self.save_checkpoint_format = gr.Dropdown(
label="Checkpoint Format",
info="The save format for the checkpoints. `full_diffusers` saves the full model in diffusers "
"format. `trained_only_diffusers` saves only the parts of the model that were finetuned "
"(i.e. the UNet).",
choices=get_typing_literal_options(SdxlFinetuneConfig, "save_checkpoint_format"),
interactive=True,
)
self.save_dtype = gr.Dropdown(
label="Save Dtype",
info="The dtype to use when saving the model.",
choices=get_typing_literal_options(SdxlFinetuneConfig, "save_dtype"),
interactive=True,
)
self.max_checkpoints = gr.Number(
label="Maximum Number of Checkpoints",
info="The maximum number of checkpoints to keep on disk from this training run. Earlier "
"checkpoints will be deleted to respect this limit.",
interactive=True,
precision=0,
)

gr.Markdown("## Data Configs")
self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup()

gr.Markdown("## Optimizer Configs")
self.optimizer_config_group = OptimizerConfigGroup()

gr.Markdown("## Speed / Memory Configs")
with gr.Group():
with gr.Row():
self.gradient_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps",
info="The number of gradient steps to accumulate before each weight update. This is an alternative"
"to increasing the batch size when training with limited VRAM."
"effective_batch_size = train_batch_size * gradient_accumulation_steps.",
precision=0,
interactive=True,
)
with gr.Row():
self.weight_dtype = gr.Dropdown(
label="Weight Type",
info="The precision of the model weights. Lower precision can speed up training and reduce memory, "
"with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases "
"if your GPU supports it.",
choices=get_typing_literal_options(SdxlFinetuneConfig, "weight_dtype"),
interactive=True,
)
with gr.Row():
self.cache_text_encoder_outputs = gr.Checkbox(
label="Cache Text Encoder Outputs",
info="Cache the text encoder outputs to increase speed. This should not be used when training the "
"text encoder or performing data augmentations that would change the text encoder outputs.",
interactive=True,
)
self.cache_vae_outputs = gr.Checkbox(
label="Cache VAE Outputs",
info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or "
"performing data augmentations that would change the VAE outputs.",
interactive=True,
)
with gr.Row():
self.enable_cpu_offload_during_validation = gr.Checkbox(
label="Enable CPU Offload during Validation",
info="Offload models to the CPU sequentially during validation. This reduces peak VRAM "
"requirements at the cost of slower validation during training.",
interactive=True,
)
self.gradient_checkpointing = gr.Checkbox(
label="Gradient Checkpointing",
info="If True, VRAM requirements are reduced at the cost of ~20% slower training",
interactive=True,
)

gr.Markdown("## General Training Configs")
with gr.Tab("Core"):
with gr.Row():
self.lr_scheduler = gr.Dropdown(
label="Learning Rate Scheduler",
choices=get_typing_literal_options(SdxlFinetuneConfig, "lr_scheduler"),
interactive=True,
)
self.lr_warmup_steps = gr.Number(
label="Warmup Steps",
info="The number of warmup steps in the "
"learning rate schedule, if applicable to the selected scheduler.",
interactive=True,
)
with gr.Row():
self.min_snr_gamma = gr.Number(
label="Minimum SNR Gamma",
info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise "
"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended "
"value is min_snr gamma = 5.0.",
interactive=True,
)
self.max_grad_norm = gr.Number(
label="Max Gradient Norm",
info="Max gradient norm for clipping. Set to None for no clipping.",
interactive=True,
)
self.train_batch_size = gr.Number(
label="Batch Size",
info="The Training Batch Size - Higher values require increasing amounts of VRAM.",
precision=0,
interactive=True,
)

gr.Markdown("## Validation")
with gr.Group():
self.validation_prompts = gr.Textbox(
label="Validation Prompts",
info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' "
"delimiter. For example: `positive prompt[NEG]negative prompt`. ",
lines=5,
interactive=True,
)
self.num_validation_images_per_prompt = gr.Number(
label="# of Validation Images to Generate per Prompt", precision=0, interactive=True
)

def update_ui_components_with_config_data(
self, config: SdxlFinetuneConfig
) -> dict[gr.components.Component, typing.Any]:
update_dict = {
self.model: config.model,
self.hf_variant: config.hf_variant,
self.vae_model: config.vae_model,
self.save_checkpoint_format: config.save_checkpoint_format,
self.save_dtype: config.save_dtype,
self.max_checkpoints: config.max_checkpoints,
self.lr_scheduler: config.lr_scheduler,
self.lr_warmup_steps: config.lr_warmup_steps,
self.min_snr_gamma: config.min_snr_gamma,
self.max_grad_norm: config.max_grad_norm,
self.train_batch_size: config.train_batch_size,
self.cache_text_encoder_outputs: config.cache_text_encoder_outputs,
self.cache_vae_outputs: config.cache_vae_outputs,
self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation,
self.gradient_accumulation_steps: config.gradient_accumulation_steps,
self.weight_dtype: config.weight_dtype,
self.gradient_checkpointing: config.gradient_checkpointing,
self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(
config.validation_prompts, config.negative_validation_prompts
),
self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,
}
update_dict.update(
self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader)
)
update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))
update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))

# Sanity check to catch if we accidentally forget to update a UI component.
assert set(update_dict.keys()) == set(self.get_ui_output_components())

return update_dict

def update_config_with_ui_component_data(
self, orig_config: SdxlFinetuneConfig, ui_data: dict[gr.components.Component, typing.Any]
) -> SdxlFinetuneConfig:
new_config = orig_config.model_copy(deep=True)

new_config.model = ui_data.pop(self.model)
new_config.hf_variant = ui_data.pop(self.hf_variant) or None
new_config.vae_model = ui_data.pop(self.vae_model) or None
new_config.save_checkpoint_format = ui_data.pop(self.save_checkpoint_format)
new_config.save_dtype = ui_data.pop(self.save_dtype)
new_config.max_checkpoints = ui_data.pop(self.max_checkpoints)
new_config.lr_scheduler = ui_data.pop(self.lr_scheduler)
new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps)
new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)
new_config.max_grad_norm = ui_data.pop(self.max_grad_norm)
new_config.train_batch_size = ui_data.pop(self.train_batch_size)
new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs)
new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs)
new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation)
new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps)
new_config.weight_dtype = ui_data.pop(self.weight_dtype)
new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing)
new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)

positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))
new_config.validation_prompts = positive_prompts
new_config.negative_validation_prompts = negative_prompts

new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data(
new_config.data_loader, ui_data
)
new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data)
new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(
new_config.optimizer, ui_data
)

# We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred
# to the config.
assert len(ui_data) == 0

return new_config
11 changes: 11 additions & 0 deletions src/invoke_training/ui/pages/training_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from invoke_training.config.pipeline_config import PipelineConfig
from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig
from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig
from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig
from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig
from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import (
SdxlLoraAndTextualInversionConfig,
)
from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig
from invoke_training.ui.config_groups.sd_lora_config_group import SdLoraConfigGroup
from invoke_training.ui.config_groups.sd_textual_inversion_config_group import SdTextualInversionConfigGroup
from invoke_training.ui.config_groups.sdxl_finetune_config_group import SdxlFinetuneConfigGroup
from invoke_training.ui.config_groups.sdxl_lora_and_textual_inversion_config_group import (
SdxlLoraAndTextualInversionConfigGroup,
)
Expand Down Expand Up @@ -82,6 +84,15 @@ def __init__(self):
run_training_cb=self._run_training,
app=app,
)
with gr.Tab(label="SDXL Finetune"):
PipelineTab(
name="SDXL Finetune",
default_config_file_path=str(get_config_dir_path() / "sdxl_finetune_baroque_1x24gb.yaml"),
pipeline_config_cls=SdxlFinetuneConfig,
config_group_cls=SdxlFinetuneConfigGroup,
run_training_cb=self._run_training,
app=app,
)

self._app = app

Expand Down

0 comments on commit 24a3bc5

Please sign in to comment.