-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimagenet_augmentation.py
112 lines (87 loc) · 4.49 KB
/
imagenet_augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# this code is taken from https://github.com/kakaobrain/fast-autoaugment
import math
import random
import torch
class EfficientNetRandomCrop:
def __init__(self, imgsize, min_covered=0.1, aspect_ratio_range=(3./4, 4./3), area_range=(0.08, 1.0), max_attempts=10):
assert 0.0 < min_covered
assert 0 < aspect_ratio_range[0] <= aspect_ratio_range[1]
assert 0 < area_range[0] <= area_range[1]
assert 1 <= max_attempts
self.min_covered = min_covered
self.aspect_ratio_range = aspect_ratio_range
self.area_range = area_range
self.max_attempts = max_attempts
self._fallback = EfficientNetCenterCrop(imgsize)
def __call__(self, img):
# https://github.com/tensorflow/tensorflow/blob/9274bcebb31322370139467039034f8ff852b004/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc#L111
original_width, original_height = img.size
min_area = self.area_range[0] * (original_width * original_height)
max_area = self.area_range[1] * (original_width * original_height)
for _ in range(self.max_attempts):
aspect_ratio = random.uniform(*self.aspect_ratio_range)
height = int(round(math.sqrt(min_area / aspect_ratio)))
max_height = int(round(math.sqrt(max_area / aspect_ratio)))
if max_height * aspect_ratio > original_width:
max_height = (original_width + 0.5 - 1e-7) / aspect_ratio
max_height = int(max_height)
if max_height * aspect_ratio > original_width:
max_height -= 1
if max_height > original_height:
max_height = original_height
if height >= max_height:
height = max_height
height = int(round(random.uniform(height, max_height)))
width = int(round(height * aspect_ratio))
area = width * height
if area < min_area or area > max_area:
continue
if width > original_width or height > original_height:
continue
if area < self.min_covered * (original_width * original_height):
continue
if width == original_width and height == original_height:
return self._fallback(img) # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/preprocessing.py#L102
x = random.randint(0, original_width - width)
y = random.randint(0, original_height - height)
return img.crop((x, y, x + width, y + height))
return self._fallback(img)
class EfficientNetCenterCrop:
def __init__(self, imgsize):
self.imgsize = imgsize
def __call__(self, img):
"""Crop the given PIL Image and resize it to desired size.
Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
output_size (sequence or int): (height, width) of the crop box. If int,
it is used for both directions
Returns:
PIL Image: Cropped image.
"""
image_width, image_height = img.size
image_short = min(image_width, image_height)
crop_size = float(self.imgsize) / (self.imgsize + 32) * image_short
crop_height, crop_width = crop_size, crop_size
crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.))
return img.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height))
class Lighting(torch.nn.Module):
"""Lighting noise(AlexNet - style PCA - based noise)"""
def __init__(self, alphastd=0.1):
super().__init__()
self.alphastd = alphastd
self.register_buffer('eigval', torch.Tensor([0.2175, 0.0188, 0.0045]))
self.register_buffer('eigvec', torch.Tensor([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
]))
def forward(self, img):
if self.alphastd == 0:
return img
alpha = img.new().resize_(3).normal_(0, self.alphastd)
rgb = self.eigvec.type_as(img).clone() \
.mul(alpha.view(1, 3).expand(3, 3)) \
.mul(self.eigval.view(1, 3).expand(3, 3)) \
.sum(1).squeeze()
return img.add(rgb.view(3, 1, 1).expand_as(img))