-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcutout.py
34 lines (24 loc) · 981 Bytes
/
cutout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import random
import torch
def _gen_cutout_coord(height, width, size):
height_loc = random.randint(0, height - 1)
width_loc = random.randint(0, width - 1)
upper_coord = (max(0, height_loc - size // 2),
max(0, width_loc - size // 2))
lower_coord = (min(height, height_loc + size // 2),
min(width, width_loc + size // 2))
return upper_coord, lower_coord
class Cutout(torch.nn.Module):
def __init__(self, size=16):
super().__init__()
self.size = size
def forward(self, img):
h, w = img.shape[-2:]
upper_coord, lower_coord = _gen_cutout_coord(h, w, self.size)
mask_height = lower_coord[0] - upper_coord[0]
mask_width = lower_coord[1] - upper_coord[1]
assert mask_height > 0
assert mask_width > 0
mask = torch.ones_like(img)
mask[..., upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1]] = 0
return img * mask