Skip to content

Commit

Permalink
align with streamlit helpers and re-de-deuplicate
Browse files Browse the repository at this point in the history
  • Loading branch information
palp committed Aug 6, 2023
1 parent 77d0e27 commit b216934
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 468 deletions.
24 changes: 18 additions & 6 deletions scripts/demo/sampling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from pytorch_lightning import seed_everything

from sgm.inference.helpers import (
do_img2img,
do_sample,
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)
from scripts.demo.streamlit_helpers import *

SAVE_PATH = "outputs/demo/txt2img/"
Expand Down Expand Up @@ -99,9 +105,7 @@ def load_img(display=True, key=None, device="cuda"):
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((width, height))
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
Expand Down Expand Up @@ -143,6 +147,8 @@ def run_txt2img(

if st.button("Sample"):
st.write(f"**Model I:** {version}")
outputs = st.empty()
st.text("Sampling")
out = do_sample(
state["model"],
sampler,
Expand All @@ -156,6 +162,9 @@ def run_txt2img(
return_latents=return_latents,
filter=filter,
)

show_samples(out, outputs)

return out


Expand Down Expand Up @@ -184,16 +193,16 @@ def run_img2img(
prompt=prompt,
negative_prompt=negative_prompt,
)
strength = st.number_input(
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
)
strength = st.number_input("**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0)
sampler, num_rows, num_cols = init_sampling(
img2img_strength=strength,
stage2strength=stage2strength,
)
num_samples = num_rows * num_cols

if st.button("Sample"):
outputs = st.empty()
st.text("Sampling")
out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples),
state["model"],
Expand All @@ -204,6 +213,7 @@ def run_img2img(
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)
return out


Expand Down Expand Up @@ -342,6 +352,7 @@ def apply_refiner(
samples_z = None

if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**")
samples = apply_refiner(
samples_z,
Expand All @@ -353,6 +364,7 @@ def apply_refiner(
filter=state.get("filter"),
finish_denoising=finish_denoising,
)
show_samples(samples, outputs)

if save_locally and samples is not None:
perform_save_locally(save_path, samples)
Loading

0 comments on commit b216934

Please sign in to comment.