Skip to content

Commit

Permalink
Merge pull request #91 from noskill/checkcomp
Browse files Browse the repository at this point in the history
check for allowed components in from_pipe method
  • Loading branch information
noskill authored Jan 6, 2025
2 parents e31ca25 + 09eec2f commit 0fddad9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
21 changes: 15 additions & 6 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,13 +539,22 @@ def __init__(self, *args, pipe: Optional[StableDiffusionImg2ImgPipeline] = None,
def _from_pipe(self, pipe, **args):
cls = pipe.__class__
if 'StableDiffusionXLPipeline' in str(cls) :
return self._classxl(**pipe.components, **args)
return self.__verify_from_pipe(self._classxl, pipe, **args)
elif 'StableDiffusionPipeline' in str(cls):
return self._class(**pipe.components, **args)
return self.__verify_from_pipe(self._class, pipe, **args)
elif 'Flux' in str(cls):
return self._classflux(**pipe.components, **args)
return self.__verify_from_pipe(self._classflux, pipe, **args)
raise RuntimeError(f"can't load pipeline from type {cls}")

def __verify_from_pipe(self, cls, pipe, **args):
allowed = util.get_allowed_components(cls)
source_components = set(pipe.components.keys())
target_components = set(allowed)

logging.debug("Missing components: ", target_components - source_components)
logging.debug("Extra components: ", source_components - target_components)
return cls(**{k: v for (k, v) in pipe.components.items() if k in allowed}, **args)

def setup(self, image=None, image_painted=None, mask=None, blur=4,
blur_compose=4, sample_mode='sample', scale=None, **kwargs):
"""
Expand Down Expand Up @@ -697,12 +706,12 @@ class Cond2ImPipe(BasePipe):
"inpaint": 1.0,
"qr": 1.5
})
cond_scales_defaults_flux = defaultdict(lambda: 0.8,

cond_scales_defaults_flux = defaultdict(lambda: 0.8,
{"canny-dev": 0.6})

def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] = None,
ctypes=["soft"], cnets: Optional[List[ControlNetModel]]=None,
ctypes=["soft"], cnets: Optional[List[ControlNetModel]]=None,
cnet_ids: Optional[List[str]]=None, model_type=None, **args):
"""
Constructor
Expand Down
7 changes: 7 additions & 0 deletions multigen/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import psutil
from PIL import Image
import copy as cp
from inspect import signature
import optimum.quanto
from optimum.quanto import freeze, qfloat8, quantize as _quantize
from diffusers.utils import is_accelerate_available
Expand Down Expand Up @@ -146,3 +147,9 @@ def weightshare_copy(pipe):
# some buffers might not be transfered from pipe to copy
copy.to(pipe.device)
return copy


def get_allowed_components(cls: type) -> dict:
params = signature(cls.__init__).parameters
components = [name for name in params.keys() if name != 'self']
return components

0 comments on commit 0fddad9

Please sign in to comment.