-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathamcts.py
66 lines (54 loc) · 2.36 KB
/
amcts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from MCTS.sim import *
from RL.SupervisedPolicy import *
from RL.SupervisedValueNetwork import *
from MCTS.game import *
from MCTS.mcts import *
import numpy as np
import scipy.io as sio
import os
from tester import test_policy_vs_MCTS
import tensorflow as tf
from keras import backend as K
from learn_sup_policy_network import load_supervised_training_data
class AMCTSPlayer(ComputerPlayer):
"""
A player which uses an MCTS algorithm
enhanced by a value network and a Policy
network to more efficiently explore the
search tree.
"""
def __init__(self, name, time_limit, policy_agent = None, value_agent = None, rollout_randomness=0.3, v_network_weight=0.5):
self.policy_agent = policy_agent
self.value_agent = value_agent
self.rollout_randomness = rollout_randomness # % of the time rollout step is random, vs. policy based
self.v_network_weight = v_network_weight # Multiplies the strength of the value-agent prediction
ComputerPlayer.__init__(self, name, self.amcts_algo(), time_limit)
def amcts_algo(self):
def uct_heuristic(board):
if self.value_agent == None:
return 0
board_image = board.visualize_image(makeCurrPlayerRed=True)
return self.value_agent.predict_value(board_image)
def default_heuristic(board):
actions = board.get_legal_actions()
if self.policy_agent == None or np.random.rand() < self.rollout_randomness:
action = np.random.choice(list(actions))
return action
board_image = board.visualize_image(makeCurrPlayerRed=True)
column_prob_dist = self.policy_agent.predict_action(board_image)
legal_column_prob_dist = [column_prob_dist[action.col] for action in actions]
action_idx = np.argmax(legal_column_prob_dist)
return actions[action_idx]
algo = lambda board, time_limit: uct_with_heuristics(board, time_limit, uct_heuristic, default_heuristic, self.v_network_weight)
return algo
def main():
policy_agent = SupervisedPolicyAgent((144,144,3),7)
policy_agent.load_train_results()
value_agent = SupervisedValueNetworkAgent((144,144,3))
value_agent.load_train_results()
#amcts_algo = amcts(policy_agent,value_agent,0.5)
#amcts_player = game.ComputerPlayer('amcts', amcts_algo)
amcts_player = AMCTSPlayer('amcts', 2, policy_agent, value_agent)
test_policy_vs_MCTS(amcts_player,verbose=True)
if __name__ == '__main__':
main()