From d245e2002fa6b2b0eb6826a954d738a6481c9505 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 9 Aug 2023 13:46:06 -0700 Subject: [PATCH] more types --- sgm/inference/api.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 668cc65d..e3f3d17d 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -18,6 +18,7 @@ LinearMultistepSampler, ) from sgm.util import load_model_from_config +import torch from typing import Optional, Dict, Any @@ -226,8 +227,8 @@ def text_to_image( negative_prompt: str = "", samples: int = 1, return_latents: bool = False, - noise_strength=None, - filter=None, + noise_strength: Optional[float] = None, + filter: Any = None, ): sampler = get_sampler_config(params) @@ -260,13 +261,13 @@ def text_to_image( def image_to_image( self, params: SamplingParams, - image, + image: torch.Tensor, prompt: str, negative_prompt: str = "", samples: int = 1, return_latents: bool = False, - noise_strength=None, - filter=None, + noise_strength: Optional[float] = None, + filter: Any = None, ): sampler = get_sampler_config(params) @@ -321,7 +322,7 @@ def wrap_discretization( def refiner( self, - image, + image: torch.Tensor, prompt: str, negative_prompt: str = "", params: SamplingParams = SamplingParams( @@ -329,8 +330,8 @@ def refiner( ), samples: int = 1, return_latents: bool = False, - filter=None, - add_noise=False, + filter: Any = None, + add_noise: bool = False, ): sampler = get_sampler_config(params) value_dict = {