forked from chenxwh/Kandinsky-2
-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict.py
72 lines (69 loc) · 2.4 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
from typing import List
from cog import BasePredictor, Input, Path
from kandinsky2 import get_kandinsky2
class Predictor(BasePredictor):
def setup(self):
self.model = get_kandinsky2(
"cuda",
task_type="text2img",
cache_dir="./kandinsky2-weights",
model_version="2.1",
use_flash_attention=False,
)
def predict(
self,
prompt: str = Input(description="Input Prompt", default="red cat, 4k photo"),
num_inference_steps: int = Input(
description="Number of denoising steps", ge=1, le=500, default=50
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance", ge=1, le=20, default=4
),
scheduler: str = Input(
description="Choose a scheduler",
default="p_sampler",
choices=["ddim_sampler", "p_sampler", "plms_sampler"],
),
prior_cf_scale: int = Input(default=4),
prior_steps: str = Input(default="5"),
width: int = Input(
description="Choose width. Lower the setting if out of memory.",
default=512,
choices=[256, 288, 432, 512, 576, 768, 1024],
),
height: int = Input(
description="Choose height. Lower the setting if out of memory.",
default=512,
choices=[256, 288, 432, 512, 576, 768, 1024],
),
batch_size: int = Input(
description="Choose batch size. Lower the setting if out of memory.",
default=1,
choices=[1, 2, 3, 4],
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> List[Path]:
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
images = self.model.generate_text2img(
prompt,
num_steps=num_inference_steps,
batch_size=batch_size,
guidance_scale=guidance_scale,
h=height,
w=width,
sampler=scheduler,
prior_cf_scale=prior_cf_scale,
prior_steps=prior_steps,
seed=seed,
)
output = []
for i, im in enumerate(images):
out = f"/tmp/out_{i}.png"
im.save(out)
im.save(f"out_{i}.png")
output.append(Path(out))
return output