diff --git a/docker-entrypoint.py b/docker-entrypoint.py index eff129c..310d0a1 100755 --- a/docker-entrypoint.py +++ b/docker-entrypoint.py @@ -3,7 +3,11 @@ import torch from PIL import Image from torch import autocast -from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline +from diffusers import ( + StableDiffusionPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, +) def cuda_device(): @@ -24,16 +28,21 @@ def skip_safety_checker(images, *args, **kwargs): return images, False -def stable_diffusion_pipeline(model, image, half, skip, do_slice, token): +def stable_diffusion_pipeline(model, image, mask, half, skip, do_slice, token): if token is None: with open("token.txt") as f: token = f.read().replace("\n", "") diffuser = StableDiffusionPipeline + if image is not None: diffuser = StableDiffusionImg2ImgPipeline image = load_image(image) + if mask is not None: + diffuser = StableDiffusionInpaintPipeline + mask = load_image(mask) + dtype, rev = (torch.float16, "fp16") if half else (torch.float32, "main") print("load pipeline start:", iso_date_time()) @@ -50,7 +59,7 @@ def stable_diffusion_pipeline(model, image, half, skip, do_slice, token): print("loaded models after:", iso_date_time()) - return pipeline, image + return pipeline, image, mask def stable_diffusion_inference( @@ -58,6 +67,7 @@ def stable_diffusion_inference( prompt, neg_prompt, image, + mask, samples, iters, height, @@ -79,6 +89,8 @@ def stable_diffusion_inference( prompt, negative_prompt=neg_prompt, init_image=image, + image=image, + mask_image=mask, height=height, width=width, num_images_per_prompt=samples, @@ -162,7 +174,13 @@ def main(): "--image", type=str, nargs="?", - help="The input filename to use for image-to-image diffusion", + help="The input image to use for image-to-image diffusion", + ) + parser.add_argument( + "--mask", + type=str, + nargs="?", + help="The input mask to use for diffusion inpainting", ) parser.add_argument( "--model", @@ -200,8 +218,14 @@ def main(): if args.prompt0 is not None: args.prompt = args.prompt0 - pipeline, image = stable_diffusion_pipeline( - args.model, args.image, args.half, args.skip, args.attention_slicing, args.token + pipeline, image, mask = stable_diffusion_pipeline( + args.model, + args.image, + args.mask, + args.half, + args.skip, + args.attention_slicing, + args.token, ) stable_diffusion_inference( @@ -209,6 +233,7 @@ def main(): args.prompt, args.negative_prompt, image, + mask, args.n_samples, args.n_iter, args.H,