diff --git a/atari_wrappers.py b/atari_wrappers.py index bef4b4c..e4648a9 100644 --- a/atari_wrappers.py +++ b/atari_wrappers.py @@ -2,17 +2,31 @@ import cv2 from collections import deque -class ResizeImage(gym.ObservationWrapper): +# Gym wrapper with clone/restore state +class Wrapper(gym.Wrapper): + def clone_state(self): + return self.env.clone_state() + + def restore_state(self, state): + self.env.restore_state(state) + + +class ResizeImage(Wrapper): def __init__(self, env, new_size): super(ResizeImage, self).__init__(env) self.resize_fn = lambda obs: cv2.resize(obs, dsize=new_size, interpolation=cv2.INTER_LINEAR) self.observation_space = gym.spaces.Box(low=0, high=255, shape=new_size) - def observation(self, observation): + def reset(self, **kwargs): + observation = self.env.reset(**kwargs) return self.resize_fn(observation) + def step(self, action): + observation, reward, done, info = self.env.step(action) + return self.resize_fn(observation), reward, done, info -class FrameBuffer(gym.Wrapper): + +class FrameBuffer(Wrapper): def __init__(self, env, buffer_size): assert (buffer_size > 0) super(FrameBuffer, self).__init__(env) @@ -36,6 +50,13 @@ def observation(self): # Return a list instead of a numpy array to reduce space in memory when storing the same frame more than once return list(self.observations) + def clone_state(self): + return (tuple(self.observations), self.env.clone_state()) + + def restore_state(self, state): + assert len(state[0]) == len(self.observations) + self.observations.extend(state[0]) + return self.env.restore_state(state[1]) def is_atari_env(env): import gym.envs.atari diff --git a/online_planning.py b/online_planning.py index 5dcce09..e9df3cc 100644 --- a/online_planning.py +++ b/online_planning.py @@ -63,7 +63,8 @@ def softmax_Q_tree_policy(tree, n_actions, discount_factor, temp=0): episode_done = current_root_data["done"] steps_cnt += 1 - print(actor.tree.root.data["s"]["world"], "Action: ", current_root_data["a"], "Reward: ", current_root_data["r"], + print("\n".join([" ".join(row) for row in env.unwrapped.get_char_matrix(actor.tree.root.data["s"])]), + "Action: ", current_root_data["a"], "Reward: ", current_root_data["r"], "Simulator steps:", actor.nodes_generated, "Planning steps:", steps_cnt, "\n") print("It took %i steps but the problem can be solved in 13." % steps_cnt) diff --git a/online_planning_learning.py b/online_planning_learning.py index f751faa..790d93f 100644 --- a/online_planning_learning.py +++ b/online_planning_learning.py @@ -121,7 +121,8 @@ def network_policy(node, branching_factor): steps_cnt += 1 episode_steps +=1 - print(actor.tree.root.data["s"]["world"], "Reward: ", r, "Simulator steps:", actor.nodes_generated, + print("\n".join([" ".join(row) for row in env.unwrapped.get_char_matrix(actor.tree.root.data["s"])]), + "Reward: ", r, "Simulator steps:", actor.nodes_generated, "Planning steps:", steps_cnt, "Loss:", loss.numpy(), "\n") if episode_done: print("Problem solved in %i steps (min 13 steps)."%episode_steps) diff --git a/piIW_alphazero.py b/piIW_alphazero.py index ff0b5e5..d189c50 100644 --- a/piIW_alphazero.py +++ b/piIW_alphazero.py @@ -83,7 +83,7 @@ def alphazero_planning_step(episode_transitions): # pi-IW planning step function with given hyperparameters -def get_pi_iw_planning_step_fn(actor, planner, policy_fn, tree_budget, discount_factor): +def get_pi_iw_planning_step_fn(actor, planner, policy_fn, tree_budget, discount_factor, temp): def pi_iw_planning_step(episode_tranistions): nodes_before_planning = len(actor.tree) budget_fn = lambda: len(actor.tree) - nodes_before_planning == tree_budget @@ -91,7 +91,7 @@ def pi_iw_planning_step(episode_tranistions): successor_fn=actor.generate_successor, stop_condition_fn=budget_fn, policy_fn=policy_fn) - return softmax_Q_tree_policy(actor.tree, actor.tree.branching_factor, discount_factor, temp=0) + return softmax_Q_tree_policy(actor.tree, actor.tree.branching_factor, discount_factor, temp=temp) return pi_iw_planning_step @@ -134,7 +134,6 @@ def run_episode(plan_step_fn, learner, dataset, cache_subtree, add_returns, prep return episode_rewards - class TrainStats: def __init__(self): self.last_interactions = 0 @@ -173,7 +172,7 @@ def report(self, episode_rewards, total_interactions): discount_factor = 0.99 puct_factor = 0.5 # AlphaZero first_moves_temp = np.inf # AlphaZero - policy_temp = 1 # AlphaZero + policy_temp = 1 # pi-IW and AlphaZero cache_subtree = True batch_size = 32 learning_rate = 0.0007 @@ -186,7 +185,6 @@ def report(self, episode_rewards, total_interactions): rmsprop_epsilon = 0.1 frameskip_atari = 15 - logger = logging.getLogger(__name__) parser = argparse.ArgumentParser() @@ -247,7 +245,8 @@ def report(self, episode_rewards, total_interactions): planner=planner, policy_fn=network_policy, tree_budget=tree_budget, - discount_factor=discount_factor) + discount_factor=discount_factor, + temp=policy_temp) learner = SupervisedPolicy(model, optimizer, regularization_factor=regularization_factor, use_graph=True) # Initialize experience replay: run complete episodes until we exceed both batch_size and dataset_min_transitions diff --git a/planning_step.py b/planning_step.py index 8c4c2c5..de895af 100644 --- a/planning_step.py +++ b/planning_step.py @@ -3,7 +3,7 @@ def features_to_atoms(feature_vector): # Define how we will extract features def gridenvs_BASIC_features(env, node): - node.data["features"] = features_to_atoms(env.unwrapped.state["world"].get_colors().flatten()) + node.data["features"] = features_to_atoms(env.unwrapped.get_colors().flatten()) if __name__ == "__main__": import gym diff --git a/tree.py b/tree.py index 704bb77..9b1f007 100644 --- a/tree.py +++ b/tree.py @@ -154,7 +154,7 @@ def __init__(self, env, observe_fn=None): def generate_successor(self, node, action): assert not self._done, "Trying to generate nodes, but either the episode is over or hasn't started yet. Please use reset()." if self.last_node is not node: - self.env.unwrapped.restore_state(node.data["s"]) + self.env.restore_state(node.data["s"]) # Perform step next_obs, r, end_of_episode, info = self.env.step(action) @@ -185,7 +185,7 @@ def reset(self): return self.tree def _observe(self, node): - node.data["s"] = self.env.unwrapped.clone_state() + node.data["s"] = self.env.clone_state() self.observe_fn(self.env, node) self.last_node = node