-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
38 lines (27 loc) · 814 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np
import sys
import tensorflow as tf
from sklearn import neural_network
# local imports
from constants import *
from game import *
from neural_network import NeuralNetwork
from envs.deep_snake_env import DeepSnakeEnv
args = sys.argv
should_train = SHOULD_TRAIN
should_plot = SHOULD_PLOT
# random seed (reproduciblity)
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)
# set the env
env = DeepSnakeEnv() # env to import
env.seed(RANDOM_SEED)
env.reset() # reset to env
model_name = f'trained_models/{TRAIN_MODEL}'
neural_network = NeuralNetwork(env, model_name) # import model
neural_network.train(episodes=N_EPISODES) # train model
if should_train:
neural_network.save_model(model_name) # save model
if should_plot:
neural_network.plot() # plot model
env.close()