Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save generation images (step, solo) #717

Closed
wants to merge 8 commits into from
7 changes: 7 additions & 0 deletions aaaaaa/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,3 +710,10 @@ def controlnet(w: Widgets, n: int, is_img2img: bool):
interactive=controlnet_exists,
elem_id=eid("ad_controlnet_guidance_end"),
)

with gr.Column(variant="compact"):
w.ad_solo_generation = gr.Checkbox(
label="Solo generation" + suffix(n),
value=False,
elem_id=eid("ad_solo_generation"),
)
2 changes: 2 additions & 0 deletions adetailer/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0
ad_controlnet_guidance_start: confloat(ge=0.0, le=1.0) = 0.0
ad_controlnet_guidance_end: confloat(ge=0.0, le=1.0) = 1.0
ad_solo_generation = False
is_api: bool = True

@validator("is_api", pre=True)
Expand Down Expand Up @@ -252,6 +253,7 @@ def need_skip(self) -> bool:
("ad_controlnet_weight", "ADetailer ControlNet weight"),
("ad_controlnet_guidance_start", "ADetailer ControlNet guidance start"),
("ad_controlnet_guidance_end", "ADetailer ControlNet guidance end"),
("ad_solo_generation", "ADetailer solo generation"),
]

_args = [Arg(*args) for args in _all_args]
Expand Down
96 changes: 88 additions & 8 deletions scripts/!adetailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ADetailerArgs,
InpaintBBoxMatchMode,
SkipImg2ImgOrig,
_all_args,
)
from adetailer.common import PredictOutput, ensure_pil_image, safe_mkdir
from adetailer.mask import (
Expand Down Expand Up @@ -563,7 +564,7 @@ def save_image(self, p, image, *, condition: str, suffix: str) -> None:
save_prompt = p.prompt
seed, _ = self.get_seed(p)

if opts.data.get(condition, False):
if condition is True or opts.data.get(condition, False):
ad_save_images_dir: str = opts.data.get("ad_save_images_dir", "")

if not ad_save_images_dir.strip():
Expand Down Expand Up @@ -885,6 +886,22 @@ def _postprocess_image_inner(

return False

def _postprocess_fork(self, init_image, p, pp: PPImage, args, n):
processed = False
save_incremental = None
save_solo = None
if not args.need_skip():
if not args.ad_solo_generation:
processed = self._postprocess_image_inner(p, pp, args, n=n[0])
if n[0] < n[1]:
save_incremental = (n[0], copy(pp.image))
elif not args.need_skip() and args.ad_solo_generation:
pp_solo = copy(pp)
pp_solo.image = init_image
if self._postprocess_image_inner(p, pp_solo, args, n=n[0]):
save_solo = (n[0], pp_solo.image, args)
return (processed, save_incremental, save_solo)

@rich_traceback
def postprocess_image(self, p, pp: PPImage, *args_):
if getattr(p, "_ad_disabled", False) or not self.is_ad_enabled(*args_):
Expand All @@ -901,17 +918,50 @@ def postprocess_image(self, p, pp: PPImage, *args_):
with preserve_prompts(p):
p.scripts.postprocess(copy(p), dummy)

last_index = self._find_last_index(arg_list)

save_incrementals = []
save_solos = []

is_processed = False
with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control():
for n, args in enumerate(arg_list):
if args.need_skip():
continue
is_processed |= self._postprocess_image_inner(p, pp, args, n=n)
fork_result = self._postprocess_fork(
init_image, p, pp, args, (n, last_index)
)
is_processed |= fork_result[0]
save_incrementals.append(fork_result[1])
save_solos.append(fork_result[2])

if is_processed:
if not is_skip_img2img(p):
self.save_image(
p,
init_image,
condition="ad_save_images_before",
suffix="-ad-before",
)

if is_processed and not is_skip_img2img(p):
self.save_image(
p, init_image, condition="ad_save_images_before", suffix="-ad-before"
)
for save in filter(None, save_incrementals):
self.save_image(
p,
save[1],
condition="ad_save_step_images",
suffix=f"-ad-step-{save[0]+1}",
)

all_extra_params = p.extra_generation_params
for save in filter(None, save_solos):
p.extra_generation_params = self._fix_extra_generation_params(
all_extra_params, save[2]
)
self.save_image(
p,
save[1],
condition=True,
suffix=f"-ad-solo-{save[0]+1}",
)
p.extra_generation_params = all_extra_params

if need_call_process(p):
with preserve_prompts(p):
Expand All @@ -921,6 +971,31 @@ def postprocess_image(self, p, pp: PPImage, *args_):

self.write_params_txt(params_txt_content)

def _find_last_index(self, arg_list):
last_index = 0
for n, args in enumerate(arg_list):
if not args.need_skip() and args.ad_solo_generation is False:
last_index = n
return last_index

def _fix_extra_generation_params(self, params: dict, args: ADetailerArgs):
ad_params = {}
for params_k in list(params.keys()):
found = False
for _, (_, v) in enumerate(_all_args):
if v in params_k:
found = True
break
if not found:
ad_params[params_k] = params[params_k]

for _, (k, v) in enumerate(_all_args):
if hasattr(args, k):
args_v = getattr(args, k)
if args_v is not None and args_v != "":
ad_params[v] = args_v
return ad_params


def on_after_component(component, **_kwargs):
global txt2img_submit_button, img2img_submit_button
Expand Down Expand Up @@ -977,6 +1052,11 @@ def on_ui_settings():
shared.OptionInfo(False, "Save images before ADetailer", section=section),
)

shared.opts.add_option(
"ad_save_step_images",
shared.OptionInfo(False, "Save incremental step images", section=section),
)

shared.opts.add_option(
"ad_only_selected_scripts",
shared.OptionInfo(
Expand Down
Loading