Skip to content

Commit

Permalink
Bugfix: taking into account clone and restore functions at frame buff…
Browse files Browse the repository at this point in the history
…er wrapper. Updated gridenvs changes.
  • Loading branch information
mjunyentb committed Jul 29, 2020
1 parent 6d95e00 commit ac897ff
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 14 deletions.
27 changes: 24 additions & 3 deletions atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,31 @@
import cv2
from collections import deque

class ResizeImage(gym.ObservationWrapper):
# Gym wrapper with clone/restore state
class Wrapper(gym.Wrapper):
def clone_state(self):
return self.env.clone_state()

def restore_state(self, state):
self.env.restore_state(state)


class ResizeImage(Wrapper):
def __init__(self, env, new_size):
super(ResizeImage, self).__init__(env)
self.resize_fn = lambda obs: cv2.resize(obs, dsize=new_size, interpolation=cv2.INTER_LINEAR)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=new_size)

def observation(self, observation):
def reset(self, **kwargs):
observation = self.env.reset(**kwargs)
return self.resize_fn(observation)

def step(self, action):
observation, reward, done, info = self.env.step(action)
return self.resize_fn(observation), reward, done, info

class FrameBuffer(gym.Wrapper):

class FrameBuffer(Wrapper):
def __init__(self, env, buffer_size):
assert (buffer_size > 0)
super(FrameBuffer, self).__init__(env)
Expand All @@ -36,6 +50,13 @@ def observation(self):
# Return a list instead of a numpy array to reduce space in memory when storing the same frame more than once
return list(self.observations)

def clone_state(self):
return (tuple(self.observations), self.env.clone_state())

def restore_state(self, state):
assert len(state[0]) == len(self.observations)
self.observations.extend(state[0])
return self.env.restore_state(state[1])

def is_atari_env(env):
import gym.envs.atari
Expand Down
3 changes: 2 additions & 1 deletion online_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def softmax_Q_tree_policy(tree, n_actions, discount_factor, temp=0):

episode_done = current_root_data["done"]
steps_cnt += 1
print(actor.tree.root.data["s"]["world"], "Action: ", current_root_data["a"], "Reward: ", current_root_data["r"],
print("\n".join([" ".join(row) for row in env.unwrapped.get_char_matrix(actor.tree.root.data["s"])]),
"Action: ", current_root_data["a"], "Reward: ", current_root_data["r"],
"Simulator steps:", actor.nodes_generated, "Planning steps:", steps_cnt, "\n")

print("It took %i steps but the problem can be solved in 13." % steps_cnt)
3 changes: 2 additions & 1 deletion online_planning_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def network_policy(node, branching_factor):

steps_cnt += 1
episode_steps +=1
print(actor.tree.root.data["s"]["world"], "Reward: ", r, "Simulator steps:", actor.nodes_generated,
print("\n".join([" ".join(row) for row in env.unwrapped.get_char_matrix(actor.tree.root.data["s"])]),
"Reward: ", r, "Simulator steps:", actor.nodes_generated,
"Planning steps:", steps_cnt, "Loss:", loss.numpy(), "\n")
if episode_done:
print("Problem solved in %i steps (min 13 steps)."%episode_steps)
Expand Down
11 changes: 5 additions & 6 deletions piIW_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ def alphazero_planning_step(episode_transitions):


# pi-IW planning step function with given hyperparameters
def get_pi_iw_planning_step_fn(actor, planner, policy_fn, tree_budget, discount_factor):
def get_pi_iw_planning_step_fn(actor, planner, policy_fn, tree_budget, discount_factor, temp):
def pi_iw_planning_step(episode_tranistions):
nodes_before_planning = len(actor.tree)
budget_fn = lambda: len(actor.tree) - nodes_before_planning == tree_budget
planner.plan(tree=actor.tree,
successor_fn=actor.generate_successor,
stop_condition_fn=budget_fn,
policy_fn=policy_fn)
return softmax_Q_tree_policy(actor.tree, actor.tree.branching_factor, discount_factor, temp=0)
return softmax_Q_tree_policy(actor.tree, actor.tree.branching_factor, discount_factor, temp=temp)
return pi_iw_planning_step


Expand Down Expand Up @@ -134,7 +134,6 @@ def run_episode(plan_step_fn, learner, dataset, cache_subtree, add_returns, prep

return episode_rewards


class TrainStats:
def __init__(self):
self.last_interactions = 0
Expand Down Expand Up @@ -173,7 +172,7 @@ def report(self, episode_rewards, total_interactions):
discount_factor = 0.99
puct_factor = 0.5 # AlphaZero
first_moves_temp = np.inf # AlphaZero
policy_temp = 1 # AlphaZero
policy_temp = 1 # pi-IW and AlphaZero
cache_subtree = True
batch_size = 32
learning_rate = 0.0007
Expand All @@ -186,7 +185,6 @@ def report(self, episode_rewards, total_interactions):
rmsprop_epsilon = 0.1
frameskip_atari = 15


logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -247,7 +245,8 @@ def report(self, episode_rewards, total_interactions):
planner=planner,
policy_fn=network_policy,
tree_budget=tree_budget,
discount_factor=discount_factor)
discount_factor=discount_factor,
temp=policy_temp)
learner = SupervisedPolicy(model, optimizer, regularization_factor=regularization_factor, use_graph=True)

# Initialize experience replay: run complete episodes until we exceed both batch_size and dataset_min_transitions
Expand Down
2 changes: 1 addition & 1 deletion planning_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ def features_to_atoms(feature_vector):

# Define how we will extract features
def gridenvs_BASIC_features(env, node):
node.data["features"] = features_to_atoms(env.unwrapped.state["world"].get_colors().flatten())
node.data["features"] = features_to_atoms(env.unwrapped.get_colors().flatten())

if __name__ == "__main__":
import gym
Expand Down
4 changes: 2 additions & 2 deletions tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __init__(self, env, observe_fn=None):
def generate_successor(self, node, action):
assert not self._done, "Trying to generate nodes, but either the episode is over or hasn't started yet. Please use reset()."
if self.last_node is not node:
self.env.unwrapped.restore_state(node.data["s"])
self.env.restore_state(node.data["s"])

# Perform step
next_obs, r, end_of_episode, info = self.env.step(action)
Expand Down Expand Up @@ -185,7 +185,7 @@ def reset(self):
return self.tree

def _observe(self, node):
node.data["s"] = self.env.unwrapped.clone_state()
node.data["s"] = self.env.clone_state()
self.observe_fn(self.env, node)
self.last_node = node

Expand Down

0 comments on commit ac897ff

Please sign in to comment.