Skip to content

Commit

Permalink
cooperative battlesnake experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
ymahlau committed Mar 26, 2024
1 parent a3bbfa5 commit 32257f4
Show file tree
Hide file tree
Showing 42 changed files with 16,690 additions and 38 deletions.
413 changes: 413 additions & 0 deletions config/cfg_4dc11_0.yaml

Large diffs are not rendered by default.

500 changes: 500 additions & 0 deletions config/cfg_4dc11_1.yaml

Large diffs are not rendered by default.

500 changes: 500 additions & 0 deletions config/cfg_4dc11_2.yaml

Large diffs are not rendered by default.

500 changes: 500 additions & 0 deletions config/cfg_4dc11_3.yaml

Large diffs are not rendered by default.

500 changes: 500 additions & 0 deletions config/cfg_4dc11_4.yaml

Large diffs are not rendered by default.

474 changes: 474 additions & 0 deletions config/cfg_4dc11_proxy_0.yaml

Large diffs are not rendered by default.

561 changes: 561 additions & 0 deletions config/cfg_4dc11_proxy_1.yaml

Large diffs are not rendered by default.

561 changes: 561 additions & 0 deletions config/cfg_4dc11_proxy_2.yaml

Large diffs are not rendered by default.

561 changes: 561 additions & 0 deletions config/cfg_4dc11_proxy_3.yaml

Large diffs are not rendered by default.

561 changes: 561 additions & 0 deletions config/cfg_4dc11_proxy_4.yaml

Large diffs are not rendered by default.

455 changes: 455 additions & 0 deletions config/cfg_4dc11_resp_0.yaml

Large diffs are not rendered by default.

542 changes: 542 additions & 0 deletions config/cfg_4dc11_resp_1.yaml

Large diffs are not rendered by default.

542 changes: 542 additions & 0 deletions config/cfg_4dc11_resp_2.yaml

Large diffs are not rendered by default.

542 changes: 542 additions & 0 deletions config/cfg_4dc11_resp_3.yaml

Large diffs are not rendered by default.

542 changes: 542 additions & 0 deletions config/cfg_4dc11_resp_4.yaml

Large diffs are not rendered by default.

413 changes: 413 additions & 0 deletions config/cfg_4dc9_0.yaml

Large diffs are not rendered by default.

500 changes: 500 additions & 0 deletions config/cfg_4dc9_1.yaml

Large diffs are not rendered by default.

500 changes: 500 additions & 0 deletions config/cfg_4dc9_2.yaml

Large diffs are not rendered by default.

500 changes: 500 additions & 0 deletions config/cfg_4dc9_3.yaml

Large diffs are not rendered by default.

500 changes: 500 additions & 0 deletions config/cfg_4dc9_4.yaml

Large diffs are not rendered by default.

474 changes: 474 additions & 0 deletions config/cfg_4dc9_proxy_0.yaml

Large diffs are not rendered by default.

561 changes: 561 additions & 0 deletions config/cfg_4dc9_proxy_1.yaml

Large diffs are not rendered by default.

561 changes: 561 additions & 0 deletions config/cfg_4dc9_proxy_2.yaml

Large diffs are not rendered by default.

561 changes: 561 additions & 0 deletions config/cfg_4dc9_proxy_3.yaml

Large diffs are not rendered by default.

561 changes: 561 additions & 0 deletions config/cfg_4dc9_proxy_4.yaml

Large diffs are not rendered by default.

455 changes: 455 additions & 0 deletions config/cfg_4dc9_resp_0.yaml

Large diffs are not rendered by default.

542 changes: 542 additions & 0 deletions config/cfg_4dc9_resp_1.yaml

Large diffs are not rendered by default.

542 changes: 542 additions & 0 deletions config/cfg_4dc9_resp_2.yaml

Large diffs are not rendered by default.

542 changes: 542 additions & 0 deletions config/cfg_4dc9_resp_3.yaml

Large diffs are not rendered by default.

542 changes: 542 additions & 0 deletions config/cfg_4dc9_resp_4.yaml

Large diffs are not rendered by default.

343 changes: 343 additions & 0 deletions scripts/training/generate_training_cfg_bs_coop_az.py

Large diffs are not rendered by default.

353 changes: 353 additions & 0 deletions scripts/training/generate_training_cfg_bs_coop_proxy.py

Large diffs are not rendered by default.

354 changes: 354 additions & 0 deletions scripts/training/generate_training_cfg_bs_coop_resp.py

Large diffs are not rendered by default.

15 changes: 9 additions & 6 deletions src/game/battlesnake/battlesnake_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __call__(
@dataclass
class CooperationBattleSnakeRewardConfig(BattleSnakeRewardConfig):
living_reward: float = field(default=0.02)
terminal_reward: float = -0.5
terminal_reward: float = -0.25

class BattleSnakeRewardFunctionCooperation(BattleSnakeRewardFunction):
def __init__(self, cfg: CooperationBattleSnakeRewardConfig):
Expand All @@ -154,12 +154,15 @@ def __call__(
) -> np.ndarray:
num_at_turn = len(players_at_turn)
num_at_turn_last = len(players_at_turn_last)
num_dead = num_at_turn_last - num_at_turn
# all players get negative terminal reward if a snake died
if num_at_turn != num_at_turn_last:
rewards = np.ones(shape=(num_players,), dtype=float) * self.cfg.terminal_reward
# add living reward if no one died
else:
rewards = self.cfg.living_reward * np.ones(shape=(num_players,), dtype=float)
rewards = np.zeros(shape=(num_players,), dtype=float)
player_died = [p for p in players_at_turn_last if p not in players_at_turn]
for p in player_died:
rewards[p] += self.cfg.terminal_reward * num_at_turn_last
for p in players_at_turn:
rewards[p] += self.cfg.terminal_reward * num_dead
rewards[p] += self.cfg.living_reward
return rewards


Expand Down
17 changes: 16 additions & 1 deletion src/game/battlesnake/bootcamp/test_envs_11x11.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from src.game.battlesnake.battlesnake_conf import BattleSnakeConfig
from src.game.battlesnake.battlesnake_enc import SimpleBattleSnakeEncodingConfig, VanillaBattleSnakeEncodingConfig, \
SimpleConstrictorEncodingConfig
from src.game.battlesnake.battlesnake_rewards import StandardBattleSnakeRewardConfig, KillBattleSnakeRewardConfig
from src.game.battlesnake.battlesnake_rewards import CooperationBattleSnakeRewardConfig, StandardBattleSnakeRewardConfig, KillBattleSnakeRewardConfig


def perform_choke_11x11(centered: bool) -> BattleSnakeConfig:
Expand Down Expand Up @@ -116,3 +116,18 @@ def survive_on_11x11_4_player_royale() -> BattleSnakeConfig:
)
return gc

def survive_on_11x11_constrictor_4_player_coop() -> BattleSnakeConfig:
ec = SimpleConstrictorEncodingConfig()
rc = CooperationBattleSnakeRewardConfig()

gc = BattleSnakeConfig(
w=11,
h=11,
num_players=4,
ec=ec,
reward_cfg=rc,
all_actions_legal=False,
constrictor=True,
)
return gc

19 changes: 18 additions & 1 deletion src/game/battlesnake/bootcamp/test_envs_9x9.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from src.game.battlesnake.battlesnake_conf import BattleSnakeConfig
from src.game.battlesnake.battlesnake_enc import BestBattleSnakeEncodingConfig
from src.game.battlesnake.battlesnake_enc import BestBattleSnakeEncodingConfig, SimpleConstrictorEncodingConfig
from src.game.battlesnake.battlesnake_rewards import CooperationBattleSnakeRewardConfig


def survive_on_9x9_constrictor_4_player() -> BattleSnakeConfig:
Expand All @@ -23,3 +24,19 @@ def survive_on_9x9_constrictor_4_player() -> BattleSnakeConfig:
constrictor=True,
)
return gc


def survive_on_9x9_constrictor_4_player_coop() -> BattleSnakeConfig:
ec = SimpleConstrictorEncodingConfig()
rc = CooperationBattleSnakeRewardConfig()

gc = BattleSnakeConfig(
w=9,
h=9,
num_players=4,
ec=ec,
reward_cfg=rc,
all_actions_legal=False,
constrictor=True,
)
return gc
49 changes: 49 additions & 0 deletions src/network/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,55 @@ class MobileNetConfig7x7Incumbent(MobileNetConfig):
value_head_cfg: HeadConfig = field(
default_factory=lambda: HeadConfig(num_layers=1, final_activation=ActivationType.TANH)
)


# in_channels, exp_channels, out_channels, kernel_size, stride, se
incumbent_9x9 = [
[64, 128, 64, 3, 1, 0],
[64, 128, 64, 5, 1, 0], # 19
[64, 192, 128, 3, 2, 0], # 10
[128, 320, 128, 3, 1, 1],
[128, 320, 128, 5, 1, 1],
[128, 320, 192, 3, 2, 1], # 5
[192, 384, 192, 3, 1, 1],
[192, 384, 192, 3, 2, 1], # 3
[192, 384, 192, 3, 1, 1],
]
@dataclass
class MobileNetConfig9x9Incumbent(MobileNetConfig):
layer_specs: list[list[int]] = field(default_factory=lambda: incumbent_9x9)
lff_features: bool = field(default=False)
lff_feature_expansion: int = field(default=27)
policy_head_cfg: HeadConfig = field(
default_factory=lambda: HeadConfig(num_layers=1, final_activation=ActivationType.NONE)
)
value_head_cfg: HeadConfig = field(
default_factory=lambda: HeadConfig(num_layers=1, final_activation=ActivationType.TANH)
)

# in_channels, exp_channels, out_channels, kernel_size, stride, se
incumbent_11x11 = [
[64, 128, 64, 3, 1, 0],
[64, 128, 64, 5, 1, 0], # 21
[64, 192, 128, 3, 2, 0], # 11
[128, 320, 128, 3, 1, 1],
[128, 320, 128, 5, 1, 1],
[128, 320, 192, 3, 2, 1], # 6
[192, 384, 192, 3, 1, 1],
[192, 384, 192, 3, 2, 1], # 3
[192, 384, 192, 3, 1, 1],
]
@dataclass
class MobileNetConfig11x11Incumbent(MobileNetConfig):
layer_specs: list[list[int]] = field(default_factory=lambda: incumbent_11x11)
lff_features: bool = field(default=False)
lff_feature_expansion: int = field(default=27)
policy_head_cfg: HeadConfig = field(
default_factory=lambda: HeadConfig(num_layers=1, final_activation=ActivationType.NONE)
)
value_head_cfg: HeadConfig = field(
default_factory=lambda: HeadConfig(num_layers=1, final_activation=ActivationType.TANH)
)


extrapolated_11x11 = [
Expand Down
10 changes: 8 additions & 2 deletions src/trainer/az_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ def _train(self) -> None:
if self.cfg.temperature_input and not self.cfg.single_sbr_temperature:
# proxy
seed_server = random.randint(0, 2 ** 32 - 1)
cur_gpu_idx = gpu_idx
if self.cfg.merge_proxy_response_gpu:
cur_gpu_idx -= 1
if self.cfg.merge_inference_update_gpu:
cur_gpu_idx = 0
kwargs_inference = {
'trainer_cfg': self.cfg,
'const_net_path': self.cfg.proxy_net_path,
Expand All @@ -250,15 +255,16 @@ def _train(self) -> None:
'input_arr': input_list[arr_idx],
'output_arr': output_list[arr_idx],
'cpu_list': cpu_list_inference,
'gpu_idx': gpu_idx if not self.cfg.merge_inference_update_gpu else 0,
'gpu_idx': cur_gpu_idx,
'prev_run_dir': Path(self.cfg.prev_run_dir) if self.cfg.prev_run_dir is not None else None,
'prev_run_idx': self.cfg.prev_run_idx,
}
p = mp.Process(target=run_inference_server, kwargs=kwargs_inference)
p.start()
process_list.append(p)
arr_idx += 1
gpu_idx += 1
if not self.cfg.merge_proxy_response_gpu:
gpu_idx += 1
if self.cfg.max_cpu_inference_server is not None:
cpu_counter += self.cfg.max_cpu_inference_server
# cpu list for distributor, collector, saver and logger
Expand Down
1 change: 1 addition & 0 deletions src/trainer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,4 @@ class AlphaZeroTrainerConfig:
compile_model: bool = False
compile_mode: str = 'reduce-overhead' # Can also be max_autotune (currently does not work on rtx3090
merge_inference_update_gpu: bool = False # updater and inference server use same gpu
merge_proxy_response_gpu: bool = False
29 changes: 2 additions & 27 deletions start_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from src.misc.serialization import deserialize_dataclass
from src.trainer.az_trainer import AlphaZeroTrainer, AlphaZeroTrainerConfig

# @hydra.main(version_base=None, config_name='config', config_path=str(Path(__file__).parent / 'config_generated'))
@hydra.main(version_base=None, config_name='cfg_4dc9_0', config_path=str(Path(__file__).parent / 'config'))
def main(cfg: AlphaZeroTrainerConfig):
# torch.set_num_threads(1)
# os.environ["OMP_NUM_THREADS"] = "1"
Expand All @@ -28,29 +28,4 @@ def main(cfg: AlphaZeroTrainerConfig):
if __name__ == '__main__':
mp.set_start_method('spawn', force=True) # this is important for using CUDA
print(f"{mp.get_start_method()=}")
config_path = Path(__file__).parent / 'config'
config_name = 'config'

if len(sys.argv) > 3 and sys.argv[1].startswith("config="):
config_prefix = sys.argv[1].split("=")[-1]
sys.argv.pop(1)
arr_id = int(sys.argv[1])
sys.argv.pop(1)

pref_lists = [
# list(range(1, 6)),
# [1] + list(range(5, 51, 5)),
['aa', 'cc', 'co', 'cr', 'fc'],
list(range(5)),
]
prod = list(itertools.product(*pref_lists))
tpl = prod[arr_id]
# config_name = f"{config_prefix}_{tpl[0]}_{tpl[1]}_{tpl[2]}"
config_name = f"{config_prefix}_{tpl[0]}_{tpl[1]}"
# config_name = f"{config_prefix}_{prefix_arr[t]}_{seed}"
# config_name = f"{config_prefix}_{seed}_{prefix_arr[t]}"
elif len(sys.argv) > 2 and sys.argv[1].startswith("config="):
config_name = sys.argv[1].split("=")[-1]
sys.argv.pop(1)
print(f"{config_name=}", flush=True)
hydra.main(config_path=str(config_path), config_name=config_name, version_base=None)(main)()
main()
15 changes: 15 additions & 0 deletions test/game/battlesnake/test_constrictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from src.game.battlesnake.battlesnake import BattleSnakeGame, UP, RIGHT, LEFT, DOWN
from src.game.battlesnake.battlesnake_conf import BattleSnakeConfig
from src.game.battlesnake.bootcamp.test_envs_9x9 import survive_on_9x9_constrictor_4_player_coop


class TestConstrictor(unittest.TestCase):
Expand Down Expand Up @@ -70,4 +71,18 @@ def test_bool_game_array(self):
arr = game.get_bool_board_matrix()
print(arr)
self.assertEqual(7, np.sum(arr))

def test_coop_9x9(self):
game_cfg = survive_on_9x9_constrictor_4_player_coop()
game_cfg.all_actions_legal = True
game = BattleSnakeGame(game_cfg)
game.render()

r, _, _ = game.step((UP, UP, UP, UP))
game.render()
print(f"{r=}")

r, _, _ = game.step((UP, UP, UP, UP))
game.render()
print(f"{r=}")

15 changes: 14 additions & 1 deletion test/network/test_mobile_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from src.game.battlesnake.bootcamp.test_envs_3x3 import perform_choke_2_player
from src.game.battlesnake.bootcamp.test_envs_5x5 import perform_choke_5x5_4_player, survive_on_5x5_constrictor
from src.game.battlesnake.bootcamp.test_envs_7x7 import survive_on_7x7
from src.game.battlesnake.bootcamp.test_envs_9x9 import survive_on_9x9_constrictor_4_player_coop
from src.game.initialization import get_game_from_config
from src.game.overcooked_slow.layouts import CrampedRoomOvercookedSlowConfig, AsymmetricAdvantageOvercookedSlowConfig
from src.network.initialization import get_network_from_config
from src.network.mobilenet_v3 import MobileNetConfig3x3, MobileNetConfig7x7, MobileNetConfig5x5, \
MobileNetConfig5x5Large, MobileNetConfig11x11, MobileNetConfig11x11Extrapolated, MobileNetConfig5x5Extrapolated, \
MobileNetConfig7x7Incumbent, MobileNetConfigOvercookedCramped, MobileNetConfigOvercookedAsymmetricAdvantage
MobileNetConfig7x7Incumbent, MobileNetConfig9x9Incumbent, MobileNetConfigOvercookedCramped, MobileNetConfigOvercookedAsymmetricAdvantage
from src.network.resnet import ResNetConfig3x3, ResNetConfig7x7Large
from src.network.vision_net import EquivarianceType

Expand Down Expand Up @@ -191,4 +192,16 @@ def test_mobile_overcooked_asym(self):
in_tensor, _, _ = game.get_obs()
out = net(torch.tensor(in_tensor))
print(f"{out=}")

def test_mobile_9x9(self):
game_cfg = survive_on_9x9_constrictor_4_player_coop()
game = get_game_from_config(game_cfg)
game.render()

net_cfg = MobileNetConfig9x9Incumbent(game_cfg=game_cfg)
net = get_network_from_config(net_cfg)
print(f"{net.num_params()=}")
in_tensor, _, _ = game.get_obs()
out = net(torch.tensor(in_tensor))
print(f"{out=}")

0 comments on commit 32257f4

Please sign in to comment.