-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathdepth_txt2img.py
65 lines (56 loc) · 2.25 KB
/
depth_txt2img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import sys
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler
import numpy as np
from PIL import Image
import torch
# REVIEW
_DEFAULT_DEVICE = (
'cuda' if torch.cuda.is_available()
else 'mps' if torch.backends.mps.is_available()
else 'cpu'
)
class TextToObjectImage:
def __init__(
self,
device=_DEFAULT_DEVICE,
model='Lykon/dreamshaper-8',
cn_model='lllyasviel/control_v11p_sd15_normalbae',
):
controlnet = ControlNetModel.from_pretrained(cn_model, torch_dtype=torch.float16, variant='fp16')
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
model, controlnet=controlnet, torch_dtype=torch.float16, variant='fp16',
safety_checker=None,
)
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
self.pipe = self.pipe.to(device)
def generate(self, desc: str, steps: int, control_image: Image):
return self.pipe(
prompt=f'{desc}, front and back view, 180, reverse, 3D rendering, high quality 4K, flat',
negative_prompt='lighting, shadows, grid, dark, mesh',
num_inference_steps=steps,
num_images_per_prompt=1,
image=control_image,
width=control_image.width,
height=control_image.height,
).images[0]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('desc', help='Short description of desired model appearance')
parser.add_argument('depth_img', help='Depth control image')
parser.add_argument('output_path', help='Path for generated image')
parser.add_argument(
'--image-model',
help='SD 1.5-based model for texture image gen',
default='Lykon/dreamshaper-8',
)
parser.add_argument('--steps', type=int, default=12)
parser.add_argument(
'--device',
default=_DEFAULT_DEVICE,
type=str,
help='Device to prefer. Default: try to auto-detect from platform (CUDA, Metal)'
)
args = parser.parse_args()
t2i = TextToObjectImage(args.device, args.image_model)
t2i.generate(args.desc, args.steps, Image.open(args.depth_img)).save(args.output_path)