-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTests.java
44 lines (44 loc) · 2.1 KB
/
Tests.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import java.util.Arrays;
public class Tests{
public static void main(String[] args){
/* Test 1: Forward propagation given premade model
Network test1 = new Network("SampleXOR.model");
System.out.println(Arrays.toString(test1.evaluate(new double[]{0,0})));
System.out.println(Arrays.toString(test1.evaluate(new double[]{0,1})));
System.out.println(Arrays.toString(test1.evaluate(new double[]{1,0})));
System.out.println(Arrays.toString(test1.evaluate(new double[]{1,1})));
*/
/*
Network test2 = new Network(3, 0.1, new int[] {2, 5, 1}, new String[]{"identity", "sigmoid", "sigmoid"});
//Network test2 = new Network("XOR-Sigmoid-BackpropagationTrained.model");
test2.training_mode = 0;
test2.randomize_weights();
double[][] test2cases = new double[][]{{0, 0}, {0, 1}, {1, 0}, {1, 1}};
double[][] test2correct = new double[][]{{0}, {1}, {1}, {0}};
for(int i = 0; i < 10000; i++){
//System.out.println(test2.outputNetwork());
System.out.println("Error for epoch " + i + ": " + test2.trainOneEpoch(test2cases, test2correct));
}
System.out.println(test2.outputNetwork(false));
for(int i = 0; i < 4; i++){
System.out.println(Arrays.toString(test2.evaluate(test2cases[i])));
}
*/
// XOR, but if 0, set node 0 to 1, if 1, set node 1 to 1
//Network test3 = new Network(3, 0.1, new int[] {2, 5, 2}, new String[]{"identity", "lrelu", "softmax"});
Network test3 = new Network("XOR-SoftmaxLogLoss-BackpropagationTrained.model");
//test3.training_mode = 0;
//test3.randomize_weights();
//test3.loss_function = new LogLoss();
double[][] test3cases = new double[][]{{0, 0}, {0, 1}, {1, 0}, {1, 1}};
double[][] test3correct = new double[][]{{1, 0}, {0, 1}, {0, 1}, {1, 0}};
for(int i = 0; i < 1; i++){
//System.out.println(test3.outputNetwork());
System.out.println("Error for epoch " + i + ": " + test3.trainOneEpoch(test3cases, test3correct));
}
System.out.println(test3.outputNetwork(false));
for(int i = 0; i < 4; i++){
System.out.println(Arrays.toString(test3.evaluate(test3cases[i])));
}
}
}