Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Helpers + Streamlit refactor #102

Open
wants to merge 57 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
45feb6c
Use wrapper correctly in refiner helper
palp Aug 2, 2023
853adb4
Add defaults to refiner function
palp Aug 3, 2023
73287ec
Extract method for img2img wrapper
palp Aug 3, 2023
44943df
Allow loading custom models and improve path logic
palp Aug 3, 2023
baf79d2
black
palp Aug 4, 2023
4e2236f
Fix path logic for development installs
palp Aug 4, 2023
19fa4da
run black again
palp Aug 4, 2023
84d3a7f
fix fallback logic for config path
palp Aug 4, 2023
4aea6fa
Fix checkpoint loading too
palp Aug 4, 2023
77d0e27
format
palp Aug 4, 2023
b216934
align with streamlit helpers and re-de-deuplicate
palp Aug 6, 2023
f06c67c
formatting, remove reference
palp Aug 6, 2023
ea5f232
move conditioner to device
palp Aug 6, 2023
0c2c5c6
fix device check
palp Aug 6, 2023
451c76a
format
palp Aug 6, 2023
f2fba1d
fix noisy latent handling
palp Aug 6, 2023
8f8757b
version bump for changes to inference helpers
palp Aug 6, 2023
76ca428
fix path resolution bug
palp Aug 6, 2023
ced97f0
update defaults
palp Aug 6, 2023
6c18c84
rename ModelOnDevice to SwapToDevice
palp Aug 6, 2023
49fe53c
use env var for sgm checkpoints path
palp Aug 7, 2023
7e7fee3
system env var
palp Aug 7, 2023
c4b7baf
Streamlit refactor (#105)
palp Aug 7, 2023
a726ce3
replace usage of get
palp Aug 9, 2023
f86ffac
context manager
palp Aug 9, 2023
a009aa8
adding some typing
palp Aug 9, 2023
725bea9
pull in import fix
palp Aug 9, 2023
d245e20
more types
palp Aug 9, 2023
b51c36b
extract path resolution method, fix/improve device swapping support
palp Aug 10, 2023
8011d54
some PR fixes
palp Aug 10, 2023
fc498bf
remove duplicate imports
palp Aug 10, 2023
e190ecc
path helper & model swapping rewrite
palp Aug 10, 2023
47805f2
finish device manager refactor
palp Aug 10, 2023
9b18e6f
update api module
palp Aug 10, 2023
de7a627
more fixes and cleanup
palp Aug 10, 2023
3e7ada7
fix autocast
palp Aug 10, 2023
26b10f5
fix missing index
palp Aug 10, 2023
a25662e
low vram checkbox fix, remove magic strings
palp Aug 10, 2023
b3866d1
move checkbox out of cached resource
palp Aug 10, 2023
8839526
update helpers
palp Aug 10, 2023
3816aaa
simplify device_manager usage
palp Aug 10, 2023
2aebc88
split fp16 and swapping functionality
palp Aug 10, 2023
5c17043
change default
palp Aug 10, 2023
cd81956
text updates
palp Aug 10, 2023
d6f2b78
pass options into state2 init
palp Aug 10, 2023
fe46320
fix for orig dimensions
palp Aug 11, 2023
d4307be
Test model device manager and fix bugs
palp Aug 12, 2023
98c4b77
cleanup imports in test
palp Aug 12, 2023
f670453
abstract device defaults
palp Aug 12, 2023
c065573
fix streamlit inputs
palp Aug 12, 2023
fbe93fc
PR fixes, model specific defaults
palp Aug 12, 2023
5fde7e7
set a default scale
palp Aug 12, 2023
65c6ec1
run black
palp Aug 12, 2023
e32972b
remove extra init
palp Aug 12, 2023
2fc4680
Easier default params
palp Aug 12, 2023
e289621
fix reference
palp Aug 12, 2023
7ef5489
Merge branch 'main' into helpers-fixes
palp Aug 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions scripts/demo/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
)
from scripts.demo.streamlit_helpers import (
get_interactive_image,
get_unique_embedder_keys_from_conditioner,
init_embedder_options,
init_sampling,
init_save_locally,
init_st,
perform_save_locally,
set_lowvram_mode,
show_samples,
)
Expand Down
10 changes: 5 additions & 5 deletions scripts/demo/streamlit_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
SamplingPipeline,
Thresholder,
)
from sgm.inference.helpers import (
embed_watermark,
)
from sgm.inference.helpers import embed_watermark, CudaModelLoader


@st.cache_resource()
Expand All @@ -35,10 +33,12 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A

if lowvram_mode:
pipeline = SamplingPipeline(
model_spec=spec, use_fp16=True, device="cuda", swap_device="cpu"
model_spec=spec,
use_fp16=True,
model_loader=CudaModelLoader(device="cuda", swap_device="cpu"),
)
else:
pipeline = SamplingPipeline(model_spec=spec, use_fp16=True, device="cuda")
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)

state["spec"] = spec
state["model"] = pipeline
Expand Down
36 changes: 11 additions & 25 deletions sgm/inference/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from enum import Enum
from omegaconf import OmegaConf
import os
import pathlib
from sgm.inference.helpers import (
do_sample,
do_img2img,
BaseDeviceModelLoader,
CudaModelLoader,
Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
)
Expand All @@ -17,7 +18,7 @@
DPMPP2MSampler,
LinearMultistepSampler,
)
from sgm.util import load_model_from_config
from sgm.util import load_model_from_config, get_configs_path, get_checkpoints_path
import torch
from typing import Optional, Dict, Any, Union

Expand Down Expand Up @@ -163,11 +164,10 @@ def __init__(
self,
model_id: Optional[ModelArchitecture] = None,
model_spec: Optional[SamplingSpec] = None,
model_path: Optional[Union[str, pathlib.Path]] = None,
config_path: Optional[Union[str, pathlib.Path]] = None,
device: Union[str, torch.device] = "cuda",
swap_device: Optional[Union[str, torch.device]] = None,
model_path: Optional[str] = None,
config_path: Optional[str] = None,
use_fp16: bool = True,
model_loader: BaseDeviceModelLoader = CudaModelLoader(device="cuda"),
) -> None:
"""
Sampling pipeline for generating images from a model.
Expand All @@ -176,9 +176,8 @@ def __init__(
@param model_spec: Model specification to use. If not specified, model_id must be specified.
@param model_path: Path to model checkpoints folder.
@param config_path: Path to model config folder.
@param device: Device to use for sampling.
@param swap_device: Device to swap models to when not in use.
@param use_fp16: Whether to use fp16 for sampling.
@param model_loader: Model loader class to use. Defaults to CudaModelLoader.
"""

self.model_id = model_id
Expand All @@ -192,11 +191,11 @@ def __init__(
raise ValueError("Either model_id or model_spec should be provided")

if model_path is None:
model_path = self._resolve_default_path("checkpoints")
model_path = get_checkpoints_path()
if config_path is None:
config_path = self._resolve_default_path("configs/inference")
self.config = str(pathlib.Path(config_path) / self.specs.config)
self.ckpt = str(pathlib.Path(model_path) / self.specs.ckpt)
config_path = get_configs_path()
self.config = os.path.join(config_path, "inference", self.specs.config)
self.ckpt = os.path.join(model_path, self.specs.ckpt)
if not os.path.exists(self.config):
raise ValueError(
f"Config {self.config} not found, check model spec or config_path"
Expand All @@ -210,19 +209,6 @@ def __init__(
load_device = device if swap_device is None else swap_device
self.model = self._load_model(device=load_device, use_fp16=use_fp16)

def _resolve_default_path(self, suffix: str) -> pathlib.Path:
# Resolves a path relative to the root of the module or repo
repo_path = pathlib.Path(__file__).parent.parent.parent.resolve() / suffix
module_path = pathlib.Path(__file__).parent.parent.resolve() / suffix
path = module_path / suffix
if not os.path.exists(path):
path = repo_path / suffix
if not os.path.exists(path):
raise ValueError(
f"Default locations for {suffix} not found, please specify path"
)
return pathlib.Path(path)

def _load_model(self, device="cuda", use_fp16=True):
config = OmegaConf.load(self.config)
model = load_model_from_config(config, self.ckpt)
Expand Down
93 changes: 63 additions & 30 deletions sgm/inference/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig
from torch import autocast
from abc import ABC, abstractmethod

from sgm.util import append_dims

Expand Down Expand Up @@ -353,35 +354,67 @@ def denoiser(x, sigma, c):
return samples


@contextlib.contextmanager
def swap_to_device(
model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str]
):
class BaseDeviceModelLoader(ABC):
"""
Context manager that swaps a model or tensor to a device, and then swaps it back to its original device
when the context is exited.
Base class for device managers. Device managers are used to manage the device used for a model.
"""
if isinstance(model, torch.Tensor):
original_device = model.device
else:
param = next(model.parameters(), None)
if param is not None:
original_device = param.device
else:
buf = next(model.buffers(), None)
if buf is not None:
original_device = buf.device
else:
# If device could not be found, do nothing
return
device = torch.device(device)

if device != original_device:
model.to(device)

yield

if device != original_device:
model.to(original_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()

@abstractmethod
def __init__(self, device: Union[torch.device, str]):
"""
Args:
device (Union[torch.device, str]): The device to use for the model.
"""
pass

def load(self, model: torch.nn.Module):
"""
Loads a model to the device.
"""
pass

@contextlib.contextmanager
def use(self, model: torch.nn.Module):
"""
Context manager that ensures a model is on the correct device during use.
"""
yield


class CudaModelLoader(BaseDeviceModelLoader):
"""
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
"""

def __init__(
self,
device: Union[torch.device, str] = "cuda",
swap_device: Union[torch.device, str] = None,
):
"""
Args:
device (Union[torch.device, str]): The device to use for the model.
"""
self.device = torch.device(device)
self.swap_device = (
torch.device(swap_device) if swap_device is not None else self.device
)

def load(self, model: Union[torch.nn.Module, torch.Tensor]):
"""
Loads a model to the device.
"""
model.to(self.swap_device)

@contextlib.contextmanager
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
"""
Context manager that ensures a model is on the correct device during use.
"""
if self.device != self.swap_device:
model.to(self.device)
yield
if self.device != self.swap_device:
palp marked this conversation as resolved.
Show resolved Hide resolved
model.to(self.swap_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
palp marked this conversation as resolved.
Show resolved Hide resolved
18 changes: 18 additions & 0 deletions sgm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,24 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
return model


def get_checkpoints_path() -> str:
"""
Get the `checkpoints` directory.
This could be in the root of the repository for a working copy,
or in the cwd for other use cases.
"""
this_dir = os.path.dirname(__file__)
candidates = (
os.path.join(this_dir, "checkpoints"),
os.path.join(os.getcwd(), "checkpoints"),
)
for candidate in candidates:
candidate = os.path.abspath(candidate)
if os.path.isdir(candidate):
return candidate
raise FileNotFoundError(f"Could not find SGM checkpoints in {candidates}")


def get_configs_path() -> str:
"""
Get the `configs` directory.
Expand Down
Loading