Skip to content

Commit

Permalink
add GetMaskByBiRefNet and BlurFusionForegroundEstimation
Browse files Browse the repository at this point in the history
  • Loading branch information
lldacing committed Jan 4, 2025
1 parent 9b7262e commit 7570a99
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 17 deletions.
115 changes: 100 additions & 15 deletions birefnetNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def load_model(self, model, device, use_weight=False):
return [(biRefNet_model, version)]


class RembgByBiRefNetAdvanced:
class GetMaskByBiRefNet:

@classmethod
def INPUT_TYPES(cls):
Expand All @@ -191,26 +191,22 @@ def INPUT_TYPES(cls):
"default": "bilinear",
"tooltip": "Interpolation method for post-processing mask"
}),
"blur_size": ("INT", {"default": 91, "min": 1, "max": 255, "step": 2, }),
"blur_size_two": ("INT", {"default": 7, "min": 1, "max": 255, "step": 2, }),
"fill_color": ("BOOLEAN", {"default": False}),
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
"mask_threshold": ("FLOAT", {"default": 0.000, "min": 0.0, "max": 1.0, "step": 0.001, }),
"mask_threshold": ("FLOAT", {"default": 0.000, "min": 0.0, "max": 1.0, "step": 0.004, }),
}
}

RETURN_TYPES = ("IMAGE", "MASK",)
RETURN_NAMES = ("image", "mask",)
FUNCTION = "rem_bg"
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("mask",)
FUNCTION = "get_mask"
CATEGORY = "rembg/BiRefNet"

def rem_bg(self, model, images, upscale_method='bilinear', width=1024, height=1024, blur_size=91, blur_size_two=7, fill_color=False, color=None, mask_threshold=0.000):
def get_mask(self, model, images, width=1024, height=1024, upscale_method='bilinear', mask_threshold=0.000):
model, version = model
model_device_type = next(model.parameters()).device.type
b, h, w, c = images.shape
image_bchw = images.permute(0, 3, 1, 2)

image_preproc = ImagePreprocessor(resolution=(1024, 1024))
image_preproc = ImagePreprocessor(resolution=(height, width))
if VERSION[0] == version:
im_tensor = image_preproc.old_proc(image_bchw)
else:
Expand All @@ -226,12 +222,48 @@ def rem_bg(self, model, images, upscale_method='bilinear', width=1024, height=10
mask_bchw = torch.cat(_mask_bchw, dim=0)
del _mask_bchw
# 遮罩大小需还原为与原图一致
mask = comfy.utils.common_upscale(mask_bchw, w, h, 'bilinear', "disabled")
mask = comfy.utils.common_upscale(mask_bchw, w, h, upscale_method, "disabled")
# (b, 1, h, w)
if mask_threshold > 0:
out_masks = filter_mask(mask, threshold=mask_threshold)
else:
out_masks = normalize_mask(mask)
mask = filter_mask(mask, threshold=mask_threshold)
# else:
# 似乎几乎无影响
# mask = normalize_mask(mask)

return mask.squeeze(1),


class BlurFusionForegroundEstimation:

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"masks": ("MASK",),
"blur_size": ("INT", {"default": 91, "min": 1, "max": 255, "step": 2, }),
"blur_size_two": ("INT", {"default": 7, "min": 1, "max": 255, "step": 2, }),
"fill_color": ("BOOLEAN", {"default": False}),
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
}
}

RETURN_TYPES = ("IMAGE", "MASK",)
RETURN_NAMES = ("image", "mask",)
FUNCTION = "get_foreground"
CATEGORY = "rembg/BiRefNet"
DESCRIPTION = "Approximate Fast Foreground Colour Estimation. https://github.com/Photoroom/fast-foreground-estimation"

def get_foreground(self, images, masks, blur_size=91, blur_size_two=7, fill_color=False, color=None):
b, h, w, c = images.shape
if b != masks.shape[0]:
raise ValueError("images and masks must have the same batch size")

image_bchw = images.permute(0, 3, 1, 2)

if masks.dim() == 3:
# (b, h, w) => (b, 1, h, w)
out_masks = masks.unsqueeze(1)

# (b, c, h, w)
_image_masked = refine_foreground(image_bchw, out_masks, r1=blur_size, r2=blur_size_two)
Expand Down Expand Up @@ -260,6 +292,55 @@ def rem_bg(self, model, images, upscale_method='bilinear', width=1024, height=10
return out_images, out_masks


class RembgByBiRefNetAdvanced(GetMaskByBiRefNet, BlurFusionForegroundEstimation):

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("BiRefNetMODEL",),
"images": ("IMAGE",),
"width": ("INT",
{
"default": 1024,
"min": 0,
"max": 16384,
"tooltip": "The width of the preprocessed image, does not affect the final output image size"
}),
"height": ("INT",
{
"default": 1024,
"min": 0,
"max": 16384,
"tooltip": "The height of the preprocessed image, does not affect the final output image size"
}),
"upscale_method": (["bislerp", "nearest-exact", "bilinear", "area", "bicubic"],
{
"default": "bilinear",
"tooltip": "Interpolation method for post-processing mask"
}),
"blur_size": ("INT", {"default": 91, "min": 1, "max": 255, "step": 2, }),
"blur_size_two": ("INT", {"default": 7, "min": 1, "max": 255, "step": 2, }),
"fill_color": ("BOOLEAN", {"default": False}),
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
"mask_threshold": ("FLOAT", {"default": 0.000, "min": 0.0, "max": 1.0, "step": 0.001, }),
}
}

RETURN_TYPES = ("IMAGE", "MASK",)
RETURN_NAMES = ("image", "mask",)
FUNCTION = "rem_bg"
CATEGORY = "rembg/BiRefNet"

def rem_bg(self, model, images, upscale_method='bilinear', width=1024, height=1024, blur_size=91, blur_size_two=7, fill_color=False, color=None, mask_threshold=0.000):

masks = super().get_mask(model, images, width, height, upscale_method, mask_threshold)

out_images, out_masks = super().get_foreground(images, masks=masks[0], blur_size=blur_size, blur_size_two=blur_size_two, fill_color=fill_color, color=color)

return out_images, out_masks


class RembgByBiRefNet(RembgByBiRefNetAdvanced):

@classmethod
Expand All @@ -285,11 +366,15 @@ def rem_bg(self, model, images):
"LoadRembgByBiRefNetModel": LoadRembgByBiRefNetModel,
"RembgByBiRefNet": RembgByBiRefNet,
"RembgByBiRefNetAdvanced": RembgByBiRefNetAdvanced,
"GetMaskByBiRefNet": GetMaskByBiRefNet,
"BlurFusionForegroundEstimation": BlurFusionForegroundEstimation,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"AutoDownloadBiRefNetModel": "AutoDownloadBiRefNetModel",
"LoadRembgByBiRefNetModel": "LoadRembgByBiRefNetModel",
"RembgByBiRefNet": "RembgByBiRefNet",
"RembgByBiRefNetAdvanced": "RembgByBiRefNetAdvanced",
"GetMaskByBiRefNet": "GetMaskByBiRefNet",
"BlurFusionForegroundEstimation": "BlurFusionForegroundEstimation",
}
Binary file modified doc/base.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified example/workflow_base.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui_birefnet_ll"
description = "Sync with version of BiRefNet. NODES:AutoDownloadBiRefNetModel, LoadRembgByBiRefNetModel, RembgByBiRefNet."
version = "1.0.7"
description = "Sync with version of BiRefNet. NODES:AutoDownloadBiRefNetModel, LoadRembgByBiRefNetModel, RembgByBiRefNet, RembgByBiRefNetAdvanced, GetMaskByBiRefNet, BlurFusionForegroundEstimation."
version = "1.0.8"
license = {file = "LICENSE"}
dependencies = ["numpy", "opencv-python", "timm"]

Expand Down

0 comments on commit 7570a99

Please sign in to comment.