-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathsvm_lda.py
48 lines (38 loc) · 1.59 KB
/
svm_lda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#!/usr/bin/env python
from sklearn import svm, lda
import numpy as np
class Classify:
def __init__(self, labelled):
''' Set up the dataset and target as numpy arrays. '''
self.positions = {1 : 'BLINK',
2 : 'UP',
3 : 'UP-RIGHT',
4 : 'RIGHT',
5 : 'DOWN-RIGHT',
6 : 'DOWN',
7 : 'DOWN-LEFT',
8 : 'LEFT',
9 : 'UP-LEFT',
10 : 'STRAIGHT'}
self.dataset = np.array(labelled)
self.dataset = self.dataset.reshape((len(self.dataset), -1))
self.targets = np.array(self.positions.keys())
self.clf = None
def svm_predict(self, unlabelled):
''' Use a Support Vector Machine for classification. '''
self.clf = svm.SVC()
self.clf.fit(self.dataset, self.targets)
unlabelled = np.array(unlabelled)
unlabelled = unlabelled.flatten()
target = self.clf.predict(unlabelled)
return self.positions[target[0]]
def lda_predict(self, unlabelled):
''' Use Linear Discriminant Analysis for classification.
WARNING: Will only work when we have multiple data samples
for each dataset (i.e., two for left, two for right, etc.)'''
self.clf = lda.LDA()
self.clf.fit(self.dataset, self.targets)
unlabelled = np.array(unlabelled)
unlabelled = unlabelled.flatten()
target = self.clf.predict(unlabelled)
return self.positions[target[0]]