diff --git a/birefnetNode.py b/birefnetNode.py index a07b92f..87a03fb 100644 --- a/birefnetNode.py +++ b/birefnetNode.py @@ -164,7 +164,7 @@ def load_model(self, model, device, use_weight=False): return [(biRefNet_model, version)] -class RembgByBiRefNetAdvanced: +class GetMaskByBiRefNet: @classmethod def INPUT_TYPES(cls): @@ -191,26 +191,22 @@ def INPUT_TYPES(cls): "default": "bilinear", "tooltip": "Interpolation method for post-processing mask" }), - "blur_size": ("INT", {"default": 91, "min": 1, "max": 255, "step": 2, }), - "blur_size_two": ("INT", {"default": 7, "min": 1, "max": 255, "step": 2, }), - "fill_color": ("BOOLEAN", {"default": False}), - "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), - "mask_threshold": ("FLOAT", {"default": 0.000, "min": 0.0, "max": 1.0, "step": 0.001, }), + "mask_threshold": ("FLOAT", {"default": 0.000, "min": 0.0, "max": 1.0, "step": 0.004, }), } } - RETURN_TYPES = ("IMAGE", "MASK",) - RETURN_NAMES = ("image", "mask",) - FUNCTION = "rem_bg" + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("mask",) + FUNCTION = "get_mask" CATEGORY = "rembg/BiRefNet" - def rem_bg(self, model, images, upscale_method='bilinear', width=1024, height=1024, blur_size=91, blur_size_two=7, fill_color=False, color=None, mask_threshold=0.000): + def get_mask(self, model, images, width=1024, height=1024, upscale_method='bilinear', mask_threshold=0.000): model, version = model model_device_type = next(model.parameters()).device.type b, h, w, c = images.shape image_bchw = images.permute(0, 3, 1, 2) - image_preproc = ImagePreprocessor(resolution=(1024, 1024)) + image_preproc = ImagePreprocessor(resolution=(height, width)) if VERSION[0] == version: im_tensor = image_preproc.old_proc(image_bchw) else: @@ -226,12 +222,48 @@ def rem_bg(self, model, images, upscale_method='bilinear', width=1024, height=10 mask_bchw = torch.cat(_mask_bchw, dim=0) del _mask_bchw # 遮罩大小需还原为与原图一致 - mask = comfy.utils.common_upscale(mask_bchw, w, h, 'bilinear', "disabled") + mask = comfy.utils.common_upscale(mask_bchw, w, h, upscale_method, "disabled") # (b, 1, h, w) if mask_threshold > 0: - out_masks = filter_mask(mask, threshold=mask_threshold) - else: - out_masks = normalize_mask(mask) + mask = filter_mask(mask, threshold=mask_threshold) + # else: + # 似乎几乎无影响 + # mask = normalize_mask(mask) + + return mask.squeeze(1), + + +class BlurFusionForegroundEstimation: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE",), + "masks": ("MASK",), + "blur_size": ("INT", {"default": 91, "min": 1, "max": 255, "step": 2, }), + "blur_size_two": ("INT", {"default": 7, "min": 1, "max": 255, "step": 2, }), + "fill_color": ("BOOLEAN", {"default": False}), + "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), + } + } + + RETURN_TYPES = ("IMAGE", "MASK",) + RETURN_NAMES = ("image", "mask",) + FUNCTION = "get_foreground" + CATEGORY = "rembg/BiRefNet" + DESCRIPTION = "Approximate Fast Foreground Colour Estimation. https://github.com/Photoroom/fast-foreground-estimation" + + def get_foreground(self, images, masks, blur_size=91, blur_size_two=7, fill_color=False, color=None): + b, h, w, c = images.shape + if b != masks.shape[0]: + raise ValueError("images and masks must have the same batch size") + + image_bchw = images.permute(0, 3, 1, 2) + + if masks.dim() == 3: + # (b, h, w) => (b, 1, h, w) + out_masks = masks.unsqueeze(1) # (b, c, h, w) _image_masked = refine_foreground(image_bchw, out_masks, r1=blur_size, r2=blur_size_two) @@ -260,6 +292,55 @@ def rem_bg(self, model, images, upscale_method='bilinear', width=1024, height=10 return out_images, out_masks +class RembgByBiRefNetAdvanced(GetMaskByBiRefNet, BlurFusionForegroundEstimation): + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("BiRefNetMODEL",), + "images": ("IMAGE",), + "width": ("INT", + { + "default": 1024, + "min": 0, + "max": 16384, + "tooltip": "The width of the preprocessed image, does not affect the final output image size" + }), + "height": ("INT", + { + "default": 1024, + "min": 0, + "max": 16384, + "tooltip": "The height of the preprocessed image, does not affect the final output image size" + }), + "upscale_method": (["bislerp", "nearest-exact", "bilinear", "area", "bicubic"], + { + "default": "bilinear", + "tooltip": "Interpolation method for post-processing mask" + }), + "blur_size": ("INT", {"default": 91, "min": 1, "max": 255, "step": 2, }), + "blur_size_two": ("INT", {"default": 7, "min": 1, "max": 255, "step": 2, }), + "fill_color": ("BOOLEAN", {"default": False}), + "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), + "mask_threshold": ("FLOAT", {"default": 0.000, "min": 0.0, "max": 1.0, "step": 0.001, }), + } + } + + RETURN_TYPES = ("IMAGE", "MASK",) + RETURN_NAMES = ("image", "mask",) + FUNCTION = "rem_bg" + CATEGORY = "rembg/BiRefNet" + + def rem_bg(self, model, images, upscale_method='bilinear', width=1024, height=1024, blur_size=91, blur_size_two=7, fill_color=False, color=None, mask_threshold=0.000): + + masks = super().get_mask(model, images, width, height, upscale_method, mask_threshold) + + out_images, out_masks = super().get_foreground(images, masks=masks[0], blur_size=blur_size, blur_size_two=blur_size_two, fill_color=fill_color, color=color) + + return out_images, out_masks + + class RembgByBiRefNet(RembgByBiRefNetAdvanced): @classmethod @@ -285,6 +366,8 @@ def rem_bg(self, model, images): "LoadRembgByBiRefNetModel": LoadRembgByBiRefNetModel, "RembgByBiRefNet": RembgByBiRefNet, "RembgByBiRefNetAdvanced": RembgByBiRefNetAdvanced, + "GetMaskByBiRefNet": GetMaskByBiRefNet, + "BlurFusionForegroundEstimation": BlurFusionForegroundEstimation, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -292,4 +375,6 @@ def rem_bg(self, model, images): "LoadRembgByBiRefNetModel": "LoadRembgByBiRefNetModel", "RembgByBiRefNet": "RembgByBiRefNet", "RembgByBiRefNetAdvanced": "RembgByBiRefNetAdvanced", + "GetMaskByBiRefNet": "GetMaskByBiRefNet", + "BlurFusionForegroundEstimation": "BlurFusionForegroundEstimation", } diff --git a/doc/base.png b/doc/base.png index 87ad92a..647046d 100644 Binary files a/doc/base.png and b/doc/base.png differ diff --git a/example/workflow_base.png b/example/workflow_base.png index 2794785..3fba62f 100644 Binary files a/example/workflow_base.png and b/example/workflow_base.png differ diff --git a/pyproject.toml b/pyproject.toml index 837f232..ea4f987 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui_birefnet_ll" -description = "Sync with version of BiRefNet. NODES:AutoDownloadBiRefNetModel, LoadRembgByBiRefNetModel, RembgByBiRefNet." -version = "1.0.7" +description = "Sync with version of BiRefNet. NODES:AutoDownloadBiRefNetModel, LoadRembgByBiRefNetModel, RembgByBiRefNet, RembgByBiRefNetAdvanced, GetMaskByBiRefNet, BlurFusionForegroundEstimation." +version = "1.0.8" license = {file = "LICENSE"} dependencies = ["numpy", "opencv-python", "timm"]