Skip to content

Commit

Permalink
Merge pull request #64 from noskill/quickfix
Browse files Browse the repository at this point in the history
use SDXL flag for controlnet loading
  • Loading branch information
Necr0x0Der authored Jun 18, 2024
2 parents 43d8eca + 57815d2 commit 496b1d7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions multigen/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .worker_base import ServiceThreadBase
from .prompting import Cfgen
from .sessions import GenSession
from .pipes import Prompt2ImPipe
from .pipes import Prompt2ImPipe, ControlnetType


class ServiceThread(ServiceThreadBase):
Expand All @@ -22,11 +22,14 @@ def _get_pipeline(self, pipe_class, model_id, cnet=None, xl=False):
pipe = pipe_class(model_id, pipe=pipeline)
else:
pipeline = self._loader.get_pipeline(model_id)
cnet_type = ControlnetType.SD
if xl:
cnet_type = ControlnetType.SDXL
if pipeline is None or 'controlnet' not in pipeline.components:
pipe = pipe_class(model_id, ctypes=[cnet])
pipe = pipe_class(model_id, ctypes=[cnet], model_type=cnet_type)
self._loader.register_pipeline(pipe.pipe, model_id)
else:
pipe = pipe_class(model_id, pipe=pipeline)
pipe = pipe_class(model_id, pipe=pipeline, model_type=cnet_type)
return pipe

def run(self):
Expand Down
2 changes: 1 addition & 1 deletion multigen/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, cfg_file):
for p in self.models['pipes']:
class_name = self.models['pipes'][p]['classname']
self._pipe_name_to_pipe[p] = globals()[class_name]
print(self._pipe_name_to_pipe)
self.logger.debug('pipe name to pipe %s', str(self._pipe_name_to_pipe))
self.logger.info('service is running')

@property
Expand Down

0 comments on commit 496b1d7

Please sign in to comment.