forked from wooyeolbaek/attention-map-diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patht2i.py
28 lines (24 loc) · 777 Bytes
/
t2i.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from diffusers import StableDiffusionXLPipeline
from utils import (
cross_attn_init,
register_cross_attention_hook,
attn_maps,
get_net_attn_map,
resize_net_attn_map,
save_net_attn_map,
)
cross_attn_init()
pipe = StableDiffusionXLPipeline.from_pretrained(
"/data/intern/dan/data/base/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
)
pipe.unet = register_cross_attention_hook(pipe.unet)
pipe = pipe.to("cuda")
prompt = "A photo of a black puppy, christmas atmosphere"
image = pipe(prompt).images[0]
image.save('test.png')
dir_name = "attn_maps"
net_attn_maps = get_net_attn_map(image.size)
net_attn_maps = resize_net_attn_map(net_attn_maps, image.size)
save_net_attn_map(net_attn_maps, dir_name, pipe.tokenizer, prompt)