Skip to content

Commit

Permalink
Merge pull request #93 from noskill/main
Browse files Browse the repository at this point in the history
support for batch generation
  • Loading branch information
Necr0x0Der authored Jan 23, 2025
2 parents 4256ddc + c784cdd commit 00d6cf1
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 107 deletions.
63 changes: 39 additions & 24 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,14 @@ def prepare_inputs(self, inputs):
self.try_set_scheduler(inputs)
return kwargs

@property
def pad(self):
pad = 8
if hasattr(self.pipe, 'image_processor'):
if hasattr(self.pipe.image_processor, 'vae_scale_factor'):
pad = self.pipe.image_processor.vae_scale_factor
return pad


class Prompt2ImPipe(BasePipe):
"""
Expand Down Expand Up @@ -411,7 +419,7 @@ def gen(self, inputs: dict):
"""
kwargs = self.prepare_inputs(inputs)
logging.debug("Prompt2ImPipe.gen calling pipe")
image = self.pipe(**kwargs).images[0]
image = self.pipe(**kwargs).images
return image


Expand Down Expand Up @@ -446,7 +454,7 @@ def setup(self, fimage, image=None, strength=0.75,
self._input_image = self.scale_image(self._input_image, scale)
self._original_size = self._input_image.size
logging.debug("origin image size {self._original_size}")
self._input_image = util.pad_image_to_multiple_of_8(self._input_image)
self._input_image = util.pad_image_to_multiple(self._input_image, self.pad)
self.pipe_params.update({
"width": self._input_image.width if width is None else width,
"height": self._input_image.height if height is None else height,
Expand Down Expand Up @@ -501,10 +509,12 @@ def gen(self, inputs: dict):
# so we update kwargs with inputs after pipe_params
kwargs.update({"image": self._input_image})
self.try_set_scheduler(kwargs)
image = self.pipe(**kwargs).images[0]
logging.debug(f'generated image {image}')
result = image.crop((0, 0, self._original_size[0], self._original_size[1]))
return result
res = []
for image in self.pipe(**kwargs).images:
logging.debug(f'generated image {image}')
result = image.crop((0, 0, self._original_size[0], self._original_size[1]))
res.append(result)
return res


class MaskedIm2ImPipe(Im2ImPipe):
Expand Down Expand Up @@ -555,8 +565,8 @@ def __verify_from_pipe(self, cls, pipe, **args):
source_components = set(pipe.components.keys())
target_components = set(allowed)

logging.debug("Missing components: ", target_components - source_components)
logging.debug("Extra components: ", source_components - target_components)
logging.debug("Missing components: " + str(target_components - source_components))
logging.debug("Extra components: " + str(source_components - target_components))
return cls(**{k: v for (k, v) in pipe.components.items() if k in allowed}, **args)

def setup(self, image=None, image_painted=None, mask=None, blur=4,
Expand Down Expand Up @@ -597,12 +607,13 @@ def setup(self, image=None, image_painted=None, mask=None, blur=4,
input_image = self._image_painted if self._image_painted is not None else self._original_image

super().setup(fimage=None, image=input_image, scale=scale, **kwargs)

if self._original_image is not None:
self._original_image = self.scale_image(self._original_image, scale)
self._original_image = util.pad_image_to_multiple_of_8(self._original_image)
self._original_image = util.pad_image_to_multiple(self._original_image, self.pad)
if self._image_painted is not None:
self._image_painted = self.scale_image(self._image_painted, scale)
self._image_painted = util.pad_image_to_multiple_of_8(self._image_painted)
self._image_painted = util.pad_image_to_multiple(self._image_painted, self.pad)

# there are two options:
# 1. mask is provided
Expand All @@ -619,7 +630,7 @@ def setup(self, image=None, image_painted=None, mask=None, blur=4,
pil_mask = Image.fromarray(mask)
if pil_mask.mode != "L":
pil_mask = pil_mask.convert("L")
pil_mask = util.pad_image_to_multiple_of_8(pil_mask)
pil_mask = util.pad_image_to_multiple(pil_mask, self.pad)
self._mask = pil_mask
self._mask_blur = self.blur_mask(pil_mask, blur)
self._mask_compose = self.blur_mask(pil_mask.crop((0, 0, self._original_size[0], self._original_size[1]))
Expand All @@ -644,13 +655,15 @@ def gen(self, inputs):
if 'sample_mode' not in inputs:
inputs['sample_mode'] = self._sample_mode
inputs['original_image'] = normalised
img_gen = super().gen(inputs)

# compose with original using mask
img_compose = self._mask_compose * img_gen + (1 - self._mask_compose) * self._original_image.crop((0, 0, self._original_size[0], self._original_size[1]))
# convert to PIL image
img_compose = Image.fromarray(img_compose.astype(np.uint8))
return img_compose
images = super().gen(inputs)
res = []
for img_gen in images:
# compose with original using mask
img_compose = self._mask_compose * img_gen + (1 - self._mask_compose) * self._original_image.crop((0, 0, self._original_size[0], self._original_size[1]))
# convert to PIL image
img_compose = Image.fromarray(img_compose.astype(np.uint8))
res.append(img_compose)
return res


class Cond2ImPipe(BasePipe):
Expand Down Expand Up @@ -863,7 +876,7 @@ def setup(self, fimage, width=None, height=None,
image = Image.open(fimage).convert("RGB") if image is None else image
self._original_size = image.size
self._use_input_size = width is None or height is None
image = util.pad_image_to_multiple_of_8(image)
image = util.pad_image_to_multiple(image, self.pad)
self._condition_image = [image]
self._input_image = [image]
if cscales is None:
Expand Down Expand Up @@ -914,10 +927,12 @@ def gen(self, inputs):
inputs = self.prepare_inputs(inputs)
inputs.update({"image": self._input_image,
"control_image": self._condition_image})
image = self.pipe(**inputs).images[0]
result = image.crop((0, 0, self._original_size[0] if self._use_input_size else inputs.get('height'),
res = []
for image in self.pipe(**inputs).images:
result = image.crop((0, 0, self._original_size[0] if self._use_input_size else inputs.get('height'),
self._original_size[1] if self._use_input_size else inputs.get('width') ))
return result
res.append(result)
return res


class CIm2ImPipe(Cond2ImPipe):
Expand Down Expand Up @@ -1038,7 +1053,7 @@ def _proc_cimg(self, oriImg):
condition_image += [Image.fromarray(formatted)]
else:
condition_image += [Image.fromarray(oriImg)]
return condition_image
return [c.resize((oriImg.shape[1], oriImg.shape[0])) for c in condition_image]


class InpaintingPipe(MaskedIm2ImPipe):
Expand Down Expand Up @@ -1126,7 +1141,7 @@ def gen(self, inputs):
"mask_image": self._mask_image,
"control_image": self._control_image
})
image = self.pipe(**inputs).images[0]
image = self.pipe(**inputs).images
return image

def _make_inpaint_condition(self, image, image_mask):
Expand Down
133 changes: 87 additions & 46 deletions multigen/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@


class GenSession:

def __init__(self, session_dir, pipe, config: Cfgen, name_prefix=""):
"""
Initialize a GenSession instance.
Args:
session_dir (str):
The directory to store the session files.
Expand All @@ -28,6 +26,8 @@ def __init__(self, session_dir, pipe, config: Cfgen, name_prefix=""):
self.confg = config
self.last_conf = None
self.name_prefix = name_prefix
# Check if sequential CPU offloading is enabled
self.offload_gpu_id = getattr(pipe, 'offload_gpu_id', None)

def get_last_conf(self):
conf = {**self.last_conf}
Expand All @@ -36,9 +36,9 @@ def get_last_conf(self):
'feedback': '?',
'cversion': "0.0.1"})
return conf

def get_last_file_prefix(self):
idxs = self.name_prefix + str(self.last_index).zfill(5)
def get_file_prefix(self, index):
idxs = self.name_prefix + str(index).zfill(5)
f_prefix = os.path.join(self.session_dir, idxs)
if os.path.isfile(f_prefix + ".txt"):
cnt = 1
Expand All @@ -47,17 +47,15 @@ def get_last_file_prefix(self):
f_prefix += "_" + str(cnt)
return f_prefix

def save_last_conf(self):
self.last_cfg_name = self.get_last_file_prefix() + ".txt"
with open(self.last_cfg_name, 'w') as f:
print(json.dumps(self.get_last_conf(), indent=4), file=f)

def gen_sess(self, add_count = 0, save_img=True,
drop_cfg=False, force_collect=False,
callback=None, save_metadata=False):
def save_conf(self, index, conf):
cfg_name = self.get_file_prefix(index) + ".txt"
with open(cfg_name, 'w') as f:
print(json.dumps(conf, indent=4), file=f)

def gen_sess(self, add_count=0, save_img=True, drop_cfg=False,
force_collect=False, callback=None, save_metadata=False):
"""
Run image generation session
Run image generation session.
Args:
add_count (int, *optional*):
The number of additional iterations to add. Defaults to 0.
Expand All @@ -71,43 +69,86 @@ def gen_sess(self, add_count = 0, save_img=True,
A callback function to be called after each iteration. Defaults to None.
save_metadata (bool, *optional*):
Whether to save metadata in the image EXIF. Defaults to False.
Returns:
List[Image.Image]: The generated images if `save_img` is False or `force_collect` is True.
"""
self.confg.max_count += add_count
self.confg.start_count = self.confg.count
self.last_img_name = None
self.last_cfg_name = None
images = None
images = []

if save_img:
os.makedirs(self.session_dir, exist_ok=True)
# collecting images to return if requested or images are not saved
if not save_img or force_collect:
images = []
logging.info(f"add count = {add_count}")
jk = 0
for inputs in self.confg:
self.last_index = self.confg.count - 1
self.last_conf = {**inputs}
# TODO: multiple inputs?
inputs['generator'] = torch.Generator().manual_seed(inputs['generator'])
logging.debug("start generation")
image = self.pipe.gen(inputs)
if save_img:
self.last_img_name = self.get_last_file_prefix() + ".png"
exif = None
if save_metadata:
exif = util.create_exif_metadata(image, json.dumps(self.get_last_conf()))
image.save(self.last_img_name, exif=exif)
if not save_img or force_collect:
images += [image]
# saving cfg only if images are saved and dropping is not requested
if save_img and not drop_cfg:
self.save_last_conf()
if callback is not None:
logging.debug("call callback after generation")
callback()
jk += 1
logging.debug(f"done iteration {jk}")
return images

# Determine batch size
if self.offload_gpu_id is not None:
# Sequential CPU offloading is enabled, set batch_size to a reasonable number
batch_size = 8 # You can adjust this value based on your environment
else:
batch_size = 1 # Process one input at a time

logging.info(f"Starting generation with batch_size = {batch_size}")
confg_iter = iter(self.confg)
index = self.confg.start_count

while True:
batch_inputs_list = []
# Collect inputs into batch
for _ in range(batch_size):
try:
inputs = next(confg_iter)
except StopIteration:
break # No more inputs
batch_inputs_list.append(inputs)

if not batch_inputs_list:
break # All inputs have been processed

# Prepare batch inputs
batch_inputs_dict = {}
for key in batch_inputs_list[0]:
batch_inputs_dict[key] = [input[key] for input in batch_inputs_list]

# Adjust 'generator' field with manual seeds
batch_generators = []
for seed in batch_inputs_dict.get('generator', [None] * len(batch_inputs_list)):
if seed is not None:
batch_generators.append(torch.Generator().manual_seed(seed))
else:
batch_generators.append(torch.Generator())
batch_inputs_dict['generator'] = batch_generators

# Generate images
batch_images = self.pipe.gen(batch_inputs_dict)

# Process generated images
for i, image in enumerate(batch_images):
idx = index + i
self.last_index = idx
self.last_conf = {**batch_inputs_list[i % len(batch_inputs_list)]}
self.last_conf.update(self.pipe.get_config())
self.last_conf.update({'feedback': '?', 'cversion': '0.0.1'})

if save_img:
f_prefix = self.get_file_prefix(idx)
img_name = f_prefix + ".png"
exif = None
if save_metadata:
exif = util.create_exif_metadata(image, json.dumps(self.get_last_conf()))
image.save(img_name, exif=exif)
self.last_img_name = img_name
if not drop_cfg:
# Save configuration
self.save_conf(idx, self.get_last_conf())
if not save_img or force_collect:
images.append(image)
if callback is not None:
logging.debug("Call callback after generation")
callback()

index += len(batch_images)
logging.debug(f"Processed batch up to index {index}")

logging.debug(f"Generation session completed.")
return images if images else None
Loading

0 comments on commit 00d6cf1

Please sign in to comment.