Skip to content

Commit

Permalink
Init image (#27)
Browse files Browse the repository at this point in the history
adds init_image and init_mask functionality
  • Loading branch information
dmarx authored Sep 7, 2022
1 parent fd78f4d commit 2cd72cc
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 178 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Install the PyPI package via:

* Use Python venv: `python3 -m venv pyenv`
* Set up in venv dependencies: `pyenv/bin/pip3 install -r requirements.txt`
* `pyenv/bin/enable` to use the venv.
* `pyenv/bin/activate` to use the venv.
* Set the `STABILITY_HOST` environment variable. This is by default set to the production endpoint `grpc.stability.ai:443`.
* Set the `STABILITY_KEY` environment variable.

Expand All @@ -31,11 +31,11 @@ See usage demo notebooks in ./nbs

## Command line usage
```
usage: client.py [-h] [--height HEIGHT] [--width WIDTH]
[--cfg_scale CFG_SCALE] [--sampler SAMPLER] [--steps STEPS]
[--seed SEED] [--prefix PREFIX] [--no-store]
[--num_samples NUM_SAMPLES] [--show]
prompt [prompt ...]
usage: python -m stability_sdk.client [-h] [--height HEIGHT] [--width WIDTH]
[--cfg_scale CFG_SCALE] [--sampler SAMPLER] [--steps STEPS]
[--seed SEED] [--prefix PREFIX] [--no-store]
[--num_samples NUM_SAMPLES] [--show]
prompt [prompt ...]
positional arguments:
prompt
Expand Down
295 changes: 167 additions & 128 deletions nbs/demo_colab.ipynb

Large diffs are not rendered by default.

110 changes: 96 additions & 14 deletions src/stability_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Dict, Generator, List, Union, Any, Sequence, Tuple
from dotenv import load_dotenv
from google.protobuf.json_format import MessageToJson
from PIL import Image

load_dotenv()

Expand All @@ -38,6 +39,28 @@
"k_lms": generation.SAMPLER_K_LMS,
}

def image_to_prompt(im, init: bool = False, mask: bool = False) -> Tuple[str, generation.Prompt]:
if init and mask:
raise ValueError("init and mask cannot both be True")
buf = io.BytesIO()
im.save(buf, format='PNG')
buf.seek(0)
if mask:
return generation.Prompt(
artifact=generation.Artifact(
type=generation.ARTIFACT_MASK,
binary=buf.getvalue()
)
)
return generation.Prompt(
artifact=generation.Artifact(
type=generation.ARTIFACT_IMAGE,
binary=buf.getvalue()
),
parameters=generation.PromptParameters(
init=init
),
)

def get_sampler_from_str(s: str) -> generation.DiffusionSampler:
"""
Expand Down Expand Up @@ -129,7 +152,7 @@ def __init__(
self,
host: str = "grpc.stability.ai:443",
key: str = "",
engine: str = "stable-diffusion-v1",
engine: str = "stable-diffusion-v1-5",
verbose: bool = False,
wait_for_ready: bool = True,
):
Expand Down Expand Up @@ -177,8 +200,12 @@ def __init__(
def generate(
self,
prompt: Union[List[str], str],
init_image: Image.Image = None,
mask_image: Image.Image = None,
height: int = 512,
width: int = 512,
start_schedule: float = 0.5,
end_schedule: float = 0.01,
cfg_scale: float = 7.0,
sampler: generation.DiffusionSampler = generation.SAMPLER_K_LMS,
steps: int = 50,
Expand All @@ -191,8 +218,12 @@ def generate(
Generate images from a prompt.
:param prompt: Prompt to generate images from.
:param init_image: Init image.
:param mask_image: Mask image
:param height: Height of the generated images.
:param width: Width of the generated images.
:param start_schedule: Start schedule for init image.
:param end_schedule: End schedule for init image.
:param cfg_scale: Scale of the configuration.
:param sampler: Sampler to use.
:param steps: Number of steps to take.
Expand All @@ -205,18 +236,46 @@ def generate(
if safety and classifiers is None:
classifiers = generation.ClassifierParameters()

if not prompt:
raise ValueError("prompt must be provided")
if not prompt and not init_image:
raise ValueError("prompt and/or init_image must be provided")

if mask_image and not init_image:
raise ValueError("If mask_image is provided, init_image must also be provided")

request_id = str(uuid.uuid4())

if not seed:
seed = [random.randrange(0, 4294967295)]
else:
seed = [seed]

if isinstance(prompt, str):
prompt = [generation.Prompt(text=prompt)]
else:
elif isinstance(prompt, Sequence):
prompt = [generation.Prompt(text=p) for p in prompt]
else:
raise TypeError("prompt must be a string or a sequence")

if init_image:
prompt += [image_to_prompt(init_image, init=True)]
parameters = generation.StepParameter(
scaled_step=0,
sampler=generation.SamplerParameters(
cfg_scale=cfg_scale,
),
schedule=generation.ScheduleParameters(
start=start_schedule,
end=end_schedule,
)
),
if mask_image:
prompt += [image_to_prompt(mask_image, mask=True)]
else:
parameters = generation.StepParameter(
scaled_step=0,
sampler=generation.SamplerParameters(
cfg_scale=cfg_scale),
),

rq = generation.Request(
engine_id=self.engine,
Expand All @@ -229,12 +288,7 @@ def generate(
seed=seed,
steps=steps,
samples=samples,
parameters=[
generation.StepParameter(
scaled_step=0,
sampler=generation.SamplerParameters(cfg_scale=cfg_scale),
)
],
parameters=parameters,
),
classifier=classifiers,
)
Expand Down Expand Up @@ -272,11 +326,15 @@ def build_request_dict(cli_args: Namespace) -> Dict[str, Any]:
return {
"height": cli_args.height,
"width": cli_args.width,
"start_schedule": cli_args.start_schedule,
"end_schedule": cli_args.end_schedule,
"cfg_scale": cli_args.cfg_scale,
"sampler": get_sampler_from_str(cli_args.sampler),
"steps": cli_args.steps,
"seed": cli_args.seed,
"samples": cli_args.num_samples,
"init_image": cli_args.init_image,
"mask_image": cli_args.mask_image,
}


Expand Down Expand Up @@ -312,6 +370,14 @@ def build_request_dict(cli_args: Namespace) -> Dict[str, Any]:
parser.add_argument(
"--width", "-W", type=int, default=512, help="[512] width of image"
)
parser.add_argument(
"--start_schedule",
type=float, default=0.5, help="[0.5] start schedule for init image (must be greater than 0, 1 is full strength text prompt, no trace of image)"
)
parser.add_argument(
"--end_schedule",
type=float, default=0.01, help="[0.01] end schedule for init image"
)
parser.add_argument(
"--cfg_scale", "-C", type=float, default=7.0, help="[7.0] CFG scale factor"
)
Expand Down Expand Up @@ -345,17 +411,33 @@ def build_request_dict(cli_args: Namespace) -> Dict[str, Any]:
"-e",
type=str,
help="engine to use for inference",
default="stable-diffusion-v1",
default="stable-diffusion-v1-5",
)
parser.add_argument(
"--init_image", "-i",
type=str,
help="Init image",
)
parser.add_argument(
"--mask_image", "-m",
type=str,
help="Mask image",
)
parser.add_argument("prompt", nargs="+")
parser.add_argument("prompt", nargs="*")

args = parser.parse_args()
if not args.prompt:
logger.warning("prompt must be provided")
if not args.prompt and not args.init_image:
logger.warning("prompt or init image must be provided")
parser.print_help()
sys.exit(1)
else:
args.prompt = " ".join(args.prompt)

if args.init_image:
args.init_image = Image.open(args.init_image)

if args.mask_image:
args.mask_image = Image.open(args.mask_image)

request = build_request_dict(args)

Expand Down
Loading

0 comments on commit 2cd72cc

Please sign in to comment.