Skip to content

Commit

Permalink
tournament
Browse files Browse the repository at this point in the history
  • Loading branch information
ymahlau committed Jan 9, 2024
1 parent 52ed1de commit 5f0eaff
Show file tree
Hide file tree
Showing 12 changed files with 384 additions and 27 deletions.
169 changes: 169 additions & 0 deletions scripts/depth/eval_tournament.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from datetime import datetime
import itertools
import math
from pathlib import Path
import pickle
import random
import numpy as np
from src.agent import Agent
from src.agent.albatross import AlbatrossAgent, AlbatrossAgentConfig
from src.agent.initialization import get_agent_from_config
from src.agent.one_shot import NetworkAgentConfig
from src.agent.search_agent import AreaControlSearchAgentConfig
from src.game.actions import sample_individual_actions

from src.game.battlesnake.bootcamp.test_envs_7x7 import survive_on_7x7, survive_on_7x7_4_player, survive_on_7x7_constrictor, survive_on_7x7_constrictor_4_player
from src.game.game import Game
from src.game.initialization import get_game_from_config
from src.game.utils import step_with_draw_prevention
from src.misc.utils import set_seed
from src.network.initialization import get_network_from_file


def play_single_game(
game: Game,
agent_list: list[Agent],
iterations: list[int],
temperatures: list[float],
prevent_draw: bool,
verbose_level: int = 0,
):
game.reset()
for agent in agent_list:
agent.reset_episode()
step_counter = 0
while not game.is_terminal():
joint_action_list: list[int] = []
for player in game.players_at_turn():
probs, _ = agent_list[player](game, player=player, iterations=iterations[player])
probs[game.illegal_actions(player)] = 0
probs /= probs.sum()
if verbose_level >= 2:
print(probs, flush=True)
action = sample_individual_actions(probs[np.newaxis, ...], temperatures[player])[0]
joint_action_list.append(action)
if prevent_draw:
step_with_draw_prevention(game, tuple(joint_action_list))
else:
game.step(tuple(joint_action_list))
if verbose_level >= 2:
print(joint_action_list, flush=True)
game.render()
print('#########################', flush=True)
step_counter += 1
# add rewards of player 0 to sum
cum_rewards = game.get_cum_rewards()
if verbose_level >= 1:
print(f"{datetime.now()}: {cum_rewards}", flush=True)
game.reset()
for agent in agent_list:
agent.reset_episode()
return cum_rewards, step_counter



def play_tournament(experiment_id: int):
num_seeds = 5
num_parts = 20

depths = np.asarray(list(range(200, 2001, 200)), dtype=int)
depth_dict = {
x: d for x, d in enumerate(depths)
}
depth_dict[len(depths)] = 1 # albatross
depth_dict[len(depths) + 1] = 1 # alphaZero

save_path = Path(__file__).parent.parent.parent / 'a_data' / 'bs_depth'
base_name = 'trnmt_small'
prefix = '4nd7'
num_games_per_part = 100


game_dict = {
'4nd7': survive_on_7x7_4_player(),
'd7': survive_on_7x7_constrictor(),
'nd7': survive_on_7x7(),
'4d7': survive_on_7x7_constrictor_4_player(),
}

pref_lists = [
list(range(num_seeds)),
list(range(int(num_parts)))
]
prod = list(itertools.product(*pref_lists))
seed, cur_part = prod[experiment_id]
num_agents = 4 if prefix.startswith("4") else 2

set_seed((seed + 1) * cur_part)
game_cfg = game_dict[prefix]

net_path = Path(__file__).parent.parent.parent.parent.parent / 'a_saved_runs' / 'battlesnake'
resp_path = net_path / f'{prefix}_resp_{seed}' / 'latest.pt'
proxy_path = net_path / f'{prefix}_proxy_{seed}' / 'latest.pt'
az_path = net_path / f'{prefix}_{seed}' / 'latest.pt'

net = get_network_from_file(resp_path).eval()
alb_network_agent_cfg = NetworkAgentConfig(
net_cfg=net.cfg,
temperature_input=True,
single_temperature=False,
init_temperatures=[0, 0, 0, 0] if prefix.startswith('4') else [0, 0],
)
alb_online_agent_cfg = AlbatrossAgentConfig(
num_player=num_agents,
agent_cfg=alb_network_agent_cfg,
device_str='cpu',
response_net_path=str(resp_path),
proxy_net_path=str(proxy_path),
noise_std=None,
num_samples=1,
init_temp=5,
)
alb_online_agent = AlbatrossAgent(alb_online_agent_cfg)

net = get_network_from_file(az_path).eval()
az_cfg = NetworkAgentConfig(net_cfg=net.cfg)
az_agent = get_agent_from_config(az_cfg)
az_agent.replace_net(net)

base_agent_cfg = AreaControlSearchAgentConfig()
base_agent = get_agent_from_config(base_agent_cfg)

agent_dict = {
idx: base_agent for idx in range(len(depths))
}
agent_dict[len(depths)] = alb_online_agent # albatross
agent_dict[len(depths) + 1] = az_agent # alphaZero

game_cfg.ec.temperature_input = False
game_cfg.ec.single_temperature_input = False
game = get_game_from_config(game_cfg)

result_list = []
for game_idx in range(num_games_per_part):
# sample agents without replacement
sampled_indices = random.sample(range(len(depth_dict)), 4)
cur_agent_list = [agent_dict[idx] for idx in sampled_indices]
cur_iterations = [depth_dict[idx] for idx in sampled_indices]

cur_result, cur_length = play_single_game(
game=game,
agent_list=cur_agent_list,
iterations=cur_iterations,
temperatures=[math.inf for _ in range(num_agents)],
prevent_draw=False,
verbose_level=0,
)
result_list.append((sampled_indices, cur_result, cur_length))

full_save_path = save_path / f'{base_name}_{prefix}_{seed}_{cur_part}.pkl'
with open(full_save_path, 'wb') as f:
pickle.dump(result_list, f)
print(f"{datetime.now()} - {game_idx}: {sampled_indices} - {cur_result}, {cur_length}")





if __name__ == '__main__':
play_tournament(0)
21 changes: 13 additions & 8 deletions scripts/depth/evaluate_bs_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@


def evaluate_bs_depth_func(experiment_id: int):
num_games_per_part = 100
num_parts = 10
search_iterations = np.arange(50, 2001, 50)
# search_iterations = np.asarray([500])
save_path = Path(__file__).parent.parent.parent / 'a_data' / 'bs_depth'
base_name = 'bs_alb_fixed'
eval_az = False
# save_path = Path(__file__).parent.parent.parent / 'a_data' / 'temp'
base_name = 'nodraw_bs'
eval_az = True

game_dict = {
'4nd7': survive_on_7x7_4_player(),
Expand All @@ -41,6 +42,10 @@ def evaluate_bs_depth_func(experiment_id: int):
prod = list(itertools.product(*pref_lists))
prefix, seed, cur_game_id = prod[experiment_id]
assert isinstance(prefix, str)
num_games_per_part = 100
if 'n' in prefix:
num_games_per_part = 50

# we do not want to set the same seed in every game and repeat the same play.
# Therefore, set a different seed for every game and base seed
set_seed((seed + 1) * cur_game_id)
Expand Down Expand Up @@ -107,8 +112,8 @@ def evaluate_bs_depth_func(experiment_id: int):
# cur_log_path = save_path / f'{base_name}_log_{prefix}_{seed}_{cur_game_id}_{cur_iterations}.pkl'
# alb_online_agent.cfg.estimate_log_path = str(cur_log_path)

cur_temp = mean_temps[cur_iterations.item()]
alb_online_agent.cfg.fixed_temperatures = [cur_temp for _ in range(game_cfg.num_players)]
# cur_temp = mean_temps[cur_iterations.item()]
# alb_online_agent.cfg.fixed_temperatures = [cur_temp for _ in range(game_cfg.num_players)]

print(f'Started evaluation with: {iteration_idx=}, {cur_iterations=}')

Expand All @@ -120,7 +125,7 @@ def evaluate_bs_depth_func(experiment_id: int):
enemy_iterations=cur_iterations,
temperature_list=[math.inf],
own_temperature=math.inf,
prevent_draw=True,
prevent_draw=False,
switch_positions=False,
verbose_level=1,
own_iterations=1,
Expand All @@ -137,7 +142,7 @@ def evaluate_bs_depth_func(experiment_id: int):
enemy_iterations=cur_iterations,
temperature_list=[math.inf],
own_temperature=math.inf,
prevent_draw=True,
prevent_draw=False,
switch_positions=False,
verbose_level=1,
own_iterations=1,
Expand All @@ -159,4 +164,4 @@ def evaluate_bs_depth_func(experiment_id: int):


if __name__ == '__main__':
evaluate_bs_depth_func(175)
evaluate_bs_depth_func(0)
File renamed without changes.
File renamed without changes.
File renamed without changes.
39 changes: 26 additions & 13 deletions scripts/plotting/plot_depth_bs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from matplotlib import pyplot as plt

import numpy as np
import scipy
import seaborn
from src.misc.const import COLORS, LIGHT_COLORS, LINESTYLES

Expand All @@ -17,23 +18,25 @@ def plot_bs_depth():
num_seeds = 5
depths = np.asarray(list(range(50, 2001, 50)), dtype=int)

# game_abbrevs = ['4nd7', 'd7', 'nd7', '4d7']
game_abbrevs = ['4d7']
# prefix -> (alb, az)
base_names = {
# 'd7': ('bs_az_alb_area_50_to_2000_inf_100games', 'bs_az_alb_area_50_to_2000_inf_100games'),
# '4d7': ('bs_az_alb_area_50_to_2000_inf_100games', 'bs_az_alb_area_50_to_2000_inf_100games'),
# '4nd7': ('bs_az_alb_area_50_to_2000_inf_100games', 'bs_az_alb_area_50_to_2000_inf_100games'),
'nd7': ('bs_az_alb_area_50_to_2000_100games_retrained', 'bs_az_alb_area_50_to_2000_inf_100games')
}

alb_base_name = 'bs_az_alb_area_50_to_2000_inf_100games'
az_base_name = 'bs_az_alb_area_50_to_2000_inf_100games'

for abbrev in game_abbrevs:
for abbrev, (alb_base, az_base) in base_names.items():
full_list_alb, full_list_az, length_list_alb, length_list_az = [], [], [], []
for seed in range(num_seeds):
for part in range(num_parts):
file_name_alb = f'{alb_base_name}_{abbrev}_{seed}_{part}.pkl'
file_name_alb = f'{alb_base}_{abbrev}_{seed}_{part}.pkl'
with open(data_path / file_name_alb, 'rb') as f:
cur_dict = pickle.load(f)
full_list_alb.append(cur_dict['results_alb'])
length_list_alb.append(cur_dict['lengths_alb'])

file_name_az = f'{az_base_name}_{abbrev}_{seed}_{part}.pkl'
file_name_az = f'{az_base}_{abbrev}_{seed}_{part}.pkl'
with open(data_path / file_name_az, 'rb') as f:
cur_dict = pickle.load(f)
full_list_az.append(cur_dict['results_az'])
Expand All @@ -43,10 +46,19 @@ def plot_bs_depth():
length_arr_alb = np.concatenate(length_list_alb, axis=2)[:, 0, :]
length_arr_az = np.concatenate(length_list_az, axis=2)[:, 0, :]

# discount = 0.99
# full_arr_alb = np.power(discount, length_arr_alb) * full_arr_alb
# full_arr_az = np.power(discount, length_arr_az) * full_arr_az

full_arr_alb = full_arr_alb.reshape(len(depths), num_seeds, -1).mean(axis=-1)
full_arr_az = full_arr_az.reshape(len(depths), num_seeds, -1).mean(axis=-1)
length_arr_alb = length_arr_alb.reshape(len(depths), num_seeds, -1).mean(axis=-1)
length_arr_az = length_arr_az.reshape(len(depths), num_seeds, -1).mean(axis=-1)

if abbrev == '4d7':
full_arr_alb = scipy.signal.savgol_filter(full_arr_alb, window_length=5, polyorder=1, axis=0)
full_arr_az = scipy.signal.savgol_filter(full_arr_az, window_length=5, polyorder=1, axis=0)

# length_arr_alb = length_arr_alb.reshape(len(depths), num_seeds, -1).mean(axis=-1)
# length_arr_az = length_arr_az.reshape(len(depths), num_seeds, -1).mean(axis=-1)

plt.clf()
plt.figure(dpi=600)
Expand Down Expand Up @@ -74,16 +86,17 @@ def plot_bs_depth():
label='Albatross',
)

fontsize = 'medium'
fontsize = 'large'
plt.xlabel('Enemy Search Iterations', fontsize=fontsize)
plt.ylabel('Reward', fontsize=fontsize)
plt.xlim(depths[0], depths[-1])
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.legend(fontsize=fontsize)
if abbrev == 'd7':
plt.legend(fontsize='x-large')
plt.tight_layout()
# plt.savefig(img_path / f'inf_100g_{abbrev}_depths.png')
plt.savefig(img_path / f'inf_100g_{abbrev}_depths.pdf')
plt.savefig(img_path / f'bs_depth_{abbrev}.pdf', bbox_inches='tight', pad_inches=0.0)


if __name__ == '__main__':
Expand Down
34 changes: 32 additions & 2 deletions scripts/temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,37 @@ def main():
# print(game.get_cum_rewards())


if __name__ == '__main__':
main()
def main2():
save_path = Path(__file__).parent.parent / 'a_data' / 'temp'
base_name = '500_nodraw'
prefix = "nd7"

num_seeds = 5
num_parts = 10
# for seed in range(num_seeds):
# for part in range(num_parts):
# cur_path = save_path / f'{base_name}_{prefix}_{seed}_{part}.pkl'
# with open(cur_path, 'rb') as f:
# cur_dict = pickle.load(f)
full_list_alb, full_list_az, length_list_alb, length_list_az = [], [], [], []
for seed in range(num_seeds):
for part in range(num_parts):
cur_path = save_path / f'{base_name}_{prefix}_{seed}_{part}.pkl'
with open(cur_path, 'rb') as f:
cur_dict = pickle.load(f)
full_list_alb.append(cur_dict['results_alb'])
length_list_alb.append(cur_dict['lengths_alb'])

full_list_az.append(cur_dict['results_az'])
length_list_az.append(cur_dict['lengths_az'])
full_arr_alb = np.concatenate(full_list_alb, axis=2)[:, 0, :]
full_arr_az = np.concatenate(full_list_az, axis=2)[:, 0, :]
length_arr_alb = np.concatenate(length_list_alb, axis=2)[:, 0, :]
length_arr_az = np.concatenate(length_list_az, axis=2)[:, 0, :]

a = 1


if __name__ == '__main__':
main2()

Loading

0 comments on commit 5f0eaff

Please sign in to comment.