-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_dataset.py
69 lines (62 loc) · 2.71 KB
/
mnist_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import pandas as pd
from torch.utils.data import Dataset
import numpy as np
from tqdm import tqdm
class MNISTDataset(Dataset):
def __init__(self, data_df: pd.DataFrame, transform=None, is_test=False):
# method will run once when class object is created.
# method will create data at the time of object creation.
# this will save time of training
super(MNISTDataset, self).__init__()
dataset = []
labels_positive = {}
labels_negative = {}
if is_test == False:
# for each label create a set of same label images.
for i in list(data_df.label.unique()):
labels_positive[i] = data_df[data_df.label == i].to_numpy()
# for each label create a set of image of different label.
for i in list(data_df.label.unique()):
labels_negative[i] = data_df[data_df.label != i].to_numpy()
for i, row in tqdm(data_df.iterrows(), total=len(data_df)):
data = row.to_numpy()
# if test then only image will be returned.
if is_test:
label = -1
first = data.reshape(28, 28)
second = -1
dis = -1
else:
# label and image of the index for each row in df
label = data[0]
first = data[1:].reshape(28, 28)
# probability of same label image == 0.5
if np.random.randint(0, 2) == 0:
# randomly select same label image
second = labels_positive[label][
np.random.randint(0, len(labels_positive[label]))
]
else:
# randomly select different(negative) label
second = labels_negative[label][
np.random.randint(0, len(labels_negative[label]))
]
# cosine is 1 for same and 0 for different label
dis = 1.0 if second[0] == label else 0.0
# reshape image
second = second[1:].reshape(28, 28)
# apply transform on both images
if transform is not None:
first = transform(first.astype(np.float32))
if second is not -1:
second = transform(second.astype(np.float32))
# append to dataset list.
# this random list is created once and used in every epoch
dataset.append((first, second, dis, label))
self.dataset = dataset
self.transform = transform
self.is_test = is_test
def __len__(self):
return len(self.dataset)
def __getitem__(self, i):
return self.dataset[i]