Skip to content

Commit

Permalink
Fix validation checking for u2net models.
Browse files Browse the repository at this point in the history
  • Loading branch information
nadermx committed Aug 13, 2021
1 parent 04d3ef4 commit b652d91
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 18 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setup(
name="backgroundremover",
version="0.1.8",
version="0.1.9",
description="Background remover from image and video",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
2 changes: 1 addition & 1 deletion src/backgroundremover/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
A library to remove background from videos and images
"""

__version__ = "0.1.8"
__version__ = "0.1.9"
__author__ = 'Johnathan Nader'
__credits__ = 'BackgroundRemover.app'
2 changes: 0 additions & 2 deletions src/backgroundremover/bg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class Net(torch.nn.Module):
def __init__(self, model_name):
super(Net, self).__init__()
hasher = Hasher()

model = {
'u2netp': (u2net.U2NETP,
'e4f636406ca4e2af789941e7f139ee2e',
Expand Down Expand Up @@ -162,7 +161,6 @@ def naive_cutout(img, mask):
return cutout


@functools.lru_cache(maxsize=None)
def get_model(model_name):
if model_name == "u2netp":
return detect.load_model(model_name="u2netp")
Expand Down
10 changes: 1 addition & 9 deletions src/backgroundremover/cmd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,7 @@


def main():
model_path = os.environ.get(
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net")),
)
print(model_path)
model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
print('here', model_choices)
if len(model_choices) == 0:
model_choices = ["u2net", "u2net_human_seg", "u2netp"]
model_choices = ["u2net", "u2net_human_seg", "u2netp"]

ap = argparse.ArgumentParser()

Expand Down
10 changes: 5 additions & 5 deletions src/backgroundremover/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,15 +281,15 @@ def transparentvideooverimage(output, overlay, file_path,
sp.run(shlex.split(cmd))
print("Starting alphamerge")
cmd = "nice -10 ffmpeg -y -i %s -i %s -i %s -filter_complex '[0][1]scale2ref[img][vid];[img]setsar=1[img];[vid]nullsink; [img][2]overlay=(W-w)/2:(H-h)/2' -shortest %s" % (
#cmd = "nice -10 ffmpeg -y -i %s -i %s -i %s -filter_complex '[1][0]scale2ref[mask][main];[main][mask]alphamerge=shortest=1[vid];[2:v][vid]overlay[out]' -map [out] -shortest %s" % (
temp_image, file_path, temp_file, output)
sp.run(shlex.split(cmd))
print("Process finished")
return

def download_file_from_google_drive(model, path):
head, tail = os.path.split(path)
os.makedirs(head, exist_ok=True)
URL = "https://drive.google.com/uc?id=%s" % model[2]
if not os.path.exists(path):
head, tail = os.path.split(path)
os.makedirs(head, exist_ok=True)
URL = "https://drive.google.com/uc?id=%s" % model[2]

gdown.download(URL, path, quiet=False)
gdown.download(URL, path, quiet=False)

0 comments on commit b652d91

Please sign in to comment.