Skip to content

Commit

Permalink
add Sana-LoRA training and guidance (#98)
Browse files Browse the repository at this point in the history
* code update;

Signed-off-by: lawrence-cj <[email protected]>

* add sana-lora training files and update README.md;

Signed-off-by: lawrence-cj <[email protected]>

* fix the bug for dreambooth-sana-lora training;
update sana-lora README.md

Signed-off-by: lawrence-cj <[email protected]>

---------

Signed-off-by: lawrence-cj <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
lawrence-cj and sayakpaul authored Dec 18, 2024
1 parent 2799bdc commit cbd7a63
Show file tree
Hide file tree
Showing 6 changed files with 1,598 additions and 15 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ ldm_ae*
data/*
*.pth
.gradio/
*.bin
*.safetensors
*.pkl

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.

## 🔥🔥 News

- (🔥 New) \[2024/12/18\] `diffusers` supports Sana-LoRA fine-tuning! Sana-LoRA's training and convergence speed is supper fast. [\[Guidance\]](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sana.md). Thanks to [@paul](https://github.com/sayakpaul).
- (🔥 New) \[2024/12/18\] `diffusers` supports Sana-LoRA fine-tuning! Sana-LoRA's training and convergence speed is supper fast. [\[Guidance\]](asset/docs/sana_lora_dreambooth.md) or [\[diffusers docs\]](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sana.md).
- (🔥 New) \[2024/12/13\] `diffusers` has Sana! [All Sana models in diffusers safetensors](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released and diffusers pipeline `SanaPipeline`, `SanaPAGPipeline`, `DPMSolverMultistepScheduler(with FlowMatching)` are all supported now. We prepare a [Model Card](asset/docs/model_zoo.md) for you to choose.
- (🔥 New) \[2024/12/10\] 1.6B BF16 [Sana model](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16) is released for stable fine-tuning.
- (🔥 New) \[2024/12/9\] We release the [ComfyUI node](https://github.com/Efficient-Large-Model/ComfyUI_ExtraModels) for Sana. [\[Guidance\]](asset/docs/ComfyUI/comfyui.md)
Expand Down
41 changes: 29 additions & 12 deletions asset/docs/sana_lora_dreambooth.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@ cd diffusers
pip install -e .
```

Then cd in the `examples/dreambooth` folder and run

```bash
pip install -r requirements_sana.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
Expand Down Expand Up @@ -59,7 +53,7 @@ Let's first download it locally:
```python
from huggingface_hub import snapshot_download

local_dir = "./dog"
local_dir = "data/dreambooth/dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
Expand All @@ -71,14 +65,21 @@ This will also allow us to push the trained LoRA parameters to the Hugging Face

[Here is the Model Card](model_zoo.md) for you to choose the desired pre-trained models and set it to `MODEL_NAME`.

Now, we can launch training using:
Now, we can launch training using [file here](../../train_scripts/train_lora.sh):

```bash
bash train_scripts/train_lora.sh
```

or you can run it locally:

```bash
export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_diffusers"
export INSTANCE_DIR="dog"
export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
export INSTANCE_DIR="data/dreambooth/dog"
export OUTPUT_DIR="trained-sana-lora"

accelerate launch train_dreambooth_lora_sana.py \
accelerate launch --num_processes 8 --main_process_port 29500 --gpu_ids 0,1,2,3 \
train_scripts/train_dreambooth_lora_sana.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
Expand All @@ -93,7 +94,7 @@ accelerate launch train_dreambooth_lora_sana.py \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_prompt="A photo of sks dog in a pond, yarn art style" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
Expand Down Expand Up @@ -125,3 +126,19 @@ We provide several options for optimizing memory optimization:
- `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.

Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.

## Samples

We show some samples during Sana-LoRA fine-tuning process below.

<p align="center" border-raduis="10px">
<img src="https://nvlabs.github.io/Sana/asset/content/dreambooth/step0.jpg" width="90%" alt="sana-lora-step0"/>
<br>
<em> training samples at step=0 </em>
</p>

<p align="center" border-raduis="10px">
<img src="https://nvlabs.github.io/Sana/asset/content/dreambooth/step500.jpg" width="90%" alt="sana-lora-step500"/>
<br>
<em> training samples at step=500 </em>
</p>
4 changes: 2 additions & 2 deletions diffusion/data/datasets/sana_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
self.logger.info(f"Loading external caption json from: original_filename{external_caption_suffixes}.json")
self.logger.info(f"Loading external clipscore json from: original_filename{external_clipscore_suffixes}.json")
self.logger.info(f"external caption clipscore threshold: {clip_thr}, temperature: {clip_thr_temperature}")
self.logger.info(f"T5 max token length: {self.max_length}")
self.logger.info(f"Text max token length: {self.max_length}")

def getdata(self, idx):
data = self.dataset[idx]
Expand Down Expand Up @@ -288,7 +288,7 @@ def __init__(
self.logger.info(f"Loading external caption json from: original_filename{external_caption_suffixes}.json")
self.logger.info(f"Loading external clipscore json from: original_filename{external_clipscore_suffixes}.json")
self.logger.info(f"external caption clipscore threshold: {clip_thr}, temperature: {clip_thr_temperature}")
self.logger.info(f"T5 max token length: {self.max_length}")
self.logger.info(f"Text max token length: {self.max_length}")
self.logger.warning(f"Sort the dataset: {sort_dataset}")

def _initialize_dataset(self, num_replicas, sort_dataset):
Expand Down
Loading

0 comments on commit cbd7a63

Please sign in to comment.