Skip to content

Commit

Permalink
scripted baseline agents
Browse files Browse the repository at this point in the history
  • Loading branch information
ymahlau committed Mar 26, 2024
1 parent fe5a089 commit 993f1d1
Show file tree
Hide file tree
Showing 11 changed files with 947 additions and 9 deletions.
113 changes: 113 additions & 0 deletions scripts/eval_oc/scripted_overcooked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from datetime import datetime
import itertools
import math
from pathlib import Path
import numpy as np
import pickle
from src.agent.albatross import AlbatrossAgent, AlbatrossAgentConfig
from src.agent.initialization import get_agent_from_config
from src.agent.one_shot import NetworkAgent, NetworkAgentConfig, bc_agent_from_file
from src.agent.scripted import PlaceDishEverywhereAgentConfig, PlaceOnionAgentConfig, PlaceOnionDeliverAgentConfig, PlaceOnionEverywhereAgentConfig
from src.game.overcooked.config import AsymmetricAdvantageOvercookedConfig, CoordinationRingOvercookedConfig, \
CounterCircuitOvercookedConfig, CrampedRoomOvercookedConfig, ForcedCoordinationOvercookedConfig, OvercookedRewardConfig
from src.game.overcooked.overcooked import OvercookedGame
from src.misc.utils import set_seed
from src.network.initialization import get_network_from_file
from src.trainer.az_evaluator import do_evaluation


def eval_scripted_oc(seed: int):
print(f'{datetime.now()} - Started eval script', flush=True)
save_path = Path(__file__).parent.parent.parent / 'a_data' / 'scripted'
game_cfg, prefix = AsymmetricAdvantageOvercookedConfig(), 'aa'
# game_cfg, prefix = CoordinationRingOvercookedConfig(), 'co'
# game_cfg, prefix = CounterCircuitOvercookedConfig(), 'cc'

# init scripted agent
scripted_cfg = PlaceOnionAgentConfig()
# scripted_cfg = PlaceOnionDeliverAgentConfig()
# scripted_cfg = PlaceOnionEverywhereAgentConfig()
# scripted_cfg = PlaceDishEverywhereAgentConfig()

scripted_agent = get_agent_from_config(scripted_cfg)

# fname = f'{prefix}_dish_everywhere_{seed}.pkl'
fname = f'tmp.pkl'
print(f"{fname=}", flush=True)

set_seed(seed)

net_path = Path(__file__).parent.parent.parent / 'a_saved_runs' / 'overcooked'
proxy_path = net_path / f'proxy_{prefix}_{seed}' / 'latest.pt'
resp_path = net_path / f'resp_{prefix}_{seed}' / 'latest.pt'

net = get_network_from_file(resp_path).eval()
alb_network_agent_cfg = NetworkAgentConfig(
net_cfg=net.cfg,
temperature_input=True,
single_temperature=False,
init_temperatures=[0, 0],
)
alb_online_agent_cfg = AlbatrossAgentConfig(
num_player=2,
agent_cfg=alb_network_agent_cfg,
device_str='cpu',
response_net_path=str(resp_path),
proxy_net_path=str(proxy_path),
noise_std=None,
# fixed_temperatures=[9, 9],
num_samples=1,
init_temp=0,
# num_likelihood_bins=int(2e3),
# sample_from_likelihood=True,
)
alb_online_agent = AlbatrossAgent(alb_online_agent_cfg)

reward_cfg = OvercookedRewardConfig(
placement_in_pot=0,
dish_pickup=0,
soup_pickup=0,
soup_delivery=20,
start_cooking=0,
)
game_cfg.reward_cfg = reward_cfg
game_cfg.temperature_input = True
game_cfg.single_temperature_input = True
game_cfg.automatic_cook_start = False
game = OvercookedGame(game_cfg)

print(f'{datetime.now()} - Started evaluation of {prefix} with {seed=}', flush=True)
results, _ = do_evaluation(
game=game,
evaluee=scripted_agent,
opponent_list=[alb_online_agent],
num_episodes=[100],
enemy_iterations=0,
temperature_list=[0.5],
own_temperature=1,
prevent_draw=False,
switch_positions=False,
verbose_level=1,
)
with open(save_path / fname, 'wb') as f:
pickle.dump(results, f)


def compute_avg():
path = Path(__file__).parent.parent.parent / 'a_data' / 'scripted'
res_list = []
for seed in range(5):
with open(path / f'cc_dish_everywhere_{seed}.pkl', 'rb') as f:
res = pickle.load(f)
res_list.append(res)
full_arr = np.asarray(res_list)[:, 0]
arr = full_arr.mean(axis=-1)
print(arr)
print(f"{arr.mean()=}")
print(f"{arr.std()=}")



if __name__ == '__main__':
# eval_scripted_oc(0)
compute_avg()
69 changes: 69 additions & 0 deletions scripts/temp_est/matrix_game.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@


import numpy as np
from src.equilibria.logit import compute_logit_equilibrium
from src.modelling.mle import compute_likelihood, compute_temperature_mle


def main_matrix_game_repeated():
# compute le with temperature 10 as ground truth best play
ja_vals = np.asarray([[4, 4], [0, 0], [1, 1], [2, 2]])

val, pol, err = compute_logit_equilibrium(
available_actions=[[0, 1], [0, 1]],
joint_action_list=[(0, 0), (0, 1), (1, 0), (1, 1)],
joint_action_value_arr=ja_vals,
num_iterations=int(1e6),
epsilon=0,
temperatures=[10, 10],
)
p1_logit_pol = pol[0]
q_a = p1_logit_pol[0] * ja_vals[0, 1] + p1_logit_pol[1] * ja_vals[2, 1]
q_b = p1_logit_pol[0] * ja_vals[1, 1] + p1_logit_pol[1] * ja_vals[3, 1]

p2_actions = [0, 1, 0, 0]
utils = [[q_a, q_b] for _ in range(len(p2_actions))]

cur_temperature = compute_temperature_mle(
min_temp=-10,
max_temp=10,
num_iterations=20,
chosen_actions=p2_actions,
utilities=utils,
)

val_log, pol_log, err = compute_logit_equilibrium(
available_actions=[[0, 1], [0, 1]],
joint_action_list=[(0, 0), (0, 1), (1, 0), (1, 1)],
joint_action_value_arr=ja_vals,
num_iterations=int(1e6),
epsilon=0,
temperatures=[cur_temperature, cur_temperature],
)
q_a_p1 = pol_log[1][0] * ja_vals[0, 1] + pol_log[1][1] * ja_vals[1, 1]
q_b_p1 = pol_log[1][0] * ja_vals[2, 1] + pol_log[1][1] * ja_vals[3, 1]

a = 1

# for t in range(1, num_steps + 1):
# utils = [[q_a, q_b] for _ in range(t)]

# p2_actions = [0 for _ in range(t)]

# cur_temperature = compute_temperature_mle(
# min_temp=0,
# max_temp=10,
# num_iterations=20,
# chosen_actions=p2_actions,
# utilities=utils,
# )
# all_estimates.append(cur_temperature)

a = 1





if __name__ == '__main__':
main_matrix_game_repeated()
9 changes: 9 additions & 0 deletions src/agent/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
NetworkAgentConfig, LegalRandomAgent, LegalRandomAgentConfig, BCNetworkAgent, BCNetworkAgentConfig
from src.agent.overcooked import GreedyHumanOvercookedAgent, GreedyHumanOvercookedAgentConfig
from src.agent.planning import AStarAgent, AStarAgentConfig
from src.agent.scripted import PlaceDishEverywhereAgent, PlaceDishEverywhereAgentConfig, PlaceOnionAgent, PlaceOnionAgentConfig, PlaceOnionDeliverAgent, PlaceOnionDeliverAgentConfig, PlaceOnionEverywhereAgent, PlaceOnionEverywhereAgentConfig
from src.agent.search_agent import SearchAgent, LookaheadAgent, SearchAgentConfig, LookaheadAgentConfig, \
DoubleSearchAgent, DoubleSearchAgentConfig

Expand All @@ -29,6 +30,14 @@ def get_agent_from_config(agent_cfg: AgentConfig) -> Agent:
return AlbatrossAgent(agent_cfg)
elif isinstance(agent_cfg, BCNetworkAgentConfig):
return BCNetworkAgent(agent_cfg)
elif isinstance(agent_cfg, PlaceOnionAgentConfig):
return PlaceOnionAgent(agent_cfg)
elif isinstance(agent_cfg, PlaceOnionEverywhereAgentConfig):
return PlaceOnionEverywhereAgent(agent_cfg)
elif isinstance(agent_cfg, PlaceDishEverywhereAgentConfig):
return PlaceDishEverywhereAgent(agent_cfg)
elif isinstance(agent_cfg, PlaceOnionDeliverAgentConfig):
return PlaceOnionDeliverAgent(agent_cfg)
else:
raise ValueError(f"Unknown agent type: {agent_cfg}")

Loading

0 comments on commit 993f1d1

Please sign in to comment.