diff --git a/baselines/marl_benchmark/marl_benchmark/__init__.py b/baselines/marl_benchmark/marl_benchmark/__init__.py index f3c499c3a6..18b90c1dcf 100644 --- a/baselines/marl_benchmark/marl_benchmark/__init__.py +++ b/baselines/marl_benchmark/marl_benchmark/__init__.py @@ -57,15 +57,12 @@ def gen_config(**kwargs): "obs_space": gym.spaces.Tuple([obs_space] * agent_missions_count), "act_space": gym.spaces.Tuple([act_space] * agent_missions_count), "groups": {"group": agent_ids}, + "model": config["policy"][-1], } ) - tune_config.update({"model": config["policy"][-1]}) - policies = {} for k in agents: - policies[k] = config["policy"][:-1] + ( - {**config["policy"][-1], "agent_id": k}, - ) + policies[k] = config["policy"][:-1] + ({**config["policy"][-1], "agent_id": k},) tune_config.update( { "multiagent": { diff --git a/baselines/marl_benchmark/marl_benchmark/evaluate.py b/baselines/marl_benchmark/marl_benchmark/evaluate.py index 15c62bcf86..caca9d9e0e 100644 --- a/baselines/marl_benchmark/marl_benchmark/evaluate.py +++ b/baselines/marl_benchmark/marl_benchmark/evaluate.py @@ -86,11 +86,10 @@ def main( tune_config = config["run"]["config"] trainer_cls = config["trainer"] trainer_config = {"env_config": config["env_config"]} - if paradigm != "centralized": - trainer_config.update({"multiagent": tune_config["multiagent"]}) - else: - trainer_config.update({"model": tune_config["model"]}) + if paradigm == "centralized": + trainer_config["model"] = config["policy"][-1] + trainer_config.update({"multiagent": tune_config["multiagent"]}) trainer = trainer_cls(env=tune_config["env"], config=trainer_config) trainer_config["evaluation_interval"] = True trainer.setup(trainer_config) diff --git a/baselines/marl_benchmark/marl_benchmark/wrappers/rllib/frame_stack.py b/baselines/marl_benchmark/marl_benchmark/wrappers/rllib/frame_stack.py index 3124ca340e..9467c9af0a 100644 --- a/baselines/marl_benchmark/marl_benchmark/wrappers/rllib/frame_stack.py +++ b/baselines/marl_benchmark/marl_benchmark/wrappers/rllib/frame_stack.py @@ -25,6 +25,7 @@ import gym import numpy as np from ray import logger +from ray.rllib.models import ModelCatalog from ray.rllib.models import Preprocessor from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.utils.annotations import override @@ -144,7 +145,8 @@ def stack_frames(frames): @staticmethod def get_preprocessor(): - return TupleStackingPreprocessor + ModelCatalog.register_custom_preprocessor("my_prep", TupleStackingPreprocessor) + return "my_prep" def _get_observations(self, raw_frames): """Update frame stack with given single frames, diff --git a/baselines/marl_benchmark/marl_benchmark/wrappers/rllib/group.py b/baselines/marl_benchmark/marl_benchmark/wrappers/rllib/group.py index d4e26c86e4..33e03a1bcd 100644 --- a/baselines/marl_benchmark/marl_benchmark/wrappers/rllib/group.py +++ b/baselines/marl_benchmark/marl_benchmark/wrappers/rllib/group.py @@ -21,7 +21,6 @@ # THE SOFTWARE. import copy -from ray.rllib.agents.dqn import DQNTrainer from ray.rllib.env.constants import GROUP_INFO, GROUP_REWARDS from ray.rllib.env.group_agents_wrapper import _GroupAgentsWrapper