diff --git a/piIW_alphazero.py b/piIW_alphazero.py index f8c2574..d189c50 100644 --- a/piIW_alphazero.py +++ b/piIW_alphazero.py @@ -83,7 +83,7 @@ def alphazero_planning_step(episode_transitions): # pi-IW planning step function with given hyperparameters -def get_pi_iw_planning_step_fn(actor, planner, policy_fn, tree_budget, discount_factor): +def get_pi_iw_planning_step_fn(actor, planner, policy_fn, tree_budget, discount_factor, temp): def pi_iw_planning_step(episode_tranistions): nodes_before_planning = len(actor.tree) budget_fn = lambda: len(actor.tree) - nodes_before_planning == tree_budget @@ -91,7 +91,7 @@ def pi_iw_planning_step(episode_tranistions): successor_fn=actor.generate_successor, stop_condition_fn=budget_fn, policy_fn=policy_fn) - return softmax_Q_tree_policy(actor.tree, actor.tree.branching_factor, discount_factor, temp=0) + return softmax_Q_tree_policy(actor.tree, actor.tree.branching_factor, discount_factor, temp=temp) return pi_iw_planning_step @@ -172,7 +172,7 @@ def report(self, episode_rewards, total_interactions): discount_factor = 0.99 puct_factor = 0.5 # AlphaZero first_moves_temp = np.inf # AlphaZero - policy_temp = 1 # AlphaZero + policy_temp = 1 # pi-IW and AlphaZero cache_subtree = True batch_size = 32 learning_rate = 0.0007 @@ -245,7 +245,8 @@ def report(self, episode_rewards, total_interactions): planner=planner, policy_fn=network_policy, tree_budget=tree_budget, - discount_factor=discount_factor) + discount_factor=discount_factor, + temp=policy_temp) learner = SupervisedPolicy(model, optimizer, regularization_factor=regularization_factor, use_graph=True) # Initialize experience replay: run complete episodes until we exceed both batch_size and dataset_min_transitions