-
Notifications
You must be signed in to change notification settings - Fork 214
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
…177) * Nv labs GitHub repo/nv labs GitHub repo main adding controlnet (#23) * 1. add scripts of controlnet; 2. pre-commit; Signed-off-by: lawrence-cj <[email protected]> * add all we need for controlnet inference and run successful; * move samples txt file into one dir; update readme; * 1. add readme for controlnet; 2. update readme; * add 1.6B controlnet related model and config files; Signed-off-by: lawrence-cj <[email protected]> * update readme && pre-commit; Signed-off-by: lawrence-cj <[email protected]> * 1. update readme.md * 1. add all need for online controlnet demo; 2. run success; * little bug fixed; Signed-off-by: lawrence-cj <[email protected]> * code update && pre-commit; Signed-off-by: lawrence-cj <[email protected]> * 1. update controlnet readme; 2. pre-commit; Signed-off-by: lawrence-cj <[email protected]> * 1. update controlnet readme; 2. pre-commit; Signed-off-by: lawrence-cj <[email protected]> --------- Signed-off-by: lawrence-cj <[email protected]> Co-authored-by: Enze Xie <[email protected]> * 1. add test controlnet in CI; 2. fix controlnet config bug; Signed-off-by: lawrence-cj <[email protected]> * add ref image for controlnet; Signed-off-by: lawrence-cj <[email protected]> * update controlnet readme; Signed-off-by: lawrence-cj <[email protected]> * update controlnet CI; Signed-off-by: lawrence-cj <[email protected]> --------- Signed-off-by: lawrence-cj <[email protected]> Co-authored-by: Enze Xie <[email protected]>
- Loading branch information
1 parent
dd38c12
commit 93c9f9d
Showing
32 changed files
with
2,054 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,306 @@ | ||
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py | ||
import argparse | ||
import os | ||
import random | ||
import socket | ||
import tempfile | ||
import time | ||
|
||
import gradio as gr | ||
import numpy as np | ||
import torch | ||
from PIL import Image | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from app import safety_check | ||
from app.sana_controlnet_pipeline import SanaControlNetPipeline | ||
|
||
STYLES = { | ||
"None": "{prompt}", | ||
"Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", | ||
"3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting", | ||
"Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed", | ||
"Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed", | ||
"Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed", | ||
"Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics", | ||
"Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", | ||
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", | ||
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style", | ||
} | ||
DEFAULT_STYLE_NAME = "None" | ||
STYLE_NAMES = list(STYLES.keys()) | ||
|
||
MAX_SEED = 1000000000 | ||
DEFAULT_SKETCH_GUIDANCE = 0.28 | ||
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432")) | ||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255)) | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--config", type=str, help="config") | ||
parser.add_argument( | ||
"--model_path", | ||
nargs="?", | ||
default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth", | ||
type=str, | ||
help="Path to the model file (positional)", | ||
) | ||
parser.add_argument("--output", default="./", type=str) | ||
parser.add_argument("--bs", default=1, type=int) | ||
parser.add_argument("--image_size", default=1024, type=int) | ||
parser.add_argument("--cfg_scale", default=5.0, type=float) | ||
parser.add_argument("--pag_scale", default=2.0, type=float) | ||
parser.add_argument("--seed", default=42, type=int) | ||
parser.add_argument("--step", default=-1, type=int) | ||
parser.add_argument("--custom_image_size", default=None, type=int) | ||
parser.add_argument("--share", action="store_true") | ||
parser.add_argument( | ||
"--shield_model_path", | ||
type=str, | ||
help="The path to shield model, we employ ShieldGemma-2B by default.", | ||
default="google/shieldgemma-2b", | ||
) | ||
|
||
return parser.parse_known_args()[0] | ||
|
||
|
||
args = get_args() | ||
|
||
if torch.cuda.is_available(): | ||
model_path = args.model_path | ||
pipe = SanaControlNetPipeline(args.config) | ||
pipe.from_pretrained(model_path) | ||
pipe.register_progress_bar(gr.Progress()) | ||
|
||
# safety checker | ||
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path) | ||
safety_checker_model = AutoModelForCausalLM.from_pretrained( | ||
args.shield_model_path, | ||
device_map="auto", | ||
torch_dtype=torch.bfloat16, | ||
).to(device) | ||
|
||
|
||
def save_image(img): | ||
if isinstance(img, dict): | ||
img = img["composite"] | ||
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | ||
img.save(temp_file.name) | ||
return temp_file.name | ||
|
||
|
||
def norm_ip(img, low, high): | ||
img.clamp_(min=low, max=high) | ||
img.sub_(low).div_(max(high - low, 1e-5)) | ||
return img | ||
|
||
|
||
@torch.no_grad() | ||
@torch.inference_mode() | ||
def run( | ||
image, | ||
prompt: str, | ||
prompt_template: str, | ||
sketch_thickness: int, | ||
guidance_scale: float, | ||
inference_steps: int, | ||
seed: int, | ||
blend_alpha: float, | ||
) -> tuple[Image, str]: | ||
|
||
print(f"Prompt: {prompt}") | ||
image_numpy = np.array(image["composite"].convert("RGB")) | ||
|
||
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628): | ||
return blank_image, "Please input the prompt or draw something." | ||
|
||
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2): | ||
prompt = "A red heart." | ||
|
||
prompt = prompt_template.format(prompt=prompt) | ||
pipe.set_blend_alpha(blend_alpha) | ||
start_time = time.time() | ||
images = pipe( | ||
prompt=prompt, | ||
ref_image=image["composite"], | ||
guidance_scale=guidance_scale, | ||
num_inference_steps=inference_steps, | ||
num_images_per_prompt=1, | ||
sketch_thickness=sketch_thickness, | ||
generator=torch.Generator(device=device).manual_seed(seed), | ||
) | ||
|
||
latency = time.time() - start_time | ||
|
||
if latency < 1: | ||
latency = latency * 1000 | ||
latency_str = f"{latency:.2f}ms" | ||
else: | ||
latency_str = f"{latency:.2f}s" | ||
torch.cuda.empty_cache() | ||
|
||
img = [ | ||
Image.fromarray( | ||
norm_ip(img, -1, 1) | ||
.mul(255) | ||
.add_(0.5) | ||
.clamp_(0, 255) | ||
.permute(1, 2, 0) | ||
.to("cpu", torch.uint8) | ||
.numpy() | ||
.astype(np.uint8) | ||
) | ||
for img in images | ||
] | ||
img = img[0] | ||
return img, latency_str | ||
|
||
|
||
model_size = "1.6" if "1600M" in args.model_path else "0.6" | ||
title = f""" | ||
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'> | ||
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/> | ||
</div> | ||
""" | ||
DESCRIPTION = f""" | ||
<p><span style="font-size: 36px; font-weight: bold;">Sana-ControlNet-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p> | ||
<p style="font-size: 18px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p> | ||
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p> | ||
<p style="font-size: 18px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}. | ||
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p> | ||
""" | ||
if model_size == "0.6": | ||
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>" | ||
if not torch.cuda.is_available(): | ||
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | ||
|
||
|
||
with gr.Blocks(css_paths="asset/app_styles/controlnet_app_style.css", title=f"Sana Sketch-to-Image Demo") as demo: | ||
gr.Markdown(title) | ||
gr.HTML(DESCRIPTION) | ||
|
||
with gr.Row(elem_id="main_row"): | ||
with gr.Column(elem_id="column_input"): | ||
gr.Markdown("## INPUT", elem_id="input_header") | ||
with gr.Group(): | ||
canvas = gr.Sketchpad( | ||
value=blank_image, | ||
height=640, | ||
image_mode="RGB", | ||
sources=["upload", "clipboard"], | ||
type="pil", | ||
label="Sketch", | ||
show_label=False, | ||
show_download_button=True, | ||
interactive=True, | ||
transforms=[], | ||
canvas_size=(1024, 1024), | ||
scale=1, | ||
brush=gr.Brush(default_size=3, colors=["#000000"], color_mode="fixed"), | ||
format="png", | ||
layers=False, | ||
) | ||
with gr.Row(): | ||
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6) | ||
run_button = gr.Button("Run", scale=1, elem_id="run_button") | ||
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch") | ||
with gr.Row(): | ||
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1) | ||
prompt_template = gr.Textbox( | ||
label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1 | ||
) | ||
|
||
with gr.Row(): | ||
sketch_thickness = gr.Slider( | ||
label="Sketch Thickness", | ||
minimum=1, | ||
maximum=4, | ||
step=1, | ||
value=2, | ||
) | ||
with gr.Row(): | ||
inference_steps = gr.Slider( | ||
label="Sampling steps", | ||
minimum=5, | ||
maximum=40, | ||
step=1, | ||
value=20, | ||
) | ||
guidance_scale = gr.Slider( | ||
label="CFG Guidance scale", | ||
minimum=1, | ||
maximum=10, | ||
step=0.1, | ||
value=4.5, | ||
) | ||
blend_alpha = gr.Slider( | ||
label="Blend Alpha", | ||
minimum=0, | ||
maximum=1, | ||
step=0.1, | ||
value=0, | ||
) | ||
with gr.Row(): | ||
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4) | ||
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed") | ||
|
||
with gr.Column(elem_id="column_output"): | ||
gr.Markdown("## OUTPUT", elem_id="output_header") | ||
with gr.Group(): | ||
result = gr.Image( | ||
format="png", | ||
height=640, | ||
image_mode="RGB", | ||
type="pil", | ||
label="Result", | ||
show_label=False, | ||
show_download_button=True, | ||
interactive=False, | ||
elem_id="output_image", | ||
) | ||
latency_result = gr.Text(label="Inference Latency", show_label=True) | ||
|
||
download_result = gr.DownloadButton("Download Result", elem_id="download_result") | ||
gr.Markdown("### Instructions") | ||
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)") | ||
gr.Markdown("**2**. Start sketching or upload a reference image") | ||
gr.Markdown("**3**. Change the image style using a style template") | ||
gr.Markdown("**4**. Try different seeds to generate different results") | ||
|
||
run_inputs = [canvas, prompt, prompt_template, sketch_thickness, guidance_scale, inference_steps, seed, blend_alpha] | ||
run_outputs = [result, latency_result] | ||
|
||
randomize_seed.click( | ||
lambda: random.randint(0, MAX_SEED), | ||
inputs=[], | ||
outputs=seed, | ||
api_name=False, | ||
queue=False, | ||
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False) | ||
|
||
style.change( | ||
lambda x: STYLES[x], | ||
inputs=[style], | ||
outputs=[prompt_template], | ||
api_name=False, | ||
queue=False, | ||
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False) | ||
gr.on( | ||
triggers=[prompt.submit, run_button.click, canvas.change], | ||
fn=run, | ||
inputs=run_inputs, | ||
outputs=run_outputs, | ||
api_name=False, | ||
) | ||
|
||
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch) | ||
download_result.click(fn=save_image, inputs=result, outputs=download_result) | ||
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility") | ||
|
||
|
||
if __name__ == "__main__": | ||
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share) |
Oops, something went wrong.