Skip to content

Commit

Permalink
fix the custom_preprocessor issue arising from training for centraliz…
Browse files Browse the repository at this point in the history
…ed paradigm
  • Loading branch information
RutvikGupta authored and Gamenot committed May 3, 2022
1 parent f3187d7 commit 5f4842f
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 11 deletions.
7 changes: 2 additions & 5 deletions baselines/marl_benchmark/marl_benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,12 @@ def gen_config(**kwargs):
"obs_space": gym.spaces.Tuple([obs_space] * agent_missions_count),
"act_space": gym.spaces.Tuple([act_space] * agent_missions_count),
"groups": {"group": agent_ids},
"model": config["policy"][-1],
}
)
tune_config.update({"model": config["policy"][-1]})

policies = {}
for k in agents:
policies[k] = config["policy"][:-1] + (
{**config["policy"][-1], "agent_id": k},
)
policies[k] = config["policy"][:-1] + ({**config["policy"][-1], "agent_id": k},)
tune_config.update(
{
"multiagent": {
Expand Down
7 changes: 3 additions & 4 deletions baselines/marl_benchmark/marl_benchmark/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,10 @@ def main(
tune_config = config["run"]["config"]
trainer_cls = config["trainer"]
trainer_config = {"env_config": config["env_config"]}
if paradigm != "centralized":
trainer_config.update({"multiagent": tune_config["multiagent"]})
else:
trainer_config.update({"model": tune_config["model"]})
if paradigm == "centralized":
trainer_config["model"] = config["policy"][-1]

trainer_config.update({"multiagent": tune_config["multiagent"]})
trainer = trainer_cls(env=tune_config["env"], config=trainer_config)
trainer_config["evaluation_interval"] = True
trainer.setup(trainer_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import gym
import numpy as np
from ray import logger
from ray.rllib.models import ModelCatalog
from ray.rllib.models import Preprocessor
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -144,7 +145,8 @@ def stack_frames(frames):

@staticmethod
def get_preprocessor():
return TupleStackingPreprocessor
ModelCatalog.register_custom_preprocessor("my_prep", TupleStackingPreprocessor)
return "my_prep"

def _get_observations(self, raw_frames):
"""Update frame stack with given single frames,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
# THE SOFTWARE.
import copy

from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.env.constants import GROUP_INFO, GROUP_REWARDS
from ray.rllib.env.group_agents_wrapper import _GroupAgentsWrapper

Expand Down

0 comments on commit 5f4842f

Please sign in to comment.