From 5927f571a789034e152bf7268e97ca1e5d4d5a31 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 11 Nov 2024 17:51:06 +0000 Subject: [PATCH 01/11] add --vae_enable_tiling to encode large res images with less vram used --- helpers/configuration/cmd_args.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 14e1b470..53471635 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -592,6 +592,15 @@ def get_argument_parser(): " but if you are at that point of contention, it's possible that your GPU has too little RAM. Default: 4." ), ) + parser.add_argument( + "--vae_enable_tiling", + action="store_true", + default=False, + help=( + "If set, will enable tiling for VAE caching. This is useful for very large images when VRAM is limited." + " This may be required for 2048px VAE caching on 24G accelerators, in addition to reducing --vae_batch_size." + ), + ) parser.add_argument( "--vae_cache_scan_behaviour", type=str, From bc9d5d904c53aa13f0466fb945177c84f4787d70 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 11 Nov 2024 17:52:35 +0000 Subject: [PATCH 02/11] s3: when file does not exist, handle generic 404 error for headobject --- helpers/data_backend/aws.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/helpers/data_backend/aws.py b/helpers/data_backend/aws.py index b1d3b788..3b5ef8af 100644 --- a/helpers/data_backend/aws.py +++ b/helpers/data_backend/aws.py @@ -106,6 +106,8 @@ def exists(self, s3_key): except (NoCredentialsError, PartialCredentialsError) as e: raise e # Raise credential errors to the caller except Exception as e: + if "An error occurred (404) when calling the HeadObject operation: Not Found" in str(e): + return False logger.error(f'Error checking existence of S3 key "{s3_key}": {e}') if i == self.read_retry_limit - 1: # We have reached our maximum retry count. From 2f227f8606c915ab9759768729630eec98594554 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 11 Nov 2024 17:54:02 +0000 Subject: [PATCH 03/11] trainer: enable vae tiling when enabled --- helpers/training/trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 72b76559..904cc31f 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -468,6 +468,9 @@ def init_vae(self, move_to_accelerator: bool = True): ) self.config.vae_kwargs["subfolder"] = None self.vae = AutoencoderKL.from_pretrained(**self.config.vae_kwargs) + if self.vae is not None and self.config.vae_enable_tiling and hasattr(self.vae, 'enable_tiling'): + logger.warning("Enabling VAE tiling for greatly reduced memory consumption due to --vae_enable_tiling which may result in VAE tiling artifacts in encoded latents.") + self.vae.enable_tiling() if not move_to_accelerator: logger.debug("Not moving VAE to accelerator.") return From bbae68256972f891205c14fddf372310348e7dde Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 11 Nov 2024 17:54:41 +0000 Subject: [PATCH 04/11] validation: fix error when torch compile is disabled for lycoris --- helpers/training/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/training/validation.py b/helpers/training/validation.py index e0bef972..717612e3 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -1085,7 +1085,7 @@ def setup_pipeline(self, validation_type, enable_ema_model: bool = True): if self.args.validation_torch_compile: if self.deepspeed: logger.warning("DeepSpeed does not support torch compile. Disabling. Set --validation_torch_compile=False to suppress this warning.") - elif self.lora_type.lower() == "lycoris": + elif self.args.lora_type.lower() == "lycoris": logger.warning("LyCORIS does not support torch compile for validation due to graph compile breaks. Disabling. Set --validation_torch_compile=False to suppress this warning.") else: if self.unet is not None and not is_compiled_module(self.unet): From dfc3f8fd7101eb606173b0be615b0933a0c68a4c Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 11 Nov 2024 18:04:12 +0000 Subject: [PATCH 05/11] update OPTIONS doc --- OPTIONS.md | 112 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 77 insertions(+), 35 deletions(-) diff --git a/OPTIONS.md b/OPTIONS.md index 032cc468..a6bf5cf1 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -356,7 +356,10 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--model_type {full,lora,deepfloyd-full,deepfloyd-lora,deepfloyd-stage2,deepfloyd-stage2-lora}] [--flux_lora_target {mmdit,context,context+ffs,all,all+ffs,ai-toolkit,tiny,nano}] [--flow_matching_sigmoid_scale FLOW_MATCHING_SIGMOID_SCALE] - [--flux_fast_schedule] + [--flux_fast_schedule] [--flux_use_uniform_schedule] + [--flux_use_beta_schedule] + [--flux_beta_schedule_alpha FLUX_BETA_SCHEDULE_ALPHA] + [--flux_beta_schedule_beta FLUX_BETA_SCHEDULE_BETA] [--flux_schedule_shift FLUX_SCHEDULE_SHIFT] [--flux_schedule_auto_shift] [--flux_guidance_mode {constant,random-range,mobius}] @@ -366,8 +369,9 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--flux_attention_masked_training] [--t5_padding {zero,unmodified}] [--smoldit] [--smoldit_config {smoldit-small,smoldit-swiglu,smoldit-base,smoldit-large,smoldit-huge}] - [--flow_matching_loss {diffusers,compatible,diffusion}] - [--sd3_t5_mask_behaviour {do-nothing,mask}] + [--flow_matching_loss {diffusers,compatible,diffusion,sd35}] + [--sd3_clip_uncond_behaviour {empty_string,zero}] + [--sd3_t5_uncond_behaviour {empty_string,zero}] [--lora_type {standard,lycoris}] [--lora_init_type {default,gaussian,loftq,olora,pissa}] [--init_lora INIT_LORA] [--lora_rank LORA_RANK] @@ -396,7 +400,7 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--disable_segmented_timestep_sampling] [--rescale_betas_zero_snr] [--vae_dtype {default,fp16,fp32,bf16}] - [--vae_batch_size VAE_BATCH_SIZE] + [--vae_batch_size VAE_BATCH_SIZE] [--vae_enable_tiling] [--vae_cache_scan_behaviour {recreate,sync}] [--vae_cache_ondemand] [--compress_disk_cache] [--aspect_bucket_disable_rebuild] [--keep_vae_loaded] @@ -448,9 +452,9 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--ema_update_interval EMA_UPDATE_INTERVAL] [--ema_decay EMA_DECAY] [--non_ema_revision NON_EMA_REVISION] [--offload_param_path OFFLOAD_PARAM_PATH] --optimizer - {adamw_bf16,ao-adamw8bit,ao-adamw4bit,ao-adamfp8,ao-adamwfp8,adamw_schedulefree,adamw_schedulefree+aggressive,adamw_schedulefree+no_kahan,optimi-stableadamw,optimi-adamw,optimi-lion,optimi-radam,optimi-ranger,optimi-adan,optimi-adam,optimi-sgd} + {adamw_bf16,ao-adamw8bit,ao-adamw4bit,ao-adamfp8,ao-adamwfp8,adamw_schedulefree,adamw_schedulefree+aggressive,adamw_schedulefree+no_kahan,optimi-stableadamw,optimi-adamw,optimi-lion,optimi-radam,optimi-ranger,optimi-adan,optimi-adam,optimi-sgd,soap,bnb-adagrad,bnb-adagrad8bit,bnb-adam,bnb-adam8bit,bnb-adamw,bnb-adamw8bit,bnb-adamw-paged,bnb-adamw8bit-paged,bnb-ademamix,bnb-ademamix8bit,bnb-ademamix-paged,bnb-ademamix8bit-paged,bnb-lion,bnb-lion8bit,bnb-lion-paged,bnb-lion8bit-paged} [--optimizer_config OPTIMIZER_CONFIG] - [--optimizer_cpu_offload_method {none,torchao}] + [--optimizer_cpu_offload_method {none}] [--optimizer_offload_gradients] [--fuse_optimizer] [--optimizer_beta1 OPTIMIZER_BETA1] [--optimizer_beta2 OPTIMIZER_BETA2] @@ -466,6 +470,10 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--validation_on_startup] [--validation_seed_source {gpu,cpu}] [--validation_torch_compile] [--validation_torch_compile_mode {max-autotune,reduce-overhead,default}] + [--validation_guidance_skip_layers VALIDATION_GUIDANCE_SKIP_LAYERS] + [--validation_guidance_skip_layers_start VALIDATION_GUIDANCE_SKIP_LAYERS_START] + [--validation_guidance_skip_layers_stop VALIDATION_GUIDANCE_SKIP_LAYERS_STOP] + [--validation_guidance_skip_scale VALIDATION_GUIDANCE_SKIP_SCALE] [--allow_tf32] [--disable_tf32] [--validation_using_datasets] [--webhook_config WEBHOOK_CONFIG] [--webhook_reporting_interval WEBHOOK_REPORTING_INTERVAL] @@ -487,11 +495,12 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--mixed_precision {bf16,no}] [--gradient_precision {unmodified,fp32}] [--quantize_via {cpu,accelerator}] - [--base_model_precision {no_change,fp8-quanto,fp8uz-quanto,int8-quanto,int4-quanto,int2-quanto,int8-torchao}] + [--base_model_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}] + [--quantize_activations] [--base_model_default_dtype {bf16,fp32}] - [--text_encoder_1_precision {no_change,fp8-quanto,fp8uz-quanto,int8-quanto,int4-quanto,int2-quanto,int8-torchao}] - [--text_encoder_2_precision {no_change,fp8-quanto,fp8uz-quanto,int8-quanto,int4-quanto,int2-quanto,int8-torchao}] - [--text_encoder_3_precision {no_change,fp8-quanto,fp8uz-quanto,int8-quanto,int4-quanto,int2-quanto,int8-torchao}] + [--text_encoder_1_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}] + [--text_encoder_2_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}] + [--text_encoder_3_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}] [--local_rank LOCAL_RANK] [--enable_xformers_memory_efficient_attention] [--set_grads_to_none] [--noise_offset NOISE_OFFSET] @@ -568,6 +577,21 @@ options: schedule closer to what it was trained with, which has improved results in short experiments. Thanks to @mhirki for the contribution. + --flux_use_uniform_schedule + Whether or not to use a uniform schedule with Flux + instead of sigmoid. Using uniform sampling may help + preserve more capabilities from the base model. Some + tasks may not benefit from this. + --flux_use_beta_schedule + Whether or not to use a beta schedule with Flux + instead of sigmoid. The default values of alpha and + beta approximate a sigmoid. + --flux_beta_schedule_alpha FLUX_BETA_SCHEDULE_ALPHA + The alpha value of the flux beta schedule. Default is + 2.0 + --flux_beta_schedule_beta FLUX_BETA_SCHEDULE_BETA + The beta value of the flux beta schedule. Default is + 2.0 --flux_schedule_shift FLUX_SCHEDULE_SHIFT Shift the noise schedule. This is a value between 0 and ~4.0, where 0 disables the timestep-dependent @@ -627,28 +651,26 @@ options: --smoldit_config {smoldit-small,smoldit-swiglu,smoldit-base,smoldit-large,smoldit-huge} The SmolDiT configuration to use. This is a list of pre-configured models. The default is 'smoldit-base'. - --flow_matching_loss {diffusers,compatible,diffusion} + --flow_matching_loss {diffusers,compatible,diffusion,sd35} A discrepancy exists between the Diffusers implementation of flow matching and the minimal implementation provided by StabilityAI. This experimental option allows switching loss calculations to be compatible with those. Additionally, 'diffusion' is offered as an option to reparameterise a model to - v_prediction loss. - --sd3_t5_mask_behaviour {do-nothing,mask} - StabilityAI did not correctly implement their - attention masking on T5 inputs for SD3 Medium. This - option enables you to switch between their broken - implementation or the corrected mask implementation. - Although, the corrected masking is still applied via - hackish workaround, manually applying the mask to the - prompt embeds so that the padded positions are zero. - This improves the results for short captions, but does - not change the behaviour for long captions. It is - important to note that this limitation currently - prevents expansion of SD3 Medium's prompt length, as - it will unnecessarily attend to every token in the - prompt embed, even masked positions. + v_prediction loss. sd35 provides the ability to train + on SD3.5's flow-matching target, which is the denoised + sample. + --sd3_clip_uncond_behaviour {empty_string,zero} + SD3 can be trained using zeroed prompt embeds during + unconditional dropout, or an encoded empty string may + be used instead (the default). Changing this value may + stabilise or destabilise training. The default is + 'empty_string'. + --sd3_t5_uncond_behaviour {empty_string,zero} + Override the value of unconditional prompts from T5 + embeds. The default is to follow the value of + --sd3_clip_uncond_behaviour. --lora_type {standard,lycoris} When training using --model_type=lora, you may specify a different type of LoRA to train here. standard @@ -812,6 +834,11 @@ options: issues, but if you are at that point of contention, it's possible that your GPU has too little RAM. Default: 4. + --vae_enable_tiling If set, will enable tiling for VAE caching. This is + useful for very large images when VRAM is limited. + This may be required for 2048px VAE caching on 24G + accelerators, in addition to reducing + --vae_batch_size. --vae_cache_scan_behaviour {recreate,sync} When a mismatched latent vector is detected, a scan will be initiated to locate inconsistencies and @@ -1138,16 +1165,16 @@ options: When using DeepSpeed ZeRo stage 2 or 3 with NVMe offload, this may be specified to provide a path for the offload. - --optimizer {adamw_bf16,ao-adamw8bit,ao-adamw4bit,ao-adamfp8,ao-adamwfp8,adamw_schedulefree,adamw_schedulefree+aggressive,adamw_schedulefree+no_kahan,optimi-stableadamw,optimi-adamw,optimi-lion,optimi-radam,optimi-ranger,optimi-adan,optimi-adam,optimi-sgd} + --optimizer {adamw_bf16,ao-adamw8bit,ao-adamw4bit,ao-adamfp8,ao-adamwfp8,adamw_schedulefree,adamw_schedulefree+aggressive,adamw_schedulefree+no_kahan,optimi-stableadamw,optimi-adamw,optimi-lion,optimi-radam,optimi-ranger,optimi-adan,optimi-adam,optimi-sgd,soap,bnb-adagrad,bnb-adagrad8bit,bnb-adam,bnb-adam8bit,bnb-adamw,bnb-adamw8bit,bnb-adamw-paged,bnb-adamw8bit-paged,bnb-ademamix,bnb-ademamix8bit,bnb-ademamix-paged,bnb-ademamix8bit-paged,bnb-lion,bnb-lion8bit,bnb-lion-paged,bnb-lion8bit-paged} --optimizer_config OPTIMIZER_CONFIG When setting a given optimizer, this allows a comma- separated list of key-value pairs to be provided that will override the optimizer defaults. For example, `-- optimizer_config=decouple_lr=True,weight_decay=0.01`. - --optimizer_cpu_offload_method {none,torchao} - When loading an optimiser, a CPU offload mechanism can - be used. Currently, no offload is used by default, and - only torchao is supported. + --optimizer_cpu_offload_method {none} + This option is a placeholder. In the future, it will + allow for the selection of different CPU offload + methods. --optimizer_offload_gradients When creating a CPU-offloaded optimiser, the gradients can be offloaded to the CPU to save more memory. @@ -1231,6 +1258,18 @@ options: PyTorch provides different modes for the Torch Inductor when compiling graphs. max-autotune, the default mode, provides the most benefit. + --validation_guidance_skip_layers VALIDATION_GUIDANCE_SKIP_LAYERS + StabilityAI recommends a value of [7, 8, 9] for Stable + Diffusion 3.5 Medium. + --validation_guidance_skip_layers_start VALIDATION_GUIDANCE_SKIP_LAYERS_START + StabilityAI recommends a value of 0.01 for SLG start. + --validation_guidance_skip_layers_stop VALIDATION_GUIDANCE_SKIP_LAYERS_STOP + StabilityAI recommends a value of 0.2 for SLG start. + --validation_guidance_skip_scale VALIDATION_GUIDANCE_SKIP_SCALE + StabilityAI recommends a value of 2.8 for SLG guidance + skip scaling. When adding more layers, you must + increase the scale, eg. adding one more layer requires + doubling the value given. --allow_tf32 Deprecated option. TF32 is now enabled by default. Use --disable_tf32 to disable. --disable_tf32 Previous defaults were to disable TF32 on Ampere GPUs. @@ -1350,7 +1389,7 @@ options: required, but the process completes in milliseconds. When done on the CPU, the process may take upwards of 60 seconds, but can complete without OOM on 16G cards. - --base_model_precision {no_change,fp8-quanto,fp8uz-quanto,int8-quanto,int4-quanto,int2-quanto,int8-torchao} + --base_model_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto} When training a LoRA, you might want to quantise the base model to a lower precision to save more VRAM. The default value, 'no_change', does not quantise any @@ -1358,6 +1397,9 @@ options: Bits n Bytes for quantisation (NVIDIA, maybe AMD). Using 'fp8-quanto' will require Quanto for quantisation (Apple Silicon, NVIDIA, AMD). + --quantize_activations + (EXPERIMENTAL) This option is currently unsupported, + and exists solely for development purposes. --base_model_default_dtype {bf16,fp32} Unlike --mixed_precision, this value applies specifically for the default weights of your quantised @@ -1368,7 +1410,7 @@ options: optimizers than adamw_bf16. However, this uses marginally more memory, and may not be necessary for your use case. - --text_encoder_1_precision {no_change,fp8-quanto,fp8uz-quanto,int8-quanto,int4-quanto,int2-quanto,int8-torchao} + --text_encoder_1_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto} When training a LoRA, you might want to quantise text encoder 1 to a lower precision to save more VRAM. The default value is to follow base_model_precision @@ -1376,7 +1418,7 @@ options: Bits n Bytes for quantisation (NVIDIA, maybe AMD). Using 'fp8-quanto' will require Quanto for quantisation (Apple Silicon, NVIDIA, AMD). - --text_encoder_2_precision {no_change,fp8-quanto,fp8uz-quanto,int8-quanto,int4-quanto,int2-quanto,int8-torchao} + --text_encoder_2_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto} When training a LoRA, you might want to quantise text encoder 2 to a lower precision to save more VRAM. The default value is to follow base_model_precision @@ -1384,7 +1426,7 @@ options: Bits n Bytes for quantisation (NVIDIA, maybe AMD). Using 'fp8-quanto' will require Quanto for quantisation (Apple Silicon, NVIDIA, AMD). - --text_encoder_3_precision {no_change,fp8-quanto,fp8uz-quanto,int8-quanto,int4-quanto,int2-quanto,int8-torchao} + --text_encoder_3_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto} When training a LoRA, you might want to quantise text encoder 3 to a lower precision to save more VRAM. The default value is to follow base_model_precision From a675609f851045d069b6a504589de0a71c27dfbb Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 12 Nov 2024 19:27:23 +0000 Subject: [PATCH 06/11] add clip score tracking --- OPTIONS.md | 17 +++++++++++ helpers/configuration/cmd_args.py | 20 +++++++++++++ helpers/training/evaluation.py | 48 +++++++++++++++++++++++++++++ helpers/training/trainer.py | 10 +++++++ helpers/training/validation.py | 50 +++++++++++++++++++++++++++---- 5 files changed, 140 insertions(+), 5 deletions(-) create mode 100644 helpers/training/evaluation.py diff --git a/OPTIONS.md b/OPTIONS.md index a6bf5cf1..348c16e1 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -203,6 +203,11 @@ A lot of settings are instead set through the [dataloader config](/documentation - **What**: Output image resolution, measured in pixels, or, formatted as: `widthxheight`, as in `1024x1024`. Multiple resolutions can be defined, separated by commas. - **Why**: All images generated during validation will be this resolution. Useful if the model is being trained with a different resolution. +### `--validation_model_evaluator` + +- **What**: Enable CLIP evaluation of generated images during validations. +- **Why**: CLIP scores calculate the distance of the generated image features to the provided validation prompt. This can give an idea of whether prompt adherence is improving, though it requires a large number of validation prompts to have any meaningful value. +- **Options**: "none" or "clip" ### `--crop` @@ -467,6 +472,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--model_card_note MODEL_CARD_NOTE] [--model_card_safe_for_work] [--logging_dir LOGGING_DIR] [--benchmark_base_model] [--disable_benchmark] + [--validation_model_evaluator {clip,none}] + [--pretrained_validation_model_name_or_path PRETRAINED_VALIDATION_MODEL_NAME_OR_PATH] [--validation_on_startup] [--validation_seed_source {gpu,cpu}] [--validation_torch_compile] [--validation_torch_compile_mode {max-autotune,reduce-overhead,default}] @@ -1236,6 +1243,16 @@ options: --disable_benchmark By default, the model will be benchmarked on the first batch of the first epoch. This can be disabled with this option. + --validation_model_evaluator {clip,none} + Validations must be enabled for model evaluation to + function. The default is to use no evaluator, and + 'clip' will use a CLIP model to evaluate the resulting + model's performance during validations. + --pretrained_validation_model_name_or_path PRETRAINED_VALIDATION_MODEL_NAME_OR_PATH + Optionally provide a custom model to use for ViT + evaluations. The default is currently clip-vit-large- + patch14-336, allowing for lower patch sizes (greater + accuracy) and an input resolution of 336x336. --validation_on_startup When training begins, the starting model will have validation prompts run through it, for later diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 53471635..2baf3b02 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1330,6 +1330,26 @@ def get_argument_parser(): " This can be disabled with this option." ), ) + parser.add_argument( + "--validation_model_evaluator", + type=str, + default=None, + choices=["clip", "none"], + help=( + "Validations must be enabled for model evaluation to function. The default is to use no evaluator," + " and 'clip' will use a CLIP model to evaluate the resulting model's performance during validations." + ) + ) + parser.add_argument( + "--pretrained_validation_model_name_or_path", + type=str, + default="openai/clip-vit-large-patch14-336", + help=( + "Optionally provide a custom model to use for ViT evaluations." + " The default is currently clip-vit-large-patch14-336, allowing for lower patch sizes (greater accuracy)" + " and an input resolution of 336x336." + ) + ) parser.add_argument( "--validation_on_startup", action="store_true", diff --git a/helpers/training/evaluation.py b/helpers/training/evaluation.py new file mode 100644 index 00000000..5b91b008 --- /dev/null +++ b/helpers/training/evaluation.py @@ -0,0 +1,48 @@ +from functools import partial +from torchmetrics.functional.multimodal import clip_score +from torchvision import transforms +import torch, logging, os +import numpy as np +from PIL import Image +from helpers.training.state_tracker import StateTracker + +logger = logging.getLogger("ModelEvaluator") +logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) + +model_evaluator_map = { + "clip": "CLIPModelEvaluator", +} + +class ModelEvaluator: + def __init__(self, pretrained_model_name_or_path): + raise NotImplementedError("Subclasses is incomplete, no __init__ method was found.") + + def evaluate(self, images, prompts): + raise NotImplementedError("Subclasses should implement the evaluate() method.") + + @staticmethod + def from_config(args): + """Instantiate a ModelEvaluator from the training config, if set to do so.""" + if not StateTracker.get_accelerator().is_main_process: + return None + if args.validation_model_evaluator is not None and args.validation_model_evaluator.lower() != "" and args.validation_model_evaluator.lower() != "none": + model_evaluator = model_evaluator_map[args.validation_model_evaluator] + return globals()[model_evaluator](args.pretrained_validation_model_name_or_path) + + return None + + +class CLIPModelEvaluator(ModelEvaluator): + def __init__(self, pretrained_model_name_or_path='openai/clip-vit-large-patch14-336'): + self.clip_score_fn = partial(clip_score, model_name_or_path=pretrained_model_name_or_path) + self.preprocess = transforms.Compose([ + transforms.ToTensor() + ]) + + def evaluate(self, images, prompts): + # Preprocess images + images_tensor = torch.stack([self.preprocess(img) * 255 for img in images]) + # Compute CLIP scores + result = self.clip_score_fn(images_tensor, prompts).detach().cpu() + + return result diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 904cc31f..577c8ec1 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -21,6 +21,7 @@ from helpers.caching.memory import reclaim_memory from helpers.training.multi_process import _get_rank as get_rank from helpers.training.validation import Validation, prepare_validation_prompt_list +from helpers.training.evaluation import ModelEvaluator from helpers.training.state_tracker import StateTracker from helpers.training.schedulers import load_scheduler_from_args from helpers.training.custom_schedule import get_lr_scheduler @@ -1353,6 +1354,7 @@ def init_validations(self): ): logger.error("Cannot run validations with DeepSpeed ZeRO stage 3.") return + model_evaluator = ModelEvaluator.from_config(args=self.config) self.validation = Validation( accelerator=self.accelerator, unet=self.unet, @@ -1374,6 +1376,7 @@ def init_validations(self): ema_model=self.ema_model, vae=self.vae, controlnet=self.controlnet if self.config.controlnet else None, + model_evaluator=model_evaluator ) if not self.config.train_text_encoder and self.validation is not None: self.validation.clear_text_encoders() @@ -2589,6 +2592,13 @@ def train(self): self.guidance_values_list = [] if grad_norm is not None: wandb_logs["grad_norm"] = grad_norm + if self.validation is not None and hasattr(self.validation, 'evaluation_result'): + eval_result = self.validation.get_eval_result() + if eval_result is not None and type(eval_result) == dict: + # add the dict to wandb_logs + self.validation.clear_eval_result() + wandb_logs.update(eval_result) + progress_bar.update(1) self.state["global_step"] += 1 current_epoch_step += 1 diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 717612e3..3e45590f 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -398,6 +398,7 @@ def __init__( text_encoder_3=None, tokenizer_3=None, is_deepspeed: bool = False, + model_evaluator=None, ): self.accelerator = accelerator self.prompt_handler = None @@ -453,8 +454,11 @@ def __init__( if not is_deepspeed else "cuda" if torch.cuda.is_available() else "cpu" ) - + self.model_evaluator = model_evaluator + if self.model_evaluator is not None: + logger.info(f"Using model evaluator: {self.model_evaluator}") self._update_state() + self.eval_scores = {} def _validation_seed_source(self): if self.args.validation_seed_source == "gpu": @@ -897,6 +901,8 @@ def run_validations( self.setup_scheduler() self.process_prompts() self.finalize_validation(validation_type) + if self.evaluation_result is not None: + logger.info(f"Evaluation result: {self.evaluation_result}") logger.debug("Validation process completed.") self.clean_pipeline() @@ -1120,6 +1126,8 @@ def clean_pipeline(self): def process_prompts(self): """Processes each validation prompt and logs the result.""" + self.validation_prompt_dict = {} + self.evaluation_result = None validation_images = {} _content = zip(self.validation_shortnames, self.validation_prompts) total_samples = ( @@ -1127,6 +1135,7 @@ def process_prompts(self): if self.validation_shortnames is not None else 0 ) + self.eval_scores = {} if self.validation_image_inputs: # Override the pipeline inputs to be entirely based upon the validation image inputs. _content = self.validation_image_inputs @@ -1148,19 +1157,27 @@ def process_prompts(self): raise ValueError( f"Validation content is not in the correct format: {content}" ) + self.validation_prompt_dict[shortname] = prompt logger.debug(f"Processing validation for prompt: {prompt}") validation_images.update( self.validate_prompt(prompt, shortname, validation_input_image) ) self._save_images(validation_images, shortname, prompt) - self._log_validations_to_webhook(validation_images, shortname, prompt) logger.debug(f"Completed generating image: {prompt}") - self.validation_images = validation_images + self.validation_images = validation_images + self.evaluation_result = self.evaluate_images() + self._log_validations_to_webhook(validation_images, shortname, prompt) try: self._log_validations_to_trackers(validation_images) except Exception as e: logger.error(f"Error logging validation images: {e}") + def get_eval_result(self): + return self.evaluation_result or {} + + def clear_eval_result(self): + self.evaluation_result = None + def stitch_conditioning_images(self, validation_image_results, conditioning_image): """ For each image, make a new canvas and place it side by side with its equivalent from {self.validation_image_inputs} @@ -1364,8 +1381,11 @@ def _log_validations_to_webhook( ): if StateTracker.get_webhook_handler() is not None: StateTracker.get_webhook_handler().send( - f"Validation image for `{validation_shortname if validation_shortname != '' else '(blank shortname)'}`" - f"\nValidation prompt: `{validation_prompt if validation_prompt != '' else '(blank prompt)'}`", + ( + f"Validation image for `{validation_shortname if validation_shortname != '' else '(blank shortname)'}`" + f"\nValidation prompt: `{validation_prompt if validation_prompt != '' else '(blank prompt)'}`" + f"\nEvaluation score: {self.eval_scores.get(validation_shortname, 'N/A')}" + ), images=validation_images[validation_shortname], ) @@ -1470,3 +1490,23 @@ def finalize_validation(self, validation_type, enable_ema_model: bool = True): self.pipeline = None if torch.cuda.is_available(): torch.cuda.empty_cache() + + def evaluate_images(self): + if self.model_evaluator is None: + return None + for shortname, image_list in self.validation_images.items(): + if shortname in self.eval_scores: + continue + prompt = self.validation_prompt_dict.get(shortname, '') + for image in image_list: + evaluation_score = self.model_evaluator.evaluate([image], [prompt]) + self.eval_scores[shortname] = round(float(evaluation_score), 4) + # Log the scores into dict: {"min", "max", "mean", "std"} + result = { + "clip/min": min(self.eval_scores.values()), + "clip/max": max(self.eval_scores.values()), + "clip/mean": np.mean(list(self.eval_scores.values())), + "clip/std": np.std(list(self.eval_scores.values())), + } + + return result \ No newline at end of file From 28f66360ba5ad48954a7aed0c6bedf7b27d48adb Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 12 Nov 2024 22:38:21 +0000 Subject: [PATCH 07/11] evaluation: only evaluate the original validation image, not the stitched copy --- helpers/training/validation.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 3e45590f..d91f1b8c 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -1159,13 +1159,14 @@ def process_prompts(self): ) self.validation_prompt_dict[shortname] = prompt logger.debug(f"Processing validation for prompt: {prompt}") + stitched_validation_images, original_validation_images = self.validate_prompt(prompt, shortname, validation_input_image) validation_images.update( - self.validate_prompt(prompt, shortname, validation_input_image) + stitched_validation_images ) self._save_images(validation_images, shortname, prompt) logger.debug(f"Completed generating image: {prompt}") self.validation_images = validation_images - self.evaluation_result = self.evaluate_images() + self.evaluation_result = self.evaluate_images(original_validation_images) self._log_validations_to_webhook(validation_images, shortname, prompt) try: self._log_validations_to_trackers(validation_images) @@ -1199,7 +1200,10 @@ def validate_prompt( """Generate validation images for a single prompt.""" # Placeholder for actual image generation and logging logger.debug(f"Validating prompt: {prompt}") + # benchmarked / stitched validation images validation_images = {} + # untouched / un-stitched validation images + original_validation_images = {} for resolution in self.validation_resolutions: extra_validation_kwargs = {} if not self.args.validation_randomize: @@ -1265,6 +1269,7 @@ def validate_prompt( ) if validation_shortname not in validation_images: validation_images[validation_shortname] = [] + original_validation_images[validation_shortname] = [] try: extra_validation_kwargs.update(self._gather_prompt_embeds(prompt)) except Exception as e: @@ -1331,10 +1336,11 @@ def validate_prompt( pipeline_kwargs.pop("negative_mask")[0], dim=0 ).to(device=self.inference_device, dtype=self.weight_dtype) - validation_image_results = self.pipeline(**pipeline_kwargs).images + original_validation_image_results = self.pipeline(**pipeline_kwargs).images + validation_image_results = original_validation_image_results.copy() if self.args.controlnet: validation_image_results = self.stitch_conditioning_images( - validation_image_results, extra_validation_kwargs["image"] + original_validation_image_results, extra_validation_kwargs["image"] ) elif not self.args.disable_benchmark and self.benchmark_exists( "base_model" @@ -1348,6 +1354,7 @@ def validate_prompt( validation_image_results[0], benchmark_image ) validation_images[validation_shortname].extend(validation_image_results) + original_validation_images[validation_shortname].extend(original_validation_image_results) except Exception as e: import traceback @@ -1356,7 +1363,7 @@ def validate_prompt( ) continue - return validation_images + return validation_images, original_validation_images def _save_images(self, validation_images, validation_shortname, validation_prompt): validation_img_idx = 0 @@ -1491,10 +1498,10 @@ def finalize_validation(self, validation_type, enable_ema_model: bool = True): if torch.cuda.is_available(): torch.cuda.empty_cache() - def evaluate_images(self): + def evaluate_images(self, images: list = None): if self.model_evaluator is None: return None - for shortname, image_list in self.validation_images.items(): + for shortname, image_list in images.items(): if shortname in self.eval_scores: continue prompt = self.validation_prompt_dict.get(shortname, '') From ced89ee7f8491927399520b08929f5ec97fa3686 Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 13 Nov 2024 00:59:47 +0000 Subject: [PATCH 08/11] add CLIP evaluation doc and links from quickstarts --- documentation/DREAMBOOTH.md | 4 ++++ documentation/MIXTURE_OF_EXPERTS.md | 4 ++++ documentation/evaluation/CLIP_SCORES.md | 24 ++++++++++++++++++++++++ documentation/quickstart/FLUX.md | 4 ++++ documentation/quickstart/KOLORS.md | 4 ++++ documentation/quickstart/SD3.md | 6 +++++- documentation/quickstart/SIGMA.md | 4 ++++ 7 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 documentation/evaluation/CLIP_SCORES.md diff --git a/documentation/DREAMBOOTH.md b/documentation/DREAMBOOTH.md index 44d79267..46cf3892 100644 --- a/documentation/DREAMBOOTH.md +++ b/documentation/DREAMBOOTH.md @@ -222,6 +222,10 @@ Alternatively, one might use the real name of their subject, or a 'similar enoug After a number of training experiments, it seems as though a 'similar enough' celebrity is the best choice, especially if prompting the model for the person's real name ends up looking dissimilar. +# CLIP score tracking + +If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores. + # Refiner tuning If you're a fan of the SDXL refiner, you may find that it causes your generations to "ruin" the results of your Dreamboothed model. diff --git a/documentation/MIXTURE_OF_EXPERTS.md b/documentation/MIXTURE_OF_EXPERTS.md index f54953d5..1e7663d5 100644 --- a/documentation/MIXTURE_OF_EXPERTS.md +++ b/documentation/MIXTURE_OF_EXPERTS.md @@ -105,6 +105,10 @@ If you'd like a demonstration dataset, [pseudo-camera-10k](https://huggingface.c Stage two refiner training will automatically select images from each of your training sets, and use those as inputs for partial denoising at validation time. +## CLIP score tracking + +If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores. + ## Putting it all together at inference time If you'd like to plug both of the models together to experiment with in a simple script, this will get you started: diff --git a/documentation/evaluation/CLIP_SCORES.md b/documentation/evaluation/CLIP_SCORES.md new file mode 100644 index 00000000..4ce477ca --- /dev/null +++ b/documentation/evaluation/CLIP_SCORES.md @@ -0,0 +1,24 @@ +## CLIP score tracking + +The CLIP score of your model indicates how closely the features extracted from the image align with the features extracted from the prompt. It is currently a popular metric for determining general prompt adherence, though is typically evaluated across a very large (~5,000) number of test prompts (eg. Parti Prompts). + +CLIP score generation during model pretraining can help demonstrate that the model is approaching its objective, but once a `clip/mean` value around `.30` to `.39` is reached, the comparison seems to become less meaningful. Models that show an average CLIP score around `.33` can outperform a model with an average CLIP score of `.36` in human analysis. + +CLIP scores are not related to image quality. Within a single test run, some prompts will result in a very low CLIP score of around `0.14` (`clip/min` value in the tracker charts) even though their images align fairly well with the user prompt; conversely, CLIP scores as high as `0.39` (`clip/max` value in the tracker charts) may appear from images with questionable quality, as this test is not meant to capture this information. + +On its own, CLIP scores do not take long to calculate; however, the number of prompts required for meaningful evaluation can make it take an incredibly long time. + +Since it doesn't take much to run, it doesn't hurt to include CLIP evaluation in small training runs. Perhaps you will discover a pattern of the outputs where it makes sense to abandon a training run or adjust other hyperparameters such as the learning rate. + +To include a standard prompt library for evaluation, `--validation_prompt_library` can be provided and then we will generate a somewhat relative benchmark between training runs. + +In `config.json`: + +```json +{ + ... + "evaluation_type": "clip", + "pretrained_evaluation_model_name_or_path": "openai/clip-vit-large-patch14-336", + ... +} +``` diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 41f4e64c..57e1f4ea 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -191,6 +191,10 @@ A set of diverse prompt will help determine whether the model is collapsing as i > ℹ️ Flux is a flow-matching model and shorter prompts that have strong similarities will result in practically the same image being produced by the model. Be sure to use longer, more descriptive prompts. +#### CLIP score tracking + +If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores. + #### Flux time schedule shifting Flow-matching models such as Flux and SD3 have a property called "shift" that allows us to shift the trained portion of the timestep schedule using a simple decimal value. diff --git a/documentation/quickstart/KOLORS.md b/documentation/quickstart/KOLORS.md index 86f50b4d..d052eacc 100644 --- a/documentation/quickstart/KOLORS.md +++ b/documentation/quickstart/KOLORS.md @@ -260,3 +260,7 @@ bash train.sh This will begin the text embed and VAE output caching to disk. For more information, see the [dataloader](/documentation/DATALOADER.md) and [tutorial](/TUTORIAL.md) documents. + +### CLIP score tracking + +If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores. diff --git a/documentation/quickstart/SD3.md b/documentation/quickstart/SD3.md index 16837520..5d646562 100644 --- a/documentation/quickstart/SD3.md +++ b/documentation/quickstart/SD3.md @@ -349,4 +349,8 @@ For more information on regularisation datasets, see [this section](/documentati ### Quantised training -See [this section](/documentation/DREAMBOOTH.md#quantised-model-training-loralycoris-only) of the Dreambooth guide for information on configuring quantisation for SD3 and other models. \ No newline at end of file +See [this section](/documentation/DREAMBOOTH.md#quantised-model-training-loralycoris-only) of the Dreambooth guide for information on configuring quantisation for SD3 and other models. + +### CLIP score tracking + +If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores. diff --git a/documentation/quickstart/SIGMA.md b/documentation/quickstart/SIGMA.md index f02e72b3..cf3389aa 100644 --- a/documentation/quickstart/SIGMA.md +++ b/documentation/quickstart/SIGMA.md @@ -216,3 +216,7 @@ bash train.sh This will begin the text embed and VAE output caching to disk. For more information, see the [dataloader](/documentation/DATALOADER.md) and [tutorial](/TUTORIAL.md) documents. + +### CLIP score tracking + +If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores. From 6a8b707f6827a5bb372e91b6ba0c646c749d1fac Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 13 Nov 2024 01:13:51 +0000 Subject: [PATCH 09/11] revise CLIP score doc text --- documentation/evaluation/CLIP_SCORES.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/documentation/evaluation/CLIP_SCORES.md b/documentation/evaluation/CLIP_SCORES.md index 4ce477ca..bb30d6e1 100644 --- a/documentation/evaluation/CLIP_SCORES.md +++ b/documentation/evaluation/CLIP_SCORES.md @@ -1,10 +1,12 @@ -## CLIP score tracking +# CLIP score tracking -The CLIP score of your model indicates how closely the features extracted from the image align with the features extracted from the prompt. It is currently a popular metric for determining general prompt adherence, though is typically evaluated across a very large (~5,000) number of test prompts (eg. Parti Prompts). +CLIP scores are loosely related to measurement of a model's ability to follow prompts; it is not at all related to image quality/fidelity. -CLIP score generation during model pretraining can help demonstrate that the model is approaching its objective, but once a `clip/mean` value around `.30` to `.39` is reached, the comparison seems to become less meaningful. Models that show an average CLIP score around `.33` can outperform a model with an average CLIP score of `.36` in human analysis. +The `clip/mean` score of your model indicates how closely the features extracted from the image align with the features extracted from the prompt. It is currently a popular metric for determining general prompt adherence, though is typically evaluated across a very large (~5,000) number of test prompts (eg. Parti Prompts). -CLIP scores are not related to image quality. Within a single test run, some prompts will result in a very low CLIP score of around `0.14` (`clip/min` value in the tracker charts) even though their images align fairly well with the user prompt; conversely, CLIP scores as high as `0.39` (`clip/max` value in the tracker charts) may appear from images with questionable quality, as this test is not meant to capture this information. +CLIP score generation during model pretraining can help demonstrate that the model is approaching its objective, but once a `clip/mean` value around `.30` to `.39` is reached, the comparison seems to become less meaningful. Models that show an average CLIP score around `.33` can outperform a model with an average CLIP score of `.36` in human analysis. However, a model with a very low average CLIP score around `0.18` to `0.22` will probably be pretty poorly-performing. + +Within a single test run, some prompts will result in a very low CLIP score of around `0.14` (`clip/min` value in the tracker charts) even though their images align fairly well with the user prompt and have high image quality; conversely, CLIP scores as high as `0.39` (`clip/max` value in the tracker charts) may appear from images with questionable quality, as this test is not meant to capture this information. This is why such a large number of prompts are typically used to measure model performance - _and even then_.. On its own, CLIP scores do not take long to calculate; however, the number of prompts required for meaningful evaluation can make it take an incredibly long time. @@ -19,6 +21,7 @@ In `config.json`: ... "evaluation_type": "clip", "pretrained_evaluation_model_name_or_path": "openai/clip-vit-large-patch14-336", + "report_to": "tensorboard", # or wandb ... } ``` From 7088595038e271698bfb9e541cbd5adfcd22ae80 Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 13 Nov 2024 16:38:02 +0000 Subject: [PATCH 10/11] add documentation updates --- OPTIONS.md | 10 +++++----- documentation/quickstart/FLUX.md | 1 + helpers/configuration/cmd_args.py | 4 ++-- helpers/training/evaluation.py | 6 +++--- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/OPTIONS.md b/OPTIONS.md index 348c16e1..edeea098 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -203,7 +203,7 @@ A lot of settings are instead set through the [dataloader config](/documentation - **What**: Output image resolution, measured in pixels, or, formatted as: `widthxheight`, as in `1024x1024`. Multiple resolutions can be defined, separated by commas. - **Why**: All images generated during validation will be this resolution. Useful if the model is being trained with a different resolution. -### `--validation_model_evaluator` +### `--evaluation_type` - **What**: Enable CLIP evaluation of generated images during validations. - **Why**: CLIP scores calculate the distance of the generated image features to the provided validation prompt. This can give an idea of whether prompt adherence is improving, though it requires a large number of validation prompts to have any meaningful value. @@ -472,8 +472,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--model_card_note MODEL_CARD_NOTE] [--model_card_safe_for_work] [--logging_dir LOGGING_DIR] [--benchmark_base_model] [--disable_benchmark] - [--validation_model_evaluator {clip,none}] - [--pretrained_validation_model_name_or_path PRETRAINED_VALIDATION_MODEL_NAME_OR_PATH] + [--evaluation_type {clip,none}] + [--pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path] [--validation_on_startup] [--validation_seed_source {gpu,cpu}] [--validation_torch_compile] [--validation_torch_compile_mode {max-autotune,reduce-overhead,default}] @@ -1243,12 +1243,12 @@ options: --disable_benchmark By default, the model will be benchmarked on the first batch of the first epoch. This can be disabled with this option. - --validation_model_evaluator {clip,none} + --evaluation_type {clip,none} Validations must be enabled for model evaluation to function. The default is to use no evaluator, and 'clip' will use a CLIP model to evaluate the resulting model's performance during validations. - --pretrained_validation_model_name_or_path PRETRAINED_VALIDATION_MODEL_NAME_OR_PATH + --pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path Optionally provide a custom model to use for ViT evaluations. The default is currently clip-vit-large- patch14-336, allowing for lower patch sizes (greater diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 57e1f4ea..ddb86e1d 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -413,6 +413,7 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with: - Batch size: 1, zero gradient accumulation steps - DeepSpeed: disabled / unconfigured - PyTorch: 2.6 Nightly (Sept 29th build) +- Using `--quantize_via=cpu` to avoid outOfMemory error during startup on <=16G cards. Speed was approximately 1.4 iterations per second on a 4090. diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 2baf3b02..745b61a1 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1331,7 +1331,7 @@ def get_argument_parser(): ), ) parser.add_argument( - "--validation_model_evaluator", + "--evaluation_type", type=str, default=None, choices=["clip", "none"], @@ -1341,7 +1341,7 @@ def get_argument_parser(): ) ) parser.add_argument( - "--pretrained_validation_model_name_or_path", + "--pretrained_evaluation_model_name_or_path", type=str, default="openai/clip-vit-large-patch14-336", help=( diff --git a/helpers/training/evaluation.py b/helpers/training/evaluation.py index 5b91b008..351fd94c 100644 --- a/helpers/training/evaluation.py +++ b/helpers/training/evaluation.py @@ -25,9 +25,9 @@ def from_config(args): """Instantiate a ModelEvaluator from the training config, if set to do so.""" if not StateTracker.get_accelerator().is_main_process: return None - if args.validation_model_evaluator is not None and args.validation_model_evaluator.lower() != "" and args.validation_model_evaluator.lower() != "none": - model_evaluator = model_evaluator_map[args.validation_model_evaluator] - return globals()[model_evaluator](args.pretrained_validation_model_name_or_path) + if args.evaluation_type is not None and args.evaluation_type.lower() != "" and args.evaluation_type.lower() != "none": + model_evaluator = model_evaluator_map[args.evaluation_type] + return globals()[model_evaluator](args.pretrained_evaluation_model_name_or_path) return None From ca023b59ec84cfdb54f704726164e12236def7d6 Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 13 Nov 2024 19:04:36 +0000 Subject: [PATCH 11/11] metadata: add more ddpm related schedule info to the model card --- helpers/publishing/metadata.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index b3cd0568..4aeccbc5 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -362,12 +362,39 @@ def sd3_schedule_info(args): return output_str +def ddpm_schedule_info(args): + """Information about DDPM schedules, eg. rescaled betas or offset noise""" + output_args = [] + if args.snr_gamma: + output_args.append(f"snr_gamma={args.snr_gamma}") + if args.use_soft_min_snr: + output_args.append(f"use_soft_min_snr") + if args.soft_min_snr_sigma_data: + output_args.append(f"soft_min_snr_sigma_data={args.soft_min_snr_sigma_data}") + if args.rescale_betas_zero_snr: + output_args.append(f"rescale_betas_zero_snr") + if args.offset_noise: + output_args.append(f"offset_noise") + output_args.append(f"noise_offset={args.noise_offset}") + output_args.append(f"noise_offset_probability={args.noise_offset_probability}") + output_args.append(f"training_scheduler_timestep_spacing={args.training_scheduler_timestep_spacing}") + output_args.append(f"validation_scheduler_timestep_spacing={args.validation_scheduler_timestep_spacing}") + output_str = ( + f" (extra parameters={output_args})" + if output_args + else " (no special parameters set)" + ) + + return output_str def model_schedule_info(args): if args.model_family == "flux": return flux_schedule_info(args) if args.model_family == "sd3": return sd3_schedule_info(args) + else: + return ddpm_schedule_info(args) + def save_model_card( @@ -495,10 +522,9 @@ def save_model_card( - Gradient accumulation steps: {StateTracker.get_args().gradient_accumulation_steps} - Number of GPUs: {StateTracker.get_accelerator().num_processes} - Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{model_schedule_info(args=StateTracker.get_args())} -- Rescaled betas zero SNR: {StateTracker.get_args().rescale_betas_zero_snr} - Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''} -- Precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'} -- Quantised: {f'Yes: {StateTracker.get_args().base_model_precision}' if StateTracker.get_args().base_model_precision != "no_change" else 'No'} +- Trainable parameter precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'} +- Quantised base model: {f'Yes ({StateTracker.get_args().base_model_precision})' if StateTracker.get_args().base_model_precision != "no_change" else 'No'} - Xformers: {'Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else 'Not used'} {lora_info(args=StateTracker.get_args())}