Skip to content

Commit

Permalink
modify to pydocstyle linting
Browse files Browse the repository at this point in the history
  • Loading branch information
dhyeythumar committed Oct 1, 2020
1 parent a50981f commit eb67fb8
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 167 deletions.
69 changes: 47 additions & 22 deletions statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,36 @@

class Memory:
"""
This memory class is used to store the data for Tensorboard summary
and Terminal logs.
This class is used to store the data for Tensorboard summary and Terminal logs.
Length of the data stored is equal to SUMMARY_FREQ used while training.
Data length = BUFFER_SIZE is crunched to a single value before stored in this class.
"""
def __init__(self, RUN_ID):

def __init__(self, RUN_ID):
self.base_tb_dir = "./training_data/summaries/" + RUN_ID
self.writer = SummaryWriter(self.base_tb_dir)

# lists to store data length = SUMMARY_FREQ
self.rewards = []
self.episode_lens = []
self.actor_losses = []
self.critic_losses = []
self.advantages = []
self.actor_lrs = [] # actor learning rate
self.critic_lrs = [] # critic learning rate

def add_data(self, reward, episode_len, actor_loss, critic_loss, advantage, actor_lr, critic_lr):
self.rewards = []
self.episode_lens = []
self.actor_losses = []
self.critic_losses = []
self.advantages = []
self.actor_lrs = [] # actor learning rate
self.critic_lrs = [] # critic learning rate

def add_data(
self,
reward,
episode_len,
actor_loss,
critic_loss,
advantage,
actor_lr,
critic_lr,
):
"""Add data for tensorboard and terminal logging."""
self.rewards.append(reward)
self.episode_lens.append(episode_len)
self.actor_losses.append(actor_loss)
Expand All @@ -33,6 +43,7 @@ def add_data(self, reward, episode_len, actor_loss, critic_loss, advantage, acto
self.critic_lrs.append(critic_lr)

def clear_memory(self):
"""Clear the collected data."""
self.rewards.clear()
self.episode_lens.clear()
self.actor_losses.clear()
Expand All @@ -42,21 +53,35 @@ def clear_memory(self):
self.critic_lrs.clear()

def terminal_logs(self, step):
if (len(self.rewards) == 0):
"""Display logs on terminal."""
if len(self.rewards) == 0:
self.rewards.append(0)

print("[INFO]\tSteps: {}\tMean Reward: {:0.3f}\tStd of Reward: {:0.3f}".format(step, np.mean(self.rewards), np.std(self.rewards)))
print(
"[INFO]\tSteps: {}\tMean Reward: {:0.3f}\tStd of Reward: {:0.3f}".format(
step, np.mean(self.rewards), np.std(self.rewards)
)
)

def tensorboard_logs(self, step):
self.writer.add_scalar('Environment/Cumulative_reward', np.mean(self.rewards), step)
self.writer.add_scalar('Environment/Episode_length', np.mean(self.episode_lens), step)
"""Store the logs for tensorboard vis."""
self.writer.add_scalar(
"Environment/Cumulative_reward", np.mean(self.rewards), step
)
self.writer.add_scalar(
"Environment/Episode_length", np.mean(self.episode_lens), step
)

self.writer.add_scalar(
"Learning_rate/Actor_model", np.mean(self.actor_lrs), step
)
self.writer.add_scalar(
"Learning_rate/Critic_model", np.mean(self.critic_lrs), step
)

self.writer.add_scalar('Learning_rate/Actor_model', np.mean(self.actor_lrs), step)
self.writer.add_scalar('Learning_rate/Critic_model', np.mean(self.critic_lrs), step)
self.writer.add_scalar("Loss/Policy_loss", np.mean(self.actor_losses), step)
self.writer.add_scalar("Loss/Value_loss", np.mean(self.critic_losses), step)

self.writer.add_scalar('Loss/Policy_loss', np.mean(self.actor_losses), step)
self.writer.add_scalar('Loss/Value_loss', np.mean(self.critic_losses), step)
self.writer.add_scalar("Policy/Value_estimate", np.mean(self.advantages), step)

self.writer.add_scalar('Policy/Value_estimate', np.mean(self.advantages), step)

self.clear_memory()
63 changes: 39 additions & 24 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
from mlagents_envs.side_channel.engine_configuration_channel import (
EngineConfigurationChannel,
)
from mlagents_envs.exception import (
UnityEnvironmentException,
UnityCommunicationException,
Expand All @@ -9,52 +11,57 @@
import numpy as np
from typing import Deque, Dict, List, Tuple
from keras.models import load_model
import keras.backend as K
import tensorflow as tf

# import keras.backend as K
# import tensorflow as tf


# Name of the Unity environment binary to be launched
ENV_NAME = "./rl_env_binary/Windows_build/Learning-Agents--r1"
RUN_ID = "train-1"
ENV_NAME = "./rl_env_binary/Windows_build/Learning-Agents--r1"
RUN_ID = "train-1"


class Test_FindflagAgent:

def __init__(self, env: UnityEnvironment):

MODEL_NAME = self.get_model_name()
self.env = env
self.env.reset() # without this env won't work
self.env.reset() # without this env won't work
self.behavior_name = self.env.get_behavior_names()[0]
self.behavior_spec = self.env.get_behavior_spec(self.behavior_name)
self.state_dims = self.behavior_spec.observation_shapes[0][0]
self.n_actions = self.behavior_spec.action_size

self.actor = load_model(MODEL_NAME, custom_objects={'loss': 'categorical_hinge'})
self.actor = load_model(
MODEL_NAME, custom_objects={"loss": "categorical_hinge"}
)

def get_model_name(self) -> str:
"""Get the latest saved actor model name."""
_dir = "./training_data/model/" + RUN_ID
basepath = Path(_dir)
files_in_basepath = (entry for entry in basepath.iterdir() if entry.is_file())

# get the latest actor's saved model file name.
for item in files_in_basepath:
if (item.name.find("actor") != -1):
if item.name.find("actor") != -1:
name = _dir + "/" + item.name

print("-"*100)
print("-" * 100)
print("\t\tUsing {} saved model for testing.".format(name))
print("-"*100)
print("-" * 100)
return name

def check_done(self, step_result) -> bool:
"""Return the done status for env reset."""
if len(step_result[1]) != 0:
return True
else:
return False

def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool]:
"""Return the next_state, reward and done response of the env."""
"""
Apply the actions to the env, step the env and return new set of experience.
Return the next_state, reward and done response of the env.
"""
self.env.set_actions(self.behavior_name, action)
self.env.step()
step_result = self.env.get_steps(self.behavior_name)
Expand All @@ -67,22 +74,25 @@ def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool]:
reward = step_result[1].reward[0]
return next_state, reward, done

def get_action(self, state: np.ndarray) -> np.ndarray:
def get_action(self, action_probs: np.ndarray) -> np.ndarray:
"""Get actions from action probablities."""
n_agents = 1 # only 1 agent is used in the env.

action_probs = self.actor.predict(state, steps=1) # (1, 2)
action = action_probs[0]
action = np.clip(action, -1, 1) # just for confirmation
return np.reshape(action, (1, self.n_actions))
return np.reshape(action, (n_agents, self.n_actions))

def test(self) -> None:
"""Test the trained Actor model."""
self.env.reset()
step_result = self.env.get_steps(self.behavior_name)
state = step_result[0].obs[0]
score = 0

try:
while True:
action = self.get_action(state)
action_probs = self.actor.predict(state, steps=1) # (1, 2)
action = self.get_action(action_probs)
next_state, reward, done = self.step(action)
state = next_state
score += reward
Expand All @@ -96,16 +106,21 @@ def test(self) -> None:
UnityEnvironmentException,
UnityCommunicatorStoppedException,
) as ex:
print("-"*100)
print("-" * 100)
print("\t\tException has occured !!\tTesting was interrupted.")
print("-"*100)
print("-" * 100)
self.env.close()

if __name__ == '__main__':

if __name__ == "__main__":
engine_config_channel = EngineConfigurationChannel()
engine_config_channel.set_configuration_parameters(width=1800, height=900, time_scale=1.0)
engine_config_channel.set_configuration_parameters(
width=1800, height=900, time_scale=1.0
)

env = UnityEnvironment(file_name=ENV_NAME, seed=2, side_channels=[engine_config_channel])
env = UnityEnvironment(
file_name=ENV_NAME, seed=2, side_channels=[engine_config_channel]
)

agent = Test_FindflagAgent(env)
agent.test()
Loading

0 comments on commit eb67fb8

Please sign in to comment.