Skip to content

Commit

Permalink
Merge branch 'main' into feature/omnigen
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira authored Nov 13, 2024
2 parents 250b5ab + 23ec487 commit 01beb97
Show file tree
Hide file tree
Showing 14 changed files with 328 additions and 51 deletions.
129 changes: 94 additions & 35 deletions OPTIONS.md

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions documentation/DREAMBOOTH.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ Alternatively, one might use the real name of their subject, or a 'similar enoug

After a number of training experiments, it seems as though a 'similar enough' celebrity is the best choice, especially if prompting the model for the person's real name ends up looking dissimilar.

# CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.

# Refiner tuning

If you're a fan of the SDXL refiner, you may find that it causes your generations to "ruin" the results of your Dreamboothed model.
Expand Down
4 changes: 4 additions & 0 deletions documentation/MIXTURE_OF_EXPERTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ If you'd like a demonstration dataset, [pseudo-camera-10k](https://huggingface.c

Stage two refiner training will automatically select images from each of your training sets, and use those as inputs for partial denoising at validation time.

## CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.

## Putting it all together at inference time

If you'd like to plug both of the models together to experiment with in a simple script, this will get you started:
Expand Down
27 changes: 27 additions & 0 deletions documentation/evaluation/CLIP_SCORES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# CLIP score tracking

CLIP scores are loosely related to measurement of a model's ability to follow prompts; it is not at all related to image quality/fidelity.

The `clip/mean` score of your model indicates how closely the features extracted from the image align with the features extracted from the prompt. It is currently a popular metric for determining general prompt adherence, though is typically evaluated across a very large (~5,000) number of test prompts (eg. Parti Prompts).

CLIP score generation during model pretraining can help demonstrate that the model is approaching its objective, but once a `clip/mean` value around `.30` to `.39` is reached, the comparison seems to become less meaningful. Models that show an average CLIP score around `.33` can outperform a model with an average CLIP score of `.36` in human analysis. However, a model with a very low average CLIP score around `0.18` to `0.22` will probably be pretty poorly-performing.

Within a single test run, some prompts will result in a very low CLIP score of around `0.14` (`clip/min` value in the tracker charts) even though their images align fairly well with the user prompt and have high image quality; conversely, CLIP scores as high as `0.39` (`clip/max` value in the tracker charts) may appear from images with questionable quality, as this test is not meant to capture this information. This is why such a large number of prompts are typically used to measure model performance - _and even then_..

On its own, CLIP scores do not take long to calculate; however, the number of prompts required for meaningful evaluation can make it take an incredibly long time.

Since it doesn't take much to run, it doesn't hurt to include CLIP evaluation in small training runs. Perhaps you will discover a pattern of the outputs where it makes sense to abandon a training run or adjust other hyperparameters such as the learning rate.

To include a standard prompt library for evaluation, `--validation_prompt_library` can be provided and then we will generate a somewhat relative benchmark between training runs.

In `config.json`:

```json
{
...
"evaluation_type": "clip",
"pretrained_evaluation_model_name_or_path": "openai/clip-vit-large-patch14-336",
"report_to": "tensorboard", # or wandb
...
}
```
5 changes: 5 additions & 0 deletions documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ A set of diverse prompt will help determine whether the model is collapsing as i

> ℹ️ Flux is a flow-matching model and shorter prompts that have strong similarities will result in practically the same image being produced by the model. Be sure to use longer, more descriptive prompts.
#### CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.

#### Flux time schedule shifting

Flow-matching models such as Flux and SD3 have a property called "shift" that allows us to shift the trained portion of the timestep schedule using a simple decimal value.
Expand Down Expand Up @@ -409,6 +413,7 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with:
- Batch size: 1, zero gradient accumulation steps
- DeepSpeed: disabled / unconfigured
- PyTorch: 2.6 Nightly (Sept 29th build)
- Using `--quantize_via=cpu` to avoid outOfMemory error during startup on <=16G cards.

Speed was approximately 1.4 iterations per second on a 4090.

Expand Down
4 changes: 4 additions & 0 deletions documentation/quickstart/KOLORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,7 @@ bash train.sh
This will begin the text embed and VAE output caching to disk.

For more information, see the [dataloader](/documentation/DATALOADER.md) and [tutorial](/TUTORIAL.md) documents.

### CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.
6 changes: 5 additions & 1 deletion documentation/quickstart/SD3.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,4 +349,8 @@ For more information on regularisation datasets, see [this section](/documentati

### Quantised training

See [this section](/documentation/DREAMBOOTH.md#quantised-model-training-loralycoris-only) of the Dreambooth guide for information on configuring quantisation for SD3 and other models.
See [this section](/documentation/DREAMBOOTH.md#quantised-model-training-loralycoris-only) of the Dreambooth guide for information on configuring quantisation for SD3 and other models.

### CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.
4 changes: 4 additions & 0 deletions documentation/quickstart/SIGMA.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,7 @@ bash train.sh
This will begin the text embed and VAE output caching to disk.

For more information, see the [dataloader](/documentation/DATALOADER.md) and [tutorial](/TUTORIAL.md) documents.

### CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.
29 changes: 29 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,15 @@ def get_argument_parser():
" but if you are at that point of contention, it's possible that your GPU has too little RAM. Default: 4."
),
)
parser.add_argument(
"--vae_enable_tiling",
action="store_true",
default=False,
help=(
"If set, will enable tiling for VAE caching. This is useful for very large images when VRAM is limited."
" This may be required for 2048px VAE caching on 24G accelerators, in addition to reducing --vae_batch_size."
),
)
parser.add_argument(
"--vae_cache_scan_behaviour",
type=str,
Expand Down Expand Up @@ -1330,6 +1339,26 @@ def get_argument_parser():
" This can be disabled with this option."
),
)
parser.add_argument(
"--evaluation_type",
type=str,
default=None,
choices=["clip", "none"],
help=(
"Validations must be enabled for model evaluation to function. The default is to use no evaluator,"
" and 'clip' will use a CLIP model to evaluate the resulting model's performance during validations."
)
)
parser.add_argument(
"--pretrained_evaluation_model_name_or_path",
type=str,
default="openai/clip-vit-large-patch14-336",
help=(
"Optionally provide a custom model to use for ViT evaluations."
" The default is currently clip-vit-large-patch14-336, allowing for lower patch sizes (greater accuracy)"
" and an input resolution of 336x336."
)
)
parser.add_argument(
"--validation_on_startup",
action="store_true",
Expand Down
2 changes: 2 additions & 0 deletions helpers/data_backend/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def exists(self, s3_key):
except (NoCredentialsError, PartialCredentialsError) as e:
raise e # Raise credential errors to the caller
except Exception as e:
if "An error occurred (404) when calling the HeadObject operation: Not Found" in str(e):
return False
logger.error(f'Error checking existence of S3 key "{s3_key}": {e}')
if i == self.read_retry_limit - 1:
# We have reached our maximum retry count.
Expand Down
32 changes: 29 additions & 3 deletions helpers/publishing/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,39 @@ def sd3_schedule_info(args):

return output_str

def ddpm_schedule_info(args):
"""Information about DDPM schedules, eg. rescaled betas or offset noise"""
output_args = []
if args.snr_gamma:
output_args.append(f"snr_gamma={args.snr_gamma}")
if args.use_soft_min_snr:
output_args.append(f"use_soft_min_snr")
if args.soft_min_snr_sigma_data:
output_args.append(f"soft_min_snr_sigma_data={args.soft_min_snr_sigma_data}")
if args.rescale_betas_zero_snr:
output_args.append(f"rescale_betas_zero_snr")
if args.offset_noise:
output_args.append(f"offset_noise")
output_args.append(f"noise_offset={args.noise_offset}")
output_args.append(f"noise_offset_probability={args.noise_offset_probability}")
output_args.append(f"training_scheduler_timestep_spacing={args.training_scheduler_timestep_spacing}")
output_args.append(f"validation_scheduler_timestep_spacing={args.validation_scheduler_timestep_spacing}")
output_str = (
f" (extra parameters={output_args})"
if output_args
else " (no special parameters set)"
)

return output_str

def model_schedule_info(args):
if args.model_family == "flux":
return flux_schedule_info(args)
if args.model_family == "sd3":
return sd3_schedule_info(args)
else:
return ddpm_schedule_info(args)



def save_model_card(
Expand Down Expand Up @@ -495,10 +522,9 @@ def save_model_card(
- Gradient accumulation steps: {StateTracker.get_args().gradient_accumulation_steps}
- Number of GPUs: {StateTracker.get_accelerator().num_processes}
- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{model_schedule_info(args=StateTracker.get_args())}
- Rescaled betas zero SNR: {StateTracker.get_args().rescale_betas_zero_snr}
- Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''}
- Precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'}
- Quantised: {f'Yes: {StateTracker.get_args().base_model_precision}' if StateTracker.get_args().base_model_precision != "no_change" else 'No'}
- Trainable parameter precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'}
- Quantised base model: {f'Yes ({StateTracker.get_args().base_model_precision})' if StateTracker.get_args().base_model_precision != "no_change" else 'No'}
- Xformers: {'Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else 'Not used'}
{lora_info(args=StateTracker.get_args())}
Expand Down
48 changes: 48 additions & 0 deletions helpers/training/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from functools import partial
from torchmetrics.functional.multimodal import clip_score
from torchvision import transforms
import torch, logging, os
import numpy as np
from PIL import Image
from helpers.training.state_tracker import StateTracker

logger = logging.getLogger("ModelEvaluator")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))

model_evaluator_map = {
"clip": "CLIPModelEvaluator",
}

class ModelEvaluator:
def __init__(self, pretrained_model_name_or_path):
raise NotImplementedError("Subclasses is incomplete, no __init__ method was found.")

def evaluate(self, images, prompts):
raise NotImplementedError("Subclasses should implement the evaluate() method.")

@staticmethod
def from_config(args):
"""Instantiate a ModelEvaluator from the training config, if set to do so."""
if not StateTracker.get_accelerator().is_main_process:
return None
if args.evaluation_type is not None and args.evaluation_type.lower() != "" and args.evaluation_type.lower() != "none":
model_evaluator = model_evaluator_map[args.evaluation_type]
return globals()[model_evaluator](args.pretrained_evaluation_model_name_or_path)

return None


class CLIPModelEvaluator(ModelEvaluator):
def __init__(self, pretrained_model_name_or_path='openai/clip-vit-large-patch14-336'):
self.clip_score_fn = partial(clip_score, model_name_or_path=pretrained_model_name_or_path)
self.preprocess = transforms.Compose([
transforms.ToTensor()
])

def evaluate(self, images, prompts):
# Preprocess images
images_tensor = torch.stack([self.preprocess(img) * 255 for img in images])
# Compute CLIP scores
result = self.clip_score_fn(images_tensor, prompts).detach().cpu()

return result
13 changes: 13 additions & 0 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from helpers.caching.memory import reclaim_memory
from helpers.training.multi_process import _get_rank as get_rank
from helpers.training.validation import Validation, prepare_validation_prompt_list
from helpers.training.evaluation import ModelEvaluator
from helpers.training.state_tracker import StateTracker
from helpers.training.schedulers import load_scheduler_from_args
from helpers.training.custom_schedule import get_lr_scheduler
Expand Down Expand Up @@ -468,6 +469,9 @@ def init_vae(self, move_to_accelerator: bool = True):
)
self.config.vae_kwargs["subfolder"] = None
self.vae = AutoencoderKL.from_pretrained(**self.config.vae_kwargs)
if self.vae is not None and self.config.vae_enable_tiling and hasattr(self.vae, 'enable_tiling'):
logger.warning("Enabling VAE tiling for greatly reduced memory consumption due to --vae_enable_tiling which may result in VAE tiling artifacts in encoded latents.")
self.vae.enable_tiling()
if not move_to_accelerator:
logger.debug("Not moving VAE to accelerator.")
return
Expand Down Expand Up @@ -1379,6 +1383,7 @@ def init_validations(self):
):
logger.error("Cannot run validations with DeepSpeed ZeRO stage 3.")
return
model_evaluator = ModelEvaluator.from_config(args=self.config)
self.validation = Validation(
accelerator=self.accelerator,
unet=self.unet,
Expand All @@ -1400,6 +1405,7 @@ def init_validations(self):
ema_model=self.ema_model,
vae=self.vae,
controlnet=self.controlnet if self.config.controlnet else None,
model_evaluator=model_evaluator
)
if not self.config.train_text_encoder and self.validation is not None:
self.validation.clear_text_encoders()
Expand Down Expand Up @@ -2673,6 +2679,13 @@ def train(self):
self.guidance_values_list = []
if grad_norm is not None:
wandb_logs["grad_norm"] = grad_norm
if self.validation is not None and hasattr(self.validation, 'evaluation_result'):
eval_result = self.validation.get_eval_result()
if eval_result is not None and type(eval_result) == dict:
# add the dict to wandb_logs
self.validation.clear_eval_result()
wandb_logs.update(eval_result)

progress_bar.update(1)
self.state["global_step"] += 1
current_epoch_step += 1
Expand Down
Loading

0 comments on commit 01beb97

Please sign in to comment.