-
Notifications
You must be signed in to change notification settings - Fork 74
/
dataset.py
102 lines (91 loc) · 4.37 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import glob
import random
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
class Dataset(Dataset):
def __init__(self, dataset_path, split_path, split_number, input_shape, sequence_length, training):
self.training = training
self.label_index = self._extract_label_mapping(split_path)
self.sequences = self._extract_sequence_paths(dataset_path, split_path, split_number, training)
self.sequence_length = sequence_length
self.label_names = sorted(list(set([self._activity_from_path(seq_path) for seq_path in self.sequences])))
self.num_classes = len(self.label_names)
self.transform = transforms.Compose(
[
transforms.Resize(input_shape[-2:], Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
def _extract_label_mapping(self, split_path="data/ucfTrainTestlist"):
""" Extracts a mapping between activity name and softmax index """
with open(os.path.join(split_path, "classInd.txt")) as file:
lines = file.read().splitlines()
label_mapping = {}
for line in lines:
label, action = line.split()
label_mapping[action] = int(label) - 1
return label_mapping
def _extract_sequence_paths(
self, dataset_path, split_path="data/ucfTrainTestlist", split_number=1, training=True
):
""" Extracts paths to sequences given the specified train / test split """
assert split_number in [1, 2, 3], "Split number has to be one of {1, 2, 3}"
fn = f"trainlist0{split_number}.txt" if training else f"testlist0{split_number}.txt"
split_path = os.path.join(split_path, fn)
with open(split_path) as file:
lines = file.read().splitlines()
sequence_paths = []
for line in lines:
seq_name = line.split(".avi")[0]
sequence_paths += [os.path.join(dataset_path, seq_name)]
return sequence_paths
def _activity_from_path(self, path):
""" Extracts activity name from filepath """
return path.split("/")[-2]
def _frame_number(self, image_path):
""" Extracts frame number from filepath """
return int(image_path.split("/")[-1].split(".jpg")[0])
def _pad_to_length(self, sequence):
""" Pads the sequence to required sequence length """
left_pad = sequence[0]
if self.sequence_length is not None:
while len(sequence) < self.sequence_length:
sequence.insert(0, left_pad)
return sequence
def __getitem__(self, index):
sequence_path = self.sequences[index % len(self)]
# Sort frame sequence based on frame number
image_paths = sorted(glob.glob(f"{sequence_path}/*.jpg"), key=lambda path: self._frame_number(path))
# Pad frames sequences shorter than `self.sequence_length` to length
image_paths = self._pad_to_length(image_paths)
if self.training:
# Randomly choose sample interval and start frame
sample_interval = np.random.randint(1, len(image_paths) // self.sequence_length + 1)
start_i = np.random.randint(0, len(image_paths) - sample_interval * self.sequence_length + 1)
flip = np.random.random() < 0.5
else:
# Start at first frame and sample uniformly over sequence
start_i = 0
sample_interval = 1 if self.sequence_length is None else len(image_paths) // self.sequence_length
flip = False
# Extract frames as tensors
image_sequence = []
for i in range(start_i, len(image_paths), sample_interval):
if self.sequence_length is None or len(image_sequence) < self.sequence_length:
image_tensor = self.transform(Image.open(image_paths[i]))
if flip:
image_tensor = torch.flip(image_tensor, (-1,))
image_sequence.append(image_tensor)
image_sequence = torch.stack(image_sequence)
target = self.label_index[self._activity_from_path(sequence_path)]
return image_sequence, target
def __len__(self):
return len(self.sequences)