Skip to content

Commit

Permalink
sync BiRefNet code and optimize package import method
Browse files Browse the repository at this point in the history
  • Loading branch information
lldacing committed Nov 18, 2024
1 parent 172e873 commit f089a41
Show file tree
Hide file tree
Showing 27 changed files with 172 additions and 226 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Some models on GitHub:

## Thanks

[BiRefNet](https://github.com/zhengpeng7/birefnet)
[ZhengPeng7/BiRefNet](https://github.com/zhengpeng7/birefnet)

[dimitribarbot/sd-webui-birefnet](https://github.com/dimitribarbot/sd-webui-birefnet)

2 changes: 1 addition & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ GitHub上的模型:

## 感谢

[BiRefNet](https://github.com/zhengpeng7/birefnet)
[ZhengPeng7/BiRefNet](https://github.com/zhengpeng7/birefnet)

[dimitribarbot/sd-webui-birefnet](https://github.com/dimitribarbot/sd-webui-birefnet)

12 changes: 3 additions & 9 deletions birefnet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ def __init__(self, bb_index: int = 6) -> None:
self.prompt4loc = ['dense', 'sparse'][0]

# Faster-Training settings
self.load_all = False # Turn it on/off by your case. It may consume a lot of CPU memory. And for multi-GPU (N), it would cost N times the CPU memory to load the data.
self.use_fp16 = False # It may cause nan in training.
self.compile = True and (not self.use_fp16) # 1. Trigger CPU memory leak in some extend, which is an inherent problem of PyTorch.
self.load_all = False # Turn it on/off by your case. It may consume a lot of CPU memory. And for multi-GPU (N), it would cost N times the CPU memory to load the data.
self.compile = True # 1. Trigger CPU memory leak in some extend, which is an inherent problem of PyTorch.
# Machines with > 70GB CPU memory can run the whole training on DIS5K with default setting.
# 2. Higher PyTorch version may fix it: https://github.com/pytorch/pytorch/issues/119607.
# 3. But compile in Pytorch > 2.0.1 seems to bring no acceleration for training.
Expand Down Expand Up @@ -102,6 +101,7 @@ def __init__(self, bb_index: int = 6) -> None:
self.freeze_bb = False
self.model = [
'BiRefNet',
'BiRefNetC2F',
][0]

# TRAINING settings - inactive
Expand Down Expand Up @@ -200,10 +200,4 @@ def __init__(self, bb_index: int = 6) -> None:
# self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0])
# self.save_step = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0])

def print_task(self) -> None:
# Return task for choosing settings in shell scripts.
print(self.task)

# if __name__ == '__main__':
# config = Config()
# config.print_task()
6 changes: 3 additions & 3 deletions birefnet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from torch.utils import data
from torchvision import transforms

from birefnet.image_proc import preproc
from birefnet.config import Config
from birefnet.utils import path_to_image
from .image_proc import preproc
from .config import Config
from .utils import path_to_image


Image.MAX_IMAGE_PIXELS = None # remove DecompressionBombWarning
Expand Down
12 changes: 5 additions & 7 deletions birefnet/models/backbones/build_backbone.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import torch
import torch.nn as nn
import safetensors.torch
from collections import OrderedDict
from torchvision.models import vgg16, vgg16_bn, VGG16_Weights, VGG16_BN_Weights, resnet50, ResNet50_Weights
from birefnet.models.backbones.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5
from birefnet.models.backbones.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l
from birefnet.config import Config
from ..backbones.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5
from ..backbones.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l
from ...config import Config


config = Config()
Expand All @@ -27,8 +26,7 @@ def build_backbone(bb_name, pretrained=True, params_settings=''):
return bb

def load_weights(model, model_name):
# safetensors.torch.load_file
save_model = torch.load(config.weights[model_name], map_location='cpu')
save_model = torch.load(config.weights[model_name], map_location='cpu', weights_only=True)
model_dict = model.state_dict()
state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model.items() if k in model_dict.keys()}
# to ignore the weights with mismatched size when I modify the backbone itself.
Expand All @@ -37,7 +35,7 @@ def load_weights(model, model_name):
sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None
state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model[sub_item].items() if k in model_dict.keys()}
if not state_dict or not sub_item:
print('Weights are not successully loaded. Check the state dict of weights file.')
print('Weights are not successfully loaded. Check the state dict of weights file.')
return None
else:
print('Found correct weights in the "{}" item of loaded state_dict.'.format(sub_item))
Expand Down
24 changes: 11 additions & 13 deletions birefnet/models/backbones/pvt_v2.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import math
from functools import partial
import torch
import torch.nn as nn
from functools import partial

from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model

import math
try:
# version > 0.6.13
from timm.layers import DropPath, to_2tuple, trunc_normal_
except Exception:
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

from birefnet.config import Config
from ...config import Config

config = Config()

Expand Down Expand Up @@ -383,7 +385,6 @@ def _conv_filter(state_dict, patch_size=16):
return out_dict


## @register_model
class pvt_v2_b0(PyramidVisionTransformerImpr):
def __init__(self, **kwargs):
super(pvt_v2_b0, self).__init__(
Expand All @@ -392,32 +393,30 @@ def __init__(self, **kwargs):
drop_rate=0.0, drop_path_rate=0.1)



## @register_model
class pvt_v2_b1(PyramidVisionTransformerImpr):
def __init__(self, **kwargs):
super(pvt_v2_b1, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)

## @register_model

class pvt_v2_b2(PyramidVisionTransformerImpr):
def __init__(self, in_channels=3, **kwargs):
super(pvt_v2_b2, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels)

## @register_model

class pvt_v2_b3(PyramidVisionTransformerImpr):
def __init__(self, **kwargs):
super(pvt_v2_b3, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)

## @register_model

class pvt_v2_b4(PyramidVisionTransformerImpr):
def __init__(self, **kwargs):
super(pvt_v2_b4, self).__init__(
Expand All @@ -426,7 +425,6 @@ def __init__(self, **kwargs):
drop_rate=0.0, drop_path_rate=0.1)


## @register_model
class pvt_v2_b5(PyramidVisionTransformerImpr):
def __init__(self, **kwargs):
super(pvt_v2_b5, self).__init__(
Expand Down
6 changes: 5 additions & 1 deletion birefnet/models/backbones/swin_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
try:
# version > 0.6.13
from timm.layers import DropPath, to_2tuple, trunc_normal_
except Exception:
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

from birefnet.config import Config

Expand Down
113 changes: 86 additions & 27 deletions birefnet/models/birefnet.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from kornia.filters import laplacian

from birefnet.config import Config
from birefnet.dataset import class_labels_TR_sorted
from birefnet.models.backbones.build_backbone import build_backbone
from birefnet.models.modules.decoder_blocks import BasicDecBlk, ResBlk
from birefnet.models.modules.lateral_blocks import BasicLatBlk
from birefnet.models.modules.aspp import ASPP, ASPPDeformable
from birefnet.models.refinement.refiner import Refiner, RefinerPVTInChannels4, RefUNet
from birefnet.models.refinement.stem_layer import StemLayer

from huggingface_hub import PyTorchModelHubMixin

from ..config import Config
from ..dataset import class_labels_TR_sorted
from .backbones.build_backbone import build_backbone
from .modules.decoder_blocks import BasicDecBlk, ResBlk
from .modules.lateral_blocks import BasicLatBlk
from .modules.aspp import ASPP, ASPPDeformable
from .refinement.refiner import Refiner, RefinerPVTInChannels4, RefUNet
from .refinement.stem_layer import StemLayer


def image2patches(image, grid_h=2, grid_w=2, patch_ref=None, transformation='b c (hg h) (wg w) -> (b hg wg) c h w'):
if patch_ref is not None:
grid_h, grid_w = image.shape[-2] // patch_ref.shape[-2], image.shape[-1] // patch_ref.shape[-1]
patches = rearrange(image, transformation, hg=grid_h, wg=grid_w)
return patches

def patches2image(patches, grid_h=2, grid_w=2, patch_ref=None, transformation='(b hg wg) c h w -> b c (hg h) (wg w)'):
if patch_ref is not None:
grid_h, grid_w = patch_ref.shape[-2] // patches[0].shape[-2], patch_ref.shape[-1] // patches[0].shape[-1]
image = rearrange(patches, transformation, hg=grid_h, wg=grid_w)
return image

class BiRefNet(nn.Module):
def __init__(self, bb_pretrained=True, bb_index=6):
Expand Down Expand Up @@ -159,18 +173,6 @@ def __init__(self, channels):
self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))

def get_patches_batch(self, x, p):
_size_h, _size_w = p.shape[2:]
patches_batch = []
for idx in range(x.shape[0]):
columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
patches_x = []
for column_x in columns_x:
patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
patch_sample = torch.cat(patches_x, dim=1)
patches_batch.append(patch_sample)
return torch.cat(patches_batch, dim=0)

def forward(self, features):
if self.training and self.config.out_ref:
outs_gdt_pred = []
Expand All @@ -181,7 +183,7 @@ def forward(self, features):
outs = []

if self.config.dec_ipt:
patches_batch = self.get_patches_batch(x, x4) if self.split else x
patches_batch = image2patches(x, patch_ref=x4, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
p4 = self.decoder_block4(x4)
m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None
Expand All @@ -202,7 +204,7 @@ def forward(self, features):
_p3 = _p4 + self.lateral_block4(x3)

if self.config.dec_ipt:
patches_batch = self.get_patches_batch(x, _p3) if self.split else x
patches_batch = image2patches(x, patch_ref=_p3, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
_p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
p3 = self.decoder_block3(_p3)
m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None
Expand All @@ -228,7 +230,7 @@ def forward(self, features):
_p2 = _p3 + self.lateral_block3(x2)

if self.config.dec_ipt:
patches_batch = self.get_patches_batch(x, _p2) if self.split else x
patches_batch = image2patches(x, patch_ref=_p2, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
_p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
p2 = self.decoder_block2(_p2)
m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None
Expand All @@ -249,13 +251,13 @@ def forward(self, features):
_p1 = _p2 + self.lateral_block2(x1)

if self.config.dec_ipt:
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
_p1 = self.decoder_block1(_p1)
_p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)

if self.config.dec_ipt:
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
p1_out = self.conv_out1(_p1)

Expand All @@ -277,3 +279,60 @@ def __init__(

def forward(self, x):
return self.conv_out(self.conv1(x))


###########


class BiRefNetC2F(
nn.Module,
PyTorchModelHubMixin,
library_name="birefnet_c2f",
repo_url="https://github.com/ZhengPeng7/BiRefNet_C2F",
tags=['Image Segmentation', 'Background Removal', 'Mask Generation', 'Dichotomous Image Segmentation', 'Camouflaged Object Detection', 'Salient Object Detection']
):
def __init__(self, bb_pretrained=True):
super(BiRefNetC2F, self).__init__()
self.config = Config()
self.epoch = 1
self.grid = 4
self.model_coarse = BiRefNet(bb_pretrained=True)
self.model_fine = BiRefNet(bb_pretrained=True)
self.input_mixer = nn.Conv2d(4, 3, 1, 1, 0)
self.output_mixer_merge_post = nn.Sequential(nn.Conv2d(1, 16, 3, 1, 1), nn.Conv2d(16, 1, 3, 1, 1))

def forward(self, x):
x_ori = x.clone()
########## Coarse ##########
x = F.interpolate(x, size=[s//self.grid for s in self.config.size[::-1]], mode='bilinear', align_corners=True)

if self.training:
scaled_preds, class_preds_lst = self.model_coarse(x)
else:
scaled_preds = self.model_coarse(x)
########## Fine ##########
x_HR_patches = image2patches(x_ori, patch_ref=x, transformation='b c (hg h) (wg w) -> (b hg wg) c h w')
pred = F.interpolate(scaled_preds[-1] if not (self.config.out_ref and self.training) else scaled_preds[1][-1], size=x_ori.shape[2:], mode='bilinear', align_corners=True)
pred_patches = image2patches(pred, patch_ref=x, transformation='b c (hg h) (wg w) -> (b hg wg) c h w')
t = torch.cat([x_HR_patches, pred_patches], dim=1)
x_HR = self.input_mixer(t)

pred_patches = image2patches(pred, patch_ref=x_HR, transformation='b c (hg h) (wg w) -> b (c hg wg) h w')
if self.training:
scaled_preds_HR, class_preds_lst_HR = self.model_fine(x_HR)
else:
scaled_preds_HR = self.model_fine(x_HR)
if self.training:
if self.config.out_ref:
[outs_gdt_pred, outs_gdt_label], outs = scaled_preds
[outs_gdt_pred_HR, outs_gdt_label_HR], outs_HR = scaled_preds_HR
for idx_out, out_HR in enumerate(outs_HR):
outs_HR[idx_out] = self.output_mixer_merge_post(patches2image(out_HR, grid_h=self.grid, grid_w=self.grid, transformation='(b hg wg) c h w -> b c (hg h) (wg w)'))
return [([outs_gdt_pred + outs_gdt_pred_HR, outs_gdt_label + outs_gdt_label_HR], outs + outs_HR), class_preds_lst] # handle gt here
else:
return [
scaled_preds + [self.output_mixer_merge_post(patches2image(scaled_pred_HR, grid_h=self.grid, grid_w=self.grid, transformation='(b hg wg) c h w -> b c (hg h) (wg w)')) for scaled_pred_HR in scaled_preds_HR],
class_preds_lst
]
else:
return scaled_preds + [self.output_mixer_merge_post(patches2image(scaled_pred_HR, grid_h=self.grid, grid_w=self.grid, transformation='(b hg wg) c h w -> b c (hg h) (wg w)')) for scaled_pred_HR in scaled_preds_HR]
4 changes: 2 additions & 2 deletions birefnet/models/modules/aspp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from birefnet.models.modules.deform_conv import DeformableConv2d
from birefnet.config import Config
from ..modules.deform_conv import DeformableConv2d
from ...config import Config


config = Config()
Expand Down
4 changes: 2 additions & 2 deletions birefnet/models/modules/decoder_blocks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from birefnet.models.modules.aspp import ASPP, ASPPDeformable
from birefnet.config import Config
from ..modules.aspp import ASPP, ASPPDeformable
from ...config import Config


config = Config()
Expand Down
2 changes: 1 addition & 1 deletion birefnet/models/modules/lateral_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn.functional as F
from functools import partial

from birefnet.config import Config
from ...config import Config


config = Config()
Expand Down
Loading

0 comments on commit f089a41

Please sign in to comment.