-
Notifications
You must be signed in to change notification settings - Fork 98
/
utils.py
97 lines (76 loc) · 2.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import logging
import os
import torch
from torchvision import transforms
import numpy as np
import random
import cv2
from PIL import Image
def path_to_image(path, size=(1024, 1024), color_type=['rgb', 'gray'][0]):
if color_type.lower() == 'rgb':
image = cv2.imread(path)
elif color_type.lower() == 'gray':
image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
else:
print('Select the color_type to return, either to RGB or gray image.')
return
if size:
image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
if color_type.lower() == 'rgb':
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert('RGB')
else:
image = Image.fromarray(image).convert('L')
return image
def check_state_dict(state_dict, unwanted_prefix='_orig_mod.'):
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
return state_dict
def generate_smoothed_gt(gts):
epsilon = 0.001
new_gts = (1-epsilon)*gts+epsilon/2
return new_gts
class Logger():
def __init__(self, path="log.txt"):
self.logger = logging.getLogger('BiRefNet')
self.file_handler = logging.FileHandler(path, "w")
self.stdout_handler = logging.StreamHandler()
self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
self.logger.addHandler(self.file_handler)
self.logger.addHandler(self.stdout_handler)
self.logger.setLevel(logging.INFO)
self.logger.propagate = False
def info(self, txt):
self.logger.info(txt)
def close(self):
self.file_handler.close()
self.stdout_handler.close()
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0.0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def save_checkpoint(state, path, filename="latest.pth"):
torch.save(state, os.path.join(path, filename))
def save_tensor_img(tenor_im, path):
im = tenor_im.cpu().clone()
im = im.squeeze(0)
tensor2pil = transforms.ToPILImage()
im = tensor2pil(im)
im.save(path)
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True