diff --git a/.gitignore b/.gitignore
old mode 100644
new mode 100755
index d785003..7b112ef
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,19 @@
+# Sana related files
+.idea/
+*.png
+*.json
+tmp*
+output*
+output/
+outputs/
+wandb/
+.vscode/
+private/
+ldm_ae*
+data/*
+*.pth
+.gradio/
+
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@@ -106,8 +122,10 @@ ipython_config.py
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
-# https://pdm.fming.dev/#use-with-ide
+# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
+.pdm-python
+.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
@@ -157,15 +175,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
-.idea/
-
-*png
-*json
-tmp*
-output/
-wandb/
-.vscode/
-private/
-ldm_ae*
-data/*
-*pth
+#.idea/
diff --git a/LICENSE b/LICENSE
new file mode 100755
index 0000000..af26232
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,117 @@
+Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+
+
+Nvidia Source Code License-NC
+
+=======================================================================
+
+1. Definitions
+
+“Licensor” means any person or entity that distributes its Work.
+
+“Work” means (a) the original work of authorship made available under
+this license, which may include software, documentation, or other
+files, and (b) any additions to or derivative works thereof
+that are made available under this license.
+
+“NVIDIA Processors” means any central processing unit (CPU),
+graphics processing unit (GPU), field-programmable gate array (FPGA),
+application-specific integrated circuit (ASIC) or any combination
+thereof designed, made, sold, or provided by NVIDIA or its affiliates.
+
+The terms “reproduce,” “reproduction,” “derivative works,” and
+“distribution” have the meaning as provided under U.S. copyright law;
+provided, however, that for the purposes of this license, derivative
+works shall not include works that remain separable from, or merely
+link (or bind by name) to the interfaces of, the Work.
+
+Works are “made available” under this license by including in or with
+the Work either (a) a copyright notice referencing the applicability
+of this license to the Work, or (b) a copy of this license.
+
+"Safe Model" means ShieldGemma-2B, which is a series of safety
+content moderation models designed to moderate four categories of
+harmful content: sexually explicit material, dangerous content,
+hate speech, and harassment, and which you separately obtain
+from Google at https://huggingface.co/google/shieldgemma-2b.
+
+
+2. License Grant
+
+2.1 Copyright Grant. Subject to the terms and conditions of this
+license, each Licensor grants to you a perpetual, worldwide,
+non-exclusive, royalty-free, copyright license to use, reproduce,
+prepare derivative works of, publicly display, publicly perform,
+sublicense and distribute its Work and any resulting derivative
+works in any form.
+
+3. Limitations
+
+3.1 Redistribution. You may reproduce or distribute the Work only if
+(a) you do so under this license, (b) you include a complete copy of
+this license with your distribution, and (c) you retain without
+modification any copyright, patent, trademark, or attribution notices
+that are present in the Work.
+
+3.2 Derivative Works. You may specify that additional or different
+terms apply to the use, reproduction, and distribution of your
+derivative works of the Work (“Your Terms”) only if (a) Your Terms
+provide that the use limitation in Section 3.3 applies to your
+derivative works, and (b) you identify the specific derivative works
+that are subject to Your Terms. Notwithstanding Your Terms, this
+license (including the redistribution requirements in Section 3.1)
+will continue to apply to the Work itself.
+
+3.3 Use Limitation. The Work and any derivative works thereof only may
+be used or intended for use non-commercially and with NVIDIA Processors,
+in accordance with Section 3.4, below. Notwithstanding the foregoing,
+NVIDIA Corporation and its affiliates may use the Work and any
+derivative works commercially. As used herein, “non-commercially”
+means for research or evaluation purposes only.
+
+3.4 You shall filter your input content to the Work and any derivative
+works thereof through the Safe Model to ensure that no content described
+as Not Safe For Work (NSFW) is processed or generated. You shall not use
+the Work to process or generate NSFW content. You are solely responsible
+for any damages and liabilities arising from your failure to adequately
+filter content in accordance with this section. As used herein,
+“Not Safe For Work” or “NSFW” means content, videos or website pages
+that contain potentially disturbing subject matter, including but not
+limited to content that is sexually explicit, dangerous, hate,
+or harassment.
+
+3.5 Patent Claims. If you bring or threaten to bring a patent claim
+against any Licensor (including any claim, cross-claim or counterclaim
+in a lawsuit) to enforce any patents that you allege are infringed by
+any Work, then your rights under this license from such Licensor
+(including the grant in Section 2.1) will terminate immediately.
+
+3.6 Trademarks. This license does not grant any rights to use any
+Licensor’s or its affiliates’ names, logos, or trademarks, except as
+necessary to reproduce the notices described in this license.
+
+3.7 Termination. If you violate any term of this license, then your
+rights under this license (including the grant in Section 2.1) will
+terminate immediately.
+
+4. Disclaimer of Warranty.
+
+THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
+NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES
+UNDER THIS LICENSE.
+
+5. Limitation of Liability.
+
+EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
+THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
+SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
+INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
+OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
+LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
+DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE
+POSSIBILITY OF SUCH DAMAGES.
+
+=======================================================================
diff --git a/README.md b/README.md
index 101e8b5..dd7655e 100644
--- a/README.md
+++ b/README.md
@@ -9,6 +9,7 @@
+
@@ -34,13 +35,22 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.
## 🔥🔥 News
-- Sana code is coming soon
-- (🔥 New) \[2024/10\] [Demo](https://nv-sana.mit.edu/) is released.
-- (🔥 New) \[2024/10\] [DC-AE Code](https://github.com/mit-han-lab/efficientvit/blob/master/applications/dc_ae/README.md) and [weights](https://huggingface.co/collections/mit-han-lab/dc-ae-670085b9400ad7197bb1009b) are released!
+- (🔥 New) \[2024/11\] Training & Inference & Metrics code are released.
+- \[2024/10\] [Demo](https://nv-sana.mit.edu/) is released.
+- \[2024/10\] [DC-AE Code](https://github.com/mit-han-lab/efficientvit/blob/master/applications/dc_ae/README.md) and [weights](https://huggingface.co/collections/mit-han-lab/dc-ae-670085b9400ad7197bb1009b) are released!
- \[2024/10\] [Paper](https://arxiv.org/abs/2410.10629) is on Arxiv!
## Performance
+| Methods (1024x1024) | Throughput (samples/s) | Latency (s) | Params (B) | Speedup | FID 👆 | CLIP 👆 | GenEval 👆 | DPG 👆 |
+|------------------------------|------------------------|-------------|------------|-----------|-------------|--------------|-------------|-------------|
+| FLUX-dev | 0.04 | 23.0 | 12.0 | 1.0× | 10.15 | 27.47 | _0.67_ | _84.0_ |
+| **Sana-0.6B** | 1.7 | 0.9 | 0.6 | **39.5×** | 5.81 | 28.36 | 0.64 | 83.6 |
+| **Sana-1.6B** | 1.0 | 1.2 | 1.6 | **23.3×** | **5.76** | 28.67 | 0.66 | **84.8** |
+
+
+ Click to show all
+
| Methods | Throughput (samples/s) | Latency (s) | Params (B) | Speedup | FID 👆 | CLIP 👆 | GenEval 👆 | DPG 👆 |
|------------------------------|------------------------|-------------|------------|-----------|-------------|--------------|-------------|-------------|
| _**512 × 512 resolution**_ | | | | | | | | |
@@ -61,28 +71,149 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.
| **Sana-0.6B** | 1.7 | 0.9 | 0.6 | **39.5×** | 5.81 | 28.36 | 0.64 | 83.6 |
| **Sana-1.6B** | 1.0 | 1.2 | 1.6 | **23.3×** | **5.76** | 28.67 | 0.66 | **84.8** |
+
+
## Contents
+- [Env](#-1-dependencies-and-installation)
+- [Demo](#-3-how-to-inference)
+- [Training](#-2-how-to-train)
+- [Testing](#-4-how-to-inference--test-metrics-fid-clip-score-geneval-dpg-bench-etc)
- [TODO](#to-do-list)
- [Citation](#bibtex)
-## 💪To-Do List
+# 🔧 1. Dependencies and Installation
+
+- Python >= 3.10.0 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
+- [PyTorch >= 2.0.1+cu12.1](https://pytorch.org/)
+
+```bash
+git clone https://github.com/NVlabs/Sana.git
+cd Sana
+
+./environment_setup.sh sana
+# or you can install each components step by step following environment_setup.sh
+```
+
+# 💻 2. How to Play with Sana (Inference)
+
+## 💰Hardware requirement
+
+- 9GB VRAM is required for 0.6B model and 12GB VRAM for 1.6B model. Our later quantization version will require less than 8GB for inference.
+- All the tests are done on A100 GPUs. Different GPU version may be different.
+
+## 🔛 Quick start with [Gradio](https://www.gradio.app/guides/quickstart)
+
+```bash
+# official online demo
+DEMO_PORT=15432 \
+pyhton app/sana_app.py \
+ --config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
+ --model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
+```
+
+```python
+import torch
+from app.sana_pipeline import SanaPipeline
+from torchvision.utils import save_image
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+generator = torch.Generator(device=device).manual_seed(42)
+
+sana = SanaPipeline("configs/sana_config/1024ms/Sana_1600M_img1024.yaml")
+sana.from_pretrained("hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth")
+prompt = 'a cyberpunk cat with a neon sign that says "Sana"'
+
+image = sana(
+ prompt=prompt,
+ height=1024,
+ width=1024,
+ guidance_scale=5.0,
+ pag_guidance_scale=2.0,
+ num_inference_steps=18,
+ generator=generator,
+)
+save_image(image, 'output/sana.png', nrow=1, normalize=True, value_range=(-1, 1))
+```
+
+## 🔛 Run inference with TXT or JSON files
+
+```bash
+# Run samples in a txt file
+python scripts/inference.py \
+ --config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
+ --model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
+ --txt_file=asset/samples_mini.txt
+
+# Run samples in a json file
+python scripts/inference.py \
+ --config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
+ --model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
+ --json_file=asset/samples_mini.json
+```
+
+where each line of [`asset/samples_mini.txt`](asset/samples_mini.txt) contains a prompt to generate
+
+# 🔥 3. How to Train Sana
+
+## 💰Hardware requirement
+
+- 32GB VRAM is required for both 0.6B and 1.6B model's training
+
+We provide a training example here and you can also select your desired config file from [config files dir](configs/sana_config) based on your data structure.
+
+To launch Sana training, you will first need to prepare data in the following formats
+
+```bash
+asset/example_data
+├── AAA.txt
+├── AAA.png
+├── BCC.txt
+├── BCC.png
+├── ......
+├── CCC.txt
+└── CCC.png
+```
+
+Then Sana's training can be launched via
+
+```bash
+# Example of training Sana 0.6B with 512x512 resolution
+bash train_scripts/train.sh \
+ configs/sana_config/512ms/Sana_600M_img512.yaml \
+ --data.data_dir="[asset/example_data]" \
+ --data.type=SanaImgDataset \
+ --model.multi_scale=false \
+ --train.train_batch_size=32
+
+# Example of training Sana 1.6B with 1024x1024 resolution
+bash train_scripts/train.sh \
+ configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
+ --data.data_dir="[asset/example_data]" \
+ --data.type=SanaImgDataset \
+ --model.multi_scale=false \
+ --train.train_batch_size=8
+```
+
+# 💻 4. Metric toolkit
+
+Refer to [Toolkit Manual](asset/docs/metrics_toolkit.md).
+
+# 💪To-Do List
We will try our best to release
-- \[ \] Training code
-- \[ \] Inference code
+- \[x\] Training code
+- \[x\] Inference code
- \[ \] Model zoo
- \[ \] Diffusers
- \[ \] ComfyUI
+- \[ \] Laptop development
# 🤗Acknowledgements
- Thanks to [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha), [PixArt-Σ](https://github.com/PixArt-alpha/PixArt-sigma) and [Efficient-ViT](https://github.com/mit-han-lab/efficientvit) for their wonderful work and codebase!
-[//]: # (- Thanks to [Diffusers](https://github.com/huggingface/diffusers) for their wonderful technical support and awesome collaboration!)
-[//]: # (- Thanks to [Hugging Face](https://github.com/huggingface) for sponsoring the nicely demo!)
-
# 📖BibTeX
```
@@ -96,7 +227,3 @@ We will try our best to release
url={https://arxiv.org/abs/2410.10629},
}
```
-
-[//]: # (## Star History)
-
-[//]: # ([![Star History Chart](https://api.star-history.com/svg?repos=NVlabs/Sana&type=Date)](https://star-history.com/#NVlabs/sana&Date))
diff --git a/asset/docs/metrics_toolkit.md b/asset/docs/metrics_toolkit.md
new file mode 100644
index 0000000..925b298
--- /dev/null
+++ b/asset/docs/metrics_toolkit.md
@@ -0,0 +1,118 @@
+# 💻 How to Inference & Test Metrics (FID, CLIP Score, GenEval, DPG-Bench, etc...)
+
+This ToolKit will automatically inference your model and log the metrics results onto wandb as chart for better illustration. We curerntly support:
+
+- \[x\] [FID](https://github.com/mseitzer/pytorch-fid) & [CLIP-Score](https://github.com/openai/CLIP)
+- \[x\] [GenEval](https://github.com/djghosh13/geneval)
+- \[x\] [DPG-Bench](https://github.com/TencentQQGYLab/ELLA)
+- \[x\] [ImageReward](https://github.com/THUDM/ImageReward/tree/main)
+
+### 0. Install corresponding env for GenEval and DPG-Bench
+
+Make sure you can activate the following envs:
+
+- `conda activate geneval`([GenEval](https://github.com/djghosh13/geneval))
+- `conda activate dpg`([DGB-Bench](https://github.com/TencentQQGYLab/ELLA))
+
+### 0.1 Prepare data.
+
+Metirc FID & CLIP-Score on [MJHQ-30K](https://huggingface.co/datasets/playgroundai/MJHQ-30K)
+
+```python
+from huggingface_hub import hf_hub_download
+
+hf_hub_download(
+ repo_id="playgroundai/MJHQ-30K",
+ filename="mjhq30k_imgs.zip",
+ local_dir="data/test/PG-eval-data/MJHQ-30K/",
+ repo_type="dataset"
+)
+```
+
+Unzip mjhq30k_imgs.zip into its per-category folder structure.
+
+```
+data/test/PG-eval-data/MJHQ-30K/imgs/
+├── animals
+├── art
+├── fashion
+├── food
+├── indoor
+├── landscape
+├── logo
+├── people
+├── plants
+└── vehicles
+```
+
+### 0.2 Prepare checkpoints
+
+```bash
+huggingface-cli download Efficient-Large-Model/Sana_1600M_1024px --repo-type model --local-dir ./output/Sana_1600M_1024px --local-dir-use-symlinks False
+```
+
+### 1. directly \[Inference and Metric\] a .pth file
+
+```bash
+# We provide four scripts for evaluating metrics:
+fid_clipscore_launch=scripts/bash_run_inference_metric.sh
+geneval_launch=scripts/bash_run_inference_metric_geneval.sh
+dpg_launch=scripts/bash_run_inference_metric_dpg.sh
+image_reward_launch=scripts/bash_run_inference_metric_imagereward.sh
+
+# Use following format to metric your models:
+# bash $correspoinding_metric_launch $your_config_file_path $your_relative_pth_file_path
+
+# example
+bash $geneval_launch \
+ configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
+ output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
+```
+
+### 2. \[Inference and Metric\] a list of .pth files using a txt file
+
+You can also write all your pth files of a job in one txt file, eg. [model_paths.txt](../model_paths.txt)
+
+```bash
+# Use following format to metric your models, gathering in a txt file:
+# bash $correspoinding_metric_launch $your_config_file_path $your_txt_file_path_containing_pth_path
+
+# We suggest follow the file tree structure in our project for robust experiment
+# example
+bash scripts/bash_run_inference_metric.sh \
+ configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
+ asset/model_paths.txt
+```
+
+### 3. You will get the following data tree.
+
+```
+output
+├──your_job_name/ (everything will be saved here)
+│ ├──config.yaml
+│ ├──train_log.log
+
+│ ├──checkpoints (all checkpoints)
+│ │ ├──epoch_1_step_6666.pth
+│ │ ├──epoch_1_step_8888.pth
+│ │ ├──......
+
+│ ├──vis (all visualization result dirs)
+│ │ ├──visualization_file_name
+│ │ │ ├──xxxxxxx.jpg
+│ │ │ ├──......
+│ │ ├──visualization_file_name2
+│ │ │ ├──xxxxxxx.jpg
+│ │ │ ├──......
+│ ├──......
+
+│ ├──metrics (all metrics testing related files)
+│ │ ├──model_paths.txt Optional(👈)(relative path of testing ckpts)
+│ │ │ ├──output/your_job_name/checkpoings/epoch_1_step_6666.pth
+│ │ │ ├──output/your_job_name/checkpoings/epoch_1_step_8888.pth
+│ │ ├──fid_img_paths.txt Optional(👈)(name of testing img_dir in vis)
+│ │ │ ├──visualization_file_name
+│ │ │ ├──visualization_file_name2
+│ │ ├──cached_img_paths.txt Optional(👈)
+│ │ ├──......
+```
diff --git a/asset/example_data/00000000.png b/asset/example_data/00000000.png
new file mode 100644
index 0000000..e4babe1
Binary files /dev/null and b/asset/example_data/00000000.png differ
diff --git a/asset/example_data/00000000.txt b/asset/example_data/00000000.txt
new file mode 100644
index 0000000..42be5fc
--- /dev/null
+++ b/asset/example_data/00000000.txt
@@ -0,0 +1 @@
+a cyberpunk cat with a neon sign that says "Sana".
diff --git a/asset/examples.py b/asset/examples.py
new file mode 100755
index 0000000..6e2a75f
--- /dev/null
+++ b/asset/examples.py
@@ -0,0 +1,69 @@
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+examples = [
+ [
+ "A small cactus with a happy face in the Sahara desert.",
+ "flow_dpm-solver",
+ 20,
+ 5.0,
+ 2.5,
+ ],
+ [
+ "An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history"
+ "of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits "
+ "mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret "
+ "and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile "
+ "as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and "
+ "the Parisian streets and city in the background, depth of field, cinematic 35mm film.",
+ "flow_dpm-solver",
+ 20,
+ 5.0,
+ 2.5,
+ ],
+ [
+ "An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. "
+ "Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. "
+ "The quote 'Find the universe within you' is etched in bold letters across the horizon."
+ "blue and pink, brilliantly illuminated in the background.",
+ "flow_dpm-solver",
+ 20,
+ 5.0,
+ 2.5,
+ ],
+ [
+ "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.",
+ "flow_dpm-solver",
+ 20,
+ 5.0,
+ 2.5,
+ ],
+ [
+ "A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.",
+ "flow_dpm-solver",
+ 20,
+ 5.0,
+ 2.5,
+ ],
+ [
+ "a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, "
+ "national geographic photo, 8k resolution, crayon art, interactive artwork",
+ "flow_dpm-solver",
+ 20,
+ 5.0,
+ 2.5,
+ ],
+]
diff --git a/asset/model-incremental.jpg b/asset/model-incremental.jpg
index 88bfe96..6107e93 100644
Binary files a/asset/model-incremental.jpg and b/asset/model-incremental.jpg differ
diff --git a/asset/model_paths.txt b/asset/model_paths.txt
new file mode 100644
index 0000000..8e62194
--- /dev/null
+++ b/asset/model_paths.txt
@@ -0,0 +1,2 @@
+output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
+output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
diff --git a/asset/samples.txt b/asset/samples.txt
new file mode 100755
index 0000000..31de23e
--- /dev/null
+++ b/asset/samples.txt
@@ -0,0 +1,125 @@
+A small cactus with a happy face in the Sahara desert.
+Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.
+beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background
+stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background.
+nature vs human nature, surreal, UHD, 8k, hyper details, rich colors, photograph.
+Spectacular Tiny World in the Transparent Jar On the Table, interior of the Great Hall, Elaborate, Carved Architecture, Anatomy, Symetrical, Geometric and Parameteric Details, Precision Flat line Details, Pattern, Dark fantasy, Dark errie mood and ineffably mysterious mood, Technical design, Intricate Ultra Detail, Ornate Detail, Stylized and Futuristic and Biomorphic Details, Architectural Concept, Low contrast Details, Cinematic Lighting, 8k, by moebius, Fullshot, Epic, Fullshot, Octane render, Unreal ,Photorealistic, Hyperrealism
+anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur
+The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8
+Bright scene, aerial view, ancient city, fantasy, gorgeous light, mirror reflection, high detail, wide angle lens.
+8k uhd A man looks up at the starry sky, lonely and ethereal, Minimalism, Chaotic composition Op Art
+A middle-aged woman of Asian descent, her dark hair streaked with silver, appears fractured and splintered, intricately embedded within a sea of broken porcelain. The porcelain glistens with splatter paint patterns in a harmonious blend of glossy and matte blues, greens, oranges, and reds, capturing her dance in a surreal juxtaposition of movement and stillness. Her skin tone, a light hue like the porcelain, adds an almost mystical quality to her form.
+A 4k dslr image of a lemur wearing a red magician hat and a blue coat performing magic tricks with cards in a garden.
+A alpaca made of colorful building blocks, cyberpunk
+A baby painter trying to draw very simple picture, white background
+A boy and a girl fall in love
+A dog that has been meditating all the time
+A man is sitting in a chair with his chin resting on his hand. The chair, along with the man's feet, are submerged in the sea. Strikingly, the man's back is on fire.
+A painter study hard to learn how to draw with many concepts in the air, white background
+A painter with low quality, white background, pixel art
+A person standing on the desert, desert waves, gossip illustration, half red, half blue, abstract image of sand, clear style, trendy illustration, outdoor, top view, clear style, precision art, ultra high definition image
+A silhouette of a grand piano overlooking a dusky cityscape viewed from a top-floor penthouse, rendered in the bold and vivid sytle of a vintage travel poster.
+A sureal parallel world where mankind avoid extinction by preserving nature, epic trees, water streams, various flowers, intricate details, rich colors, rich vegetation, cinematic, symmetrical, beautiful lighting, V-Ray render, sun rays, magical lights, photography
+A woman is shopping for fresh produce at the farmer's market.
+A worker that looks like a mixture of cow and horse is working hard to type code
+A young man dressed in ancient Chinese clothing, Asian people, White robe, Handsome, Hand gestures forming a spell, Martial arts and fairy-like vibe, Carrying a legendary-level giant sword on the back, Game character, Surrounded by runes, Cyberpunk style, neon lights, best quality, masterpiece, cg, hdr, high-definition, extremely detailed, photorealistic, epic, character design, detailed face, superhero, hero, detailed UHD, real-time, vfx, 3D rendering, 8k
+An alien octopus floats through a protal reading a newspaper
+An epressive oil painting of a basketbal player dunking, depicted as an explosion of a nebula
+art collection style and fashion shoot, in the style of made of glass, dark blue and light pink, paul rand, solarpunk, camille vivier, beth didonato hair, barbiecore, hyper-realistic
+artistic
+beautiful secen
+Crocodile in a sweater
+Design a letter A, 3D stereoscopic Ice material Interior light blue Conceptual product design Futuristic Blind box toy Handcrafted Exquisite 3D effect Full body display Ultra-high precision Ultra-detailed Perfect lighting OC Renderer Blender 8k Ultra-sharp Ultra-noise reduction
+Floating,colossal,futuristic statue in the sky, awe-inspiring and serenein the style of Stuart Lippincott:2with detailed composition and subtle geometric elements.This sanctuary-ike atmosphere features crisp clarity and soft amber tones.In contrasttiny human figures surround the statueThe pieceincorporates flowing draperiesreminiscent of Shwedoff and Philip McKay's stylesemphasizing thejuxtaposition between the powerful presence of the statue and thevulnerability of the minuscule human figuresshwedoff
+knolling of a drawing tools for painter
+Leonardo da Vinci's Last Supper content, Van Goph's Starry Night Style
+Luffy from ONEPIECE, handsome face, fantasy
+photography shot through an outdoor window of a coffee shop with neon sign lighting, window glares and reflections, depth of field, {little girl with red hair sitting at a table, portrait, kodak portra 800,105 mm f1.8
+poster of a mechanical cat, techical Schematics viewed from front and side view on light white blueprint paper, illustartion drafting style, illustation, typography, conceptual art, dark fantasy steampunk, cinematic, dark fantasy
+The girl in the car is filled with goldfish and flowers, goldfish can fly, Kawaguchi Renko's art, natural posture, holiday dadcore, youthful energy and pressure, body stretching, goldfish simulation movies in the sky, super details, and dreamy high photography. Colorful. Covered by water and goldfish, indoor scene, close-up shot in XT4 movie
+The image features a woman wearing a red shirt with an icon. She appears to be posing for the camera, and her outfit includes a pair of jeans. The woman seems to be in a good mood, as she is smiling. The background of the image is blurry, focusing more on the woman and her attire.
+The towel was on top of the hard counter.
+A vast landscape made entirely of various meats spreads out before the viewer. tender, succulent hills of roast beef, chicken drumstick trees, bacon rivers, and ham boulders create a surreal, yet appetizing scene. the sky is adorned with pepperoni sun and salami clouds.
+I want to supplement vitamin c, please help me paint related food.
+A vibrant yellow banana-shaped couch sits in a cozy living room, its curve cradling a pile of colorful cushions. on the wooden floor, a patterned rug adds a touch of eclectic charm, and a potted plant sits in the corner, reaching towards the sunlight filtering through the window.
+A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.
+A blue jay standing on a large basket of rainbow macarons.
+A bucket bag made of blue suede. The bag is decorated with intricate golden paisley patterns. The handle of the bag is made of rubies and pearls.
+An alien octopus floats through a portal reading a newspaper.
+bird's eye view of a city.
+beautiful scene
+A 2D animation of a folk music band composed of anthropomorphic autumn leaves, each playing traditional bluegrass instruments, amidst a rustic forest setting dappled with the soft light of a harvest moon.
+In front of a deep black backdrop, a figure of middle years, her Tongan skin rich and glowing, is captured mid-twirl, her curly hair flowing like a storm behind her. Her attire resembles a whirlwind of marble and porcelain fragments. Illuminated by the gleam of scattered porcelain shards, creating a dreamlike atmosphere, the dancer manages to appear fragmented, yet maintains a harmonious and fluid form.
+Digital illustration of a beach scene crafted from yarn. The sandy beach is depicted with beige yarn, waves are made of blue and white yarn crashing onto the shore. A yarn sun sets on the horizon, casting a warm glow. Yarn palm trees sway gently, and little yarn seashells dot the shoreline.
+Illustration of a chic chair with a design reminiscent of a pumpkin’s form, with deep orange cushioning, in a stylish loft setting.
+A detailed oil painting of an old sea captain, steering his ship through a storm. Saltwater is splashing against his weathered face, determination in his eyes. Twirling malevolent clouds are seen above and stern waves threaten to submerge the ship while seagulls dive and twirl through the chaotic landscape. Thunder and lights embark in the distance, illuminating the scene with an eerie green glow.
+An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. The quote 'Find the universe within you' is etched in bold letters across the horizon.
+A modern architectural building with large glass windows, situated on a cliff overlooking a serene ocean at sunset
+photo of an ancient shipwreck nestled on the ocean floor. Marine plants have claimed the wooden structure, and fish swim in and out of its hollow spaces. Sunken treasures and old cannons are scattered around, providing a glimpse into the past
+A 3D render of a coffee mug placed on a window sill during a stormy day. The storm outside the window is reflected in the coffee, with miniature lightning bolts and turbulent waves seen inside the mug. The room is dimly lit, adding to the dramatic atmosphere.A minimap diorama of a cafe adorned with indoor plants. Wooden beams crisscross above, and a cold brew station stands out with tiny bottles and glasses.
+An antique botanical illustration drawn with fine lines and a touch of watercolour whimsy, depicting a strange lily crossed with a Venus flytrap, its petals poised as if ready to snap shut on any unsuspecting insects.An illustration inspired by old-world botanical sketches blends a cactus with lilac blooms into a Möbius strip, using detailed lines and subtle watercolor touches to capture nature's diverse beauty and mathematical intrigue.
+An ink sketch style illustration of a small hedgehog holding a piece of watermelon with its tiny paws, taking little bites with its eyes closed in delight.Photo of a lychee-inspired spherical chair, with a bumpy white exterior and plush interior, set against a tropical wallpaper.
+3d digital art of an adorable ghost, glowing within, holding a heart shaped pumpkin, Halloween, super cute, spooky haunted house background
+professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.
+an astronaut sitting in a diner, eating fries, cinematic, analog film
+Chinese architecture, ancient style,mountain, bird, lotus, pond, big tree, 4K Unity, octane rendering.
+Ethereal fantasy concept art of thunder god with hammer. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy.
+A Japanese girl walking along a path, surrounding by blooming oriental cherry, pink petal slowly falling down to the ground
+A Ukiyoe style painting, an astronaut riding a unicorn, In the background there is an ancient Japanese architecture
+Steampunk makeup, in the style of vray tracing, colorful impasto, uhd image, indonesian art, fine feather details with bright red and yellow and green and pink and orange colours, intricate patterns and details, dark cyan and amber makeup. Rich colourful plumes. Victorian style.
+A cute teddy bear in front of a plain white wall, warm and brown fur, soft and fluffy
+The beautiful scenery of Seattle, painting by Al Capp.
+Photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang.
+An astronaut riding a horse on the moon, oil painting by Van Gogh.
+A deep forest clearing with a mirrored pond reflecting a galaxy-filled night sky
+Realistic oil painting of a stunning model merged in multicolor splash made of finely torn paper, eye contact, walking with class in a street.
+a chinese model is sitting on a train, magazine cover, clothes made of plastic, photorealistic,futuristic style, gray and green light, movie lighting, 32K HD
+a handsome 24 years old boy in the middle with sky color background wearing eye glasses, it's super detailed with anime style, it's a portrait with delicated eyes and nice looking face
+a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, national geographic photo, 8k resolution, crayon art, interactive artwork
+3D rendering miniature scene design, Many tall buildings, A winding urban road runs through the middle,a lot of cars on the road, transparent material pipeline transports Materials, ,there are many people around, in thestyle of light orange and yellow, graphic design- inspired illustrations, classic still-life, beeple, josan gon-zalez, manga-influenced, miniature dioramas, in thestyle of playful and whimsical designs, graphic de-sign-inspired illustrations, minimalism, hyperrealismlomo lca, e-commerce C4D style, e-commerce posterUl, UX, octane render, blender
+Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works
+A cute orange kitten sliding down an aqua slide. happy excited. 16mm lens in front. we see his excitement and scared in the eye. vibrant colors. water splashing on the lens
+Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.
+A gorgeously rendered papercraft world of a coral reef, rife with colorful fish and sea creatures.
+An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt , he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film.
+A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.
+A New Zealand female business owner stands and is happy that his business is growing by having good VoIP and broadband supplied by Voyager Internet. This business owner is dressed semi casual and is standing with a funky office space in the background. The image is light and bright and is well lit. This image needs to be shot like a professional photo shoot using a Canon R6 with high quality 25mm lens. This image has a shallow depth of field
+The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8
+Editorial photoshoot of a old woman, high fashion 2000s fashion
+Mural Painted of Prince in Purple Rain on side of 5 story brick building next to zen garden vacant lot in the urban center district, rgb
+Cozy Scandinavian living room, there is a cat sleeping on the couch, depth of field
+Street style centered straight shot photo shot on Afga Vista 400, lense 50mm, of a two women,skin to skin touch face, emotion, hughing, natural blond hair, natural features, ultra detailed, skin texture, Rembrandt light, soft shadows
+Frog, in forest, colorful, no watermark, no signature, in forest, 8k
+selfie of a woman and her lion cub on the plains
+A fisherman fixing his net sitting on a beautiful tropical beach at sunset with bending palm trees fishing gear and a small boat on shore
+Coast, decorative painting, horizon, modern, fashionable, full of abstract feeling, full of imagination, the picture reveals the sense of time passing, there is a feeling of the end of the world
+A close up of a branch of a tree and a golden bug on the top a leaf, shutterstock contest winner,ecological art, depth of field, shallow depth of field, macro photography
+Outdoor style fashion photo, full – body shot of a man with short brown hair, happy and smiling, he is standing on his hipster bicycle wearing a light blue long sleeved blouse with closed buttons and dark blue jeans trousers, in the background the exterior of an Aldi store, fully lit background, natural afternoon lighting
+beautiful woman sniper, wearing soviet army uniform, one eye on sniper lens, in snow ground
+A very attractive and natural woman, sitting on a yoka mat, breathing, eye closed, no make up, intense satisfaction, she looks like she is intensely relaxed, yoga class, sunrise, 35mm
+a close up of a helmet on a person, digital art, inspired by Han Gan, cloisonnism, female, victorian armor, ultramarine, best of behance, anton fadeev 8 k, fined detail, sci-fi character, elegant armor, fantasy art behance
+a melting apple
+yellow FIAT 500 Cinquecento 1957 driving through liechtenstein castle with a lot of banknotes scattered behind ,filled with wads of cash , car color yellow, license plate R-33
+tented resort in the desert, rocky and sandy terrain, 5 star hotel, beautiful landscape, landscape photography, depth of view, Fujifilm GFX 100 –uplight
+Full body shot, a French woman, Photography, French Streets background, backlighting, rim light, Fujifilm.
+Modern luxury contemporary luxury home interiors house, in the style of mimicking ruined materials, ray tracing, haunting houses, and stone, capture the essence of nature, gray and bronze, dynamic outdoor shots.
+Over the shoulder game perspective, game screen of Diablo 4, Inside the gorgeous palace is the wet ground, The necromancer knelt before the king, and a horde of skeletons he summoned stood at his side, cinematic light.
+Color photo of a corgi made of transparent glass, standing on the riverside in Yosemite National Park.
+Happy dreamy owl monster sitting on a tree branch, colorful glittering particles, forest background, detailed feathers.
+Game-Art - An island with different geographical properties and multiple small cities floating in space
+Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee.
+A car made out of vegetables.
+A serene lakeside during autumn with trees displaying a palette of fiery colors.
+A realistic landscape shot of the Northern Lights dancing over a snowy mountain range in Iceland.
+A deep forest clearing with a mirrored pond reflecting a galaxy-filled night sky.
+Drone view of waves crashing against the rugged cliffs along Big Sur’s Garay Point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore.
+A curvy timber house near a sea, designed by Zaha Hadid, represent the image of a cold, modern architecture, at night, white lighting, highly detailed.
+Eiffel Tower was Made up of more than 2 million translucent straws to look like a cloud, with the bell tower at the top of the building, Michel installed huge foam-making machines in the forest to blow huge amounts of unpredictable wet clouds in the building's classic architecture.
+Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works.
+smiling cartoon dog sits at a table, coffee mug on hand, as a room goes up in flames. "Help" the dog is yelling.
+A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.
+A close-up photo of a person. The subject is a woman. She wore a blue coat with a gray dress underneath. She has blue eyes and blond hair and wears a pair of earrings. Behind are blurred city buildings and streets.
+👧 with 🌹 in the ❄️
+🐶 Wearing 🕶 flying on the 🌈
+a cyberpunk cat with a neon sign that says "MIT"
+a black and white picture of a woman looking through the window, in the style of Duffy Sheridan, Anna Razumovskaya, smooth and shiny, wavy, Patrick Demarchelier, album covers, lush and detailed.
diff --git a/asset/samples_mini.txt b/asset/samples_mini.txt
new file mode 100755
index 0000000..2775ad7
--- /dev/null
+++ b/asset/samples_mini.txt
@@ -0,0 +1,10 @@
+A cyberpunk cat with a neon sign that says 'Sana'.
+A small cactus with a happy face in the Sahara desert.
+The towel was on top of the hard counter.
+A vast landscape made entirely of various meats spreads out before the viewer. tender, succulent hills of roast beef, chicken drumstick trees, bacon rivers, and ham boulders create a surreal, yet appetizing scene. the sky is adorned with pepperoni sun and salami clouds.
+I want to supplement vitamin c, please help me paint related food.
+A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.
+an old rusted robot wearing pants and a jacket riding skis in a supermarket.
+professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.
+Astronaut in a jungle, cold color palette, muted colors, detailed
+a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests.
diff --git a/configs/sana_app_config/Sana_1600M_app.yaml b/configs/sana_app_config/Sana_1600M_app.yaml
new file mode 100644
index 0000000..ec941f2
--- /dev/null
+++ b/configs/sana_app_config/Sana_1600M_app.yaml
@@ -0,0 +1,107 @@
+data:
+ data_dir: []
+ image_size: 1024
+ caption_proportion:
+ prompt: 1
+ external_caption_suffixes: []
+ external_clipscore_suffixes: []
+ clip_thr_temperature: 0.1
+ clip_thr: 25.0
+ load_text_feat: false
+ load_vae_feat: false
+ transform: default_train
+ type: SanaWebDatasetMS
+ data:
+ sort_dataset: false
+# model config
+model:
+ model: SanaMS_1600M_P1_D20
+ image_size: 1024
+ mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
+ fp32_attention: true
+ load_from:
+ resume_from:
+ aspect_ratio_type: ASPECT_RATIO_1024
+ multi_scale: true
+ #pe_interpolation: 1.
+ attn_type: linear
+ ffn_type: glumbconv
+ mlp_acts:
+ - silu
+ - silu
+ -
+ mlp_ratio: 2.5
+ use_pe: false
+ qk_norm: false
+ class_dropout_prob: 0.1
+ # CFG & PAG settings
+ pag_applied_layers:
+ - 8
+# VAE setting
+vae:
+ vae_type: dc-ae
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
+ scale_factor: 0.41407
+ vae_latent_dim: 32
+ vae_downsample_rate: 32
+ sample_posterior: true
+# text encoder
+text_encoder:
+ text_encoder_name: gemma-2-2b-it
+ y_norm: true
+ y_norm_scale_factor: 0.01
+ model_max_length: 300
+ # CHI
+ chi_prompt:
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
+ - 'Here are examples of how to transform or refine prompts:'
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
+ - 'User Prompt: '
+# Sana schedule Flow
+scheduler:
+ predict_v: true
+ noise_schedule: linear_flow
+ pred_sigma: false
+ flow_shift: 3.0
+ # logit-normal timestep
+ weighting_scheme: logit_normal
+ logit_mean: 0.0
+ logit_std: 1.0
+ vis_sampler: flow_dpm-solver
+# training setting
+train:
+ num_workers: 10
+ seed: 1
+ train_batch_size: 64
+ num_epochs: 100
+ gradient_accumulation_steps: 1
+ grad_checkpointing: true
+ gradient_clip: 0.1
+ optimizer:
+ betas:
+ - 0.9
+ - 0.999
+ - 0.9999
+ eps:
+ - 1.0e-30
+ - 1.0e-16
+ lr: 0.0001
+ type: CAMEWrapper
+ weight_decay: 0.0
+ lr_schedule: constant
+ lr_schedule_args:
+ num_warmup_steps: 2000
+ local_save_vis: true # if save log image locally
+ visualize: true
+ eval_sampling_steps: 500
+ log_interval: 20
+ save_model_epochs: 5
+ save_model_steps: 500
+ work_dir: output/debug
+ online_metric: false
+ eval_metric_step: 2000
+ online_metric_dir: metric_helper
diff --git a/configs/sana_app_config/Sana_600M_app.yaml b/configs/sana_app_config/Sana_600M_app.yaml
new file mode 100644
index 0000000..f6ae866
--- /dev/null
+++ b/configs/sana_app_config/Sana_600M_app.yaml
@@ -0,0 +1,105 @@
+data:
+ data_dir: []
+ image_size: 1024
+ caption_proportion:
+ prompt: 1
+ external_caption_suffixes: []
+ external_clipscore_suffixes: []
+ clip_thr_temperature: 0.1
+ clip_thr: 25.0
+ load_text_feat: false
+ load_vae_feat: true
+ transform: default_train
+ type: SanaWebDatasetMS
+ sort_dataset: false
+# model config
+model:
+ model: SanaMS_600M_P1_D28
+ image_size: 1024
+ mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
+ fp32_attention: true
+ load_from:
+ resume_from:
+ aspect_ratio_type: ASPECT_RATIO_1024
+ multi_scale: true
+ attn_type: linear
+ ffn_type: glumbconv
+ mlp_acts:
+ - silu
+ - silu
+ -
+ mlp_ratio: 2.5
+ use_pe: false
+ qk_norm: false
+ class_dropout_prob: 0.1
+ # CFG & PAG settings
+ pag_applied_layers:
+ - 14
+# VAE setting
+vae:
+ vae_type: dc-ae
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
+ scale_factor: 0.41407
+ vae_latent_dim: 32
+ vae_downsample_rate: 32
+ sample_posterior: true
+# text encoder
+text_encoder:
+ text_encoder_name: gemma-2-2b-it
+ y_norm: true
+ y_norm_scale_factor: 0.01
+ model_max_length: 300
+ # CHI
+ chi_prompt:
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
+ - 'Here are examples of how to transform or refine prompts:'
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
+ - 'User Prompt: '
+# Sana schedule Flow
+scheduler:
+ predict_v: true
+ noise_schedule: linear_flow
+ pred_sigma: false
+ flow_shift: 4.0
+ # logit-normal timestep
+ weighting_scheme: logit_normal
+ logit_mean: 0.0
+ logit_std: 1.0
+ vis_sampler: flow_dpm-solver
+# training setting
+train:
+ num_workers: 10
+ seed: 1
+ train_batch_size: 64
+ num_epochs: 100
+ gradient_accumulation_steps: 1
+ grad_checkpointing: true
+ gradient_clip: 0.1
+ optimizer:
+ betas:
+ - 0.9
+ - 0.999
+ - 0.9999
+ eps:
+ - 1.0e-30
+ - 1.0e-16
+ lr: 0.0001
+ type: CAMEWrapper
+ weight_decay: 0.0
+ lr_schedule: constant
+ lr_schedule_args:
+ num_warmup_steps: 2000
+ local_save_vis: true # if save log image locally
+ visualize: true
+ eval_sampling_steps: 500
+ log_interval: 20
+ save_model_epochs: 5
+ save_model_steps: 500
+ work_dir: output/debug
+ online_metric: false
+ eval_metric_step: 2000
+ online_metric_dir: metric_helper
diff --git a/configs/sana_base.yaml b/configs/sana_base.yaml
new file mode 100644
index 0000000..abbd935
--- /dev/null
+++ b/configs/sana_base.yaml
@@ -0,0 +1,140 @@
+# data settings
+data:
+ data_dir: []
+ caption_proportion:
+ prompt: 1
+ external_caption_suffixes: []
+ external_clipscore_suffixes: []
+ clip_thr_temperature: 1.0
+ clip_thr: 0.0
+ sort_dataset: false
+ load_text_feat: false
+ load_vae_feat: false
+ transform: default_train
+ type: SanaWebDatasetMS
+ image_size: 512
+ hq_only: false
+ valid_num: 0
+# model settings
+model:
+ model: SanaMS_600M_P1_D28
+ image_size: 512
+ mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
+ fp32_attention: true
+ load_from:
+ resume_from:
+ checkpoint:
+ load_ema: false
+ resume_lr_scheduler: true
+ resume_optimizer: true
+ aspect_ratio_type: ASPECT_RATIO_1024
+ multi_scale: true
+ pe_interpolation: 1.0
+ micro_condition: false
+ attn_type: linear # 'flash', 'linear', 'vanilla', 'triton_linear'
+ cross_norm: false
+ autocast_linear_attn: false
+ ffn_type: glumbconv
+ mlp_acts:
+ - silu
+ - silu
+ -
+ mlp_ratio: 2.5
+ use_pe: false
+ qk_norm: false
+ class_dropout_prob: 0.0
+ linear_head_dim: 32
+ # CFG & PAG settings
+ cfg_scale: 4
+ guidance_type: classifier-free
+ pag_applied_layers: [14]
+# text encoder settings
+text_encoder:
+ text_encoder_name: gemma-2-2b-it
+ caption_channels: 2304
+ y_norm: false
+ y_norm_scale_factor: 1.0
+ model_max_length: 300
+ chi_prompt: []
+# VAE settings
+vae:
+ vae_type: dc-ae
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
+ scale_factor: 0.41407
+ vae_latent_dim: 32
+ vae_downsample_rate: 32
+ sample_posterior: true
+# Scheduler settings
+scheduler:
+ train_sampling_steps: 1000
+ predict_v: True
+ noise_schedule: linear_flow
+ pred_sigma: false
+ flow_shift: 1.0
+ weighting_scheme: logit_normal
+ logit_mean: 0.0
+ logit_std: 1.0
+ vis_sampler: flow_dpm-solver
+# training settings
+train:
+ num_workers: 4
+ seed: 43
+ train_batch_size: 32
+ num_epochs: 100
+ gradient_accumulation_steps: 1
+ grad_checkpointing: false
+ gradient_clip: 1.0
+ gc_step: 1
+ # optimizer settings
+ optimizer:
+ eps: 1.0e-10
+ lr: 0.0001
+ type: AdamW
+ weight_decay: 0.03
+ lr_schedule: constant
+ lr_schedule_args:
+ num_warmup_steps: 500
+ auto_lr:
+ rule: sqrt
+ ema_rate: 0.9999
+ eval_batch_size: 16
+ use_fsdp: false
+ use_flash_attn: false
+ eval_sampling_steps: 250
+ lora_rank: 4
+ log_interval: 50
+ mask_type: 'null'
+ mask_loss_coef: 0.0
+ load_mask_index: false
+ snr_loss: false
+ real_prompt_ratio: 1.0
+ debug_nan: false
+ # checkpoint settings
+ save_image_epochs: 1
+ save_model_epochs: 1
+ save_model_steps: 1000000
+ # visualization settings
+ visualize: false
+ null_embed_root: output/pretrained_models/
+ valid_prompt_embed_root: output/tmp_embed/
+ validation_prompts:
+ - dog
+ - portrait photo of a girl, photograph, highly detailed face, depth of field
+ - Self-portrait oil painting, a beautiful cyborg with golden hair, 8k
+ - Astronaut in a jungle, cold color palette, muted colors, detailed, 8k
+ - A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece
+ local_save_vis: false
+ deterministic_validation: true
+ online_metric: false
+ eval_metric_step: 5000
+ online_metric_dir: metric_helper
+ # work dir settings
+ work_dir: /cache/exps/
+ skip_step: 0
+ # LCM settings
+ loss_type: huber
+ huber_c: 0.001
+ num_ddim_timesteps: 50
+ w_max: 15.0
+ w_min: 3.0
+ ema_decay: 0.95
diff --git a/configs/sana_config/1024ms/Sana_1600M_img1024.yaml b/configs/sana_config/1024ms/Sana_1600M_img1024.yaml
new file mode 100644
index 0000000..07a0884
--- /dev/null
+++ b/configs/sana_config/1024ms/Sana_1600M_img1024.yaml
@@ -0,0 +1,109 @@
+data:
+ data_dir: [data/data_public/dir1]
+ image_size: 1024
+ caption_proportion:
+ prompt: 1
+ external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
+ external_clipscore_suffixes:
+ - _InternVL2-26B_clip_score
+ - _VILA1-5-13B_clip_score
+ - _prompt_clip_score
+ clip_thr_temperature: 0.1
+ clip_thr: 25.0
+ load_text_feat: false
+ load_vae_feat: false
+ transform: default_train
+ type: SanaWebDatasetMS
+ sort_dataset: false
+# model config
+model:
+ model: SanaMS_1600M_P1_D20
+ image_size: 1024
+ mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
+ fp32_attention: true
+ load_from:
+ resume_from:
+ aspect_ratio_type: ASPECT_RATIO_1024
+ multi_scale: true
+ #pe_interpolation: 1.
+ attn_type: linear
+ ffn_type: glumbconv
+ mlp_acts:
+ - silu
+ - silu
+ -
+ mlp_ratio: 2.5
+ use_pe: false
+ qk_norm: false
+ class_dropout_prob: 0.1
+ # PAG
+ pag_applied_layers:
+ - 8
+# VAE setting
+vae:
+ vae_type: dc-ae
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
+ scale_factor: 0.41407
+ vae_latent_dim: 32
+ vae_downsample_rate: 32
+ sample_posterior: true
+# text encoder
+text_encoder:
+ text_encoder_name: gemma-2-2b-it
+ y_norm: true
+ y_norm_scale_factor: 0.01
+ model_max_length: 300
+ # CHI
+ chi_prompt:
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
+ - 'Here are examples of how to transform or refine prompts:'
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
+ - 'User Prompt: '
+# Sana schedule Flow
+scheduler:
+ predict_v: true
+ noise_schedule: linear_flow
+ pred_sigma: false
+ flow_shift: 3.0
+ # logit-normal timestep
+ weighting_scheme: logit_normal
+ logit_mean: 0.0
+ logit_std: 1.0
+ vis_sampler: flow_dpm-solver
+# training setting
+train:
+ num_workers: 10
+ seed: 1
+ train_batch_size: 64
+ num_epochs: 100
+ gradient_accumulation_steps: 1
+ grad_checkpointing: true
+ gradient_clip: 0.1
+ optimizer:
+ betas:
+ - 0.9
+ - 0.999
+ - 0.9999
+ eps:
+ - 1.0e-30
+ - 1.0e-16
+ lr: 0.0001
+ type: CAMEWrapper
+ weight_decay: 0.0
+ lr_schedule: constant
+ lr_schedule_args:
+ num_warmup_steps: 2000
+ local_save_vis: true # if save log image locally
+ visualize: true
+ eval_sampling_steps: 500
+ log_interval: 20
+ save_model_epochs: 5
+ save_model_steps: 500
+ work_dir: output/debug
+ online_metric: false
+ eval_metric_step: 2000
+ online_metric_dir: metric_helper
diff --git a/configs/sana_config/1024ms/Sana_600M_img1024.yaml b/configs/sana_config/1024ms/Sana_600M_img1024.yaml
new file mode 100644
index 0000000..2cce306
--- /dev/null
+++ b/configs/sana_config/1024ms/Sana_600M_img1024.yaml
@@ -0,0 +1,105 @@
+data:
+ data_dir: [data/data_public/dir1]
+ image_size: 1024
+ caption_proportion:
+ prompt: 1
+ external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
+ external_clipscore_suffixes:
+ - _InternVL2-26B_clip_score
+ - _VILA1-5-13B_clip_score
+ - _prompt_clip_score
+ clip_thr_temperature: 0.1
+ clip_thr: 25.0
+ load_text_feat: false
+ load_vae_feat: false
+ transform: default_train
+ type: SanaWebDatasetMS
+ sort_dataset: false
+# model config
+model:
+ model: SanaMS_600M_P1_D28
+ image_size: 1024
+ mixed_precision: fp16
+ fp32_attention: true
+ load_from:
+ resume_from:
+ aspect_ratio_type: ASPECT_RATIO_1024
+ multi_scale: true
+ attn_type: linear
+ ffn_type: glumbconv
+ mlp_acts:
+ - silu
+ - silu
+ -
+ mlp_ratio: 2.5
+ use_pe: false
+ qk_norm: false
+ class_dropout_prob: 0.1
+# VAE setting
+vae:
+ vae_type: dc-ae
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
+ scale_factor: 0.41407
+ vae_latent_dim: 32
+ vae_downsample_rate: 32
+ sample_posterior: true
+# text encoder
+text_encoder:
+ text_encoder_name: gemma-2-2b-it
+ y_norm: true
+ y_norm_scale_factor: 0.01
+ model_max_length: 300
+ # CHI
+ chi_prompt:
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
+ - 'Here are examples of how to transform or refine prompts:'
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
+ - 'User Prompt: '
+# Sana schedule Flow
+scheduler:
+ predict_v: true
+ noise_schedule: linear_flow
+ pred_sigma: false
+ flow_shift: 4.0
+ # logit-normal timestep
+ weighting_scheme: logit_normal
+ logit_mean: 0.0
+ logit_std: 1.0
+ vis_sampler: flow_dpm-solver
+# training setting
+train:
+ num_workers: 10
+ seed: 1
+ train_batch_size: 64
+ num_epochs: 100
+ gradient_accumulation_steps: 1
+ grad_checkpointing: true
+ gradient_clip: 0.1
+ optimizer:
+ betas:
+ - 0.9
+ - 0.999
+ - 0.9999
+ eps:
+ - 1.0e-30
+ - 1.0e-16
+ lr: 0.0001
+ type: CAMEWrapper
+ weight_decay: 0.0
+ lr_schedule: constant
+ lr_schedule_args:
+ num_warmup_steps: 2000
+ local_save_vis: true # if save log image locally
+ visualize: true
+ eval_sampling_steps: 500
+ log_interval: 20
+ save_model_epochs: 5
+ save_model_steps: 500
+ work_dir: output/debug
+ online_metric: false
+ eval_metric_step: 2000
+ online_metric_dir: metric_helper
diff --git a/configs/sana_config/512ms/Sana_1600M_img512.yaml b/configs/sana_config/512ms/Sana_1600M_img512.yaml
new file mode 100644
index 0000000..140b7cb
--- /dev/null
+++ b/configs/sana_config/512ms/Sana_1600M_img512.yaml
@@ -0,0 +1,108 @@
+data:
+ data_dir: [data/data_public/dir1]
+ image_size: 512
+ caption_proportion:
+ prompt: 1
+ external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
+ external_clipscore_suffixes:
+ - _InternVL2-26B_clip_score
+ - _VILA1-5-13B_clip_score
+ - _prompt_clip_score
+ clip_thr_temperature: 0.1
+ clip_thr: 25.0
+ load_text_feat: false
+ load_vae_feat: false
+ transform: default_train
+ type: SanaWebDatasetMS
+ sort_dataset: false
+# model config
+model:
+ model: SanaMS_1600M_P1_D20
+ image_size: 512
+ mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
+ fp32_attention: true
+ load_from:
+ resume_from:
+ aspect_ratio_type: ASPECT_RATIO_512
+ multi_scale: true
+ attn_type: linear
+ ffn_type: glumbconv
+ mlp_acts:
+ - silu
+ - silu
+ -
+ mlp_ratio: 2.5
+ use_pe: false
+ qk_norm: false
+ class_dropout_prob: 0.1
+ # PAG
+ pag_applied_layers:
+ - 8
+# VAE setting
+vae:
+ vae_type: dc-ae
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
+ scale_factor: 0.41407
+ vae_latent_dim: 32
+ vae_downsample_rate: 32
+ sample_posterior: true
+# text encoder
+text_encoder:
+ text_encoder_name: gemma-2-2b-it
+ y_norm: true
+ y_norm_scale_factor: 0.01
+ model_max_length: 300
+ # CHI
+ chi_prompt:
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
+ - 'Here are examples of how to transform or refine prompts:'
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
+ - 'User Prompt: '
+# Sana schedule Flow
+scheduler:
+ predict_v: true
+ noise_schedule: linear_flow
+ pred_sigma: false
+ flow_shift: 1.0
+ # logit-normal timestep
+ weighting_scheme: logit_normal
+ logit_mean: 0.0
+ logit_std: 1.0
+ vis_sampler: flow_dpm-solver
+# training setting
+train:
+ num_workers: 10
+ seed: 1
+ train_batch_size: 64
+ num_epochs: 100
+ gradient_accumulation_steps: 1
+ grad_checkpointing: true
+ gradient_clip: 0.1
+ optimizer:
+ betas:
+ - 0.9
+ - 0.999
+ - 0.9999
+ eps:
+ - 1.0e-30
+ - 1.0e-16
+ lr: 0.0001
+ type: CAMEWrapper
+ weight_decay: 0.0
+ lr_schedule: constant
+ lr_schedule_args:
+ num_warmup_steps: 2000
+ local_save_vis: true # if save log image locally
+ visualize: true
+ eval_sampling_steps: 500
+ log_interval: 20
+ save_model_epochs: 5
+ save_model_steps: 500
+ work_dir: output/debug
+ online_metric: false
+ eval_metric_step: 2000
+ online_metric_dir: metric_helper
diff --git a/configs/sana_config/512ms/Sana_600M_img512.yaml b/configs/sana_config/512ms/Sana_600M_img512.yaml
new file mode 100644
index 0000000..b3f2e12
--- /dev/null
+++ b/configs/sana_config/512ms/Sana_600M_img512.yaml
@@ -0,0 +1,107 @@
+data:
+ data_dir: [data/data_public/dir1]
+ image_size: 512
+ caption_proportion:
+ prompt: 1
+ external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
+ external_clipscore_suffixes:
+ - _InternVL2-26B_clip_score
+ - _VILA1-5-13B_clip_score
+ - _prompt_clip_score
+ clip_thr_temperature: 0.1
+ clip_thr: 25.0
+ load_text_feat: false
+ load_vae_feat: false
+ transform: default_train
+ type: SanaWebDatasetMS
+ sort_dataset: false
+# model config
+model:
+ model: SanaMS_600M_P1_D28
+ image_size: 512
+ mixed_precision: fp16
+ fp32_attention: true
+ load_from:
+ resume_from:
+ aspect_ratio_type: ASPECT_RATIO_512
+ multi_scale: true
+ #pe_interpolation: 1.
+ attn_type: linear
+ linear_head_dim: 32
+ ffn_type: glumbconv
+ mlp_acts:
+ - silu
+ - silu
+ - null
+ mlp_ratio: 2.5
+ use_pe: false
+ qk_norm: false
+ class_dropout_prob: 0.1
+# VAE setting
+vae:
+ vae_type: dc-ae
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
+ scale_factor: 0.41407
+ vae_latent_dim: 32
+ vae_downsample_rate: 32
+ sample_posterior: true
+# text encoder
+text_encoder:
+ text_encoder_name: gemma-2-2b-it
+ y_norm: true
+ y_norm_scale_factor: 0.01
+ model_max_length: 300
+ # CHI
+ chi_prompt:
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
+ - 'Here are examples of how to transform or refine prompts:'
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
+ - 'User Prompt: '
+# Sana schedule Flow
+scheduler:
+ predict_v: true
+ noise_schedule: linear_flow
+ pred_sigma: false
+ flow_shift: 1.0
+ # logit-normal timestep
+ weighting_scheme: logit_normal
+ logit_mean: 0.0
+ logit_std: 1.0
+ vis_sampler: flow_dpm-solver
+# training setting
+train:
+ num_workers: 10
+ seed: 1
+ train_batch_size: 128
+ num_epochs: 100
+ gradient_accumulation_steps: 1
+ grad_checkpointing: true
+ gradient_clip: 0.1
+ optimizer:
+ betas:
+ - 0.9
+ - 0.999
+ - 0.9999
+ eps:
+ - 1.0e-30
+ - 1.0e-16
+ lr: 0.0001
+ type: CAMEWrapper
+ weight_decay: 0.0
+ lr_schedule: constant
+ lr_schedule_args:
+ num_warmup_steps: 2000
+ local_save_vis: true # if save log image locally
+ visualize: true
+ eval_sampling_steps: 500
+ log_interval: 20
+ save_model_epochs: 5
+ save_model_steps: 500
+ work_dir: output/debug
+ online_metric: false
+ eval_metric_step: 2000
+ online_metric_dir: metric_helper
diff --git a/configs/sana_config/512ms/ci_Sana_600M_img512.yaml b/configs/sana_config/512ms/ci_Sana_600M_img512.yaml
new file mode 100644
index 0000000..6a84578
--- /dev/null
+++ b/configs/sana_config/512ms/ci_Sana_600M_img512.yaml
@@ -0,0 +1,107 @@
+data:
+ data_dir: [data/data_public/vaef32c32_v2_512/dir1]
+ image_size: 512
+ caption_proportion:
+ prompt: 1
+ external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
+ external_clipscore_suffixes:
+ - _InternVL2-26B_clip_score
+ - _VILA1-5-13B_clip_score
+ - _prompt_clip_score
+ clip_thr_temperature: 0.1
+ clip_thr: 25.0
+ load_text_feat: false
+ load_vae_feat: false
+ transform: default_train
+ type: SanaWebDatasetMS
+ sort_dataset: false
+# model config
+model:
+ model: SanaMS_600M_P1_D28
+ image_size: 512
+ mixed_precision: fp16
+ fp32_attention: true
+ load_from:
+ resume_from:
+ aspect_ratio_type: ASPECT_RATIO_512
+ multi_scale: true
+ #pe_interpolation: 1.
+ attn_type: linear
+ linear_head_dim: 32
+ ffn_type: glumbconv
+ mlp_acts:
+ - silu
+ - silu
+ - null
+ mlp_ratio: 2.5
+ use_pe: false
+ qk_norm: false
+ class_dropout_prob: 0.1
+# VAE setting
+vae:
+ vae_type: dc-ae
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
+ scale_factor: 0.41407
+ vae_latent_dim: 32
+ vae_downsample_rate: 32
+ sample_posterior: true
+# text encoder
+text_encoder:
+ text_encoder_name: gemma-2-2b-it
+ y_norm: true
+ y_norm_scale_factor: 0.01
+ model_max_length: 300
+ # CHI
+ chi_prompt:
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
+ - 'Here are examples of how to transform or refine prompts:'
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
+ - 'User Prompt: '
+# Sana schedule Flow
+scheduler:
+ predict_v: true
+ noise_schedule: linear_flow
+ pred_sigma: false
+ flow_shift: 1.0
+ # logit-normal timestep
+ weighting_scheme: logit_normal
+ logit_mean: 0.0
+ logit_std: 1.0
+ vis_sampler: flow_dpm-solver
+# training setting
+train:
+ num_workers: 10
+ seed: 1
+ train_batch_size: 64
+ num_epochs: 1
+ gradient_accumulation_steps: 1
+ grad_checkpointing: true
+ gradient_clip: 0.1
+ optimizer:
+ betas:
+ - 0.9
+ - 0.999
+ - 0.9999
+ eps:
+ - 1.0e-30
+ - 1.0e-16
+ lr: 0.0001
+ type: CAMEWrapper
+ weight_decay: 0.0
+ lr_schedule: constant
+ lr_schedule_args:
+ num_warmup_steps: 2000
+ local_save_vis: true # if save log image locally
+ visualize: true
+ eval_sampling_steps: 500
+ log_interval: 20
+ save_model_epochs: 5
+ save_model_steps: 500
+ work_dir: output/debug
+ online_metric: false
+ eval_metric_step: 2000
+ online_metric_dir: metric_helper
diff --git a/configs/sana_config/512ms/sample_dataset.yaml b/configs/sana_config/512ms/sample_dataset.yaml
new file mode 100644
index 0000000..be53827
--- /dev/null
+++ b/configs/sana_config/512ms/sample_dataset.yaml
@@ -0,0 +1,107 @@
+data:
+ data_dir: [asset/example_data]
+ image_size: 512
+ caption_proportion:
+ prompt: 1
+ external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] # json fils
+ external_clipscore_suffixes: # json files
+ - _InternVL2-26B_clip_score
+ - _VILA1-5-13B_clip_score
+ - _prompt_clip_score
+ clip_thr_temperature: 0.1
+ clip_thr: 25.0
+ load_text_feat: false
+ load_vae_feat: false
+ transform: default_train
+ type: SanaImgDataset
+ sort_dataset: false
+# model config
+model:
+ model: SanaMS_600M_P1_D28
+ image_size: 512
+ mixed_precision: fp16
+ fp32_attention: true
+ load_from:
+ resume_from:
+ aspect_ratio_type: ASPECT_RATIO_512
+ multi_scale: false
+ #pe_interpolation: 1.
+ attn_type: linear
+ linear_head_dim: 32
+ ffn_type: glumbconv
+ mlp_acts:
+ - silu
+ - silu
+ - null
+ mlp_ratio: 2.5
+ use_pe: false
+ qk_norm: false
+ class_dropout_prob: 0.1
+# VAE setting
+vae:
+ vae_type: dc-ae
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
+ scale_factor: 0.41407
+ vae_latent_dim: 32
+ vae_downsample_rate: 32
+ sample_posterior: true
+# text encoder
+text_encoder:
+ text_encoder_name: gemma-2-2b-it
+ y_norm: true
+ y_norm_scale_factor: 0.01
+ model_max_length: 300
+ # CHI
+ chi_prompt:
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
+ - 'Here are examples of how to transform or refine prompts:'
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
+ - 'User Prompt: '
+# Sana schedule Flow
+scheduler:
+ predict_v: true
+ noise_schedule: linear_flow
+ pred_sigma: false
+ flow_shift: 1.0
+ # logit-normal timestep
+ weighting_scheme: logit_normal
+ logit_mean: 0.0
+ logit_std: 1.0
+ vis_sampler: flow_dpm-solver
+# training setting
+train:
+ num_workers: 10
+ seed: 1
+ train_batch_size: 128
+ num_epochs: 100
+ gradient_accumulation_steps: 1
+ grad_checkpointing: true
+ gradient_clip: 0.1
+ optimizer:
+ betas:
+ - 0.9
+ - 0.999
+ - 0.9999
+ eps:
+ - 1.0e-30
+ - 1.0e-16
+ lr: 0.0001
+ type: CAMEWrapper
+ weight_decay: 0.0
+ lr_schedule: constant
+ lr_schedule_args:
+ num_warmup_steps: 2000
+ local_save_vis: true # if save log image locally
+ visualize: true
+ eval_sampling_steps: 500
+ log_interval: 20
+ save_model_epochs: 5
+ save_model_steps: 500
+ work_dir: output/debug
+ online_metric: false
+ eval_metric_step: 2000
+ online_metric_dir: metric_helper
diff --git a/diffusion/data/wids/__init__.py b/diffusion/data/wids/__init__.py
new file mode 100755
index 0000000..bd3bd14
--- /dev/null
+++ b/diffusion/data/wids/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
+# This file is part of the WebDataset library.
+# See the LICENSE file for licensing terms (BSD-style).
+#
+# flake8: noqa
+
+from .wids import (
+ ChunkedSampler,
+ DistributedChunkedSampler,
+ DistributedLocalSampler,
+ DistributedRangedSampler,
+ ShardedSampler,
+ ShardListDataset,
+ ShardListDatasetMulti,
+ lru_json_load,
+)
diff --git a/diffusion/data/wids/wids.py b/diffusion/data/wids/wids.py
new file mode 100755
index 0000000..1e78f73
--- /dev/null
+++ b/diffusion/data/wids/wids.py
@@ -0,0 +1,1051 @@
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/NVlabs/VILA/tree/main/llava/wids
+import base64
+import gzip
+import hashlib
+import io
+import json
+import math
+import os
+import os.path as osp
+import random
+import re
+import sqlite3
+import sys
+import tempfile
+import uuid
+import warnings
+from functools import lru_cache, partial
+from typing import Any, BinaryIO, Dict, Optional, TypeVar, Union
+from urllib.parse import quote, urlparse
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.utils.data.distributed import DistributedSampler
+
+from .wids_dl import download_and_open
+from .wids_lru import LRUCache
+from .wids_mmtar import MMIndexedTar
+from .wids_specs import load_dsdesc_and_resolve, urldir
+from .wids_tar import TarFileReader, find_index_file
+
+try:
+ from torch.utils.data import Dataset, Sampler
+except ImportError:
+
+ class Dataset:
+ pass
+
+ class Sampler:
+ pass
+
+
+T = TypeVar("T")
+
+T_co = TypeVar("T_co", covariant=True)
+
+
+def compute_file_md5sum(fname: Union[str, BinaryIO], chunksize: int = 1000000) -> str:
+ """Compute the md5sum of a file in chunks.
+
+ Parameters
+ ----------
+ fname : Union[str, BinaryIO]
+ Filename or file object
+ chunksize : int, optional
+ Chunk size in bytes, by default 1000000
+
+ Returns
+ -------
+ str
+ MD5 sum of the file
+
+ Examples
+ --------
+ >>> compute_file_md5sum("test.txt")
+ 'd41d8cd98f00b204e9800998ecf8427e'
+ """
+ md5 = hashlib.md5()
+ if isinstance(fname, str):
+ with open(fname, "rb") as f:
+ for chunk in iter(lambda: f.read(chunksize), b""):
+ md5.update(chunk)
+ else:
+ fname.seek(0)
+ for chunk in iter(lambda: fname.read(chunksize), b""):
+ md5.update(chunk)
+ return md5.hexdigest()
+
+
+def compute_file_md5sum(fname: Union[str, BinaryIO], chunksize: int = 1000000) -> str:
+ """Compute the md5sum of a file in chunks."""
+ md5 = hashlib.md5()
+ if isinstance(fname, str):
+ with open(fname, "rb") as f:
+ for chunk in iter(lambda: f.read(chunksize), b""):
+ md5.update(chunk)
+ else:
+ fname.seek(0)
+ for chunk in iter(lambda: fname.read(chunksize), b""):
+ md5.update(chunk)
+ return md5.hexdigest()
+
+
+def compute_num_samples(fname):
+ ds = IndexedTarSamples(fname)
+ return len(ds)
+
+
+def splitname(fname):
+ """Returns the basename and extension of a filename"""
+ assert "." in fname, "Filename must have an extension"
+ # basename, extension = re.match(r"^((?:.*/)?.*?)(\..*)$", fname).groups()
+ basename, extension = os.path.splitext(fname)
+ return basename, extension
+
+
+# NOTE(ligeng): change to ordered mapping to more flexbile dict
+# TODO(ligeng): submit a PR to fix the mapping issue.
+def group_by_key(names):
+ """Group the file names by key.
+
+ Args:
+ names: A list of file names.
+
+ Returns:
+ A list of lists of indices, where each sublist contains indices of files
+ with the same key.
+ """
+ groups = []
+ kmaps = {}
+ for i, fname in enumerate(names):
+ # Ignore files that are not in a subdirectory.
+ if "." not in fname:
+ print(f"Warning: Ignoring file {fname} (no '.')")
+ continue
+ if fname == ".":
+ print(f"Warning: Ignoring the '.' file.")
+ continue
+ key, ext = splitname(fname)
+ if key not in kmaps:
+ kmaps[key] = []
+ kmaps[key].append(i)
+ for k, v in kmaps.items():
+ groups.append(v)
+ return groups
+
+
+def default_decoder(sample: Dict[str, Any], format: Optional[Union[bool, str]] = True):
+ """A default decoder for webdataset.
+
+ This handles common file extensions: .txt, .cls, .cls2,
+ .jpg, .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl.
+ These are the most common extensions used in webdataset.
+ For other extensions, users can provide their own decoder.
+
+ Args:
+ sample: sample, modified in place
+ """
+ sample = dict(sample)
+ for key, stream in sample.items():
+ extensions = key.split(".")
+ if len(extensions) < 1:
+ continue
+ extension = extensions[-1]
+ if extension in ["gz"]:
+ decompressed = gzip.decompress(stream.read())
+ stream = io.BytesIO(decompressed)
+ if len(extensions) < 2:
+ sample[key] = stream
+ continue
+ extension = extensions[-2]
+ if key.startswith("__"):
+ continue
+ elif extension in ["txt", "text"]:
+ value = stream.read()
+ sample[key] = value.decode("utf-8")
+ elif extension in ["cls", "cls2"]:
+ value = stream.read()
+ sample[key] = int(value.decode("utf-8"))
+ elif extension in ["jpg", "png", "ppm", "pgm", "pbm", "pnm"]:
+ if format == "PIL":
+ import PIL.Image
+
+ sample[key] = PIL.Image.open(stream)
+ elif format == "numpy":
+ import numpy as np
+
+ sample[key] = np.asarray(PIL.Image.open(stream))
+ else:
+ raise ValueError(f"Unknown format: {format}")
+ elif extension == "json":
+ import json
+
+ value = stream.read()
+ sample[key] = json.loads(value)
+ elif extension == "npy":
+ import numpy as np
+
+ sample[key] = np.load(stream)
+ elif extension == "mp":
+ import msgpack
+
+ value = stream.read()
+ sample[key] = msgpack.unpackb(value, raw=False)
+ elif extension in ["pt", "pth"]:
+ import torch
+
+ sample[key] = torch.load(stream)
+ elif extension in ["pickle", "pkl"]:
+ import pickle
+
+ sample[key] = pickle.load(stream)
+ elif extension == "mp4":
+ # Write stream to a temporary file
+ # with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmpfile:
+ # tmpfile.write(stream.read())
+ # tmpfile_path = tmpfile.name
+
+ # sample[key] = tmpfile_path
+ sample[key] = io.BytesIO(stream.read())
+ return sample
+
+
+def update_dict_with_extend(original_dict, update_dict):
+ for key, value in update_dict.items():
+ if key in original_dict and isinstance(original_dict[key], list) and isinstance(value, list):
+ original_dict[key].extend(value)
+ else:
+ original_dict[key] = value
+
+
+open_itfs = {}
+
+
+class IndexedTarSamples:
+ """A class that accesses samples in a tar file. The tar file must follow
+ WebDataset conventions. The tar file is indexed when the IndexedTarSamples
+ object is created. The samples are accessed by index using the __getitem__
+ method. The __getitem__ method returns a dictionary containing the files
+ for the sample. The key for each file is the extension of the file name.
+ The key "__key__" is reserved for the key of the sample (the basename of
+ each file without the extension). For example, if the tar file contains
+ the files "sample1.jpg" and "sample1.txt", then the sample with key
+ "sample1" will be returned as the dictionary {"jpg": ..., "txt": ...}.
+ """
+
+ def __init__(
+ self,
+ *,
+ path=None,
+ stream=None,
+ md5sum=None,
+ expected_size=None,
+ use_mmap=True,
+ index_file=find_index_file,
+ ):
+ assert path is not None or stream is not None
+
+ # Create TarFileReader object to read from tar_file
+ self.path = path
+ stream = self.stream = stream or open(path, "rb")
+
+ # verify the MD5 sum
+ if md5sum is not None:
+ stream.seek(0)
+ got = compute_file_md5sum(stream)
+ assert got == md5sum, f"MD5 sum mismatch: expected {md5sum}, got {got}"
+ stream.seek(0)
+
+ # use either the mmap or the stream based implementation
+ # NOTE(ligeng): https://stackoverflow.com/questions/11072705/twitter-trends-api-unicodedecodeerror-utf8-codec-cant-decode-byte-0x8b-in-po
+ # import gzip
+ # print("convert to gzip IO stream")
+ # stream = gzip.GzipFile(fileobj=stream)
+
+ if use_mmap:
+ self.reader = MMIndexedTar(stream)
+ else:
+ self.reader = TarFileReader(stream, index_file=index_file)
+
+ # Get list of all files in stream
+ all_files = self.reader.names()
+
+ # Group files by key into samples
+ self.samples = group_by_key(all_files)
+ # print("DEBUG:", list(all_files)[:20])
+ # print("DEBUG:", self.samples[:20])
+
+ # check that the number of samples is correct
+ if expected_size is not None:
+ assert len(self) == expected_size, f"Expected {expected_size} samples, got {len(self)}"
+
+ self.uuid = str(uuid.uuid4())
+
+ def close(self):
+ self.reader.close()
+ if not self.stream.closed:
+ self.stream.close()
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __getitem__(self, idx):
+ # Get indexes of files for the sample at index idx
+ try:
+ indexes = self.samples[idx]
+ except IndexError as e:
+ print(f"[wids-debug] curr idx: {idx}, total sample length: {len(self.samples)} {e}")
+ raise e
+ sample = {}
+ key = None
+ for i in indexes:
+ # Get filename and data for the file at index i
+ fname, data = self.reader.get_file(i)
+ # Split filename into key and extension
+ k, ext = splitname(fname)
+ # Make sure all files in sample have same key
+ key = key or k
+ assert key == k
+ sample[ext] = data
+ # Add key to sample
+ sample["__key__"] = key
+ return sample
+
+ def __str__(self):
+ return f""
+
+ def __repr__(self):
+ return str(self)
+
+
+def hash_localname(dldir="/tmp/_wids_cache"):
+ os.makedirs(dldir, exist_ok=True)
+
+ connection = sqlite3.connect(os.path.join(dldir, "cache.db"))
+ cursor = connection.cursor()
+ cursor.execute("CREATE TABLE IF NOT EXISTS cache (url TEXT PRIMARY KEY, path TEXT, checksum TEXT)")
+ connection.commit()
+
+ def f(shard):
+ """Given a URL, return a local name for the shard."""
+ if shard.startswith("pipe:"):
+ # uuencode the entire URL string
+ hex32 = base64.urlsafe_b64encode(hashlib.sha256(shard.encode()).digest())[:32].decode()
+ return os.path.join(dldir, "pipe__" + hex32)
+ else:
+ # we hash the host and directory components into a 16 character string
+ dirname = urldir(shard)
+ hex16 = base64.urlsafe_b64encode(hashlib.sha256(dirname.encode()).digest())[:16].decode()
+ # the cache name is the concatenation of the hex16 string and the file name component of the URL
+ cachename = "data__" + hex16 + "__" + os.path.basename(urlparse(shard).path)
+ checksum = None
+ cursor.execute(
+ "INSERT OR REPLACE INTO cache VALUES (?, ?, ?)",
+ (shard, cachename, checksum),
+ )
+ connection.commit()
+ return os.path.join(dldir, cachename)
+
+ return f
+
+
+def cache_localname(cachedir):
+ os.makedirs(cachedir, exist_ok=True)
+
+ def f(shard):
+ """Given a URL, return a local name for the shard."""
+ path = urlparse(shard).path
+ fname = os.path.basename(path)
+ return os.path.join(cachedir, fname)
+
+ return f
+
+
+def default_localname(dldir="/tmp/_wids_cache"):
+ os.makedirs(dldir, exist_ok=True)
+
+ def f(shard):
+ """Given a URL, return a local name for the shard."""
+ cachename = quote(shard, safe="+-")
+ return os.path.join(dldir, cachename)
+
+ return f
+
+
+class LRUShards:
+ """A class that manages a cache of shards. The cache is a LRU cache that
+ stores the local names of the shards as keys and the downloaded paths as
+ values. The shards are downloaded to a directory specified by dldir.
+ The local name of a shard is computed by the localname function, which
+ takes the shard URL as an argument. If keep is True, the downloaded files
+ are not deleted when they are no longer needed.
+ """
+
+ def __init__(self, lru_size, keep=False, localname=default_localname()):
+ self.localname = localname
+ # the cache contains the local name as the key and the downloaded path as the value
+ self.lru = LRUCache(lru_size, release_handler=self.release_handler)
+ # keep statistics
+ self.reset_stats()
+
+ def reset_stats(self):
+ self.accesses = 0
+ self.misses = 0
+
+ def __len__(self):
+ return len(self.lru)
+
+ def release_handler(self, key, value):
+ value.close()
+
+ def clear(self):
+ self.lru.clear()
+
+ def get_shard(self, url):
+ assert isinstance(url, str)
+ self.accesses += 1
+ if url not in self.lru:
+ local = self.localname(url)
+ with download_and_open(url, local) as stream:
+ itf = IndexedTarSamples(path=local, stream=stream)
+ self.lru[url] = itf
+ self.misses += 1
+ self.last_missed = True
+ else:
+ self.last_missed = False
+ return self.lru[url]
+
+
+def interpret_transformations(transformations):
+ """Interpret the transformations argument.
+
+ This takes care of transformations specified as string shortcuts
+ and returns a list of callables.
+ """
+ if not isinstance(transformations, list):
+ transformations = [transformations]
+
+ result = []
+
+ for transformation in transformations:
+ if transformation == "PIL":
+ transformation = partial(default_decoder, format="PIL")
+ elif transformation == "numpy":
+ transformation = partial(default_decoder, format="numpy")
+ else:
+ assert callable(transformation)
+ result.append(transformation)
+
+ return result
+
+
+def hash_dataset_name(input_string):
+ """Compute a hash of the input string and return the first 16 characters of the hash."""
+ # Compute SHA256 hash of the input string
+ hash_object = hashlib.sha256(input_string.encode())
+ hash_digest = hash_object.digest()
+
+ # Encode the hash in base64
+ base64_encoded_hash = base64.urlsafe_b64encode(hash_digest)
+
+ # Return the first 16 characters of the base64-encoded hash
+ return base64_encoded_hash[:16].decode("ascii")
+
+
+@lru_cache(maxsize=16)
+def lru_json_load(fpath):
+ with open(fpath) as fp:
+ return json.load(fp)
+
+
+class ShardListDataset(Dataset[T]):
+ """An indexable dataset based on a list of shards.
+
+ The dataset is either given as a list of shards with optional options and name,
+ or as a URL pointing to a JSON descriptor file.
+
+ Datasets can reference other datasets via `source_url`.
+
+ Shard references within a dataset are resolve relative to an explicitly
+ given `base` property, or relative to the URL from which the dataset
+ descriptor was loaded.
+ """
+
+ def __init__(
+ self,
+ shards,
+ *,
+ cache_size=int(1e12),
+ cache_dir=None,
+ lru_size=10,
+ dataset_name=None,
+ localname=None,
+ transformations="PIL",
+ keep=False,
+ base=None,
+ options=None,
+ ):
+ """Create a ShardListDataset.
+
+ Args:
+ shards: a list of (filename, length) pairs or a URL pointing to a JSON descriptor file
+ cache_size: the number of shards to keep in the cache
+ lru_size: the number of shards to keep in the LRU cache
+ localname: a function that maps URLs to local filenames
+
+ Note that there are two caches: an on-disk directory, and an in-memory LRU cache.
+ """
+ if options is None:
+ options = {}
+ super().__init__()
+ # shards is a list of (filename, length) pairs. We'll need to
+ # keep track of the lengths and cumulative lengths to know how
+ # to map indices to shards and indices within shards.
+ if isinstance(shards, (str, io.IOBase)):
+ if base is None and isinstance(shards, str):
+ shards = osp.expanduser(shards)
+ base = urldir(shards)
+ self.base = base
+ self.spec = load_dsdesc_and_resolve(shards, options=options, base=base)
+ self.shards = self.spec.get("shardlist", [])
+ self.dataset_name = self.spec.get("name") or hash_dataset_name(str(shards))
+ else:
+ raise NotImplementedError("Only support taking path/url to JSON descriptor file.")
+ self.base = None
+ self.spec = options
+ self.shards = shards
+ self.dataset_name = dataset_name or hash_dataset_name(str(shards))
+
+ self.lengths = [shard["nsamples"] for shard in self.shards]
+ self.cum_lengths = np.cumsum(self.lengths)
+ self.total_length = self.cum_lengths[-1]
+
+ if cache_dir is not None:
+ # when a cache dir is explicitly given, we download files into
+ # that directory without any changes
+ self.cache_dir = cache_dir
+ self.localname = cache_localname(cache_dir)
+ elif localname is not None:
+ # when a localname function is given, we use that
+ self.cache_dir = None
+ self.localname = localname
+ else:
+ import getpass
+
+ # when no cache dir or localname are given, use the cache from the environment
+ self.cache_dir = os.environ.get("WIDS_CACHE", f"~/.cache/_wids_cache")
+ self.cache_dir = osp.expanduser(self.cache_dir)
+ self.localname = default_localname(self.cache_dir)
+
+ self.data_info = (
+ f"[WebShardedList] {str(shards)}, base: {self.base,}, name: {self.spec.get('name')}, "
+ f"nfiles: {str(len(self.shards))}"
+ )
+ if True or int(os.environ.get("WIDS_VERBOSE", 0)):
+ nbytes = sum(shard.get("filesize", 0) for shard in self.shards)
+ nsamples = sum(shard["nsamples"] for shard in self.shards)
+ self.data_info += f"nbytes: {str(nbytes)}, samples: {str(nsamples),}, cache: {self.cache_dir} "
+ # print(
+ # "[WebShardedList]",
+ # str(shards),
+ # "base:",
+ # self.base,
+ # "name:",
+ # self.spec.get("name"),
+ # "nfiles:",
+ # len(self.shards),
+ # "nbytes:",
+ # nbytes,
+ # "samples:",
+ # nsamples,
+ # "cache:",
+ # self.cache_dir,
+ # file=sys.stderr,
+ # )
+ self.transformations = interpret_transformations(transformations)
+
+ if lru_size > 200:
+ warnings.warn("LRU size is very large; consider reducing it to avoid running out of file descriptors")
+ self.cache = LRUShards(lru_size, localname=self.localname, keep=keep)
+
+ def add_transform(self, transform):
+ """Add a transformation to the dataset."""
+ self.transformations.append(transform)
+ return self
+
+ def __len__(self):
+ """Return the total number of samples in the dataset."""
+ return self.total_length
+
+ def get_stats(self):
+ """Return the number of cache accesses and misses."""
+ return self.cache.accesses, self.cache.misses
+
+ def check_cache_misses(self):
+ """Check if the cache miss rate is too high."""
+ accesses, misses = self.get_stats()
+ if accesses > 100 and misses / accesses > 0.3:
+ # output a warning only once
+ self.check_cache_misses = lambda: None
+ print(f"Warning: ShardListDataset has a cache miss rate of {misses * 100.0 / accesses:.1%}%")
+
+ def get_shard(self, index):
+ """Get the shard and index within the shard corresponding to the given index."""
+ # Find the shard corresponding to the given index.
+ shard_idx = np.searchsorted(self.cum_lengths, index, side="right")
+
+ # Figure out which index within the shard corresponds to the
+ # given index.
+ if shard_idx == 0:
+ inner_idx = index
+ else:
+ inner_idx = index - self.cum_lengths[shard_idx - 1]
+
+ # Get the shard and return the corresponding element.
+ desc = self.shards[shard_idx]
+ url = desc["url"]
+ if url.startswith(("https://", "http://", "gs://", "/", "~")):
+ # absolute path or url path
+ url = url
+ else:
+ # concat relative path
+ if self.base is None and "base_path" not in self.spec:
+ raise FileNotFoundError("passing a relative path in shardlist but no base found.")
+ base_path = self.spec["base_path"] if "base_path" in self.spec else self.base
+ url = osp.abspath(osp.join(osp.expanduser(base_path), url))
+
+ desc["url"] = url
+ try:
+ shard = self.cache.get_shard(url)
+ except UnicodeDecodeError as e:
+ print("UnicodeDecodeError:", desc)
+ raise e
+ return shard, inner_idx, desc
+
+ def __getitem__(self, index):
+ """Return the sample corresponding to the given index."""
+ shard, inner_idx, desc = self.get_shard(index)
+ sample = shard[inner_idx]
+
+ # Check if we're missing the cache too often.
+ self.check_cache_misses()
+
+ sample["__dataset__"] = desc.get("dataset")
+ sample["__index__"] = index
+ sample["__shard__"] = desc["url"]
+ sample["__shardindex__"] = inner_idx
+
+ # Apply transformations
+ for transform in self.transformations:
+ sample = transform(sample)
+
+ return sample
+
+ def close(self):
+ """Close the dataset."""
+ self.cache.clear()
+
+
+class ShardListDatasetMulti(ShardListDataset):
+ """An indexable dataset based on a list of shards.
+
+ The dataset is either given as a list of shards with optional options and name,
+ or as a URL pointing to a JSON descriptor file.
+
+ Datasets can reference other datasets via `source_url`.
+
+ Shard references within a dataset are resolve relative to an explicitly
+ given `base` property, or relative to the URL from which the dataset
+ descriptor was loaded.
+ """
+
+ def __init__(
+ self,
+ shards,
+ *,
+ cache_size=int(1e12),
+ cache_dir=None,
+ lru_size=10,
+ dataset_name=None,
+ localname=None,
+ transformations="PIL",
+ keep=False,
+ base=None,
+ options=None,
+ sort_data_inseq=False,
+ num_replicas=None,
+ ):
+ """Create a ShardListDataset.
+
+ Args:
+ shards: a list of (filename, length) pairs or a URL pointing to a JSON descriptor file
+ cache_size: the number of shards to keep in the cache
+ lru_size: the number of shards to keep in the LRU cache
+ localname: a function that maps URLs to local filenames
+
+ Note that there are two caches: an on-disk directory, and an in-memory LRU cache.
+ """
+ if options is None:
+ options = {}
+ # shards is a list of (filename, length) pairs. We'll need to
+ # keep track of the lengths and cumulative lengths to know how
+ # to map indices to shards and indices within shards.
+ shards_lists = shards if isinstance(shards, list) else [shards]
+ bases = base if isinstance(base, list) else [base] * len(shards_lists)
+ self.spec = {}
+ self.shards = []
+ self.num_per_dir = {}
+ for base, shards in zip(bases, shards_lists):
+ if isinstance(shards, (str, io.IOBase)):
+ if base is None and isinstance(shards, str):
+ shards = osp.expanduser(shards)
+ base = urldir(shards)
+ self.base = base
+ _spec = load_dsdesc_and_resolve(shards, options=options, base=base)
+ update_dict_with_extend(self.spec, _spec)
+ self.num_per_dir[os.path.basename(os.path.dirname(shards))] = sum(
+ [shard["nsamples"] for shard in _spec.get("shardlist", [])]
+ )
+ else:
+ raise NotImplementedError("Only support taking path/url to JSON descriptor file.")
+ self.base = None
+ self.spec = options
+ self.shards = shards
+ self.dataset_name = dataset_name or hash_dataset_name(str(shards))
+
+ if sort_data_inseq and len(self.spec.get("shardlist", [])) > 0:
+ num_replicas = num_replicas or dist.get_world_size()
+ self.spec["shardlist"] = split_and_recombine(self.spec["shardlist"], num_replicas)
+
+ self.shards.extend(self.spec.get("shardlist", []))
+ self.dataset_name = self.spec.get("name") or hash_dataset_name(str(shards))
+
+ self.lengths = [shard["nsamples"] for shard in self.shards]
+ self.cum_lengths = np.cumsum(self.lengths)
+ self.total_length = self.cum_lengths[-1]
+
+ if cache_dir is not None:
+ # when a cache dir is explicitly given, we download files into
+ # that directory without any changes
+ self.cache_dir = cache_dir
+ self.localname = cache_localname(cache_dir)
+ elif localname is not None:
+ # when a localname function is given, we use that
+ self.cache_dir = None
+ self.localname = localname
+ else:
+ import getpass
+
+ # when no cache dir or localname are given, use the cache from the environment
+ self.cache_dir = os.environ.get("WIDS_CACHE", f"~/.cache/_wids_cache")
+ self.cache_dir = osp.expanduser(self.cache_dir)
+ self.localname = default_localname(self.cache_dir)
+
+ self.data_info = (
+ f"[WebShardedList] {str(shards)}, base: {self.base,}, name: {self.spec.get('name')}, "
+ f"nfiles: {str(len(self.shards))}"
+ )
+ if True or int(os.environ.get("WIDS_VERBOSE", 0)):
+ nbytes = sum(shard.get("filesize", 0) for shard in self.shards)
+ nsamples = sum(shard["nsamples"] for shard in self.shards)
+ self.data_info += f"nbytes: {str(nbytes)}, samples: {str(nsamples),}, cache: {self.cache_dir} "
+ self.transformations = interpret_transformations(transformations)
+
+ if lru_size > 200:
+ warnings.warn("LRU size is very large; consider reducing it to avoid running out of file descriptors")
+ self.cache = LRUShards(lru_size, localname=self.localname, keep=keep)
+
+
+def split_and_recombine(lst, n):
+ from collections import OrderedDict
+
+ def extract_prefix(i):
+ return i["url"].split("/")[-2]
+
+ unique_parts = list(OrderedDict((extract_prefix(item), None) for item in lst).keys())
+ split_dict = {part: [] for part in unique_parts}
+
+ for part in unique_parts:
+ part_list = [item for item in lst if extract_prefix(item) == part]
+ chunk_size = max(1, len(part_list) // n) # 确保 chunk_size 至少为 1
+ chunks = [part_list[i * chunk_size : (i + 1) * chunk_size] for i in range(n)]
+
+ # 处理最后一个 chunk,如果数量不均匀,将剩余的元素添加到最后一个 chunk
+ if len(part_list) % n != 0:
+ chunks[-1].extend(part_list[n * chunk_size :])
+
+ split_dict[part] = chunks
+
+ recombined_list = []
+ for i in range(n):
+ for part in unique_parts:
+ recombined_list.extend(split_dict[part][i])
+
+ return recombined_list
+
+
+def lengths_to_ranges(lengths):
+ """Convert a list of lengths to a list of ranges."""
+ ranges = []
+ start = 0
+ for length in lengths:
+ ranges.append((start, start + length))
+ start += length
+ return ranges
+
+
+def intersect_range(a, b):
+ """Return the intersection of the two half-open integer intervals."""
+ result = max(a[0], b[0]), min(a[1], b[1])
+ if result[0] >= result[1]:
+ return None
+ return result
+
+
+def intersect_ranges(rangelist, r):
+ """Return the intersection of the half-open integer interval r with the list of half-open integer intervals."""
+ result = []
+ for a in rangelist:
+ x = intersect_range(a, r)
+ if x is not None:
+ result.append(x)
+ return result
+
+
+def iterate_ranges(ranges, rng, indexshuffle=True, shardshuffle=True):
+ """Iterate over the ranges in a random order."""
+ shard_indexes = list(range(len(ranges)))
+ if shardshuffle:
+ rng.shuffle(shard_indexes)
+ for i in shard_indexes:
+ lo, hi = ranges[i]
+ sample_indexes = list(range(lo, hi))
+ if indexshuffle:
+ rng.shuffle(sample_indexes)
+ yield from sample_indexes
+
+
+class ShardListSampler(Sampler):
+ """A sampler that samples consistent with a ShardListDataset.
+
+ This sampler is used to sample from a ShardListDataset in a way that
+ preserves locality.
+
+ This returns a permutation of the indexes by shard, then a permutation of
+ indexes within each shard. This ensures that the data is accessed in a
+ way that preserves locality.
+
+ Note that how this ends up splitting data between multiple workers ends up
+ on the details of the DataLoader. Generally, it will likely load samples from the
+ same shard in each worker.
+
+ Other more sophisticated shard-aware samplers are possible and will likely
+ be added.
+ """
+
+ def __init__(self, dataset, *, lengths=None, seed=0, shufflefirst=False):
+ if lengths is None:
+ lengths = list(dataset.lengths)
+ self.ranges = lengths_to_ranges(lengths)
+ self.seed = seed
+ self.shufflefirst = shufflefirst
+ self.epoch = 0
+
+ def __iter__(self):
+ self.rng = random.Random(self.seed + 1289738273 * self.epoch)
+ shardshuffle = self.shufflefirst or self.epoch > 0
+ yield from iterate_ranges(self.ranges, self.rng, shardshuffle=shardshuffle)
+ self.epoch += 1
+
+
+ShardedSampler = ShardListSampler
+
+
+class ChunkedSampler(Sampler):
+ """A sampler that samples in chunks and then shuffles the samples within each chunk.
+
+ This preserves locality of reference while still shuffling the data.
+ """
+
+ def __init__(
+ self,
+ dataset,
+ *,
+ num_samples=None,
+ chunksize=2000,
+ seed=0,
+ shuffle=False,
+ shufflefirst=False,
+ ):
+ if isinstance(num_samples, int):
+ lo, hi = 0, num_samples
+ elif num_samples is None:
+ lo, hi = 0, len(dataset)
+ else:
+ lo, hi = num_samples
+ self.ranges = [(i, min(i + chunksize, hi)) for i in range(lo, hi, chunksize)]
+ self.seed = seed
+ self.shuffle = shuffle
+ self.shufflefirst = shufflefirst
+ self.epoch = 0
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+ def __iter__(self):
+ self.rng = random.Random(self.seed + 1289738273 * self.epoch)
+ shardshuffle = self.shufflefirst or self.epoch > 0
+ yield from iterate_ranges(
+ self.ranges,
+ self.rng,
+ indexshuffle=self.shuffle,
+ shardshuffle=(self.shuffle and shardshuffle),
+ )
+ self.epoch += 1
+
+ def __len__(self):
+ return len(self.ranges)
+
+
+def DistributedChunkedSampler(
+ dataset: Dataset,
+ *,
+ num_replicas: Optional[int] = None,
+ num_samples: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = True,
+ shufflefirst: bool = False,
+ seed: int = 0,
+ drop_last: bool = None,
+ chunksize: int = 1000000,
+) -> ChunkedSampler:
+ """Return a ChunkedSampler for the current worker in distributed training.
+
+ Reverts to a simple ChunkedSampler if not running in distributed mode.
+
+ Since the split among workers takes place before the chunk shuffle,
+ workers end up with a fixed set of shards they need to download. The
+ more workers, the fewer shards are used by each worker.
+ """
+ if drop_last is not None:
+ warnings.warn("DistributedChunkedSampler does not support drop_last, thus it will be ignored")
+ if not dist.is_initialized():
+ warnings.warn("DistributedChunkedSampler is called without distributed initialized; assuming single process")
+ num_replicas = 1
+ rank = 0
+ else:
+ num_replicas = num_replicas or dist.get_world_size()
+ rank = rank or dist.get_rank()
+ assert rank >= 0 and rank < num_replicas
+
+ num_samples = num_samples or len(dataset)
+ worker_chunk = (num_samples + num_replicas - 1) // num_replicas
+ worker_start = rank * worker_chunk
+ worker_end = min(worker_start + worker_chunk, num_samples)
+ return ChunkedSampler(
+ dataset,
+ num_samples=(worker_start, worker_end),
+ chunksize=chunksize,
+ seed=seed,
+ shuffle=shuffle,
+ shufflefirst=shufflefirst,
+ )
+
+
+class DistributedRangedSampler(Sampler):
+ """A sampler that samples in chunks and then shuffles the samples within each chunk.
+
+ This preserves locality of reference while still shuffling the data.
+ """
+
+ def __init__(
+ self,
+ dataset: Dataset,
+ num_replicas: Optional[int] = None,
+ num_samples: Optional[int] = None,
+ rank: Optional[int] = None,
+ drop_last: bool = None,
+ ):
+ if drop_last is not None:
+ warnings.warn("DistributedChunkedSampler does not support drop_last, thus it will be ignored")
+ if not dist.is_initialized():
+ warnings.warn(
+ "DistributedChunkedSampler is called without distributed initialized; assuming single process"
+ )
+ num_replicas = 1
+ rank = 0
+ else:
+ num_replicas = num_replicas or dist.get_world_size()
+ rank = rank or dist.get_rank()
+ assert rank >= 0 and rank < num_replicas
+ num_samples = num_samples or len(dataset)
+ self.worker_chunk = num_samples // num_replicas
+ self.worker_start = rank * self.worker_chunk
+ self.worker_end = min((rank + 1) * self.worker_chunk, num_samples)
+ self.ranges = range(self.worker_start, self.worker_end)
+ self.epoch = 0
+ self.step_start = 0
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+ def __len__(self):
+ return len(self.ranges)
+
+ def set_start(self, start):
+ self.step_start = start
+
+ def __iter__(self):
+ yield from self.ranges[self.step_start :]
+ self.epoch += 1
+
+
+class DistributedLocalSampler(DistributedSampler):
+ def __iter__(self):
+ if self.shuffle:
+ # deterministically shuffle based on epoch and seed
+ g = torch.Generator()
+ g.manual_seed(self.seed + self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
+ else:
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
+
+ if not self.drop_last:
+ # add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ if padding_size <= len(indices):
+ indices += indices[:padding_size]
+ else:
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
+ else:
+ # remove tail of data to make it evenly divisible.
+ indices = indices[: self.total_size]
+ assert len(indices) == self.total_size
+
+ # subsample
+ # indices = indices[self.rank:self.total_size:self.num_replicas]
+ chunk_size = self.total_size // self.num_replicas
+ begin_idx = chunk_size * self.rank
+ stop_idx = chunk_size * (self.rank + 1)
+ indices = indices[begin_idx:stop_idx]
+
+ # print("[SamplerIndices: ]", indices)
+ assert len(indices) == self.num_samples
+ return iter(indices)
diff --git a/diffusion/data/wids/wids_dl.py b/diffusion/data/wids/wids_dl.py
new file mode 100755
index 0000000..3eed890
--- /dev/null
+++ b/diffusion/data/wids/wids_dl.py
@@ -0,0 +1,174 @@
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
+import fcntl
+import os
+import shutil
+import sys
+import time
+from collections import deque
+from datetime import datetime
+from urllib.parse import urlparse
+
+recent_downloads = deque(maxlen=1000)
+
+open_objects = {}
+max_open_objects = 100
+
+
+class ULockFile:
+ """A simple locking class. We don't need any of the third
+ party libraries since we rely on POSIX semantics for linking
+ below anyway."""
+
+ def __init__(self, path):
+ self.lockfile_path = path
+ self.lockfile = None
+
+ def __enter__(self):
+ self.lockfile = open(self.lockfile_path, "w")
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN)
+ self.lockfile.close()
+ self.lockfile = None
+ try:
+ os.unlink(self.lockfile_path)
+ except FileNotFoundError:
+ pass
+
+
+def pipe_download(remote, local):
+ """Perform a download for a pipe: url."""
+ assert remote.startswith("pipe:")
+ cmd = remote[5:]
+ cmd = cmd.format(local=local)
+ assert os.system(cmd) == 0, "Command failed: %s" % cmd
+
+
+def copy_file(remote, local):
+ remote = urlparse(remote)
+ assert remote.scheme in ["file", ""]
+ # use absolute path
+ remote = os.path.abspath(remote.path)
+ local = urlparse(local)
+ assert local.scheme in ["file", ""]
+ local = os.path.abspath(local.path)
+ if remote == local:
+ return
+ # check if the local file exists
+ shutil.copyfile(remote, local)
+
+
+verbose_cmd = int(os.environ.get("WIDS_VERBOSE_CMD", "0"))
+
+
+def vcmd(flag, verbose_flag=""):
+ return verbose_flag if verbose_cmd else flag
+
+
+default_cmds = {
+ "posixpath": copy_file,
+ "file": copy_file,
+ "pipe": pipe_download,
+ "http": "curl " + vcmd("-s") + " -L {url} -o {local}",
+ "https": "curl " + vcmd("-s") + " -L {url} -o {local}",
+ "ftp": "curl " + vcmd("-s") + " -L {url} -o {local}",
+ "ftps": "curl " + vcmd("-s") + " -L {url} -o {local}",
+ "gs": "gsutil " + vcmd("-q") + " cp {url} {local}",
+ "s3": "aws s3 cp {url} {local}",
+}
+
+
+# TODO(ligeng): change HTTPS download to python requests library
+
+
+def download_file_no_log(remote, local, handlers=default_cmds):
+ """Download a file from a remote url to a local path.
+ The remote url can be a pipe: url, in which case the remainder of
+ the url is treated as a command template that is executed to perform the download.
+ """
+
+ if remote.startswith("pipe:"):
+ schema = "pipe"
+ else:
+ schema = urlparse(remote).scheme
+ if schema is None or schema == "":
+ schema = "posixpath"
+ # get the handler
+ handler = handlers.get(schema)
+ if handler is None:
+ raise ValueError("Unknown schema: %s" % schema)
+ # call the handler
+ if callable(handler):
+ handler(remote, local)
+ else:
+ assert isinstance(handler, str)
+ cmd = handler.format(url=remote, local=local)
+ assert os.system(cmd) == 0, "Command failed: %s" % cmd
+ return local
+
+
+def download_file(remote, local, handlers=default_cmds, verbose=False):
+ start = time.time()
+ try:
+ return download_file_no_log(remote, local, handlers=handlers)
+ finally:
+ recent_downloads.append((remote, local, time.time(), time.time() - start))
+ if verbose:
+ print(
+ "downloaded",
+ remote,
+ "to",
+ local,
+ "in",
+ time.time() - start,
+ "seconds",
+ file=sys.stderr,
+ )
+
+
+def download_and_open(remote, local, mode="rb", handlers=default_cmds, verbose=False):
+ with ULockFile(local + ".lock"):
+ if os.path.exists(remote):
+ # print("enter1", remote, local, mode)
+ result = open(remote, mode)
+ else:
+ # print("enter2", remote, local, mode)
+ if not os.path.exists(local):
+ if verbose:
+ print("downloading", remote, "to", local, file=sys.stderr)
+ download_file(remote, local, handlers=handlers)
+ else:
+ if verbose:
+ print("using cached", local, file=sys.stderr)
+ result = open(local, mode)
+
+ # input()
+
+ if open_objects is not None:
+ for k, v in list(open_objects.items()):
+ if v.closed:
+ del open_objects[k]
+ if len(open_objects) > max_open_objects:
+ raise RuntimeError("Too many open objects")
+ current_time = datetime.now().strftime("%Y%m%d%H%M%S")
+ key = tuple(str(x) for x in [remote, local, mode, current_time])
+ open_objects[key] = result
+ return result
diff --git a/diffusion/data/wids/wids_lru.py b/diffusion/data/wids/wids_lru.py
new file mode 100755
index 0000000..7bd0106
--- /dev/null
+++ b/diffusion/data/wids/wids_lru.py
@@ -0,0 +1,81 @@
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
+from collections import OrderedDict
+
+
+class LRUCache:
+ def __init__(self, capacity: int, release_handler=None):
+ """Initialize a new LRU cache with the given capacity."""
+ self.capacity = capacity
+ self.cache = OrderedDict()
+ self.release_handler = release_handler
+
+ def __getitem__(self, key):
+ """Return the value associated with the given key, or None."""
+ if key not in self.cache:
+ return None
+ self.cache.move_to_end(key)
+ return self.cache[key]
+
+ def __setitem__(self, key, value):
+ """Associate the given value with the given key."""
+ if key in self.cache:
+ self.cache.move_to_end(key)
+ self.cache[key] = value
+ if len(self.cache) > self.capacity:
+ key, value = self.cache.popitem(last=False)
+ if self.release_handler is not None:
+ self.release_handler(key, value)
+
+ def __delitem__(self, key):
+ """Remove the given key from the cache."""
+ if key in self.cache:
+ if self.release_handler is not None:
+ value = self.cache[key]
+ self.release_handler(key, value)
+ del self.cache[key]
+
+ def __len__(self):
+ """Return the number of entries in the cache."""
+ return len(self.cache)
+
+ def __contains__(self, key):
+ """Return whether the cache contains the given key."""
+ return key in self.cache
+
+ def items(self):
+ """Return an iterator over the keys of the cache."""
+ return self.cache.items()
+
+ def keys(self):
+ """Return an iterator over the keys of the cache."""
+ return self.cache.keys()
+
+ def values(self):
+ """Return an iterator over the values of the cache."""
+ return self.cache.values()
+
+ def clear(self):
+ for key in list(self.keys()):
+ value = self.cache[key]
+ if self.release_handler is not None:
+ self.release_handler(key, value)
+ del self[key]
+
+ def __del__(self):
+ self.clear()
diff --git a/diffusion/data/wids/wids_mmtar.py b/diffusion/data/wids/wids_mmtar.py
new file mode 100755
index 0000000..72ac778
--- /dev/null
+++ b/diffusion/data/wids/wids_mmtar.py
@@ -0,0 +1,168 @@
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
+import collections
+import fcntl
+import io
+import mmap
+import os
+import struct
+
+TarHeader = collections.namedtuple(
+ "TarHeader",
+ [
+ "name",
+ "mode",
+ "uid",
+ "gid",
+ "size",
+ "mtime",
+ "chksum",
+ "typeflag",
+ "linkname",
+ "magic",
+ "version",
+ "uname",
+ "gname",
+ "devmajor",
+ "devminor",
+ "prefix",
+ ],
+)
+
+
+def parse_tar_header(header_bytes):
+ header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes)
+ return TarHeader(*header)
+
+
+def next_header(offset, header):
+ block_size = 512
+ size = header.size.decode("utf-8").strip("\x00")
+ if size == "":
+ return -1
+ size = int(size, 8)
+ # compute the file size rounded up to the next block size if it is a partial block
+ padded_file_size = (size + block_size - 1) // block_size * block_size
+ return offset + block_size + padded_file_size
+
+
+# TODO(ligeng): support gzip stream
+class MMIndexedTar:
+ def __init__(self, fname, index_file=None, verbose=True, cleanup_callback=None):
+ self.verbose = verbose
+ self.cleanup_callback = cleanup_callback
+ if isinstance(fname, str):
+ self.stream = open(fname, "rb")
+ self.fname = fname
+ elif isinstance(fname, io.IOBase):
+ self.stream = fname
+ self.fname = None
+ self.mmapped_file = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ)
+ if cleanup_callback:
+ cleanup_callback(fname, self.stream.fileno(), "start")
+ self._build_index()
+
+ def close(self, dispose=False):
+ if self.cleanup_callback:
+ self.cleanup_callback(self.fname, self.stream.fileno(), "end")
+ self.mmapped_file.close()
+ self.stream.close()
+
+ def _build_index(self):
+ self.by_name = {}
+ self.by_index = []
+ offset = 0
+ while offset >= 0 and offset < len(self.mmapped_file):
+ header = parse_tar_header(self.mmapped_file[offset : offset + 500])
+ name = header.name.decode("utf-8").strip("\x00")
+ typeflag = header.typeflag.decode("utf-8").strip("\x00")
+ if name != "" and name != "././@PaxHeader" and typeflag in ["0", ""]:
+ try:
+ size = int(header.size.decode("utf-8")[:-1], 8)
+ except ValueError as exn:
+ print(header)
+ raise exn
+ self.by_name[name] = offset
+ self.by_index.append((name, offset, size))
+ offset = next_header(offset, header)
+
+ def names(self):
+ return self.by_name.keys()
+
+ def get_at_offset(self, offset):
+ header = parse_tar_header(self.mmapped_file[offset : offset + 500])
+ name = header.name.decode("utf-8").strip("\x00")
+ start = offset + 512
+ end = start + int(header.size.decode("utf-8")[:-1], 8)
+ return name, self.mmapped_file[start:end]
+
+ def get_at_index(self, index):
+ name, offset, size = self.by_index[index]
+ return self.get_at_offset(offset)
+
+ def get_by_name(self, name):
+ offset = self.by_name[name]
+ return self.get_at_offset(offset)
+
+ def __iter__(self):
+ for name, offset, size in self.by_index:
+ yield name, self.mmapped_file[offset + 512 : offset + 512 + size]
+
+ def __getitem__(self, key):
+ if isinstance(key, int):
+ return self.get_at_index(key)
+ else:
+ return self.get_by_name(key)
+
+ def __len__(self):
+ return len(self.by_index)
+
+ def get_file(self, i):
+ fname, data = self.get_at_index(i)
+ return fname, io.BytesIO(data)
+
+
+def keep_while_reading(fname, fd, phase, delay=0.0):
+ """This is a possible cleanup callback for cleanup_callback of MIndexedTar.
+
+ It assumes that as long as there are some readers for a file,
+ more readers may be trying to open it.
+
+ Note that on Linux, unlinking the file doesn't matter after
+ it has been mmapped. The contents will only be deleted when
+ all readers close the file. The unlinking merely makes the file
+ unavailable to new readers, since the downloader checks first
+ whether the file exists.
+ """
+ assert delay == 0.0, "delay not implemented"
+ if fd < 0 or fname is None:
+ return
+ if phase == "start":
+ fcntl.flock(fd, fcntl.LOCK_SH)
+ elif phase == "end":
+ try:
+ fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
+ os.unlink(fname)
+ except FileNotFoundError:
+ # someone else deleted it already
+ pass
+ except BlockingIOError:
+ # we couldn't get an exclusive lock, so someone else is still reading
+ pass
+ else:
+ raise ValueError(f"Unknown phase {phase}")
diff --git a/diffusion/data/wids/wids_specs.py b/diffusion/data/wids/wids_specs.py
new file mode 100755
index 0000000..0178e75
--- /dev/null
+++ b/diffusion/data/wids/wids_specs.py
@@ -0,0 +1,192 @@
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
+import io
+import json
+import os
+import tempfile
+from urllib.parse import urlparse, urlunparse
+
+from .wids_dl import download_and_open
+
+
+def urldir(url):
+ """Return the directory part of a url."""
+ parsed_url = urlparse(url)
+ path = parsed_url.path
+ directory = os.path.dirname(path)
+ return parsed_url._replace(path=directory).geturl()
+
+
+def urlmerge(base, url):
+ """Merge a base URL and a relative URL.
+
+ The function fills in any missing part of the url from the base,
+ except for params, query, and fragment, which are taken only from the 'url'.
+ For the pathname component, it merges the paths like os.path.join:
+ an absolute path in 'url' overrides the base path, otherwise the paths are merged.
+
+ Parameters:
+ base (str): The base URL.
+ url (str): The URL to merge with the base.
+
+ Returns:
+ str: The merged URL.
+ """
+ # Parse the base and the relative URL
+ parsed_base = urlparse(base)
+ parsed_url = urlparse(url)
+
+ # Merge paths using os.path.join
+ # If the url path is absolute, it overrides the base path
+ if parsed_url.path.startswith("/"):
+ merged_path = parsed_url.path
+ else:
+ merged_path = os.path.normpath(os.path.join(parsed_base.path, parsed_url.path))
+
+ # Construct the merged URL
+ merged_url = urlunparse(
+ (
+ parsed_url.scheme or parsed_base.scheme,
+ parsed_url.netloc or parsed_base.netloc,
+ merged_path,
+ parsed_url.params, # Use params from the url only
+ parsed_url.query, # Use query from the url only
+ parsed_url.fragment, # Use fragment from the url only
+ )
+ )
+
+ return merged_url
+
+
+def check_shards(l):
+ """Check that a list of shards is well-formed.
+
+ This checks that the list is a list of dictionaries, and that
+ each dictionary has a "url" and a "nsamples" key.
+ """
+ assert isinstance(l, list)
+ for shard in l:
+ assert isinstance(shard, dict)
+ assert "url" in shard
+ assert "nsamples" in shard
+ return l
+
+
+def set_all(l, k, v):
+ """Set a key to a value in a list of dictionaries."""
+ if v is None:
+ return
+ for x in l:
+ if k not in x:
+ x[k] = v
+
+
+def load_remote_dsdesc_raw(source):
+ """Load a remote or local dataset description in JSON format."""
+ if isinstance(source, str):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ dlname = os.path.join(tmpdir, "dataset.json")
+ with download_and_open(source, dlname) as f:
+ dsdesc = json.load(f)
+ elif isinstance(source, io.IOBase):
+ dsdesc = json.load(source)
+ else:
+ # FIXME: use gopen
+ import requests
+
+ jsondata = requests.get(source).text
+ dsdesc = json.loads(jsondata)
+ return dsdesc
+
+
+def rebase_shardlist(shardlist, base):
+ """Rebase the URLs in a shardlist."""
+ if base is None:
+ return shardlist
+ for shard in shardlist:
+ shard["url"] = urlmerge(base, shard["url"])
+ return shardlist
+
+
+def resolve_dsdesc(dsdesc, *, options=None, base=None):
+ """Resolve a dataset description.
+
+ This rebases the shards as necessary and loads any remote references.
+
+ Dataset descriptions are JSON files. They must have the following format;
+
+ {
+ "wids_version": 1,
+ # optional immediate shardlist
+ "shardlist": [
+ {"url": "http://example.com/file.tar", "nsamples": 1000},
+ ...
+ ],
+ # sub-datasets
+ "datasets": [
+ {"source_url": "http://example.com/dataset.json"},
+ {"shardlist": [
+ {"url": "http://example.com/file.tar", "nsamples": 1000},
+ ...
+ ]}
+ ...
+ ]
+ }
+ """
+ if options is None:
+ options = {}
+ assert isinstance(dsdesc, dict)
+ dsdesc = dict(dsdesc, **options)
+ shardlist = rebase_shardlist(dsdesc.get("shardlist", []), base)
+ assert shardlist is not None
+ set_all(shardlist, "weight", dsdesc.get("weight"))
+ set_all(shardlist, "name", dsdesc.get("name"))
+ check_shards(shardlist)
+ assert "wids_version" in dsdesc, "No wids_version in dataset description"
+ assert dsdesc["wids_version"] == 1, "Unknown wids_version"
+ for component in dsdesc.get("datasets", []):
+ # we use the weight from the reference to the dataset,
+ # regardless of remote loading
+ weight = component.get("weight")
+ # follow any source_url dsdescs through remote loading
+ source_url = None
+ if "source_url" in component:
+ source_url = component["source_url"]
+ component = load_remote_dsdesc_raw(source_url)
+ assert "source_url" not in component, "double indirection in dataset description"
+ assert "shardlist" in component, "no shardlist in dataset description"
+ # if the component has a base, use it to rebase the shardlist
+ # otherwise use the base from the source_url, if any
+ subbase = component.get("base", urldir(source_url) if source_url else None)
+ if subbase is not None:
+ rebase_shardlist(component["shardlist"], subbase)
+ l = check_shards(component["shardlist"])
+ set_all(l, "weight", weight)
+ set_all(l, "source_url", source_url)
+ set_all(l, "dataset", component.get("name"))
+ shardlist.extend(l)
+ assert len(shardlist) > 0, "No shards found"
+ dsdesc["shardlist"] = shardlist
+ return dsdesc
+
+
+def load_dsdesc_and_resolve(source, *, options=None, base=None):
+ if options is None:
+ options = {}
+ dsdesc = load_remote_dsdesc_raw(source)
+ return resolve_dsdesc(dsdesc, base=base, options=options)
diff --git a/diffusion/data/wids/wids_tar.py b/diffusion/data/wids/wids_tar.py
new file mode 100755
index 0000000..270aaaf
--- /dev/null
+++ b/diffusion/data/wids/wids_tar.py
@@ -0,0 +1,98 @@
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
+import io
+import os
+import os.path
+import pickle
+import re
+import tarfile
+
+import numpy as np
+
+
+def find_index_file(file):
+ prefix, last_ext = os.path.splitext(file)
+ if re.match("._[0-9]+_$", last_ext):
+ return prefix + ".index"
+ else:
+ return file + ".index"
+
+
+class TarFileReader:
+ def __init__(self, file, index_file=find_index_file, verbose=True):
+ self.verbose = verbose
+ if callable(index_file):
+ index_file = index_file(file)
+ self.index_file = index_file
+
+ # Open the tar file and keep it open
+ if isinstance(file, str):
+ self.tar_file = tarfile.open(file, "r")
+ else:
+ self.tar_file = tarfile.open(fileobj=file, mode="r")
+
+ # Create the index
+ self._create_tar_index()
+
+ def _create_tar_index(self):
+ if self.index_file is not None and os.path.exists(self.index_file):
+ if self.verbose:
+ print("Loading tar index from", self.index_file)
+ with open(self.index_file, "rb") as stream:
+ self.fnames, self.index = pickle.load(stream)
+ return
+ # Create an empty list for the index
+ self.fnames = []
+ self.index = []
+
+ if self.verbose:
+ print("Creating tar index for", self.tar_file.name, "at", self.index_file)
+ # Iterate over the members of the tar file
+ for member in self.tar_file:
+ # If the member is a file, add it to the index
+ if member.isfile():
+ # Get the file's offset
+ offset = self.tar_file.fileobj.tell()
+ self.fnames.append(member.name)
+ self.index.append([offset, member.size])
+ if self.verbose:
+ print("Done creating tar index for", self.tar_file.name, "at", self.index_file)
+ self.index = np.array(self.index)
+ if self.index_file is not None:
+ if os.path.exists(self.index_file + ".temp"):
+ os.unlink(self.index_file + ".temp")
+ with open(self.index_file + ".temp", "wb") as stream:
+ pickle.dump((self.fnames, self.index), stream)
+ os.rename(self.index_file + ".temp", self.index_file)
+
+ def names(self):
+ return self.fnames
+
+ def __len__(self):
+ return len(self.index)
+
+ def get_file(self, i):
+ name = self.fnames[i]
+ offset, size = self.index[i]
+ self.tar_file.fileobj.seek(offset)
+ file_bytes = self.tar_file.fileobj.read(size)
+ return name, io.BytesIO(file_bytes)
+
+ def close(self):
+ # Close the tar file
+ self.tar_file.close()
diff --git a/diffusion/model/dc_ae/efficientvit/__init__.py b/diffusion/model/dc_ae/efficientvit/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/diffusion/model/dc_ae/efficientvit/ae_model_zoo.py b/diffusion/model/dc_ae/efficientvit/ae_model_zoo.py
new file mode 100644
index 0000000..3f4f0a4
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/ae_model_zoo.py
@@ -0,0 +1,82 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Callable, Optional
+
+import diffusers
+import torch
+from huggingface_hub import PyTorchModelHubMixin
+from torch import nn
+
+from ..efficientvit.models.efficientvit.dc_ae import DCAE, DCAEConfig, dc_ae_f32c32, dc_ae_f64c128, dc_ae_f128c512
+
+__all__ = ["create_dc_ae_model_cfg", "DCAE_HF", "AutoencoderKL"]
+
+
+REGISTERED_DCAE_MODEL: dict[str, tuple[Callable, Optional[str]]] = {
+ "dc-ae-f32c32-in-1.0": (dc_ae_f32c32, None),
+ "dc-ae-f64c128-in-1.0": (dc_ae_f64c128, None),
+ "dc-ae-f128c512-in-1.0": (dc_ae_f128c512, None),
+ #################################################################################################
+ "dc-ae-f32c32-mix-1.0": (dc_ae_f32c32, None),
+ "dc-ae-f64c128-mix-1.0": (dc_ae_f64c128, None),
+ "dc-ae-f128c512-mix-1.0": (dc_ae_f128c512, None),
+ #################################################################################################
+ "dc-ae-f32c32-sana-1.0": (dc_ae_f32c32, None),
+}
+
+
+def create_dc_ae_model_cfg(name: str, pretrained_path: Optional[str] = None) -> DCAEConfig:
+ assert name in REGISTERED_DCAE_MODEL, f"{name} is not supported"
+ dc_ae_cls, default_pt_path = REGISTERED_DCAE_MODEL[name]
+ pretrained_path = default_pt_path if pretrained_path is None else pretrained_path
+ model_cfg = dc_ae_cls(name, pretrained_path)
+ return model_cfg
+
+
+class DCAE_HF(DCAE, PyTorchModelHubMixin):
+ def __init__(self, model_name: str):
+ cfg = create_dc_ae_model_cfg(model_name)
+ DCAE.__init__(self, cfg)
+
+
+class AutoencoderKL(nn.Module):
+ def __init__(self, model_name: str):
+ super().__init__()
+ self.model_name = model_name
+ if self.model_name in ["stabilityai/sd-vae-ft-ema"]:
+ self.model = diffusers.models.AutoencoderKL.from_pretrained(self.model_name)
+ self.spatial_compression_ratio = 8
+ elif self.model_name == "flux-vae":
+ from diffusers import FluxPipeline
+
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
+ self.model = diffusers.models.AutoencoderKL.from_pretrained(pipe.vae.config._name_or_path)
+ self.spatial_compression_ratio = 8
+ else:
+ raise ValueError(f"{self.model_name} is not supported for AutoencoderKL")
+
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
+ if self.model_name in ["stabilityai/sd-vae-ft-ema", "flux-vae"]:
+ return self.model.encode(x).latent_dist.sample()
+ else:
+ raise ValueError(f"{self.model_name} is not supported for AutoencoderKL")
+
+ def decode(self, latent: torch.Tensor) -> torch.Tensor:
+ if self.model_name in ["stabilityai/sd-vae-ft-ema", "flux-vae"]:
+ return self.model.decode(latent).sample
+ else:
+ raise ValueError(f"{self.model_name} is not supported for AutoencoderKL")
diff --git a/diffusion/model/dc_ae/efficientvit/apps/__init__.py b/diffusion/model/dc_ae/efficientvit/apps/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/diffusion/model/dc_ae/efficientvit/apps/setup.py b/diffusion/model/dc_ae/efficientvit/apps/setup.py
new file mode 100644
index 0000000..5ca892b
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/apps/setup.py
@@ -0,0 +1,103 @@
+import os
+import time
+from copy import deepcopy
+from typing import Optional
+
+import torch.backends.cudnn
+import torch.distributed
+import torch.nn as nn
+
+from ..apps.utils import (
+ dist_init,
+ dump_config,
+ get_dist_local_rank,
+ get_dist_rank,
+ get_dist_size,
+ init_modules,
+ is_master,
+ load_config,
+ partial_update_config,
+ zero_last_gamma,
+)
+from ..models.utils import build_kwargs_from_config, load_state_dict_from_file
+
+__all__ = [
+ "save_exp_config",
+ "setup_dist_env",
+ "setup_seed",
+ "setup_exp_config",
+ "init_model",
+]
+
+
+def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
+ if not is_master():
+ return
+ dump_config(exp_config, os.path.join(path, name))
+
+
+def setup_dist_env(gpu: Optional[str] = None) -> None:
+ if gpu is not None:
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpu
+ if not torch.distributed.is_initialized():
+ dist_init()
+ torch.backends.cudnn.benchmark = True
+ torch.cuda.set_device(get_dist_local_rank())
+
+
+def setup_seed(manual_seed: int, resume: bool) -> None:
+ if resume:
+ manual_seed = int(time.time())
+ manual_seed = get_dist_rank() + manual_seed
+ torch.manual_seed(manual_seed)
+ torch.cuda.manual_seed_all(manual_seed)
+
+
+def setup_exp_config(config_path: str, recursive=True, opt_args: Optional[dict] = None) -> dict:
+ # load config
+ if not os.path.isfile(config_path):
+ raise ValueError(config_path)
+
+ fpaths = [config_path]
+ if recursive:
+ extension = os.path.splitext(config_path)[1]
+ while os.path.dirname(config_path) != config_path:
+ config_path = os.path.dirname(config_path)
+ fpath = os.path.join(config_path, "default" + extension)
+ if os.path.isfile(fpath):
+ fpaths.append(fpath)
+ fpaths = fpaths[::-1]
+
+ default_config = load_config(fpaths[0])
+ exp_config = deepcopy(default_config)
+ for fpath in fpaths[1:]:
+ partial_update_config(exp_config, load_config(fpath))
+ # update config via args
+ if opt_args is not None:
+ partial_update_config(exp_config, opt_args)
+
+ return exp_config
+
+
+def init_model(
+ network: nn.Module,
+ init_from: Optional[str] = None,
+ backbone_init_from: Optional[str] = None,
+ rand_init="trunc_normal",
+ last_gamma=None,
+) -> None:
+ # initialization
+ init_modules(network, init_type=rand_init)
+ # zero gamma of last bn in each block
+ if last_gamma is not None:
+ zero_last_gamma(network, last_gamma)
+
+ # load weight
+ if init_from is not None and os.path.isfile(init_from):
+ network.load_state_dict(load_state_dict_from_file(init_from))
+ print(f"Loaded init from {init_from}")
+ elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
+ network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
+ print(f"Loaded backbone init from {backbone_init_from}")
+ else:
+ print(f"Random init ({rand_init}) with last gamma {last_gamma}")
diff --git a/diffusion/model/dc_ae/efficientvit/apps/trainer/__init__.py b/diffusion/model/dc_ae/efficientvit/apps/trainer/__init__.py
new file mode 100644
index 0000000..1e3b210
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/apps/trainer/__init__.py
@@ -0,0 +1 @@
+from .run_config import *
diff --git a/diffusion/model/dc_ae/efficientvit/apps/trainer/run_config.py b/diffusion/model/dc_ae/efficientvit/apps/trainer/run_config.py
new file mode 100644
index 0000000..442bb13
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/apps/trainer/run_config.py
@@ -0,0 +1,128 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+from typing import Any
+
+import numpy as np
+import torch.nn as nn
+
+from ...apps.utils import CosineLRwithWarmup, build_optimizer
+
+__all__ = ["Scheduler", "RunConfig"]
+
+
+class Scheduler:
+ PROGRESS = 0
+
+
+class RunConfig:
+ n_epochs: int
+ init_lr: float
+ warmup_epochs: int
+ warmup_lr: float
+ lr_schedule_name: str
+ lr_schedule_param: dict
+ optimizer_name: str
+ optimizer_params: dict
+ weight_decay: float
+ no_wd_keys: list
+ grad_clip: float # allow none to turn off grad clipping
+ reset_bn: bool
+ reset_bn_size: int
+ reset_bn_batch_size: int
+ eval_image_size: list # allow none to use image_size in data_provider
+
+ @property
+ def none_allowed(self):
+ return ["grad_clip", "eval_image_size"]
+
+ def __init__(self, **kwargs): # arguments must be passed as kwargs
+ for k, val in kwargs.items():
+ setattr(self, k, val)
+
+ # check that all relevant configs are there
+ annotations = {}
+ for clas in type(self).mro():
+ if hasattr(clas, "__annotations__"):
+ annotations.update(clas.__annotations__)
+ for k, k_type in annotations.items():
+ assert hasattr(self, k), f"Key {k} with type {k_type} required for initialization."
+ attr = getattr(self, k)
+ if k in self.none_allowed:
+ k_type = (k_type, type(None))
+ assert isinstance(attr, k_type), f"Key {k} must be type {k_type}, provided={attr}."
+
+ self.global_step = 0
+ self.batch_per_epoch = 1
+
+ def build_optimizer(self, network: nn.Module) -> tuple[Any, Any]:
+ r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler"""
+ param_dict = {}
+ for name, param in network.named_parameters():
+ if param.requires_grad:
+ opt_config = [self.weight_decay, self.init_lr]
+ if self.no_wd_keys is not None and len(self.no_wd_keys) > 0:
+ if np.any([key in name for key in self.no_wd_keys]):
+ opt_config[0] = 0
+ opt_key = json.dumps(opt_config)
+ param_dict[opt_key] = param_dict.get(opt_key, []) + [param]
+
+ net_params = []
+ for opt_key, param_list in param_dict.items():
+ wd, lr = json.loads(opt_key)
+ net_params.append({"params": param_list, "weight_decay": wd, "lr": lr})
+
+ optimizer = build_optimizer(net_params, self.optimizer_name, self.optimizer_params, self.init_lr)
+ # build lr scheduler
+ if self.lr_schedule_name == "cosine":
+ decay_steps = []
+ for epoch in self.lr_schedule_param.get("step", []):
+ decay_steps.append(epoch * self.batch_per_epoch)
+ decay_steps.append(self.n_epochs * self.batch_per_epoch)
+ decay_steps.sort()
+ lr_scheduler = CosineLRwithWarmup(
+ optimizer,
+ self.warmup_epochs * self.batch_per_epoch,
+ self.warmup_lr,
+ decay_steps,
+ )
+ else:
+ raise NotImplementedError
+ return optimizer, lr_scheduler
+
+ def update_global_step(self, epoch, batch_id=0) -> None:
+ self.global_step = epoch * self.batch_per_epoch + batch_id
+ Scheduler.PROGRESS = self.progress
+
+ @property
+ def progress(self) -> float:
+ warmup_steps = self.warmup_epochs * self.batch_per_epoch
+ steps = max(0, self.global_step - warmup_steps)
+ return steps / (self.n_epochs * self.batch_per_epoch)
+
+ def step(self) -> None:
+ self.global_step += 1
+ Scheduler.PROGRESS = self.progress
+
+ def get_remaining_epoch(self, epoch, post=True) -> int:
+ return self.n_epochs + self.warmup_epochs - epoch - int(post)
+
+ def epoch_format(self, epoch: int) -> str:
+ epoch_format = f"%.{len(str(self.n_epochs))}d"
+ epoch_format = f"[{epoch_format}/{epoch_format}]"
+ epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs)
+ return epoch_format
diff --git a/diffusion/model/dc_ae/efficientvit/apps/utils/__init__.py b/diffusion/model/dc_ae/efficientvit/apps/utils/__init__.py
new file mode 100644
index 0000000..3b4bc1d
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/apps/utils/__init__.py
@@ -0,0 +1,10 @@
+from .dist import *
+from .ema import *
+
+# from .export import *
+from .image import *
+from .init import *
+from .lr import *
+from .metric import *
+from .misc import *
+from .opt import *
diff --git a/diffusion/model/dc_ae/efficientvit/apps/utils/dist.py b/diffusion/model/dc_ae/efficientvit/apps/utils/dist.py
new file mode 100644
index 0000000..b1625a2
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/apps/utils/dist.py
@@ -0,0 +1,91 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+from typing import Union
+
+import torch
+import torch.distributed
+
+from ...models.utils.list import list_mean, list_sum
+
+__all__ = [
+ "dist_init",
+ "is_dist_initialized",
+ "get_dist_rank",
+ "get_dist_size",
+ "is_master",
+ "dist_barrier",
+ "get_dist_local_rank",
+ "sync_tensor",
+]
+
+
+def dist_init() -> None:
+ if is_dist_initialized():
+ return
+ try:
+ torch.distributed.init_process_group(backend="nccl")
+ assert torch.distributed.is_initialized()
+ except Exception:
+ os.environ["RANK"] = "0"
+ os.environ["WORLD_SIZE"] = "1"
+ os.environ["LOCAL_RANK"] = "0"
+ print("warning: dist not init")
+
+
+def is_dist_initialized() -> bool:
+ return torch.distributed.is_initialized()
+
+
+def get_dist_rank() -> int:
+ return int(os.environ["RANK"])
+
+
+def get_dist_size() -> int:
+ return int(os.environ["WORLD_SIZE"])
+
+
+def is_master() -> bool:
+ return get_dist_rank() == 0
+
+
+def dist_barrier() -> None:
+ if is_dist_initialized():
+ torch.distributed.barrier()
+
+
+def get_dist_local_rank() -> int:
+ return int(os.environ["LOCAL_RANK"])
+
+
+def sync_tensor(tensor: Union[torch.Tensor, float], reduce="mean") -> Union[torch.Tensor, list[torch.Tensor]]:
+ if not is_dist_initialized():
+ return tensor
+ if not isinstance(tensor, torch.Tensor):
+ tensor = torch.Tensor(1).fill_(tensor).cuda()
+ tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())]
+ torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
+ if reduce == "mean":
+ return list_mean(tensor_list)
+ elif reduce == "sum":
+ return list_sum(tensor_list)
+ elif reduce == "cat":
+ return torch.cat(tensor_list, dim=0)
+ elif reduce == "root":
+ return tensor_list[0]
+ else:
+ return tensor_list
diff --git a/diffusion/model/dc_ae/efficientvit/apps/utils/ema.py b/diffusion/model/dc_ae/efficientvit/apps/utils/ema.py
new file mode 100644
index 0000000..0e88c4c
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/apps/utils/ema.py
@@ -0,0 +1,54 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import copy
+import math
+
+import torch
+import torch.nn as nn
+
+from ...models.utils import is_parallel
+
+__all__ = ["EMA"]
+
+
+def update_ema(ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float) -> None:
+ for k, v in ema.state_dict().items():
+ if v.dtype.is_floating_point:
+ v -= (1.0 - decay) * (v - new_state_dict[k].detach())
+
+
+class EMA:
+ def __init__(self, model: nn.Module, decay: float, warmup_steps=2000):
+ self.shadows = copy.deepcopy(model.module if is_parallel(model) else model).eval()
+ self.decay = decay
+ self.warmup_steps = warmup_steps
+
+ for p in self.shadows.parameters():
+ p.requires_grad = False
+
+ def step(self, model: nn.Module, global_step: int) -> None:
+ with torch.no_grad():
+ msd = (model.module if is_parallel(model) else model).state_dict()
+ update_ema(self.shadows, msd, self.decay * (1 - math.exp(-global_step / self.warmup_steps)))
+
+ def state_dict(self) -> dict[float, dict[str, torch.Tensor]]:
+ return {self.decay: self.shadows.state_dict()}
+
+ def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None:
+ for decay in state_dict:
+ if decay == self.decay:
+ self.shadows.load_state_dict(state_dict[decay])
diff --git a/diffusion/model/dc_ae/efficientvit/apps/utils/export.py b/diffusion/model/dc_ae/efficientvit/apps/utils/export.py
new file mode 100644
index 0000000..8ec1b98
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/apps/utils/export.py
@@ -0,0 +1,58 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import io
+import os
+from typing import Any
+
+import onnx
+import torch
+import torch.nn as nn
+from onnxsim import simplify as simplify_func
+
+__all__ = ["export_onnx"]
+
+
+def export_onnx(model: nn.Module, export_path: str, sample_inputs: Any, simplify=True, opset=11) -> None:
+ """Export a model to a platform-specific onnx format.
+
+ Args:
+ model: a torch.nn.Module object.
+ export_path: export location.
+ sample_inputs: Any.
+ simplify: a flag to turn on onnx-simplifier
+ opset: int
+ """
+ model.eval()
+
+ buffer = io.BytesIO()
+ with torch.no_grad():
+ torch.onnx.export(model, sample_inputs, buffer, opset_version=opset)
+ buffer.seek(0, 0)
+ if simplify:
+ onnx_model = onnx.load_model(buffer)
+ onnx_model, success = simplify_func(onnx_model)
+ assert success
+ new_buffer = io.BytesIO()
+ onnx.save(onnx_model, new_buffer)
+ buffer = new_buffer
+ buffer.seek(0, 0)
+
+ if buffer.getbuffer().nbytes > 0:
+ save_dir = os.path.dirname(export_path)
+ os.makedirs(save_dir, exist_ok=True)
+ with open(export_path, "wb") as f:
+ f.write(buffer.read())
diff --git a/diffusion/model/dc_ae/efficientvit/apps/utils/image.py b/diffusion/model/dc_ae/efficientvit/apps/utils/image.py
new file mode 100644
index 0000000..9db9d92
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/apps/utils/image.py
@@ -0,0 +1,190 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+import pathlib
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+from PIL import Image
+from torch.utils.data.dataset import Dataset
+from torchvision.datasets import ImageFolder
+
+__all__ = ["load_image", "load_image_from_dir", "DMCrop", "CustomImageFolder", "ImageDataset"]
+
+
+def load_image(data_path: str, mode="rgb") -> Image.Image:
+ img = Image.open(data_path)
+ if mode == "rgb":
+ img = img.convert("RGB")
+ return img
+
+
+def load_image_from_dir(
+ dir_path: str,
+ suffix: Union[str, tuple[str, ...], list[str]] = (".jpg", ".JPEG", ".png"),
+ return_mode="path",
+ k: Optional[int] = None,
+ shuffle_func: Optional[Callable] = None,
+) -> Union[list, tuple[list, list]]:
+ suffix = [suffix] if isinstance(suffix, str) else suffix
+
+ file_list = []
+ for dirpath, _, fnames in os.walk(dir_path):
+ for fname in fnames:
+ if pathlib.Path(fname).suffix not in suffix:
+ continue
+ image_path = os.path.join(dirpath, fname)
+ file_list.append(image_path)
+
+ if shuffle_func is not None and k is not None:
+ shuffle_file_list = shuffle_func(file_list)
+ file_list = shuffle_file_list or file_list
+ file_list = file_list[:k]
+
+ file_list = sorted(file_list)
+
+ if return_mode == "path":
+ return file_list
+ else:
+ files = []
+ path_list = []
+ for file_path in file_list:
+ try:
+ files.append(load_image(file_path))
+ path_list.append(file_path)
+ except Exception:
+ print(f"Fail to load {file_path}")
+ if return_mode == "image":
+ return files
+ else:
+ return path_list, files
+
+
+class DMCrop:
+ """center/random crop used in diffusion models"""
+
+ def __init__(self, size: int) -> None:
+ self.size = size
+
+ def __call__(self, pil_image: Image.Image) -> Image.Image:
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ image_size = self.size
+ if pil_image.size == (image_size, image_size):
+ return pil_image
+
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
+
+
+class CustomImageFolder(ImageFolder):
+ def __init__(self, root: str, transform: Optional[Callable] = None, return_dict: bool = False):
+ root = os.path.expanduser(root)
+ self.return_dict = return_dict
+ super().__init__(root, transform)
+
+ def __getitem__(self, index: int) -> Union[dict[str, Any], tuple[Any, Any]]:
+ path, target = self.samples[index]
+ image = load_image(path)
+ if self.transform is not None:
+ image = self.transform(image)
+ if self.return_dict:
+ return {
+ "index": index,
+ "image_path": path,
+ "image": image,
+ "label": target,
+ }
+ else:
+ return image, target
+
+
+class ImageDataset(Dataset):
+ def __init__(
+ self,
+ data_dirs: Union[str, list[str]],
+ splits: Optional[Union[str, list[Optional[str]]]] = None,
+ transform: Optional[Callable] = None,
+ suffix=(".jpg", ".JPEG", ".png"),
+ pil=True,
+ return_dict=True,
+ ) -> None:
+ super().__init__()
+
+ self.data_dirs = [data_dirs] if isinstance(data_dirs, str) else data_dirs
+ if isinstance(splits, list):
+ assert len(splits) == len(self.data_dirs)
+ self.splits = splits
+ elif isinstance(splits, str):
+ assert len(self.data_dirs) == 1
+ self.splits = [splits]
+ else:
+ self.splits = [None for _ in range(len(self.data_dirs))]
+
+ self.transform = transform
+ self.pil = pil
+ self.return_dict = return_dict
+
+ # load all images [image_path]
+ self.samples = []
+ for data_dir, split in zip(self.data_dirs, self.splits):
+ if split is None:
+ samples = load_image_from_dir(data_dir, suffix, return_mode="path")
+ else:
+ samples = []
+ with open(split) as fin:
+ for line in fin.readlines():
+ relative_path = line[:-1]
+ full_path = os.path.join(data_dir, relative_path)
+ samples.append(full_path)
+ self.samples += samples
+
+ def __len__(self) -> int:
+ return len(self.samples)
+
+ def __getitem__(self, index: int, skip_image=False) -> dict[str, Any]:
+ image_path = self.samples[index]
+
+ if skip_image:
+ image = None
+ else:
+ try:
+ image = load_image(image_path, return_pil=self.pil)
+ except Exception:
+ print(f"Fail to load {image_path}")
+ raise OSError
+ if self.transform is not None:
+ image = self.transform(image)
+ if self.return_dict:
+ return {
+ "index": index,
+ "image_path": image_path,
+ "image_name": os.path.basename(image_path),
+ "data": image,
+ }
+ else:
+ return image
diff --git a/diffusion/model/dc_ae/efficientvit/apps/utils/init.py b/diffusion/model/dc_ae/efficientvit/apps/utils/init.py
new file mode 100644
index 0000000..415981e
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/apps/utils/init.py
@@ -0,0 +1,80 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.batchnorm import _BatchNorm
+
+__all__ = ["init_modules", "zero_last_gamma"]
+
+
+def init_modules(model: Union[nn.Module, list[nn.Module]], init_type="trunc_normal") -> None:
+ _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02}
+
+ if isinstance(model, list):
+ for sub_module in model:
+ init_modules(sub_module, init_type)
+ else:
+ init_params = init_type.split("@")
+ init_params = float(init_params[1]) if len(init_params) > 1 else None
+
+ if init_type.startswith("trunc_normal"):
+ init_func = lambda param: nn.init.trunc_normal_(
+ param, std=(_DEFAULT_INIT_PARAM["trunc_normal"] if init_params is None else init_params)
+ )
+ else:
+ raise NotImplementedError
+
+ for m in model.modules():
+ if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
+ init_func(m.weight)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Embedding):
+ init_func(m.weight)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ else:
+ weight = getattr(m, "weight", None)
+ bias = getattr(m, "bias", None)
+ if isinstance(weight, torch.nn.Parameter):
+ init_func(weight)
+ if isinstance(bias, torch.nn.Parameter):
+ bias.data.zero_()
+
+
+def zero_last_gamma(model: nn.Module, init_val=0) -> None:
+ import efficientvit.models.nn.ops as ops
+
+ for m in model.modules():
+ if isinstance(m, ops.ResidualBlock) and isinstance(m.shortcut, ops.IdentityLayer):
+ if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)):
+ parent_module = m.main.point_conv
+ elif isinstance(m.main, ops.ResBlock):
+ parent_module = m.main.conv2
+ elif isinstance(m.main, ops.ConvLayer):
+ parent_module = m.main
+ elif isinstance(m.main, (ops.LiteMLA)):
+ parent_module = m.main.proj
+ else:
+ parent_module = None
+ if parent_module is not None:
+ norm = getattr(parent_module, "norm", None)
+ if norm is not None:
+ nn.init.constant_(norm.weight, init_val)
diff --git a/diffusion/model/dc_ae/efficientvit/apps/utils/lr.py b/diffusion/model/dc_ae/efficientvit/apps/utils/lr.py
new file mode 100644
index 0000000..56aa636
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/apps/utils/lr.py
@@ -0,0 +1,79 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import math
+from typing import Union
+
+import torch
+
+from ...models.utils.list import val2list
+
+__all__ = ["CosineLRwithWarmup", "ConstantLRwithWarmup"]
+
+
+class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ warmup_steps: int,
+ warmup_lr: float,
+ decay_steps: Union[int, list[int]],
+ last_epoch: int = -1,
+ ) -> None:
+ self.warmup_steps = warmup_steps
+ self.warmup_lr = warmup_lr
+ self.decay_steps = val2list(decay_steps)
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self) -> list[float]:
+ if self.last_epoch < self.warmup_steps:
+ return [
+ (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + self.warmup_lr
+ for base_lr in self.base_lrs
+ ]
+ else:
+ current_steps = self.last_epoch - self.warmup_steps
+ decay_steps = [0] + self.decay_steps
+ idx = len(decay_steps) - 2
+ for i, decay_step in enumerate(decay_steps[:-1]):
+ if decay_step <= current_steps < decay_steps[i + 1]:
+ idx = i
+ break
+ current_steps -= decay_steps[idx]
+ decay_step = decay_steps[idx + 1] - decay_steps[idx]
+ return [0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step)) for base_lr in self.base_lrs]
+
+
+class ConstantLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ warmup_steps: int,
+ warmup_lr: float,
+ last_epoch: int = -1,
+ ) -> None:
+ self.warmup_steps = warmup_steps
+ self.warmup_lr = warmup_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self) -> list[float]:
+ if self.last_epoch < self.warmup_steps:
+ return [
+ (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + self.warmup_lr
+ for base_lr in self.base_lrs
+ ]
+ else:
+ return self.base_lrs
diff --git a/diffusion/model/dc_ae/efficientvit/apps/utils/metric.py b/diffusion/model/dc_ae/efficientvit/apps/utils/metric.py
new file mode 100644
index 0000000..b22f1d1
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/apps/utils/metric.py
@@ -0,0 +1,47 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+import torch
+
+from ...apps.utils.dist import sync_tensor
+
+__all__ = ["AverageMeter"]
+
+
+class AverageMeter:
+ """Computes and stores the average and current value."""
+
+ def __init__(self, is_distributed=True):
+ self.is_distributed = is_distributed
+ self.sum = 0
+ self.count = 0
+
+ def _sync(self, val: Union[torch.Tensor, int, float]) -> Union[torch.Tensor, int, float]:
+ return sync_tensor(val, reduce="sum") if self.is_distributed else val
+
+ def update(self, val: Union[torch.Tensor, int, float], delta_n=1):
+ self.count += self._sync(delta_n)
+ self.sum += self._sync(val * delta_n)
+
+ def get_count(self) -> Union[torch.Tensor, int, float]:
+ return self.count.item() if isinstance(self.count, torch.Tensor) and self.count.numel() == 1 else self.count
+
+ @property
+ def avg(self):
+ avg = -1 if self.count == 0 else self.sum / self.count
+ return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg
diff --git a/diffusion/model/dc_ae/efficientvit/apps/utils/misc.py b/diffusion/model/dc_ae/efficientvit/apps/utils/misc.py
new file mode 100644
index 0000000..06273a1
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/apps/utils/misc.py
@@ -0,0 +1,114 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+from typing import Union
+
+import yaml
+
+__all__ = [
+ "parse_with_yaml",
+ "parse_unknown_args",
+ "partial_update_config",
+ "resolve_and_load_config",
+ "load_config",
+ "dump_config",
+]
+
+
+def parse_with_yaml(config_str: str) -> Union[str, dict]:
+ try:
+ # add space manually for dict
+ if "{" in config_str and "}" in config_str and ":" in config_str:
+ out_str = config_str.replace(":", ": ")
+ else:
+ out_str = config_str
+ return yaml.safe_load(out_str)
+ except ValueError:
+ # return raw string if parsing fails
+ return config_str
+
+
+def parse_unknown_args(unknown: list) -> dict:
+ """Parse unknown args."""
+ index = 0
+ parsed_dict = {}
+ while index < len(unknown):
+ key, val = unknown[index], unknown[index + 1]
+ index += 2
+ if not key.startswith("--"):
+ continue
+ key = key[2:]
+
+ # try parsing with either dot notation or full yaml notation
+ # Note that the vanilla case "--key value" will be parsed the same
+ if "." in key:
+ # key == a.b.c, val == val --> parsed_dict[a][b][c] = val
+ keys = key.split(".")
+ dict_to_update = parsed_dict
+ for key in keys[:-1]:
+ if not (key in dict_to_update and isinstance(dict_to_update[key], dict)):
+ dict_to_update[key] = {}
+ dict_to_update = dict_to_update[key]
+ dict_to_update[keys[-1]] = parse_with_yaml(val) # so we can parse lists, bools, etc...
+ else:
+ parsed_dict[key] = parse_with_yaml(val)
+ return parsed_dict
+
+
+def partial_update_config(config: dict, partial_config: dict) -> dict:
+ for key in partial_config:
+ if key in config and isinstance(partial_config[key], dict) and isinstance(config[key], dict):
+ partial_update_config(config[key], partial_config[key])
+ else:
+ config[key] = partial_config[key]
+ return config
+
+
+def resolve_and_load_config(path: str, config_name="config.yaml") -> dict:
+ path = os.path.realpath(os.path.expanduser(path))
+ if os.path.isdir(path):
+ config_path = os.path.join(path, config_name)
+ else:
+ config_path = path
+ if os.path.isfile(config_path):
+ pass
+ else:
+ raise Exception(f"Cannot find a valid config at {path}")
+ config = load_config(config_path)
+ return config
+
+
+class SafeLoaderWithTuple(yaml.SafeLoader):
+ """A yaml safe loader with python tuple loading capabilities."""
+
+ def construct_python_tuple(self, node):
+ return tuple(self.construct_sequence(node))
+
+
+SafeLoaderWithTuple.add_constructor("tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple)
+
+
+def load_config(filename: str) -> dict:
+ """Load a yaml file."""
+ filename = os.path.realpath(os.path.expanduser(filename))
+ return yaml.load(open(filename), Loader=SafeLoaderWithTuple)
+
+
+def dump_config(config: dict, filename: str) -> None:
+ """Dump a config file"""
+ filename = os.path.realpath(os.path.expanduser(filename))
+ yaml.dump(config, open(filename, "w"), sort_keys=False)
diff --git a/diffusion/model/dc_ae/efficientvit/apps/utils/opt.py b/diffusion/model/dc_ae/efficientvit/apps/utils/opt.py
new file mode 100644
index 0000000..a968a94
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/apps/utils/opt.py
@@ -0,0 +1,42 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+
+__all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"]
+
+# register optimizer here
+# name: optimizer, kwargs with default values
+REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, Any]]] = {
+ "sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}),
+ "adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
+ "adamw": (torch.optim.AdamW, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
+}
+
+
+def build_optimizer(
+ net_params, optimizer_name: str, optimizer_params: Optional[dict], init_lr: float
+) -> torch.optim.Optimizer:
+ optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name]
+ optimizer_params = {} if optimizer_params is None else optimizer_params
+
+ for key in default_params:
+ if key in optimizer_params:
+ default_params[key] = optimizer_params[key]
+ optimizer = optimizer_class(net_params, init_lr, **default_params)
+ return optimizer
diff --git a/diffusion/model/dc_ae/efficientvit/models/__init__.py b/diffusion/model/dc_ae/efficientvit/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/diffusion/model/dc_ae/efficientvit/models/efficientvit/__init__.py b/diffusion/model/dc_ae/efficientvit/models/efficientvit/__init__.py
new file mode 100644
index 0000000..ce6455c
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/models/efficientvit/__init__.py
@@ -0,0 +1 @@
+from .dc_ae import *
diff --git a/diffusion/model/dc_ae/efficientvit/models/efficientvit/dc_ae.py b/diffusion/model/dc_ae/efficientvit/models/efficientvit/dc_ae.py
new file mode 100644
index 0000000..64f162c
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/models/efficientvit/dc_ae.py
@@ -0,0 +1,517 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from dataclasses import dataclass, field
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from omegaconf import MISSING, OmegaConf
+from torch import Tensor
+
+from ...models.nn.act import build_act
+from ...models.nn.norm import build_norm
+from ...models.nn.ops import (
+ ChannelDuplicatingPixelUnshuffleUpSampleLayer,
+ ConvLayer,
+ ConvPixelShuffleUpSampleLayer,
+ ConvPixelUnshuffleDownSampleLayer,
+ EfficientViTBlock,
+ IdentityLayer,
+ InterpolateConvUpSampleLayer,
+ OpSequential,
+ PixelUnshuffleChannelAveragingDownSampleLayer,
+ ResBlock,
+ ResidualBlock,
+)
+
+__all__ = ["DCAE", "dc_ae_f32c32", "dc_ae_f64c128", "dc_ae_f128c512"]
+
+
+@dataclass
+class EncoderConfig:
+ in_channels: int = MISSING
+ latent_channels: int = MISSING
+ width_list: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024)
+ depth_list: tuple[int, ...] = (2, 2, 2, 2, 2, 2)
+ block_type: Any = "ResBlock"
+ norm: str = "trms2d"
+ act: str = "silu"
+ downsample_block_type: str = "ConvPixelUnshuffle"
+ downsample_match_channel: bool = True
+ downsample_shortcut: Optional[str] = "averaging"
+ out_norm: Optional[str] = None
+ out_act: Optional[str] = None
+ out_shortcut: Optional[str] = "averaging"
+ double_latent: bool = False
+
+
+@dataclass
+class DecoderConfig:
+ in_channels: int = MISSING
+ latent_channels: int = MISSING
+ in_shortcut: Optional[str] = "duplicating"
+ width_list: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024)
+ depth_list: tuple[int, ...] = (2, 2, 2, 2, 2, 2)
+ block_type: Any = "ResBlock"
+ norm: Any = "trms2d"
+ act: Any = "silu"
+ upsample_block_type: str = "ConvPixelShuffle"
+ upsample_match_channel: bool = True
+ upsample_shortcut: str = "duplicating"
+ out_norm: str = "trms2d"
+ out_act: str = "relu"
+
+
+@dataclass
+class DCAEConfig:
+ in_channels: int = 3
+ latent_channels: int = 32
+ encoder: EncoderConfig = field(
+ default_factory=lambda: EncoderConfig(in_channels="${..in_channels}", latent_channels="${..latent_channels}")
+ )
+ decoder: DecoderConfig = field(
+ default_factory=lambda: DecoderConfig(in_channels="${..in_channels}", latent_channels="${..latent_channels}")
+ )
+ use_quant_conv: bool = False
+
+ pretrained_path: Optional[str] = None
+ pretrained_source: str = "dc-ae"
+
+ scaling_factor: Optional[float] = None
+
+
+def build_block(
+ block_type: str, in_channels: int, out_channels: int, norm: Optional[str], act: Optional[str]
+) -> nn.Module:
+ if block_type == "ResBlock":
+ assert in_channels == out_channels
+ main_block = ResBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=1,
+ use_bias=(True, False),
+ norm=(None, norm),
+ act_func=(act, None),
+ )
+ block = ResidualBlock(main_block, IdentityLayer())
+ elif block_type == "EViT_GLU":
+ assert in_channels == out_channels
+ block = EfficientViTBlock(in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=())
+ elif block_type == "EViTS5_GLU":
+ assert in_channels == out_channels
+ block = EfficientViTBlock(in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=(5,))
+ else:
+ raise ValueError(f"block_type {block_type} is not supported")
+ return block
+
+
+def build_stage_main(
+ width: int, depth: int, block_type: str | list[str], norm: str, act: str, input_width: int
+) -> list[nn.Module]:
+ assert isinstance(block_type, str) or (isinstance(block_type, list) and depth == len(block_type))
+ stage = []
+ for d in range(depth):
+ current_block_type = block_type[d] if isinstance(block_type, list) else block_type
+ block = build_block(
+ block_type=current_block_type,
+ in_channels=width if d > 0 else input_width,
+ out_channels=width,
+ norm=norm,
+ act=act,
+ )
+ stage.append(block)
+ return stage
+
+
+def build_downsample_block(block_type: str, in_channels: int, out_channels: int, shortcut: Optional[str]) -> nn.Module:
+ if block_type == "Conv":
+ block = ConvLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=2,
+ use_bias=True,
+ norm=None,
+ act_func=None,
+ )
+ elif block_type == "ConvPixelUnshuffle":
+ block = ConvPixelUnshuffleDownSampleLayer(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, factor=2
+ )
+ else:
+ raise ValueError(f"block_type {block_type} is not supported for downsampling")
+ if shortcut is None:
+ pass
+ elif shortcut == "averaging":
+ shortcut_block = PixelUnshuffleChannelAveragingDownSampleLayer(
+ in_channels=in_channels, out_channels=out_channels, factor=2
+ )
+ block = ResidualBlock(block, shortcut_block)
+ else:
+ raise ValueError(f"shortcut {shortcut} is not supported for downsample")
+ return block
+
+
+def build_upsample_block(block_type: str, in_channels: int, out_channels: int, shortcut: Optional[str]) -> nn.Module:
+ if block_type == "ConvPixelShuffle":
+ block = ConvPixelShuffleUpSampleLayer(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, factor=2
+ )
+ elif block_type == "InterpolateConv":
+ block = InterpolateConvUpSampleLayer(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, factor=2
+ )
+ else:
+ raise ValueError(f"block_type {block_type} is not supported for upsampling")
+ if shortcut is None:
+ pass
+ elif shortcut == "duplicating":
+ shortcut_block = ChannelDuplicatingPixelUnshuffleUpSampleLayer(
+ in_channels=in_channels, out_channels=out_channels, factor=2
+ )
+ block = ResidualBlock(block, shortcut_block)
+ else:
+ raise ValueError(f"shortcut {shortcut} is not supported for upsample")
+ return block
+
+
+def build_encoder_project_in_block(in_channels: int, out_channels: int, factor: int, downsample_block_type: str):
+ if factor == 1:
+ block = ConvLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=1,
+ use_bias=True,
+ norm=None,
+ act_func=None,
+ )
+ elif factor == 2:
+ block = build_downsample_block(
+ block_type=downsample_block_type, in_channels=in_channels, out_channels=out_channels, shortcut=None
+ )
+ else:
+ raise ValueError(f"downsample factor {factor} is not supported for encoder project in")
+ return block
+
+
+def build_encoder_project_out_block(
+ in_channels: int, out_channels: int, norm: Optional[str], act: Optional[str], shortcut: Optional[str]
+):
+ block = OpSequential(
+ [
+ build_norm(norm),
+ build_act(act),
+ ConvLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=1,
+ use_bias=True,
+ norm=None,
+ act_func=None,
+ ),
+ ]
+ )
+ if shortcut is None:
+ pass
+ elif shortcut == "averaging":
+ shortcut_block = PixelUnshuffleChannelAveragingDownSampleLayer(
+ in_channels=in_channels, out_channels=out_channels, factor=1
+ )
+ block = ResidualBlock(block, shortcut_block)
+ else:
+ raise ValueError(f"shortcut {shortcut} is not supported for encoder project out")
+ return block
+
+
+def build_decoder_project_in_block(in_channels: int, out_channels: int, shortcut: Optional[str]):
+ block = ConvLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=1,
+ use_bias=True,
+ norm=None,
+ act_func=None,
+ )
+ if shortcut is None:
+ pass
+ elif shortcut == "duplicating":
+ shortcut_block = ChannelDuplicatingPixelUnshuffleUpSampleLayer(
+ in_channels=in_channels, out_channels=out_channels, factor=1
+ )
+ block = ResidualBlock(block, shortcut_block)
+ else:
+ raise ValueError(f"shortcut {shortcut} is not supported for decoder project in")
+ return block
+
+
+def build_decoder_project_out_block(
+ in_channels: int, out_channels: int, factor: int, upsample_block_type: str, norm: Optional[str], act: Optional[str]
+):
+ layers: list[nn.Module] = [
+ build_norm(norm, in_channels),
+ build_act(act),
+ ]
+ if factor == 1:
+ layers.append(
+ ConvLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=1,
+ use_bias=True,
+ norm=None,
+ act_func=None,
+ )
+ )
+ elif factor == 2:
+ layers.append(
+ build_upsample_block(
+ block_type=upsample_block_type, in_channels=in_channels, out_channels=out_channels, shortcut=None
+ )
+ )
+ else:
+ raise ValueError(f"upsample factor {factor} is not supported for decoder project out")
+ return OpSequential(layers)
+
+
+class Encoder(nn.Module):
+ def __init__(self, cfg: EncoderConfig):
+ super().__init__()
+ self.cfg = cfg
+ num_stages = len(cfg.width_list)
+ self.num_stages = num_stages
+ assert len(cfg.depth_list) == num_stages
+ assert len(cfg.width_list) == num_stages
+ assert isinstance(cfg.block_type, str) or (
+ isinstance(cfg.block_type, list) and len(cfg.block_type) == num_stages
+ )
+
+ self.project_in = build_encoder_project_in_block(
+ in_channels=cfg.in_channels,
+ out_channels=cfg.width_list[0] if cfg.depth_list[0] > 0 else cfg.width_list[1],
+ factor=1 if cfg.depth_list[0] > 0 else 2,
+ downsample_block_type=cfg.downsample_block_type,
+ )
+
+ self.stages: list[OpSequential] = []
+ for stage_id, (width, depth) in enumerate(zip(cfg.width_list, cfg.depth_list)):
+ block_type = cfg.block_type[stage_id] if isinstance(cfg.block_type, list) else cfg.block_type
+ stage = build_stage_main(
+ width=width, depth=depth, block_type=block_type, norm=cfg.norm, act=cfg.act, input_width=width
+ )
+
+ if stage_id < num_stages - 1 and depth > 0:
+ downsample_block = build_downsample_block(
+ block_type=cfg.downsample_block_type,
+ in_channels=width,
+ out_channels=cfg.width_list[stage_id + 1] if cfg.downsample_match_channel else width,
+ shortcut=cfg.downsample_shortcut,
+ )
+ stage.append(downsample_block)
+ self.stages.append(OpSequential(stage))
+ self.stages = nn.ModuleList(self.stages)
+
+ self.project_out = build_encoder_project_out_block(
+ in_channels=cfg.width_list[-1],
+ out_channels=2 * cfg.latent_channels if cfg.double_latent else cfg.latent_channels,
+ norm=cfg.out_norm,
+ act=cfg.out_act,
+ shortcut=cfg.out_shortcut,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.project_in(x)
+ for stage in self.stages:
+ if len(stage.op_list) == 0:
+ continue
+ x = stage(x)
+ x = self.project_out(x)
+ return x
+
+
+class Decoder(nn.Module):
+ def __init__(self, cfg: DecoderConfig):
+ super().__init__()
+ self.cfg = cfg
+ num_stages = len(cfg.width_list)
+ self.num_stages = num_stages
+ assert len(cfg.depth_list) == num_stages
+ assert len(cfg.width_list) == num_stages
+ assert isinstance(cfg.block_type, str) or (
+ isinstance(cfg.block_type, list) and len(cfg.block_type) == num_stages
+ )
+ assert isinstance(cfg.norm, str) or (isinstance(cfg.norm, list) and len(cfg.norm) == num_stages)
+ assert isinstance(cfg.act, str) or (isinstance(cfg.act, list) and len(cfg.act) == num_stages)
+
+ self.project_in = build_decoder_project_in_block(
+ in_channels=cfg.latent_channels,
+ out_channels=cfg.width_list[-1],
+ shortcut=cfg.in_shortcut,
+ )
+
+ self.stages: list[OpSequential] = []
+ for stage_id, (width, depth) in reversed(list(enumerate(zip(cfg.width_list, cfg.depth_list)))):
+ stage = []
+ if stage_id < num_stages - 1 and depth > 0:
+ upsample_block = build_upsample_block(
+ block_type=cfg.upsample_block_type,
+ in_channels=cfg.width_list[stage_id + 1],
+ out_channels=width if cfg.upsample_match_channel else cfg.width_list[stage_id + 1],
+ shortcut=cfg.upsample_shortcut,
+ )
+ stage.append(upsample_block)
+
+ block_type = cfg.block_type[stage_id] if isinstance(cfg.block_type, list) else cfg.block_type
+ norm = cfg.norm[stage_id] if isinstance(cfg.norm, list) else cfg.norm
+ act = cfg.act[stage_id] if isinstance(cfg.act, list) else cfg.act
+ stage.extend(
+ build_stage_main(
+ width=width,
+ depth=depth,
+ block_type=block_type,
+ norm=norm,
+ act=act,
+ input_width=(
+ width if cfg.upsample_match_channel else cfg.width_list[min(stage_id + 1, num_stages - 1)]
+ ),
+ )
+ )
+ self.stages.insert(0, OpSequential(stage))
+ self.stages = nn.ModuleList(self.stages)
+
+ self.project_out = build_decoder_project_out_block(
+ in_channels=cfg.width_list[0] if cfg.depth_list[0] > 0 else cfg.width_list[1],
+ out_channels=cfg.in_channels,
+ factor=1 if cfg.depth_list[0] > 0 else 2,
+ upsample_block_type=cfg.upsample_block_type,
+ norm=cfg.out_norm,
+ act=cfg.out_act,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.project_in(x)
+ for stage in reversed(self.stages):
+ if len(stage.op_list) == 0:
+ continue
+ x = stage(x)
+ x = self.project_out(x)
+ return x
+
+
+class DCAE(nn.Module):
+ def __init__(self, cfg: DCAEConfig):
+ super().__init__()
+ self.cfg = cfg
+ self.encoder = Encoder(cfg.encoder)
+ self.decoder = Decoder(cfg.decoder)
+
+ if self.cfg.pretrained_path is not None:
+ self.load_model()
+
+ def load_model(self):
+ if self.cfg.pretrained_source == "dc-ae":
+ state_dict = torch.load(self.cfg.pretrained_path, map_location="cpu", weights_only=True)["state_dict"]
+ self.load_state_dict(state_dict)
+ else:
+ raise NotImplementedError
+
+ @property
+ def spatial_compression_ratio(self) -> int:
+ return 2 ** (self.decoder.num_stages - 1)
+
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.encoder(x)
+ return x
+
+ def decode(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.decoder(x)
+ return x
+
+ def forward(self, x: torch.Tensor, global_step: int) -> tuple[Any, Tensor, dict[Any, Any]]:
+ x = self.encoder(x)
+ x = self.decoder(x)
+ return x, torch.tensor(0), {}
+
+
+def dc_ae_f32c32(name: str, pretrained_path: str) -> DCAEConfig:
+ if name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
+ cfg_str = (
+ "latent_channels=32 "
+ "encoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU] "
+ "encoder.width_list=[128,256,512,512,1024,1024] encoder.depth_list=[0,4,8,2,2,2] "
+ "decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU] "
+ "decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[0,5,10,2,2,2] "
+ "decoder.norm=[bn2d,bn2d,bn2d,trms2d,trms2d,trms2d] decoder.act=[relu,relu,relu,silu,silu,silu]"
+ )
+ elif name in ["dc-ae-f32c32-sana-1.0"]:
+ cfg_str = (
+ "latent_channels=32 "
+ "encoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
+ "encoder.width_list=[128,256,512,512,1024,1024] encoder.depth_list=[2,2,2,3,3,3] "
+ "encoder.downsample_block_type=Conv "
+ "decoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
+ "decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[3,3,3,3,3,3] "
+ "decoder.upsample_block_type=InterpolateConv "
+ "decoder.norm=trms2d decoder.act=silu "
+ "scaling_factor=0.41407"
+ )
+ else:
+ raise NotImplementedError
+ cfg = OmegaConf.from_dotlist(cfg_str.split(" "))
+ cfg: DCAEConfig = OmegaConf.to_object(OmegaConf.merge(OmegaConf.structured(DCAEConfig), cfg))
+ cfg.pretrained_path = pretrained_path
+ return cfg
+
+
+def dc_ae_f64c128(name: str, pretrained_path: Optional[str] = None) -> DCAEConfig:
+ if name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
+ cfg_str = (
+ "latent_channels=128 "
+ "encoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU] "
+ "encoder.width_list=[128,256,512,512,1024,1024,2048] encoder.depth_list=[0,4,8,2,2,2,2] "
+ "decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU] "
+ "decoder.width_list=[128,256,512,512,1024,1024,2048] decoder.depth_list=[0,5,10,2,2,2,2] "
+ "decoder.norm=[bn2d,bn2d,bn2d,trms2d,trms2d,trms2d,trms2d] decoder.act=[relu,relu,relu,silu,silu,silu,silu]"
+ )
+ else:
+ raise NotImplementedError
+ cfg = OmegaConf.from_dotlist(cfg_str.split(" "))
+ cfg: DCAEConfig = OmegaConf.to_object(OmegaConf.merge(OmegaConf.structured(DCAEConfig), cfg))
+ cfg.pretrained_path = pretrained_path
+ return cfg
+
+
+def dc_ae_f128c512(name: str, pretrained_path: Optional[str] = None) -> DCAEConfig:
+ if name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
+ cfg_str = (
+ "latent_channels=512 "
+ "encoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU] "
+ "encoder.width_list=[128,256,512,512,1024,1024,2048,2048] encoder.depth_list=[0,4,8,2,2,2,2,2] "
+ "decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU] "
+ "decoder.width_list=[128,256,512,512,1024,1024,2048,2048] decoder.depth_list=[0,5,10,2,2,2,2,2] "
+ "decoder.norm=[bn2d,bn2d,bn2d,trms2d,trms2d,trms2d,trms2d,trms2d] decoder.act=[relu,relu,relu,silu,silu,silu,silu,silu]"
+ )
+ else:
+ raise NotImplementedError
+ cfg = OmegaConf.from_dotlist(cfg_str.split(" "))
+ cfg: DCAEConfig = OmegaConf.to_object(OmegaConf.merge(OmegaConf.structured(DCAEConfig), cfg))
+ cfg.pretrained_path = pretrained_path
+ return cfg
diff --git a/diffusion/model/dc_ae/efficientvit/models/nn/__init__.py b/diffusion/model/dc_ae/efficientvit/models/nn/__init__.py
new file mode 100644
index 0000000..3328d2e
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/models/nn/__init__.py
@@ -0,0 +1,5 @@
+from .act import *
+from .drop import *
+from .norm import *
+from .ops import *
+from .triton_rms_norm import *
diff --git a/diffusion/model/dc_ae/efficientvit/models/nn/act.py b/diffusion/model/dc_ae/efficientvit/models/nn/act.py
new file mode 100644
index 0000000..4907742
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/models/nn/act.py
@@ -0,0 +1,43 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from functools import partial
+from typing import Optional
+
+import torch.nn as nn
+
+from ...models.utils import build_kwargs_from_config
+
+__all__ = ["build_act"]
+
+
+# register activation function here
+REGISTERED_ACT_DICT: dict[str, type] = {
+ "relu": nn.ReLU,
+ "relu6": nn.ReLU6,
+ "hswish": nn.Hardswish,
+ "silu": nn.SiLU,
+ "gelu": partial(nn.GELU, approximate="tanh"),
+}
+
+
+def build_act(name: str, **kwargs) -> Optional[nn.Module]:
+ if name in REGISTERED_ACT_DICT:
+ act_cls = REGISTERED_ACT_DICT[name]
+ args = build_kwargs_from_config(kwargs, act_cls)
+ return act_cls(**args)
+ else:
+ return None
diff --git a/diffusion/model/dc_ae/efficientvit/models/nn/drop.py b/diffusion/model/dc_ae/efficientvit/models/nn/drop.py
new file mode 100644
index 0000000..a0ddf62
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/models/nn/drop.py
@@ -0,0 +1,102 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ...apps.trainer.run_config import Scheduler
+from ...models.nn.ops import IdentityLayer, ResidualBlock
+from ...models.utils import build_kwargs_from_config
+
+__all__ = ["apply_drop_func"]
+
+
+def apply_drop_func(network: nn.Module, drop_config: Optional[dict[str, Any]]) -> None:
+ if drop_config is None:
+ return
+
+ drop_lookup_table = {
+ "droppath": apply_droppath,
+ }
+
+ drop_func = drop_lookup_table[drop_config["name"]]
+ drop_kwargs = build_kwargs_from_config(drop_config, drop_func)
+
+ drop_func(network, **drop_kwargs)
+
+
+def apply_droppath(
+ network: nn.Module,
+ drop_prob: float,
+ linear_decay=True,
+ scheduled=True,
+ skip=0,
+) -> None:
+ all_valid_blocks = []
+ for m in network.modules():
+ for name, sub_module in m.named_children():
+ if isinstance(sub_module, ResidualBlock) and isinstance(sub_module.shortcut, IdentityLayer):
+ all_valid_blocks.append((m, name, sub_module))
+ all_valid_blocks = all_valid_blocks[skip:]
+ for i, (m, name, sub_module) in enumerate(all_valid_blocks):
+ prob = drop_prob * (i + 1) / len(all_valid_blocks) if linear_decay else drop_prob
+ new_module = DropPathResidualBlock(
+ sub_module.main,
+ sub_module.shortcut,
+ sub_module.post_act,
+ sub_module.pre_norm,
+ prob,
+ scheduled,
+ )
+ m._modules[name] = new_module
+
+
+class DropPathResidualBlock(ResidualBlock):
+ def __init__(
+ self,
+ main: nn.Module,
+ shortcut: Optional[nn.Module],
+ post_act=None,
+ pre_norm: Optional[nn.Module] = None,
+ ######################################
+ drop_prob: float = 0,
+ scheduled=True,
+ ):
+ super().__init__(main, shortcut, post_act, pre_norm)
+
+ self.drop_prob = drop_prob
+ self.scheduled = scheduled
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if not self.training or self.drop_prob == 0 or not isinstance(self.shortcut, IdentityLayer):
+ return ResidualBlock.forward(self, x)
+ else:
+ drop_prob = self.drop_prob
+ if self.scheduled:
+ drop_prob *= np.clip(Scheduler.PROGRESS, 0, 1)
+ keep_prob = 1 - drop_prob
+
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+
+ res = self.forward_main(x) / keep_prob * random_tensor + self.shortcut(x)
+ if self.post_act:
+ res = self.post_act(res)
+ return res
diff --git a/diffusion/model/dc_ae/efficientvit/models/nn/norm.py b/diffusion/model/dc_ae/efficientvit/models/nn/norm.py
new file mode 100644
index 0000000..5e62beb
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/models/nn/norm.py
@@ -0,0 +1,157 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ...models.nn.triton_rms_norm import TritonRMSNorm2dFunc
+from ...models.utils import build_kwargs_from_config
+
+__all__ = ["LayerNorm2d", "TritonRMSNorm2d", "build_norm", "reset_bn", "set_norm_eps"]
+
+
+class LayerNorm2d(nn.LayerNorm):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ out = x - torch.mean(x, dim=1, keepdim=True)
+ out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)
+ if self.elementwise_affine:
+ out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
+ return out
+
+
+class TritonRMSNorm2d(nn.LayerNorm):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return TritonRMSNorm2dFunc.apply(x, self.weight, self.bias, self.eps)
+
+
+# register normalization function here
+REGISTERED_NORM_DICT: dict[str, type] = {
+ "bn2d": nn.BatchNorm2d,
+ "ln": nn.LayerNorm,
+ "ln2d": LayerNorm2d,
+ "trms2d": TritonRMSNorm2d,
+}
+
+
+def build_norm(name="bn2d", num_features=None, **kwargs) -> Optional[nn.Module]:
+ if name in ["ln", "ln2d", "trms2d"]:
+ kwargs["normalized_shape"] = num_features
+ else:
+ kwargs["num_features"] = num_features
+ if name in REGISTERED_NORM_DICT:
+ norm_cls = REGISTERED_NORM_DICT[name]
+ args = build_kwargs_from_config(kwargs, norm_cls)
+ return norm_cls(**args)
+ else:
+ return None
+
+
+def reset_bn(
+ model: nn.Module,
+ data_loader: list,
+ sync=True,
+ progress_bar=False,
+) -> None:
+ import copy
+
+ import torch.nn.functional as F
+ from efficientvit.apps.utils import AverageMeter, is_master, sync_tensor
+ from efficientvit.models.utils import get_device, list_join
+ from tqdm import tqdm
+
+ bn_mean = {}
+ bn_var = {}
+
+ tmp_model = copy.deepcopy(model)
+ for name, m in tmp_model.named_modules():
+ if isinstance(m, _BatchNorm):
+ bn_mean[name] = AverageMeter(is_distributed=False)
+ bn_var[name] = AverageMeter(is_distributed=False)
+
+ def new_forward(bn, mean_est, var_est):
+ def lambda_forward(x):
+ x = x.contiguous()
+ if sync:
+ batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1
+ batch_mean = sync_tensor(batch_mean, reduce="cat")
+ batch_mean = torch.mean(batch_mean, dim=0, keepdim=True)
+
+ batch_var = (x - batch_mean) * (x - batch_mean)
+ batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
+ batch_var = sync_tensor(batch_var, reduce="cat")
+ batch_var = torch.mean(batch_var, dim=0, keepdim=True)
+ else:
+ batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1
+ batch_var = (x - batch_mean) * (x - batch_mean)
+ batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
+
+ batch_mean = torch.squeeze(batch_mean)
+ batch_var = torch.squeeze(batch_var)
+
+ mean_est.update(batch_mean.data, x.size(0))
+ var_est.update(batch_var.data, x.size(0))
+
+ # bn forward using calculated mean & var
+ _feature_dim = batch_mean.shape[0]
+ return F.batch_norm(
+ x,
+ batch_mean,
+ batch_var,
+ bn.weight[:_feature_dim],
+ bn.bias[:_feature_dim],
+ False,
+ 0.0,
+ bn.eps,
+ )
+
+ return lambda_forward
+
+ m.forward = new_forward(m, bn_mean[name], bn_var[name])
+
+ # skip if there is no batch normalization layers in the network
+ if len(bn_mean) == 0:
+ return
+
+ tmp_model.eval()
+ with torch.no_grad():
+ with tqdm(total=len(data_loader), desc="reset bn", disable=not progress_bar or not is_master()) as t:
+ for images in data_loader:
+ images = images.to(get_device(tmp_model))
+ tmp_model(images)
+ t.set_postfix(
+ {
+ "bs": images.size(0),
+ "res": list_join(images.shape[-2:], "x"),
+ }
+ )
+ t.update()
+
+ for name, m in model.named_modules():
+ if name in bn_mean and bn_mean[name].count > 0:
+ feature_dim = bn_mean[name].avg.size(0)
+ assert isinstance(m, _BatchNorm)
+ m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
+ m.running_var.data[:feature_dim].copy_(bn_var[name].avg)
+
+
+def set_norm_eps(model: nn.Module, eps: Optional[float] = None) -> None:
+ for m in model.modules():
+ if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)):
+ if eps is not None:
+ m.eps = eps
diff --git a/diffusion/model/dc_ae/efficientvit/models/nn/ops.py b/diffusion/model/dc_ae/efficientvit/models/nn/ops.py
new file mode 100644
index 0000000..66fb682
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/models/nn/ops.py
@@ -0,0 +1,835 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...models.nn.act import build_act
+from ...models.nn.norm import build_norm
+from ...models.utils import get_same_padding, list_sum, resize, val2list, val2tuple
+
+__all__ = [
+ "ConvLayer",
+ "UpSampleLayer",
+ "ConvPixelUnshuffleDownSampleLayer",
+ "PixelUnshuffleChannelAveragingDownSampleLayer",
+ "ConvPixelShuffleUpSampleLayer",
+ "ChannelDuplicatingPixelUnshuffleUpSampleLayer",
+ "LinearLayer",
+ "IdentityLayer",
+ "DSConv",
+ "MBConv",
+ "FusedMBConv",
+ "ResBlock",
+ "LiteMLA",
+ "EfficientViTBlock",
+ "ResidualBlock",
+ "DAGBlock",
+ "OpSequential",
+]
+
+
+#################################################################################
+# Basic Layers #
+#################################################################################
+
+
+class ConvLayer(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=3,
+ stride=1,
+ dilation=1,
+ groups=1,
+ use_bias=False,
+ dropout=0,
+ norm="bn2d",
+ act_func="relu",
+ ):
+ super().__init__()
+
+ padding = get_same_padding(kernel_size)
+ padding *= dilation
+
+ self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=(kernel_size, kernel_size),
+ stride=(stride, stride),
+ padding=padding,
+ dilation=(dilation, dilation),
+ groups=groups,
+ bias=use_bias,
+ )
+ self.norm = build_norm(norm, num_features=out_channels)
+ self.act = build_act(act_func)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.dropout is not None:
+ x = self.dropout(x)
+ x = self.conv(x)
+ if self.norm:
+ x = self.norm(x)
+ if self.act:
+ x = self.act(x)
+ return x
+
+
+class UpSampleLayer(nn.Module):
+ def __init__(
+ self,
+ mode="bicubic",
+ size: Optional[int | tuple[int, int] | list[int]] = None,
+ factor=2,
+ align_corners=False,
+ ):
+ super().__init__()
+ self.mode = mode
+ self.size = val2list(size, 2) if size is not None else None
+ self.factor = None if self.size is not None else factor
+ self.align_corners = align_corners
+
+ @torch.autocast(device_type="cuda", enabled=False)
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if (self.size is not None and tuple(x.shape[-2:]) == self.size) or self.factor == 1:
+ return x
+ if x.dtype in [torch.float16, torch.bfloat16]:
+ x = x.float()
+ return resize(x, self.size, self.factor, self.mode, self.align_corners)
+
+
+class ConvPixelUnshuffleDownSampleLayer(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ factor: int,
+ ):
+ super().__init__()
+ self.factor = factor
+ out_ratio = factor**2
+ assert out_channels % out_ratio == 0
+ self.conv = ConvLayer(
+ in_channels=in_channels,
+ out_channels=out_channels // out_ratio,
+ kernel_size=kernel_size,
+ use_bias=True,
+ norm=None,
+ act_func=None,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv(x)
+ x = F.pixel_unshuffle(x, self.factor)
+ return x
+
+
+class PixelUnshuffleChannelAveragingDownSampleLayer(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ factor: int,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.factor = factor
+ assert in_channels * factor**2 % out_channels == 0
+ self.group_size = in_channels * factor**2 // out_channels
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = F.pixel_unshuffle(x, self.factor)
+ B, C, H, W = x.shape
+ x = x.view(B, self.out_channels, self.group_size, H, W)
+ x = x.mean(dim=2)
+ return x
+
+
+class ConvPixelShuffleUpSampleLayer(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ factor: int,
+ ):
+ super().__init__()
+ self.factor = factor
+ out_ratio = factor**2
+ self.conv = ConvLayer(
+ in_channels=in_channels,
+ out_channels=out_channels * out_ratio,
+ kernel_size=kernel_size,
+ use_bias=True,
+ norm=None,
+ act_func=None,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv(x)
+ x = F.pixel_shuffle(x, self.factor)
+ return x
+
+
+class InterpolateConvUpSampleLayer(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ factor: int,
+ mode: str = "nearest",
+ ) -> None:
+ super().__init__()
+ self.factor = factor
+ self.mode = mode
+ self.conv = ConvLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ use_bias=True,
+ norm=None,
+ act_func=None,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode)
+ x = self.conv(x)
+ return x
+
+
+class ChannelDuplicatingPixelUnshuffleUpSampleLayer(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ factor: int,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.factor = factor
+ assert out_channels * factor**2 % in_channels == 0
+ self.repeats = out_channels * factor**2 // in_channels
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x.repeat_interleave(self.repeats, dim=1)
+ x = F.pixel_shuffle(x, self.factor)
+ return x
+
+
+class LinearLayer(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ use_bias=True,
+ dropout=0,
+ norm=None,
+ act_func=None,
+ ):
+ super().__init__()
+
+ self.dropout = nn.Dropout(dropout, inplace=False) if dropout > 0 else None
+ self.linear = nn.Linear(in_features, out_features, use_bias)
+ self.norm = build_norm(norm, num_features=out_features)
+ self.act = build_act(act_func)
+
+ def _try_squeeze(self, x: torch.Tensor) -> torch.Tensor:
+ if x.dim() > 2:
+ x = torch.flatten(x, start_dim=1)
+ return x
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self._try_squeeze(x)
+ if self.dropout:
+ x = self.dropout(x)
+ x = self.linear(x)
+ if self.norm:
+ x = self.norm(x)
+ if self.act:
+ x = self.act(x)
+ return x
+
+
+class IdentityLayer(nn.Module):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x
+
+
+#################################################################################
+# Basic Blocks #
+#################################################################################
+
+
+class DSConv(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=3,
+ stride=1,
+ use_bias=False,
+ norm=("bn2d", "bn2d"),
+ act_func=("relu6", None),
+ ):
+ super().__init__()
+
+ use_bias = val2tuple(use_bias, 2)
+ norm = val2tuple(norm, 2)
+ act_func = val2tuple(act_func, 2)
+
+ self.depth_conv = ConvLayer(
+ in_channels,
+ in_channels,
+ kernel_size,
+ stride,
+ groups=in_channels,
+ norm=norm[0],
+ act_func=act_func[0],
+ use_bias=use_bias[0],
+ )
+ self.point_conv = ConvLayer(
+ in_channels,
+ out_channels,
+ 1,
+ norm=norm[1],
+ act_func=act_func[1],
+ use_bias=use_bias[1],
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.depth_conv(x)
+ x = self.point_conv(x)
+ return x
+
+
+class MBConv(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=3,
+ stride=1,
+ mid_channels=None,
+ expand_ratio=6,
+ use_bias=False,
+ norm=("bn2d", "bn2d", "bn2d"),
+ act_func=("relu6", "relu6", None),
+ ):
+ super().__init__()
+
+ use_bias = val2tuple(use_bias, 3)
+ norm = val2tuple(norm, 3)
+ act_func = val2tuple(act_func, 3)
+ mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels
+
+ self.inverted_conv = ConvLayer(
+ in_channels,
+ mid_channels,
+ 1,
+ stride=1,
+ norm=norm[0],
+ act_func=act_func[0],
+ use_bias=use_bias[0],
+ )
+ self.depth_conv = ConvLayer(
+ mid_channels,
+ mid_channels,
+ kernel_size,
+ stride=stride,
+ groups=mid_channels,
+ norm=norm[1],
+ act_func=act_func[1],
+ use_bias=use_bias[1],
+ )
+ self.point_conv = ConvLayer(
+ mid_channels,
+ out_channels,
+ 1,
+ norm=norm[2],
+ act_func=act_func[2],
+ use_bias=use_bias[2],
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.inverted_conv(x)
+ x = self.depth_conv(x)
+ x = self.point_conv(x)
+ return x
+
+
+class FusedMBConv(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=3,
+ stride=1,
+ mid_channels=None,
+ expand_ratio=6,
+ groups=1,
+ use_bias=False,
+ norm=("bn2d", "bn2d"),
+ act_func=("relu6", None),
+ ):
+ super().__init__()
+ use_bias = val2tuple(use_bias, 2)
+ norm = val2tuple(norm, 2)
+ act_func = val2tuple(act_func, 2)
+
+ mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels
+
+ self.spatial_conv = ConvLayer(
+ in_channels,
+ mid_channels,
+ kernel_size,
+ stride,
+ groups=groups,
+ use_bias=use_bias[0],
+ norm=norm[0],
+ act_func=act_func[0],
+ )
+ self.point_conv = ConvLayer(
+ mid_channels,
+ out_channels,
+ 1,
+ use_bias=use_bias[1],
+ norm=norm[1],
+ act_func=act_func[1],
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.spatial_conv(x)
+ x = self.point_conv(x)
+ return x
+
+
+class GLUMBConv(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=3,
+ stride=1,
+ mid_channels=None,
+ expand_ratio=6,
+ use_bias=False,
+ norm=(None, None, "ln2d"),
+ act_func=("silu", "silu", None),
+ ):
+ super().__init__()
+ use_bias = val2tuple(use_bias, 3)
+ norm = val2tuple(norm, 3)
+ act_func = val2tuple(act_func, 3)
+
+ mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels
+
+ self.glu_act = build_act(act_func[1], inplace=False)
+ self.inverted_conv = ConvLayer(
+ in_channels,
+ mid_channels * 2,
+ 1,
+ use_bias=use_bias[0],
+ norm=norm[0],
+ act_func=act_func[0],
+ )
+ self.depth_conv = ConvLayer(
+ mid_channels * 2,
+ mid_channels * 2,
+ kernel_size,
+ stride=stride,
+ groups=mid_channels * 2,
+ use_bias=use_bias[1],
+ norm=norm[1],
+ act_func=None,
+ )
+ self.point_conv = ConvLayer(
+ mid_channels,
+ out_channels,
+ 1,
+ use_bias=use_bias[2],
+ norm=norm[2],
+ act_func=act_func[2],
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.inverted_conv(x)
+ x = self.depth_conv(x)
+
+ x, gate = torch.chunk(x, 2, dim=1)
+ gate = self.glu_act(gate)
+ x = x * gate
+
+ x = self.point_conv(x)
+ return x
+
+
+class ResBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=3,
+ stride=1,
+ mid_channels=None,
+ expand_ratio=1,
+ use_bias=False,
+ norm=("bn2d", "bn2d"),
+ act_func=("relu6", None),
+ ):
+ super().__init__()
+ use_bias = val2tuple(use_bias, 2)
+ norm = val2tuple(norm, 2)
+ act_func = val2tuple(act_func, 2)
+
+ mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels
+
+ self.conv1 = ConvLayer(
+ in_channels,
+ mid_channels,
+ kernel_size,
+ stride,
+ use_bias=use_bias[0],
+ norm=norm[0],
+ act_func=act_func[0],
+ )
+ self.conv2 = ConvLayer(
+ mid_channels,
+ out_channels,
+ kernel_size,
+ 1,
+ use_bias=use_bias[1],
+ norm=norm[1],
+ act_func=act_func[1],
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv1(x)
+ x = self.conv2(x)
+ return x
+
+
+class LiteMLA(nn.Module):
+ r"""Lightweight multi-scale linear attention"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ heads: Optional[int] = None,
+ heads_ratio: float = 1.0,
+ dim=8,
+ use_bias=False,
+ norm=(None, "bn2d"),
+ act_func=(None, None),
+ kernel_func="relu",
+ scales: tuple[int, ...] = (5,),
+ eps=1.0e-15,
+ ):
+ super().__init__()
+ self.eps = eps
+ heads = int(in_channels // dim * heads_ratio) if heads is None else heads
+
+ total_dim = heads * dim
+
+ use_bias = val2tuple(use_bias, 2)
+ norm = val2tuple(norm, 2)
+ act_func = val2tuple(act_func, 2)
+
+ self.dim = dim
+ self.qkv = ConvLayer(
+ in_channels,
+ 3 * total_dim,
+ 1,
+ use_bias=use_bias[0],
+ norm=norm[0],
+ act_func=act_func[0],
+ )
+ self.aggreg = nn.ModuleList(
+ [
+ nn.Sequential(
+ nn.Conv2d(
+ 3 * total_dim,
+ 3 * total_dim,
+ scale,
+ padding=get_same_padding(scale),
+ groups=3 * total_dim,
+ bias=use_bias[0],
+ ),
+ nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
+ )
+ for scale in scales
+ ]
+ )
+ self.kernel_func = build_act(kernel_func, inplace=False)
+
+ self.proj = ConvLayer(
+ total_dim * (1 + len(scales)),
+ out_channels,
+ 1,
+ use_bias=use_bias[1],
+ norm=norm[1],
+ act_func=act_func[1],
+ )
+
+ @torch.autocast(device_type="cuda", enabled=False)
+ def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
+ B, _, H, W = list(qkv.size())
+
+ if qkv.dtype == torch.float16:
+ qkv = qkv.float()
+
+ qkv = torch.reshape(
+ qkv,
+ (
+ B,
+ -1,
+ 3 * self.dim,
+ H * W,
+ ),
+ )
+ q, k, v = (
+ qkv[:, :, 0 : self.dim],
+ qkv[:, :, self.dim : 2 * self.dim],
+ qkv[:, :, 2 * self.dim :],
+ )
+
+ # lightweight linear attention
+ q = self.kernel_func(q)
+ k = self.kernel_func(k)
+
+ # linear matmul
+ trans_k = k.transpose(-1, -2)
+
+ v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1)
+ vk = torch.matmul(v, trans_k)
+ out = torch.matmul(vk, q)
+ if out.dtype == torch.bfloat16:
+ out = out.float()
+ out = out[:, :, :-1] / (out[:, :, -1:] + self.eps)
+
+ out = torch.reshape(out, (B, -1, H, W))
+ return out
+
+ @torch.autocast(device_type="cuda", enabled=False)
+ def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor:
+ B, _, H, W = list(qkv.size())
+
+ qkv = torch.reshape(
+ qkv,
+ (
+ B,
+ -1,
+ 3 * self.dim,
+ H * W,
+ ),
+ )
+ q, k, v = (
+ qkv[:, :, 0 : self.dim],
+ qkv[:, :, self.dim : 2 * self.dim],
+ qkv[:, :, 2 * self.dim :],
+ )
+
+ q = self.kernel_func(q)
+ k = self.kernel_func(k)
+
+ att_map = torch.matmul(k.transpose(-1, -2), q) # b h n n
+ original_dtype = att_map.dtype
+ if original_dtype in [torch.float16, torch.bfloat16]:
+ att_map = att_map.float()
+ att_map = att_map / (torch.sum(att_map, dim=2, keepdim=True) + self.eps) # b h n n
+ att_map = att_map.to(original_dtype)
+ out = torch.matmul(v, att_map) # b h d n
+
+ out = torch.reshape(out, (B, -1, H, W))
+ return out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # generate multi-scale q, k, v
+ qkv = self.qkv(x)
+ multi_scale_qkv = [qkv]
+ for op in self.aggreg:
+ multi_scale_qkv.append(op(qkv))
+ qkv = torch.cat(multi_scale_qkv, dim=1)
+
+ H, W = list(qkv.size())[-2:]
+ if H * W > self.dim:
+ out = self.relu_linear_att(qkv).to(qkv.dtype)
+ else:
+ out = self.relu_quadratic_att(qkv)
+ out = self.proj(out)
+
+ return out
+
+
+class EfficientViTBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ heads_ratio: float = 1.0,
+ dim=32,
+ expand_ratio: float = 4,
+ scales: tuple[int, ...] = (5,),
+ norm: str = "bn2d",
+ act_func: str = "hswish",
+ context_module: str = "LiteMLA",
+ local_module: str = "MBConv",
+ ):
+ super().__init__()
+ if context_module == "LiteMLA":
+ self.context_module = ResidualBlock(
+ LiteMLA(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ heads_ratio=heads_ratio,
+ dim=dim,
+ norm=(None, norm),
+ scales=scales,
+ ),
+ IdentityLayer(),
+ )
+ else:
+ raise ValueError(f"context_module {context_module} is not supported")
+ if local_module == "MBConv":
+ self.local_module = ResidualBlock(
+ MBConv(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ expand_ratio=expand_ratio,
+ use_bias=(True, True, False),
+ norm=(None, None, norm),
+ act_func=(act_func, act_func, None),
+ ),
+ IdentityLayer(),
+ )
+ elif local_module == "GLUMBConv":
+ self.local_module = ResidualBlock(
+ GLUMBConv(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ expand_ratio=expand_ratio,
+ use_bias=(True, True, False),
+ norm=(None, None, norm),
+ act_func=(act_func, act_func, None),
+ ),
+ IdentityLayer(),
+ )
+ else:
+ raise NotImplementedError(f"local_module {local_module} is not supported")
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.context_module(x)
+ x = self.local_module(x)
+ return x
+
+
+#################################################################################
+# Functional Blocks #
+#################################################################################
+
+
+class ResidualBlock(nn.Module):
+ def __init__(
+ self,
+ main: Optional[nn.Module],
+ shortcut: Optional[nn.Module],
+ post_act=None,
+ pre_norm: Optional[nn.Module] = None,
+ ):
+ super().__init__()
+
+ self.pre_norm = pre_norm
+ self.main = main
+ self.shortcut = shortcut
+ self.post_act = build_act(post_act)
+
+ def forward_main(self, x: torch.Tensor) -> torch.Tensor:
+ if self.pre_norm is None:
+ return self.main(x)
+ else:
+ return self.main(self.pre_norm(x))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.main is None:
+ res = x
+ elif self.shortcut is None:
+ res = self.forward_main(x)
+ else:
+ res = self.forward_main(x) + self.shortcut(x)
+ if self.post_act:
+ res = self.post_act(res)
+ return res
+
+
+class DAGBlock(nn.Module):
+ def __init__(
+ self,
+ inputs: dict[str, nn.Module],
+ merge: str,
+ post_input: Optional[nn.Module],
+ middle: nn.Module,
+ outputs: dict[str, nn.Module],
+ ):
+ super().__init__()
+
+ self.input_keys = list(inputs.keys())
+ self.input_ops = nn.ModuleList(list(inputs.values()))
+ self.merge = merge
+ self.post_input = post_input
+
+ self.middle = middle
+
+ self.output_keys = list(outputs.keys())
+ self.output_ops = nn.ModuleList(list(outputs.values()))
+
+ def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ feat = [op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops)]
+ if self.merge == "add":
+ feat = list_sum(feat)
+ elif self.merge == "cat":
+ feat = torch.concat(feat, dim=1)
+ else:
+ raise NotImplementedError
+ if self.post_input is not None:
+ feat = self.post_input(feat)
+ feat = self.middle(feat)
+ for key, op in zip(self.output_keys, self.output_ops):
+ feature_dict[key] = op(feat)
+ return feature_dict
+
+
+class OpSequential(nn.Module):
+ def __init__(self, op_list: list[Optional[nn.Module]]):
+ super().__init__()
+ valid_op_list = []
+ for op in op_list:
+ if op is not None:
+ valid_op_list.append(op)
+ self.op_list = nn.ModuleList(valid_op_list)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for op in self.op_list:
+ x = op(x)
+ return x
diff --git a/diffusion/model/dc_ae/efficientvit/models/nn/triton_rms_norm.py b/diffusion/model/dc_ae/efficientvit/models/nn/triton_rms_norm.py
new file mode 100644
index 0000000..6f559b5
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/models/nn/triton_rms_norm.py
@@ -0,0 +1,207 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+__all__ = ["TritonRMSNorm2dFunc"]
+
+
+@triton.jit
+def _rms_norm_2d_fwd_fused(
+ X, # pointer to the input
+ Y, # pointer to the output
+ W, # pointer to the weights
+ B, # pointer to the biases
+ Rrms, # pointer to the 1/rms
+ M,
+ C,
+ N,
+ num_blocks, # number of columns in X
+ eps, # epsilon to avoid division by zero
+ BLOCK_SIZE: tl.constexpr,
+):
+ # Map the program id to the row of X and Y it should compute.
+ m_n = tl.program_id(0)
+ m, n = m_n // num_blocks, m_n % num_blocks
+
+ Y += m * C * N
+ X += m * C * N
+ # Compute mean
+
+ cols = n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = cols < N
+
+ x_sum_square = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, C):
+ x = tl.load(X + off * N + cols, mask=mask, other=0.0).to(tl.float32)
+ x_sum_square += x * x
+ mean_square = x_sum_square / C
+ rrms = 1 / tl.sqrt(mean_square + eps)
+ # Write rstd
+ tl.store(Rrms + m * N + cols, rrms, mask=mask)
+ # Normalize and apply linear transformation
+ for off in range(0, C):
+ pos = off * N + cols
+ w = tl.load(W + off)
+ b = tl.load(B + off)
+ x = tl.load(X + pos, mask=mask, other=0.0).to(tl.float32)
+ x_hat = x * rrms
+ y = x_hat * w + b
+ # Write output
+ tl.store(Y + pos, y, mask=mask)
+
+
+@triton.jit
+def _rms_norm_2d_bwd_dx_fused(
+ DX, # pointer to the input gradient
+ DY, # pointer to the output gradient
+ DW, # pointer to the partial sum of weights gradient
+ DB, # pointer to the partial sum of biases gradient
+ X, # pointer to the input
+ W, # pointer to the weights
+ B, # pointer to the biases
+ Rrms, # pointer to the 1/rms
+ M,
+ C,
+ N, # number of columns in X
+ num_blocks,
+ eps, # epsilon to avoid division by zero
+ GROUP_SIZE_M: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+ BLOCK_SIZE_C: tl.constexpr,
+):
+ # Map the program id to the elements of X, DX, and DY it should compute.
+ m_n = tl.program_id(0)
+ m, n = m_n // num_blocks, m_n % num_blocks
+ X += m * C * N
+ DY += m * C * N
+ DX += m * C * N
+ Rrms += m * N
+
+ cols = n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = cols < N
+ # Offset locks and weights/biases gradient pointer for parallel reduction
+ DW = DW + m_n * C
+ DB = DB + m_n * C
+ rrms = tl.load(Rrms + cols, mask=mask, other=1)
+ # Load data to SRAM
+ c1 = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, C):
+ pos = off * N + cols
+ x = tl.load(X + pos, mask=mask, other=0).to(tl.float32)
+ dy = tl.load(DY + pos, mask=mask, other=0).to(tl.float32)
+ w = tl.load(W + off).to(tl.float32)
+ # Compute dx
+ xhat = x * rrms
+ wdy = w * dy
+ xhat = tl.where(mask, xhat, 0.0)
+ wdy = tl.where(mask, wdy, 0.0)
+ c1 += xhat * wdy
+ # Accumulate partial sums for dw/db
+ tl.store(DW + off, tl.sum((dy * xhat).to(w.dtype), axis=0))
+ tl.store(DB + off, tl.sum(dy.to(w.dtype), axis=0))
+
+ c1 /= C
+ for off in range(0, C):
+ pos = off * N + cols
+ x = tl.load(X + pos, mask=mask, other=0).to(tl.float32)
+ dy = tl.load(DY + pos, mask=mask, other=0).to(tl.float32)
+ w = tl.load(W + off).to(tl.float32)
+ xhat = x * rrms
+ wdy = w * dy
+ dx = (wdy - (xhat * c1)) * rrms
+ # Write dx
+ tl.store(DX + pos, dx, mask=mask)
+
+
+class TritonRMSNorm2dFunc(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, weight, bias, eps):
+ # allocate output
+ y = torch.empty_like(x)
+ # reshape input data into 2D tensor
+ x_arg = x.reshape(x.shape[0], x.shape[1], -1)
+ M, C, N = x_arg.shape
+ rrms = torch.empty((M, N), dtype=torch.float32, device="cuda")
+ # Less than 64KB per feature: enqueue fused kernel
+ BLOCK_SIZE = 256
+ num_blocks = triton.cdiv(N, BLOCK_SIZE)
+ num_warps = 8
+ # enqueue kernel
+ _rms_norm_2d_fwd_fused[(M * num_blocks,)]( #
+ x_arg,
+ y,
+ weight,
+ bias,
+ rrms, #
+ M,
+ C,
+ N,
+ num_blocks,
+ eps, #
+ BLOCK_SIZE=BLOCK_SIZE,
+ num_warps=num_warps,
+ num_ctas=1,
+ )
+ ctx.save_for_backward(x, weight, bias, rrms)
+ ctx.BLOCK_SIZE = BLOCK_SIZE
+ ctx.num_blocks = num_blocks
+ ctx.num_warps = num_warps
+ ctx.eps = eps
+ return y
+
+ @staticmethod
+ def backward(ctx, dy):
+ x, w, b, rrms = ctx.saved_tensors
+ num_blocks = ctx.num_blocks
+
+ x_arg = x.reshape(x.shape[0], x.shape[1], -1)
+ M, C, N = x_arg.shape
+ # GROUP_SIZE_M = 64
+ GROUP_SIZE_M = M * num_blocks
+ # allocate output
+ _dw = torch.empty((GROUP_SIZE_M, C), dtype=x.dtype, device=w.device)
+ _db = torch.empty((GROUP_SIZE_M, C), dtype=x.dtype, device=w.device)
+ dw = torch.empty((C,), dtype=w.dtype, device=w.device)
+ db = torch.empty((C,), dtype=w.dtype, device=w.device)
+ dx = torch.empty_like(dy)
+ # enqueue kernel using forward pass heuristics
+ # also compute partial sums for DW and DB
+ # print(f"M={M}, num_blocks={num_blocks}, dx={dx.shape}, dy={dy.shape}, _dw={_dw.shape}, _db={_db.shape}, x={x.shape}, w={w.shape}, b={b.shape}, m={m.shape}, v={v.shape}, M={M}, C={C}, N={N}")
+ _rms_norm_2d_bwd_dx_fused[(M * num_blocks,)]( #
+ dx,
+ dy,
+ _dw,
+ _db,
+ x,
+ w,
+ b,
+ rrms, #
+ M,
+ C,
+ N,
+ num_blocks,
+ ctx.eps, #
+ BLOCK_SIZE=ctx.BLOCK_SIZE,
+ GROUP_SIZE_M=GROUP_SIZE_M, #
+ BLOCK_SIZE_C=triton.next_power_of_2(C),
+ num_warps=ctx.num_warps,
+ )
+ dw = _dw.sum(dim=0)
+ db = _db.sum(dim=0)
+ return dx, dw, db, None
diff --git a/diffusion/model/dc_ae/efficientvit/models/utils/__init__.py b/diffusion/model/dc_ae/efficientvit/models/utils/__init__.py
new file mode 100644
index 0000000..4155f95
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/models/utils/__init__.py
@@ -0,0 +1,3 @@
+from .list import *
+from .network import *
+from .random import *
diff --git a/diffusion/model/dc_ae/efficientvit/models/utils/list.py b/diffusion/model/dc_ae/efficientvit/models/utils/list.py
new file mode 100644
index 0000000..2dd8fc4
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/models/utils/list.py
@@ -0,0 +1,67 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Union
+
+__all__ = [
+ "list_sum",
+ "list_mean",
+ "weighted_list_sum",
+ "list_join",
+ "val2list",
+ "val2tuple",
+ "squeeze_list",
+]
+
+
+def list_sum(x: list) -> Any:
+ return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
+
+
+def list_mean(x: list) -> Any:
+ return list_sum(x) / len(x)
+
+
+def weighted_list_sum(x: list, weights: list) -> Any:
+ assert len(x) == len(weights)
+ return x[0] * weights[0] if len(x) == 1 else x[0] * weights[0] + weighted_list_sum(x[1:], weights[1:])
+
+
+def list_join(x: list, sep="\t", format_str="%s") -> str:
+ return sep.join([format_str % val for val in x])
+
+
+def val2list(x: Union[list, tuple, Any], repeat_time=1) -> list:
+ if isinstance(x, (list, tuple)):
+ return list(x)
+ return [x for _ in range(repeat_time)]
+
+
+def val2tuple(x: Union[list, tuple, Any], min_len: int = 1, idx_repeat: int = -1) -> tuple:
+ x = val2list(x)
+
+ # repeat elements if necessary
+ if len(x) > 0:
+ x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
+
+ return tuple(x)
+
+
+def squeeze_list(x: Optional[list]) -> Union[list, Any]:
+ if x is not None and len(x) == 1:
+ return x[0]
+ else:
+ return x
diff --git a/diffusion/model/dc_ae/efficientvit/models/utils/network.py b/diffusion/model/dc_ae/efficientvit/models/utils/network.py
new file mode 100644
index 0000000..0e23374
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/models/utils/network.py
@@ -0,0 +1,111 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import collections
+import os
+from inspect import signature
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = [
+ "is_parallel",
+ "get_device",
+ "get_same_padding",
+ "resize",
+ "build_kwargs_from_config",
+ "load_state_dict_from_file",
+ "get_submodule_weights",
+]
+
+
+def is_parallel(model: nn.Module) -> bool:
+ return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
+
+
+def get_device(model: nn.Module) -> torch.device:
+ return model.parameters().__next__().device
+
+
+def get_dtype(model: nn.Module) -> torch.dtype:
+ return model.parameters().__next__().dtype
+
+
+def get_same_padding(kernel_size: Union[int, tuple[int, ...]]) -> Union[int, tuple[int, ...]]:
+ if isinstance(kernel_size, tuple):
+ return tuple([get_same_padding(ks) for ks in kernel_size])
+ else:
+ assert kernel_size % 2 > 0, "kernel size should be odd number"
+ return kernel_size // 2
+
+
+def resize(
+ x: torch.Tensor,
+ size: Optional[Any] = None,
+ scale_factor: Optional[list[float]] = None,
+ mode: str = "bicubic",
+ align_corners: Optional[bool] = False,
+) -> torch.Tensor:
+ if mode in {"bilinear", "bicubic"}:
+ return F.interpolate(
+ x,
+ size=size,
+ scale_factor=scale_factor,
+ mode=mode,
+ align_corners=align_corners,
+ )
+ elif mode in {"nearest", "area"}:
+ return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode)
+ else:
+ raise NotImplementedError(f"resize(mode={mode}) not implemented.")
+
+
+def build_kwargs_from_config(config: dict, target_func: Callable) -> dict[str, Any]:
+ valid_keys = list(signature(target_func).parameters)
+ kwargs = {}
+ for key in config:
+ if key in valid_keys:
+ kwargs[key] = config[key]
+ return kwargs
+
+
+def load_state_dict_from_file(file: str, only_state_dict=True) -> dict[str, torch.Tensor]:
+ file = os.path.realpath(os.path.expanduser(file))
+ checkpoint = torch.load(file, map_location="cpu", weights_only=True)
+ if only_state_dict and "state_dict" in checkpoint:
+ checkpoint = checkpoint["state_dict"]
+ return checkpoint
+
+
+def get_submodule_weights(weights: collections.OrderedDict, prefix: str):
+ submodule_weights = collections.OrderedDict()
+ len_prefix = len(prefix)
+ for key, weight in weights.items():
+ if key.startswith(prefix):
+ submodule_weights[key[len_prefix:]] = weight
+ return submodule_weights
+
+
+def get_dtype_from_str(dtype: str) -> torch.dtype:
+ if dtype == "fp32":
+ return torch.float32
+ if dtype == "fp16":
+ return torch.float16
+ if dtype == "bf16":
+ return torch.bfloat16
+ raise NotImplementedError(f"dtype {dtype} is not supported")
diff --git a/diffusion/model/dc_ae/efficientvit/models/utils/random.py b/diffusion/model/dc_ae/efficientvit/models/utils/random.py
new file mode 100644
index 0000000..26675a0
--- /dev/null
+++ b/diffusion/model/dc_ae/efficientvit/models/utils/random.py
@@ -0,0 +1,79 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Union
+
+import numpy as np
+import torch
+
+__all__ = [
+ "torch_randint",
+ "torch_random",
+ "torch_shuffle",
+ "torch_uniform",
+ "torch_random_choices",
+]
+
+
+def torch_randint(low: int, high: int, generator: Optional[torch.Generator] = None) -> int:
+ """uniform: [low, high)"""
+ if low == high:
+ return low
+ else:
+ assert low < high
+ return int(torch.randint(low=low, high=high, generator=generator, size=(1,)))
+
+
+def torch_random(generator: Optional[torch.Generator] = None) -> float:
+ """uniform distribution on the interval [0, 1)"""
+ return float(torch.rand(1, generator=generator))
+
+
+def torch_shuffle(src_list: list[Any], generator: Optional[torch.Generator] = None) -> list[Any]:
+ rand_indexes = torch.randperm(len(src_list), generator=generator).tolist()
+ return [src_list[i] for i in rand_indexes]
+
+
+def torch_uniform(low: float, high: float, generator: Optional[torch.Generator] = None) -> float:
+ """uniform distribution on the interval [low, high)"""
+ rand_val = torch_random(generator)
+ return (high - low) * rand_val + low
+
+
+def torch_random_choices(
+ src_list: list[Any],
+ generator: Optional[torch.Generator] = None,
+ k=1,
+ weight_list: Optional[list[float]] = None,
+) -> Union[Any, list]:
+ if weight_list is None:
+ rand_idx = torch.randint(low=0, high=len(src_list), generator=generator, size=(k,))
+ out_list = [src_list[i] for i in rand_idx]
+ else:
+ assert len(weight_list) == len(src_list)
+ accumulate_weight_list = np.cumsum(weight_list)
+
+ out_list = []
+ for _ in range(k):
+ val = torch_uniform(0, accumulate_weight_list[-1], generator)
+ active_id = 0
+ for i, weight_val in enumerate(accumulate_weight_list):
+ active_id = i
+ if weight_val > val:
+ break
+ out_list.append(src_list[active_id])
+
+ return out_list[0] if k == 1 else out_list
diff --git a/diffusion/model/nets/fastlinear/develop_triton_ffn.py b/diffusion/model/nets/fastlinear/develop_triton_ffn.py
new file mode 100644
index 0000000..285c3e6
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/develop_triton_ffn.py
@@ -0,0 +1,307 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import time
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import ipdb
+import torch
+from modules.mb_conv_pre_glu import MBConvPreGLU
+from modules.triton_mb_conv_pre_glu import TritonMBConvPreGLU
+from modules.utils.compare_results import compare_results
+from modules.utils.dtype import get_dtype_from_str
+from modules.utils.export_onnx import export_onnx
+from omegaconf import OmegaConf
+from torch import nn
+from torch.nn import functional as F
+from torchprofile import profile_macs
+
+
+@dataclass
+class DevelopTritonFFNConfig:
+ batch_size: int = 16
+ input_size: int = 1024 // 32 // 1
+ num_channels: int = 1152
+ mlp_ratio: float = 2.5
+ ffn_type: str = "MBConvPreGLU"
+ act: Tuple[Optional[str]] = ("silu", "silu", None)
+
+ device: str = "cuda"
+ dtype: str = "fp16"
+
+ profile_macs: bool = False
+ test_correctness: bool = False
+ warmup_iterations: int = 50
+ iterations: int = 1000
+ random_weight: bool = True
+ backward: bool = False
+ autocast: bool = False
+ use_cuda_graph: bool = False
+
+ export_model: bool = False
+ opset: int = 17
+ export_path: str = ""
+ export_dtype: str = "fp32"
+ export_device: str = "cuda"
+
+
+# def simulate_litemla(x: torch.Tensor, qkv_weight: torch.Tensor, proj_weight: torch.Tensor, proj_bias: torch.Tensor, num_heads: int, head_dim: int, eps: float, backward: bool):
+# B, N, C = x.shape
+# qkv = F.linear(x, qkv_weight).reshape(B, N, 3, C).permute(0, 2, 3, 1)
+# q, k, v = qkv.unbind(1) # B, 3, C, N --> B, C, N
+
+# q = q.reshape(B, C // head_dim, head_dim, N) # b, h, h_d, N
+# k = k.reshape(B, C // head_dim, head_dim, N).transpose(-1, -2) # b, h, N, h_d
+# v = v.reshape(B, C // head_dim, head_dim, N) # b, h, h_d, N
+
+# q = F.relu(q) # B, h, h_d, N
+# k = F.relu(k)
+
+# q, k, v = q.float(), k.float(), v.float()
+# if backward:
+# k.retain_grad()
+# v.retain_grad()
+# q.retain_grad()
+# v_pad = F.pad(v, (0, 0, 0, 1), mode="constant", value=1)
+# vk = torch.matmul(v_pad, k)
+# if backward:
+# vk.retain_grad()
+# vk_q = torch.matmul(vk, q)
+# vk_q_numerator, vk_q_denominator = vk_q[:, :, :-1], vk_q[:, :, -1:]
+# if backward:
+# vk_q_numerator.retain_grad()
+# vk_q_denominator.retain_grad()
+# vk_q_divide = (vk_q_numerator / (vk_q_denominator + eps)).to(x.dtype)
+
+# proj_input = vk_q_divide.view(B, C, N).permute(0, 2, 1) # B, N, C
+# if backward:
+# proj_input.retain_grad()
+# y = F.linear(proj_input, proj_weight, proj_bias)
+# output_dict = {
+# "q": q,
+# "k": k,
+# "v": v,
+# "vk": vk,
+# "proj_input": proj_input,
+# "vk_q_numerator": vk_q_numerator,
+# "vk_q_denominator": vk_q_denominator,
+# "vk_q_divide": vk_q_divide,
+# "y": y,
+# }
+# return output_dict
+
+
+def main():
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ torch.cuda.manual_seed(0)
+ torch.manual_seed(0)
+
+ cfg = OmegaConf.structured(DevelopTritonFFNConfig)
+ cli_cfg = OmegaConf.from_cli()
+ cfg = OmegaConf.merge(cfg, OmegaConf.masked_copy(cli_cfg, cfg.keys()))
+ cfg: DevelopTritonFFNConfig = OmegaConf.to_object(cfg)
+
+ torch.set_grad_enabled(cfg.backward)
+
+ device = torch.device("cuda")
+ if cfg.autocast:
+ dtype = torch.float32
+ autocast_dtype = get_dtype_from_str(cfg.dtype)
+ else:
+ dtype = get_dtype_from_str(cfg.dtype)
+ autocast_dtype = None
+
+ print(cfg.ffn_type)
+ if cfg.ffn_type == "MBConvPreGLU":
+ block = MBConvPreGLU(
+ in_dim=cfg.num_channels,
+ out_dim=cfg.num_channels,
+ mid_dim=int(cfg.num_channels * cfg.mlp_ratio),
+ use_bias=(True, True, False),
+ norm=None,
+ act=cfg.act,
+ )
+ elif cfg.ffn_type == "TritonMBConvPreGLU":
+ block = TritonMBConvPreGLU(
+ in_dim=cfg.num_channels,
+ out_dim=cfg.num_channels,
+ mid_dim=int(cfg.num_channels * cfg.mlp_ratio),
+ use_bias=(True, True, False),
+ norm=None,
+ act=cfg.act,
+ )
+ else:
+ raise NotImplementedError
+
+ print(
+ f"bs: {cfg.batch_size}, ffn_type: {cfg.ffn_type}, mlp_ratio: {cfg.mlp_ratio}, latent_size: {cfg.input_size} X {cfg.input_size}"
+ )
+ print(f"MLP: {block.__class__.__name__}, MLP Parameters: {sum(p.numel() for p in block.parameters()) / 1e6:.2f}M")
+
+ if not cfg.backward:
+ block = block.eval()
+ block = block.to(device=device, dtype=dtype, memory_format=torch.channels_last)
+
+ if cfg.random_weight:
+ for param in block.parameters():
+ nn.init.trunc_normal_(param, std=0.001)
+
+ if cfg.profile_macs:
+ macs = profile_macs(block, x)
+ print(f"macs: {macs}")
+
+ if cfg.export_model:
+ export_dtype = get_dtype_from_str(cfg.export_dtype)
+ export_device = torch.device(cfg.export_device)
+ assert cfg.export_path != ""
+ export_onnx(
+ block.to(device=export_device, dtype=export_dtype),
+ (1, cfg.input_size**2, cfg.num_channels),
+ cfg.export_path,
+ cfg.opset,
+ export_dtype,
+ export_device,
+ )
+ elif cfg.test_correctness:
+ if cfg.ffn_type in ["MBConvPreGLU", "TritonMBConvPreGLU"]:
+ ref_block = (
+ MBConvPreGLU(
+ in_dim=cfg.num_channels,
+ out_dim=cfg.num_channels,
+ mid_dim=int(cfg.num_channels * cfg.mlp_ratio),
+ use_bias=(True, True, False),
+ norm=None,
+ act=cfg.act,
+ )
+ .eval()
+ .to(device=device, memory_format=torch.channels_last)
+ )
+ else:
+ raise NotImplementedError(f"ffn_type {cfg.ffn_type} is not supported")
+ block.load_state_dict(ref_block.state_dict())
+ correct = True
+ for i in range(10):
+ ref_x = torch.randn(
+ cfg.batch_size, cfg.input_size**2, cfg.num_channels, device=device, requires_grad=cfg.backward
+ )
+ x = ref_x.clone().detach().to(dtype=dtype).requires_grad_(cfg.backward)
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast):
+ output = block(x)
+ ref_output = ref_block(ref_x)
+ if cfg.backward:
+ dy = 0.1 * torch.randn_like(output)
+ output.backward(dy)
+ ref_output.backward(dy.float())
+ output_float = output.float()
+ if not torch.allclose(output_float, ref_output):
+ correct = False
+ max_error_pos = (output_float - ref_output).abs().view(-1).argmax()
+ print(f"comparing forward results")
+ print(
+ f"max error: {(output_float - ref_output).abs().max()}, mean error: {(output_float - ref_output).abs().mean()}"
+ )
+ print(f"max error pos: {ref_output.view(-1)[max_error_pos]} {output_float.view(-1)[max_error_pos]}")
+ if cfg.backward:
+ for (name, param), (ref_name, ref_param) in zip(block.named_parameters(), ref_block.named_parameters()):
+ assert name == ref_name
+ compare_results(f"{name} grad", param.grad, ref_param.grad)
+ compare_results(f"x grad", x.grad, ref_x.grad)
+ if correct:
+ print("correct!")
+ elif cfg.use_cuda_graph:
+ x = torch.randn(
+ cfg.batch_size,
+ cfg.input_size**2,
+ cfg.num_channels,
+ device=device,
+ dtype=dtype,
+ requires_grad=cfg.backward,
+ )
+ grad_y = 0.1 * torch.randn_like(x)
+
+ s = torch.cuda.Stream()
+ s.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(s):
+ for i in range(cfg.warmup_iterations):
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast):
+ y = block(x)
+ if cfg.backward:
+ y.backward(grad_y)
+ torch.cuda.current_stream().wait_stream(s)
+
+ g = torch.cuda.CUDAGraph()
+ # Sets grads to None before capture, so backward() will create
+ # .grad attributes with allocations from the graph's private pool
+ with torch.cuda.graph(g):
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast):
+ y = block(x)
+ if cfg.backward:
+ y.backward(grad_y)
+
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for i in range(cfg.iterations):
+ g.replay()
+ torch.cuda.synchronize()
+ end_time = time.time()
+ print(f"using cuda graph:")
+ print(f"each step takes {(end_time - start_time) * 1000 / cfg.iterations:.2f} ms")
+ print(f"max memory allocated: {torch.cuda.max_memory_allocated() / 1024 ** 3:.4f} GB\n{'-' * 80}")
+ else:
+ x = torch.randn(
+ cfg.batch_size,
+ cfg.input_size**2,
+ cfg.num_channels,
+ device=device,
+ dtype=dtype,
+ requires_grad=cfg.backward,
+ )
+ grad_y = 0.1 * torch.randn_like(x)
+ for i in range(cfg.warmup_iterations):
+ # ipdb.set_trace()
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast):
+ y = block(x)
+ if cfg.backward:
+ y.backward(grad_y)
+
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for i in range(cfg.iterations):
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast):
+ y = block(x)
+ if cfg.backward:
+ y.backward(grad_y)
+ torch.cuda.synchronize()
+ end_time = time.time()
+ print(f"each step takes {(end_time - start_time) * 1000 / cfg.iterations:.2f} ms")
+ # ipdb.set_trace()
+ print(f"max memory allocated: {torch.cuda.max_memory_allocated() / 1024 ** 3:.4f} GB\n{'-' * 80}")
+
+
+if __name__ == "__main__":
+ main()
+
+"""
+# 64x64 fp16
+python -m develop_triton_ffn ffn_type=MBConvPreGLU test_correctness=True
+each step takes 12.45 ms
+max memory allocated: 1.8467 GB
+
+python -m develop_triton_ffn ffn_type=TritonMBConvPreGLU test_correctness=True
+
+"""
diff --git a/diffusion/model/nets/fastlinear/develop_triton_litemla.py b/diffusion/model/nets/fastlinear/develop_triton_litemla.py
new file mode 100644
index 0000000..cb62b34
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/develop_triton_litemla.py
@@ -0,0 +1,321 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import time
+from dataclasses import dataclass
+
+import ipdb
+import torch
+from modules.flash_attn import FlashAttention
+from modules.lite_mla import LiteMLA
+from modules.triton_lite_mla import TritonLiteMLA
+from modules.triton_lite_mla_fwd import TritonLiteMLAFwd
+from modules.utils.dtype import get_dtype_from_str
+from modules.utils.export_onnx import export_onnx
+from omegaconf import OmegaConf
+from torch import nn
+from torch.nn import functional as F
+from torchprofile import profile_macs
+
+
+@dataclass
+class DevelopTritonLiteMLAConfig:
+ batch_size: int = 16
+ input_size: int = 1024 // 8 // 2
+ num_channels: int = 1152
+ num_heads: int = 36
+ attn_type: str = "LiteMLA"
+
+ device: str = "cuda"
+ dtype: str = "fp16"
+
+ profile_macs: bool = False
+ test_correctness: bool = False
+ warmup_iterations: int = 50
+ iterations: int = 1000
+ random_weight: bool = True
+ backward: bool = False
+ autocast: bool = False
+ use_cuda_graph: bool = False
+
+ export_model: bool = False
+ opset: int = 17
+ export_path: str = ""
+ export_dtype: str = "fp32"
+ export_device: str = "cuda"
+
+
+def simulate_litemla(
+ x: torch.Tensor,
+ qkv_weight: torch.Tensor,
+ proj_weight: torch.Tensor,
+ proj_bias: torch.Tensor,
+ num_heads: int,
+ head_dim: int,
+ eps: float,
+ backward: bool,
+):
+ B, N, C = x.shape
+ qkv = F.linear(x, qkv_weight).reshape(B, N, 3, C).permute(0, 2, 3, 1)
+ q, k, v = qkv.unbind(1) # B, 3, C, N --> B, C, N
+
+ q = q.reshape(B, C // head_dim, head_dim, N) # b, h, h_d, N
+ k = k.reshape(B, C // head_dim, head_dim, N).transpose(-1, -2) # b, h, N, h_d
+ v = v.reshape(B, C // head_dim, head_dim, N) # b, h, h_d, N
+
+ q = F.relu(q) # B, h, h_d, N
+ k = F.relu(k)
+
+ q, k, v = q.float(), k.float(), v.float()
+ if backward:
+ k.retain_grad()
+ v.retain_grad()
+ q.retain_grad()
+ v_pad = F.pad(v, (0, 0, 0, 1), mode="constant", value=1)
+ vk = torch.matmul(v_pad, k)
+ if backward:
+ vk.retain_grad()
+ vk_q = torch.matmul(vk, q)
+ vk_q_numerator, vk_q_denominator = vk_q[:, :, :-1], vk_q[:, :, -1:]
+ if backward:
+ vk_q_numerator.retain_grad()
+ vk_q_denominator.retain_grad()
+ vk_q_divide = (vk_q_numerator / (vk_q_denominator + eps)).to(x.dtype)
+
+ proj_input = vk_q_divide.view(B, C, N).permute(0, 2, 1) # B, N, C
+ if backward:
+ proj_input.retain_grad()
+ y = F.linear(proj_input, proj_weight, proj_bias)
+ output_dict = {
+ "q": q,
+ "k": k,
+ "v": v,
+ "vk": vk,
+ "proj_input": proj_input,
+ "vk_q_numerator": vk_q_numerator,
+ "vk_q_denominator": vk_q_denominator,
+ "vk_q_divide": vk_q_divide,
+ "y": y,
+ }
+ return output_dict
+
+
+def main():
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ LiteMLA.fp32_attention = True
+ torch.cuda.manual_seed(0)
+ torch.manual_seed(0)
+
+ cfg = OmegaConf.structured(DevelopTritonLiteMLAConfig)
+ cli_cfg = OmegaConf.from_cli()
+ cfg = OmegaConf.merge(cfg, OmegaConf.masked_copy(cli_cfg, cfg.keys()))
+ cfg: DevelopTritonLiteMLAConfig = OmegaConf.to_object(cfg)
+
+ torch.set_grad_enabled(cfg.backward)
+
+ device = torch.device("cuda")
+ if cfg.autocast:
+ dtype = torch.float32
+ autocast_dtype = get_dtype_from_str(cfg.dtype)
+ else:
+ dtype = get_dtype_from_str(cfg.dtype)
+ autocast_dtype = None
+
+ if cfg.attn_type == "LiteMLA":
+ block = LiteMLA(cfg.num_channels, cfg.num_channels, dim=cfg.num_channels // cfg.num_heads, eps=1e-8)
+ elif cfg.attn_type == "TritonLiteMLA":
+ block = TritonLiteMLA(cfg.num_channels, cfg.num_heads, eps=1e-8)
+ elif cfg.attn_type == "TritonLiteMLAFwd":
+ block = TritonLiteMLAFwd(cfg.num_channels, cfg.num_heads, eps=1e-8)
+ elif cfg.attn_type == "FlashAttention":
+ block = FlashAttention(cfg.num_channels, cfg.num_heads)
+ else:
+ raise NotImplementedError
+
+ if not cfg.backward:
+ block = block.eval()
+ block = block.to(device=device, dtype=dtype, memory_format=torch.channels_last)
+
+ if cfg.random_weight:
+ for param in block.parameters():
+ nn.init.trunc_normal_(param, std=0.001)
+
+ if cfg.profile_macs:
+ macs = profile_macs(block, x)
+ print(f"macs: {macs}")
+
+ if cfg.export_model:
+ export_dtype = get_dtype_from_str(cfg.export_dtype)
+ export_device = torch.device(cfg.export_device)
+ assert cfg.export_path != ""
+ export_onnx(
+ block.to(device=export_device, dtype=export_dtype),
+ (1, cfg.input_size**2, cfg.num_channels),
+ cfg.export_path,
+ cfg.opset,
+ export_dtype,
+ export_device,
+ )
+ if cfg.test_correctness:
+ ref_block = (
+ LiteMLA(cfg.num_channels, cfg.num_channels, dim=cfg.num_channels // cfg.num_heads, eps=1e-8)
+ .eval()
+ .to(device=device, memory_format=torch.channels_last)
+ )
+ block.load_state_dict(ref_block.state_dict())
+ correct = True
+ for i in range(10):
+ ref_x = torch.randn(
+ cfg.batch_size, cfg.input_size**2, cfg.num_channels, device=device, requires_grad=cfg.backward
+ )
+ x = ref_x.clone().detach().to(dtype=dtype).requires_grad_(cfg.backward)
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast):
+ output = block(x)
+ ref_output_dict = simulate_litemla(
+ ref_x,
+ ref_block.qkv.weight,
+ ref_block.proj.weight,
+ ref_block.proj.bias,
+ ref_block.in_dim // ref_block.dim,
+ ref_block.dim,
+ ref_block.eps,
+ cfg.backward,
+ )
+ ref_output = ref_output_dict["y"]
+ if cfg.backward:
+ dy = 0.1 * torch.randn_like(output)
+ output.backward(dy)
+ ref_output.backward(dy.float())
+ # ipdb.set_trace()
+ ref_output_1 = ref_block(ref_x)
+ assert torch.allclose(ref_output, ref_output_1)
+ output_float = output.float()
+ if not torch.allclose(output_float, ref_output):
+ correct = False
+ max_error_pos = (output_float - ref_output).abs().view(-1).argmax()
+ print(f"comparing forward results")
+ print(
+ f"max error: {(output_float - ref_output).abs().max()}, mean error: {(output_float - ref_output).abs().mean()}"
+ )
+ print(f"max error pos: {ref_output.view(-1)[max_error_pos]} {output_float.view(-1)[max_error_pos]}")
+ if cfg.backward:
+ for name, grad, ref_grad in [
+ ("proj_weight", block.proj.weight.grad, ref_block.proj.weight.grad),
+ ("proj_bias", block.proj.bias.grad, ref_block.proj.bias.grad),
+ ("qkv_weight", block.qkv.weight.grad, ref_block.qkv.weight.grad),
+ ("x", x.grad, ref_x.grad),
+ ]:
+ print(f"comparing {name}")
+ grad_float = grad.float()
+ max_error_pos = (grad_float - ref_grad).abs().view(-1).argmax()
+ print(
+ f"max error: {(grad_float - ref_grad).abs().max()}, mean error: {(grad_float - ref_grad).abs().mean()}"
+ )
+ print(f"max error pos: {ref_grad.view(-1)[max_error_pos]} {grad_float.view(-1)[max_error_pos]}")
+ # ipdb.set_trace()
+ if correct:
+ print("correct!")
+ elif cfg.use_cuda_graph:
+ x = torch.randn(
+ cfg.batch_size,
+ cfg.input_size**2,
+ cfg.num_channels,
+ device=device,
+ dtype=dtype,
+ requires_grad=cfg.backward,
+ )
+ grad_y = 0.1 * torch.randn_like(x)
+
+ s = torch.cuda.Stream()
+ s.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(s):
+ for i in range(cfg.warmup_iterations):
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast):
+ y = block(x)
+ if cfg.backward:
+ y.backward(grad_y)
+ torch.cuda.current_stream().wait_stream(s)
+
+ g = torch.cuda.CUDAGraph()
+ # Sets grads to None before capture, so backward() will create
+ # .grad attributes with allocations from the graph's private pool
+ with torch.cuda.graph(g):
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast):
+ y = block(x)
+ if cfg.backward:
+ y.backward(grad_y)
+
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for i in range(cfg.iterations):
+ g.replay()
+ torch.cuda.synchronize()
+ end_time = time.time()
+ print(f"using cuda graph:")
+ print(f"each step takes {(end_time-start_time)*1000/cfg.iterations:.2f} ms")
+ print(f"max memory allocated: {torch.cuda.max_memory_allocated()/1024**3:.4f} GB")
+ else:
+ x = torch.randn(
+ cfg.batch_size,
+ cfg.input_size**2,
+ cfg.num_channels,
+ device=device,
+ dtype=dtype,
+ requires_grad=cfg.backward,
+ )
+ grad_y = 0.1 * torch.randn_like(x)
+ for i in range(cfg.warmup_iterations):
+ # ipdb.set_trace()
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast):
+ y = block(x)
+ if cfg.backward:
+ y.backward(grad_y)
+
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for i in range(cfg.iterations):
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast):
+ y = block(x)
+ if cfg.backward:
+ y.backward(grad_y)
+ torch.cuda.synchronize()
+ end_time = time.time()
+ print(f"each step takes {(end_time - start_time) * 1000 / cfg.iterations:.2f} ms")
+ # ipdb.set_trace()
+ print(f"max memory allocated: {torch.cuda.max_memory_allocated() / 1024 ** 3:.4f} GB")
+
+ # x = torch.randn(cfg.batch_size*2, (cfg.input_size*2)**2, cfg.num_channels, device=device, dtype=dtype, requires_grad=cfg.backward)
+ # grad_y = 0.1*torch.randn_like(x)
+ # with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=cfg.autocast):
+ # y = block(x)
+ # if cfg.backward:
+ # y.backward(grad_y)
+
+
+if __name__ == "__main__":
+ main()
+
+"""
+# 64x64 fp16
+python -m develop_triton_litemla attn_type=LiteMLA test_correctness=True
+each step takes 10.81 ms
+max memory allocated: 2.2984 GB
+
+python -m develop_triton_litemla attn_type=TritonLiteMLA test_correctness=True
+each step takes 4.70 ms
+max memory allocated: 1.6480 GB
+"""
diff --git a/diffusion/model/nets/fastlinear/modules/__init__.py b/diffusion/model/nets/fastlinear/modules/__init__.py
new file mode 100644
index 0000000..5f4ed21
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/__init__.py
@@ -0,0 +1,21 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from .triton_lite_mla import *
+from .triton_lite_mla_fwd import *
+from .triton_mb_conv_pre_glu import *
+
+# from .flash_attn import *
diff --git a/diffusion/model/nets/fastlinear/modules/flash_attn.py b/diffusion/model/nets/fastlinear/modules/flash_attn.py
new file mode 100644
index 0000000..080220e
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/flash_attn.py
@@ -0,0 +1,43 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+from flash_attn import flash_attn_func
+from torch import nn
+from torch.nn import functional as F
+
+
+class FlashAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int):
+ super().__init__()
+ self.dim = dim
+ assert dim % num_heads == 0
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
+ self.proj_out = torch.nn.Linear(dim, dim)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).view(B, N, 3, C) # B, N, 3, C
+ q, k, v = qkv.unbind(2) # B, N, C
+ k = k.reshape(B, N, self.num_heads, self.head_dim)
+ v = v.reshape(B, N, self.num_heads, self.head_dim)
+ q = q.reshape(B, N, self.num_heads, self.head_dim)
+ out = flash_attn_func(q, k, v) # B, N, H, c
+ out = self.proj_out(out.view(B, N, C)) # B, N, C
+ return out
diff --git a/diffusion/model/nets/fastlinear/modules/lite_mla.py b/diffusion/model/nets/fastlinear/modules/lite_mla.py
new file mode 100644
index 0000000..6cd6f0f
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/lite_mla.py
@@ -0,0 +1,105 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class LiteMLA(nn.Module):
+ r"""Lightweight multiscale linear attention"""
+
+ PAD_VAL = 1
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ heads: Optional[int] = None,
+ heads_ratio: float = 1.0,
+ dim=32,
+ kernel_func="relu",
+ scales: Optional[Tuple[int]] = (5,),
+ eps=1e-15,
+ use_bias=False,
+ norm=(None, "bn2d"),
+ act=(None, None),
+ ):
+ heads = heads or int(out_dim // dim * heads_ratio)
+ super().__init__()
+
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.heads = heads
+ self.dim = dim
+ self.scales = scales
+ self.eps = eps
+
+ self.aggreg = None
+ scales = ()
+ self.kernel_func = nn.ReLU(inplace=False)
+
+ self.qkv = nn.Linear(in_dim, in_dim * 3, bias=use_bias)
+ self.proj = nn.Linear(out_dim, out_dim)
+
+ @torch.cuda.amp.autocast(enabled=os.environ.get("AUTOCAST_LINEAR_ATTN", False) == "true")
+ def attn_matmul(self, q, k, v: torch.Tensor) -> torch.Tensor:
+ # lightweight linear attention
+ q = self.kernel_func(q) # B, h, h_d, N
+ k = self.kernel_func(k)
+
+ use_fp32_attention = getattr(self, "fp32_attention", False) # necessary for NAN loss
+ if use_fp32_attention:
+ q, k, v = q.float(), k.float(), v.float()
+ v = F.pad(v, (0, 0, 0, 1), mode="constant", value=LiteMLA.PAD_VAL)
+ vk = torch.matmul(v, k)
+ out = torch.matmul(vk, q)
+ if out.dtype in [torch.float16, torch.bfloat16]:
+ out = out.float()
+ out = out[:, :, :-1] / (out[:, :, -1:] + self.eps)
+
+ return out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, C).permute(0, 2, 3, 1)
+ # B, 3, C, N --> B, C, N
+ q, k, v = qkv.unbind(1)
+ dtype = q.dtype
+
+ q = q.reshape(B, C // self.dim, self.dim, N) # b, h, h_d, N
+ k = k.reshape(B, C // self.dim, self.dim, N).transpose(-1, -2) # b, h, N, h_d
+ v = v.reshape(B, C // self.dim, self.dim, N) # b, h, h_d, N
+
+ out = self.attn_matmul(q, k, v).to(dtype)
+
+ out = out.view(B, C, N).permute(0, 2, 1) # B, N, C
+ out = self.proj(out)
+
+ return out
+
+ @property
+ def module_str(self) -> str:
+ _str = type(self).__name__ + "("
+ eps = f"{self.eps:.1E}"
+ _str += f"i={self.in_dim},o={self.out_dim},h={self.heads},d={self.dim},eps={eps}"
+ return _str
+
+ def __repr__(self):
+ return f"EPS{self.eps}-" + super().__repr__()
diff --git a/diffusion/model/nets/fastlinear/modules/mb_conv_pre_glu.py b/diffusion/model/nets/fastlinear/modules/mb_conv_pre_glu.py
new file mode 100644
index 0000000..af71b02
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/mb_conv_pre_glu.py
@@ -0,0 +1,111 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+from torch import nn
+
+from .nn.act import build_act, get_act_name
+from .nn.conv import ConvLayer
+from .nn.norm import build_norm, get_norm_name
+from .utils.model import get_same_padding, val2tuple
+
+
+class MBConvPreGLU(nn.Module):
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ kernel_size=3,
+ stride=1,
+ mid_dim=None,
+ expand=6,
+ padding: int or None = None,
+ use_bias=False,
+ norm=(None, None, "ln2d"),
+ act=("silu", "silu", None),
+ ):
+ super().__init__()
+ use_bias = val2tuple(use_bias, 3)
+ norm = val2tuple(norm, 3)
+ act = val2tuple(act, 3)
+
+ mid_dim = mid_dim or round(in_dim * expand)
+
+ self.inverted_conv = ConvLayer(
+ in_dim,
+ mid_dim * 2,
+ 1,
+ use_bias=use_bias[0],
+ norm=norm[0],
+ act=None,
+ )
+ self.glu_act = build_act(act[0], inplace=False)
+ self.depth_conv = ConvLayer(
+ mid_dim,
+ mid_dim,
+ kernel_size,
+ stride=stride,
+ groups=mid_dim,
+ padding=padding,
+ use_bias=use_bias[1],
+ norm=norm[1],
+ act=act[1],
+ )
+ self.point_conv = ConvLayer(
+ mid_dim,
+ out_dim,
+ 1,
+ use_bias=use_bias[2],
+ norm=norm[2],
+ act=act[2],
+ )
+
+ def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor:
+ B, N, C = x.shape
+ if HW is None:
+ H = W = int(N**0.5)
+ else:
+ H, W = HW
+
+ x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
+
+ x = self.inverted_conv(x)
+ x, gate = torch.chunk(x, 2, dim=1)
+ gate = self.glu_act(gate)
+ x = x * gate
+
+ x = self.depth_conv(x)
+ x = self.point_conv(x)
+
+ x = x.reshape(B, C, N).permute(0, 2, 1)
+ return x
+
+ @property
+ def module_str(self) -> str:
+ _str = f"{self.depth_conv.kernel_size}{type(self).__name__}("
+ _str += f"in={self.inverted_conv.in_dim},mid={self.depth_conv.in_dim},out={self.point_conv.out_dim},s={self.depth_conv.stride}"
+ _str += (
+ f",norm={get_norm_name(self.inverted_conv.norm)}"
+ f"+{get_norm_name(self.depth_conv.norm)}"
+ f"+{get_norm_name(self.point_conv.norm)}"
+ )
+ _str += (
+ f",act={get_act_name(self.inverted_conv.act)}"
+ f"+{get_act_name(self.depth_conv.act)}"
+ f"+{get_act_name(self.point_conv.act)}"
+ )
+ _str += f",glu_act={get_act_name(self.glu_act)})"
+ return _str
diff --git a/diffusion/model/nets/fastlinear/modules/nn/act.py b/diffusion/model/nets/fastlinear/modules/nn/act.py
new file mode 100644
index 0000000..10b4658
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/nn/act.py
@@ -0,0 +1,59 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import copy
+
+import torch.nn as nn
+
+__all__ = ["build_act", "get_act_name"]
+
+# register activation function here
+# name: module, kwargs with default values
+REGISTERED_ACT_DICT: dict[str, tuple[type, dict[str, any]]] = {
+ "relu": (nn.ReLU, {"inplace": True}),
+ "relu6": (nn.ReLU6, {"inplace": True}),
+ "hswish": (nn.Hardswish, {"inplace": True}),
+ "hsigmoid": (nn.Hardsigmoid, {"inplace": True}),
+ "swish": (nn.SiLU, {"inplace": True}),
+ "silu": (nn.SiLU, {"inplace": True}),
+ "tanh": (nn.Tanh, {}),
+ "sigmoid": (nn.Sigmoid, {}),
+ "gelu": (nn.GELU, {"approximate": "tanh"}),
+ "mish": (nn.Mish, {"inplace": True}),
+ "identity": (nn.Identity, {}),
+}
+
+
+def build_act(name: str or None, **kwargs) -> nn.Module or None:
+ if name in REGISTERED_ACT_DICT:
+ act_cls, default_args = copy.deepcopy(REGISTERED_ACT_DICT[name])
+ for key in default_args:
+ if key in kwargs:
+ default_args[key] = kwargs[key]
+ return act_cls(**default_args)
+ elif name is None or name.lower() == "none":
+ return None
+ else:
+ raise ValueError(f"do not support: {name}")
+
+
+def get_act_name(act: nn.Module or None) -> str or None:
+ if act is None:
+ return None
+ module2name = {}
+ for key, config in REGISTERED_ACT_DICT.items():
+ module2name[config[0].__name__] = key
+ return module2name.get(type(act).__name__, "unknown")
diff --git a/diffusion/model/nets/fastlinear/modules/nn/conv.py b/diffusion/model/nets/fastlinear/modules/nn/conv.py
new file mode 100644
index 0000000..93978bc
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/nn/conv.py
@@ -0,0 +1,76 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+from torch import nn
+
+from ..utils.model import get_same_padding
+from .act import build_act, get_act_name
+from .norm import build_norm, get_norm_name
+
+
+class ConvLayer(nn.Module):
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ kernel_size=3,
+ stride=1,
+ dilation=1,
+ groups=1,
+ padding: int or None = None,
+ use_bias=False,
+ dropout=0.0,
+ norm="bn2d",
+ act="relu",
+ ):
+ super().__init__()
+ if padding is None:
+ padding = get_same_padding(kernel_size)
+ padding *= dilation
+
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.dilation = dilation
+ self.groups = groups
+ self.padding = padding
+ self.use_bias = use_bias
+
+ self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
+ self.conv = nn.Conv2d(
+ in_dim,
+ out_dim,
+ kernel_size=(kernel_size, kernel_size),
+ stride=(stride, stride),
+ padding=padding,
+ dilation=(dilation, dilation),
+ groups=groups,
+ bias=use_bias,
+ )
+ self.norm = build_norm(norm, num_features=out_dim)
+ self.act = build_act(act)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.dropout is not None:
+ x = self.dropout(x)
+ x = self.conv(x)
+ if self.norm:
+ x = self.norm(x)
+ if self.act:
+ x = self.act(x)
+ return x
diff --git a/diffusion/model/nets/fastlinear/modules/nn/norm.py b/diffusion/model/nets/fastlinear/modules/nn/norm.py
new file mode 100644
index 0000000..2e6fb78
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/nn/norm.py
@@ -0,0 +1,231 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import copy
+import warnings
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.batchnorm import _BatchNorm
+
+__all__ = ["LayerNorm2d", "build_norm", "get_norm_name", "reset_bn", "remove_bn", "set_norm_eps"]
+
+
+class LayerNorm2d(nn.LayerNorm):
+ rmsnorm = False
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ out = x if LayerNorm2d.rmsnorm else x - torch.mean(x, dim=1, keepdim=True)
+ out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)
+ if self.elementwise_affine:
+ out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
+ return out
+
+ def extra_repr(self) -> str:
+ return f"{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}, rmsnorm={self.rmsnorm}"
+
+
+# register normalization function here
+# name: module, kwargs with default values
+REGISTERED_NORMALIZATION_DICT: dict[str, tuple[type, dict[str, any]]] = {
+ "bn2d": (nn.BatchNorm2d, {"num_features": None, "eps": 1e-5, "momentum": 0.1, "affine": True}),
+ "syncbn": (nn.SyncBatchNorm, {"num_features": None, "eps": 1e-5, "momentum": 0.1, "affine": True}),
+ "ln": (nn.LayerNorm, {"normalized_shape": None, "eps": 1e-5, "elementwise_affine": True}),
+ "ln2d": (LayerNorm2d, {"normalized_shape": None, "eps": 1e-5, "elementwise_affine": True}),
+}
+
+
+def build_norm(name="bn2d", num_features=None, affine=True, **kwargs) -> nn.Module or None:
+ if name in ["ln", "ln2d"]:
+ kwargs["normalized_shape"] = num_features
+ kwargs["elementwise_affine"] = affine
+ else:
+ kwargs["num_features"] = num_features
+ kwargs["affine"] = affine
+ if name in REGISTERED_NORMALIZATION_DICT:
+ norm_cls, default_args = copy.deepcopy(REGISTERED_NORMALIZATION_DICT[name])
+ for key in default_args:
+ if key in kwargs:
+ default_args[key] = kwargs[key]
+ return norm_cls(**default_args)
+ elif name is None or name.lower() == "none":
+ return None
+ else:
+ raise ValueError("do not support: %s" % name)
+
+
+def get_norm_name(norm: nn.Module or None) -> str or None:
+ if norm is None:
+ return None
+ module2name = {}
+ for key, config in REGISTERED_NORMALIZATION_DICT.items():
+ module2name[config[0].__name__] = key
+ return module2name.get(type(norm).__name__, "unknown")
+
+
+def reset_bn(
+ model: nn.Module,
+ data_loader: list,
+ sync=True,
+ progress_bar=False,
+) -> None:
+ import copy
+
+ import torch.nn.functional as F
+ from packages.apps.utils import AverageMeter, is_master, sync_tensor
+ from packages.models.utils import get_device, list_join
+ from tqdm import tqdm
+
+ bn_mean = {}
+ bn_var = {}
+
+ tmp_model = copy.deepcopy(model)
+ for name, m in tmp_model.named_modules():
+ if isinstance(m, _BatchNorm):
+ bn_mean[name] = AverageMeter(is_distributed=False)
+ bn_var[name] = AverageMeter(is_distributed=False)
+
+ def new_forward(bn, mean_est, var_est):
+ def lambda_forward(x):
+ x = x.contiguous()
+ if sync:
+ batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1
+ batch_mean = sync_tensor(batch_mean, reduce="cat")
+ batch_mean = torch.mean(batch_mean, dim=0, keepdim=True)
+
+ batch_var = (x - batch_mean) * (x - batch_mean)
+ batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
+ batch_var = sync_tensor(batch_var, reduce="cat")
+ batch_var = torch.mean(batch_var, dim=0, keepdim=True)
+ else:
+ batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1
+ batch_var = (x - batch_mean) * (x - batch_mean)
+ batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
+
+ batch_mean = torch.squeeze(batch_mean)
+ batch_var = torch.squeeze(batch_var)
+
+ mean_est.update(batch_mean.data, x.size(0))
+ var_est.update(batch_var.data, x.size(0))
+
+ # bn forward using calculated mean & var
+ _feature_dim = batch_mean.shape[0]
+ return F.batch_norm(
+ x,
+ batch_mean,
+ batch_var,
+ bn.weight[:_feature_dim],
+ bn.bias[:_feature_dim],
+ False,
+ 0.0,
+ bn.eps,
+ )
+
+ return lambda_forward
+
+ m.forward = new_forward(m, bn_mean[name], bn_var[name])
+
+ # skip if there is no batch normalization layers in the network
+ if len(bn_mean) == 0:
+ return
+
+ tmp_model.eval()
+ with torch.inference_mode():
+ with tqdm(total=len(data_loader), desc="reset bn", disable=not progress_bar or not is_master()) as t:
+ for images in data_loader:
+ images = images.to(get_device(tmp_model))
+ tmp_model(images)
+ t.set_postfix(
+ {
+ "bs": images.size(0),
+ "res": list_join(images.shape[-2:], "x"),
+ }
+ )
+ t.update()
+
+ for name, m in model.named_modules():
+ if name in bn_mean and bn_mean[name].count > 0:
+ feature_dim = bn_mean[name].avg.size(0)
+ assert isinstance(m, _BatchNorm)
+ m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
+ m.running_var.data[:feature_dim].copy_(bn_var[name].avg)
+
+
+def remove_bn(model: nn.Module) -> None:
+ for m in model.modules():
+ if isinstance(m, _BatchNorm):
+ m.weight = m.bias = None
+ m.forward = lambda x: x
+
+
+def set_norm_eps(model: nn.Module, eps: float or None = None, momentum: float or None = None) -> None:
+ for m in model.modules():
+ if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)):
+ if eps is not None:
+ m.eps = eps
+ if momentum is not None:
+ m.momentum = momentum
+
+
+try:
+ from apex.normalization import FusedRMSNorm as RMSNorm
+except ImportError:
+ warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
+
+ class RMSNorm(torch.nn.Module):
+ def __init__(self, dim: int, scale_factor=1.0, eps: float = 1e-6):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim) * scale_factor)
+
+ def _norm(self, x):
+ """
+ Apply the RMSNorm normalization to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+
+ """
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ """
+ Forward pass through the RMSNorm layer.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The output tensor after applying RMSNorm.
+
+ """
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
diff --git a/diffusion/model/nets/fastlinear/modules/triton_lite_mla.py b/diffusion/model/nets/fastlinear/modules/triton_lite_mla.py
new file mode 100644
index 0000000..8bfb0ee
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/triton_lite_mla.py
@@ -0,0 +1,134 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import ipdb
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .triton_lite_mla_kernels.linear_relu_fwd import linear_relu_fwd
+from .triton_lite_mla_kernels.mm import matmul # for autocast
+from .triton_lite_mla_kernels.pad_vk_mm_fwd import pad_vk_mm_fwd
+from .triton_lite_mla_kernels.proj_divide_bwd import proj_divide_bwd
+from .triton_lite_mla_kernels.vk_mm_relu_bwd import vk_mm_relu_bwd
+from .triton_lite_mla_kernels.vk_q_mm_divide_fwd import vk_q_mm_divide_fwd
+from .triton_lite_mla_kernels.vk_q_mm_relu_bwd import vk_q_mm_relu_bwd
+
+
+class TritonLiteMLAFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: torch.Tensor,
+ qkv_weight: torch.Tensor,
+ proj_weight: torch.Tensor,
+ proj_bias: Optional[torch.Tensor],
+ num_heads: int,
+ head_dim: int,
+ eps: float,
+ ) -> torch.Tensor:
+ ctx.x_dtype, ctx.qkv_weight_dtype, ctx.proj_dtype = x.dtype, qkv_weight.dtype, proj_weight.dtype
+ if torch.is_autocast_enabled():
+ autocast_dtype = torch.get_autocast_gpu_dtype()
+ x = x.to(autocast_dtype)
+ qkv_weight = qkv_weight.to(autocast_dtype)
+ proj_weight = proj_weight.to(autocast_dtype)
+ if proj_bias is not None:
+ proj_bias = proj_bias.to(autocast_dtype)
+ B, N, C = x.shape
+ qkv, relu_mask = linear_relu_fwd(x, qkv_weight) # B, N, 3*C. autocast is processed here
+ qkv, relu_mask = qkv.view(B, N, 3, C), relu_mask.view(B, N, 3, C)
+ q, k, v = qkv.unbind(2) # B, N, C
+ k = k.reshape(B, N, num_heads, head_dim)
+ v = v.reshape(B, N, num_heads, head_dim)
+ q = q.reshape(B, N, num_heads, head_dim)
+ vk = pad_vk_mm_fwd(v, k, torch.float, torch.float)
+ proj_input, vk_q = vk_q_mm_divide_fwd(vk, q, eps, torch.float, qkv.dtype)
+ proj_input = proj_input.view(B, N, C)
+ y = F.linear(proj_input, proj_weight, proj_bias)
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2] or ctx.needs_input_grad[3]:
+ ctx.save_for_backward(x, qkv_weight, relu_mask, v, k, vk, q, vk_q, proj_input, proj_weight)
+ ctx.eps = eps
+ if torch.get_autocast_gpu_dtype() == torch.float16:
+ y = y.clip(-65504, 65504)
+ return y
+
+ @staticmethod
+ def backward(ctx, grad_y: torch.Tensor):
+ x, qkv_weight, relu_mask, v, k, vk, q, vk_q, proj_input, proj_weight = ctx.saved_tensors
+ B, N, H, C1 = vk_q.shape
+ C = C1 - 1
+
+ # ipdb.set_trace()
+ grad_proj_weight = (
+ (grad_y.reshape(-1, H * C).T @ proj_input.view(-1, H * C)).to(ctx.proj_dtype)
+ if ctx.needs_input_grad[2]
+ else None
+ )
+ grad_proj_bias = grad_y.sum((0, 1)).to(ctx.proj_dtype) if ctx.needs_input_grad[3] else None
+ #
+ grad_vk_q = proj_divide_bwd(grad_y, proj_weight, vk_q, ctx.eps)
+ del grad_y, vk_q
+
+ grad_qkv = torch.empty(B, N, 3, H, C, dtype=q.dtype, device=q.device)
+ grad_vk = vk_q_mm_relu_bwd(grad_vk_q, vk, q, relu_mask[:, :, 0].view(B, N, H, C), grad_qkv[:, :, 0])
+ del grad_vk_q, vk
+
+ vk_mm_relu_bwd(grad_vk, k, v, relu_mask[:, :, 1].view(B, N, H, C), grad_qkv[:, :, 1], grad_qkv[:, :, 2])
+ del grad_vk, q, k, v, relu_mask
+
+ grad_qkv_weight = (
+ (grad_qkv.view(B * N, 3 * H * C).T @ x.view(B * N, H * C)).to(ctx.qkv_weight_dtype)
+ if ctx.needs_input_grad[1]
+ else None
+ )
+ grad_x = (grad_qkv.view(B, N, 3 * H * C) @ qkv_weight).to(ctx.x_dtype) if ctx.needs_input_grad[0] else None
+ del grad_qkv
+
+ return grad_x, grad_qkv_weight, grad_proj_weight, grad_proj_bias, None, None, None
+
+
+class TritonLiteMLA(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ eps=1e-15,
+ use_bias=False,
+ ):
+ super().__init__()
+ self.dim, self.num_heads, self.head_dim, self.eps = dim, num_heads, dim // num_heads, eps
+ if use_bias:
+ raise NotImplementedError(f"use_bias is not supported for TritonLiteMLA")
+ self.qkv = nn.Linear(dim, dim * 3, bias=use_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x: torch.Tensor, mask=None, HW=None, block_id=None) -> torch.Tensor:
+ return TritonLiteMLAFunction.apply(
+ x, self.qkv.weight, self.proj.weight, self.proj.bias, self.num_heads, self.head_dim, self.eps
+ )
+
+ @property
+ def module_str(self) -> str:
+ _str = type(self).__name__ + "("
+ eps = f"{self.eps:.1E}"
+ _str += f"i={self.in_dim},o={self.out_dim},h={self.heads},d={self.dim},eps={eps}"
+ return _str
+
+ def __repr__(self):
+ return f"EPS{self.eps}-" + super().__repr__()
diff --git a/diffusion/model/nets/fastlinear/modules/triton_lite_mla_fwd.py b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_fwd.py
new file mode 100644
index 0000000..bd2e3fe
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_fwd.py
@@ -0,0 +1,117 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import ipdb
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .triton_lite_mla_kernels.linear_relu_fwd import linear_relu_fwd
+from .triton_lite_mla_kernels.pad_vk_mm_fwd import pad_vk_mm_fwd
+from .triton_lite_mla_kernels.vk_q_mm_divide_fwd import vk_q_mm_divide_fwd
+
+
+class TritonLiteMLAFwdFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: torch.Tensor,
+ qkv_weight: torch.Tensor,
+ proj_weight: torch.Tensor,
+ proj_bias: torch.Tensor,
+ num_heads: int,
+ head_dim: int,
+ eps: float,
+ ) -> torch.Tensor:
+ # ipdb.set_trace()
+ B, N, C = x.shape
+ qkv, relu_mask = linear_relu_fwd(x, qkv_weight) # .view(B, N, 3, C) # B, N, 3, C
+ qkv, relu_mask = qkv.view(B, N, 3, C), relu_mask.view(B, N, 3, C)
+ q, k, v = qkv.unbind(2) # B, N, C
+ k = k.reshape(B, N, num_heads, head_dim)
+ v = v.reshape(B, N, num_heads, head_dim)
+ q = q.reshape(B, N, num_heads, head_dim)
+ vk = pad_vk_mm_fwd(v, k, torch.float, torch.float)
+ proj_input, vk_q = vk_q_mm_divide_fwd(vk, q, eps, torch.float, x.dtype)
+ proj_input = proj_input.view(B, N, C)
+ y = F.linear(proj_input, proj_weight, proj_bias)
+ ctx.save_for_backward(x, qkv_weight, relu_mask, v, k, vk, q, vk_q, proj_input, proj_weight)
+ ctx.eps = eps
+ return y
+
+ @staticmethod
+ def backward(ctx, grad_y: torch.Tensor):
+ x, qkv_weight, relu_mask, v, k, vk, q, vk_q, proj_input, proj_weight = ctx.saved_tensors
+ B, N, H, C1 = vk_q.shape
+ C = C1 - 1
+
+ grad_proj_weight = grad_y.reshape(-1, H * C).T @ proj_input.view(-1, H * C)
+ grad_proj_bias = grad_y.sum((0, 1))
+ #
+ grad_proj_input = grad_y @ proj_weight
+ grad_vk_q_numerator = grad_proj_input.view(B, N, H, C) / (vk_q[:, :, :, -1:] + ctx.eps)
+ grad_vk_q_denominator = (
+ -(grad_proj_input.view(B, N, H, C) * vk_q[:, :, :, :-1]).sum(-1, keepdim=True)
+ / (vk_q[:, :, :, -1:] + ctx.eps) ** 2
+ )
+ grad_vk_q = torch.cat([grad_vk_q_numerator, grad_vk_q_denominator], dim=-1)
+
+ grad_q = (grad_vk_q.permute(0, 2, 1, 3) @ vk).permute(0, 2, 1, 3)
+ grad_vk = grad_vk_q.permute(0, 2, 3, 1) @ q.float().permute(0, 2, 1, 3)
+ grad_q.mul_(relu_mask[:, :, 0].view(B, N, H, C))
+
+ grad_v = (grad_vk @ k.float().permute(0, 2, 3, 1)).permute(0, 3, 1, 2)[:, :, :, :-1]
+ grad_k = ((v.float().permute(0, 2, 1, 3) @ grad_vk[:, :, :-1]) + grad_vk[:, :, -1:]).permute(0, 2, 1, 3)
+ grad_k.mul_(relu_mask[:, :, 1].view(B, N, H, C))
+
+ grad_qkv = torch.stack([grad_q, grad_k, grad_v], dim=2).view(B, N, 3 * H * C).to(x.dtype)
+ grad_qkv_weight = grad_qkv.view(B * N, 3 * H * C).T @ x.view(B * N, H * C)
+ grad_x = grad_qkv @ qkv_weight
+
+ # ipdb.set_trace()
+
+ return grad_x, grad_qkv_weight, grad_proj_weight, grad_proj_bias, None, None, None
+
+
+class TritonLiteMLAFwd(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ eps=1e-15,
+ use_bias=False,
+ ):
+ super().__init__()
+ self.dim, self.num_heads, self.head_dim, self.eps = dim, num_heads, dim // num_heads, eps
+ if use_bias:
+ raise NotImplementedError(f"use_bias is not supported for TritonLiteMLA")
+ self.qkv = nn.Linear(dim, dim * 3, bias=use_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x: torch.Tensor, mask=None, HW=None, block_id=None) -> torch.Tensor:
+ return TritonLiteMLAFwdFunction.apply(
+ x, self.qkv.weight, self.proj.weight, self.proj.bias, self.num_heads, self.head_dim, self.eps
+ )
+
+ @property
+ def module_str(self) -> str:
+ _str = type(self).__name__ + "("
+ eps = f"{self.eps:.1E}"
+ _str += f"i={self.in_dim},o={self.out_dim},h={self.heads},d={self.dim},eps={eps}"
+ return _str
+
+ def __repr__(self):
+ return f"EPS{self.eps}-" + super().__repr__()
diff --git a/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/custom_autotune.py b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/custom_autotune.py
new file mode 100644
index 0000000..3174744
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/custom_autotune.py
@@ -0,0 +1,123 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import builtins
+import json
+import os
+import pickle
+import time
+
+import ipdb
+import torch
+import torch.distributed as dist
+from triton.runtime.autotuner import Autotuner
+
+
+class CustomAutotuner(Autotuner):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.best_config_cache_path = os.path.expanduser(
+ os.path.join(
+ "~",
+ ".triton",
+ "best_config_cache",
+ torch.cuda.get_device_name(0).replace(" ", "_"),
+ self.base_fn.__name__ + ".pkl",
+ )
+ )
+ if os.path.exists(self.best_config_cache_path):
+ with open(self.best_config_cache_path, "rb") as f:
+ self.cache = pickle.load(f)
+
+ def run(self, *args, **kwargs):
+ self.nargs = dict(zip(self.arg_names, args))
+ used_cached_result = True
+ if len(self.configs) > 1:
+ all_args = {**self.nargs, **kwargs}
+ _args = []
+ for name in self.arg_names:
+ if name in all_args:
+ _args.append(all_args[name])
+ key = [_args[i] for i in self.key_idx]
+ for arg in _args:
+ if hasattr(arg, "dtype"):
+ key.append(str(arg.dtype))
+ key = tuple(key)
+ if key not in self.cache:
+ # prune configs
+ used_cached_result = False
+ pruned_configs = self.prune_configs(kwargs)
+ bench_start = time.time()
+ timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
+ bench_end = time.time()
+ self.bench_time = bench_end - bench_start
+ self.cache[key] = builtins.min(timings, key=timings.get)
+ self.pre_hook(args, reset_only=True)
+ self.configs_timings = timings
+ if not dist.is_initialized() or dist.get_rank() == 0:
+ best_config_cache_dir = os.path.dirname(self.best_config_cache_path)
+ os.makedirs(best_config_cache_dir, exist_ok=True)
+ with open(self.best_config_cache_path, "wb") as f:
+ pickle.dump(self.cache, f)
+ config = self.cache[key]
+ else:
+ config = self.configs[0]
+ self.best_config = config
+ if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result:
+ print(
+ f"Triton autotuning for function {self.base_fn.__name__} finished after "
+ f"{self.bench_time:.2f}s; best config selected: {self.best_config};"
+ )
+ if config.pre_hook is not None:
+ config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()})
+ ret = self.fn.run(
+ *args,
+ **kwargs,
+ **config.all_kwargs(),
+ )
+ self.nargs = None
+ return ret
+
+
+def custom_autotune(
+ configs,
+ key,
+ prune_configs_by=None,
+ reset_to_zero=None,
+ restore_value=None,
+ pre_hook=None,
+ post_hook=None,
+ warmup=25,
+ rep=100,
+ use_cuda_graph=False,
+):
+ def decorator(fn):
+ return CustomAutotuner(
+ fn,
+ fn.arg_names,
+ configs,
+ key,
+ reset_to_zero,
+ restore_value,
+ pre_hook=pre_hook,
+ post_hook=post_hook,
+ prune_configs_by=prune_configs_by,
+ warmup=warmup,
+ rep=rep,
+ use_cuda_graph=use_cuda_graph,
+ )
+
+ return decorator
diff --git a/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/linear_relu_fwd.py b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/linear_relu_fwd.py
new file mode 100644
index 0000000..b2c2062
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/linear_relu_fwd.py
@@ -0,0 +1,223 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import ipdb
+import torch
+import triton
+import triton.language as tl
+
+from ..utils.custom_autotune import custom_autotune
+
+
+def get_cuda_autotune_config():
+ return [
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2
+ ),
+ # Good config for fp8 inputs.
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ ]
+
+
+def get_autotune_config():
+ return get_cuda_autotune_config()
+
+
+# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
+# - A list of `triton.Config` objects that define different configurations of
+# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
+# - An auto-tuning *key* whose change in values will trigger evaluation of all the
+# provided configs
+@custom_autotune(
+ configs=get_autotune_config(),
+ key=["M", "N", "K"],
+)
+@triton.jit
+def linear_relu_fwd_kernel(
+ # Pointers to matrices
+ a_ptr,
+ b_ptr,
+ c_ptr,
+ r_ptr,
+ # Matrix dimensions
+ M,
+ N,
+ K,
+ num_relu_channels,
+ # The stride variables represent how much to increase the ptr by when moving by 1
+ # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
+ # by to get the element one row down (A has M rows).
+ stride_am,
+ stride_ak, #
+ stride_bn,
+ stride_bk, #
+ stride_cm,
+ stride_cn,
+ # Meta-parameters
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr, #
+ GROUP_SIZE_M: tl.constexpr, #
+):
+ """Kernel for computing the matmul C = A x B.
+ A has shape (M, K), B has shape (K, N) and C has shape (M, N)
+ """
+ # -----------------------------------------------------------
+ # Map program ids `pid` to the block of C it should compute.
+ # This is done in a grouped ordering to promote L2 data reuse.
+ # See above `L2 Cache Optimizations` section for details.
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ # ----------------------------------------------------------
+ # Create pointers for the first blocks of A and B.
+ # We will advance this pointer as we move in the K direction
+ # and accumulate
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
+ # See above `Pointer Arithmetic` section for details
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # BLOCK_SIZE_M, BLOCK_SIZE_K
+ b_ptrs = b_ptr + (offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk) # BLOCK_SIZE_K, BLOCK_SIZE_N
+
+ # -----------------------------------------------------------
+ # Iterate to compute a block of the C matrix.
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
+ # of fp32 values for higher accuracy.
+ # `accumulator` will be converted back to fp16 after the loop.
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+ # Load the next block of A and B, generate a mask by checking the K dimension.
+ # If it is out of bounds, set it to 0.
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
+ # We accumulate along the K dimension.
+ accumulator = tl.dot(a, b, accumulator)
+ # Advance the ptrs to the next K block.
+ a_ptrs += BLOCK_SIZE_K * stride_ak
+ b_ptrs += BLOCK_SIZE_K * stride_bk
+ # You can fuse arbitrary activation functions here
+ # while the accumulator is still in FP32!
+ relu_mask = (accumulator >= 0) | (offs_bn[None, :] >= num_relu_channels)
+ accumulator = tl.where(relu_mask, accumulator, 0)
+ # accumulator = tl.where(accumulator >= 0, accumulator, 0)
+ c = accumulator.to(c_ptr.dtype.element_ty)
+
+ # -----------------------------------------------------------
+ # Write back the block of the output matrix C with masks.
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ c_offs = stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
+ tl.store(c_ptr + c_offs, c, mask=c_mask)
+ tl.store(r_ptr + c_offs, relu_mask, mask=c_mask)
+
+
+def linear_relu_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ # Check constraints.
+ assert a.shape[-1] == b.shape[1], "Incompatible dimensions"
+ assert a.dim() >= 2 and b.dim() == 2
+ M, K, N = torch.prod(torch.tensor(a.shape[:-1])).item(), a.shape[-1], b.shape[0]
+ assert N % 3 == 0 # first 2/3 of N need relu
+
+ # ref_c = a@b.mT
+ # ref_c[..., :2*N//3] = torch.nn.functional.relu(ref_c[..., :2*N//3])
+ # return ref_c
+
+ # Allocates output.
+ c = torch.empty(a.shape[:-1] + (N,), device=a.device, dtype=a.dtype)
+ relu_mask = torch.empty(a.shape[:-1] + (N,), device=a.device, dtype=bool)
+ # 1D launch kernel where each block gets its own program.
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
+ if a.dtype == b.dtype:
+ linear_relu_fwd_kernel[grid](
+ a,
+ b,
+ c,
+ relu_mask, #
+ M,
+ N,
+ K,
+ 2 * N // 3, #
+ a.stride(-2),
+ a.stride(-1), #
+ b.stride(0),
+ b.stride(1), #
+ c.stride(-2),
+ c.stride(-1), # the stride of c and relu_mask should be the same
+ )
+ else:
+ raise NotImplementedError(f"data type {a.dtype} {b.dtype} is not support")
+ return c, relu_mask
diff --git a/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/mm.py b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/mm.py
new file mode 100644
index 0000000..e7521b7
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/mm.py
@@ -0,0 +1,218 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+from ..utils.dtype import get_tl_dtype_from_torch_dtype
+
+
+def get_cuda_autotune_config():
+ return [
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2
+ ),
+ # Good config for fp8 inputs.
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ ]
+
+
+def get_autotune_config():
+ return get_cuda_autotune_config()
+
+
+# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
+# - A list of `triton.Config` objects that define different configurations of
+# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
+# - An auto-tuning *key* whose change in values will trigger evaluation of all the
+# provided configs
+@triton.autotune(
+ configs=get_autotune_config(),
+ key=["M", "N", "K"],
+)
+@triton.jit
+def matmul_kernel(
+ # Pointers to matrices
+ a_ptr,
+ b_ptr,
+ c_ptr,
+ # Matrix dimensions
+ M,
+ N,
+ K,
+ # The stride variables represent how much to increase the ptr by when moving by 1
+ # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
+ # by to get the element one row down (A has M rows).
+ stride_am,
+ stride_ak, #
+ stride_bk,
+ stride_bn, #
+ stride_cm,
+ stride_cn,
+ # Meta-parameters
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr, #
+ GROUP_SIZE_M: tl.constexpr, #
+ compute_dtype: tl.constexpr,
+):
+ """Kernel for computing the matmul C = A x B.
+ A has shape (M, K), B has shape (K, N) and C has shape (M, N)
+ """
+ # -----------------------------------------------------------
+ # Map program ids `pid` to the block of C it should compute.
+ # This is done in a grouped ordering to promote L2 data reuse.
+ # See above `L2 Cache Optimizations` section for details.
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ # ----------------------------------------------------------
+ # Create pointers for the first blocks of A and B.
+ # We will advance this pointer as we move in the K direction
+ # and accumulate
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
+ # See above `Pointer Arithmetic` section for details
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
+ b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
+
+ # -----------------------------------------------------------
+ # Iterate to compute a block of the C matrix.
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
+ # of fp32 values for higher accuracy.
+ # `accumulator` will be converted back to fp16 after the loop.
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+ # Load the next block of A and B, generate a mask by checking the K dimension.
+ # If it is out of bounds, set it to 0.
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0).to(compute_dtype)
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0).to(compute_dtype)
+ # We accumulate along the K dimension.
+ accumulator = tl.dot(a, b, accumulator)
+ # if pid_m == num_pid_m-1 and pid_n == 0:
+ # tl.device_print("M", M)
+ # tl.device_print("offs_am", offs_am)
+ # tl.device_print("a", a)
+ # # tl.device_print("a max 0", tl.max(a, axis=0))
+ # # tl.device_print("a max 1", tl.max(a, axis=1))
+ # tl.device_print("offs_bn", offs_bn)
+ # tl.device_print("b", b)
+ # # tl.device_print("b max 0", tl.max(b, axis=0))
+ # # tl.device_print("b max 1", tl.max(b, axis=1))
+ # Advance the ptrs to the next K block.
+ a_ptrs += BLOCK_SIZE_K * stride_ak
+ b_ptrs += BLOCK_SIZE_K * stride_bk
+ # You can fuse arbitrary activation functions here
+ # while the accumulator is still in FP32!
+ c = accumulator.to(c_ptr.dtype.element_ty)
+
+ # -----------------------------------------------------------
+ # Write back the block of the output matrix C with masks.
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
+ tl.store(c_ptrs, c, mask=c_mask)
+
+
+def matmul(a: torch.Tensor, b: torch.Tensor, compute_dtype: torch.dtype, output_dtype: torch.dtype) -> torch.Tensor:
+ # Check constraints.
+ assert a.shape[-1] == b.shape[0], "Incompatible dimensions"
+ M = torch.prod(torch.tensor(a.shape[:-1])).item()
+ K, N = b.shape
+ if a.dtype == b.dtype == compute_dtype == output_dtype:
+ return a @ b
+ # Allocates output.
+ c = torch.empty(a.shape[:-1] + (N,), device=a.device, dtype=output_dtype)
+ # 1D launch kernel where each block gets its own program.
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
+ matmul_kernel[grid](
+ a,
+ b,
+ c, #
+ M,
+ N,
+ K, #
+ a.stride(-2),
+ a.stride(-1), #
+ b.stride(0),
+ b.stride(1), #
+ c.stride(-2),
+ c.stride(-1), #
+ compute_dtype=get_tl_dtype_from_torch_dtype(compute_dtype),
+ )
+ return c
diff --git a/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/pad_vk_mm_fwd.py b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/pad_vk_mm_fwd.py
new file mode 100644
index 0000000..2ec793e
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/pad_vk_mm_fwd.py
@@ -0,0 +1,197 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import ipdb
+import torch
+import triton
+import triton.language as tl
+
+from ..utils.custom_autotune import custom_autotune
+
+
+def get_cuda_autotune_config():
+ return [
+ triton.Config({"BLOCK_SIZE_C": 256, "BLOCK_SIZE_N": 64}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_SIZE_C": 256, "BLOCK_SIZE_N": 32}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_C": 128, "BLOCK_SIZE_N": 32}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_C": 64, "BLOCK_SIZE_N": 32}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_C": 128, "BLOCK_SIZE_N": 32}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_C": 32, "BLOCK_SIZE_N": 32}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_C": 32, "BLOCK_SIZE_N": 32}, num_stages=5, num_warps=2),
+ triton.Config({"BLOCK_SIZE_C": 64, "BLOCK_SIZE_N": 32}, num_stages=5, num_warps=2),
+ # Good config for fp8 inputs.
+ triton.Config({"BLOCK_SIZE_C": 256, "BLOCK_SIZE_N": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_SIZE_C": 128, "BLOCK_SIZE_N": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_SIZE_C": 64, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_C": 256, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_C": 128, "BLOCK_SIZE_N": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_C": 64, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_C": 128, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_C": 32, "BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4),
+ ]
+
+
+def get_autotune_config():
+ return get_cuda_autotune_config()
+
+
+# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
+# - A list of `triton.Config` objects that define different configurations of
+# meta-parameters (e.g., `BLOCK_SIZE_C1`) and compilation options (e.g., `num_warps`) to try
+# - An auto-tuning *key* whose change in values will trigger evaluation of all the
+# provided configs
+@custom_autotune(
+ configs=get_autotune_config(),
+ key=["B", "N", "H", "C"],
+)
+@triton.jit
+def pad_vk_mm_fwd_kernel_fp32_fp32(
+ # Pointers to matrices
+ a_ptr,
+ b_ptr,
+ c_ptr,
+ # Matrix dimensions
+ B,
+ N,
+ H,
+ C,
+ # The stride variables represent how much to increase the ptr by when moving by 1
+ # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
+ # by to get the element one row down (A has M rows).
+ stride_ab,
+ stride_an,
+ stride_ah,
+ stride_ac, #
+ stride_bb,
+ stride_bn,
+ stride_bh,
+ stride_bc, #
+ stride_cb,
+ stride_ch,
+ stride_cc1,
+ stride_cc,
+ # Meta-parameters
+ BLOCK_SIZE_C1: tl.constexpr,
+ BLOCK_SIZE_C: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr, #
+):
+ """Kernel for computing the matmul C = A x B.
+ A has shape (M, K), B has shape (K, N) and C has shape (M, N)
+ """
+ # -----------------------------------------------------------
+ # Map program ids `pid` to the block of C it should compute.
+ # This is done in a grouped ordering to promote L2 data reuse.
+ # See above `L2 Cache Optimizations` section for details.
+ pid = tl.program_id(axis=0)
+ num_pid_bc = tl.cdiv(C, BLOCK_SIZE_C)
+ pid_b, pid_h, pid_bc = pid // num_pid_bc // H, (pid // num_pid_bc) % H, pid % num_pid_bc
+
+ # ----------------------------------------------------------
+ # Create pointers for the first blocks of A and B.
+ # We will advance this pointer as we move in the K direction
+ # and accumulate
+ # `a_ptrs` is a block of [BLOCK_SIZE_C1, BLOCK_SIZE_N] pointers
+ # `b_ptrs` is a block of [BLOCK_SIZE_N, BLOCK_SIZE_C] pointers
+ # See above `Pointer Arithmetic` section for details
+ offs_ac = tl.arange(0, BLOCK_SIZE_C1) % C
+ offs_bc = (pid_bc * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C)) % C
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
+ a_ptrs = a_ptr + (
+ pid_b * stride_ab + pid_h * stride_ah + offs_ac[:, None] * stride_ac + offs_n[None, :] * stride_an
+ )
+ b_ptrs = b_ptr + (
+ pid_b * stride_bb + pid_h * stride_bh + offs_n[:, None] * stride_bn + offs_bc[None, :] * stride_bc
+ )
+ # -----------------------------------------------------------
+ # Iterate to compute a block of the C matrix.
+ # We accumulate into a `[BLOCK_SIZE_C1, BLOCK_SIZE_C]` block
+ # of fp32 values for higher accuracy.
+ # `accumulator` will be converted back to fp16 after the loop.
+ accumulator = tl.zeros((BLOCK_SIZE_C1, BLOCK_SIZE_C), dtype=tl.float32)
+ accumulator1 = tl.zeros((BLOCK_SIZE_C,), dtype=tl.float32)
+ for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
+ # Load the next block of A and B, generate a mask by checking the K dimension.
+ # If it is out of bounds, set it to 0.
+ a = tl.load(a_ptrs, mask=offs_n[None, :] < N - n * BLOCK_SIZE_N, other=0.0).to(tl.float32)
+ b = tl.load(b_ptrs, mask=offs_n[:, None] < N - n * BLOCK_SIZE_N, other=0.0).to(tl.float32)
+ # We accumulate along the K dimension.
+ accumulator = tl.dot(a, b, accumulator)
+ accumulator1 += tl.sum(b, axis=0)
+ # Advance the ptrs to the next K block.
+ a_ptrs += BLOCK_SIZE_N * stride_an
+ b_ptrs += BLOCK_SIZE_N * stride_bn
+ # You can fuse arbitrary activation functions here
+ # while the accumulator is still in FP32!
+ c = accumulator
+ c1 = accumulator1
+
+ # -----------------------------------------------------------
+ # Write back the block of the output matrix C with masks.
+ offs_cc1 = tl.arange(0, BLOCK_SIZE_C1)
+ offs_cc = pid_bc * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C)
+ c_ptrs = (
+ c_ptr + stride_cb * pid_b + stride_ch * pid_h + stride_cc1 * offs_cc1[:, None] + stride_cc * offs_cc[None, :]
+ )
+ c_mask = (offs_cc1[:, None] < C) & (offs_cc[None, :] < C)
+ tl.store(c_ptrs, c, mask=c_mask)
+ c1_ptrs = c_ptr + stride_cb * pid_b + stride_ch * pid_h + stride_cc1 * C + stride_cc * offs_cc
+ c1_mask = offs_cc < C
+ tl.store(c1_ptrs, c1, mask=c1_mask)
+
+
+def pad_vk_mm_fwd(a, b, compute_dtype: torch.dtype, output_dtype: torch.dtype):
+ """
+ Input:
+ v: (B, N, H, C)
+ k: (B, N, H, C)
+ Output:
+ vk: (B, H, C+1, C)
+ """
+ # Check constraints.
+ assert a.dim() == 4 and b.dim() == 4
+ assert a.shape == b.shape, "Incompatible dimensions"
+ B, N, H, C = a.shape
+ # Allocates output.
+ c = torch.empty((B, H, C + 1, C), device=a.device, dtype=output_dtype)
+ # 1D launch kernel where each block gets its own program.
+ grid = lambda META: (B * H * triton.cdiv(C, META["BLOCK_SIZE_C"]),)
+ if compute_dtype == torch.float and output_dtype == torch.float:
+ pad_vk_mm_fwd_kernel_fp32_fp32[grid](
+ a,
+ b,
+ c, #
+ B,
+ N,
+ H,
+ C, #
+ a.stride(-4),
+ a.stride(-3),
+ a.stride(-2),
+ a.stride(-1), #
+ b.stride(-4),
+ b.stride(-3),
+ b.stride(-2),
+ b.stride(-1), #
+ c.stride(-4),
+ c.stride(-3),
+ c.stride(-2),
+ c.stride(-1), #
+ BLOCK_SIZE_C1=triton.next_power_of_2(C),
+ )
+ else:
+ raise NotImplementedError()
+ # ipdb.set_trace()
+ return c
diff --git a/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/proj_divide_bwd.py b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/proj_divide_bwd.py
new file mode 100644
index 0000000..8bf60cc
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/proj_divide_bwd.py
@@ -0,0 +1,306 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import ipdb
+import torch
+import triton
+import triton.language as tl
+
+from ..utils.custom_autotune import custom_autotune
+
+
+def get_cuda_autotune_config():
+ return [
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_H_": 8, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_H_": 8, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_H_": 4, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_H_": 2, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_H_": 4, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_H_": 1, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_H_": 1, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
+ num_stages=5,
+ num_warps=2,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_H_": 2, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
+ num_stages=5,
+ num_warps=2,
+ ),
+ # Good config for fp8 inputs.
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_H_": 8, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_H_": 4, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_H_": 2, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_H_": 8, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_H_": 4, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_H_": 2, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_H_": 4, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_H_": 1, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ ]
+
+
+def get_autotune_config():
+ return get_cuda_autotune_config()
+
+
+# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
+# - A list of `triton.Config` objects that define different configurations of
+# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
+# - An auto-tuning *key* whose change in values will trigger evaluation of all the
+# provided configs
+@custom_autotune(
+ configs=get_autotune_config(),
+ key=["M", "N", "K", "H_", "C_"],
+)
+@triton.jit
+def proj_divide_bwd_kernel(
+ # Pointers to matrices
+ grad_y_ptr,
+ project_weight_ptr,
+ vk_q_ptr,
+ grad_vk_q_ptr,
+ # Matrix dimensions
+ M,
+ N,
+ K,
+ H_,
+ C_,
+ # The stride variables represent how much to increase the ptr by when moving by 1
+ # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
+ # by to get the element one row down (A has M rows).
+ stride_grad_y_m,
+ stride_grad_y_k, #
+ stride_project_weight_k,
+ stride_project_weight_n, #
+ stride_vk_q_m,
+ stride_vk_q_h_,
+ stride_vk_q_c_,
+ eps,
+ # Meta-parameters
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr, #
+ GROUP_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_C_: tl.constexpr,
+ BLOCK_SIZE_H_: tl.constexpr,
+):
+ """Kernel for computing the matmul C = A x B.
+ A has shape (M, K), B has shape (K, N) and C has shape (M, N)
+ """
+ # -----------------------------------------------------------
+ # Map program ids `pid` to the block of C it should compute.
+ # This is done in a grouped ordering to promote L2 data reuse.
+ # See above `L2 Cache Optimizations` section for details.
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ # ----------------------------------------------------------
+ # Create pointers for the first blocks of A and B.
+ # We will advance this pointer as we move in the K direction
+ # and accumulate
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
+ # See above `Pointer Arithmetic` section for details
+ offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
+ offs_h_ = (pid_n * BLOCK_SIZE_H_ + tl.arange(0, BLOCK_SIZE_H_)) % H_
+ offs_c_ = tl.arange(0, BLOCK_SIZE_C_)
+ offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+ # offs_hc_ = tl.reshape(offs_n, BLOCK_SIZE_H_, BLOCK_SIZE_C_)
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ grad_y_ptrs = grad_y_ptr + (
+ offs_m[:, None] * stride_grad_y_m + offs_k[None, :] * stride_grad_y_k
+ ) # BLOCK_SIZE_M, BLOCK_SIZE_K
+ project_weight_ptrs = project_weight_ptr + (
+ offs_n[None, :] * stride_project_weight_n + offs_k[:, None] * stride_project_weight_k
+ ) # BLOCK_SIZE_K, BLOCK_SIZE_N
+
+ # -----------------------------------------------------------
+ # Iterate to compute a block of the C matrix.
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
+ # of fp32 values for higher accuracy.
+ # `accumulator` will be converted back to fp16 after the loop.
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+ # Load the next block of A and B, generate a mask by checking the K dimension.
+ # If it is out of bounds, set it to 0.
+ grad_y = tl.load(grad_y_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
+ project_weight = tl.load(project_weight_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0).to(
+ grad_y_ptr.dtype.element_ty
+ )
+ # We accumulate along the K dimension.
+ accumulator = tl.dot(grad_y, project_weight, accumulator)
+ # Advance the ptrs to the next K block.
+ grad_y_ptrs += BLOCK_SIZE_K * stride_grad_y_k
+ project_weight_ptrs += BLOCK_SIZE_K * stride_project_weight_k
+ grad_proj_input = accumulator.to(grad_vk_q_ptr.dtype.element_ty) # BLOCK_SIZE_M, BLOCK_SIZE_N
+ grad_proj_input = tl.reshape(
+ grad_proj_input, BLOCK_SIZE_M, BLOCK_SIZE_H_, BLOCK_SIZE_C_
+ ) # BLOCK_SIZE_M, BLOCK_SIZE_H_, C_
+
+ vk_q_numerator_ptrs = (
+ vk_q_ptr
+ + offs_m[:, None, None] * stride_vk_q_m
+ + offs_h_[None, :, None] * stride_vk_q_h_
+ + offs_c_[None, None, :] * stride_vk_q_c_
+ ) # BLOCK_SIZE_M, BLOCK_SIZE_H_, C_
+ vk_q_denominator_ptrs = (
+ vk_q_ptr
+ + offs_m[:, None, None] * stride_vk_q_m
+ + offs_h_[None, :, None] * stride_vk_q_h_
+ + BLOCK_SIZE_C_ * stride_vk_q_c_
+ ) # BLOCK_SIZE_M, BLOCK_SIZE_H_, 1
+ vk_q_numerator = tl.load(vk_q_numerator_ptrs)
+ vk_q_denominator = tl.load(vk_q_denominator_ptrs) + eps
+
+ grad_vk_q_numerator = grad_proj_input / vk_q_denominator
+ grad_vk_q_denominator = -tl.sum(grad_vk_q_numerator * vk_q_numerator, axis=2, keep_dims=True) / vk_q_denominator
+
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_h_ = pid_n * BLOCK_SIZE_H_ + tl.arange(0, BLOCK_SIZE_H_)
+ grad_vk_q_numerator_ptrs = (
+ grad_vk_q_ptr
+ + offs_m[:, None, None] * stride_vk_q_m
+ + offs_h_[None, :, None] * stride_vk_q_h_
+ + offs_c_[None, None, :] * stride_vk_q_c_
+ )
+ grad_vk_q_denominator_ptrs = (
+ grad_vk_q_ptr
+ + offs_m[:, None, None] * stride_vk_q_m
+ + offs_h_[None, :, None] * stride_vk_q_h_
+ + BLOCK_SIZE_C_ * stride_vk_q_c_
+ )
+ grad_vk_q_mask = (offs_m[:, None, None] < M) & (offs_h_[None, :, None] < H_)
+ tl.store(grad_vk_q_numerator_ptrs, grad_vk_q_numerator, mask=grad_vk_q_mask)
+ tl.store(grad_vk_q_denominator_ptrs, grad_vk_q_denominator, mask=grad_vk_q_mask)
+
+
+def proj_divide_bwd(grad_y: torch.Tensor, proj_weight: torch.Tensor, vk_q: torch.Tensor, eps: float) -> torch.Tensor:
+ """
+ Input:
+ grad_y: (B, N, H*C)
+ proj_weight: (H*C, H*C)
+ vk_q: (B, N, H, C+1)
+ Output:
+ grad_vk_q: (B, N, H, C+1)
+ """
+ assert vk_q.is_contiguous() # to ensure the stride of vk_q and grad_vk_q are the same
+
+ assert grad_y.dim() == 3 and proj_weight.dim() == 2 and vk_q.dim() == 4
+ assert grad_y.shape[0] == vk_q.shape[0]
+ assert grad_y.shape[1] == vk_q.shape[1]
+ assert grad_y.shape[2] == proj_weight.shape[0] == proj_weight.shape[1] == vk_q.shape[2] * (vk_q.shape[3] - 1)
+
+ B_, N_, H_, C1_ = vk_q.shape
+ C_ = C1_ - 1
+ assert C_ == 32, "currently only support C=32, to ensure reduction for C in each thread"
+
+ M, K, N = B_ * N_, H_ * C_, H_ * C_
+
+ # Allocates output.
+ grad_vk_q = torch.empty_like(vk_q)
+ # 1D launch kernel where each block gets its own program.
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
+ proj_divide_bwd_kernel[grid](
+ grad_y,
+ proj_weight,
+ vk_q,
+ grad_vk_q, #
+ M,
+ N,
+ K,
+ H_,
+ C_, #
+ grad_y.stride(1),
+ grad_y.stride(2), #
+ proj_weight.stride(0),
+ proj_weight.stride(1), #
+ grad_vk_q.stride(1),
+ grad_vk_q.stride(2),
+ grad_vk_q.stride(3), #
+ eps,
+ BLOCK_SIZE_C_=C_,
+ )
+
+ # ref_grad_proj_input = grad_y@proj_weight
+ # ref_grad_vk_q_numerator = ref_grad_proj_input.view(B_, N_, H_, C_)/(vk_q[:, :, :, -1:]+eps)
+ # ref_grad_vk_q_denominator = -(ref_grad_proj_input.view(B_, N_, H_, C_)*vk_q[:, :, :, :-1]).sum(-1, keepdim=True)/(vk_q[:, :, :, -1:]+eps)**2
+ # ref_grad_vk_q = torch.cat([ref_grad_vk_q_numerator, ref_grad_vk_q_denominator], dim=-1)
+ # ipdb.set_trace()
+
+ return grad_vk_q
diff --git a/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/vk_mm_relu_bwd.py b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/vk_mm_relu_bwd.py
new file mode 100644
index 0000000..76ff50c
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/vk_mm_relu_bwd.py
@@ -0,0 +1,205 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import ipdb
+import torch
+import triton
+import triton.language as tl
+
+from ..utils.custom_autotune import custom_autotune
+
+
+def get_cuda_autotune_config():
+ return [
+ triton.Config({"BLOCK_SIZE_N": 256}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_SIZE_N": 256}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_SIZE_N": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 64}, num_stages=5, num_warps=2),
+ triton.Config({"BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 32}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 32}, num_stages=5, num_warps=2),
+ ]
+
+
+def get_autotune_config():
+ return get_cuda_autotune_config()
+
+
+# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
+# - A list of `triton.Config` objects that define different configurations of
+# meta-parameters (e.g., `BLOCK_SIZE_C1`) and compilation options (e.g., `num_warps`) to try
+# - An auto-tuning *key* whose change in values will trigger evaluation of all the
+# provided configs
+@custom_autotune(
+ configs=get_autotune_config(),
+ key=["B", "N", "H", "C"],
+)
+@triton.jit
+def vk_mm_relu_bwd_kernel(
+ # Pointers to matrices
+ grad_vk_ptr,
+ k_ptr,
+ v_ptr,
+ k_relu_mask_ptr,
+ grad_k_ptr,
+ grad_v_ptr, #
+ # Matrix dimensions
+ B,
+ N,
+ H,
+ C,
+ # The stride variables represent how much to increase the ptr by when moving by 1
+ # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
+ # by to get the element one row down (A has M rows).
+ stride_vk_b,
+ stride_vk_h,
+ stride_vk_c1,
+ stride_vk_c,
+ stride_k_b,
+ stride_k_n,
+ stride_k_h,
+ stride_k_c,
+ stride_grad_k_b,
+ stride_grad_k_n,
+ stride_grad_k_h,
+ stride_grad_k_c,
+ # Meta-parameters
+ BLOCK_SIZE_C: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr, #
+):
+ """
+ Input:
+ grad_vk: (B, H, C+1, C), fp32
+ k: (B, N, H, C), fp16
+ v: (B, N, H, C), fp16
+ k_relu_mask: (B, N, H, C), bool
+ Output:
+ grad_k: (B, N, H, C), fp16
+ grad_v: (B, N, H, C), fp16
+ """
+ # -----------------------------------------------------------
+ # Map program ids `pid` to the block of C it should compute.
+ # This is done in a grouped ordering to promote L2 data reuse.
+ # See above `L2 Cache Optimizations` section for details.
+ pid = tl.program_id(axis=0)
+ pid_b, pid_h = pid // H, pid % H
+
+ offs_c = tl.arange(0, BLOCK_SIZE_C)
+ c_mask = offs_c < C
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
+
+ grad_vk_ptrs = (
+ grad_vk_ptr
+ + pid_b * stride_vk_b
+ + pid_h * stride_vk_h
+ + offs_c[:, None] * stride_vk_c1
+ + offs_c[None, :] * stride_vk_c
+ ) # Cv, Ck
+ grad_vk = tl.load(grad_vk_ptrs, mask=c_mask[:, None] & c_mask[None, :], other=0.0) # Cv, Ck
+ grad_vk_last_row_ptrs = (
+ grad_vk_ptr + pid_b * stride_vk_b + pid_h * stride_vk_h + C * stride_vk_c1 + offs_c * stride_vk_c
+ ) # Ck
+ grad_vk_last_row = tl.load(grad_vk_last_row_ptrs, mask=c_mask, other=0.0) # Ck
+ k_offs = (
+ pid_b * stride_k_b + pid_h * stride_k_h + offs_n[:, None] * stride_k_n + offs_c[None, :] * stride_k_c
+ ) # n, C
+ grad_k_offs = (
+ pid_b * stride_grad_k_b
+ + pid_h * stride_grad_k_h
+ + offs_n[:, None] * stride_grad_k_n
+ + offs_c[None, :] * stride_grad_k_c
+ ) # n, C
+
+ for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
+ n_mask = offs_n < N - n * BLOCK_SIZE_N
+ nc_mask = n_mask[:, None] & c_mask[None, :]
+
+ k = tl.load(k_ptr + k_offs, mask=nc_mask, other=0.0).to(tl.float32) # n, Ck
+ grad_v = tl.dot(k, tl.trans(grad_vk)).to(grad_v_ptr.dtype.element_ty) # n, Cv
+ tl.store(grad_v_ptr + grad_k_offs, grad_v, mask=nc_mask)
+
+ v = tl.load(v_ptr + k_offs, mask=nc_mask, other=0.0).to(tl.float32) # n, Ck
+ grad_k = tl.dot(v, grad_vk) + grad_vk_last_row # n, Ck
+ k_relu_mask = tl.load(k_relu_mask_ptr + k_offs, mask=nc_mask, other=0) # n, Ck
+ grad_k = tl.where(k_relu_mask, grad_k, 0).to(grad_k_ptr.dtype.element_ty) # n, Ck
+ tl.store(grad_k_ptr + grad_k_offs, grad_k, mask=nc_mask)
+
+ k_offs += BLOCK_SIZE_N * stride_k_n
+ grad_k_offs += BLOCK_SIZE_N * stride_grad_k_n
+
+
+def vk_mm_relu_bwd(
+ grad_vk: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ k_relu_mask: torch.Tensor,
+ grad_k: torch.Tensor,
+ grad_v: torch.Tensor,
+) -> None:
+ """
+ Input:
+ grad_vk: (B, H, C+1, C), fp32
+ k: (B, N, H, C), fp16
+ v: (B, N, H, C), fp16
+ k_relu_mask: (B, N, H, C), bool
+ grad_k: (B, N, H, C), fp16
+ grad_v: (B, N, H, C), fp16
+ """
+
+ # ref_grad_v = (grad_vk@k.float().permute(0, 2, 3, 1)).permute(0, 3, 1, 2)[:, :, :, :-1]
+ # ref_grad_k = ((v.float().permute(0, 2, 1, 3)@grad_vk[:, :, :-1])+grad_vk[:, :, -1:]).permute(0, 2, 1, 3)
+ # ref_grad_k.mul_(k_relu_mask)
+ # return ref_grad_k, ref_grad_v
+
+ assert grad_vk.dim() == 4 and k.dim() == 4 and v.dim() == 4 and k_relu_mask.dim() == 4
+ assert k.shape == v.shape == k_relu_mask.shape
+ assert grad_vk.shape[0] == k.shape[0] # B
+ assert grad_vk.shape[1] == k.shape[2] # N
+ assert grad_vk.shape[2] - 1 == grad_vk.shape[3] == k.shape[3] # C
+
+ assert k.stride() == v.stride() == k_relu_mask.stride()
+
+ B, N, H, C = k.shape
+ # 1D launch kernel where each block gets its own program.
+ grid = lambda META: (B * H,)
+ vk_mm_relu_bwd_kernel[grid](
+ grad_vk,
+ k,
+ v,
+ k_relu_mask,
+ grad_k,
+ grad_v, #
+ B,
+ N,
+ H,
+ C, #
+ grad_vk.stride(0),
+ grad_vk.stride(1),
+ grad_vk.stride(2),
+ grad_vk.stride(3), #
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ k.stride(3), #
+ grad_k.stride(0),
+ grad_k.stride(1),
+ grad_k.stride(2),
+ grad_k.stride(3), #
+ BLOCK_SIZE_C=triton.next_power_of_2(C),
+ )
+
+ # ipdb.set_trace()
diff --git a/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/vk_q_mm_divide_fwd.py b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/vk_q_mm_divide_fwd.py
new file mode 100644
index 0000000..27ecd97
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/vk_q_mm_divide_fwd.py
@@ -0,0 +1,221 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import ipdb
+import torch
+import triton
+import triton.language as tl
+
+from ..utils.custom_autotune import custom_autotune
+
+
+def get_cuda_autotune_config():
+ return [
+ triton.Config({"BLOCK_SIZE_N": 256, "BLOCK_SIZE_D": 64}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_SIZE_N": 256, "BLOCK_SIZE_D": 32}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_D": 32}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_D": 32}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_D": 32}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_D": 32}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_D": 32}, num_stages=5, num_warps=2),
+ triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_D": 32}, num_stages=5, num_warps=2),
+ # Good config for fp8 inputs.
+ triton.Config({"BLOCK_SIZE_N": 256, "BLOCK_SIZE_D": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_D": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_D": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 256, "BLOCK_SIZE_D": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_D": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_D": 64}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_D": 64}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_D": 64}, num_stages=4, num_warps=4),
+ ]
+
+
+def get_autotune_config():
+ return get_cuda_autotune_config()
+
+
+# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
+# - A list of `triton.Config` objects that define different configurations of
+# meta-parameters (e.g., `BLOCK_SIZE_C1`) and compilation options (e.g., `num_warps`) to try
+# - An auto-tuning *key* whose change in values will trigger evaluation of all the
+# provided configs
+@custom_autotune(
+ configs=get_autotune_config(),
+ key=["B", "N", "H", "D"],
+)
+@triton.jit
+def vk_q_mm_divide_fwd_kernel_fp32(
+ # Pointers to matrices
+ a_ptr,
+ b_ptr,
+ c_ptr,
+ c_mid_ptr,
+ # Matrix dimensions
+ B,
+ N,
+ H,
+ D,
+ # The stride variables represent how much to increase the ptr by when moving by 1
+ # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
+ # by to get the element one row down (A has M rows).
+ stride_ab,
+ stride_ah,
+ stride_ac1,
+ stride_ad,
+ stride_bb,
+ stride_bn,
+ stride_bh,
+ stride_bd,
+ stride_cb,
+ stride_cn,
+ stride_ch,
+ stride_cc,
+ stride_cmidb,
+ stride_cmidn,
+ stride_cmidh,
+ stride_cmidc,
+ eps,
+ # Meta-parameters
+ BLOCK_SIZE_C1: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_D: tl.constexpr, #
+):
+ # -----------------------------------------------------------
+ # Map program ids `pid` to the block of C it should compute.
+ # This is done in a grouped ordering to promote L2 data reuse.
+ # See above `L2 Cache Optimizations` section for details.
+ pid = tl.program_id(axis=0)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ pid_b, pid_h, pid_n = pid // num_pid_n // H, (pid // num_pid_n) % H, pid % num_pid_n
+
+ # ----------------------------------------------------------
+ # Create pointers for the first blocks of A and B.
+ # We will advance this pointer as we move in the K direction
+ # and accumulate
+ # `a_ptrs` is a block of [BLOCK_SIZE_C1, BLOCK_SIZE_D] pointers
+ # `b_ptrs` is a block of [BLOCK_SIZE_D, BLOCK_SIZE_N] pointers
+ # See above `Pointer Arithmetic` section for details
+ offs_ac = tl.arange(0, BLOCK_SIZE_C1) % D
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+ offs_d = tl.arange(0, BLOCK_SIZE_D)
+ a_ptrs = a_ptr + (
+ pid_b * stride_ab + pid_h * stride_ah + offs_ac[:, None] * stride_ac1 + offs_d[None, :] * stride_ad
+ )
+ a1_ptrs = a_ptr + (pid_b * stride_ab + pid_h * stride_ah + D * stride_ac1 + offs_d[:, None] * stride_ad)
+ b_ptrs = b_ptr + (
+ pid_b * stride_bb + pid_h * stride_bh + offs_d[:, None] * stride_bd + offs_bn[None, :] * stride_bn
+ )
+ # -----------------------------------------------------------
+ # Iterate to compute a block of the C matrix.
+ # We accumulate into a `[BLOCK_SIZE_C1, BLOCK_SIZE_N]` block
+ # of fp32 values for higher accuracy.
+ # `accumulator` will be converted back to fp16 after the loop.
+ accumulator = tl.zeros((BLOCK_SIZE_C1, BLOCK_SIZE_N), dtype=tl.float32)
+ accumulator1 = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32)
+ for d in range(0, tl.cdiv(D, BLOCK_SIZE_D)):
+ # Load the next block of A and B, generate a mask by checking the K dimension.
+ # If it is out of bounds, set it to 0.
+ a = tl.load(a_ptrs, mask=offs_d[None, :] < D - d * BLOCK_SIZE_D, other=0.0).to(tl.float32)
+ a1 = tl.load(a1_ptrs, mask=offs_d[:, None] < D - d * BLOCK_SIZE_D, other=0.0).to(tl.float32)
+ b = tl.load(b_ptrs, mask=offs_d[:, None] < D - d * BLOCK_SIZE_D, other=0.0).to(tl.float32)
+ # We accumulate along the K dimension.
+ accumulator = tl.dot(a, b, accumulator)
+ accumulator1 += tl.sum(a1 * b, axis=0)
+ # Advance the ptrs to the next K block.
+ a_ptrs += BLOCK_SIZE_D * stride_ad
+ a1_ptrs += BLOCK_SIZE_D * stride_ad
+ b_ptrs += BLOCK_SIZE_D * stride_bd
+ # You can fuse arbitrary activation functions here
+ # while the accumulator is still in FP32!
+ c = (accumulator / (accumulator1 + eps)).to(c_ptr.dtype.element_ty)
+
+ # -----------------------------------------------------------
+ # Write back the block of the output matrix C with masks.
+ offs_cc = tl.arange(0, BLOCK_SIZE_C1)
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ c_ptrs = c_ptr + stride_cb * pid_b + stride_ch * pid_h + stride_cc * offs_cc[:, None] + stride_cn * offs_cn[None, :]
+ c_mask = (offs_cc[:, None] < D) & (offs_cn[None, :] < N)
+ tl.store(c_ptrs, c, mask=c_mask)
+
+ c_mid_ptrs = (
+ c_mid_ptr
+ + stride_cmidb * pid_b
+ + stride_cmidh * pid_h
+ + stride_cmidc * offs_cc[:, None]
+ + stride_cmidn * offs_cn[None, :]
+ )
+ tl.store(c_mid_ptrs, accumulator, mask=c_mask)
+ c_mid_ptrs_lastrow = (
+ c_mid_ptr + stride_cmidb * pid_b + stride_cmidh * pid_h + stride_cmidc * D + stride_cmidn * offs_cn
+ )
+ tl.store(c_mid_ptrs_lastrow, accumulator1, mask=offs_cn < N)
+
+
+def vk_q_mm_divide_fwd(
+ a: torch.Tensor, b: torch.Tensor, eps: float, compute_dtype: torch.dtype, output_dtype: torch.dtype
+) -> torch.Tensor:
+ """
+ a: (B, H, C+1, D) # C=D
+ b: (B, N, H, D)
+ """
+ # Check constraints.
+ assert a.dim() == 4 and b.dim() == 4
+ assert (
+ a.shape[0] == b.shape[0]
+ and a.shape[1] == b.shape[2]
+ and a.shape[3] == b.shape[3]
+ and a.shape[2] == a.shape[3] + 1
+ )
+
+ B, N, H, D = b.shape
+ # Allocates output.
+ c_mid = torch.empty((B, N, H, D + 1), device=a.device, dtype=compute_dtype)
+ c = torch.empty((B, N, H, D), device=a.device, dtype=output_dtype)
+ # 1D launch kernel where each block gets its own program.
+ grid = lambda META: (B * H * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
+ if compute_dtype == torch.float:
+ vk_q_mm_divide_fwd_kernel_fp32[grid](
+ a,
+ b,
+ c,
+ c_mid, #
+ B,
+ N,
+ H,
+ D, #
+ a.stride(0),
+ a.stride(1),
+ a.stride(2),
+ a.stride(3), #
+ b.stride(0),
+ b.stride(1),
+ b.stride(2),
+ b.stride(3), #
+ c.stride(0),
+ c.stride(1),
+ c.stride(2),
+ c.stride(3), #
+ c_mid.stride(0),
+ c_mid.stride(1),
+ c_mid.stride(2),
+ c_mid.stride(3), #
+ eps,
+ BLOCK_SIZE_C1=triton.next_power_of_2(D),
+ )
+ else:
+ raise NotImplementedError()
+ return c, c_mid
diff --git a/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/vk_q_mm_relu_bwd.py b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/vk_q_mm_relu_bwd.py
new file mode 100644
index 0000000..b707b0e
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/triton_lite_mla_kernels/vk_q_mm_relu_bwd.py
@@ -0,0 +1,216 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import ipdb
+import torch
+import triton
+import triton.language as tl
+
+from ..utils.custom_autotune import custom_autotune
+
+
+def get_cuda_autotune_config():
+ return [
+ triton.Config({"BLOCK_SIZE_N": 256}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_SIZE_N": 256}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_SIZE_N": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 64}, num_stages=5, num_warps=2),
+ triton.Config({"BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 32}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_N": 32}, num_stages=5, num_warps=2),
+ ]
+
+
+def get_autotune_config():
+ return get_cuda_autotune_config()
+
+
+# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
+# - A list of `triton.Config` objects that define different configurations of
+# meta-parameters (e.g., `BLOCK_SIZE_C1`) and compilation options (e.g., `num_warps`) to try
+# - An auto-tuning *key* whose change in values will trigger evaluation of all the
+# provided configs
+@custom_autotune(
+ configs=get_autotune_config(),
+ key=["B", "N", "H", "C"],
+)
+@triton.jit
+def vk_q_mm_relu_bwd_kernel(
+ # Pointers to matrices
+ grad_vk_q_ptr,
+ vk_ptr,
+ q_ptr,
+ q_relu_mask_ptr,
+ grad_vk_ptr,
+ grad_q_ptr, #
+ # Matrix dimensions
+ B,
+ N,
+ H,
+ C,
+ # The stride variables represent how much to increase the ptr by when moving by 1
+ # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
+ # by to get the element one row down (A has M rows).
+ stride_vk_q_b,
+ stride_vk_q_n,
+ stride_vk_q_h,
+ stride_vk_q_c1,
+ stride_vk_b,
+ stride_vk_h,
+ stride_vk_c1,
+ stride_vk_c,
+ stride_q_b,
+ stride_q_n,
+ stride_q_h,
+ stride_q_c,
+ stride_grad_q_b,
+ stride_grad_q_n,
+ stride_grad_q_h,
+ stride_grad_q_c,
+ # Meta-parameters
+ BLOCK_SIZE_C: tl.constexpr,
+ BLOCK_SIZE_C1: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr, #
+):
+ """
+ Input:
+ grad_vk_q: (B, N, H, C+1), fp32
+ vk: (B, H, C+1, C), fp32
+ q: (B, N, H, C), fp16
+ q_relu_mask: (B, N, H, C), bool
+ Output:
+ grad_vk: (B, H, C+1, C), fp32
+ grad_q: (B, N, H, C), fp16
+ """
+ # -----------------------------------------------------------
+ # Map program ids `pid` to the block of C it should compute.
+ # This is done in a grouped ordering to promote L2 data reuse.
+ # See above `L2 Cache Optimizations` section for details.
+ pid = tl.program_id(axis=0)
+ pid_b, pid_h = pid // H, pid % H
+
+ offs_c = tl.arange(0, BLOCK_SIZE_C)
+ c_mask = offs_c < C
+ offs_c1 = tl.arange(0, BLOCK_SIZE_C1)
+ c1_mask = offs_c1 < C + 1
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
+ # n_mask = offs_n < N
+ grad_vk_q_ptrs = (
+ grad_vk_q_ptr
+ + pid_b * stride_vk_q_b
+ + pid_h * stride_vk_q_h
+ + offs_n[:, None] * stride_vk_q_n
+ + offs_c1[None, :] * stride_vk_q_c1
+ ) # n, C1
+ vk_offs = (
+ pid_b * stride_vk_b + pid_h * stride_vk_h + offs_c1[:, None] * stride_vk_c1 + offs_c[None, :] * stride_vk_c
+ ) # C1, C
+ q_offs = (
+ pid_b * stride_q_b + pid_h * stride_q_h + offs_c[:, None] * stride_q_c + offs_n[None, :] * stride_q_n
+ ) # C, n
+ grad_q_offs = (
+ pid_b * stride_grad_q_b
+ + pid_h * stride_grad_q_h
+ + offs_c[:, None] * stride_grad_q_c
+ + offs_n[None, :] * stride_grad_q_n
+ ) # C, n
+
+ vk = tl.load(vk_ptr + vk_offs, mask=c1_mask[:, None] & c_mask[None, :], other=0.0) # C1, C
+ grad_vk = tl.zeros((BLOCK_SIZE_C, BLOCK_SIZE_C1), dtype=tl.float32)
+ for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
+ n_mask = offs_n < N - n * BLOCK_SIZE_N
+
+ grad_vk_q = tl.load(grad_vk_q_ptrs, mask=n_mask[:, None] & c1_mask[None, :], other=0.0) # n, C1
+ q = tl.load(q_ptr + q_offs, mask=c_mask[:, None] & n_mask[None, :], other=0.0).to(tl.float32) # C, n
+ q_relu_mask = tl.load(q_relu_mask_ptr + q_offs, mask=c_mask[:, None] & n_mask[None, :], other=0) # C, n
+
+ grad_q = tl.trans(tl.dot(grad_vk_q, vk)) # n, C -> C, n
+ grad_q = tl.where(q_relu_mask, grad_q, 0).to(grad_q_ptr.dtype.element_ty) # C, n
+ grad_vk = tl.dot(q, grad_vk_q, grad_vk)
+
+ tl.store(grad_q_ptr + grad_q_offs, grad_q, mask=c_mask[:, None] & n_mask[None, :])
+
+ grad_vk_q_ptrs += BLOCK_SIZE_N * stride_vk_q_n
+ q_offs += BLOCK_SIZE_N * stride_q_n
+ grad_q_offs += BLOCK_SIZE_N * stride_grad_q_n
+
+ tl.store(grad_vk_ptr + vk_offs, tl.trans(grad_vk), mask=c1_mask[:, None] & c_mask[None, :])
+
+
+def vk_q_mm_relu_bwd(
+ grad_vk_q: torch.Tensor, vk: torch.Tensor, q: torch.Tensor, q_relu_mask: torch.Tensor, grad_q: torch.Tensor
+) -> torch.Tensor:
+ """
+ Input:
+ grad_vk_q: (B, N, H, C+1), fp32
+ vk: (B, H, C+1, C), fp32
+ q: (B, N, H, C), fp16
+ q_relu_mask: (B, N, H, C), bool
+ grad_q: (B, N, H, C), fp16
+ Output:
+ grad_vk: (B, H, C+1, C), fp32
+ """
+
+ assert grad_vk_q.dim() == 4 and vk.dim() == 4 and q.dim() == 4 and q_relu_mask.dim() == 4
+ assert q.shape == q_relu_mask.shape
+ assert grad_vk_q.shape[0] == vk.shape[0] == q.shape[0] # B
+ assert grad_vk_q.shape[1] == q.shape[1] # N
+ assert grad_vk_q.shape[2] == vk.shape[1] == q.shape[2] # N
+ assert grad_vk_q.shape[3] - 1 == vk.shape[2] - 1 == vk.shape[3] == q.shape[3] # C
+
+ B, N, H, C = q.shape
+ # Allocates output.
+ grad_vk = torch.empty_like(vk)
+
+ # 1D launch kernel where each block gets its own program.
+ grid = lambda META: (B * H,)
+ vk_q_mm_relu_bwd_kernel[grid](
+ grad_vk_q,
+ vk,
+ q,
+ q_relu_mask,
+ grad_vk,
+ grad_q, #
+ B,
+ N,
+ H,
+ C, #
+ grad_vk_q.stride(0),
+ grad_vk_q.stride(1),
+ grad_vk_q.stride(2),
+ grad_vk_q.stride(3), #
+ vk.stride(0),
+ vk.stride(1),
+ vk.stride(2),
+ vk.stride(3), #
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ q.stride(3), #
+ grad_q.stride(0),
+ grad_q.stride(1),
+ grad_q.stride(2),
+ grad_q.stride(3), #
+ BLOCK_SIZE_C=triton.next_power_of_2(C),
+ BLOCK_SIZE_C1=triton.next_power_of_2(C + 1),
+ )
+
+ # ref_grad_q = (grad_vk_q.permute(0, 2, 1, 3)@vk).permute(0, 2, 1, 3)
+ # ref_grad_vk = (grad_vk_q.permute(0, 2, 3, 1)@q.float().permute(0, 2, 1, 3))
+ # ref_grad_q.mul_(q_relu_mask)
+ # ipdb.set_trace()
+ return grad_vk
diff --git a/diffusion/model/nets/fastlinear/modules/triton_mb_conv_pre_glu.py b/diffusion/model/nets/fastlinear/modules/triton_mb_conv_pre_glu.py
new file mode 100644
index 0000000..8c8f024
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/triton_mb_conv_pre_glu.py
@@ -0,0 +1,127 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+from torch import nn
+
+from .nn.act import build_act, get_act_name
+from .nn.conv import ConvLayer
+from .nn.norm import build_norm, get_norm_name
+from .triton_mb_conv_pre_glu_kernels.depthwise_conv_fwd import depthwise_conv_fwd
+from .triton_mb_conv_pre_glu_kernels.linear_glu_fwd import linear_glu_fwd
+from .utils.model import get_same_padding, val2tuple
+
+
+class TritonMBConvPreGLU(nn.Module):
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ kernel_size=3,
+ stride=1,
+ mid_dim=None,
+ expand=6,
+ padding: int or None = None,
+ use_bias=False,
+ norm=(None, None, "ln2d"),
+ act=("silu", "silu", None),
+ ):
+ super().__init__()
+ use_bias = val2tuple(use_bias, 3)
+ norm = val2tuple(norm, 3)
+ act = val2tuple(act, 3)
+
+ mid_dim = mid_dim or round(in_dim * expand)
+
+ assert (
+ use_bias == (True, True, False)
+ and norm == (None, None, None)
+ and act == ("silu", "silu", None)
+ and stride == 1
+ and padding is None
+ )
+
+ self.inverted_conv = ConvLayer(
+ in_dim,
+ mid_dim * 2,
+ 1,
+ use_bias=use_bias[0],
+ norm=norm[0],
+ act=None,
+ )
+ self.glu_act = build_act(act[0], inplace=False)
+ self.depth_conv = ConvLayer(
+ mid_dim,
+ mid_dim,
+ kernel_size,
+ stride=stride,
+ groups=mid_dim,
+ padding=padding,
+ use_bias=use_bias[1],
+ norm=norm[1],
+ act=act[1],
+ )
+ self.point_conv = ConvLayer(
+ mid_dim,
+ out_dim,
+ 1,
+ use_bias=use_bias[2],
+ norm=norm[2],
+ act=act[2],
+ )
+
+ def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor:
+ C = x.shape[2]
+ # x = self.inverted_conv(x)
+ # x, gate = torch.chunk(x, 2, dim=1)
+ # gate = self.glu_act(gate)
+ # x = x * gate
+ x = linear_glu_fwd(x, self.inverted_conv.conv.weight[:, :, 0, 0], self.inverted_conv.conv.bias)
+
+ B, N, D = x.shape
+ if HW is None:
+ H = W = int(N**0.5)
+ else:
+ H, W = HW
+
+ x = x.reshape(B, H, W, D)
+ # x = depthwise_conv_fwd(x, self.depth_conv.conv.weight[:, 0], self.depth_conv.conv.bias)
+ # x = self.depth_conv.act(x)
+
+ x = x.permute(0, 3, 1, 2)
+
+ x = self.depth_conv(x)
+ x = self.point_conv(x)
+
+ x = x.reshape(B, C, N).permute(0, 2, 1)
+ return x
+
+ @property
+ def module_str(self) -> str:
+ _str = f"{self.depth_conv.kernel_size}{type(self).__name__}("
+ _str += f"in={self.inverted_conv.in_dim},mid={self.depth_conv.in_dim},out={self.point_conv.out_dim},s={self.depth_conv.stride}"
+ _str += (
+ f",norm={get_norm_name(self.inverted_conv.norm)}"
+ f"+{get_norm_name(self.depth_conv.norm)}"
+ f"+{get_norm_name(self.point_conv.norm)}"
+ )
+ _str += (
+ f",act={get_act_name(self.inverted_conv.act)}"
+ f"+{get_act_name(self.depth_conv.act)}"
+ f"+{get_act_name(self.point_conv.act)}"
+ )
+ _str += f",glu_act={get_act_name(self.glu_act)})"
+ return _str
diff --git a/diffusion/model/nets/fastlinear/modules/triton_mb_conv_pre_glu_kernels/depthwise_conv_fwd.py b/diffusion/model/nets/fastlinear/modules/triton_mb_conv_pre_glu_kernels/depthwise_conv_fwd.py
new file mode 100644
index 0000000..3c81d30
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/triton_mb_conv_pre_glu_kernels/depthwise_conv_fwd.py
@@ -0,0 +1,204 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import ipdb
+import torch
+import triton
+import triton.language as tl
+
+# from ..utils.custom_autotune import custom_autotune
+
+
+def get_cuda_autotune_config():
+ return [
+ triton.Config({"BLOCK_SIZE_H": 128, "BLOCK_SIZE_W": 256}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_SIZE_H": 64, "BLOCK_SIZE_W": 256}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_H": 128, "BLOCK_SIZE_W": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_H": 128, "BLOCK_SIZE_W": 64}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_H": 64, "BLOCK_SIZE_W": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_H": 128, "BLOCK_SIZE_W": 32}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_H": 64, "BLOCK_SIZE_W": 32}, num_stages=5, num_warps=2),
+ triton.Config({"BLOCK_SIZE_H": 32, "BLOCK_SIZE_W": 64}, num_stages=5, num_warps=2),
+ # Good config for fp8 inputs.
+ triton.Config({"BLOCK_SIZE_H": 128, "BLOCK_SIZE_W": 256}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_SIZE_H": 256, "BLOCK_SIZE_W": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_SIZE_H": 256, "BLOCK_SIZE_W": 64}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_H": 64, "BLOCK_SIZE_W": 256}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_H": 128, "BLOCK_SIZE_W": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_H": 128, "BLOCK_SIZE_W": 64}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_H": 64, "BLOCK_SIZE_W": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_SIZE_H": 128, "BLOCK_SIZE_W": 32}, num_stages=4, num_warps=4),
+ ]
+
+
+def get_autotune_config():
+ return get_cuda_autotune_config()
+
+
+# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
+# - A list of `triton.Config` objects that define different configurations of
+# meta-parameters (e.g., `BLOCK_SIZE_H`) and compilation options (e.g., `num_warps`) to try
+# - An auto-tuning *key* whose change in values will trigger evaluation of all the
+# provided configs
+@triton.autotune(
+ configs=get_autotune_config(),
+ key=["B", "H", "W", "C", "K"],
+)
+@triton.jit
+def depthwise_conv_fwd_kernel(
+ # Pointers to matrices
+ x_ptr,
+ weight_ptr,
+ bias_ptr,
+ y_ptr,
+ # Matrix dimensions
+ B,
+ H,
+ W,
+ C,
+ K,
+ # The stride variables represent how much to increase the ptr by when moving by 1
+ # element in a particular dimension. E.g. `stride_x_m` is how much to increase `x_ptr`
+ # by to get the element one row down (A has M rows).
+ stride_x_b,
+ stride_x_h,
+ stride_x_w,
+ stride_x_c, #
+ stride_weight_c,
+ stride_weight_k1,
+ stride_weight_k2, #
+ stride_bias_c,
+ # Meta-parameters
+ BLOCK_SIZE_H: tl.constexpr,
+ BLOCK_SIZE_W: tl.constexpr, #
+):
+ """
+ Input:
+ x: (B, H, W, C)
+ weight: (C, K, K)
+ bias: (C,)
+ Output:
+ y: (B, H, W, C)
+ """
+ # -----------------------------------------------------------
+ # Map program ids `pid` to the block of C it should compute.
+ # This is done in a grouped ordering to promote L2 data reuse.
+ # See above `L2 Cache Optimizations` section for details.
+ pid = tl.program_id(axis=0)
+ num_pid_h = tl.cdiv(H, BLOCK_SIZE_H)
+ num_pid_w = tl.cdiv(W, BLOCK_SIZE_W)
+ pid_bc, pid_hw = pid // (num_pid_h * num_pid_w), pid % (num_pid_h * num_pid_w)
+ pid_b, pid_c, pid_h, pid_w = pid_bc // C, pid_bc % C, pid_hw // num_pid_w, pid_hw % num_pid_w
+
+ offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
+ offs_w = pid_w * BLOCK_SIZE_W + tl.arange(0, BLOCK_SIZE_W)
+
+ offs_xy = (
+ pid_b * stride_x_b + offs_h[:, None] * stride_x_h + offs_w[None, :] * stride_x_w + pid_c * stride_x_c
+ ) # BLOCK_SIZE_H, BLOCK_SIZE_W
+
+ K_2 = K // 2
+ accumulator = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_W), dtype=tl.float32)
+ for kh in range(-K_2, K_2 + 1):
+ mask_h = (offs_h >= -kh) & (offs_h < H - kh)
+ for kw in range(-K_2, K_2 + 1):
+ mask_w = (offs_w >= -kw) & (offs_w < W - kw)
+ weight = tl.load(
+ weight_ptr + pid_c * stride_weight_c + (kh + K_2) * stride_weight_k1 + (kw + K_2) * stride_weight_k2
+ )
+ x = tl.load(
+ x_ptr + offs_xy + kh * stride_x_h + kw * stride_x_w, mask=mask_h[:, None] & mask_w[None, :], other=0.0
+ )
+ accumulator += weight * x
+ bias = tl.load(bias_ptr + pid_c * stride_bias_c)
+ y = (accumulator + bias).to(y_ptr.dtype.element_ty)
+
+ # -----------------------------------------------------------
+ # Write back the block of the output matrix C with masks.
+ y_mask = (offs_h[:, None] < H) & (offs_w[None, :] < W)
+ tl.store(y_ptr + offs_xy, y, mask=y_mask)
+
+
+def depthwise_conv_fwd(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
+ """
+ Input:
+ x: (B, H, W, C)
+ weight: (C, K, K)
+ bias: (C,)
+ Output:
+ y: (B, H, W, C)
+ """
+ # ipdb.set_trace()
+ assert x.dim() == 4 and weight.dim() == 3 and bias.dim() == 1
+ assert x.shape[-1] == weight.shape[0] == bias.shape[0] # C
+ assert weight.shape[1] == weight.shape[2] # K
+ B, H, W, C = x.shape
+ K = weight.shape[1]
+
+ # Allocates output.
+ y = torch.empty_like(x)
+ # 1D launch kernel where each block gets its own program.
+ grid = lambda META: (B * C * triton.cdiv(H, META["BLOCK_SIZE_H"]) * triton.cdiv(W, META["BLOCK_SIZE_W"]),)
+ if x.dtype == weight.dtype == bias.dtype:
+ depthwise_conv_fwd_kernel[grid](
+ x,
+ weight,
+ bias,
+ y, #
+ B,
+ H,
+ W,
+ C,
+ K, #
+ x.stride(0),
+ x.stride(1),
+ x.stride(2),
+ x.stride(3), #
+ weight.stride(0),
+ weight.stride(1),
+ weight.stride(2), #
+ bias.stride(0),
+ )
+ else:
+ raise NotImplementedError(f"data type {x.dtype} {weight.dtype} {bias.dtype} is not support")
+ return y
+
+
+def debug():
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ torch.cuda.manual_seed(0)
+ torch.manual_seed(0)
+
+ device = torch.device("cuda")
+ dtype = torch.float16
+
+ conv = torch.nn.Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, groups=512, device=device, dtype=dtype
+ )
+ x = torch.randn(16, 512, 32, 32, device=device, dtype=dtype).to(memory_format=torch.channels_last)
+ ref_y = conv(x)
+ tri_y = depthwise_conv_fwd(x.permute(0, 2, 3, 1), conv.weight[:, 0], conv.bias).permute(0, 3, 1, 2)
+
+ ipdb.set_trace()
+
+
+if __name__ == "__main__":
+ debug()
+
+"""
+python -m modules.depthwise_conv_fwd
+"""
diff --git a/diffusion/model/nets/fastlinear/modules/triton_mb_conv_pre_glu_kernels/linear_glu_fwd.py b/diffusion/model/nets/fastlinear/modules/triton_mb_conv_pre_glu_kernels/linear_glu_fwd.py
new file mode 100644
index 0000000..0c72a5e
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/triton_mb_conv_pre_glu_kernels/linear_glu_fwd.py
@@ -0,0 +1,243 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import ipdb
+import torch
+import triton
+import triton.language as tl
+
+from ..utils.custom_autotune import custom_autotune
+
+
+def get_cuda_autotune_config():
+ return [
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2
+ ),
+ # Good config for fp8 inputs.
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
+ ),
+ ]
+
+
+def get_autotune_config():
+ return get_cuda_autotune_config()
+
+
+# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
+# - A list of `triton.Config` objects that define different configurations of
+# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
+# - An auto-tuning *key* whose change in values will trigger evaluation of all the
+# provided configs
+@custom_autotune(
+ configs=get_autotune_config(),
+ key=["M", "N", "K"],
+)
+@triton.jit
+def linear_glu_fwd_kernel(
+ # Pointers to matrices
+ x_ptr,
+ weight_ptr,
+ bias_ptr,
+ y_ptr,
+ # Matrix dimensions
+ M,
+ N,
+ K,
+ # The stride variables represent how much to increase the ptr by when moving by 1
+ # element in a particular dimension. E.g. `stride_x_m` is how much to increase `x_ptr`
+ # by to get the element one row down (A has M rows).
+ stride_x_m,
+ stride_x_k, #
+ stride_weight_n,
+ stride_weight_k, #
+ stride_bias_n,
+ stride_y_m,
+ stride_y_n,
+ # Meta-parameters
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr, #
+ GROUP_SIZE_M: tl.constexpr, #
+):
+ """
+ Input:
+ x: (..., C)
+ weight: (2*D, C)
+ bias: (2*D,)
+ Output:
+ y: (..., D)
+ """
+ # -----------------------------------------------------------
+ # Map program ids `pid` to the block of C it should compute.
+ # This is done in a grouped ordering to promote L2 data reuse.
+ # See above `L2 Cache Optimizations` section for details.
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ # ----------------------------------------------------------
+ # Create pointers for the first blocks of A and B.
+ # We will advance this pointer as we move in the K direction
+ # and accumulate
+ # `x_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
+ # `weight_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
+ # See above `Pointer Arithmetic` section for details
+ offs_x_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
+ offs_weight_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ x_ptrs = x_ptr + (offs_x_m[:, None] * stride_x_m + offs_k[None, :] * stride_x_k) # BLOCK_SIZE_M, BLOCK_SIZE_K
+ weight_ptrs = weight_ptr + (
+ offs_weight_n[None, :] * stride_weight_n + offs_k[:, None] * stride_weight_k
+ ) # BLOCK_SIZE_K, BLOCK_SIZE_N
+ weight_1_ptrs = weight_ptr + (
+ (N + offs_weight_n[None, :]) * stride_weight_n + offs_k[:, None] * stride_weight_k
+ ) # BLOCK_SIZE_K, BLOCK_SIZE_N
+
+ # -----------------------------------------------------------
+ # Iterate to compute a block of the C matrix.
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
+ # of fp32 values for higher accuracy.
+ # `accumulator` will be converted back to fp16 after the loop.
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ accumulator_1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+ # Load the next block of A and B, generate a mask by checking the K dimension.
+ # If it is out of bounds, set it to 0.
+ x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
+ weight = tl.load(weight_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
+ weight_1 = tl.load(weight_1_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
+ # We accumulate along the K dimension.
+ accumulator = tl.dot(x, weight, accumulator)
+ accumulator_1 = tl.dot(x, weight_1, accumulator_1)
+ # Advance the ptrs to the next K block.
+ x_ptrs += BLOCK_SIZE_K * stride_x_k
+ weight_ptrs += BLOCK_SIZE_K * stride_weight_k
+ weight_1_ptrs += BLOCK_SIZE_K * stride_weight_k
+
+ bias_ptrs = bias_ptr + (offs_weight_n * stride_bias_n) # BLOCK_SIZE_N
+ bias_1_ptrs = bias_ptr + ((N + offs_weight_n) * stride_bias_n) # BLOCK_SIZE_N
+ bias = tl.load(bias_ptrs)
+ bias_1 = tl.load(bias_1_ptrs)
+ accumulator += bias
+ accumulator_1 += bias_1
+
+ y = accumulator * accumulator_1 * tl.sigmoid(accumulator_1).to(y_ptr.dtype.element_ty)
+
+ # -----------------------------------------------------------
+ # Write back the block of the output matrix C with masks.
+ offs_y_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_y_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ y_offs = stride_y_m * offs_y_m[:, None] + stride_y_n * offs_y_n[None, :]
+ y_mask = (offs_y_m[:, None] < M) & (offs_y_n[None, :] < N)
+ tl.store(y_ptr + y_offs, y, mask=y_mask)
+
+
+def linear_glu_fwd(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
+ """
+ Input:
+ x: (..., C)
+ weight: (2*D, C)
+ bias: (2*D,)
+ Output:
+ y: (..., D)
+ """
+ # ipdb.set_trace()
+ assert x.dim() >= 1 and weight.dim() == 2 and bias.dim() == 1
+ assert x.shape[-1] == weight.shape[-1] # C
+ assert weight.shape[0] == bias.shape[0] # D
+ assert weight.shape[0] % 2 == 0 # D
+ M, K, N = torch.prod(torch.tensor(x.shape[:-1])).item(), x.shape[-1], weight.shape[0] // 2
+
+ # Allocates output.
+ y = torch.empty(x.shape[:-1] + (N,), device=x.device, dtype=x.dtype)
+ # 1D launch kernel where each block gets its own program.
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
+ if x.dtype == weight.dtype == bias.dtype:
+ linear_glu_fwd_kernel[grid](
+ x,
+ weight,
+ bias,
+ y, #
+ M,
+ N,
+ K, #
+ x.stride(-2),
+ x.stride(-1), #
+ weight.stride(0),
+ weight.stride(1), #
+ bias.stride(0),
+ y.stride(-2),
+ y.stride(-1),
+ )
+ else:
+ raise NotImplementedError(f"data type {x.dtype} {weight.dtype} {bias.dtype} is not support")
+ return y
diff --git a/diffusion/model/nets/fastlinear/modules/utils/compare_results.py b/diffusion/model/nets/fastlinear/modules/utils/compare_results.py
new file mode 100644
index 0000000..9d24ca4
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/utils/compare_results.py
@@ -0,0 +1,25 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def compare_results(name: str, result: torch.Tensor, ref_result: torch.Tensor):
+ print(f"comparing {name}")
+ diff = (result - ref_result).abs().view(-1)
+ max_error_pos = diff.argmax()
+ print(f"max error: {diff.max()}, mean error: {diff.mean()}")
+ print(f"max error pos: {result.view(-1)[max_error_pos]} {ref_result.view(-1)[max_error_pos]}")
diff --git a/diffusion/model/nets/fastlinear/modules/utils/custom_autotune.py b/diffusion/model/nets/fastlinear/modules/utils/custom_autotune.py
new file mode 100644
index 0000000..3174744
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/utils/custom_autotune.py
@@ -0,0 +1,123 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import builtins
+import json
+import os
+import pickle
+import time
+
+import ipdb
+import torch
+import torch.distributed as dist
+from triton.runtime.autotuner import Autotuner
+
+
+class CustomAutotuner(Autotuner):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.best_config_cache_path = os.path.expanduser(
+ os.path.join(
+ "~",
+ ".triton",
+ "best_config_cache",
+ torch.cuda.get_device_name(0).replace(" ", "_"),
+ self.base_fn.__name__ + ".pkl",
+ )
+ )
+ if os.path.exists(self.best_config_cache_path):
+ with open(self.best_config_cache_path, "rb") as f:
+ self.cache = pickle.load(f)
+
+ def run(self, *args, **kwargs):
+ self.nargs = dict(zip(self.arg_names, args))
+ used_cached_result = True
+ if len(self.configs) > 1:
+ all_args = {**self.nargs, **kwargs}
+ _args = []
+ for name in self.arg_names:
+ if name in all_args:
+ _args.append(all_args[name])
+ key = [_args[i] for i in self.key_idx]
+ for arg in _args:
+ if hasattr(arg, "dtype"):
+ key.append(str(arg.dtype))
+ key = tuple(key)
+ if key not in self.cache:
+ # prune configs
+ used_cached_result = False
+ pruned_configs = self.prune_configs(kwargs)
+ bench_start = time.time()
+ timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
+ bench_end = time.time()
+ self.bench_time = bench_end - bench_start
+ self.cache[key] = builtins.min(timings, key=timings.get)
+ self.pre_hook(args, reset_only=True)
+ self.configs_timings = timings
+ if not dist.is_initialized() or dist.get_rank() == 0:
+ best_config_cache_dir = os.path.dirname(self.best_config_cache_path)
+ os.makedirs(best_config_cache_dir, exist_ok=True)
+ with open(self.best_config_cache_path, "wb") as f:
+ pickle.dump(self.cache, f)
+ config = self.cache[key]
+ else:
+ config = self.configs[0]
+ self.best_config = config
+ if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result:
+ print(
+ f"Triton autotuning for function {self.base_fn.__name__} finished after "
+ f"{self.bench_time:.2f}s; best config selected: {self.best_config};"
+ )
+ if config.pre_hook is not None:
+ config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()})
+ ret = self.fn.run(
+ *args,
+ **kwargs,
+ **config.all_kwargs(),
+ )
+ self.nargs = None
+ return ret
+
+
+def custom_autotune(
+ configs,
+ key,
+ prune_configs_by=None,
+ reset_to_zero=None,
+ restore_value=None,
+ pre_hook=None,
+ post_hook=None,
+ warmup=25,
+ rep=100,
+ use_cuda_graph=False,
+):
+ def decorator(fn):
+ return CustomAutotuner(
+ fn,
+ fn.arg_names,
+ configs,
+ key,
+ reset_to_zero,
+ restore_value,
+ pre_hook=pre_hook,
+ post_hook=post_hook,
+ prune_configs_by=prune_configs_by,
+ warmup=warmup,
+ rep=rep,
+ use_cuda_graph=use_cuda_graph,
+ )
+
+ return decorator
diff --git a/diffusion/model/nets/fastlinear/modules/utils/dtype.py b/diffusion/model/nets/fastlinear/modules/utils/dtype.py
new file mode 100644
index 0000000..27e0640
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/utils/dtype.py
@@ -0,0 +1,39 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+
+def get_dtype_from_str(dtype: str) -> torch.dtype:
+ if dtype == "fp32":
+ return torch.float32
+ if dtype == "fp16":
+ return torch.float16
+ if dtype == "bf16":
+ return torch.bfloat16
+ raise NotImplementedError(f"dtype {dtype} is not supported")
+
+
+def get_tl_dtype_from_torch_dtype(dtype: torch.dtype) -> tl.dtype:
+ if dtype == torch.float32:
+ return tl.float32
+ if dtype == torch.float16:
+ return tl.float16
+ if dtype == torch.bfloat16:
+ return tl.bfloat16
+ raise NotImplementedError(f"dtype {dtype} is not supported")
diff --git a/diffusion/model/nets/fastlinear/modules/utils/export_onnx.py b/diffusion/model/nets/fastlinear/modules/utils/export_onnx.py
new file mode 100644
index 0000000..81e7be2
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/utils/export_onnx.py
@@ -0,0 +1,63 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+import warnings
+from typing import Any, Tuple
+
+import torch
+
+
+def export_onnx(
+ model: torch.nn.Module,
+ input_shape: Tuple[int],
+ export_path: str,
+ opset: int,
+ export_dtype: torch.dtype,
+ export_device: torch.device,
+) -> None:
+ model.eval()
+
+ dummy_input = {"x": torch.randn(input_shape, dtype=export_dtype, device=export_device)}
+ dynamic_axes = {
+ "x": {0: "batch_size"},
+ }
+
+ # _ = model(**dummy_input)
+
+ output_names = ["image_embeddings"]
+
+ export_dir = os.path.dirname(export_path)
+ if not os.path.exists(export_dir):
+ os.makedirs(export_dir)
+
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
+ warnings.filterwarnings("ignore", category=UserWarning)
+ print(f"Exporting onnx model to {export_path}...")
+ with open(export_path, "wb") as f:
+ torch.onnx.export(
+ model,
+ tuple(dummy_input.values()),
+ f,
+ export_params=True,
+ verbose=False,
+ opset_version=opset,
+ do_constant_folding=True,
+ input_names=list(dummy_input.keys()),
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ )
diff --git a/diffusion/model/nets/fastlinear/modules/utils/model.py b/diffusion/model/nets/fastlinear/modules/utils/model.py
new file mode 100644
index 0000000..c2979d7
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/modules/utils/model.py
@@ -0,0 +1,42 @@
+# Copyright 2024 MIT Han Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+
+def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
+ """Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
+ if isinstance(x, (list, tuple)):
+ return list(x)
+ return [x for _ in range(repeat_time)]
+
+
+def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
+ """Return tuple with min_len by repeating element at idx_repeat."""
+ # convert to list first
+ x = val2list(x)
+
+ # repeat elements if necessary
+ if len(x) > 0:
+ x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
+
+ return tuple(x)
+
+
+def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
+ if isinstance(kernel_size, tuple):
+ return tuple([get_same_padding(ks) for ks in kernel_size])
+ else:
+ assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
+ return kernel_size // 2
diff --git a/diffusion/model/nets/fastlinear/readme.md b/diffusion/model/nets/fastlinear/readme.md
new file mode 100644
index 0000000..64539f7
--- /dev/null
+++ b/diffusion/model/nets/fastlinear/readme.md
@@ -0,0 +1,32 @@
+# a fast implementation of linear attention
+
+## 64x64, fp16
+
+```bash
+# validate correctness
+## fp16 vs fp32
+python -m develop_triton_litemla attn_type=LiteMLA test_correctness=True
+## triton fp16 vs fp32
+python -m develop_triton_litemla attn_type=TritonLiteMLA test_correctness=True
+
+# test performance
+## fp16, forward
+python -m develop_triton_litemla attn_type=LiteMLA
+each step takes 10.81 ms
+max memory allocated: 2.2984 GB
+
+## triton fp16, forward
+python -m develop_triton_litemla attn_type=TritonLiteMLA
+each step takes 4.70 ms
+max memory allocated: 1.6480 GB
+
+## fp16, backward
+python -m develop_triton_litemla attn_type=LiteMLA backward=True
+each step takes 35.34 ms
+max memory allocated: 3.4412 GB
+
+## triton fp16, backward
+python -m develop_triton_litemla attn_type=TritonLiteMLA backward=True
+each step takes 14.25 ms
+max memory allocated: 2.4704 GB
+```
diff --git a/diffusion/model/nets/sana_U_shape.py b/diffusion/model/nets/sana_U_shape.py
new file mode 100644
index 0000000..d16b895
--- /dev/null
+++ b/diffusion/model/nets/sana_U_shape.py
@@ -0,0 +1,369 @@
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
+import os
+
+import torch
+import torch.nn as nn
+from timm.models.layers import DropPath
+
+from diffusion.model.builder import MODELS
+from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
+from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
+from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
+from diffusion.model.nets.sana_blocks import (
+ Attention,
+ CaptionEmbedder,
+ FlashAttention,
+ LiteLA,
+ MultiHeadCrossAttention,
+ PatchEmbed,
+ T2IFinalLayer,
+ TimestepEmbedder,
+ t2i_modulate,
+)
+from diffusion.model.norms import RMSNorm
+from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
+from diffusion.utils.logger import get_root_logger
+
+
+class SanaUBlock(nn.Module):
+ """
+ A SanaU block with global shared adaptive layer norm (adaLN-single) conditioning and U-shaped model.
+ """
+
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ mlp_ratio=4.0,
+ drop_path=0,
+ input_size=None,
+ qk_norm=False,
+ attn_type="flash",
+ ffn_type="mlp",
+ mlp_acts=("silu", "silu", None),
+ skip_linear=False,
+ **block_kwargs,
+ ):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ if attn_type == "flash":
+ # flash self attention
+ self.attn = FlashAttention(
+ hidden_size,
+ num_heads=num_heads,
+ qkv_bias=True,
+ qk_norm=qk_norm,
+ **block_kwargs,
+ )
+ elif attn_type == "linear":
+ # linear self attention
+ # TODO: Here the num_heads set to 36 for tmp used
+ self_num_heads = hidden_size // 32
+ self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
+ elif attn_type == "triton_linear":
+ # linear self attention with triton kernel fusion
+ # TODO: Here the num_heads set to 36 for tmp used
+ self_num_heads = hidden_size // 32
+ self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
+ elif attn_type == "vanilla":
+ # vanilla self attention
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True)
+ else:
+ raise ValueError(f"{attn_type} type is not defined.")
+
+ self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ # to be compatible with lower version pytorch
+ if ffn_type == "dwmlp":
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = DWMlp(
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
+ )
+ elif ffn_type == "glumbconv":
+ self.mlp = GLUMBConv(
+ in_features=hidden_size,
+ hidden_features=int(hidden_size * mlp_ratio),
+ use_bias=(True, True, False),
+ norm=(None, None, None),
+ act=mlp_acts,
+ )
+ elif ffn_type == "mbconvpreglu":
+ self.mlp = MBConvPreGLU(
+ in_dim=hidden_size,
+ out_dim=hidden_size,
+ mid_dim=int(hidden_size * mlp_ratio),
+ use_bias=(True, True, False),
+ norm=None,
+ act=("silu", "silu", None),
+ )
+ elif ffn_type == "mlp":
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
+ )
+ else:
+ raise ValueError(f"{ffn_type} type is not defined.")
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
+
+ # skip connection
+ if skip_linear:
+ self.skip_linear = nn.Linear(hidden_size * 2, hidden_size, bias=True)
+
+ def forward(self, x, y, t, mask=None, skip_x=None, **kwargs):
+ B, N, C = x.shape
+ if skip_x is not None:
+ x = self.skip_linear(torch.cat([x, skip_x], dim=-1))
+
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + t.reshape(B, 6, -1)
+ ).chunk(6, dim=1)
+ x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
+ x = x + self.cross_attn(x, y, mask)
+ x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
+
+ return x
+
+
+#############################################################################
+# Core SanaU Model #
+#################################################################################
+@MODELS.register_module()
+class SanaU(Sana):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+
+ def __init__(
+ self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=29,
+ num_heads=16,
+ mlp_ratio=4.0,
+ class_dropout_prob=0.1,
+ learn_sigma=True,
+ pred_sigma=True,
+ drop_path: float = 0.0,
+ caption_channels=2304,
+ pe_interpolation=1.0,
+ config=None,
+ model_max_length=300,
+ micro_condition=False,
+ qk_norm=False,
+ y_norm=False,
+ norm_eps=1e-5,
+ attn_type="flash",
+ ffn_type="mlp",
+ use_pe=True,
+ y_norm_scale_factor=1.0,
+ patch_embed_kernel=None,
+ mlp_acts=("silu", "silu", None),
+ **kwargs,
+ ):
+ super().__init__(
+ input_size=input_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ hidden_size=hidden_size,
+ depth=depth,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ class_dropout_prob=class_dropout_prob,
+ learn_sigma=learn_sigma,
+ pred_sigma=pred_sigma,
+ drop_path=drop_path,
+ caption_channels=caption_channels,
+ pe_interpolation=pe_interpolation,
+ config=config,
+ model_max_length=model_max_length,
+ micro_condition=micro_condition,
+ qk_norm=qk_norm,
+ y_norm=y_norm,
+ norm_eps=norm_eps,
+ attn_type=attn_type,
+ ffn_type=ffn_type,
+ use_pe=use_pe,
+ y_norm_scale_factor=y_norm_scale_factor,
+ patch_embed_kernel=patch_embed_kernel,
+ mlp_acts=mlp_acts,
+ **kwargs,
+ )
+
+ kernel_size = patch_embed_kernel or patch_size
+ self.x_embedder = PatchEmbed(
+ input_size, patch_size, in_channels, hidden_size, kernel_size=kernel_size, bias=True
+ )
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ num_patches = self.x_embedder.num_patches
+ self.base_size = input_size // self.patch_size
+ # Will use fixed sin-cos embedding:
+ self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
+
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
+
+ self.y_embedder = CaptionEmbedder(
+ in_channels=caption_channels,
+ hidden_size=hidden_size,
+ uncond_prob=class_dropout_prob,
+ act_layer=approx_gelu,
+ token_num=model_max_length,
+ )
+ if self.y_norm:
+ self.attention_y_norm = RMSNorm(hidden_size, scale_factor=y_norm_scale_factor, eps=norm_eps)
+ drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList(
+ [
+ SanaUBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ drop_path=drop_path[i],
+ input_size=(input_size // patch_size, input_size // patch_size),
+ qk_norm=qk_norm,
+ attn_type=attn_type,
+ ffn_type=ffn_type,
+ mlp_acts=mlp_acts,
+ skip_linear=i > depth // 2,
+ )
+ for i in range(depth)
+ ]
+ )
+ self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
+
+ self.initialize_weights()
+
+ if config:
+ logger = get_root_logger(os.path.join(config.work_dir, "train_log.log"))
+ logger = logger.warning
+ else:
+ logger = print
+ logger(f"use pe: {use_pe}, position embed interpolation: {self.pe_interpolation}, base size: {self.base_size}")
+ logger(
+ f"attention type: {attn_type}; ffn type: {ffn_type}; "
+ f"autocast linear attn: {os.environ.get('AUTOCAST_LINEAR_ATTN', False)}"
+ )
+
+ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
+ """
+ Forward pass of SanaU.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N, 1, 120, C) tensor of class labels
+ """
+ x = x.to(self.dtype)
+ timestep = timestep.to(self.dtype)
+ y = y.to(self.dtype)
+ pos_embed = self.pos_embed.to(self.dtype)
+ self.h, self.w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
+ x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
+ t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
+ t0 = self.t_block(t)
+ y = self.y_embedder(y, self.training) # (N, 1, L, D)
+ if self.y_norm:
+ y = self.attention_y_norm(y)
+ if mask is not None:
+ if mask.shape[0] != y.shape[0]:
+ mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
+ mask = mask.squeeze(1).squeeze(1)
+ y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
+ y_lens = mask.sum(dim=1).tolist()
+ else:
+ y_lens = [y.shape[2]] * y.shape[0]
+ y = y.squeeze(1).view(1, -1, x.shape[-1])
+ results_hooker = {}
+ for i, block in enumerate(self.blocks):
+ if i > len(self.blocks) // 2:
+ x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip_x=results_hooker[len(self.blocks) - i - 1])
+ else:
+ x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint
+ results_hooker[i] = x
+ x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
+ x = self.unpatchify(x) # (N, out_channels, H, W)
+ return x
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ if self.use_pe:
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
+ pos_embed = get_2d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ int(self.x_embedder.num_patches**0.5),
+ pe_interpolation=self.pe_interpolation,
+ base_size=self.base_size,
+ )
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+ nn.init.normal_(self.t_block[1].weight, std=0.02)
+
+ # Initialize caption embedding MLP:
+ nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
+ nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+
+#################################################################################
+# SanaU Configs #
+#################################################################################
+@MODELS.register_module()
+def SanaMSU_600M_P1_D28(**kwargs):
+ return SanaU(depth=28, hidden_size=1152, patch_size=1, num_heads=16, **kwargs)
+
+
+@MODELS.register_module()
+def SanaMSU_600M_P2_D28(**kwargs):
+ return SanaU(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
+
+
+@MODELS.register_module()
+def SanaMSU_600M_P4_D28(**kwargs):
+ return SanaU(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
+
+
+@MODELS.register_module()
+def SanaMSU_P1_D20(**kwargs):
+ # 20 layers, 1648.48M
+ return SanaU(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)
+
+
+@MODELS.register_module()
+def SanaMSU_P2_D20(**kwargs):
+ # 28 layers, 1648.48M
+ return SanaU(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs)
diff --git a/diffusion/model/nets/sana_U_shape_multi_scale.py b/diffusion/model/nets/sana_U_shape_multi_scale.py
new file mode 100644
index 0000000..927fc1b
--- /dev/null
+++ b/diffusion/model/nets/sana_U_shape_multi_scale.py
@@ -0,0 +1,376 @@
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
+import torch
+import torch.nn as nn
+from timm.models.layers import DropPath
+
+from diffusion.model.builder import MODELS
+from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
+from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonLiteMLAFwd
+from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
+from diffusion.model.nets.sana_blocks import (
+ Attention,
+ CaptionEmbedder,
+ FlashAttention,
+ LiteLA,
+ MultiHeadCrossAttention,
+ PatchEmbedMS,
+ SizeEmbedder,
+ T2IFinalLayer,
+ t2i_modulate,
+)
+from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
+
+
+class SanaUMSBlock(nn.Module):
+ """
+ A SanaU block with global shared adaptive layer norm (adaLN-single) conditioning and U-shaped model.
+ """
+
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ mlp_ratio=4.0,
+ drop_path=0.0,
+ input_size=None,
+ qk_norm=False,
+ attn_type="flash",
+ ffn_type="mlp",
+ mlp_acts=("silu", "silu", None),
+ skip_linear=False,
+ **block_kwargs,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ if attn_type == "flash":
+ # flash self attention
+ self.attn = FlashAttention(
+ hidden_size,
+ num_heads=num_heads,
+ qkv_bias=True,
+ qk_norm=qk_norm,
+ **block_kwargs,
+ )
+ elif attn_type == "linear":
+ # linear self attention
+ # TODO: Here the num_heads set to 36 for tmp used
+ self_num_heads = hidden_size // 32
+ self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
+ elif attn_type == "triton_linear":
+ # linear self attention with triton kernel fusion
+ self_num_heads = hidden_size // 32
+ self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
+ elif attn_type == "vanilla":
+ # vanilla self attention
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True)
+ else:
+ raise ValueError(f"{attn_type} type is not defined.")
+
+ self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ if ffn_type == "dwmlp":
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = DWMlp(
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
+ )
+ elif ffn_type == "glumbconv":
+ self.mlp = GLUMBConv(
+ in_features=hidden_size,
+ hidden_features=int(hidden_size * mlp_ratio),
+ use_bias=(True, True, False),
+ norm=(None, None, None),
+ act=mlp_acts,
+ )
+ elif ffn_type == "mlp":
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
+ )
+ elif ffn_type == "mbconvpreglu":
+ self.mlp = MBConvPreGLU(
+ in_dim=hidden_size,
+ out_dim=hidden_size,
+ mid_dim=int(hidden_size * mlp_ratio),
+ use_bias=(True, True, False),
+ norm=None,
+ act=("silu", "silu", None),
+ )
+ else:
+ raise ValueError(f"{ffn_type} type is not defined.")
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
+
+ # skip connection
+ if skip_linear:
+ self.skip_linear = nn.Linear(hidden_size * 2, hidden_size, bias=True)
+
+ def forward(self, x, y, t, mask=None, HW=None, skip_x=None, **kwargs):
+ B, N, C = x.shape
+ if skip_x is not None:
+ x = self.skip_linear(torch.cat([x, skip_x], dim=-1))
+
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + t.reshape(B, 6, -1)
+ ).chunk(6, dim=1)
+ x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
+ x = x + self.cross_attn(x, y, mask)
+ x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp), HW=HW))
+
+ return x
+
+
+#############################################################################
+# Core SanaUMS Model #
+#################################################################################
+@MODELS.register_module()
+class SanaUMS(Sana):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+
+ def __init__(
+ self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=29,
+ num_heads=16,
+ mlp_ratio=4.0,
+ class_dropout_prob=0.1,
+ learn_sigma=True,
+ pred_sigma=True,
+ drop_path: float = 0.0,
+ caption_channels=2304,
+ pe_interpolation=1.0,
+ config=None,
+ model_max_length=300,
+ micro_condition=False,
+ qk_norm=False,
+ y_norm=False,
+ norm_eps=1e-5,
+ attn_type="flash",
+ ffn_type="mlp",
+ use_pe=True,
+ y_norm_scale_factor=1.0,
+ patch_embed_kernel=None,
+ mlp_acts=("silu", "silu", None),
+ **kwargs,
+ ):
+ super().__init__(
+ input_size=input_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ hidden_size=hidden_size,
+ depth=depth,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ class_dropout_prob=class_dropout_prob,
+ learn_sigma=learn_sigma,
+ pred_sigma=pred_sigma,
+ drop_path=drop_path,
+ caption_channels=caption_channels,
+ pe_interpolation=pe_interpolation,
+ config=config,
+ model_max_length=model_max_length,
+ micro_condition=micro_condition,
+ qk_norm=qk_norm,
+ y_norm=y_norm,
+ norm_eps=norm_eps,
+ attn_type=attn_type,
+ ffn_type=ffn_type,
+ use_pe=use_pe,
+ y_norm_scale_factor=y_norm_scale_factor,
+ patch_embed_kernel=patch_embed_kernel,
+ mlp_acts=mlp_acts,
+ **kwargs,
+ )
+ self.h = self.w = 0
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
+
+ kernel_size = patch_embed_kernel or patch_size
+ self.x_embedder = PatchEmbedMS(patch_size, in_channels, hidden_size, kernel_size=kernel_size, bias=True)
+ self.y_embedder = CaptionEmbedder(
+ in_channels=caption_channels,
+ hidden_size=hidden_size,
+ uncond_prob=class_dropout_prob,
+ act_layer=approx_gelu,
+ token_num=model_max_length,
+ )
+ self.micro_conditioning = micro_condition
+ drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList(
+ [
+ SanaUMSBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ drop_path=drop_path[i],
+ input_size=(input_size // patch_size, input_size // patch_size),
+ qk_norm=qk_norm,
+ attn_type=attn_type,
+ ffn_type=ffn_type,
+ mlp_acts=mlp_acts,
+ skip_linear=i > depth // 2,
+ )
+ for i in range(depth)
+ ]
+ )
+ self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
+
+ self.initialize()
+
+ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
+ """
+ Forward pass of SanaUMS.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N, 1, 120, C) tensor of class labels
+ """
+ bs = x.shape[0]
+ x = x.to(self.dtype)
+ timestep = timestep.to(self.dtype)
+ y = y.to(self.dtype)
+ self.h, self.w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
+ if self.use_pe:
+ pos_embed = (
+ torch.from_numpy(
+ get_2d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ (self.h, self.w),
+ pe_interpolation=self.pe_interpolation,
+ base_size=self.base_size,
+ )
+ )
+ .unsqueeze(0)
+ .to(x.device)
+ .to(self.dtype)
+ )
+ x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
+ else:
+ x = self.x_embedder(x)
+
+ t = self.t_embedder(timestep) # (N, D)
+
+ t0 = self.t_block(t)
+ y = self.y_embedder(y, self.training) # (N, D)
+ if self.y_norm:
+ y = self.attention_y_norm(y)
+
+ if mask is not None:
+ if mask.shape[0] != y.shape[0]:
+ mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
+ mask = mask.squeeze(1).squeeze(1)
+ y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
+ y_lens = mask.sum(dim=1).tolist()
+ else:
+ y_lens = [y.shape[2]] * y.shape[0]
+ y = y.squeeze(1).view(1, -1, x.shape[-1])
+ results_hooker = {}
+ for i, block in enumerate(self.blocks):
+ if i > len(self.blocks) // 2:
+ x = auto_grad_checkpoint(
+ block, x, y, t0, y_lens, (self.h, self.w), results_hooker[len(self.blocks) - i - 1]
+ )
+ else:
+ x = auto_grad_checkpoint(
+ block, x, y, t0, y_lens, (self.h, self.w)
+ ) # (N, T, D) #support grad checkpoint
+ results_hooker[i] = x
+
+ x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
+ x = self.unpatchify(x) # (N, out_channels, H, W)
+
+ return x
+
+ def unpatchify(self, x):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+ p = self.x_embedder.patch_size[0]
+ assert self.h * self.w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c))
+ x = torch.einsum("nhwpqc->nchpwq", x)
+ imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p))
+ return imgs
+
+ def initialize(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+ nn.init.normal_(self.t_block[1].weight, std=0.02)
+ if self.micro_conditioning:
+ nn.init.normal_(self.csize_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.csize_embedder.mlp[2].weight, std=0.02)
+ nn.init.normal_(self.ar_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.ar_embedder.mlp[2].weight, std=0.02)
+
+ # Initialize caption embedding MLP:
+ nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
+ nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
+
+
+#################################################################################
+# SanaU multi-scale Configs #
+#################################################################################
+
+
+@MODELS.register_module()
+def SanaUMS_600M_P1_D28(**kwargs):
+ return SanaUMS(depth=28, hidden_size=1152, patch_size=1, num_heads=16, **kwargs)
+
+
+@MODELS.register_module()
+def SanaUMS_600M_P2_D28(**kwargs):
+ return SanaUMS(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
+
+
+@MODELS.register_module()
+def SanaUMS_600M_P4_D28(**kwargs):
+ return SanaUMS(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
+
+
+@MODELS.register_module()
+def SanaUMS_P1_D20(**kwargs):
+ # 20 layers, 1648.48M
+ return SanaUMS(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)
+
+
+@MODELS.register_module()
+def SanaUMS_P2_D20(**kwargs):
+ # 28 layers, 1648.48M
+ return SanaUMS(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs)
diff --git a/diffusion/model/nets/sana_multi_scale_adaln.py b/diffusion/model/nets/sana_multi_scale_adaln.py
new file mode 100644
index 0000000..8c2ce8a
--- /dev/null
+++ b/diffusion/model/nets/sana_multi_scale_adaln.py
@@ -0,0 +1,382 @@
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
+import torch
+import torch.nn as nn
+from timm.models.layers import DropPath
+
+from diffusion.model.builder import MODELS
+from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
+from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonLiteMLAFwd
+from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
+from diffusion.model.nets.sana_blocks import (
+ Attention,
+ CaptionEmbedder,
+ FlashAttention,
+ LiteLA,
+ MultiHeadCrossAttention,
+ PatchEmbedMS,
+ SizeEmbedder,
+ T2IFinalLayer,
+ modulate,
+)
+from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
+
+
+class SanaMSAdaLNBlock(nn.Module):
+ """
+ A Sana block with layer-wise adaptive layer norm zero (adaLN-Zero) conditioning.
+ """
+
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ mlp_ratio=4.0,
+ drop_path=0.0,
+ input_size=None,
+ qk_norm=False,
+ attn_type="flash",
+ ffn_type="mlp",
+ mlp_acts=("silu", "silu", None),
+ **block_kwargs,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ if attn_type == "flash":
+ # flash self attention
+ self.attn = FlashAttention(
+ hidden_size,
+ num_heads=num_heads,
+ qkv_bias=True,
+ qk_norm=qk_norm,
+ **block_kwargs,
+ )
+ elif attn_type == "linear":
+ # linear self attention
+ # TODO: Here the num_heads set to 36 for tmp used
+ self_num_heads = hidden_size // 32
+ self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
+ elif attn_type == "triton_linear":
+ # linear self attention with triton kernel fusion
+ self_num_heads = hidden_size // 32
+ self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
+ elif attn_type == "vanilla":
+ # vanilla self attention
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True)
+ else:
+ raise ValueError(f"{attn_type} type is not defined.")
+
+ self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ if ffn_type == "dwmlp":
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = DWMlp(
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
+ )
+ elif ffn_type == "glumbconv":
+ self.mlp = GLUMBConv(
+ in_features=hidden_size,
+ hidden_features=int(hidden_size * mlp_ratio),
+ use_bias=(True, True, False),
+ norm=(None, None, None),
+ act=mlp_acts,
+ )
+ elif ffn_type == "mlp":
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
+ )
+ elif ffn_type == "mbconvpreglu":
+ self.mlp = MBConvPreGLU(
+ in_dim=hidden_size,
+ out_dim=hidden_size,
+ mid_dim=int(hidden_size * mlp_ratio),
+ use_bias=(True, True, False),
+ norm=None,
+ act=("silu", "silu", None),
+ )
+ else:
+ raise ValueError(f"{ffn_type} type is not defined.")
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.scale_shift_table = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ self.silu = nn.SiLU()
+
+ def forward(self, x, y, t, mask=None, HW=None, **kwargs):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.scale_shift_table(self.silu(t)).chunk(
+ 6, dim=1
+ )
+
+ x = x + self.drop_path(gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
+ x = x + self.cross_attn(x, y, mask)
+ x = x + self.drop_path(gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), HW=HW))
+
+ return x
+
+
+#############################################################################
+# Core Sana with AdaLN Model #
+#################################################################################
+@MODELS.register_module()
+class SanaMSAdaLN(Sana):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+
+ def __init__(
+ self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ class_dropout_prob=0.1,
+ learn_sigma=True,
+ pred_sigma=True,
+ drop_path: float = 0.0,
+ caption_channels=2304,
+ pe_interpolation=1.0,
+ config=None,
+ model_max_length=300,
+ micro_condition=False,
+ qk_norm=False,
+ y_norm=False,
+ norm_eps=1e-5,
+ attn_type="flash",
+ ffn_type="mlp",
+ use_pe=True,
+ y_norm_scale_factor=1.0,
+ patch_embed_kernel=None,
+ mlp_acts=("silu", "silu", None),
+ **kwargs,
+ ):
+ super().__init__(
+ input_size=input_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ hidden_size=hidden_size,
+ depth=depth,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ class_dropout_prob=class_dropout_prob,
+ learn_sigma=learn_sigma,
+ pred_sigma=pred_sigma,
+ drop_path=drop_path,
+ caption_channels=caption_channels,
+ pe_interpolation=pe_interpolation,
+ config=config,
+ model_max_length=model_max_length,
+ micro_condition=micro_condition,
+ qk_norm=qk_norm,
+ y_norm=y_norm,
+ norm_eps=norm_eps,
+ attn_type=attn_type,
+ ffn_type=ffn_type,
+ use_pe=use_pe,
+ y_norm_scale_factor=y_norm_scale_factor,
+ patch_embed_kernel=patch_embed_kernel,
+ mlp_acts=mlp_acts,
+ **kwargs,
+ )
+ self.h = self.w = 0
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ kernel_size = patch_embed_kernel or patch_size
+
+ self.x_embedder = PatchEmbedMS(patch_size, in_channels, hidden_size, kernel_size=kernel_size, bias=True)
+ self.y_embedder = CaptionEmbedder(
+ in_channels=caption_channels,
+ hidden_size=hidden_size,
+ uncond_prob=class_dropout_prob,
+ act_layer=approx_gelu,
+ token_num=model_max_length,
+ )
+ self.micro_conditioning = micro_condition
+ if self.micro_conditioning:
+ self.csize_embedder = SizeEmbedder(hidden_size // 3) # c_size embed
+ self.ar_embedder = SizeEmbedder(hidden_size // 3) # aspect ratio embed
+ self.global_y_embed = None
+ self.t_block = None
+ drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList(
+ [
+ SanaMSAdaLNBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ drop_path=drop_path[i],
+ input_size=(input_size // patch_size, input_size // patch_size),
+ qk_norm=qk_norm,
+ attn_type=attn_type,
+ ffn_type=ffn_type,
+ mlp_acts=mlp_acts,
+ )
+ for i in range(depth)
+ ]
+ )
+ self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
+
+ self.initialize()
+
+ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
+ """
+ Forward pass of Sana.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N, 1, 120, C) tensor of class labels
+ """
+ bs = x.shape[0]
+ x = x.to(self.dtype)
+ timestep = timestep.to(self.dtype)
+ y = y.to(self.dtype)
+ self.h, self.w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
+ if self.use_pe:
+ pos_embed = (
+ torch.from_numpy(
+ get_2d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ (self.h, self.w),
+ pe_interpolation=self.pe_interpolation,
+ base_size=self.base_size,
+ )
+ )
+ .unsqueeze(0)
+ .to(x.device)
+ .to(self.dtype)
+ )
+ x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
+ else:
+ x = self.x_embedder(x)
+
+ t = self.t_embedder(timestep) # (N, D)
+
+ if self.micro_conditioning:
+ c_size, ar = data_info["img_hw"].to(self.dtype), data_info["aspect_ratio"].to(self.dtype)
+ csize = self.csize_embedder(c_size, bs) # (N, D)
+ ar = self.ar_embedder(ar, bs) # (N, D)
+ t = t + torch.cat([csize, ar], dim=1)
+
+ y = self.y_embedder(y, self.training) # (N, D)
+ if self.y_norm:
+ y = self.attention_y_norm(y)
+
+ if self.global_y_embed:
+ global_y = self.global_y_embedder(y) # (N, D)
+ t = t + global_y
+
+ if mask is not None:
+ if mask.shape[0] != y.shape[0]:
+ mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
+ mask = mask.squeeze(1).squeeze(1)
+ y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
+ y_lens = mask.sum(dim=1).tolist()
+ else:
+ y_lens = [y.shape[2]] * y.shape[0]
+ y = y.squeeze(1).view(1, -1, x.shape[-1])
+ for block in self.blocks:
+ x = auto_grad_checkpoint(
+ block, x, y, t, y_lens, (self.h, self.w), **kwargs
+ ) # (N, T, D) #support grad checkpoint
+
+ x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
+ x = self.unpatchify(x) # (N, out_channels, H, W)
+
+ return x
+
+ def forward_with_dpmsolver(self, x, timestep, y, data_info, **kwargs):
+ """
+ dpm solver donnot need variance prediction
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ model_out = self.forward(x, timestep, y, data_info=data_info, **kwargs)
+ return model_out.chunk(2, dim=1)[0] if self.pred_sigma else model_out
+
+ def unpatchify(self, x):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+ p = self.x_embedder.patch_size[0]
+ assert self.h * self.w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c))
+ x = torch.einsum("nhwpqc->nchpwq", x)
+ imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p))
+ return imgs
+
+ def initialize(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+ # nn.init.normal_(self.t_block[1].weight, std=0.02)
+ if self.micro_conditioning:
+ nn.init.normal_(self.csize_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.csize_embedder.mlp[2].weight, std=0.02)
+ nn.init.normal_(self.ar_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.ar_embedder.mlp[2].weight, std=0.02)
+
+ # Initialize caption embedding MLP:
+ nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
+ nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
+
+
+#################################################################################
+# Sana Multi-scale Configs #
+#################################################################################
+
+
+@MODELS.register_module()
+def SanaMSAdaLN_600M_P1_D28(**kwargs):
+ return SanaMSAdaLN(depth=28, hidden_size=1152, patch_size=1, num_heads=16, **kwargs)
+
+
+@MODELS.register_module()
+def SanaMSAdaLN_600M_P2_D28(**kwargs):
+ return SanaMSAdaLN(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
+
+
+@MODELS.register_module()
+def SanaMSAdaLN_600M_P4_D28(**kwargs):
+ return SanaMSAdaLN(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
+
+
+@MODELS.register_module()
+def SanaMSAdaLN_P1_D20(**kwargs):
+ # 20 layers, 1648.48M
+ return SanaMSAdaLN(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)
+
+
+@MODELS.register_module()
+def SanaMSAdaLN_P2_D20(**kwargs):
+ # 28 layers, 1648.48M
+ return SanaMSAdaLN(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs)
diff --git a/diffusion/model/nets/sana_others.py b/diffusion/model/nets/sana_others.py
new file mode 100644
index 0000000..519a011
--- /dev/null
+++ b/diffusion/model/nets/sana_others.py
@@ -0,0 +1,300 @@
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
+import torch
+import torch.nn as nn
+from timm.models.layers import DropPath
+
+from diffusion.model.nets.basic_modules import DWMlp, MBConvPreGLU, Mlp
+from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
+from diffusion.model.nets.sana_blocks import Attention, FlashAttention, MultiHeadCrossAttention, t2i_modulate
+
+
+class SanaMSPABlock(nn.Module):
+ """
+ A Sana block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ reference VIT-22B
+ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L224
+ """
+
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ mlp_ratio=4.0,
+ drop_path=0.0,
+ input_size=None,
+ sampling=None,
+ sr_ratio=1,
+ qk_norm=False,
+ attn_type="flash",
+ ffn_type="mlp",
+ mlp_acts=("silu", "silu", None),
+ **block_kwargs,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.norm1 = nn.LayerNorm(hidden_size * 3, elementwise_affine=False, eps=1e-6)
+ if attn_type == "flash":
+ # flash self attention
+ self.attn = FlashAttention(
+ hidden_size,
+ num_heads=num_heads,
+ qkv_bias=True,
+ sampling=sampling,
+ sr_ratio=sr_ratio,
+ qk_norm=qk_norm,
+ **block_kwargs,
+ )
+ print("currently not support parallel attn")
+ exit()
+ elif attn_type == "linear":
+ # linear self attention
+ # TODO: Here the num_heads set to 36 for tmp used
+ self_num_heads = hidden_size // 32
+ # self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8)
+ self.attn = SlimLiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8)
+ elif attn_type == "triton_linear":
+ # linear self attention with triton kernel fusion
+ self_num_heads = hidden_size // 32
+ self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
+ print("currently not support parallel attn")
+ exit()
+ elif attn_type == "vanilla":
+ # vanilla self attention
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True)
+ print("currently not support parallel attn")
+ exit()
+ else:
+ raise ValueError(f"{attn_type} type is not defined.")
+
+ self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
+ self.norm2 = nn.LayerNorm(int(hidden_size * mlp_ratio * 2), elementwise_affine=False, eps=1e-6)
+ if ffn_type == "dwmlp":
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = DWMlp(
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
+ )
+ print("currently not support parallel attn")
+ exit()
+ elif ffn_type == "glumbconv":
+ self.mlp = SlimGLUMBConv(
+ in_features=hidden_size,
+ hidden_features=int(hidden_size * mlp_ratio),
+ use_bias=(True, True, False),
+ norm=(None, None, None),
+ act=mlp_acts,
+ )
+ elif ffn_type == "mlp":
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
+ )
+ print("currently not support parallel attn")
+ exit()
+ elif ffn_type == "mbconvpreglu":
+ self.mlp = MBConvPreGLU(
+ in_dim=hidden_size,
+ out_dim=hidden_size,
+ mid_dim=int(hidden_size * mlp_ratio),
+ use_bias=(True, True, False),
+ norm=None,
+ act=("silu", "silu", None),
+ )
+ print("currently not support parallel attn")
+ exit()
+ else:
+ raise ValueError(f"{ffn_type} type is not defined.")
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
+
+ # parallel layers
+ self.mlp_ratio = mlp_ratio
+ self.in_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.in_proj = nn.Linear(hidden_size, (hidden_size * 3 + int(hidden_size * mlp_ratio * 2)))
+ self.in_split = [hidden_size * 3] + [int(hidden_size * mlp_ratio * 2)]
+
+ def forward(self, x, y, t, mask=None, HW=None, **kwargs):
+ B, N, C = x.shape
+
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + t.reshape(B, 6, -1)
+ ).chunk(6, dim=1)
+ # original Attention code
+ # x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
+ # x = x + self.cross_attn(x, y, mask)
+ # x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp), HW=HW))
+
+ # combine GLUMBConv fc1 & qkv projections
+ # x_1 = self.in_norm(x)
+ # x_1 = self.in_proj(x_1)
+ x_1 = self.in_proj(self.in_norm(x))
+ qkv, x_mlp = torch.split(x_1, self.in_split, dim=-1)
+
+ qkv = t2i_modulate(self.norm1(qkv), shift_msa.repeat(1, 1, 3), scale_msa.repeat(1, 1, 3))
+ x_mlp = t2i_modulate(
+ self.norm2(x_mlp),
+ shift_mlp.repeat(1, 1, int(self.mlp_ratio * 2)),
+ scale_mlp.repeat(1, 1, int(self.mlp_ratio * 2)),
+ )
+ # qkv = self.norm1(qkv)
+ # x_mlp = self.norm2(x_mlp)
+
+ # branch 1
+ x_attn = gate_msa * self.attn(qkv, HW=HW)
+ x_attn = x_attn + self.cross_attn(x_attn, y, mask)
+
+ # branch 2
+ x_mlp = gate_mlp * self.mlp(x_mlp, HW=HW)
+
+ # Add residual w/ drop path & layer scale applied
+ x = x + self.drop_path(x_attn + x_mlp)
+
+ return x
+
+
+class SanaMSPABlock(nn.Module):
+ """
+ A Sana block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ reference VIT-22B
+ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L224
+ """
+
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ mlp_ratio=4.0,
+ drop_path=0.0,
+ input_size=None,
+ sampling=None,
+ sr_ratio=1,
+ qk_norm=False,
+ attn_type="flash",
+ ffn_type="mlp",
+ mlp_acts=("silu", "silu", None),
+ **block_kwargs,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.norm1 = nn.LayerNorm(hidden_size * 3, elementwise_affine=False, eps=1e-6)
+ if attn_type == "flash":
+ # flash self attention
+ self.attn = FlashAttention(
+ hidden_size,
+ num_heads=num_heads,
+ qkv_bias=True,
+ sampling=sampling,
+ sr_ratio=sr_ratio,
+ qk_norm=qk_norm,
+ **block_kwargs,
+ )
+ print("currently not support parallel attn")
+ exit()
+ elif attn_type == "linear":
+ # linear self attention
+ # TODO: Here the num_heads set to 36 for tmp used
+ self_num_heads = hidden_size // 32
+ # self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8)
+ self.attn = SlimLiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8)
+ elif attn_type == "triton_linear":
+ # linear self attention with triton kernel fusion
+ self_num_heads = hidden_size // 32
+ self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
+ print("currently not support parallel attn")
+ exit()
+ elif attn_type == "vanilla":
+ # vanilla self attention
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True)
+ print("currently not support parallel attn")
+ exit()
+ else:
+ raise ValueError(f"{attn_type} type is not defined.")
+
+ self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
+ self.norm2 = nn.LayerNorm(int(hidden_size * mlp_ratio * 2), elementwise_affine=False, eps=1e-6)
+ if ffn_type == "dwmlp":
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = DWMlp(
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
+ )
+ print("currently not support parallel attn")
+ exit()
+ elif ffn_type == "glumbconv":
+ self.mlp = SlimGLUMBConv(
+ in_features=hidden_size,
+ hidden_features=int(hidden_size * mlp_ratio),
+ use_bias=(True, True, False),
+ norm=(None, None, None),
+ act=mlp_acts,
+ )
+ elif ffn_type == "mlp":
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
+ )
+ print("currently not support parallel attn")
+ exit()
+ elif ffn_type == "mbconvpreglu":
+ self.mlp = MBConvPreGLU(
+ in_dim=hidden_size,
+ out_dim=hidden_size,
+ mid_dim=int(hidden_size * mlp_ratio),
+ use_bias=(True, True, False),
+ norm=None,
+ act=("silu", "silu", None),
+ )
+ print("currently not support parallel attn")
+ exit()
+ else:
+ raise ValueError(f"{ffn_type} type is not defined.")
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
+
+ # parallel layers
+ self.mlp_ratio = mlp_ratio
+ self.in_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.in_proj = nn.Linear(hidden_size, (hidden_size * 3 + int(hidden_size * mlp_ratio * 2)))
+ self.in_split = [hidden_size * 3] + [int(hidden_size * mlp_ratio * 2)]
+
+ def forward(self, x, y, t, mask=None, HW=None, **kwargs):
+ B, N, C = x.shape
+
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + t.reshape(B, 6, -1)
+ ).chunk(6, dim=1)
+ x_1 = self.in_proj(self.in_norm(x))
+ qkv, x_mlp = torch.split(x_1, self.in_split, dim=-1)
+
+ qkv = t2i_modulate(self.norm1(qkv), shift_msa.repeat(1, 1, 3), scale_msa.repeat(1, 1, 3))
+ x_mlp = t2i_modulate(
+ self.norm2(x_mlp),
+ shift_mlp.repeat(1, 1, int(self.mlp_ratio * 2)),
+ scale_mlp.repeat(1, 1, int(self.mlp_ratio * 2)),
+ )
+
+ # branch 1
+ x_attn = gate_msa * self.attn(qkv, HW=HW)
+ x_attn = x_attn + self.cross_attn(x_attn, y, mask)
+
+ # branch 2
+ x_mlp = gate_mlp * self.mlp(x_mlp, HW=HW)
+
+ # Add residual w/ drop path & layer scale applied
+ x = x + self.drop_path(x_attn + x_mlp)
+
+ return x
diff --git a/tools/metrics/clip-score/.gitignore b/tools/metrics/clip-score/.gitignore
new file mode 100644
index 0000000..0447b8b
--- /dev/null
+++ b/tools/metrics/clip-score/.gitignore
@@ -0,0 +1,116 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
diff --git a/tools/metrics/clip-score/LICENSE b/tools/metrics/clip-score/LICENSE
new file mode 100644
index 0000000..261eeb9
--- /dev/null
+++ b/tools/metrics/clip-score/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/tools/metrics/clip-score/README.md b/tools/metrics/clip-score/README.md
new file mode 100644
index 0000000..3ca5b01
--- /dev/null
+++ b/tools/metrics/clip-score/README.md
@@ -0,0 +1,99 @@
+# CLIP Score for PyTorch
+
+[![PyPI](https://img.shields.io/pypi/v/clip-score.svg)](https://pypi.org/project/clip-score/)
+
+This repository provides a batch-wise quick processing for calculating CLIP scores. It uses the pretrained CLIP model to measure the cosine similarity between two modalities. The project structure is adapted from [pytorch-fid](https://github.com/mseitzer/pytorch-fid) and [CLIP](https://github.com/openai/CLIP).
+
+## Installation
+
+Requirements:
+
+- Install PyTorch:
+ ```
+ pip install torch # Choose a version that suits your GPU
+ ```
+- Install CLIP:
+ ```
+ pip install git+https://github.com/openai/CLIP.git
+ ```
+- Install clip-score from [PyPI](https://pypi.org/project/clip-score/):
+ ```
+ pip install clip-score
+ ```
+
+## Data Input Specifications
+
+This project is designed to process paired images and text files, and therefore requires two directories: one for images and one for text files.
+
+### Image Files
+
+All images should be stored in a single directory. The image files can be in either `.png` or `.jpg` format.
+
+### Text Files
+
+All text data should be contained in plain text files in a separate directory. These text files should have the extension `.txt`.
+
+### File Number and Naming
+
+The number of files in the image directory should be exactly equal to the number of files in the text directory. Additionally, the files in the image directory and text directory should be paired by file name. For instance, if there is a `cat.png` in the image directory, there should be a corresponding `cat.txt` in the text directory.
+
+### Directory Structure Example
+
+Below is an example of the expected directory structure:
+
+```plaintext
+├── path/to/image
+│ ├── cat.png
+│ ├── dog.png
+│ └── bird.jpg
+└── path/to/text
+ ├── cat.txt
+ ├── dog.txt
+ └── bird.txt
+```
+
+In this example, `cat.png` is paired with `cat.txt`, `dog.png` is paired with `dog.txt`, and `bird.jpg` is paired with `bird.txt`.
+
+Please adhere to the specified structure to ensure correct operation of the program. If there are any questions or issues, feel free to raise an issue here on GitHub.
+
+## Usage
+
+To compute the CLIP score between images and texts, make sure that the image and text data are contained in two separate folders, and each sample has the same name in both modalities. Run the following command:
+
+```
+python -m clip_score path/to/image path/to/text
+```
+
+If GPU is available, the project is set to run automatically on a GPU by default. If you want to specify a particular GPU, you can use the `--device cuda:N` flag when running the script, where `N` is the index of the GPU you wish to use. In case you want to run the program on a CPU instead, you can specify this by using the `--device cpu` flag.
+
+## Computing CLIP Score within the Same Modality
+
+If you want to calculate the CLIP score within the same modality (e.g., image-image or text-text), follow the same folder structure as mentioned above. Additionally, specify the preferred modalities using the `--real_flag` and `--fake_flag` options. By default, `--real_flag=img` and `--fake_flag=txt`. Examples:
+
+```
+python -m clip_score path/to/imageA path/to/imageB --real_flag img --fake_flag img
+python -m clip_score path/to/textA path/to/textB --real_flag txt --fake_flag txt
+```
+
+## Citing
+
+If you use this repository in your research, consider citing it using the following Bibtex entry:
+
+```
+@misc{taited2023CLIPScore,
+ author={SUN Zhengwentai},
+ title={{clip-score: CLIP Score for PyTorch}},
+ month={March},
+ year={2023},
+ note={Version 0.1.1},
+ howpublished={\url{https://github.com/taited/clip-score}},
+}
+```
+
+## License
+
+This implementation is licensed under the Apache License 2.0.
+
+The project structure is adapted from [mseitzer's pytorch-fid](https://github.com/mseitzer/pytorch-fid) project. The CLIP model is adapted from [OpenAI's CLIP](https://github.com/openai/CLIP).
+
+The CLIP Score was introduced in OpenAI's [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020).
diff --git a/tools/metrics/clip-score/clip_score.py b/tools/metrics/clip-score/clip_score.py
new file mode 100644
index 0000000..d0c3e05
--- /dev/null
+++ b/tools/metrics/clip-score/clip_score.py
@@ -0,0 +1,399 @@
+import io
+import os
+from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
+
+import clip
+import numpy as np
+import torch
+import webdataset as wds
+from PIL import Image
+from torch.utils.data import DataLoader, Dataset, IterableDataset
+
+from diffusion.data.transforms import get_transform
+from tools.metrics.utils import tracker
+
+try:
+ from tqdm import tqdm
+except ImportError:
+ # If tqdm is not available, provide a mock version of it
+ def tqdm(x):
+ return x
+
+
+import json
+
+IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}
+TEXT_EXTENSIONS = {"txt"}
+
+
+class DummyDataset(Dataset):
+ FLAGS = ["img", "txt", "json"]
+
+ def __init__(
+ self,
+ real_path,
+ fake_path,
+ real_flag: str = "img",
+ fake_flag: str = "img",
+ gen_img_path="",
+ transform=None,
+ tokenizer=None,
+ ) -> None:
+ super().__init__()
+ assert (
+ real_flag in self.FLAGS and fake_flag in self.FLAGS
+ ), f"CLIP Score only support modality of {self.FLAGS}. However, get {real_flag} and {fake_flag}"
+ self.gen_img_path = gen_img_path
+ print(f"images are from {gen_img_path}")
+ self.real_folder = self._load_img_from_path(real_path)
+ self.real_flag = real_flag
+ self.fake_data = self._load_txt_from_path(fake_path)
+ self.transform = transform
+ self.tokenizer = tokenizer
+ self.data_dict = {}
+
+ def __len__(self):
+ return len(self.real_folder)
+
+ def __getitem__(self, index):
+ if index >= len(self):
+ raise IndexError
+ real_path = self.real_folder[index]
+ real_data = self._load_modality(real_path, self.real_flag)
+ fake_data = self._load_txt(self.fake_data[index])
+ sample = dict(real=real_data, fake=fake_data, prompt=self.fake_data[index])
+ return sample
+
+ def _load_modality(self, path, modality):
+ if modality == "img":
+ data = self._load_img(path)
+ else:
+ raise TypeError(f"Got unexpected modality: {modality}")
+ return data
+
+ def _load_txt(self, data):
+ if self.tokenizer is not None:
+ data = self.tokenizer(data, context_length=77, truncate=True).squeeze()
+ return data
+
+ def _load_img(self, path):
+ img = Image.open(path)
+ if self.transform is not None:
+ img = self.transform(img)
+ return img
+
+ def _load_img_from_path(self, path):
+ image_list = []
+ if path.endswith(".json"):
+ with open(path) as file:
+ data_dict = json.load(file)
+ all_lines = list(data_dict.keys())[:sample_nums]
+ if isinstance(all_lines, list):
+ for k in all_lines:
+ img_path = os.path.join(self.gen_img_path, f"{k}.jpg")
+ image_list.append(img_path)
+ elif isinstance(all_lines, dict):
+ assert sample_nums >= 30_000, ValueError(f"{sample_nums} is not supported for json files")
+ for k, v in all_lines.items():
+ img_path = os.path.join(self.gen_img_path, f"{k}.jpg")
+ image_list.append(img_path)
+
+ else:
+ raise ValueError(f"Only JSON file type is supported now. Wrong with: {path}")
+
+ return image_list
+
+ def _load_txt_from_path(self, path):
+ txt_list = []
+ if path.endswith(".json"):
+ with open(path) as file:
+ data_dict = json.load(file)
+ all_lines = list(data_dict.keys())[:sample_nums]
+ if isinstance(all_lines, list):
+ for k in all_lines:
+ v = data_dict[k]
+ txt_list.append(v["prompt"])
+ elif isinstance(all_lines, dict):
+ assert sample_nums >= 30_000, ValueError(f"{sample_nums} is not supported for json files")
+ for k, v in all_lines.items():
+ txt_list.append(v["prompt"])
+ else:
+ raise ValueError(f"Only JSON file type is supported now. Wrong with: {path}")
+
+ return txt_list
+
+
+class DummyTarDataset(IterableDataset):
+ def __init__(
+ self, tar_path, transform=None, external_json_path=None, prompt_key="prompt", tokenizer=None, **kwargs
+ ):
+ assert ".tar" in tar_path
+ self.sample_nums = args.sample_nums
+ self.dataset = (
+ wds.WebDataset(tar_path)
+ .map(self.safe_decode)
+ .to_tuple("png;jpg", "json", "__key__")
+ .map(self.process_sample)
+ .slice(0, self.sample_nums)
+ )
+ if external_json_path is not None and os.path.exists(external_json_path):
+ print(f"Loading {external_json_path}, wait...")
+ self.json_file = json.load(open(external_json_path))
+ else:
+ self.json_file = {}
+ assert prompt_key == "prompt"
+ self.prompt_key = prompt_key
+ self.transform = transform
+ self.tokenizer = tokenizer
+
+ def __iter__(self):
+ return self._generator()
+
+ def _generator(self):
+ for i, (ori_img, info, key) in enumerate(self.dataset):
+ if self.transform is not None:
+ img = self.transform(ori_img)
+
+ if key in self.json_file:
+ info.update(self.json_file[key])
+
+ prompt = info.get(self.prompt_key, "")
+ if not prompt:
+ prompt = ""
+ print(f"{self.prompt_key} not exist in {key}.json")
+ txt_feat = self._load_txt(prompt)
+
+ yield dict(
+ real=img, fake=txt_feat, prompt=prompt, ori_img=np.array(img), key=key, prompt_key=self.prompt_key
+ )
+
+ def __len__(self):
+ return self.sample_nums
+
+ def _load_txt(self, data):
+ if self.tokenizer is not None:
+ data = self.tokenizer(data, context_length=77, truncate=True).squeeze()
+ return data
+
+ @staticmethod
+ def process_sample(sample):
+ try:
+ image_bytes, json_bytes, key = sample
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
+ json_dict = json.loads(json_bytes)
+ return image, json_dict, key
+ except (ValueError, TypeError, OSError) as e:
+ print(f"Skipping sample due to error: {e}")
+ return None
+
+ @staticmethod
+ def safe_decode(sample):
+ def custom_decode(sample):
+ result = {}
+ for k, v in sample.items():
+ result[k] = v
+ return result
+
+ try:
+ return custom_decode(sample)
+ except Exception as e:
+ print(f"skipping sample due to decode error: {e}")
+ return None
+
+
+@torch.no_grad()
+def calculate_clip_score(dataloader, model, real_flag, fake_flag, save_json_path=None):
+ score_acc = 0.0
+ sample_num = 0.0
+ json_dict = {} if save_json_path is not None else None
+ logit_scale = model.logit_scale.exp()
+ for batch_data in tqdm(dataloader, desc=f"CLIP-Score: {args.exp_name}", position=args.gpu_id, leave=True):
+ real_features = forward_modality(model, batch_data["real"], real_flag)
+ fake_features = forward_modality(model, batch_data["fake"], fake_flag)
+
+ # normalize features
+ real_features = real_features / real_features.norm(dim=1, keepdim=True).to(torch.float32)
+ fake_features = fake_features / fake_features.norm(dim=1, keepdim=True).to(torch.float32)
+
+ score = logit_scale * (fake_features * real_features).sum()
+ if save_json_path is not None:
+ json_dict[batch_data["key"][0]] = {f"{batch_data['prompt_key'][0]}": f"{score:.04f}"}
+
+ score_acc += score
+ sample_num += batch_data["real"].shape[0]
+
+ if save_json_path is not None:
+ json.dump(json_dict, open(save_json_path, "w"))
+ return score_acc / sample_num
+
+
+@torch.no_grad()
+def calculate_clip_score_official(dataloader):
+ import numpy as np
+ from torchmetrics.multimodal.clip_score import CLIPScore
+
+ clip_score_fn = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to(device)
+ # clip_score_fn = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16").to(device)
+ all_clip_scores = []
+
+ for batch_data in tqdm(dataloader, desc=args.exp_name, position=args.gpu_id, leave=True):
+ imgs = batch_data["real"].add_(1.0).mul_(0.5)
+ imgs = (imgs * 255).to(dtype=torch.uint8, device=device)
+
+ prompts = batch_data["prompt"]
+ clip_scores = clip_score_fn(imgs, prompts).detach().cpu()
+ all_clip_scores.append(float(clip_scores))
+
+ clip_scores = float(np.mean(all_clip_scores))
+ return clip_scores
+
+
+def forward_modality(model, data, flag):
+ device = next(model.parameters()).device
+ if flag == "img":
+ features = model.encode_image(data.to(device))
+ elif flag == "txt":
+ features = model.encode_text(data.to(device))
+ else:
+ raise TypeError
+ return features
+
+
+def main():
+ txt_path = args.txt_path if args.txt_path is not None else args.img_path
+ gen_img_path = str(os.path.join(args.img_path, args.exp_name))
+ if ".tar" in gen_img_path:
+ save_txt_path = os.path.join(txt_path, f"{args.exp_name}_{args.tar_prompt_key}_clip_score.txt").replace(
+ ".tar", ""
+ )
+ save_json_path = save_txt_path.replace(".tar", "").replace(".txt", ".json")
+ if os.path.exists(save_json_path):
+ print(f"{save_json_path} exists. Finished.")
+ return None
+ else:
+ save_txt_path = os.path.join(txt_path, f"{args.exp_name}_sample{sample_nums}_clip_score.txt")
+ save_json_path = None
+ if os.path.exists(save_txt_path):
+ with open(save_txt_path) as f:
+ clip_score = f.readlines()[0].strip()
+ print(f"CLIP Score: {clip_score}: {args.exp_name}")
+ return {args.exp_name: float(clip_score)}
+
+ print(f"Loading CLIP model: {args.clip_model}")
+ if args.clipscore_type == "diffusers":
+ preprocess = get_transform("default_train", 512)
+ else:
+ model, preprocess = clip.load(args.clip_model, device=device)
+
+ if ".tar" in gen_img_path:
+ dataset = DummyTarDataset(
+ gen_img_path,
+ transform=preprocess,
+ external_json_path=args.external_json_file,
+ prompt_key=args.tar_prompt_key,
+ tokenizer=clip.tokenize,
+ )
+ else:
+ dataset = DummyDataset(
+ args.real_path,
+ args.fake_path,
+ args.real_flag,
+ args.fake_flag,
+ transform=preprocess,
+ tokenizer=clip.tokenize,
+ gen_img_path=gen_img_path,
+ )
+ dataloader = DataLoader(dataset, args.batch_size, num_workers=num_workers, pin_memory=True)
+
+ print("Calculating CLIP Score:")
+ if args.clipscore_type == "diffusers":
+ clip_score = calculate_clip_score_official(dataloader)
+ else:
+ clip_score = calculate_clip_score(
+ dataloader, model, args.real_flag, args.fake_flag, save_json_path=save_json_path
+ )
+ clip_score = clip_score.cpu().item()
+ print("CLIP Score: ", clip_score)
+ with open(save_txt_path, "w") as file:
+ file.write(str(clip_score))
+ print(f"Result saved at: {save_txt_path}")
+
+ return {args.exp_name: clip_score}
+
+
+def parse_args():
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+ parser.add_argument("--batch-size", type=int, default=50, help="Batch size to use")
+ parser.add_argument("--clip-model", type=str, default="ViT-L/14", help="CLIP model to use")
+ # parser.add_argument('--clip-model', type=str, default='ViT-B/16', help='CLIP model to use')
+ parser.add_argument("--img_path", type=str, default=None)
+ parser.add_argument("--txt_path", type=str, default=None)
+ parser.add_argument("--sample_nums", type=int, default=30_000)
+ parser.add_argument("--exp_name", type=str, default="Sana")
+ parser.add_argument(
+ "--num-workers", type=int, help="Number of processes to use for data loading. Defaults to `min(8, num_cpus)`"
+ )
+ parser.add_argument("--device", type=str, default=None, help="Device to use. Like cuda, cuda:0 or cpu")
+ parser.add_argument("--real_flag", type=str, default="img", help="The modality of real path. Default to img")
+ parser.add_argument("--fake_flag", type=str, default="txt", help="The modality of real path. Default to txt")
+ parser.add_argument("--real_path", type=str, help="Paths to the generated images")
+ parser.add_argument("--fake_path", type=str, help="Paths to the generated images")
+ parser.add_argument("--external_json_file", type=str, default=None, help="external meta json file for tar_file")
+ parser.add_argument("--tar_prompt_key", type=str, default="prompt", help="key name of prompt in json")
+
+ # online logging setting
+ parser.add_argument("--clipscore_type", type=str, default="self", choices=["diffusers", "self"])
+ parser.add_argument("--log_metric", type=str, default="metric")
+ parser.add_argument("--gpu_id", type=int, default=0)
+ parser.add_argument("--log_clip_score", action="store_true")
+ parser.add_argument("--suffix_label", type=str, default="", help="used for clip_score online log")
+ parser.add_argument("--tracker_pattern", type=str, default="epoch_step", help="used for fid online log")
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default=None,
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="t2i-evit-baseline",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+ parser.add_argument(
+ "--name",
+ type=str,
+ default="baseline",
+ help=("Wandb Project Name"),
+ )
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ sample_nums = args.sample_nums
+ if args.device is None:
+ device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
+ else:
+ device = torch.device(args.device)
+
+ if args.num_workers is None:
+ try:
+ num_cpus = len(os.sched_getaffinity(0))
+ except AttributeError:
+ num_cpus = os.cpu_count()
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
+ else:
+ num_workers = args.num_workers
+
+ args.exp_name = os.path.basename(args.exp_name) or os.path.dirname(args.exp_name)
+ clip_score_result = main()
+ if args.log_clip_score:
+ tracker(args, clip_score_result, args.suffix_label, pattern=args.tracker_pattern, metric="CLIP-Score")
diff --git a/tools/metrics/clip-score/setup.py b/tools/metrics/clip-score/setup.py
new file mode 100644
index 0000000..67ada71
--- /dev/null
+++ b/tools/metrics/clip-score/setup.py
@@ -0,0 +1,53 @@
+import os
+
+import setuptools
+
+
+def read(rel_path):
+ base_path = os.path.abspath(os.path.dirname(__file__))
+ with open(os.path.join(base_path, rel_path)) as f:
+ return f.read()
+
+
+def get_version(rel_path):
+ for line in read(rel_path).splitlines():
+ if line.startswith("__version__"):
+ delim = '"' if '"' in line else "'"
+ return line.split(delim)[1]
+
+ raise RuntimeError("Unable to find version string.")
+
+
+if __name__ == "__main__":
+ setuptools.setup(
+ name="clip-score",
+ version=get_version(os.path.join("src", "clip_score", "__init__.py")),
+ author="Taited",
+ author_email="taited9160@gmail.com",
+ description=("Package for calculating CLIP-Score" " using PyTorch"),
+ long_description=read("README.md"),
+ long_description_content_type="text/markdown",
+ url="https://github.com/taited/clip-score",
+ package_dir={"": "src"},
+ packages=setuptools.find_packages(where="src"),
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ ],
+ python_requires=">=3.5",
+ entry_points={
+ "console_scripts": [
+ "clip-score = clip_score.clip_score:main",
+ ],
+ },
+ install_requires=[
+ "numpy",
+ "pillow",
+ "torch>=1.7.1",
+ "torchvision>=0.8.2",
+ "ftfy",
+ "regex",
+ "tqdm",
+ ],
+ extras_require={"dev": ["flake8", "flake8-bugbear", "flake8-isort", "nox"]},
+ )
diff --git a/tools/metrics/clip-score/src/clip_score/__init__.py b/tools/metrics/clip-score/src/clip_score/__init__.py
new file mode 100644
index 0000000..485f44a
--- /dev/null
+++ b/tools/metrics/clip-score/src/clip_score/__init__.py
@@ -0,0 +1 @@
+__version__ = "0.1.1"
diff --git a/tools/metrics/clip-score/src/clip_score/__main__.py b/tools/metrics/clip-score/src/clip_score/__main__.py
new file mode 100644
index 0000000..71cac08
--- /dev/null
+++ b/tools/metrics/clip-score/src/clip_score/__main__.py
@@ -0,0 +1,3 @@
+import clip_score.clip_score
+
+clip_score.clip_score.main()
diff --git a/tools/metrics/clip-score/src/clip_score/clip_score.py b/tools/metrics/clip-score/src/clip_score/clip_score.py
new file mode 100644
index 0000000..63c4d84
--- /dev/null
+++ b/tools/metrics/clip-score/src/clip_score/clip_score.py
@@ -0,0 +1,210 @@
+"""Calculates the CLIP Scores
+
+The CLIP model is a contrasitively learned language-image model. There is
+an image encoder and a text encoder. It is believed that the CLIP model could
+measure the similarity of cross modalities. Please find more information from
+https://github.com/openai/CLIP.
+
+The CLIP Score measures the Cosine Similarity between two embedded features.
+This repository utilizes the pretrained CLIP Model to calculate
+the mean average of cosine similarities.
+
+See --help to see further details.
+
+Code apapted from https://github.com/mseitzer/pytorch-fid and https://github.com/openai/CLIP.
+
+Copyright 2023 The Hong Kong Polytechnic University
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import os
+import os.path as osp
+from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
+
+import clip
+import torch
+from PIL import Image
+from torch.utils.data import DataLoader, Dataset
+
+try:
+ from tqdm import tqdm
+except ImportError:
+ # If tqdm is not available, provide a mock version of it
+ def tqdm(x):
+ return x
+
+
+parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+parser.add_argument("--batch-size", type=int, default=50, help="Batch size to use")
+parser.add_argument("--clip-model", type=str, default="ViT-B/32", help="CLIP model to use")
+parser.add_argument(
+ "--num-workers", type=int, help=("Number of processes to use for data loading. " "Defaults to `min(8, num_cpus)`")
+)
+parser.add_argument("--device", type=str, default=None, help="Device to use. Like cuda, cuda:0 or cpu")
+parser.add_argument("--real_flag", type=str, default="img", help=("The modality of real path. " "Default to img"))
+parser.add_argument("--fake_flag", type=str, default="txt", help=("The modality of real path. " "Default to txt"))
+parser.add_argument("real_path", type=str, help=("Paths to the generated images or " "to .npz statistic files"))
+parser.add_argument("fake_path", type=str, help=("Paths to the generated images or " "to .npz statistic files"))
+
+IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}
+
+TEXT_EXTENSIONS = {"txt"}
+
+
+class DummyDataset(Dataset):
+
+ FLAGS = ["img", "txt"]
+
+ def __init__(
+ self, real_path, fake_path, real_flag: str = "img", fake_flag: str = "img", transform=None, tokenizer=None
+ ) -> None:
+ super().__init__()
+ assert (
+ real_flag in self.FLAGS and fake_flag in self.FLAGS
+ ), f"CLIP Score only support modality of {self.FLAGS}. However, get {real_flag} and {fake_flag}"
+ self.real_folder = self._combine_without_prefix(real_path)
+ self.real_flag = real_flag
+ self.fake_foler = self._combine_without_prefix(fake_path)
+ self.fake_flag = fake_flag
+ self.transform = transform
+ self.tokenizer = tokenizer
+ # assert self._check()
+
+ def __len__(self):
+ return len(self.real_folder)
+
+ def __getitem__(self, index):
+ if index >= len(self):
+ raise IndexError
+ real_path = self.real_folder[index]
+ fake_path = self.fake_foler[index]
+ real_data = self._load_modality(real_path, self.real_flag)
+ fake_data = self._load_modality(fake_path, self.fake_flag)
+
+ sample = dict(real=real_data, fake=fake_data)
+ return sample
+
+ def _load_modality(self, path, modality):
+ if modality == "img":
+ data = self._load_img(path)
+ elif modality == "txt":
+ data = self._load_txt(path)
+ else:
+ raise TypeError(f"Got unexpected modality: {modality}")
+ return data
+
+ def _load_img(self, path):
+ img = Image.open(path)
+ if self.transform is not None:
+ img = self.transform(img)
+ return img
+
+ def _load_txt(self, path):
+ with open(path) as fp:
+ data = fp.read()
+ fp.close()
+ if self.tokenizer is not None:
+ data = self.tokenizer(data).squeeze()
+ return data
+
+ def _check(self):
+ for idx in range(len(self)):
+ real_name = self.real_folder[idx].split(".")
+ fake_name = self.fake_folder[idx].split(".")
+ if fake_name != real_name:
+ return False
+ return True
+
+ def _combine_without_prefix(self, folder_path, prefix="."):
+ folder = []
+ for name in os.listdir(folder_path):
+ if name[0] == prefix:
+ continue
+ folder.append(osp.join(folder_path, name))
+ folder.sort()
+ return folder
+
+
+@torch.no_grad()
+def calculate_clip_score(dataloader, model, real_flag, fake_flag):
+ score_acc = 0.0
+ sample_num = 0.0
+ logit_scale = model.logit_scale.exp()
+ for batch_data in tqdm(dataloader):
+ real = batch_data["real"]
+ real_features = forward_modality(model, real, real_flag)
+ fake = batch_data["fake"]
+ fake_features = forward_modality(model, fake, fake_flag)
+
+ # normalize features
+ real_features = real_features / real_features.norm(dim=1, keepdim=True).to(torch.float32)
+ fake_features = fake_features / fake_features.norm(dim=1, keepdim=True).to(torch.float32)
+
+ # calculate scores
+ # score = logit_scale * real_features @ fake_features.t()
+ # score_acc += torch.diag(score).sum()
+ score = logit_scale * (fake_features * real_features).sum()
+ score_acc += score
+ sample_num += real.shape[0]
+
+ return score_acc / sample_num
+
+
+def forward_modality(model, data, flag):
+ device = next(model.parameters()).device
+ if flag == "img":
+ features = model.encode_image(data.to(device))
+ elif flag == "txt":
+ features = model.encode_text(data.to(device))
+ else:
+ raise TypeError
+ return features
+
+
+def main():
+ args = parser.parse_args()
+
+ if args.device is None:
+ device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
+ else:
+ device = torch.device(args.device)
+
+ if args.num_workers is None:
+ try:
+ num_cpus = len(os.sched_getaffinity(0))
+ except AttributeError:
+ # os.sched_getaffinity is not available under Windows, use
+ # os.cpu_count instead (which may not return the *available* number
+ # of CPUs).
+ num_cpus = os.cpu_count()
+
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
+ else:
+ num_workers = args.num_workers
+
+ print(f"Loading CLIP model: {args.clip_model}")
+ model, preprocess = clip.load(args.clip_model, device=device)
+
+ dataset = DummyDataset(
+ args.real_path, args.fake_path, args.real_flag, args.fake_flag, transform=preprocess, tokenizer=clip.tokenize
+ )
+ dataloader = DataLoader(dataset, args.batch_size, num_workers=num_workers, pin_memory=True)
+
+ print("Calculating CLIP Score:")
+ clip_score = calculate_clip_score(dataloader, model, args.real_flag, args.fake_flag)
+ clip_score = clip_score.cpu().item()
+ print("CLIP Score: ", clip_score)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/metrics/pytorch-fid/.gitignore b/tools/metrics/pytorch-fid/.gitignore
new file mode 100644
index 0000000..0447b8b
--- /dev/null
+++ b/tools/metrics/pytorch-fid/.gitignore
@@ -0,0 +1,116 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
diff --git a/tools/metrics/pytorch-fid/CHANGELOG.md b/tools/metrics/pytorch-fid/CHANGELOG.md
new file mode 100644
index 0000000..8f2f744
--- /dev/null
+++ b/tools/metrics/pytorch-fid/CHANGELOG.md
@@ -0,0 +1,41 @@
+# Changelog
+
+## \[0.3.0\] - 2023-01-05
+
+### Added
+
+- Add argument `--save-stats` allowing to compute dataset statistics and save them as an `.npz` file ([#80](https://github.com/mseitzer/pytorch-fid/pull/80)). The `.npz` file can be used in subsequent FID computations instead of recomputing the dataset statistics. This option can be used in the following way: `python -m pytorch_fid --save-stats path/to/dataset path/to/outputfile`.
+
+### Fixed
+
+- Do not use `os.sched_getaffinity` to get number of available CPUs on Windows, as it is not available there ([232b3b14](https://github.com/mseitzer/pytorch-fid/commit/232b3b1468800102fcceaf6f2bb8977811fc991a), [#84](https://github.com/mseitzer/pytorch-fid/issues/84)).
+- Do not use Inception model argument `pretrained`, as it was deprecated in torchvision 0.13 ([#88](https://github.com/mseitzer/pytorch-fid/pull/88)).
+
+## \[0.2.1\] - 2021-10-10
+
+### Added
+
+- Add argument `--num-workers` to select number of dataloader processes ([#66](https://github.com/mseitzer/pytorch-fid/pull/66)). Defaults to 8 or the number of available CPUs if less than 8 CPUs are available.
+
+### Fixed
+
+- Fixed package setup to work under Windows ([#55](https://github.com/mseitzer/pytorch-fid/pull/55), [#72](https://github.com/mseitzer/pytorch-fid/issues/72))
+
+## \[0.2.0\] - 2020-11-30
+
+### Added
+
+- Load images using a Pytorch dataloader, which should result in a speed-up. ([#47](https://github.com/mseitzer/pytorch-fid/pull/47))
+- Support more image extensions ([#53](https://github.com/mseitzer/pytorch-fid/pull/53))
+- Improve tooling by setting up Nox, add linting and test support ([#52](https://github.com/mseitzer/pytorch-fid/pull/52))
+- Add some unit tests
+
+## \[0.1.1\] - 2020-08-16
+
+### Fixed
+
+- Fixed software license string in `setup.py`
+
+## \[0.1.0\] - 2020-08-16
+
+Initial release as a pypi package. Use `pip install pytorch-fid` to install.
diff --git a/tools/metrics/pytorch-fid/LICENSE b/tools/metrics/pytorch-fid/LICENSE
new file mode 100644
index 0000000..261eeb9
--- /dev/null
+++ b/tools/metrics/pytorch-fid/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/tools/metrics/pytorch-fid/README.md b/tools/metrics/pytorch-fid/README.md
new file mode 100644
index 0000000..c74549a
--- /dev/null
+++ b/tools/metrics/pytorch-fid/README.md
@@ -0,0 +1,93 @@
+[![PyPI](https://img.shields.io/pypi/v/pytorch-fid.svg)](https://pypi.org/project/pytorch-fid/)
+
+# FID score for PyTorch
+
+This is a port of the official implementation of [Fréchet Inception Distance](https://arxiv.org/abs/1706.08500) to PyTorch.
+See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR) for the original implementation using Tensorflow.
+
+FID is a measure of similarity between two datasets of images.
+It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks.
+FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network.
+
+Further insights and an independent evaluation of the FID score can be found in [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337).
+
+The weights and the model are exactly the same as in [the official Tensorflow implementation](https://github.com/bioinf-jku/TTUR), and were tested to give very similar results (e.g. `.08` absolute error and `0.0009` relative error on LSUN, using ProGAN generated images). However, due to differences in the image interpolation implementation and library backends, FID results still differ slightly from the original implementation. So if you report FID scores in your paper, and you want them to be *exactly comparable* to FID scores reported in other papers, you should consider using [the official Tensorflow implementation](https://github.com/bioinf-jku/TTUR).
+
+## Installation
+
+Install from [pip](https://pypi.org/project/pytorch-fid/):
+
+```
+pip install pytorch-fid
+```
+
+Requirements:
+
+- python3
+- pytorch
+- torchvision
+- pillow
+- numpy
+- scipy
+
+## Usage
+
+To compute the FID score between two datasets, where images of each dataset are contained in an individual folder:
+
+```
+python -m pytorch_fid path/to/dataset1 path/to/dataset2
+```
+
+To run the evaluation on GPU, use the flag `--device cuda:N`, where `N` is the index of the GPU to use.
+
+### Using different layers for feature maps
+
+In difference to the official implementation, you can choose to use a different feature layer of the Inception network instead of the default `pool3` layer.
+As the lower layer features still have spatial extent, the features are first global average pooled to a vector before estimating mean and covariance.
+
+This might be useful if the datasets you want to compare have less than the otherwise required 2048 images.
+Note that this changes the magnitude of the FID score and you can not compare them against scores calculated on another dimensionality.
+The resulting scores might also no longer correlate with visual quality.
+
+You can select the dimensionality of features to use with the flag `--dims N`, where N is the dimensionality of features.
+The choices are:
+
+- 64: first max pooling features
+- 192: second max pooling features
+- 768: pre-aux classifier features
+- 2048: final average pooling features (this is the default)
+
+## Generating a compatible `.npz` archive from a dataset
+
+A frequent use case will be to compare multiple models against an original dataset.
+To save training multiple times on the original dataset, there is also the ability to generate a compatible `.npz` archive from a dataset. This is done using any combination of the previously mentioned arguments with the addition of the `--save-stats` flag. For example:
+
+```
+python -m pytorch_fid --save-stats path/to/dataset path/to/outputfile
+```
+
+The output file may then be used in place of the path to the original dataset for further comparisons.
+
+## Citing
+
+If you use this repository in your research, consider citing it using the following Bibtex entry:
+
+```
+@misc{Seitzer2020FID,
+ author={Maximilian Seitzer},
+ title={{pytorch-fid: FID Score for PyTorch}},
+ month={August},
+ year={2020},
+ note={Version 0.3.0},
+ howpublished={\url{https://github.com/mseitzer/pytorch-fid}},
+}
+```
+
+## License
+
+This implementation is licensed under the Apache License 2.0.
+
+FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see [https://arxiv.org/abs/1706.08500](https://arxiv.org/abs/1706.08500)
+
+The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0.
+See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR).
diff --git a/tools/metrics/pytorch-fid/compute_fid.py b/tools/metrics/pytorch-fid/compute_fid.py
new file mode 100644
index 0000000..a83fdc7
--- /dev/null
+++ b/tools/metrics/pytorch-fid/compute_fid.py
@@ -0,0 +1,327 @@
+import json
+import os
+import pathlib
+from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
+
+import numpy as np
+import torch
+import torchvision.transforms as T
+from PIL import Image
+from pytorch_fid.inception import InceptionV3
+from scipy import linalg
+from torch.nn.functional import adaptive_avg_pool2d
+
+from tools.metrics.utils import tracker
+
+try:
+ from tqdm import tqdm
+except ImportError:
+ # If tqdm is not available, provide a mock version of it
+ def tqdm(x):
+ return x
+
+
+IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}
+
+
+class ImagePathDataset(torch.utils.data.Dataset):
+ def __init__(self, files, transforms=None):
+ self.files = files
+ self.transforms = transforms
+
+ def __len__(self):
+ return len(self.files)
+
+ def __getitem__(self, i):
+ path = self.files[i]
+ try:
+ img = Image.open(path)
+ assert img.mode == "RGB"
+ if self.transforms is not None:
+ img = self.transforms(img)
+ except Exception as e:
+ raise FileNotFoundError(path, "\n", e)
+
+ return img
+
+
+def get_activations(files, model, batch_size=50, dims=2048, device="cpu", num_workers=1):
+ model.eval()
+
+ if batch_size > len(files):
+ print("Warning: batch size is bigger than the data size. " "Setting batch size to data size")
+ batch_size = len(files)
+ transform = T.Compose(
+ [
+ T.Resize(args.img_size), # Image.BICUBIC
+ T.CenterCrop(args.img_size),
+ T.ToTensor(),
+ ]
+ )
+ dataset = ImagePathDataset(files, transforms=transform)
+ dataloader = torch.utils.data.DataLoader(
+ dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers
+ )
+
+ pred_arr = np.empty((len(files), dims))
+
+ start_idx = 0
+
+ for batch in tqdm(dataloader, desc=f"FID: {args.exp_name}", position=args.gpu_id, leave=True):
+ batch = batch.to(device)
+
+ with torch.no_grad():
+ pred = model(batch)[0]
+
+ # If model output is not scalar, apply global spatial average pooling.
+ # This happens if you choose a dimensionality not equal 2048.
+ if pred.size(2) != 1 or pred.size(3) != 1:
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
+
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
+
+ pred_arr[start_idx : start_idx + pred.shape[0]] = pred
+
+ start_idx = start_idx + pred.shape[0]
+
+ return pred_arr
+
+
+def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
+
+ mu1 = np.atleast_1d(mu1)
+ mu2 = np.atleast_1d(mu2)
+
+ sigma1 = np.atleast_2d(sigma1)
+ sigma2 = np.atleast_2d(sigma2)
+
+ assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
+ assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
+
+ diff = mu1 - mu2
+
+ # Product might be almost singular
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+ if not np.isfinite(covmean).all():
+ msg = ("fid calculation produces singular product; " "adding %s to diagonal of cov estimates") % eps
+ print(msg)
+ offset = np.eye(sigma1.shape[0]) * eps
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+ # Numerical error might give slight imaginary component
+ if np.iscomplexobj(covmean):
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+ m = np.max(np.abs(covmean.imag))
+ raise ValueError(f"Imaginary component {m}")
+ covmean = covmean.real
+
+ tr_covmean = np.trace(covmean)
+
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
+
+
+def calculate_activation_statistics(files, model, batch_size=50, dims=2048, device="cpu", num_workers=1):
+ act = get_activations(files, model, batch_size, dims, device, num_workers)
+ mu = np.mean(act, axis=0)
+ sigma = np.cov(act, rowvar=False)
+ return mu, sigma
+
+
+def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=1, flag="ref"):
+ if path.endswith(".npz"):
+ print("loaded from npz files")
+ with np.load(path) as f:
+ m, s = f["mu"][:], f["sigma"][:]
+ elif path.endswith(".json"):
+ with open(path) as file:
+ data_dict = json.load(file)
+ all_lines = list(data_dict.keys())[:sample_nums]
+
+ files = []
+ if isinstance(all_lines, list):
+ for k in all_lines:
+ v = data_dict[k]
+ if "PG-eval-data" in args.img_path:
+ img_path = os.path.join(args.img_path, v["category"], f"{k}.jpg")
+ else:
+ img_path = os.path.join(args.img_path, args.exp_name, f"{k}.jpg")
+ files.append(img_path)
+ elif isinstance(all_lines, dict):
+ assert sample_nums >= 30_000, ValueError(f"{sample_nums} is not supported for json files")
+ for k, v in all_lines.items():
+ if "PG-eval-data" in args.img_path:
+ img_path = os.path.join(args.img_path, v["category"], f"{k}.jpg")
+ else:
+ img_path = os.path.join(args.img_path, args.exp_name, f"{k}.jpg")
+ files.append(img_path)
+
+ files = sorted(files)
+ m, s = calculate_activation_statistics(files, model, batch_size, dims, device, num_workers)
+ else:
+ path = pathlib.Path(path)
+ files = sorted([file for ext in IMAGE_EXTENSIONS for file in path.glob(f"*.{ext}")])
+
+ m, s = calculate_activation_statistics(files, model, batch_size, dims, device, num_workers)
+ return m, s
+
+
+def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
+ """Calculates the FID of two paths"""
+ for p in paths:
+ if not os.path.exists(p):
+ raise RuntimeError("Invalid path: %s" % p)
+
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
+
+ model = InceptionV3([block_idx]).to(device)
+
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, dims, device, num_workers, flag="ref")
+ m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, dims, device, num_workers, flag="gen")
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
+
+ return fid_value
+
+
+def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
+ """Calculates the FID of two paths"""
+ if not os.path.exists(paths[0]):
+ raise RuntimeError("Invalid path: %s" % paths[0])
+
+ if os.path.exists(paths[1]):
+ raise RuntimeError("Existing output file: %s" % paths[1])
+
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
+
+ model = InceptionV3([block_idx]).to(device)
+
+ print(f"Saving statistics for {paths[0]}")
+
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, dims, device, num_workers, flag="ref")
+ np.savez_compressed(paths[1], mu=m1, sigma=s1)
+
+
+def main():
+ txt_path = args.txt_path if args.txt_path is not None else args.img_path
+ save_txt_path = os.path.join(txt_path, f"{args.exp_name}_sample{sample_nums}.txt")
+ if os.path.exists(save_txt_path):
+ with open(save_txt_path) as f:
+ fid_value = f.readlines()[0].strip()
+ print(f"FID {fid_value}: {args.exp_name}")
+ return {args.exp_name: float(fid_value)}
+
+ if args.device is None:
+ device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
+ else:
+ device = torch.device(args.device)
+
+ if args.num_workers is None:
+ try:
+ num_cpus = len(os.sched_getaffinity(0))
+ except AttributeError:
+ num_cpus = os.cpu_count()
+
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
+ else:
+ num_workers = args.num_workers
+
+ if args.save_stats:
+ save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
+ return
+
+ fid_value = calculate_fid_given_paths(args.path, args.batch_size, device, args.dims, num_workers)
+ print(f"FID {fid_value}: {args.exp_name}")
+ with open(save_txt_path, "w") as file:
+ file.write(str(fid_value))
+
+ return {args.exp_name: fid_value}
+
+
+def parse_args():
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+ parser.add_argument("--batch-size", type=int, default=50, help="Batch size to use")
+ parser.add_argument(
+ "--num-workers", type=int, help="Number of processes to use for data loading. Defaults to `min(8, num_cpus)`"
+ )
+ parser.add_argument("--img_size", type=int, default=512)
+ parser.add_argument("--device", type=str, default="cuda", help="Device to use. Like cuda, cuda:0 or cpu")
+
+ parser.add_argument("--img_path", type=str, default=None)
+ parser.add_argument("--exp_name", type=str, default="Sana")
+ parser.add_argument("--txt_path", type=str, default=None)
+ parser.add_argument("--sample_nums", type=int, default=30_000)
+
+ parser.add_argument(
+ "--dims",
+ type=int,
+ default=2048,
+ choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
+ help="Dimensionality of Inception features to use. By default, uses pool3 features",
+ )
+ parser.add_argument(
+ "--save-stats",
+ action="store_true",
+ help="Generate an npz archive from a directory of samples. The first path is used as input and the second as output.",
+ )
+ parser.add_argument("--stat", action="store_true")
+ parser.add_argument(
+ "--path", type=str, nargs=2, default=["", ""], help="Paths to the generated images or to .npz statistic files"
+ )
+
+ # online logging setting
+ parser.add_argument("--log_metric", type=str, default="metric")
+ parser.add_argument("--gpu_id", type=int, default=0)
+ parser.add_argument("--log_fid", action="store_true")
+ parser.add_argument("--suffix_label", type=str, default="", help="used for fid online log")
+ parser.add_argument("--tracker_pattern", type=str, default="epoch_step", help="used for fid online log")
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default=None,
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="t2i-evit-baseline",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+ parser.add_argument(
+ "--name",
+ type=str,
+ default="baseline",
+ help=("Wandb Project Name"),
+ )
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ sample_nums = args.sample_nums
+ if args.stat:
+ if args.device is None:
+ device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
+ else:
+ device = torch.device(args.device)
+
+ if args.num_workers is None:
+ try:
+ num_cpus = len(os.sched_getaffinity(0))
+ except AttributeError:
+ num_cpus = os.cpu_count()
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
+ else:
+ num_workers = args.num_workers
+ save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
+ else:
+ print(args.path, args.exp_name)
+ args.exp_name = os.path.basename(args.exp_name) or os.path.dirname(args.exp_name)
+ fid_result = main()
+ if args.log_fid:
+ tracker(args, fid_result, args.suffix_label, pattern=args.tracker_pattern, metric="FID")
diff --git a/tools/metrics/pytorch-fid/noxfile.py b/tools/metrics/pytorch-fid/noxfile.py
new file mode 100644
index 0000000..651bd8e
--- /dev/null
+++ b/tools/metrics/pytorch-fid/noxfile.py
@@ -0,0 +1,21 @@
+import nox
+
+LOCATIONS = ("src/", "tests/", "noxfile.py", "setup.py")
+
+
+@nox.session
+def lint(session):
+ session.install("flake8")
+ session.install("flake8-bugbear")
+ session.install("flake8-isort")
+
+ args = session.posargs or LOCATIONS
+ session.run("flake8", *args)
+
+
+@nox.session
+def tests(session):
+ session.install(".")
+ session.install("pytest")
+ session.install("pytest-mock")
+ session.run("pytest", *session.posargs)
diff --git a/tools/metrics/pytorch-fid/setup.cfg b/tools/metrics/pytorch-fid/setup.cfg
new file mode 100644
index 0000000..a1b91ae
--- /dev/null
+++ b/tools/metrics/pytorch-fid/setup.cfg
@@ -0,0 +1,8 @@
+[flake8]
+select=F,W,E,I,B,B9
+ignore=W503,B950
+max-line-length=79
+
+[isort]
+multi_line_output=1
+line_length=79
diff --git a/tools/metrics/pytorch-fid/setup.py b/tools/metrics/pytorch-fid/setup.py
new file mode 100644
index 0000000..c7c1ad1
--- /dev/null
+++ b/tools/metrics/pytorch-fid/setup.py
@@ -0,0 +1,45 @@
+import os
+
+import setuptools
+
+
+def read(rel_path):
+ base_path = os.path.abspath(os.path.dirname(__file__))
+ with open(os.path.join(base_path, rel_path)) as f:
+ return f.read()
+
+
+def get_version(rel_path):
+ for line in read(rel_path).splitlines():
+ if line.startswith("__version__"):
+ # __version__ = "0.9"
+ delim = '"' if '"' in line else "'"
+ return line.split(delim)[1]
+
+ raise RuntimeError("Unable to find version string.")
+
+
+if __name__ == "__main__":
+ setuptools.setup(
+ name="pytorch-fid",
+ version=get_version(os.path.join("src", "pytorch_fid", "__init__.py")),
+ author="Max Seitzer",
+ description=("Package for calculating Frechet Inception Distance (FID)" " using PyTorch"),
+ long_description=read("README.md"),
+ long_description_content_type="text/markdown",
+ url="https://github.com/mseitzer/pytorch-fid",
+ package_dir={"": "src"},
+ packages=setuptools.find_packages(where="src"),
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ ],
+ python_requires=">=3.5",
+ entry_points={
+ "console_scripts": [
+ "pytorch-fid = pytorch_fid.fid_score:main",
+ ],
+ },
+ install_requires=["numpy", "pillow", "scipy", "torch>=1.0.1", "torchvision>=0.2.2"],
+ extras_require={"dev": ["flake8", "flake8-bugbear", "flake8-isort", "nox"]},
+ )
diff --git a/tools/metrics/pytorch-fid/src/pytorch_fid/__init__.py b/tools/metrics/pytorch-fid/src/pytorch_fid/__init__.py
new file mode 100644
index 0000000..493f741
--- /dev/null
+++ b/tools/metrics/pytorch-fid/src/pytorch_fid/__init__.py
@@ -0,0 +1 @@
+__version__ = "0.3.0"
diff --git a/tools/metrics/pytorch-fid/src/pytorch_fid/__main__.py b/tools/metrics/pytorch-fid/src/pytorch_fid/__main__.py
new file mode 100644
index 0000000..197ee40
--- /dev/null
+++ b/tools/metrics/pytorch-fid/src/pytorch_fid/__main__.py
@@ -0,0 +1,3 @@
+import pytorch_fid.fid_score
+
+pytorch_fid.fid_score.main()
diff --git a/tools/metrics/pytorch-fid/src/pytorch_fid/fid_score.py b/tools/metrics/pytorch-fid/src/pytorch_fid/fid_score.py
new file mode 100644
index 0000000..2e4950a
--- /dev/null
+++ b/tools/metrics/pytorch-fid/src/pytorch_fid/fid_score.py
@@ -0,0 +1,307 @@
+"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
+
+The FID metric calculates the distance between two distributions of images.
+Typically, we have summary statistics (mean & covariance matrix) of one
+of these distributions, while the 2nd distribution is given by a GAN.
+
+When run as a stand-alone program, it compares the distribution of
+images that are stored as PNG/JPEG at a specified location with a
+distribution given by summary statistics (in pickle format).
+
+The FID is calculated by assuming that X_1 and X_2 are the activations of
+the pool_3 layer of the inception net for generated samples and real world
+samples respectively.
+
+See --help to see further details.
+
+Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
+of Tensorflow
+
+Copyright 2018 Institute of Bioinformatics, JKU Linz
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import os
+import pathlib
+from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
+
+import numpy as np
+import torch
+import torchvision.transforms as TF
+from PIL import Image
+from scipy import linalg
+from torch.nn.functional import adaptive_avg_pool2d
+
+try:
+ from tqdm import tqdm
+except ImportError:
+ # If tqdm is not available, provide a mock version of it
+ def tqdm(x):
+ return x
+
+
+from pytorch_fid.inception import InceptionV3
+
+parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+parser.add_argument("--batch-size", type=int, default=50, help="Batch size to use")
+parser.add_argument(
+ "--num-workers", type=int, help=("Number of processes to use for data loading. " "Defaults to `min(8, num_cpus)`")
+)
+parser.add_argument("--device", type=str, default=None, help="Device to use. Like cuda, cuda:0 or cpu")
+parser.add_argument(
+ "--dims",
+ type=int,
+ default=2048,
+ choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
+ help=("Dimensionality of Inception features to use. " "By default, uses pool3 features"),
+)
+parser.add_argument(
+ "--save-stats",
+ action="store_true",
+ help=(
+ "Generate an npz archive from a directory of samples. "
+ "The first path is used as input and the second as output."
+ ),
+)
+parser.add_argument("path", type=str, nargs=2, help=("Paths to the generated images or " "to .npz statistic files"))
+
+IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}
+
+
+class ImagePathDataset(torch.utils.data.Dataset):
+ def __init__(self, files, transforms=None):
+ self.files = files
+ self.transforms = transforms
+
+ def __len__(self):
+ return len(self.files)
+
+ def __getitem__(self, i):
+ path = self.files[i]
+ img = Image.open(path).convert("RGB")
+ if self.transforms is not None:
+ img = self.transforms(img)
+ return img
+
+
+def get_activations(files, model, batch_size=50, dims=2048, device="cpu", num_workers=1):
+ """Calculates the activations of the pool_3 layer for all images.
+
+ Params:
+ -- files : List of image files paths
+ -- model : Instance of inception model
+ -- batch_size : Batch size of images for the model to process at once.
+ Make sure that the number of samples is a multiple of
+ the batch size, otherwise some samples are ignored. This
+ behavior is retained to match the original FID score
+ implementation.
+ -- dims : Dimensionality of features returned by Inception
+ -- device : Device to run calculations
+ -- num_workers : Number of parallel dataloader workers
+
+ Returns:
+ -- A numpy array of dimension (num images, dims) that contains the
+ activations of the given tensor when feeding inception with the
+ query tensor.
+ """
+ model.eval()
+
+ if batch_size > len(files):
+ print("Warning: batch size is bigger than the data size. " "Setting batch size to data size")
+ batch_size = len(files)
+
+ dataset = ImagePathDataset(files, transforms=TF.ToTensor())
+ dataloader = torch.utils.data.DataLoader(
+ dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers
+ )
+
+ pred_arr = np.empty((len(files), dims))
+
+ start_idx = 0
+
+ for batch in tqdm(dataloader):
+ batch = batch.to(device)
+
+ with torch.no_grad():
+ pred = model(batch)[0]
+
+ # If model output is not scalar, apply global spatial average pooling.
+ # This happens if you choose a dimensionality not equal 2048.
+ if pred.size(2) != 1 or pred.size(3) != 1:
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
+
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
+
+ pred_arr[start_idx : start_idx + pred.shape[0]] = pred
+
+ start_idx = start_idx + pred.shape[0]
+
+ return pred_arr
+
+
+def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
+ """Numpy implementation of the Frechet Distance.
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
+ and X_2 ~ N(mu_2, C_2) is
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+
+ Stable version by Dougal J. Sutherland.
+
+ Params:
+ -- mu1 : Numpy array containing the activations of a layer of the
+ inception net (like returned by the function 'get_predictions')
+ for generated samples.
+ -- mu2 : The sample mean over activations, precalculated on an
+ representative data set.
+ -- sigma1: The covariance matrix over activations for generated samples.
+ -- sigma2: The covariance matrix over activations, precalculated on an
+ representative data set.
+
+ Returns:
+ -- : The Frechet Distance.
+ """
+
+ mu1 = np.atleast_1d(mu1)
+ mu2 = np.atleast_1d(mu2)
+
+ sigma1 = np.atleast_2d(sigma1)
+ sigma2 = np.atleast_2d(sigma2)
+
+ assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
+ assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
+
+ diff = mu1 - mu2
+
+ # Product might be almost singular
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+ if not np.isfinite(covmean).all():
+ msg = ("fid calculation produces singular product; " "adding %s to diagonal of cov estimates") % eps
+ print(msg)
+ offset = np.eye(sigma1.shape[0]) * eps
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+ # Numerical error might give slight imaginary component
+ if np.iscomplexobj(covmean):
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+ m = np.max(np.abs(covmean.imag))
+ raise ValueError(f"Imaginary component {m}")
+ covmean = covmean.real
+
+ tr_covmean = np.trace(covmean)
+
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
+
+
+def calculate_activation_statistics(files, model, batch_size=50, dims=2048, device="cpu", num_workers=1):
+ """Calculation of the statistics used by the FID.
+ Params:
+ -- files : List of image files paths
+ -- model : Instance of inception model
+ -- batch_size : The images numpy array is split into batches with
+ batch size batch_size. A reasonable batch size
+ depends on the hardware.
+ -- dims : Dimensionality of features returned by Inception
+ -- device : Device to run calculations
+ -- num_workers : Number of parallel dataloader workers
+
+ Returns:
+ -- mu : The mean over samples of the activations of the pool_3 layer of
+ the inception model.
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
+ the inception model.
+ """
+ act = get_activations(files, model, batch_size, dims, device, num_workers)
+ mu = np.mean(act, axis=0)
+ sigma = np.cov(act, rowvar=False)
+ return mu, sigma
+
+
+def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=1):
+ if path.endswith(".npz"):
+ with np.load(path) as f:
+ m, s = f["mu"][:], f["sigma"][:]
+ else:
+ path = pathlib.Path(path)
+ files = sorted([file for ext in IMAGE_EXTENSIONS for file in path.glob(f"*.{ext}")])
+ m, s = calculate_activation_statistics(files, model, batch_size, dims, device, num_workers)
+
+ return m, s
+
+
+def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
+ """Calculates the FID of two paths"""
+ for p in paths:
+ if not os.path.exists(p):
+ raise RuntimeError("Invalid path: %s" % p)
+
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
+
+ model = InceptionV3([block_idx]).to(device)
+
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, dims, device, num_workers)
+ m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, dims, device, num_workers)
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
+
+ return fid_value
+
+
+def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
+ """Calculates the FID of two paths"""
+ if not os.path.exists(paths[0]):
+ raise RuntimeError("Invalid path: %s" % paths[0])
+
+ if os.path.exists(paths[1]):
+ raise RuntimeError("Existing output file: %s" % paths[1])
+
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
+
+ model = InceptionV3([block_idx]).to(device)
+
+ print(f"Saving statistics for {paths[0]}")
+
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, dims, device, num_workers)
+
+ np.savez_compressed(paths[1], mu=m1, sigma=s1)
+
+
+def main():
+ args = parser.parse_args()
+
+ if args.device is None:
+ device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
+ else:
+ device = torch.device(args.device)
+
+ if args.num_workers is None:
+ try:
+ num_cpus = len(os.sched_getaffinity(0))
+ except AttributeError:
+ # os.sched_getaffinity is not available under Windows, use
+ # os.cpu_count instead (which may not return the *available* number
+ # of CPUs).
+ num_cpus = os.cpu_count()
+
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
+ else:
+ num_workers = args.num_workers
+
+ if args.save_stats:
+ save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
+ return
+
+ fid_value = calculate_fid_given_paths(args.path, args.batch_size, device, args.dims, num_workers)
+ print("FID: ", fid_value)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/metrics/pytorch-fid/src/pytorch_fid/inception.py b/tools/metrics/pytorch-fid/src/pytorch_fid/inception.py
new file mode 100644
index 0000000..c703082
--- /dev/null
+++ b/tools/metrics/pytorch-fid/src/pytorch_fid/inception.py
@@ -0,0 +1,336 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+
+try:
+ from torchvision.models.utils import load_state_dict_from_url
+except ImportError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+# Inception weights ported to Pytorch from
+# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+FID_WEIGHTS_URL = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" # noqa: E501
+
+
+class InceptionV3(nn.Module):
+ """Pretrained InceptionV3 network returning feature maps"""
+
+ # Index of default block of inception to return,
+ # corresponds to output of final average pooling
+ DEFAULT_BLOCK_INDEX = 3
+
+ # Maps feature dimensionality to their output blocks indices
+ BLOCK_INDEX_BY_DIM = {
+ 64: 0, # First max pooling features
+ 192: 1, # Second max pooling featurs
+ 768: 2, # Pre-aux classifier features
+ 2048: 3, # Final average pooling features
+ }
+
+ def __init__(
+ self,
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
+ resize_input=True,
+ normalize_input=True,
+ requires_grad=False,
+ use_fid_inception=True,
+ ):
+ """Build pretrained InceptionV3
+
+ Parameters
+ ----------
+ output_blocks : list of int
+ Indices of blocks to return features of. Possible values are:
+ - 0: corresponds to output of first max pooling
+ - 1: corresponds to output of second max pooling
+ - 2: corresponds to output which is fed to aux classifier
+ - 3: corresponds to output of final average pooling
+ resize_input : bool
+ If true, bilinearly resizes input to width and height 299 before
+ feeding input to model. As the network without fully connected
+ layers is fully convolutional, it should be able to handle inputs
+ of arbitrary size, so resizing might not be strictly needed
+ normalize_input : bool
+ If true, scales the input from range (0, 1) to the range the
+ pretrained Inception network expects, namely (-1, 1)
+ requires_grad : bool
+ If true, parameters of the model require gradients. Possibly useful
+ for finetuning the network
+ use_fid_inception : bool
+ If true, uses the pretrained Inception model used in Tensorflow's
+ FID implementation. If false, uses the pretrained Inception model
+ available in torchvision. The FID Inception model has different
+ weights and a slightly different structure from torchvision's
+ Inception model. If you want to compute FID scores, you are
+ strongly advised to set this parameter to true to get comparable
+ results.
+ """
+ super().__init__()
+
+ self.resize_input = resize_input
+ self.normalize_input = normalize_input
+ self.output_blocks = sorted(output_blocks)
+ self.last_needed_block = max(output_blocks)
+
+ assert self.last_needed_block <= 3, "Last possible output block index is 3"
+
+ self.blocks = nn.ModuleList()
+
+ if use_fid_inception:
+ inception = fid_inception_v3()
+ else:
+ inception = _inception_v3(weights="DEFAULT")
+
+ # Block 0: input to maxpool1
+ block0 = [
+ inception.Conv2d_1a_3x3,
+ inception.Conv2d_2a_3x3,
+ inception.Conv2d_2b_3x3,
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ ]
+ self.blocks.append(nn.Sequential(*block0))
+
+ # Block 1: maxpool1 to maxpool2
+ if self.last_needed_block >= 1:
+ block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]
+ self.blocks.append(nn.Sequential(*block1))
+
+ # Block 2: maxpool2 to aux classifier
+ if self.last_needed_block >= 2:
+ block2 = [
+ inception.Mixed_5b,
+ inception.Mixed_5c,
+ inception.Mixed_5d,
+ inception.Mixed_6a,
+ inception.Mixed_6b,
+ inception.Mixed_6c,
+ inception.Mixed_6d,
+ inception.Mixed_6e,
+ ]
+ self.blocks.append(nn.Sequential(*block2))
+
+ # Block 3: aux classifier to final avgpool
+ if self.last_needed_block >= 3:
+ block3 = [
+ inception.Mixed_7a,
+ inception.Mixed_7b,
+ inception.Mixed_7c,
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
+ ]
+ self.blocks.append(nn.Sequential(*block3))
+
+ for param in self.parameters():
+ param.requires_grad = requires_grad
+
+ def forward(self, inp):
+ """Get Inception feature maps
+
+ Parameters
+ ----------
+ inp : torch.autograd.Variable
+ Input tensor of shape Bx3xHxW. Values are expected to be in
+ range (0, 1)
+
+ Returns
+ -------
+ List of torch.autograd.Variable, corresponding to the selected output
+ block, sorted ascending by index
+ """
+ outp = []
+ x = inp
+
+ if self.resize_input:
+ x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
+
+ if self.normalize_input:
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
+
+ for idx, block in enumerate(self.blocks):
+ x = block(x)
+ if idx in self.output_blocks:
+ outp.append(x)
+
+ if idx == self.last_needed_block:
+ break
+
+ return outp
+
+
+def _inception_v3(*args, **kwargs):
+ """Wraps `torchvision.models.inception_v3`"""
+ try:
+ version = tuple(map(int, torchvision.__version__.split(".")[:2]))
+ except ValueError:
+ # Just a caution against weird version strings
+ version = (0,)
+
+ # Skips default weight inititialization if supported by torchvision
+ # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
+ if version >= (0, 6):
+ kwargs["init_weights"] = False
+
+ # Backwards compatibility: `weights` argument was handled by `pretrained`
+ # argument prior to version 0.13.
+ if version < (0, 13) and "weights" in kwargs:
+ if kwargs["weights"] == "DEFAULT":
+ kwargs["pretrained"] = True
+ elif kwargs["weights"] is None:
+ kwargs["pretrained"] = False
+ else:
+ raise ValueError(
+ "weights=={} not supported in torchvision {}".format(kwargs["weights"], torchvision.__version__)
+ )
+ del kwargs["weights"]
+
+ return torchvision.models.inception_v3(*args, **kwargs)
+
+
+def fid_inception_v3():
+ """Build pretrained Inception model for FID computation
+
+ The Inception model for FID computation uses a different set of weights
+ and has a slightly different structure than torchvision's Inception.
+
+ This method first constructs torchvision's Inception and then patches the
+ necessary parts that are different in the FID Inception model.
+ """
+ inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None)
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
+ inception.Mixed_7b = FIDInceptionE_1(1280)
+ inception.Mixed_7c = FIDInceptionE_2(2048)
+
+ # state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
+ # inception.load_state_dict(state_dict)
+ inception.load_state_dict(
+ torch.load("output/pretrained_models/pt_inception-2015-12-05-6726825d.pth", map_location="cpu")
+ )
+ print(f"model loaded")
+ return inception
+
+
+class FIDInceptionA(torchvision.models.inception.InceptionA):
+ """InceptionA block patched for FID computation"""
+
+ def __init__(self, in_channels, pool_features):
+ super().__init__(in_channels, pool_features)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch5x5 = self.branch5x5_1(x)
+ branch5x5 = self.branch5x5_2(branch5x5)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionC(torchvision.models.inception.InceptionC):
+ """InceptionC block patched for FID computation"""
+
+ def __init__(self, in_channels, channels_7x7):
+ super().__init__(in_channels, channels_7x7)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch7x7 = self.branch7x7_1(x)
+ branch7x7 = self.branch7x7_2(branch7x7)
+ branch7x7 = self.branch7x7_3(branch7x7)
+
+ branch7x7dbl = self.branch7x7dbl_1(x)
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_1(torchvision.models.inception.InceptionE):
+ """First InceptionE block patched for FID computation"""
+
+ def __init__(self, in_channels):
+ super().__init__(in_channels)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = [
+ self.branch3x3_2a(branch3x3),
+ self.branch3x3_2b(branch3x3),
+ ]
+ branch3x3 = torch.cat(branch3x3, 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = [
+ self.branch3x3dbl_3a(branch3x3dbl),
+ self.branch3x3dbl_3b(branch3x3dbl),
+ ]
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_2(torchvision.models.inception.InceptionE):
+ """Second InceptionE block patched for FID computation"""
+
+ def __init__(self, in_channels):
+ super().__init__(in_channels)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = [
+ self.branch3x3_2a(branch3x3),
+ self.branch3x3_2b(branch3x3),
+ ]
+ branch3x3 = torch.cat(branch3x3, 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = [
+ self.branch3x3dbl_3a(branch3x3dbl),
+ self.branch3x3dbl_3b(branch3x3dbl),
+ ]
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+ # Patch: The FID Inception model uses max pooling instead of average
+ # pooling. This is likely an error in this specific Inception
+ # implementation, as other Inception models use average pooling here
+ # (which matches the description in the paper).
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
diff --git a/tools/metrics/pytorch-fid/tests/test_fid_score.py b/tools/metrics/pytorch-fid/tests/test_fid_score.py
new file mode 100644
index 0000000..6520c1c
--- /dev/null
+++ b/tools/metrics/pytorch-fid/tests/test_fid_score.py
@@ -0,0 +1,90 @@
+import numpy as np
+import pytest
+import torch
+from PIL import Image
+from pytorch_fid import fid_score, inception
+
+
+@pytest.fixture
+def device():
+ return torch.device("cpu")
+
+
+def test_calculate_fid_given_statistics(mocker, tmp_path, device):
+ dim = 2048
+ m1, m2 = np.zeros((dim,)), np.ones((dim,))
+ sigma = np.eye(dim)
+
+ def dummy_statistics(path, model, batch_size, dims, device, num_workers):
+ if path.endswith("1"):
+ return m1, sigma
+ elif path.endswith("2"):
+ return m2, sigma
+ else:
+ raise ValueError
+
+ mocker.patch("pytorch_fid.fid_score.compute_statistics_of_path", side_effect=dummy_statistics)
+
+ dir_names = ["1", "2"]
+ paths = []
+ for name in dir_names:
+ path = tmp_path / name
+ path.mkdir()
+ paths.append(str(path))
+
+ fid_value = fid_score.calculate_fid_given_paths(paths, batch_size=dim, device=device, dims=dim, num_workers=0)
+
+ # Given equal covariance, FID is just the squared norm of difference
+ assert fid_value == np.sum((m1 - m2) ** 2)
+
+
+def test_compute_statistics_of_path(mocker, tmp_path, device):
+ model = mocker.MagicMock(inception.InceptionV3)()
+ model.side_effect = lambda inp: [inp.mean(dim=(2, 3), keepdim=True)]
+
+ size = (4, 4, 3)
+ arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)]
+ images = [(arr * 255).astype(np.uint8) for arr in arrays]
+
+ paths = []
+ for idx, image in enumerate(images):
+ paths.append(str(tmp_path / f"{idx}.png"))
+ Image.fromarray(image, mode="RGB").save(paths[-1])
+
+ stats = fid_score.compute_statistics_of_path(
+ str(tmp_path), model, batch_size=len(images), dims=3, device=device, num_workers=0
+ )
+
+ assert np.allclose(stats[0], np.ones((3,)) * 0.5, atol=1e-3)
+ assert np.allclose(stats[1], np.ones((3, 3)) * 0.25)
+
+
+def test_compute_statistics_of_path_from_file(mocker, tmp_path, device):
+ model = mocker.MagicMock(inception.InceptionV3)()
+
+ mu = np.random.randn(5)
+ sigma = np.random.randn(5, 5)
+
+ path = tmp_path / "stats.npz"
+ with path.open("wb") as f:
+ np.savez(f, mu=mu, sigma=sigma)
+
+ stats = fid_score.compute_statistics_of_path(str(path), model, batch_size=1, dims=5, device=device, num_workers=0)
+
+ assert np.allclose(stats[0], mu)
+ assert np.allclose(stats[1], sigma)
+
+
+def test_image_types(tmp_path):
+ in_arr = np.ones((24, 24, 3), dtype=np.uint8) * 255
+ in_image = Image.fromarray(in_arr, mode="RGB")
+
+ paths = []
+ for ext in fid_score.IMAGE_EXTENSIONS:
+ paths.append(str(tmp_path / f"img.{ext}"))
+ in_image.save(paths[-1])
+
+ dataset = fid_score.ImagePathDataset(paths)
+
+ for img in dataset:
+ assert np.allclose(np.array(img), in_arr)