-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdataset_qp.py
63 lines (51 loc) · 3.06 KB
/
dataset_qp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import util
import datasets_util
class DatasetQP(torch.nn.Module):
def __init__(self, model, global_features_dim, geoloc_train_dataset, qp_threshold, dataset_name):
"""Dataset used to compute pairs of query-positive.
Parameters
----------
model : nn.Module, used to compute the query-positive pairs.
global_features_dim : int, dimension of the global features generated by the model.
geoloc_train_dataset : dataset_geoloc.GeolocDataset, containing the queries and gallery images.
threshold : float, only pairs with distance (in features space) below
the given threshold will be taken into account.
"""
super().__init__()
self.dataset_name = dataset_name
# Compute predictions with the given model on the given dataset
if self.dataset_name == "msls":
posit = geoloc_train_dataset.get_positives()
query_positive_pairs = []
for i in range(len(posit)):
for j in range(len(posit[i])):
query_positive_pairs.append([geoloc_train_dataset.queries_paths[i],geoloc_train_dataset.gallery_paths[posit[i][j]]])
self.query_positive_distances = query_positive_pairs
if self.dataset_name == "pitts30k":
_, _, predictions, correct_bool_mat, distances, _, _ = util.compute_features(geoloc_train_dataset, model, global_features_dim)
num_preds = predictions.shape[1]
real_positives = [[] for _ in range(geoloc_train_dataset.queries_num)]
# In query_positive_distances saves the index of query, positive, and their distance
# for each query-positive pair
query_positive_distances = []
for query_index in range(geoloc_train_dataset.queries_num):
query_path = geoloc_train_dataset.queries_paths[query_index]
for pred_index in range(num_preds):
if correct_bool_mat[query_index, pred_index] == 1:
distance = distances[query_index, pred_index]
positive = predictions[query_index, pred_index]
positive_path = geoloc_train_dataset.gallery_paths[positive]
real_positives[query_index].append(positive_path)
query_positive_distances.append((query_path, positive_path, distance))
self.query_positive_distances = [qpd for qpd in query_positive_distances if qpd[2] < qp_threshold]
def __getitem__(self, index):
if self.dataset_name == "msls":
query_path, positive_path = self.query_positive_distances[index]
if self.dataset_name == "pitts30k":
query_path, positive_path, _ = self.query_positive_distances[index]
query = datasets_util.open_image_and_apply_transform_train(query_path)
positive = datasets_util.open_image_and_apply_transform_train(positive_path)
return query, positive
def __len__(self):
return len(self.query_positive_distances)