From 8b93c91a07e8a74069d40e321d9b7f75b4656382 Mon Sep 17 00:00:00 2001 From: hipsterusername Date: Sat, 18 May 2024 19:32:52 -0700 Subject: [PATCH 1/6] initial addition of Finetune tab --- .../sdxl_finetune_config_group.py | 233 ++++++++++++++++++ src/invoke_training/ui/pages/training_page.py | 13 + 2 files changed, 246 insertions(+) create mode 100644 src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py diff --git a/src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py b/src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py new file mode 100644 index 00000000..575530dd --- /dev/null +++ b/src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py @@ -0,0 +1,233 @@ +import typing + +import gradio as gr + +from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig +from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup +from invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import ImageCaptionSDDataLoaderConfigGroup +from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup +from invoke_training.ui.config_groups.ui_config_element import UIConfigElement +from invoke_training.ui.utils.prompts import convert_pos_neg_prompts_to_ui_prompts, convert_ui_prompts_to_pos_neg_prompts +from invoke_training.ui.utils.utils import get_typing_literal_options + + +class SdxlFinetuneConfigGroup(UIConfigElement): + def __init__(self): + """The SDXL_FINETUNE configs.""" + + gr.Markdown("## Basic Configs") + with gr.Row(): + with gr.Column(scale=1): + with gr.Tab("Base Model"): + self.model = gr.Textbox( + label="Model", + info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in " + "diffusers or checkpoint format).", + type="text", + interactive=True, + ) + self.hf_variant = gr.Textbox( + label="Variant", + info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a" + " HF Hub model name.", + type="text", + interactive=True, + ) + self.vae_model = gr.Textbox( + label="VAE Model", + info="(optional) If set, this overrides the base model's default VAE model.", + type="text", + interactive=True, + ) + with gr.Column(scale=3): + with gr.Tab("Training Outputs"): + self.base_pipeline_config_group = BasePipelineConfigGroup() + self.save_checkpoint_format = gr.Dropdown( + label="Checkpoint Format", + choices=get_typing_literal_options(SdxlFinetuneConfig, "save_checkpoint_format"), + interactive=True, + ) + self.save_dtype = gr.Dropdown( + label="Save Dtype", + choices=get_typing_literal_options(SdxlFinetuneConfig, "save_dtype"), + interactive=True, + ) + self.max_checkpoints = gr.Number( + label="Maximum Number of Checkpoints", + info="The maximum number of checkpoints to keep on disk from this training run. Earlier " + "checkpoints will be deleted to respect this limit.", + interactive=True, + precision=0, + ) + + gr.Markdown("## Data Configs") + self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup() + + gr.Markdown("## Optimizer Configs") + self.optimizer_config_group = OptimizerConfigGroup() + + gr.Markdown("## Speed / Memory Configs") + with gr.Group(): + with gr.Row(): + self.gradient_accumulation_steps = gr.Number( + label="Gradient Accumulation Steps", + info="The number of gradient steps to accumulate before each weight update. This is an alternative" + "to increasing the batch size when training with limited VRAM." + "effective_batch_size = train_batch_size * gradient_accumulation_steps.", + precision=0, + interactive=True, + ) + with gr.Row(): + self.weight_dtype = gr.Dropdown( + label="Weight Type", + info="The precision of the model weights. Lower precision can speed up training and reduce memory, " + "with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases " + "if your GPU supports it.", + choices=get_typing_literal_options(SdxlFinetuneConfig, "weight_dtype"), + interactive=True, + ) + with gr.Row(): + self.cache_text_encoder_outputs = gr.Checkbox( + label="Cache Text Encoder Outputs", + info="Cache the text encoder outputs to increase speed. This should not be used when training the " + "text encoder or performing data augmentations that would change the text encoder outputs.", + interactive=True, + ) + self.cache_vae_outputs = gr.Checkbox( + label="Cache VAE Outputs", + info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or " + "performing data augmentations that would change the VAE outputs.", + interactive=True, + ) + with gr.Row(): + self.enable_cpu_offload_during_validation = gr.Checkbox( + label="Enable CPU Offload during Validation", + info="Offload models to the CPU sequentially during validation. This reduces peak VRAM " + "requirements at the cost of slower validation during training.", + interactive=True, + ) + self.gradient_checkpointing = gr.Checkbox( + label="Gradient Checkpointing", + info="If True, VRAM requirements are reduced at the cost of ~20% slower training", + interactive=True, + ) + + gr.Markdown("## General Training Configs") + with gr.Tab("Core"): + with gr.Row(): + self.lr_scheduler = gr.Dropdown( + label="Learning Rate Scheduler", + choices=get_typing_literal_options(SdxlFinetuneConfig, "lr_scheduler"), + interactive=True, + ) + self.lr_warmup_steps = gr.Number( + label="Warmup Steps", + info="The number of warmup steps in the " + "learning rate schedule, if applicable to the selected scheduler.", + interactive=True, + ) + with gr.Row(): + self.min_snr_gamma = gr.Number( + label="Minimum SNR Gamma", + info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise " + "levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended " + "value is min_snr gamma = 5.0.", + interactive=True, + ) + self.max_grad_norm = gr.Number( + label="Max Gradient Norm", + info="Max gradient norm for clipping. Set to None for no clipping.", + interactive=True, + ) + self.train_batch_size = gr.Number( + label="Batch Size", + info="The Training Batch Size - Higher values require increasing amounts of VRAM.", + precision=0, + interactive=True, + ) + + gr.Markdown("## Validation") + with gr.Group(): + self.validation_prompts = gr.Textbox( + label="Validation Prompts", + info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' " + "delimiter. For example: `positive prompt[NEG]negative prompt`. ", + lines=5, + interactive=True, + ) + self.num_validation_images_per_prompt = gr.Number( + label="# of Validation Images to Generate per Prompt", precision=0, interactive=True + ) + + def update_ui_components_with_config_data(self, config: SdxlFinetuneConfig) -> dict[gr.components.Component, typing.Any]: + update_dict = { + self.model: config.model, + self.hf_variant: config.hf_variant, + self.vae_model: config.vae_model, + self.save_checkpoint_format: config.save_checkpoint_format, + self.save_dtype: config.save_dtype, + self.max_checkpoints: config.max_checkpoints, + self.lr_scheduler: config.lr_scheduler, + self.lr_warmup_steps: config.lr_warmup_steps, + self.min_snr_gamma: config.min_snr_gamma, + self.max_grad_norm: config.max_grad_norm, + self.train_batch_size: config.train_batch_size, + self.cache_text_encoder_outputs: config.cache_text_encoder_outputs, + self.cache_vae_outputs: config.cache_vae_outputs, + self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation, + self.gradient_accumulation_steps: config.gradient_accumulation_steps, + self.weight_dtype: config.weight_dtype, + self.gradient_checkpointing: config.gradient_checkpointing, + self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts( + config.validation_prompts, config.negative_validation_prompts + ), + self.num_validation_images_per_prompt: config.num_validation_images_per_prompt, + } + update_dict.update( + self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader) + ) + update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config)) + update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer)) + + # Sanity check to catch if we accidentally forget to update a UI component. + assert set(update_dict.keys()) == set(self.get_ui_output_components()) + + return update_dict + + def update_config_with_ui_component_data(self, orig_config: SdxlFinetuneConfig, ui_data: dict[gr.components.Component, typing.Any]) -> SdxlFinetuneConfig: + new_config = orig_config.model_copy(deep=True) + + new_config.model = ui_data.pop(self.model) + new_config.hf_variant = ui_data.pop(self.hf_variant) or None + new_config.vae_model = ui_data.pop(self.vae_model) or None + new_config.save_checkpoint_format = ui_data.pop(self.save_checkpoint_format) + new_config.save_dtype = ui_data.pop(self.save_dtype) + new_config.max_checkpoints = ui_data.pop(self.max_checkpoints) + new_config.lr_scheduler = ui_data.pop(self.lr_scheduler) + new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps) + new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma) + new_config.max_grad_norm = ui_data.pop(self.max_grad_norm) + new_config.train_batch_size = ui_data.pop(self.train_batch_size) + new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs) + new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs) + new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation) + new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps) + new_config.weight_dtype = ui_data.pop(self.weight_dtype) + new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing) + new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt) + + positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts)) + new_config.validation_prompts = positive_prompts + new_config.negative_validation_prompts = negative_prompts + + new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data( + new_config.data_loader, ui_data + ) + new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data) + new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(new_config.optimizer, ui_data) + + # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred + # to the config. + assert len(ui_data) == 0 + + return new_config diff --git a/src/invoke_training/ui/pages/training_page.py b/src/invoke_training/ui/pages/training_page.py index 1752e501..cc9086cb 100644 --- a/src/invoke_training/ui/pages/training_page.py +++ b/src/invoke_training/ui/pages/training_page.py @@ -14,6 +14,8 @@ SdxlLoraAndTextualInversionConfig, ) from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig +from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig + from invoke_training.ui.config_groups.sd_lora_config_group import SdLoraConfigGroup from invoke_training.ui.config_groups.sd_textual_inversion_config_group import SdTextualInversionConfigGroup from invoke_training.ui.config_groups.sdxl_lora_and_textual_inversion_config_group import ( @@ -21,6 +23,8 @@ ) from invoke_training.ui.config_groups.sdxl_lora_config_group import SdxlLoraConfigGroup from invoke_training.ui.config_groups.sdxl_textual_inversion_config_group import SdxlTextualInversionConfigGroup +from invoke_training.ui.config_groups.sdxl_finetune_config_group import SdxlFinetuneConfigGroup + from invoke_training.ui.gradio_blocks.header import Header from invoke_training.ui.gradio_blocks.pipeline_tab import PipelineTab from invoke_training.ui.utils.utils import get_config_dir_path @@ -82,6 +86,15 @@ def __init__(self): run_training_cb=self._run_training, app=app, ) + with gr.Tab(label="SDXL Finetune"): + PipelineTab( + name="SDXL Finetune", + default_config_file_path=str(get_config_dir_path() / "sdxl_finetune_config_baroque_1x24gb.yaml"), + pipeline_config_cls=SdxlFinetuneConfig, + config_group_cls=SdxlFinetuneConfigGroup, + run_training_cb=self._run_training, + app=app, + ) self._app = app From 707f6dcd7ed9f5332076c6ee626b4081553dff22 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Sun, 19 May 2024 23:28:34 -0400 Subject: [PATCH 2/6] Fix reference config name. --- src/invoke_training/ui/pages/training_page.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/invoke_training/ui/pages/training_page.py b/src/invoke_training/ui/pages/training_page.py index cc9086cb..599bd320 100644 --- a/src/invoke_training/ui/pages/training_page.py +++ b/src/invoke_training/ui/pages/training_page.py @@ -9,22 +9,20 @@ from invoke_training.config.pipeline_config import PipelineConfig from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig +from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import ( SdxlLoraAndTextualInversionConfig, ) from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig -from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig - from invoke_training.ui.config_groups.sd_lora_config_group import SdLoraConfigGroup from invoke_training.ui.config_groups.sd_textual_inversion_config_group import SdTextualInversionConfigGroup +from invoke_training.ui.config_groups.sdxl_finetune_config_group import SdxlFinetuneConfigGroup from invoke_training.ui.config_groups.sdxl_lora_and_textual_inversion_config_group import ( SdxlLoraAndTextualInversionConfigGroup, ) from invoke_training.ui.config_groups.sdxl_lora_config_group import SdxlLoraConfigGroup from invoke_training.ui.config_groups.sdxl_textual_inversion_config_group import SdxlTextualInversionConfigGroup -from invoke_training.ui.config_groups.sdxl_finetune_config_group import SdxlFinetuneConfigGroup - from invoke_training.ui.gradio_blocks.header import Header from invoke_training.ui.gradio_blocks.pipeline_tab import PipelineTab from invoke_training.ui.utils.utils import get_config_dir_path @@ -89,7 +87,7 @@ def __init__(self): with gr.Tab(label="SDXL Finetune"): PipelineTab( name="SDXL Finetune", - default_config_file_path=str(get_config_dir_path() / "sdxl_finetune_config_baroque_1x24gb.yaml"), + default_config_file_path=str(get_config_dir_path() / "sdxl_finetune_baroque_1x24gb.yaml"), pipeline_config_cls=SdxlFinetuneConfig, config_group_cls=SdxlFinetuneConfigGroup, run_training_cb=self._run_training, From 5ac72b37b4778d1c6d5912a55ce17645c375c3c4 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Sun, 19 May 2024 23:36:11 -0400 Subject: [PATCH 3/6] Fix end_dimension typo --- .../ui/config_groups/aspect_ratio_bucket_config_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/invoke_training/ui/config_groups/aspect_ratio_bucket_config_group.py b/src/invoke_training/ui/config_groups/aspect_ratio_bucket_config_group.py index 50ae168e..5d5ceff2 100644 --- a/src/invoke_training/ui/config_groups/aspect_ratio_bucket_config_group.py +++ b/src/invoke_training/ui/config_groups/aspect_ratio_bucket_config_group.py @@ -17,7 +17,7 @@ def __init__(self): self.enabled = gr.Checkbox(label="Use Aspect Ratio Bucketing", interactive=True) self.target_resolution = gr.Number(label="target_resolution", interactive=True, precision=0) self.start_dim = gr.Number(label="start_dimension", interactive=True, precision=0) - self.end_dim = gr.Number(label="end_imension", interactive=True, precision=0) + self.end_dim = gr.Number(label="end_dimension", interactive=True, precision=0) self.divisible_by = gr.Number(label="divisible_by", interactive=True, precision=0) def update_ui_components_with_config_data( From 399c130d9d3104556a01c14c128b1760a9b7fa3e Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Mon, 20 May 2024 08:12:38 -0400 Subject: [PATCH 4/6] ruff fixes? --- .../sdxl_finetune_config_group.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py b/src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py index 575530dd..961fd8f8 100644 --- a/src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py +++ b/src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py @@ -4,10 +4,15 @@ from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup -from invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import ImageCaptionSDDataLoaderConfigGroup +from invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import ( + ImageCaptionSDDataLoaderConfigGroup, +) from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup from invoke_training.ui.config_groups.ui_config_element import UIConfigElement -from invoke_training.ui.utils.prompts import convert_pos_neg_prompts_to_ui_prompts, convert_ui_prompts_to_pos_neg_prompts +from invoke_training.ui.utils.prompts import ( + convert_pos_neg_prompts_to_ui_prompts, + convert_ui_prompts_to_pos_neg_prompts, +) from invoke_training.ui.utils.utils import get_typing_literal_options @@ -159,7 +164,9 @@ def __init__(self): label="# of Validation Images to Generate per Prompt", precision=0, interactive=True ) - def update_ui_components_with_config_data(self, config: SdxlFinetuneConfig) -> dict[gr.components.Component, typing.Any]: + def update_ui_components_with_config_data( + self, config: SdxlFinetuneConfig + ) -> dict[gr.components.Component, typing.Any]: update_dict = { self.model: config.model, self.hf_variant: config.hf_variant, @@ -194,7 +201,9 @@ def update_ui_components_with_config_data(self, config: SdxlFinetuneConfig) -> d return update_dict - def update_config_with_ui_component_data(self, orig_config: SdxlFinetuneConfig, ui_data: dict[gr.components.Component, typing.Any]) -> SdxlFinetuneConfig: + def update_config_with_ui_component_data( + self, orig_config: SdxlFinetuneConfig, ui_data: dict[gr.components.Component, typing.Any] + ) -> SdxlFinetuneConfig: new_config = orig_config.model_copy(deep=True) new_config.model = ui_data.pop(self.model) @@ -224,7 +233,9 @@ def update_config_with_ui_component_data(self, orig_config: SdxlFinetuneConfig, new_config.data_loader, ui_data ) new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data) - new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(new_config.optimizer, ui_data) + new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data( + new_config.optimizer, ui_data + ) # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred # to the config. From 654bba160d473c68aee80f5f62ddece91bcdd45e Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 21 May 2024 09:28:59 -0400 Subject: [PATCH 5/6] Add help messages to the SDXL finetune UI. --- .../ui/config_groups/sdxl_finetune_config_group.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py b/src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py index 961fd8f8..1fe9311e 100644 --- a/src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py +++ b/src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py @@ -49,11 +49,15 @@ def __init__(self): self.base_pipeline_config_group = BasePipelineConfigGroup() self.save_checkpoint_format = gr.Dropdown( label="Checkpoint Format", + info="The save format for the checkpoints. `full_diffusers` saves the full model in diffusers " + "format. `trained_only_diffusers` saves only the parts of the model that were finetuned " + "(i.e. the UNet).", choices=get_typing_literal_options(SdxlFinetuneConfig, "save_checkpoint_format"), interactive=True, ) self.save_dtype = gr.Dropdown( label="Save Dtype", + info="The dtype to use when saving the model.", choices=get_typing_literal_options(SdxlFinetuneConfig, "save_dtype"), interactive=True, ) From cd85bb29a3473afece5322a1167dbab6ecdfec21 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 21 May 2024 10:16:00 -0400 Subject: [PATCH 6/6] Add use_8bit Adam optimizer config to the UI. --- .../config_groups/optimizer_config_group.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/invoke_training/ui/config_groups/optimizer_config_group.py b/src/invoke_training/ui/config_groups/optimizer_config_group.py index a6c95ee3..5cd06e43 100644 --- a/src/invoke_training/ui/config_groups/optimizer_config_group.py +++ b/src/invoke_training/ui/config_groups/optimizer_config_group.py @@ -11,12 +11,18 @@ class AdamOptimizerConfigGroup(UIConfigElement): def __init__(self): with gr.Tab("Core"): - self.learning_rate = gr.Number( - label="Learning Rate", - info="Initial learning rate to use (after the potential warmup period). Note that in some training " - "pipelines this can be overriden for a specific group of params.", - interactive=True, - ) + with gr.Row(): + self.learning_rate = gr.Number( + label="Learning Rate", + info="Initial learning rate to use (after the potential warmup period). Note that in some training " + "pipelines this can be overriden for a specific group of params.", + interactive=True, + ) + self.use_8bit = gr.Checkbox( + label="Use 8-bit", + info="Use 8-bit Adam optimizer to reduce VRAM requirements. (Requires bitsandbytes.)", + interactive=True, + ) with gr.Tab("Advanced"): with gr.Row(): self.beta1 = gr.Number(label="beta1", interactive=True) @@ -32,6 +38,7 @@ def update_ui_components_with_config_data(self, config: AdamOptimizerConfig) -> self.beta2: config.beta2, self.weight_decay: config.weight_decay, self.epsilon: config.epsilon, + self.use_8bit: config.use_8bit, } def update_config_with_ui_component_data( @@ -45,6 +52,7 @@ def update_config_with_ui_component_data( beta2=ui_data.pop(self.beta2), weight_decay=ui_data.pop(self.weight_decay), epsilon=ui_data.pop(self.epsilon), + use_8bit=ui_data.pop(self.use_8bit), )