Skip to content

Commit

Permalink
[RLlib] New ConnectorV3 API #5: PPO runs in single-agent mode in this…
Browse files Browse the repository at this point in the history
… API stack. (ray-project#42272)
  • Loading branch information
sven1977 authored Jan 19, 2024
1 parent 88a35bc commit e03dd6e
Show file tree
Hide file tree
Showing 42 changed files with 1,140 additions and 605 deletions.
27 changes: 12 additions & 15 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,6 @@ py_test(
# --------------------------------------------------------------------

# APPO
py_test(
name = "learning_tests_cartpole_appo_no_vtrace",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
size = "medium", # bazel may complain about it being too long sometimes - medium is on purpose as some frameworks take longer
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/cartpole-appo.yaml"],
args = ["--dir=tuned_examples/appo"]
)

py_test(
name = "learning_tests_cartpole_appo_w_rl_modules_and_learner",
main = "tests/run_regression_tests.py",
Expand All @@ -177,7 +167,7 @@ py_test(
size = "medium",
srcs = ["tests/run_regression_tests.py"],
data = [
"tuned_examples/appo/cartpole-appo-vtrace-separate-losses.py"
"tuned_examples/appo/cartpole-appo-separate-losses.py"
],
args = ["--dir=tuned_examples/appo"]
)
Expand Down Expand Up @@ -208,17 +198,17 @@ py_test(
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
size = "medium",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/cartpole-appo-vtrace-fake-gpus.yaml"],
data = ["tuned_examples/appo/cartpole-appo-fake-gpus.yaml"],
args = ["--dir=tuned_examples/appo"]
)

py_test(
name = "learning_tests_stateless_cartpole_appo_vtrace",
name = "learning_tests_stateless_cartpole_appo",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
size = "enormous",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/stateless-cartpole-appo-vtrace.py"],
data = ["tuned_examples/appo/stateless_cartpole_appo.py"],
args = ["--dir=tuned_examples/appo"]
)

Expand Down Expand Up @@ -1453,6 +1443,13 @@ py_test(
srcs = ["utils/exploration/tests/test_explorations.py"]
)

py_test(
name = "test_value_predictions",
tags = ["team:rllib", "utils"],
size = "small",
srcs = ["utils/postprocessing/tests/test_value_predictions.py"]
)

py_test(
name = "test_random_encoder",
tags = ["team:rllib", "utils"],
Expand All @@ -1461,7 +1458,7 @@ py_test(
)

py_test(
name = "utils/tests/test_torch_utils",
name = "test_torch_utils",
tags = ["team:rllib", "utils", "gpu"],
size = "medium",
srcs = ["utils/tests/test_torch_utils.py"]
Expand Down
25 changes: 19 additions & 6 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)

Expand All @@ -46,7 +47,6 @@
collect_metrics,
summarize_episodes,
)
from ray.rllib.evaluation.postprocessing_v2 import postprocess_episodes_to_sample_batch
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
Expand Down Expand Up @@ -129,6 +129,8 @@
from ray.util.timer import _Timer
from ray.tune.registry import get_trainable_cls

if TYPE_CHECKING:
from ray.rllib.core.learner.learner_group import LearnerGroup

try:
from ray.rllib.extensions import AlgorithmBase
Expand Down Expand Up @@ -449,6 +451,9 @@ def __init__(
# Placeholder for a local replay buffer instance.
self.local_replay_buffer = None

# Placeholder for our LearnerGroup responsible for updating the RLModule(s).
self.learner_group: Optional["LearnerGroup"] = None

# Create a default logger creator if no logger_creator is specified
if logger_creator is None:
# Default logdir prefix containing the agent's name and the
Expand Down Expand Up @@ -1410,7 +1415,12 @@ def remote_fn(worker):
worker.set_weights(
weights=ray.get(weights_ref), weights_seq_no=weights_seq_no
)
episodes = worker.sample(explore=False)
# By episode: Run always only one episode per remote call.
# By timesteps: By default EnvRunner runs for the configured number of
# timesteps (based on `rollout_fragment_length` and `num_envs_per_worker`).
episodes = worker.sample(
explore=False, num_episodes=1 if unit == "episodes" else None
)
metrics = worker.get_metrics()
return episodes, metrics, weights_seq_no

Expand Down Expand Up @@ -1449,11 +1459,13 @@ def remote_fn(worker):
rollout_metrics.extend(metrics)
i += 1

# Convert our list of Episodes to a single SampleBatch.
batch = postprocess_episodes_to_sample_batch(episodes)
# Collect steps stats.
_agent_steps = batch.agent_steps()
_env_steps = batch.env_steps()
# TODO (sven): Solve for proper multi-agent env/agent steps counting.
# Once we have multi-agent support on EnvRunner stack, we can simply do:
# `len(episode)` for env steps and `episode.num_agent_steps()` for agent
# steps.
_agent_steps = sum(len(e) for e in episodes)
_env_steps = sum(len(e) for e in episodes)

# Only complete episodes done by eval workers.
if unit == "episodes":
Expand All @@ -1467,6 +1479,7 @@ def remote_fn(worker):
)

if self.reward_estimators:
batch = concat_samples([e.get_sample_batch() for e in episodes])
all_batches.append(batch)

agent_steps_this_iter += _agent_steps
Expand Down
71 changes: 48 additions & 23 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ def __init__(self, algo_class=None):
self.grad_clip = None
self.grad_clip_by = "global_norm"
self.train_batch_size = 32
# Simple logic for now: If None, use `train_batch_size`.
self.train_batch_size_per_learner = None
# TODO (sven): Unsolved problem with RLModules sometimes requiring settings from
# the main AlgorithmConfig. We should not require the user to provide those
# settings in both, the AlgorithmConfig (as property) AND the model config
Expand Down Expand Up @@ -871,6 +873,7 @@ def build_env_to_module_connector(self, env):
return pipeline

def build_module_to_env_connector(self, env):

from ray.rllib.connectors.module_to_env import (
DefaultModuleToEnv,
ModuleToEnvPipeline,
Expand Down Expand Up @@ -1333,11 +1336,11 @@ def environment(
Tuple[value1, value2]: Clip at value1 and value2.
normalize_actions: If True, RLlib will learn entirely inside a normalized
action space (0.0 centered with small stddev; only affecting Box
components). We will unsquash actions (and clip, just in case) to the
components). RLlib will unsquash actions (and clip, just in case) to the
bounds of the env's action space before sending actions back to the env.
clip_actions: If True, RLlib will clip actions according to the env's bounds
before sending them back to the env.
TODO: (sven) This option should be deprecated and always be False.
clip_actions: If True, the RLlib default ModuleToEnv connector will clip
actions according to the env's bounds (before sending them into the
`env.step()` call).
disable_env_checking: If True, disable the environment pre-checking module.
is_atari: This config can be used to explicitly specify whether the env is
an Atari env or not. If not specified, RLlib will try to auto-detect
Expand Down Expand Up @@ -1678,6 +1681,7 @@ def training(
grad_clip: Optional[float] = NotProvided,
grad_clip_by: Optional[str] = NotProvided,
train_batch_size: Optional[int] = NotProvided,
train_batch_size_per_learner: Optional[int] = NotProvided,
model: Optional[dict] = NotProvided,
optimizer: Optional[dict] = NotProvided,
max_requests_in_flight_per_sampler_worker: Optional[int] = NotProvided,
Expand Down Expand Up @@ -1726,7 +1730,16 @@ def training(
the shapes of these tensors are).
grad_clip_by: See `grad_clip` for the effect of this setting on gradient
clipping. Allowed values are `value`, `norm`, and `global_norm`.
train_batch_size: Training batch size, if applicable.
train_batch_size_per_learner: Train batch size per individual Learner
worker. This setting only applies to the new API stack. The number
of Learner workers can be set via `config.resources(
num_learner_workers=...)`. The total effective batch size is then
`num_learner_workers` x `train_batch_size_per_learner` and can
be accessed via the property `AlgorithmConfig.total_train_batch_size`.
train_batch_size: Training batch size, if applicable. When on the new API
stack, this setting should no longer be used. Instead, use
`train_batch_size_per_learner` (in combination with
`num_learner_workers`).
model: Arguments passed into the policy model. See models/catalog.py for a
full list of the available model options.
TODO: Provide ModelConfig objects instead of dicts.
Expand Down Expand Up @@ -1766,6 +1779,8 @@ def training(
"or 'global_norm'!"
)
self.grad_clip_by = grad_clip_by
if train_batch_size_per_learner is not NotProvided:
self.train_batch_size_per_learner = train_batch_size_per_learner
if train_batch_size is not NotProvided:
self.train_batch_size = train_batch_size
if model is not NotProvided:
Expand Down Expand Up @@ -2716,20 +2731,29 @@ def uses_new_env_runners(self):
self.env_runner_cls, RolloutWorker
)

@property
def total_train_batch_size(self):
if self.train_batch_size_per_learner is not None:
return self.train_batch_size_per_learner * (self.num_learner_workers or 1)
else:
return self.train_batch_size

# TODO: Make rollout_fragment_length as read-only property and replace the current
# self.rollout_fragment_length a private variable.
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
"""Automatically infers a proper rollout_fragment_length setting if "auto".
Uses the simple formula:
`rollout_fragment_length` = `train_batch_size` /
`rollout_fragment_length` = `total_train_batch_size` /
(`num_envs_per_worker` * `num_rollout_workers`)
If result is a fraction AND `worker_index` is provided, will make
those workers add additional timesteps, such that the overall batch size (across
the workers) will add up to exactly the `train_batch_size`.
the workers) will add up to exactly the `total_train_batch_size`.
Returns:
The user-provided `rollout_fragment_length` or a computed one (if user
provided value is "auto"), making sure `train_batch_size` is reached
provided value is "auto"), making sure `total_train_batch_size` is reached
exactly in each iteration.
"""
if self.rollout_fragment_length == "auto":
Expand All @@ -2739,11 +2763,11 @@ def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
# 4 workers, 3 envs per worker, 2500 train batch size:
# -> 2500 / 12 -> 208.333 -> diff=4 (208 * 12 = 2496)
# -> worker 1: 209, workers 2-4: 208
rollout_fragment_length = self.train_batch_size / (
rollout_fragment_length = self.total_train_batch_size / (
self.num_envs_per_worker * (self.num_rollout_workers or 1)
)
if int(rollout_fragment_length) != rollout_fragment_length:
diff = self.train_batch_size - int(
diff = self.total_train_batch_size - int(
rollout_fragment_length
) * self.num_envs_per_worker * (self.num_rollout_workers or 1)
if (worker_index * self.num_envs_per_worker) <= diff:
Expand Down Expand Up @@ -3095,36 +3119,38 @@ def validate_train_batch_size_vs_rollout_fragment_length(self) -> None:
Raises:
ValueError: If there is a mismatch between user provided
`rollout_fragment_length` and `train_batch_size`.
`rollout_fragment_length` and `total_train_batch_size`.
"""
if (
self.rollout_fragment_length != "auto"
and not self.in_evaluation
and self.train_batch_size > 0
and self.total_train_batch_size > 0
):
min_batch_size = (
max(self.num_rollout_workers, 1)
* self.num_envs_per_worker
* self.rollout_fragment_length
)
batch_size = min_batch_size
while batch_size < self.train_batch_size:
while batch_size < self.total_train_batch_size:
batch_size += min_batch_size
if (
batch_size - self.train_batch_size > 0.1 * self.train_batch_size
or batch_size - min_batch_size - self.train_batch_size
> (0.1 * self.train_batch_size)
if batch_size - self.total_train_batch_size > (
0.1 * self.total_train_batch_size
) or batch_size - min_batch_size - self.total_train_batch_size > (
0.1 * self.total_train_batch_size
):
suggested_rollout_fragment_length = self.train_batch_size // (
suggested_rollout_fragment_length = self.total_train_batch_size // (
self.num_envs_per_worker * (self.num_rollout_workers or 1)
)
raise ValueError(
f"Your desired `train_batch_size` ({self.train_batch_size}) or a "
"value 10% off of that cannot be achieved with your other "
"Your desired `total_train_batch_size` "
f"({self.total_train_batch_size}={self.num_learner_workers} "
f"learners x {self.train_batch_size_per_learner}) "
"or a value 10% off of that cannot be achieved with your other "
f"settings (num_rollout_workers={self.num_rollout_workers}; "
f"num_envs_per_worker={self.num_envs_per_worker}; "
f"rollout_fragment_length={self.rollout_fragment_length})! "
"Try setting `rollout_fragment_length` to 'auto' OR "
"Try setting `rollout_fragment_length` to 'auto' OR to a value of "
f"{suggested_rollout_fragment_length}."
)

Expand Down Expand Up @@ -3580,8 +3606,7 @@ def _validate_evaluation_settings(self):
"""Checks, whether evaluation related settings make sense."""
if (
self.evaluation_interval
and self.env_runner_cls is not None
and not issubclass(self.env_runner_cls, RolloutWorker)
and self.uses_new_env_runners
and not self.enable_async_evaluation
):
raise ValueError(
Expand Down
13 changes: 0 additions & 13 deletions rllib/algorithms/appo/tests/test_appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,6 @@ def test_appo_compilation(self):
num_iterations = 2

for _ in framework_iterator(config):
print("w/o v-trace")
config.vtrace = False
algo = config.build(env="CartPole-v1")
for i in range(num_iterations):
results = algo.train()
print(results)
check_train_results(results)

check_compute_single_action(algo)
algo.stop()

print("w/ v-trace")
config.vtrace = True
algo = config.build(env="CartPole-v1")
for i in range(num_iterations):
results = algo.train()
Expand Down
23 changes: 13 additions & 10 deletions rllib/algorithms/appo/tf/appo_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
OLD_ACTION_DIST_LOGITS_KEY,
)
from ray.rllib.algorithms.appo.appo_learner import AppoLearner
from ray.rllib.algorithms.impala.tf.impala_tf_learner import ImpalaTfLearner
from ray.rllib.algorithms.impala.tf.vtrace_tf_v2 import make_time_major, vtrace_tf2
from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY
from ray.rllib.core.learner.tf.tf_learner import TfLearner
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.nested_dict import NestedDict
Expand All @@ -19,10 +19,10 @@
_, tf, _ = try_import_tf()


class APPOTfLearner(AppoLearner, TfLearner):
class APPOTfLearner(AppoLearner, ImpalaTfLearner):
"""Implements APPO loss / update logic on top of ImpalaTfLearner."""

@override(TfLearner)
@override(ImpalaTfLearner)
def compute_loss_for_module(
self,
*,
Expand Down Expand Up @@ -72,12 +72,15 @@ def compute_loss_for_module(
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)
bootstrap_values_time_major = make_time_major(
batch[SampleBatch.VALUES_BOOTSTRAPPED],
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)
bootstrap_value = bootstrap_values_time_major[-1]
if self.config.uses_new_env_runners:
bootstrap_values = batch[SampleBatch.VALUES_BOOTSTRAPPED]
else:
bootstrap_values_time_major = make_time_major(
batch[SampleBatch.VALUES_BOOTSTRAPPED],
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)
bootstrap_values = bootstrap_values_time_major[-1]

# The discount factor that is used should be gamma except for timesteps where
# the episode is terminated. In that case, the discount factor should be 0.
Expand All @@ -100,7 +103,7 @@ def compute_loss_for_module(
discounts=discounts_time_major,
rewards=rewards_time_major,
values=values_time_major,
bootstrap_value=bootstrap_value,
bootstrap_values=bootstrap_values,
clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold,
clip_rho_threshold=config.vtrace_clip_rho_threshold,
)
Expand Down
Loading

0 comments on commit e03dd6e

Please sign in to comment.