-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinput_pipeline.py
258 lines (214 loc) · 8.4 KB
/
input_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
"""Data generators for COCO-style object detection datasets.
Different from Scenic:
* Custom data range for normalization.
* Just overall cleaner code?
"""
from functools import partial
from typing import Callable, Tuple
from flax import jax_utils
import jax
import jax.numpy as jnp
import ml_collections
import tensorflow as tf
import tensorflow_datasets as tfds
from dataset_lib import dataset_utils
from dataset_lib.coco_dataset import coco_utils
import transforms
# Values from ImageNet (as used by backbone)
_MEAN_RGB = [0.48, 0.456, 0.406]
_STD_RGB = [0.229, 0.224, 0.225]
def make_coco_transforms(split_name: str, max_size: int = 1333):
"""Returns augmentation/preprocessing for images and labels."""
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
ratio = max_size / 1333.
scales = [int(ratio * s) for s in scales]
# These scales are as per DETR torch implementation for RandomResize -> Crop
scales2 = [int(ratio * s) for s in [400, 500, 600]]
normalize_boxes = transforms.NormalizeBoxes()
init_padding_mask = transforms.InitPaddingMask()
if split_name == 'train':
return transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomSelect(
transforms.RandomResize(scales, max_size=max_size),
transforms.Compose([
transforms.RandomResize(scales2),
transforms.RandomSizeCrop(int(ratio * 384), int(ratio * 600)),
transforms.RandomResize(scales, max_size=max_size),
]),
),
normalize_boxes,
init_padding_mask,
])
elif split_name == 'validation':
return transforms.Compose([
transforms.Resize(max(scales), max_size=max_size),
normalize_boxes,
init_padding_mask,
])
else:
raise ValueError(f'Transforms not defined for split `{split_name}`')
def decode_boxes(boxes, size):
"""Convert yxyx[0-1] boxes from TF-Example to xyxy unnormalized."""
h = tf.cast(size[0], tf.float32)
w = tf.cast(size[1], tf.float32)
y0, x0, y1, x1 = tf.split(boxes, 4, axis=-1)
x0 = tf.clip_by_value(x0 * w, 0.0, w)
y0 = tf.clip_by_value(y0 * h, 0.0, h)
x1 = tf.clip_by_value(x1 * w, 0.0, w)
y1 = tf.clip_by_value(y1 * h, 0.0, h)
return tf.concat([x0, y0, x1, y1], axis=-1)
def decode_coco_detection_example(example, input_range=None):
"""Creates an <input, label> pair from a serialized TF Example."""
image = example['image']
decoded = image.dtype != tf.string
if not decoded:
image = tf.io.decode_image(image, channels=3, expand_animations=False)
image = tf.image.convert_image_dtype(image, tf.float32)
# Normalize
if input_range:
image = image * (input_range[1] - input_range[0]) + input_range[0]
else:
mean_rgb = tf.constant(_MEAN_RGB, shape=[1, 1, 3], dtype=tf.float32)
std_rgb = tf.constant(_STD_RGB, shape=[1, 1, 3], dtype=tf.float32)
image = (image - mean_rgb) / std_rgb
boxes = decode_boxes(example['objects']['bbox'], tf.shape(image)[0:2])
target = {
'area': example['objects']['area'],
'boxes': boxes,
'objects/id': example['objects']['id'],
'is_crowd': example['objects']['is_crowd'],
'labels': example['objects']['label'] + 1, # 0'th class will be bg
}
# Filter degenerate objects
keep = tf.where(
tf.logical_and(boxes[:, 2] > boxes[:, 0], boxes[:, 3] > boxes[:, 1]))[:,
0]
target_kept = {k: tf.gather(v, keep) for k, v in target.items()}
target_kept['orig_size'] = tf.shape(image)[0:2]
target_kept['image/id'] = example['image/id']
return {
'inputs': image,
'label': target_kept,
}
def load_split_from_tfds(
builder: tfds.core.DatasetBuilder,
*,
train: bool,
batch_size: int,
decode_fn: Callable,
preprocess_fn: Callable,
max_size: int,
max_boxes: int,
shuffle_buffer_size: int = None,
shuffle_seed: int = None,
) -> Tuple[tf.data.Dataset, tfds.core.DatasetInfo]:
split = 'train' if train else 'validation'
data_range = tfds.even_splits(split, jax.process_count())[jax.process_index()]
ds = builder.as_dataset(
data_range,
shuffle_files=False,
decoders={'image': tfds.decode.SkipDecoding()})
options = tf.data.Options()
options.threading.private_threadpool_size = 48
ds = ds.with_options(options)
ds = ds.cache()
# Padding structure for each tensor of the example.
padded_shapes = {
'inputs': [max_size, max_size, 3],
'padding_mask': [max_size, max_size],
'label': {
'boxes': [max_boxes, 4],
'area': [max_boxes,],
'objects/id': [max_boxes,],
'is_crowd': [max_boxes,],
'labels': [max_boxes,],
'image/id': [],
'size': [2,],
'orig_size': [2,],
},
}
if train:
ds = ds.shuffle(shuffle_buffer_size, shuffle_seed)
ds = ds.repeat()
ds = ds.map(decode_fn, tf.data.AUTOTUNE)
ds = ds.map(preprocess_fn, tf.data.AUTOTUNE)
ds = ds.padded_batch(batch_size, padded_shapes, drop_remainder=True)
else:
ds = ds.map(decode_fn, tf.data.AUTOTUNE)
ds = ds.map(preprocess_fn, tf.data.AUTOTUNE)
ds = ds.padded_batch(batch_size, padded_shapes, drop_remainder=False)
ds = ds.cache() # WARNING! Only if you have enough memory.
ds = ds.repeat()
ds = ds.prefetch(tf.data.AUTOTUNE)
return ds, builder.info
def build_pipeline(*, rng, batch_size: int, eval_batch_size: int,
num_shards: int, dataset_configs: ml_collections.ConfigDict):
"""Builds a train/test/valid `tf.data.Dataset` pipeline.
Args:
rng: Unused.
batch_size: Batch size for train dataset.
eval_batch_size: Batch size for test/valid dataset.
num_shards: Integer representing number of shards the batch dim is split into.
dataset_configs: A config dict containing dataset info.
Returns:
A `dataset_utils.Dataset` object holding train/test/valid iterators and
dataset metadata.
"""
builder = tfds.builder(dataset_configs.name)
max_size = dataset_configs.get('max_size', 1333)
max_boxes = dataset_configs.get('max_boxes', 100)
shuffle_buffer_size = dataset_configs.get('shuffle_buffer_size', 10_000)
train_preprocess_fn = make_coco_transforms('train', max_size)
eval_preprocess_fn = make_coco_transforms('validation', max_size)
decode_fn = partial(
decode_coco_detection_example,
input_range=dataset_configs.get('input_range'))
train_ds, ds_info = load_split_from_tfds(
builder,
train=True,
batch_size=batch_size,
decode_fn=decode_fn,
preprocess_fn=train_preprocess_fn,
max_size=max_size,
max_boxes=max_boxes,
shuffle_buffer_size=shuffle_buffer_size,
shuffle_seed=dataset_configs.get('rng_seed', 42))
eval_ds, _ = load_split_from_tfds(
builder,
train=False,
batch_size=eval_batch_size,
decode_fn=decode_fn,
preprocess_fn=eval_preprocess_fn,
max_size=max_size,
max_boxes=max_boxes)
# 0 is the background class, dataset classes run from 1..N
num_classes = ds_info.features['objects']['label'].num_classes + 1
maybe_pad_batches_train = partial(
dataset_utils.maybe_pad_batch, train=True, batch_size=batch_size)
maybe_pad_batches_eval = partial(
dataset_utils.maybe_pad_batch, train=False, batch_size=eval_batch_size)
shard_batches = partial(dataset_utils.shard, num_shards=num_shards)
train_iter = iter(train_ds)
train_iter = map(dataset_utils.tf_to_numpy, train_iter)
train_iter = map(maybe_pad_batches_train, train_iter)
train_iter = map(shard_batches, train_iter)
eval_iter = iter(eval_ds)
eval_iter = map(dataset_utils.tf_to_numpy, eval_iter)
eval_iter = map(maybe_pad_batches_eval, eval_iter)
eval_iter = map(shard_batches, eval_iter)
if dataset_configs.get('prefetch_to_device'):
train_iter = jax_utils.prefetch_to_device(
train_iter, dataset_configs.get('prefetch_to_device', 2))
eval_iter = jax_utils.prefetch_to_device(
eval_iter, dataset_configs.get('prefetch_to_device', 2))
meta_data = {
'num_classes': num_classes,
'input_shape': [-1, max_size, max_size, 3],
'num_train_examples': builder.info.splits['train'].num_examples,
'num_eval_examples': builder.info.splits['validation'].num_examples,
'input_dtype': jnp.float32,
'target_is_onehot': False,
'label_to_name': coco_utils.get_label_map(dataset_configs.name),
}
return dataset_utils.Dataset(train_iter, eval_iter, None, meta_data)