Skip to content

Commit

Permalink
[Liger-kernel] Add an option to use `_apply_liger_kernel_to_instance(…
Browse files Browse the repository at this point in the history
…)` to load model (#133)

## Summary

This PR enables to use Liger Kernel's `_apply_liger_kernel_to_instance`
to init a fsdp worker model.

## Main Changes

1. Adding an option of using
`liger_kernel.transformers.AutoLigerKernelForCausalLM` to load a model
from pretained, instead of the default
`transformers.AutoModelForCausalLM`
2. Added a test case using configuration file
`tests/e2e/run_qwen_gsm8k_model_rm_liger_kernel.sh`

## Related Issue

#96 

## TODO

#97 optimize the memory usage when computing entropy & log_probs

https://github.com/volcengine/verl/blob/6d96fda3d47f057caaa8f494ca7804181903e911/verl/workers/actor/dp_actor.py#L94-L106

---------

Signed-off-by: Hongpeng Guo <[email protected]>
  • Loading branch information
hongpeng-guo authored Jan 30, 2025
1 parent df03aa6 commit dd41877
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/e2e_gsm8k.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,7 @@ jobs:
run: |
ray stop --force
bash tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh
- name: Running gsm8k e2e with rmpad using model rm with Liger Kernel enabled
run: |
ray stop --force
bash tests/e2e/run_qwen_gsm8k_model_rm_liger_kernel.sh
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"transformers<4.48",
"vllm<=0.6.3",
"peft",
"liger-kernel",
]

# Optional dependencies (extras_require in setup.py)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ tensordict<0.6
transformers<4.48
vllm<=0.6.3
wandb
liger-kernel
52 changes: 52 additions & 0 deletions tests/e2e/run_qwen_gsm8k_model_rm_liger_kernel.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
set -x

export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
+actor_rollout_ref.model.use_liger=True \
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=32 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.optim.lr_warmup_steps_ratio=0.05 \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=32 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.grad_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
reward_model.enable=True \
reward_model.model.path=Qwen/Qwen2.5-0.5B\
reward_model.model.use_remove_padding=True \
reward_model.model.fsdp_config.param_offload=True \
reward_model.micro_batch_size=16 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
+trainer.val_before_train=False \
trainer.project_name='verl_example' \
trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_training_steps=1 $@
8 changes: 8 additions & 0 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def _build_model_optimizer(self,
use_remove_padding=False,
enable_gradient_checkpointing=False,
trust_remote_code=False,
use_liger=False,
role='actor'):
from verl.utils.model import print_model_size, update_model_config
from verl.utils.torch_dtypes import PrecisionType
Expand Down Expand Up @@ -198,6 +199,11 @@ def _build_model_optimizer(self,
config=actor_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
# Apply Liger kernel to the model if use_liger is set to True
if use_liger:
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance
_apply_liger_kernel_to_instance(model=actor_module)

# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
actor_module.to(torch_dtype)

Expand Down Expand Up @@ -338,6 +344,7 @@ def init_model(self):
use_remove_padding=use_remove_padding,
enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False),
trust_remote_code=self.config.model.get('trust_remote_code', False),
use_liger=self.config.model.get('use_liger', False),
role='actor')

# get the original unwrapped module
Expand Down Expand Up @@ -370,6 +377,7 @@ def init_model(self):
use_remove_padding=use_remove_padding,
trust_remote_code=self.config.model.get(
'trust_remote_code', False),
use_liger=self.config.model.get('use_liger', False),
role='ref')[0]
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
Expand Down

0 comments on commit dd41877

Please sign in to comment.