diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index b37e2ebd..fec7d33e 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -36,7 +36,7 @@ def init_st( pipeline = SamplingPipeline( model_spec=spec, use_fp16=True, - device_manager=CudaModelManager(device="cuda", swap_device="cpu"), + device=CudaModelManager(device="cuda", swap_device="cpu"), ) else: pipeline = SamplingPipeline(model_spec=spec, use_fp16=False) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 96aead65..e680dc55 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -57,6 +57,13 @@ class Thresholder(str, Enum): @dataclass class SamplingParams: + """ + Parameters for sampling. + The defaults here are derived from user preference testing. + They will be subject to change in the future, likely pulled + from model specs instead of global defaults. + """ + width: int = 1024 height: int = 1024 steps: int = 40 @@ -167,7 +174,9 @@ def __init__( model_path: Optional[str] = None, config_path: Optional[str] = None, use_fp16: bool = True, - device_manager: DeviceModelManager = CudaModelManager(device="cuda"), + device: Union[DeviceModelManager, str, torch.device] = CudaModelManager( + device="cuda" + ), ) -> None: """ Sampling pipeline for generating images from a model. @@ -177,7 +186,7 @@ def __init__( @param model_path: Path to model checkpoints folder. @param config_path: Path to model config folder. @param use_fp16: Whether to use fp16 for sampling. - @param model_loader: Model loader class to use. Defaults to CudaModelLoader. + @param device: Device manager to use with this pipeline. If a string or torch.device is passed, a device manager will be created based on device type if possible. """ self.model_id = model_id @@ -205,7 +214,13 @@ def __init__( f"Checkpoint {self.ckpt} not found, check model spec or config_path" ) - self.device_manager = device_manager + if isinstance(device, torch.device) or isinstance(device, str): + if torch.device(device).type == "cuda": + self.device_manager = CudaModelManager(device=device) + else: + self.device_manager = DeviceModelManager(device=device) + else: + self.device_manager = device self.model = self._load_model( device_manager=self.device_manager, use_fp16=use_fp16 )