Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add remax support #234

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

liziniu
Copy link

@liziniu liziniu commented Feb 9, 2025

Description

Added ReMax support to verl. ReMax is a simple, efficient, and stable RL algorithm customized for LLM training, with theoretical guarantees for variance reduction.

The HybridFlow paper experimented with ReMax, but verl did not provide an implementation. Therefore, ReMax has been added.

Changes

  • Added RayReMaxTrainer implementation
  • Added example scripts for ReMax training
  • Added documentation for ReMax algorithm

Testing

  • Tested ReMax example scripts with Qwen models

validation reward of optimizing Qwen2.5-3B-Instruct on the GSM8K dataset:

截屏2025-02-09 20 51 14

The curve demonstrates the effectiveness of ReMax, though its performance can be further enhanced through hyperparameter fine-tuning.

Documentation

  • Added ReMax documentation
  • Updated example configurations

Checklist

  • Code follows project's style guidelines (yapf formatted)
  • Tests added/updated and passing
  • Documentation updated
  • Example scripts added

@vermouth1992
Copy link
Collaborator

vermouth1992 commented Feb 9, 2025

Hi @liziniu Thank you for your contribution! According to our implementation, I guess ReMax can be implemented by adding a few lines to the original PPO/GRPO/Reinforce implementation instead of writing a new trainer to make maintenance easier. Correct me if this is invalid

@PeterSH6
Copy link
Collaborator

PeterSH6 commented Feb 9, 2025

+1. From my understanding, Remax can be implemented similarly to reinforce++ with a different adv estimator. See reinforce++ implementation: #228

@@ -41,7 +41,7 @@ verl is fast with:
- **vLLM** and **TGI** for rollout generation, **SGLang** support coming soon.
- huggingface models support
- Supervised fine-tuning
- Reinforcement learning from human feedback with [PPO](https://github.com/volcengine/verl/tree/main/examples/ppo_trainer) and [GRPO](https://github.com/volcengine/verl/tree/main/examples/grpo_trainer)
- Reinforcement learning from human feedback with [PPO](https://github.com/volcengine/verl/tree/main/examples/ppo_trainer), [GRPO](https://github.com/volcengine/verl/tree/main/examples/grpo_trainer), and [ReMax](https://github.com/volcengine/verl/tree/main/examples/remax_trainer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you have the training log and wandb already, would you mind adding one more record to docs/experiment/ppo.rst to include remax? it would help the community to track if experiment can be reproduced in future version.
We can do that in the next PR

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. A preliminary result on Qwen2.5-3B is added and more results will come later.

@liziniu
Copy link
Author

liziniu commented Feb 10, 2025

@vermouth1992 @PeterSH6 I see. Let me reformat the code with minimal changes of the PPO's trainer.

@liziniu
Copy link
Author

liziniu commented Feb 10, 2025

Hi, @vermouth1992 @PeterSH6

I have completed the implementation of ReMaX support. The changes include:

  1. Remove the new trainer for ReMax
  2. Implemented Remax based on PPO's trainer
  3. Updated the preliminary result in docs/experiment/ppo.rst

The code follows the project's style guidelines.

Please review when you have a chance. Let me know if any changes or clarifications are needed.

Thank you for your time!

verl/trainer/ppo/core_algos.py Outdated Show resolved Hide resolved
@vermouth1992
Copy link
Collaborator

vermouth1992 commented Feb 10, 2025

Could you add a CI to run remax with Qwen 0.5b to protect this functionality? You can follow the example here: https://github.com/volcengine/verl/blob/main/.github/workflows/e2e_gsm8k.yml#L69

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants