-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathKNN.py
72 lines (59 loc) · 2.24 KB
/
KNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import scipy.spatial.distance as dist
import math
class KNN(object):
def __init__(self):
self._bags = None
self._bag_predictions = None
self._labels = None
self._full_bags = None
self._DM = None
def fit(self, train_bags, train_labels, **kwargs):
self._bags = train_bags
self._labels = train_labels
self._K = kwargs['k']
def predict(self, Testbags):
print("Starting KNN")
train_bags = self._bags
full_bags = self._bags+Testbags
#print(full_bags)
pred_labels = np.array([])
self._DM = self.DistanceMatrix(train_bags, Testbags)
#print("Hi")
#print(self._DM)
#print(self._K)
#print("Train bag")
#print(train_bags)
#print("Printing labels")
#print(self._labels)
for i in range(0, len(self._DM)):
arr = np.array( self._DM[i] )
ind = arr.argsort()[:self._K]
#print("Array")
#print(arr)
#print("Indices of k minimum values")
#print(ind)
relevant_test_labels = []
for j in range(0, len(ind)):
relevant_test_labels.append(self._labels[ind[j]][0])
#print("All labels")
#print(relevant_test_labels)
relevant_test_labels.sort()
#print("Sorted labels")
#print(relevant_test_labels)
label_out = relevant_test_labels[int(math.floor(self._K / 2))]
pred_labels = np.append(pred_labels,label_out)
return pred_labels
def DistanceMatrix (self,train_bags, test_bags):
w, h = len(train_bags), len(test_bags)
Matrix = [[0 for x in range(w)] for y in range(h)]
count=0
for i in range(0, len(test_bags)):
for j in range(0, len(train_bags)):
Matrix[i][j] = _min_hau_bag(test_bags[i], train_bags[j])
return Matrix
def _min_hau_bag(X,Y):
Hausdorff_distance = max( min((min([list(dist.euclidean(x, y) for y in Y) for x in X]))),
min((min([list(dist.euclidean(x, y) for x in X) for y in Y])))
)
return Hausdorff_distance