Skip to content

Commit

Permalink
Nv labs GitHub repo/nv labs GitHub repo main adding controlnet (#23) (#…
Browse files Browse the repository at this point in the history
…177)

* Nv labs GitHub repo/nv labs GitHub repo main adding controlnet (#23)

* 1. add scripts of controlnet;
2. pre-commit;

Signed-off-by: lawrence-cj <[email protected]>

* add all we need for controlnet inference and run successful;

* move samples txt file into one dir;
update readme;

* 1. add readme for controlnet;
2. update readme;

* add 1.6B controlnet related model and config files;

Signed-off-by: lawrence-cj <[email protected]>

* update readme && pre-commit;

Signed-off-by: lawrence-cj <[email protected]>

* 1. update readme.md

* 1. add all need for online controlnet demo;
2. run success;

* little bug fixed;

Signed-off-by: lawrence-cj <[email protected]>

* code update && pre-commit;

Signed-off-by: lawrence-cj <[email protected]>

* 1. update controlnet readme;
2. pre-commit;

Signed-off-by: lawrence-cj <[email protected]>

* 1. update controlnet readme;
2. pre-commit;

Signed-off-by: lawrence-cj <[email protected]>

---------

Signed-off-by: lawrence-cj <[email protected]>
Co-authored-by: Enze Xie <[email protected]>

* 1. add test controlnet in CI;
2. fix controlnet config bug;

Signed-off-by: lawrence-cj <[email protected]>

* add ref image for controlnet;

Signed-off-by: lawrence-cj <[email protected]>

* update controlnet readme;

Signed-off-by: lawrence-cj <[email protected]>

* update controlnet CI;

Signed-off-by: lawrence-cj <[email protected]>

---------

Signed-off-by: lawrence-cj <[email protected]>
Co-authored-by: Enze Xie <[email protected]>
  • Loading branch information
lawrence-cj and xieenze authored Feb 10, 2025
1 parent dd38c12 commit 93c9f9d
Show file tree
Hide file tree
Showing 32 changed files with 2,054 additions and 10 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion models (e

## 🔥🔥 News

- (🔥 New) \[2025/2/10\] 🚀Sana + ControlNet is released. [\[Guidance\]](asset/docs/sana_controlnet) | [\[Model\]](asset/docs/model_zoo.md)
- (🔥 New) \[2025/1/30\] Release CAME-8bit optimizer code. Saving more GPU memory during training. [\[How to config\]](https://github.com/NVlabs/Sana/blob/main/configs/sana_config/1024ms/Sana_1600M_img1024_CAME8bit.yaml#L86)
- (🔥 New) \[2025/1/29\] 🎉 🎉 🎉**SANA 1.5 is out! Figure out how to do efficient training & inference scaling!** 🚀[\[Tech Report\]](asset/SANA_1.5.pdf)
- (🔥 New) \[2025/1/29\] 🎉 🎉 🎉**SANA 1.5 is out! Figure out how to do efficient training & inference scaling!** 🚀[\[Tech Report\]](https://arxiv.org/abs/2501.18427)
- (🔥 New) \[2025/1/24\] 4bit-Sana is released, powered by [SVDQuant and Nunchaku](https://github.com/mit-han-lab/nunchaku) inference engine. Now run your Sana within **8GB** GPU VRAM [\[Guidance\]](asset/docs/4bit_sana.md) [\[Demo\]](https://svdquant.mit.edu/) [\[Model\]](asset/docs/model_zoo.md)
- (🔥 New) \[2025/1/24\] DCAE-1.1 is released, better reconstruction quality. [\[Model\]](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.1) [\[diffusers\]](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers)
- (🔥 New) \[2025/1/23\] **Sana is accepted by ICLR-2025.** 🎉🎉🎉
Expand Down Expand Up @@ -271,16 +272,16 @@ docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
python scripts/inference.py \
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth \
--txt_file=asset/samples_mini.txt
--txt_file=asset/samples/samples_mini.txt

# Run samples in a json file
python scripts/inference.py \
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth \
--json_file=asset/samples_mini.json
--json_file=asset/samples/samples_mini.json
```

where each line of [`asset/samples_mini.txt`](asset/samples_mini.txt) contains a prompt to generate
where each line of [`asset/samples/samples_mini.txt`](asset/samples/samples_mini.txt) contains a prompt to generate

# 🔥 3. How to Train Sana

Expand Down
1 change: 0 additions & 1 deletion app/app_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def get_args():
args = get_args()

if torch.cuda.is_available():
weight_dtype = torch.float16
model_path = args.model_path
pipe = SanaPipeline(args.config)
pipe.from_pretrained(model_path)
Expand Down
306 changes: 306 additions & 0 deletions app/app_sana_controlnet_hed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import argparse
import os
import random
import socket
import tempfile
import time

import gradio as gr
import numpy as np
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer

from app import safety_check
from app.sana_controlnet_pipeline import SanaControlNetPipeline

STYLES = {
"None": "{prompt}",
"Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
"3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting",
"Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed",
"Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed",
"Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed",
"Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics",
"Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
}
DEFAULT_STYLE_NAME = "None"
STYLE_NAMES = list(STYLES.keys())

MAX_SEED = 1000000000
DEFAULT_SKETCH_GUIDANCE = 0.28
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, help="config")
parser.add_argument(
"--model_path",
nargs="?",
default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
type=str,
help="Path to the model file (positional)",
)
parser.add_argument("--output", default="./", type=str)
parser.add_argument("--bs", default=1, type=int)
parser.add_argument("--image_size", default=1024, type=int)
parser.add_argument("--cfg_scale", default=5.0, type=float)
parser.add_argument("--pag_scale", default=2.0, type=float)
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--step", default=-1, type=int)
parser.add_argument("--custom_image_size", default=None, type=int)
parser.add_argument("--share", action="store_true")
parser.add_argument(
"--shield_model_path",
type=str,
help="The path to shield model, we employ ShieldGemma-2B by default.",
default="google/shieldgemma-2b",
)

return parser.parse_known_args()[0]


args = get_args()

if torch.cuda.is_available():
model_path = args.model_path
pipe = SanaControlNetPipeline(args.config)
pipe.from_pretrained(model_path)
pipe.register_progress_bar(gr.Progress())

# safety checker
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
safety_checker_model = AutoModelForCausalLM.from_pretrained(
args.shield_model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
).to(device)


def save_image(img):
if isinstance(img, dict):
img = img["composite"]
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name


def norm_ip(img, low, high):
img.clamp_(min=low, max=high)
img.sub_(low).div_(max(high - low, 1e-5))
return img


@torch.no_grad()
@torch.inference_mode()
def run(
image,
prompt: str,
prompt_template: str,
sketch_thickness: int,
guidance_scale: float,
inference_steps: int,
seed: int,
blend_alpha: float,
) -> tuple[Image, str]:

print(f"Prompt: {prompt}")
image_numpy = np.array(image["composite"].convert("RGB"))

if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
return blank_image, "Please input the prompt or draw something."

if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
prompt = "A red heart."

prompt = prompt_template.format(prompt=prompt)
pipe.set_blend_alpha(blend_alpha)
start_time = time.time()
images = pipe(
prompt=prompt,
ref_image=image["composite"],
guidance_scale=guidance_scale,
num_inference_steps=inference_steps,
num_images_per_prompt=1,
sketch_thickness=sketch_thickness,
generator=torch.Generator(device=device).manual_seed(seed),
)

latency = time.time() - start_time

if latency < 1:
latency = latency * 1000
latency_str = f"{latency:.2f}ms"
else:
latency_str = f"{latency:.2f}s"
torch.cuda.empty_cache()

img = [
Image.fromarray(
norm_ip(img, -1, 1)
.mul(255)
.add_(0.5)
.clamp_(0, 255)
.permute(1, 2, 0)
.to("cpu", torch.uint8)
.numpy()
.astype(np.uint8)
)
for img in images
]
img = img[0]
return img, latency_str


model_size = "1.6" if "1600M" in args.model_path else "0.6"
title = f"""
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
</div>
"""
DESCRIPTION = f"""
<p><span style="font-size: 36px; font-weight: bold;">Sana-ControlNet-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
<p style="font-size: 18px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
<p style="font-size: 18px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}.
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
"""
if model_size == "0.6":
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"


with gr.Blocks(css_paths="asset/app_styles/controlnet_app_style.css", title=f"Sana Sketch-to-Image Demo") as demo:
gr.Markdown(title)
gr.HTML(DESCRIPTION)

with gr.Row(elem_id="main_row"):
with gr.Column(elem_id="column_input"):
gr.Markdown("## INPUT", elem_id="input_header")
with gr.Group():
canvas = gr.Sketchpad(
value=blank_image,
height=640,
image_mode="RGB",
sources=["upload", "clipboard"],
type="pil",
label="Sketch",
show_label=False,
show_download_button=True,
interactive=True,
transforms=[],
canvas_size=(1024, 1024),
scale=1,
brush=gr.Brush(default_size=3, colors=["#000000"], color_mode="fixed"),
format="png",
layers=False,
)
with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button")
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
with gr.Row():
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
prompt_template = gr.Textbox(
label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1
)

with gr.Row():
sketch_thickness = gr.Slider(
label="Sketch Thickness",
minimum=1,
maximum=4,
step=1,
value=2,
)
with gr.Row():
inference_steps = gr.Slider(
label="Sampling steps",
minimum=5,
maximum=40,
step=1,
value=20,
)
guidance_scale = gr.Slider(
label="CFG Guidance scale",
minimum=1,
maximum=10,
step=0.1,
value=4.5,
)
blend_alpha = gr.Slider(
label="Blend Alpha",
minimum=0,
maximum=1,
step=0.1,
value=0,
)
with gr.Row():
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")

with gr.Column(elem_id="column_output"):
gr.Markdown("## OUTPUT", elem_id="output_header")
with gr.Group():
result = gr.Image(
format="png",
height=640,
image_mode="RGB",
type="pil",
label="Result",
show_label=False,
show_download_button=True,
interactive=False,
elem_id="output_image",
)
latency_result = gr.Text(label="Inference Latency", show_label=True)

download_result = gr.DownloadButton("Download Result", elem_id="download_result")
gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
gr.Markdown("**2**. Start sketching or upload a reference image")
gr.Markdown("**3**. Change the image style using a style template")
gr.Markdown("**4**. Try different seeds to generate different results")

run_inputs = [canvas, prompt, prompt_template, sketch_thickness, guidance_scale, inference_steps, seed, blend_alpha]
run_outputs = [result, latency_result]

randomize_seed.click(
lambda: random.randint(0, MAX_SEED),
inputs=[],
outputs=seed,
api_name=False,
queue=False,
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)

style.change(
lambda x: STYLES[x],
inputs=[style],
outputs=[prompt_template],
api_name=False,
queue=False,
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
gr.on(
triggers=[prompt.submit, run_button.click, canvas.change],
fn=run,
inputs=run_inputs,
outputs=run_outputs,
api_name=False,
)

download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")


if __name__ == "__main__":
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
Loading

0 comments on commit 93c9f9d

Please sign in to comment.