Skip to content

Commit

Permalink
Support single GPU RLVR (#567)
Browse files Browse the repository at this point in the history
* Support single GPU RLVR

* quick change

* quick change

* quick push

* quick fix on docs

* update script

* update docs
  • Loading branch information
vwxyzjn authored Feb 16, 2025
1 parent 4737499 commit 7489742
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 39 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ bash scripts/train/dpo/tulu_preference_mix.sh
### Reinforcement Learning with Verifiable Rewards (RLVR)

```bash
# quick debugging run using 2 GPU (1 for inference, 1 for training)
# here we are using `HuggingFaceTB/SmolLM2-360M-Instruct`; it's prob not
# gonna work, but it's easy to test run and print stuff.
# quick debugging run using 1 GPU (0.5 for inference, 0.5 for training)
# here we are using `HuggingFaceTB/SmolLM-135M-Instruct`; it's prob not
# gonna train good models, but it's easy to test run and print stuff.
bash scripts/train/rlvr/mini.sh
bash scripts/train/rlvr/grpo_mini.sh

Expand Down
4 changes: 2 additions & 2 deletions docs/ai2_internal.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 200000 \
Expand Down Expand Up @@ -313,7 +313,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--learning_rate 5e-7 \
--total_episodes 1000000 \
Expand Down
24 changes: 12 additions & 12 deletions docs/archived_dev_scripts/olmo2_1124.sh
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--total_episodes 100000 \
--penalty_reward_value -10.0 \
Expand Down Expand Up @@ -315,7 +315,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--total_episodes 200000 \
--penalty_reward_value -10.0 \
Expand Down Expand Up @@ -379,7 +379,7 @@ python mason.py \
--dataset_eval_splits test_prefs \
--model_name_or_path allenai/open_instruct_dev \
--model_revision 1206_finetune_epoch_2_lr_1e-5_loss_type_sum__4__1733525407 \
--chat_template tulu \
--chat_template_name tulu \
--learning_rate 3e-6 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
Expand Down Expand Up @@ -434,7 +434,7 @@ python mason.py \
--dataset_eval_splits test_prefs \
--model_name_or_path allenai/open_instruct_dev \
--model_revision 1208_bsz64_13b_finetune_epoch_2_lr_5e-6_loss_type_sum__1__1733711678 \
--chat_template tulu \
--chat_template_name tulu \
--learning_rate 3e-6 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
Expand Down Expand Up @@ -489,7 +489,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--total_episodes 200000 \
--penalty_reward_value -10.0 \
Expand Down Expand Up @@ -547,7 +547,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--total_episodes 200000 \
--penalty_reward_value -10.0 \
Expand Down Expand Up @@ -605,7 +605,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--total_episodes 200000 \
--penalty_reward_value -10.0 \
Expand Down Expand Up @@ -664,7 +664,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--total_episodes 200000 \
--penalty_reward_value -10.0 \
Expand Down Expand Up @@ -722,7 +722,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--total_episodes 200000 \
--penalty_reward_value -10.0 \
Expand Down Expand Up @@ -782,7 +782,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--total_episodes 200000 \
--penalty_reward_value -10.0 \
Expand Down Expand Up @@ -840,7 +840,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--total_episodes 200000 \
--penalty_reward_value -10.0 \
Expand Down Expand Up @@ -898,7 +898,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--total_episodes 200000 \
--penalty_reward_value -10.0 \
Expand Down
8 changes: 4 additions & 4 deletions docs/archived_dev_scripts/olmoe_0125.sh
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 200000 \
Expand Down Expand Up @@ -464,7 +464,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 200000 \
Expand Down Expand Up @@ -517,7 +517,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 200000 \
Expand Down Expand Up @@ -571,7 +571,7 @@ python mason.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 200000 \
Expand Down
12 changes: 6 additions & 6 deletions docs/tulu3.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ python open_instruct/ppo_vllm_thread_ray_gtrl.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 10000000 \
Expand Down Expand Up @@ -312,10 +312,10 @@ Couple of notes:

```bash
source configs/beaker_configs/ray_node_setup.sh && python open_instruct/ppo_vllm_thread_ray_gtrl.py \
--dataset_mixer '{"allenai/RLVR-GSM-MATH-IF-Mixed-Constraints": 1.0}' \
--dataset_train_splits train \
--dataset_eval_mixer '{"allenai/RLVR-GSM-MATH-IF-Mixed-Constraints": 128}' \
--dataset_eval_splits train \
--dataset_mixer_list allenai/RLVR-GSM-MATH-IF-Mixed-Constraints 1.0 \
--dataset_mixer_list_splits train \
--dataset_mixer_eval_list allenai/RLVR-GSM-MATH-IF-Mixed-Constraints 16 \
--dataset_mixer_eval_list_splits train \
--max_token_length 2048 \
--max_prompt_token_length 2048 \
--response_length 2048 \
Expand All @@ -330,7 +330,7 @@ source configs/beaker_configs/ray_node_setup.sh && python open_instruct/ppo_vllm
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--learning_rate 1e-7 \
--total_episodes 400000 \
Expand Down
11 changes: 10 additions & 1 deletion open_instruct/grpo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ class Args:
# ray
actor_num_gpus_per_node: List[int] = field(default_factory=lambda: [1])
"""number of gpus per node for actor"""
single_gpu_mode: bool = False
"""whether to collocate vLLM and actor on the same node (mostly for debugging purposes)"""
vllm_num_engines: int = 1
"""number of vLLM Engines, set to 0 to disable vLLM"""
vllm_tensor_parallel_size: int = 1
Expand All @@ -253,6 +255,8 @@ class Args:
"""whether to enforce eager mode for vLLM -- slow inference but needed for multi-node"""
vllm_sync_backend: str = "nccl"
"""DeepSpeed -> vLLM weight sync backend"""
vllm_gpu_memory_utilization: float = 0.9
"""vLLM GPU memory utilization"""
enable_prefix_caching: bool = False
"""whether to enable prefix caching"""
deepspeed_stage: int = 0
Expand Down Expand Up @@ -1492,11 +1496,12 @@ def __init__(
pg: PlacementGroup,
ray_process_cls: RayProcess,
num_gpus_per_node: List[int],
single_gpu_mode: bool,
):
self.pg = pg
self.ray_process_cls = ray_process_cls
self.num_gpus_per_node = num_gpus_per_node
self.num_gpus_per_actor = 1
self.num_gpus_per_actor = 0.48 if single_gpu_mode else 1
self.num_cpus_per_actor = 4
self.models = []
world_size = sum(self.num_gpus_per_node)
Expand Down Expand Up @@ -1630,6 +1635,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
pg,
PolicyTrainerRayProcess,
args.actor_num_gpus_per_node,
args.single_gpu_mode,
)
wandb_url = wandb.run.get_url() if args.with_tracking else None
inits.extend(
Expand All @@ -1645,6 +1651,9 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
args.seed,
args.enable_prefix_caching,
max_len,
args.vllm_gpu_memory_utilization,
args.single_gpu_mode,
pg=pg if args.single_gpu_mode else None,
)

metrics_queue = RayQueue()
Expand Down
11 changes: 10 additions & 1 deletion open_instruct/ppo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ class Args:
# ray
actor_num_gpus_per_node: List[int] = field(default_factory=lambda: [1])
"""number of gpus per node for actor"""
single_gpu_mode: bool = False
"""whether to collocate vLLM and actor on the same node (mostly for debugging purposes)"""
vllm_num_engines: int = 1
"""number of vLLM Engines, set to 0 to disable vLLM"""
vllm_tensor_parallel_size: int = 1
Expand All @@ -261,6 +263,8 @@ class Args:
"""whether to enforce eager mode for vLLM -- slow inference but needed for multi-node"""
vllm_sync_backend: str = "nccl"
"""DeepSpeed -> vLLM weight sync backend"""
vllm_gpu_memory_utilization: float = 0.9
"""vLLM GPU memory utilization"""
enable_prefix_caching: bool = False
"""whether to enable prefix caching"""
deepspeed_stage: int = 0
Expand Down Expand Up @@ -1571,11 +1575,12 @@ def __init__(
pg: PlacementGroup,
ray_process_cls: RayProcess,
num_gpus_per_node: List[int],
single_gpu_mode: bool,
):
self.pg = pg
self.ray_process_cls = ray_process_cls
self.num_gpus_per_node = num_gpus_per_node
self.num_gpus_per_actor = 1
self.num_gpus_per_actor = 0.48 if single_gpu_mode else 1
self.num_cpus_per_actor = 4
self.models = []
world_size = sum(self.num_gpus_per_node)
Expand Down Expand Up @@ -1709,6 +1714,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
pg,
PolicyTrainerRayProcess,
args.actor_num_gpus_per_node,
args.single_gpu_mode,
)
wandb_url = wandb.run.get_url() if args.with_tracking else None
inits.extend(
Expand All @@ -1724,6 +1730,9 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
args.seed,
args.enable_prefix_caching,
max_len,
args.vllm_gpu_memory_utilization,
args.single_gpu_mode,
pg=pg if args.single_gpu_mode else None,
)

metrics_queue = RayQueue()
Expand Down
17 changes: 12 additions & 5 deletions open_instruct/vllm_utils2.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,21 @@ def create_vllm_engines(
seed: int,
enable_prefix_caching: bool,
max_model_len: int,
vllm_gpu_memory_utilization: float = 0.9,
single_gpu_mode: bool = False,
pg: Optional[ray.util.placement_group] = None,
):
vllm_engines = []
for i in range(num_engines):
# When tensor_parallel_size=1, vLLM init model in LLMEngine directly, assign 1 GPU for it.
num_gpus = int(tensor_parallel_size == 1)
scheduling_strategy = None

if tensor_parallel_size > 1:
bundles = [{"GPU": 1, "CPU": 1}] * tensor_parallel_size
if pg is not None:
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group_capture_child_tasks=True
)
elif tensor_parallel_size > 1:
bundles = [{"GPU": 1, "CPU": 4}] * tensor_parallel_size
pg = placement_group(bundles)
ray.get(pg.ready())

Expand All @@ -224,8 +230,8 @@ def create_vllm_engines(
print(f"vllm: {num_gpus=}, {num_engines=}")
vllm_engines.append(
LLMRayActor.options(
num_cpus=1,
num_gpus=num_gpus,
num_cpus=4,
num_gpus=0.48 if single_gpu_mode else num_gpus,
scheduling_strategy=scheduling_strategy,
).remote(
pretrain,
Expand All @@ -238,6 +244,7 @@ def create_vllm_engines(
seed=seed + i,
enable_prefix_caching=enable_prefix_caching,
max_model_len=max_model_len,
gpu_memory_utilization=vllm_gpu_memory_utilization,
)
)

Expand Down
7 changes: 5 additions & 2 deletions scripts/train/rlvr/grpo_mini.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
python open_instruct/grpo_vllm_thread_ray_gtrl.py \
--dataset_mixer_list allenai/RLVR-GSM 1.0 \
--dataset_mixer_list_splits train \
--dataset_mixer_eval_list allenai/RLVR-GSM 1.0 \
--dataset_mixer_eval_list allenai/RLVR-GSM 16 \
--dataset_mixer_eval_list_splits train \
--max_token_length 1023 \
--max_prompt_token_length 1024 \
Expand Down Expand Up @@ -33,5 +33,8 @@ python open_instruct/grpo_vllm_thread_ray_gtrl.py \
--save_freq 100 \
--reward_model_multiplier 0.0 \
--gradient_checkpointing \
--single_gpu_mode \
--vllm_sync_backend gloo \
--vllm_gpu_memory_utilization 0.5 \
--vllm_enforce_eager \
--with_tracking
# --with_tracking
7 changes: 5 additions & 2 deletions scripts/train/rlvr/mini.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ python open_instruct/ppo_vllm_thread_ray_gtrl.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 10000 \
Expand All @@ -33,5 +33,8 @@ python open_instruct/ppo_vllm_thread_ray_gtrl.py \
--save_freq 100 \
--reward_model_multiplier 0.0 \
--gradient_checkpointing \
--single_gpu_mode \
--vllm_sync_backend gloo \
--vllm_gpu_memory_utilization 0.5 \
--vllm_enforce_eager \
--with_tracking
--with_tracking
2 changes: 1 addition & 1 deletion scripts/train/rlvr/tulu_rlvr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ python open_instruct/ppo_vllm_thread_ray_gtrl.py \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--chat_template_name tulu \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 10000000 \
Expand Down

0 comments on commit 7489742

Please sign in to comment.