Skip to content

Commit

Permalink
initial code update;
Browse files Browse the repository at this point in the history
Signed-off-by: junsongc <[email protected]>
  • Loading branch information
xieenze authored and lawrence-cj committed Nov 21, 2024
1 parent 80976a9 commit 6cfe51e
Show file tree
Hide file tree
Showing 109 changed files with 14,813 additions and 26 deletions.
33 changes: 20 additions & 13 deletions .gitignore
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# Sana related files
.idea/
*.png
*.json
tmp*
output*
output/
outputs/
wandb/
.vscode/
private/
ldm_ae*
data/*
*.pth
.gradio/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down Expand Up @@ -106,8 +122,10 @@ ipython_config.py
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
Expand Down Expand Up @@ -157,15 +175,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

*png
*json
tmp*
output/
wandb/
.vscode/
private/
ldm_ae*
data/*
*pth
#.idea/
117 changes: 117 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
Copyright (c) 2019, NVIDIA Corporation. All rights reserved.


Nvidia Source Code License-NC

=======================================================================

1. Definitions

“Licensor” means any person or entity that distributes its Work.

“Work” means (a) the original work of authorship made available under
this license, which may include software, documentation, or other
files, and (b) any additions to or derivative works thereof
that are made available under this license.

“NVIDIA Processors” means any central processing unit (CPU),
graphics processing unit (GPU), field-programmable gate array (FPGA),
application-specific integrated circuit (ASIC) or any combination
thereof designed, made, sold, or provided by NVIDIA or its affiliates.

The terms “reproduce,” “reproduction,” “derivative works,” and
“distribution” have the meaning as provided under U.S. copyright law;
provided, however, that for the purposes of this license, derivative
works shall not include works that remain separable from, or merely
link (or bind by name) to the interfaces of, the Work.

Works are “made available” under this license by including in or with
the Work either (a) a copyright notice referencing the applicability
of this license to the Work, or (b) a copy of this license.

"Safe Model" means ShieldGemma-2B, which is a series of safety
content moderation models designed to moderate four categories of
harmful content: sexually explicit material, dangerous content,
hate speech, and harassment, and which you separately obtain
from Google at https://huggingface.co/google/shieldgemma-2b.


2. License Grant

2.1 Copyright Grant. Subject to the terms and conditions of this
license, each Licensor grants to you a perpetual, worldwide,
non-exclusive, royalty-free, copyright license to use, reproduce,
prepare derivative works of, publicly display, publicly perform,
sublicense and distribute its Work and any resulting derivative
works in any form.

3. Limitations

3.1 Redistribution. You may reproduce or distribute the Work only if
(a) you do so under this license, (b) you include a complete copy of
this license with your distribution, and (c) you retain without
modification any copyright, patent, trademark, or attribution notices
that are present in the Work.

3.2 Derivative Works. You may specify that additional or different
terms apply to the use, reproduction, and distribution of your
derivative works of the Work (“Your Terms”) only if (a) Your Terms
provide that the use limitation in Section 3.3 applies to your
derivative works, and (b) you identify the specific derivative works
that are subject to Your Terms. Notwithstanding Your Terms, this
license (including the redistribution requirements in Section 3.1)
will continue to apply to the Work itself.

3.3 Use Limitation. The Work and any derivative works thereof only may
be used or intended for use non-commercially and with NVIDIA Processors,
in accordance with Section 3.4, below. Notwithstanding the foregoing,
NVIDIA Corporation and its affiliates may use the Work and any
derivative works commercially. As used herein, “non-commercially”
means for research or evaluation purposes only.

3.4 You shall filter your input content to the Work and any derivative
works thereof through the Safe Model to ensure that no content described
as Not Safe For Work (NSFW) is processed or generated. You shall not use
the Work to process or generate NSFW content. You are solely responsible
for any damages and liabilities arising from your failure to adequately
filter content in accordance with this section. As used herein,
“Not Safe For Work” or “NSFW” means content, videos or website pages
that contain potentially disturbing subject matter, including but not
limited to content that is sexually explicit, dangerous, hate,
or harassment.

3.5 Patent Claims. If you bring or threaten to bring a patent claim
against any Licensor (including any claim, cross-claim or counterclaim
in a lawsuit) to enforce any patents that you allege are infringed by
any Work, then your rights under this license from such Licensor
(including the grant in Section 2.1) will terminate immediately.

3.6 Trademarks. This license does not grant any rights to use any
Licensor’s or its affiliates’ names, logos, or trademarks, except as
necessary to reproduce the notices described in this license.

3.7 Termination. If you violate any term of this license, then your
rights under this license (including the grant in Section 2.1) will
terminate immediately.

4. Disclaimer of Warranty.

THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES
UNDER THIS LICENSE.

5. Limitation of Liability.

EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE
POSSIBILITY OF SUCH DAMAGES.

=======================================================================
153 changes: 140 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<a href="https://hanlab.mit.edu/projects/sana/"><img src="https://img.shields.io/static/v1?label=Page&message=MIT&color=darkred&logo=github-pages"></a> &ensp;
<a href="https://arxiv.org/abs/2410.10629"><img src="https://img.shields.io/static/v1?label=Arxiv&message=Sana&color=red&logo=arxiv"></a> &ensp;
<a href="https://nv-sana.mit.edu/"><img src="https://img.shields.io/static/v1?label=Demo&message=MIT&color=yellow"></a> &ensp;
<a href="https://discord.gg/rde6eaE5Ta"><img src="https://img.shields.io/static/v1?label=Discuss&message=Discord&color=purple&logo=discord"></a> &ensp;
</div>

<p align="center" border-raduis="10px">
Expand All @@ -34,13 +35,22 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.

## 🔥🔥 News

- Sana code is coming soon
- (🔥 New) \[2024/10\] [Demo](https://nv-sana.mit.edu/) is released.
- (🔥 New) \[2024/10\] [DC-AE Code](https://github.com/mit-han-lab/efficientvit/blob/master/applications/dc_ae/README.md) and [weights](https://huggingface.co/collections/mit-han-lab/dc-ae-670085b9400ad7197bb1009b) are released!
- (🔥 New) \[2024/11\] Training & Inference & Metrics code are released.
- \[2024/10\] [Demo](https://nv-sana.mit.edu/) is released.
- \[2024/10\] [DC-AE Code](https://github.com/mit-han-lab/efficientvit/blob/master/applications/dc_ae/README.md) and [weights](https://huggingface.co/collections/mit-han-lab/dc-ae-670085b9400ad7197bb1009b) are released!
- \[2024/10\] [Paper](https://arxiv.org/abs/2410.10629) is on Arxiv!

## Performance

| Methods (1024x1024) | Throughput (samples/s) | Latency (s) | Params (B) | Speedup | FID 👆 | CLIP 👆 | GenEval 👆 | DPG 👆 |
|------------------------------|------------------------|-------------|------------|-----------|-------------|--------------|-------------|-------------|
| FLUX-dev | 0.04 | 23.0 | 12.0 | 1.0× | 10.15 | 27.47 | _0.67_ | _84.0_ |
| **Sana-0.6B** | 1.7 | 0.9 | 0.6 | **39.5×** | <u>5.81</u> | 28.36 | 0.64 | 83.6 |
| **Sana-1.6B** | 1.0 | 1.2 | 1.6 | **23.3×** | **5.76** | <u>28.67</u> | <u>0.66</u> | **84.8** |

<details>
<summary><h3>Click to show all</h3></summary>

| Methods | Throughput (samples/s) | Latency (s) | Params (B) | Speedup | FID 👆 | CLIP 👆 | GenEval 👆 | DPG 👆 |
|------------------------------|------------------------|-------------|------------|-----------|-------------|--------------|-------------|-------------|
| _**512 × 512 resolution**_ | | | | | | | | |
Expand All @@ -61,28 +71,149 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.
| **Sana-0.6B** | 1.7 | 0.9 | 0.6 | **39.5×** | <u>5.81</u> | 28.36 | 0.64 | 83.6 |
| **Sana-1.6B** | 1.0 | 1.2 | 1.6 | **23.3×** | **5.76** | <u>28.67</u> | <u>0.66</u> | **84.8** |

</details>

## Contents

- [Env](#-1-dependencies-and-installation)
- [Demo](#-3-how-to-inference)
- [Training](#-2-how-to-train)
- [Testing](#-4-how-to-inference--test-metrics-fid-clip-score-geneval-dpg-bench-etc)
- [TODO](#to-do-list)
- [Citation](#bibtex)

## 💪To-Do List
# 🔧 1. Dependencies and Installation

- Python >= 3.10.0 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
- [PyTorch >= 2.0.1+cu12.1](https://pytorch.org/)

```bash
git clone https://github.com/NVlabs/Sana.git
cd Sana

./environment_setup.sh sana
# or you can install each components step by step following environment_setup.sh
```

# 💻 2. How to Play with Sana (Inference)

## 💰Hardware requirement

- 9GB VRAM is required for 0.6B model and 12GB VRAM for 1.6B model. Our later quantization version will require less than 8GB for inference.
- All the tests are done on A100 GPUs. Different GPU version may be different.

## 🔛 Quick start with [Gradio](https://www.gradio.app/guides/quickstart)

```bash
# official online demo
DEMO_PORT=15432 \
pyhton app/sana_app.py \
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
```

```python
import torch
from app.sana_pipeline import SanaPipeline
from torchvision.utils import save_image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator = torch.Generator(device=device).manual_seed(42)

sana = SanaPipeline("configs/sana_config/1024ms/Sana_1600M_img1024.yaml")
sana.from_pretrained("hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth")
prompt = 'a cyberpunk cat with a neon sign that says "Sana"'

image = sana(
prompt=prompt,
height=1024,
width=1024,
guidance_scale=5.0,
pag_guidance_scale=2.0,
num_inference_steps=18,
generator=generator,
)
save_image(image, 'output/sana.png', nrow=1, normalize=True, value_range=(-1, 1))
```

## 🔛 Run inference with TXT or JSON files

```bash
# Run samples in a txt file
python scripts/inference.py \
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
--txt_file=asset/samples_mini.txt

# Run samples in a json file
python scripts/inference.py \
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
--json_file=asset/samples_mini.json
```

where each line of [`asset/samples_mini.txt`](asset/samples_mini.txt) contains a prompt to generate

# 🔥 3. How to Train Sana

## 💰Hardware requirement

- 32GB VRAM is required for both 0.6B and 1.6B model's training

We provide a training example here and you can also select your desired config file from [config files dir](configs/sana_config) based on your data structure.

To launch Sana training, you will first need to prepare data in the following formats

```bash
asset/example_data
├── AAA.txt
├── AAA.png
├── BCC.txt
├── BCC.png
├── ......
├── CCC.txt
└── CCC.png
```

Then Sana's training can be launched via

```bash
# Example of training Sana 0.6B with 512x512 resolution
bash train_scripts/train.sh \
configs/sana_config/512ms/Sana_600M_img512.yaml \
--data.data_dir="[asset/example_data]" \
--data.type=SanaImgDataset \
--model.multi_scale=false \
--train.train_batch_size=32

# Example of training Sana 1.6B with 1024x1024 resolution
bash train_scripts/train.sh \
configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
--data.data_dir="[asset/example_data]" \
--data.type=SanaImgDataset \
--model.multi_scale=false \
--train.train_batch_size=8
```

# 💻 4. Metric toolkit

Refer to [Toolkit Manual](asset/docs/metrics_toolkit.md).

# 💪To-Do List

We will try our best to release

- \[ \] Training code
- \[ \] Inference code
- \[x\] Training code
- \[x\] Inference code
- \[ \] Model zoo
- \[ \] Diffusers
- \[ \] ComfyUI
- \[ \] Laptop development

# 🤗Acknowledgements

- Thanks to [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha), [PixArt-Σ](https://github.com/PixArt-alpha/PixArt-sigma) and [Efficient-ViT](https://github.com/mit-han-lab/efficientvit) for their wonderful work and codebase!

[//]: # (- Thanks to [Diffusers]&#40;https://github.com/huggingface/diffusers&#41; for their wonderful technical support and awesome collaboration!)
[//]: # (- Thanks to [Hugging Face]&#40;https://github.com/huggingface&#41; for sponsoring the nicely demo!)

# 📖BibTeX

```
Expand All @@ -96,7 +227,3 @@ We will try our best to release
url={https://arxiv.org/abs/2410.10629},
}
```

[//]: # (## Star History)

[//]: # ([![Star History Chart]&#40;https://api.star-history.com/svg?repos=NVlabs/Sana&type=Date&#41;]&#40;https://star-history.com/#NVlabs/sana&Date&#41;)
Loading

0 comments on commit 6cfe51e

Please sign in to comment.