Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
nadermx committed Apr 12, 2024
2 parents f63ef14 + 7d1aaf0 commit c590858
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 44 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea/
# Created by https://www.toptal.com/developers/gitignore/api/python
# Edit at https://www.toptal.com/developers/gitignore?templates=python

Expand Down
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,27 @@ change the model for different background removal methods between `u2netp`, `u2n
backgroundremover -i "/path/to/video.mp4" -m "u2net_human_seg" -fl 150 -tv -o "output.mov"
```

## As a library
### Remove background image

```
from backgroundremover.bg import remove
def remove_bg(src_img_path, out_img_path):
model_choices = ["u2net", "u2net_human_seg", "u2netp"]
f = open(src_img_path, "rb")
data = f.read()
img = remove(data, model_name=model_choices[0],
alpha_matting=True,
alpha_matting_foreground_threshold=240,
alpha_matting_background_threshold=10,
alpha_matting_erode_structure_size=10,
alpha_matting_base_size=1000)
f.close()
f = open(out_img_path, "wb")
f.write(img)
f.close()
```

## Todo

- convert logic from video to image to utilize more GPU on image removal
Expand Down
8 changes: 4 additions & 4 deletions backgroundremover/bg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.nn.functional
from hsh.library.hash import Hasher
from .u2net import detect, u2net
from . import utilities
from . import github

# closes https://github.com/nadermx/backgroundremover/issues/18
# closes https://github.com/nadermx/backgroundremover/issues/112
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(self, model_name):
if (
not os.path.exists(path)
):
utilities.download_files_from_github(
github.download_files_from_github(
path, model_name
)

Expand All @@ -70,7 +70,7 @@ def __init__(self, model_name):
not os.path.exists(path)
#or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
):
utilities.download_files_from_github(
github.download_files_from_github(
path, model_name
)

Expand All @@ -84,7 +84,7 @@ def __init__(self, model_name):
not os.path.exists(path)
#or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
):
utilities.download_files_from_github(
github.download_files_from_github(
path, model_name
)
else:
Expand Down
38 changes: 38 additions & 0 deletions backgroundremover/github.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import requests


def download_files_from_github(path, model_name):
if model_name not in ["u2net", "u2net_human_seg", "u2netp"]:
print("Invalid model name, please use 'u2net' or 'u2net_human_seg' or 'u2netp'")
return
print(f"downloading model [{model_name}] to {path} ...")
urls = []
if model_name == "u2net":
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2aa',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ab',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ac',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ad']
elif model_name == "u2net_human_seg":
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2haa',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2hab',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2hac',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2had']
elif model_name == 'u2netp':
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2netp.pth']
try:
os.makedirs(os.path.expanduser("~/.u2net"), exist_ok=True)
except Exception as e:
print(f"Error creating directory: {e}")
return

try:

with open(path, 'wb') as out_file:
for i, url in enumerate(urls):
print(f'downloading part {i+1} of {model_name}')
part_content = requests.get(url)
out_file.write(part_content.content)
print(f'finished downloading part {i+1} of {model_name}')
except Exception as e:
print(e)
2 changes: 1 addition & 1 deletion backgroundremover/u2net/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __call__(self, sample):
# change the r,g,b to b,r,g from [0,255] to [0,1]
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
tmpLbl = tmpLbl.transpose((2, 0, 1))

return {
"imidx": torch.from_numpy(imidx),
Expand Down
12 changes: 8 additions & 4 deletions backgroundremover/u2net/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from torchvision import transforms

from . import data_loader, u2net
from .. import utilities
from .. import github


def load_model(model_name: str = "u2net"):
hasher = Hasher()
Expand Down Expand Up @@ -38,7 +39,7 @@ def load_model(model_name: str = "u2net"):
not os.path.exists(path)
#or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"
):
utilities.download_files_from_github(
github.download_files_from_github(
path, model_name
)

Expand All @@ -48,11 +49,14 @@ def load_model(model_name: str = "u2net"):
"U2NET_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)

print(f"DEBUG: path to be checked: {path}")

if (
not os.path.exists(path)
#or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
):
utilities.download_files_from_github(
github.download_files_from_github(
path, model_name
)

Expand All @@ -66,7 +70,7 @@ def load_model(model_name: str = "u2net"):
not os.path.exists(path)
#or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
):
utilities.download_files_from_github(
github.download_files_from_github(
path, model_name
)

Expand Down
35 changes: 0 additions & 35 deletions backgroundremover/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,38 +328,3 @@ def transparentvideooverimage(output, overlay, file_path,
except PermissionError:
pass
return

def download_files_from_github(path, model_name):
if model_name not in ["u2net", "u2net_human_seg", "u2netp"]:
print("Invalid model name, please use 'u2net' or 'u2net_human_seg' or 'u2netp'")
return
print(f"downloading model [{model_name}] to {path} ...")
urls = []
if model_name == "u2net":
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2aa',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ab',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ac',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ad']
elif model_name == "u2net_human_seg":
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2haa',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2hab',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2hac',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2had']
elif model_name == 'u2netp':
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2netp.pth']
try:
os.makedirs(os.path.expanduser("~/.u2net"), exist_ok=True)
except Exception as e:
print(f"Error creating directory: {e}")
return

try:

with open(path, 'wb') as out_file:
for i, url in enumerate(urls):
print(f'downloading part {i+1} of {model_name}')
part_content = requests.get(url)
out_file.write(part_content.content)
print(f'finished downloading part {i+1} of {model_name}')
except Exception as e:
print(e)

0 comments on commit c590858

Please sign in to comment.