Skip to content

Commit

Permalink
some PR fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
palp committed Aug 10, 2023
1 parent b51c36b commit 8011d54
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 16 deletions.
7 changes: 3 additions & 4 deletions scripts/demo/streamlit_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,10 @@ def init_sampling(
)
)

params = get_discretization(params, key=key)
params = get_discretization(params=params, key=key)
params = get_guider(params=params, key=key)
params = get_sampler(params=params, key=key)

params = get_guider(key=key, params=params)

params = get_sampler(params, key=key)
return params, num_rows, num_cols


Expand Down
4 changes: 2 additions & 2 deletions sgm/inference/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def __init__(
model_spec: Optional[SamplingSpec] = None,
model_path: Optional[Union[str, pathlib.Path]] = None,
config_path: Optional[Union[str, pathlib.Path]] = None,
device: Union[str, torch.Device] = "cuda",
swap_device: Optional[Union[str, torch.Device]] = None,
device: Union[str, torch.device] = "cuda",
swap_device: Optional[Union[str, torch.device]] = None,
use_fp16: bool = True,
) -> None:
"""
Expand Down
20 changes: 10 additions & 10 deletions sgm/inference/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def do_sample(
with autocast(device) as precision_scope:
with model.ema_scope():
num_samples = [num_samples]
with SwapToDevice(model.conditioner, device):
with swap_to_device(model.conditioner, device):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
Expand Down Expand Up @@ -190,11 +190,11 @@ def denoiser(input, sigma, c):
model.model, input, sigma, c, **additional_model_inputs
)

with SwapToDevice(model.denoiser, device):
with SwapToDevice(model.model, device):
with swap_to_device(model.denoiser, device):
with swap_to_device(model.model, device):
samples_z = sampler(denoiser, randn, cond=c, uc=uc)

with SwapToDevice(model.first_stage_model, device):
with swap_to_device(model.first_stage_model, device):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)

Expand Down Expand Up @@ -294,7 +294,7 @@ def do_img2img(
with torch.no_grad():
with autocast(device):
with model.ema_scope():
with SwapToDevice(model.conditioner, device):
with swap_to_device(model.conditioner, device):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
Expand All @@ -314,7 +314,7 @@ def do_img2img(
if skip_encode:
z = img
else:
with SwapToDevice(model.first_stage_model, device):
with swap_to_device(model.first_stage_model, device):
z = model.encode_first_stage(img)

noise = torch.randn_like(z)
Expand All @@ -337,11 +337,11 @@ def do_img2img(
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)

with SwapToDevice(model.denoiser, device):
with SwapToDevice(model.model, device):
with swap_to_device(model.denoiser, device):
with swap_to_device(model.model, device):
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)

with SwapToDevice(model.first_stage_model, device):
with swap_to_device(model.first_stage_model, device):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)

Expand All @@ -354,7 +354,7 @@ def denoiser(x, sigma, c):


@contextlib.contextmanager
def SwapToDevice(
def swap_to_device(
model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str]
):
"""
Expand Down

0 comments on commit 8011d54

Please sign in to comment.