-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGame.java
153 lines (124 loc) · 4.65 KB
/
Game.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import ml.classifiers.GeneticNN;
import javax.swing.*;
import java.awt.*;
import java.util.ArrayList;
import java.util.Collections;
public class Game {
public static int iterations = 10;
//parameters for the genetic algorithm
private int numFeatures = 7;
private int numNetworks = 20;
private int numGenerations = 10;
public int genNum = 1;
//parameters for each neural network in the genetic algorithm
ArrayList<GeneticNN> networkList = new ArrayList<GeneticNN>();
private int total;
public int numLayers = 1;
public int numHidden = 25;
/**
* run the game
* @param args
* @throws AWTException
*/
public static void main(String[] args) throws AWTException {
Game game = new Game();
game.start();
}
/**
* Initializes first gen of neural nets
*/
public Game() {
total = numNetworks;
for (int i = 0; i < numNetworks; i++) {
GeneticNN network = new GeneticNN(numHidden, numLayers, numFeatures, genNum);
network.train();
networkList.add(network);
}
}
/**
* run the generation of snakes and then create children based on the results
* @throws AWTException
*/
public void start() throws AWTException {
this.runGen(networkList);
for (int g = 1; g <= numGenerations; g++) {
ArrayList<GeneticNN> newNetList = nextGen(networkList);
networkList = newNetList;
runGen(networkList);
}
}
/**
* let this generation of snakes play the game and get fitness scores
* @param networkList
* @throws AWTException
*/
public void runGen(ArrayList<GeneticNN> networkList) throws AWTException {
//run each network on the game
for (int i = 0; i < numNetworks; i++) {
Snake snake = new Snake();
Board board = snake.getBoard();
board.setNumFeatures(numFeatures);
board.setNetwork(networkList.get(i));
EventQueue.invokeLater(() -> {
JFrame ex = snake;
board.setJFrame(ex);
ex.setVisible(true);
});
}
//wait until all the networks have finished before moving on
while (Board.numFinished < total) {
System.out.print("");
}
total += numNetworks;
genNum++;
}
/**
* create the next generation of children based on the fittest snakes that just ran
* @param networkList the networks in this generation
* @return the children networks created from this generation
*/
public ArrayList<GeneticNN> nextGen(ArrayList<GeneticNN> networkList) {
ArrayList<GeneticNN> allTheChildren = new ArrayList<GeneticNN>(); //the new generation
Collections.sort(networkList, GeneticNN.byFitness());
// prints average performance per generation
double count = 0;
for (GeneticNN net : networkList) {
count+= net.appleCount();
}
System.out.println(count/numNetworks);
//Get the top 2 fittest networks and call the crossover function on them
GeneticNN net1 = networkList.get(numNetworks - 1);
GeneticNN net2 = networkList.get(numNetworks - 2);
ArrayList<GeneticNN> someChildren = crossover(net1, net2);
//add all the children to the output and return the next generation
for (int j = 0; j < someChildren.size(); j++) {
allTheChildren.add(someChildren.get(j));
}
return allTheChildren;
}
/**
* create children from two given networks
* @param net1 first parent
* @param net2 second parent
* @return the children created by first and second networks
*/
private ArrayList<GeneticNN> crossover(GeneticNN net1, GeneticNN net2) {
ArrayList<GeneticNN> networks = new ArrayList<GeneticNN>();
//weights of the first parent
double[][] input1 = net1.getInputTable();
double[][] output1 = net1.getOutputTable();
double[][][] layers1 = net1.getLayerTable();
//weights of the second parent
double[][] input2 = net2.getInputTable();
double[][] output2 = net2.getOutputTable();
double[][][] layers2 = net2.getLayerTable();
//create numNetworks number of new children
for (int num = 0; num < numNetworks; num++) {
GeneticNN tmpNetwork = new GeneticNN(numHidden, numLayers, numFeatures, genNum);
tmpNetwork.crossOver(input1, layers1, output1, input2, layers2, output2);
tmpNetwork.mutate(.01);
networks.add(tmpNetwork);
}
return networks;
}
}