This repository is the official implementation of the paper Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation.
Images generated using 4 NFEs
In this paper, we propose an efficient, fast, versatile and LoRA-compatible distillation method to accelerate the generation of pre-trained diffusion models: Flash Diffusion. The method reaches state-of-the-art performances in terms of FID and CLIP-Score for few steps image generation on the COCO 2014 and COCO 2017 datasets, while requiring only several GPU hours of training and fewer trainable parameters than existing methods. In addition to its efficiency, the versatility of the method is also exposed across several tasks such as text-to-image, inpainting, face-swapping, super-resolution and using different diffusion models backbones either using a UNet-based denoisers (SD1.5, SDXL) or DiT (Pixart-α), as well as adapters. In all cases, the method allowed to reduce drastically the number of sampling steps while maintaining very high-quality image generation.
- Method overview
- Results overview
- Installation 🛠️
- Text2Image model distillation
- Distilling a custom model 🚀
- Inference with 🤗 Hugging Face pipelines
- Using Flash with ComfyUI
- Flash for training-free LoRAs acceleration 🎨
- Citing this repository
Our method aims to create a fast, reliable, and adaptable approach for various uses. We propose to train a student model to predict in a single step a denoised multiple-step teacher prediction of a corrupted input sample. Additionally, we sample timesteps from an adaptable distribution that shifts during training to help the student model target specific timesteps.
Flash Diffusion is compatible with various backbones such as
- Flash Stable Diffusion 3, distilled from a Stable Diffusion 3 teacher
- Flash SDXL, distilled from a SDXL teacher
- Flash Pixart (DiT), distilled from a Pixart-α teacher
- Flash SD, distilled from a SD1.5 teacher
It can also be used to accelerate existing LoRAs in a training-free manner. See this section to know more.
Images generated using 4 NFEs
Images generated using 4 NFEs
Images generated using 4 NFEs
Images generated using 4 NFEs
Images generated using 4 NFEs
To be up and running, you need first to create a virtual env with at least python3.10
installed and activate it
python3.10 -m venv envs/flash_diffusion
source envs/flash_diffusion/bin/activate
conda create -n flash_diffusion python=3.10
conda activate flash_diffusion
Then install the required dependencies (if on GPU) and the repo in editable mode
pip install --upgrade pip
pip install -r requirements.txt
pip install -e .
The main scripts to reproduce the main experiments of the paper are located in the examples
. We provide 4 diffirent scripts:
train_flash_sd3.py
: Distils SD3 modeltrain_flash_sdxl.py
: Distils SDXL modeltrain_flash_pixart
: Distils Pixart-α modeltrain_flash_canny_adapter.py
: Distils a T2I Canny Adaptertrain_flash_sd.py
: Distils SD1.5 model
In examples\configs
, you will find the configuration yaml
associated to each script. The only thing you need is to amend the SHARDS_PATH_OR_URLS
section of the yaml
so the model is trained on your own data. Please note that this package uses webdataset
to handle the datastream and so the urls you use should be fomatted according to the webdataset format
. In particular, for those 4 examples, each sample needs to be composed of a jpg
file containing the image and a json
file containing the caption under the key caption
and the image aesthetics score aesthetic_score
:
sample = {
"jpg": dummy_image,
"json": {
"caption": "dummy caption",
"aesthetic_score": 6.0
}
}
The scripts can then be launched by simply runing
# Set the number of gpus & nodes you want to use
export SLURM_NPROCS=1
export SLURM_NNODES=1
# Distills SD1.5
python3.10 examples/train_flash_sd.py
# Distills SDXL1.0
python3.10 examples/train_flash_sdxl.py
# Distills Pixart-α
python3.10 examples/train_flash_pixart.py
# Distills T2I Canny adapter
python3.10 examples/train_flash_canny_adapter.py
This package is also intended to support custom model distillation.
from copy import deepcopy
from flash.models.unets import DiffusersUNet2DCondWrapper
from flash.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig
from flash.models.embedders import (
ClipEmbedder,
ClipEmbedderConfig,
ClipEmbedderWithProjection,
ConditionerWrapper,
)
# Create the VAE
vae_config = AutoencoderKLDiffusersConfig(
"stabilityai/sdxl-vae" # VAE for HF Hub
)
vae = AutoencoderKLDiffusers(config=vae_config)
## Create the Conditioners ##
# A Clip conditioner returning 2 types of conditioning
embedder_1_config = ClipEmbedderConfig(
version="stabilityai/stable-diffusion-xl-base-1.0", # from HF Hub
text_embedder_subfolder="text_encoder_2",
tokenizer_subfolder="tokenizer_2",
input_key="text",
always_return_pooled=True, # Return a 1-dimensional tensor
)
embeddder_1 = ClipEmbedder(config=embedder_1_config)
# Embedder acting on a lr image injected in the UNET via concatenation
embedder_2_config = TorchNNEmbedderConfig(
nn_modules=["torch.nn.Conv2d"],
nn_modules_kwargs=[
dict(
in_channels=3,
out_channels=6,
kernel_size=3,
padding=1,
stride=2,
),
],
input_key="downsampled_image",
unconditional_conditioning_rate=request.param,
)
embedder_2 = TorchNNEmbedder(config=embedder_2_config)
conditioner_wrapper = ConditionerWrapper(
conditioners=[embedder1, embedder2]
)
# Create the Teacher denoiser
unet = DiffusersUNet2DCondWrapper(
in_channels=4 + 6, # VAE channels + concat conditioning
out_channels=4, # VAE channels
cross_attention_dim=1280, # cross-attention conditioning
projection_class_embeddings_input_dim=1280, # add conditioning
class_embed_type="projection",
)
# Load the teacher weights
...
# Create the student denoiser
student_denoiser = deepcopy(teacher_denoiser)
import torch
from diffusers import PixArtAlphaPipeline, Transformer2DModel, LCMScheduler
from peft import PeftModel
# Load LoRA
transformer = Transformer2DModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="transformer",
torch_dtype=torch.float16
)
transformer = PeftModel.from_pretrained(
transformer,
"jasperai/flash-pixart"
)
# Pipeline
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
transformer=transformer,
torch_dtype=torch.float16
)
# Scheduler
pipe.scheduler = LCMScheduler.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="scheduler",
timestep_spacing="trailing",
)
pipe.to("cuda")
prompt = "A raccoon reading a book in a lush forest."
image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0]
To use FlashSDXL locally using Comfyui you need to :
- Make sure your comfyUI install is up to date
- Download the checkpoint from huggingface. In case you wonder how, go to "Files and Version" go to
comfy
folder and hit the download button next to theFlashSDXL.safetensors
- Move the new checkpoint file to your local
comfyUI/models/loras/
. folder Use it as a LoRA on top ofsd_xl_base_1.0_0.9vae.safetensors
, a simple comfyui workflow.json is provided in this repo incexamples/comfy
Disclaimer : Model has been trained to work with a cfg scale of 1 and a lcm scheduler but parameters can be tweaked a bit.
Flash Diffusion models can also be combined with existing LoRAs to unlock few steps generation in a training free manner. They can be integrated straight to Hugging Face pipelines. See an example below.
from diffusers import DiffusionPipeline, LCMScheduler
import torch
user_lora_id = "TheLastBen/Papercut_SDXL"
trigger_word = "papercut"
flash_lora_id = "jasperai/flash-sdxl"
# Load Pipeline
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
variant="fp16"
)
# Set scheduler
pipe.scheduler = LCMScheduler.from_config(
pipe.scheduler.config
)
# Load LoRAs
pipe.load_lora_weights(flash_lora_id, adapter_name="flash")
pipe.load_lora_weights(user_lora_id, adapter_name="lora")
pipe.set_adapters(["flash", "lora"], adapter_weights=[1.0, 1.0])
pipe.to(device="cuda", dtype=torch.float16)
prompt = f"{trigger_word} a cute corgi"
image = pipe(
prompt,
num_inference_steps=4,
guidance_scale=0
).images[0]
This code is released under the Creative Commons BY-NC 4.0 license.
If you find this work useful or use it in your research, please consider citing us
@misc{chadebec2024flash,
title={Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation},
author={Clement Chadebec and Onur Tasar and Eyal Benaroche and Benjamin Aubin},
year={2024},
eprint={2406.02347},
archivePrefix={arXiv},
primaryClass={cs.CV}
}