-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathSVM Classifier for Object Detection.m
63 lines (49 loc) · 1.19 KB
/
SVM Classifier for Object Detection.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
run('vlfeat-0.9.21/toolbox/vl_setup.m');
[trD, trLb, valD, valLb, trRegs, valRegs] = HW2_Utils.getPosAndRandomNeg();
c = 10;
x = trD;
y = trLb;
[d, n] = size(trD);
f = ones(n, 1);
f = -1 * f;
h = zeros(n, n);
for i = 1:n
for j = 1:n
h(i, j) = dot(x(:, i), x(:, j)) * y(i) * y(j);
end
end
A = [];
b = [];
A_eq = trLb';
b_eq = 0;
lb = zeros(n, 1);
ub = c * ones(n, 1);
[alpha, f_val] = quadprog(h, f, A, b, A_eq, b_eq, lb, ub);
%disp(f_val);
temp = y .* alpha;
w = x * temp;
temp = abs(alpha - 0.05);
[alpha_min, index] = min(temp);
bias = y(index) - (w' * x(:, index));
y_pred = (w' * valD) + bias;
[vn, dummy] = size(y_pred);
for i = 1:vn
if y_pred(i) < 0
y_pred(i) = -1;
else
y_pred(i) = 1;
end
end
correct_predictions = 0;
wrong_predictions = 0;
for i = 1:vn
if y_pred(i) == valLb(i)
correct_predictions = correct_predictions + 1;
else
wrong_predictions = wrong_predictions + 1;
end
end
disp(correct_predictions);
disp(wrong_predictions);
HW2_Utils.genRsltFile(w, bias, "val", "question_4_4_1_output")
[ap, prec, rec] = HW2_Utils.cmpAP("question_4_4_1_output", "val");