From d36c1c7552ab11e934a9d3eaf85bb05fcb410256 Mon Sep 17 00:00:00 2001 From: sunyi001 <1659275352@qq.com> Date: Sat, 22 Feb 2025 16:57:01 +0800 Subject: [PATCH] support ASCEND NPU --- docs/ascend/ascend.md | 16 +++---- docs/ascend/images/loss_comparison.png | Bin 10549 -> 0 bytes examples/grpo_trainer/run_qwen2-7b_npu.sh | 41 ------------------ pyproject.toml | 1 + requirements-npu.txt | 4 +- verl/bert_padding.py | 46 +++++++++++++++----- verl/trainer/fsdp_sft_trainer.py | 37 ++++++++-------- verl/utils/device.py | 35 +++++++++++++--- verl/utils/fsdp_utils.py | 6 +-- verl/workers/fsdp_workers.py | 49 +++++++++++----------- 10 files changed, 123 insertions(+), 112 deletions(-) delete mode 100644 docs/ascend/images/loss_comparison.png delete mode 100644 examples/grpo_trainer/run_qwen2-7b_npu.sh diff --git a/docs/ascend/ascend.md b/docs/ascend/ascend.md index fd96a20e..c10ef90b 100644 --- a/docs/ascend/ascend.md +++ b/docs/ascend/ascend.md @@ -1,6 +1,6 @@ # veRL x Ascend -我们在 verRL 上增加对华为昇腾设备的支持,在华为昇腾设备上使用 veRL 与在英伟达 GPU 上使用几乎相同。 +我们在 veRL 上增加对华为昇腾设备的支持,在华为昇腾设备上使用 veRL 与在英伟达 GPU 上使用几乎相同。 ## 硬件支持 @@ -22,7 +22,7 @@ ### 源码安装 ```shell -git clone -b vllm-0.7-npu https://github.com/as12138/verl.git +git clone https://github.com/volcengine/verl.git cd verl pip install -r requirements-npu.txt pip install -e . @@ -50,17 +50,17 @@ pip install -e . 根据经验,我们期望在相同配置下,在华为昇腾设备上的 Loss 与英伟达 GPU 的 Loss 平均误差小于 2%,具体计算方式如下: -![loss_comparison](./images/loss_comparison.png) +![loss_comparison](https://github.com/eric-haibin-lin/verl-community/tree/main/docs/loss_comparison.png) 其中,N 表示训练的步数。更多信息请参考[精度计算说明](https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/LMaccuracy_0001.html)。 ### 进展 -| 算法 | 进展 | -|------|---------------------------------------------------------------| -| SFT | 已支持 | -| PPO | 已支持 | -| GRPO | 已支持 | +| 算法 | 进展 | +|:------|:----| +| SFT | 已支持 | +| PPO | 已支持 | +| GRPO | 已支持 | > 补充说明: diff --git a/docs/ascend/images/loss_comparison.png b/docs/ascend/images/loss_comparison.png deleted file mode 100644 index de073429eecdb350c8f00b57db7e7c2d7b1add95..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10549 zcmdUV^;=Y5_$^46lyo=JB1p*~A&9g}mw|;!3*Xh?QGfhcJ62lY;~lE>yN`d;4CH8h z`FPhMYi{^(J{R+_YOX8pQ0Y_w;lJCJKjq#z)Oy-6Wf5At*M(|^9!F~OMAN8|+Ktm2 zF{zRAMAzYIOvC^`@QC5jg8w^2@kqlOI6%{Vi{}5kA>FsIh#Q>u(QxFof4~jyG$`^0 zw-+f{6O4xuZ?w7SSGhj;;0fHUhXn7In%Sd>t7Wd#yvy136wZTqEwQpl9L841U90%! z`V>)Y+VPj6I5%kiTorA`30TaZ&WNe|RySS$Mc{UQf68rzl@I!L;n#=xwicIogtP88 zhhIi%wb(~J*;E0`W@pE@(J&v)r7wpH&6)=K$$y> z@nBg0*P z&2+;JEM(%>QC#&`ESI-fv9mDPTThW@GjDTGCRAKyarn(MC@S4dzxBjOTi zaeV?uD6i!3c#r{0Ehe0QF;NDo@yQ47U(qP&Z-S77oCIru(8DD%>!=2l3#ozf@x-X* zJQfl|C!%sOnyT5)?mxH@O|Q>97rNJ3>>YWDE_PPNnCGf2x31@17HHZ!+%#l%dK^`n zcDe=hGMFaKpAQA4etkWbZJCqYW*UkfRbIW!EEUP^Y_jxv1&!e0}UzSIClC-YAWU4uEU-9snAL z(E2=CSB+T6>2psxvhE$=@ z_Oy?4`@RqzbzSaada`1#GHSzCH}XPvH>-;(%I}^egZB z79S3)L|{`U5=(+sZdJlSGB+fwu2saBE<{-x7)vx>l80pOeiCy0_TDwy??ziMsZnQ@D>uh*E1 zQ@L95LZH3cYc4xSeYl{;5yI+`as&dXcz6-yyo!A`HI*NEMSXm>l4iiEyyI$s=~+ZC zT%U~Q1K`mg3`OZo&|aQg0@GT+?E;=_P5a+kM#p9ie_?`^WDilyB5juUyl-wUYJA1) z{O7<^;5jcs4Gx`)-5^AJKxrIr$*=o)2T6{qZ!Eo9ztSf_+HXU6t@SZ)i9ZlIKJBf(|GA#2{ z0Iy~t{B%w$HGJDaNry9WhzW7l>vcmR*-?k?i->*~R)KrKO1vVh^f!YUj1zeCKm>cO zFdVEORHFi58v;pb3vtWxOJCS3>o6tuY~^^p_}03$J|r|yhq7U$8-|VEXEXI0K!Iq% zcAS8du8VaMmLay8p~Fl=$QhDJsKE=L@KO9o3rN1M7;;h%OSEhp7+-;LF z!F-o7Dzj5sPP(sX+b~cSXxOj0w#;?&I;azIFjl0AO)lXvJ_-BC zs!}@sXeZ#fBviPnD|aiJJ^UT&wilzQcT*e?a-Exc#nyp;9FG756r%gCm3ErOdz_8u zoGX&cDGs@dB%)d*R-u z`mkf&5ZG3$JsF$|BYL2+Lc#VwnA0`4>6jI`_N&@~j0s@OY;jwrnUJeU83|NIb3x4N z8}13mqtVoe0N2f#`&H%sdeB^z#Sn8^(s1YEZ3yWUS4GZxfjX?)Y&XhC2; zAY#fWk4zUQy?4KXW?>axc}1xaa+tj^yR&SzYH<3`nH^nBA-d?#!r77WoAGT@)c+zL zNko_99NHA zHrkpbzxzxYn}To9se33Gu=DL|J`+>6=&RNu{aQ!w+J3v(mmZs3U=@HqG}{d$Wua6x zWE&|-x!yF0+=9K8Hpw%Hz@{nP(>m(h@$|-|fWH$l1w?XG8E9C476PJIHNE|p9pJKa z$lMK!n&HoyC7%2}zx=NJ1_imzz!xsn0o2ek0S`u_{*cz2>&u|?!x+ghJJ-#=*jffr zgY3**1G&5Pn$wett9$dH^Irymb3XQjbokZb;W9@8gzU%%*r7cbg%96SmwBgEacVLT zek1L{y~17Ejpzrezn@u6?Jf5!*Yk?1+b553wy7W^_M^B0igw*?FD-(sibnVHuaC-f zER{b4hBx?Uc&&j7o4BFz7=n z4u3u#LeWvJ$HJ@1&=9(TDsk#ng$hn&RYcZlFZ|p0wgikkRJ(=Sq{X5S zyOYQ6ARZ5FLw;Zzf;E2hu50q~986{jN-*M6ZO9mzCYROGdM=~Y27L|J%lf|O_Kt$8 z)2=A{%?-@@Zv$b(%MuiG-L(s@#iYnZ_vKPA&WG;Jfy;JUmQ!+X?uZV`9Nde(v}+xY zab8PE78`=yxi#?9AXrJ}2?45wrdc2U+BnkxyDYC;-ka{W527` zU97@{ROa54p{XE#A%(uK3MN-Qi;PY@XYgrw@2sWee2UTia51_`yE~(0+4hjg&!lT0N|G&7Hh8l|JjL8mVu)7a0Qb#Qz{NO};U(q+2`slE~iC;#FyaBGopKXrNd7 z4S#4qaow+(YMx>W51XhS--j0^KJu|LQC@$8D7r>p4~W)=jnrFNnp2|XBYWODvYz^) zN_^!@F+SEqHMF)evMMSMgp z0&&uN2gPRfO{+5XEdz=T{O95>YLPGjoYdm?u$UR#d@TNy6+GVjCyya6qyLyx+inei z=m-}Z$*11bGB2)ol27!mf06XuNn!%zq+;{{btt~=a6?{c-RSaW(G?jbpcgFcjWU3= zcRk`XUHdCdcf;`RRWVSS{Feh*l2Ne~KBxN7=PoXW&xTJWY1{2v1QnBPTfI;#DmvG> znvLZis}>?h`*5wH#V9`XyyKG><5Hpu=TeImuGj&8Cx5K|f-W(;B5b9V<;lt>M6fYz z&~G?8EsJHv9O0ftN?-h2v0JZmPu;~h53ZU>mIr&-^Ojy;W15z+wa^`t*~-GDtLofz zKGF(*u=_(+b0Q`fm|M4rMwIE$?7LhE3%FI`0f*35e9M-_s@Q_9?X~xqAv1HZiSI1w zM_A}YeJ^1(4y`dwm2m4Yd<%u~9v=4kY3fnTYfEa6UkR5r8 za`}=$(K86Hf~m#yLhMG|r@LayDt{K{xf#ZnJ!MFz060kk9%uSJX~oN@l0|*Ua2uBq zusVn>yCvF6%Z<_t;G}hwfL!=L$h( zfxI;l$Xj7_)gac2W>P{AsCF_lqJFZdp3l!rmr}!Lmd95%LAwZgwMps3J(6QUgLcc& zD|A$T>u8<-#rFBwhaX*nlg_R5!YG9VS)}`jy3*@ggxMdMSR;@-%VB~``493pfjaI_ zjkB^}UK~9(onah@^gq`;(<#5Z!_UPUo*sPu`v6(xtSSe^O8qqt{%N{8%30ey3v;yL zvr2e4Y#t-PqZqdZb->KP^=yTqvOq?b_b(%3*ynZ5<=a4Y`(DPfd2O&x1}9ZcZeVfr zfof3|OW9@RRoc7fPhFdmBLrH`ToqDr)l*WZHIG>RRC8*kgfqFNz7PA} z?_}ou%0nsDEhEmMf+D7&a$k;E|Bh!?Rh78$+SbB{4Or5fl~7Z#edqu6^P0tlzC0Tb zw%GH+Ng1eY&ma8OnNV;7&(Rj564n+TbVa8h+m9ZzoqiF85XF8Lu0$Hhrv=)WT<)Wn zZF|+s2H~RHOVu*i>0d;5j#iT12fw+ zJvm^|qe?Gfp<775)yen4!%3@KPb~G@JjLZcO>EP~v?;z=3s1MUl$!!T@PA^Z47-?C z(O-C%y5+lAsCTPh*pgiQ7ufe~k`!fcBa$NIz2u}7If0*5eO*3Wtd3~r)&X)pCv$Tu zBJbTqaRgD>r1hX{p41fp?@{pKfgL-!o*Jn)XQ?+*V=x$ISmf~Nx{3aB)q|G zE7HfU(o^2Rgk@D8OuS@K zaN&)lTIgl6?dTY4B~8sJ0l5oovFmiXte9_}?H;q)Qc38f5a;y{!8D)IMXOJd&%Gne z_w2=Q+I!4oJpmp+CgR9Ri_CN9CtejlD9{1t#0+&5YHV1$C>#jwEG?#z+0&a9Ct1B<(HzNgS`*fMcDd{ z`r_FAo(K_}TGXz;_5bI^X#d>>j|ahb0gZk)wn&iL)sO_@3JVkqJ?ql<oPxh(f3#Hvz@u4L}4;6i% zZx7VKJ}4o|6=V~}_SFjbtMAXqb47Ramnf`o<)E`dD&C?-!S_1@sZzGyyjv*7<9?-& zO5dL58LSpvdp!6rjSYzV@+JrD+AK*f$p+!?dSYYJ8@!%7B}<7iB`~I{M%cJCmKOGn z+(UIggPZfDo9C=B(EgrjWMO+7Cpz%{^Iwvm_T@HkP$cHD#@x&j-y0%xVZO#u(q|x@ z&t;X=?)RqlEE;(Jay3VBGR7q;o9 z9}>;7e*em(JrS%`UGiI?dJaMTD2^Pm{JM2}o*%|C+=b|=ftw``ZA1tTQXiQwi&-{o(g?)jruL5LzQd(CMaYOwd3a!pJQSF0Dql@eWlO4!tSvvUGBN2j1${spW+&2R1 z?J#dSKDO>>TCZLAMI=6sdS4Goe@*N@%2*Zs1G-A;NfJS1KmMW2qHXfsQNWf@q>7&_ z+;gTcK#RTibjH2bONsBS`(jC;0FGAi|;H&ajSf^FphJkccap8 zS7>!kD@O`L`N`1qYEQAFK>kZDe?2)&V?sF{AhO`1wrRbgshe2N_q~h{1hb@zwrg?RF!GW{yJX0Cvj&F3$C%E~2g^82|Sdw`f$$%Hye0 zo}MoHIa^aD{K<$;ue&Kr4prc^eOXM99+7PFV+QcR;|LeSFsnXquz%sfNukiU9Xs@Px;F9!|z>PDzzyl>8chs0s?b# zS&D-KIchsK*T=xyEc4ZK(B4>DFu5hZk4rUK2e|6CIp0f#QD?Wz3kwLWZ{zE6tRr+c zGk!E^i4riV)WIC z@8&!<%u_$?NJM4tx3B#xabNw$44>T3Ok1l@cl(tbWC^Nrycvzwzj!$CZ;f1P@8ulN zvOFya9^Qa7=irwa{N@lRJ0{!{Fp)zNLcQ!{+#d6t8T%Da5B~2N3k=@jk9&Oce14g= zbjHR}?5o5v?oHp)!@>S8nXBAbOkL{0j7cod3pL(m5y!$Jg%tO^=6mc#rN*BK|9Tww z+_Qz3Zul{K)|PeVcY9CmUR)%h=fiJ%F7P6@ZO^55=Q+G}hH2GEm+(8}&NJyPmI+Ok z<_6@|wmn?iWfIF9snnA^1G3usKuOXlw!y_t>T7R^HmKlY-@WFzX5qDU>O4YC=@oah z10yubHdmEt9P4(!4)CVkYum${qc>O2?+}AQ6Q4y1oh-<$fO^j>^zEZE26DT?N7MjI z4D~VQEDhpwx$~No9y^SRCt8Zde?MeltATw= z!Y#iem1@XBPo8P-+O=q?0qf5O*JI%>kse7mXnCPl+p9crk!Ao2ZcV00D(ad(WY||`Hcm(ixj$i%-ziCqiOJ-K98Iwm zs;}hu_OT%Q+37x|`RF;mKWSFqAXWKZ?HP#A5IAc}Ln;m8k9M^tp>=%m(VY?5FU);z z36U+9`u)z3aWpU#WVMTw(LFAB^yl>(QnS*)=Ol9gx0(CW`7Oh(qOd1cY5p{u%R{3! z<1!s$4>zjxXg}~o|JEhrjZNr^$9kiAziOFajXfUNP z0p>ncU88N-ecAx@%PHYJl<)HxJ~1&ku z3pfmQxkLPWQ_mz>)kZ)&3Nz4E-zgcY%gK8tX#oduG~G_6AzeMq)q&sJO-So;rwfsS zr^T*KMf#!ZXW<=4B)94W6^o26r{SKhuUoUZXSSiIZXz|#qPElZaV}Uqr%K$4Amb{kBG)#HR}<(rwIorC zpv5-?ZL8t=Dm0xaZFWnv|HMWaU<7`X%N7Ibo06VxrKXNgp+U+(tLjgMYi$$lOcLv+ zg5>BJjC#4+CpwQB2P}CX4cDZG$?QFFklN-Ywm{{?4b`1hvOa)_22V%KWBN)VNGP%d%TYi==-|>P=lu zFy!92!o3RUND&+tgqQqTU6PAmhv zM22XC_r1O3W!V#f-Gr2*7f+gQHFcK$X5LFIUvB)Eo}MmK+w(<+nv@}mmsC{Q>hO+8 zW4t1g+14%FwFqie9T0*DsB{Vm-q#Vp$)VHy2@-fGaJa5H9d}b|Vw4~Zq8b;n)?rB!^VuptuUhO1uAaf@~!3i6TF&+8> z)p(XwZAw1E^YG1~`)*$!7u#AlU5xf8;AK=@EuTS>qxOyLtVj;4 zcH%^w^)v<%gKd>10wUs~IMeMOPM364HA5#}($>0o111&`v*HWLy#pv7&{L zNZ$v6+I^6-JdI21g1@7+p6tv}UrF-Xtdd^;`0%rxb-$q0&S7 zP+=_{UPGSdaw&$@wm~Ri$wHUl%#JlZ0)!zI!@THzuCL(ON>^2m(!uSTQyZI8fuFYt zy&Jz*N=ACQ1!&{?3Qb}v`Hp>ci8-vKKWmlP}eUv)?fr(lnU2w~)`{iJQ^qRbio;Dlz& z{bfG3O`9HOY_(pt(L3j=MSBlpG5i!2`GgdH@prMRU+vc?O&KblS~FNqT3vJg#C+`y z2L9-i$ZIOLAS(4V`po8}82{NHQ_l{MfB(g-<+*>6D^6K$RHfz=wvB-fPg&e7i8fxk z;WmLpaw#Q9_<-M{N2cDS+w13RX|jF+Sd;R}q|!2n#7XV~gx(??lf}uplV>B(=hxhR z)hqfiOPC24s6CvPm0!Ky;Mj@V_Ljy&!L^CsOAG6MVe9%^2D6*kpfTzF{mR8G$f!(e zz&qhv1tRRT{86SbUyqo>=-q%ssvR_EakM#C1X0-jK+n^Cng9{bx06g!ZF1(zl)g@2 zX|<^hU0W4)oOvs|)>khs)+jkb?gw=Rw)et)>rPnkH7F~4aBrE#Ii*h5?0h3}FHV&u z>J4G2x^889Gm*0oq`Tu@DiV>FD{+_U2&{etY42Zo&VARZrcX>dq-+eskiI8;YXl=C zMM2XonyAKJ3p=ji1kHg*ynEfQRQ{ymGkO0+AZsDTL0YMQ#CBAdre|hhzc=C9GjF=4 z{cdwzhE+BfYMx$HIqPQ2LMbNw3rv%L+Q(Ns>YF*%em4ZBYe^FeemQ72DZ@waKCg0< zFI9@)glKFEWs9j9!-=Fh7U&)#)LxKXvdrJqNdL6DJjXB}hq~N=np5#__hcRFSN5b- z@dp2b-Kt4(h`fQz=Y?)bN>qWjYXRX>S z*9SZI$vg>{$q&}Eq=fY^ACeIh*gwyfqq0w%^aC2sTh~~v97jZQkI8Wb1yA(j-njSF zD!U|3?0JTJFM@6fZ_*^d0T1g>TN!C8`QJES?FG41|?%97wlSg_*lE&)$Ing*tneITe z=Tw4XA=d>+DLUt42n#c<+BDt&9>+aHNa0))H5$J2hK%QfML;WgW(mcEXd08J|I}60 zv{xIP=RdVlb^+zgSApWu3{#s{v8{_GJu__2*qt^&e)T^C*GzVU8p*BwMpQ|}DcVsY zK}#`AfsK>8iBk+}r0M|QNYl;wEpF&Va4ErdqKq%*=|mR6I7}=M(41JwG1voM^`F`& zx{kPRC@a+8McQXc*A0sWq$Z8o_cWu!3w1S`sd=Ksd4L`u4GJ&q?H)it%su=RuAAGV zW{PUlk?d^BW(&*@KS@xrCbdfhbV@wYvWt~f(5-h=so7EIrP*GrJ9R&9=a%1ZJ^$mJ zudH}|w!Cb>q&98;pCyL&KPqjOq-hG}_;gQ`{Z5d>AT5H9Mztj4KxF^(ziX6{ZJj_V z7Vt|TQ%sy?4bcDCGJ@Opx-Q=fassmhBoIIu=B(RiP@bH<&^#h*`8}Q*bhdal9z$dD z_@6-s@o(*Xafh=2.10", "tensordict<0.6", "transformers", + "vllm<=0.7.3", 'wandb', ] diff --git a/requirements-npu.txt b/requirements-npu.txt index 0ad7f301..b5bc661d 100644 --- a/requirements-npu.txt +++ b/requirements-npu.txt @@ -14,5 +14,5 @@ ray tensordict<0.6 transformers wandb -vllm -vllm-ascend +vllm==0.7.1 +vllm-ascend==0.7.1rc1 diff --git a/verl/bert_padding.py b/verl/bert_padding.py index d7584beb..60d7df5d 100644 --- a/verl/bert_padding.py +++ b/verl/bert_padding.py @@ -1,4 +1,31 @@ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py +# Copyright (c) 2023, Tri Dao. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import torch import torch.nn.functional as F @@ -6,6 +33,7 @@ class IndexFirstAxis(torch.autograd.Function): + @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) @@ -14,9 +42,8 @@ def forward(ctx, input, indices): second_dim = other_shape.numel() # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # return input[indices] - return torch.gather( - rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) - ).reshape(-1, *other_shape) + return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", + d=second_dim)).reshape(-1, *other_shape) @staticmethod def backward(ctx, grad_output): @@ -39,14 +66,13 @@ def backward(ctx, grad_output): class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod def forward(ctx, values, indices, first_axis_dim): ctx.save_for_backward(indices) assert indices.ndim == 1 assert values.ndim >= 2 - output = torch.zeros( - first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype - ) + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. output[indices] = values # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) @@ -65,6 +91,7 @@ def backward(ctx, grad_output): class IndexFirstAxisResidual(torch.autograd.Function): + @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) @@ -182,9 +209,8 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng """ length = attention_mask_in_length.sum(dim=-1) seqlen = attention_mask_in_length.size(-1) - attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), - seqlen) < length.unsqueeze( - 1) + attention_mask_2d = torch.arange(seqlen, device=length.device, + dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1) real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() @@ -217,4 +243,4 @@ def pad_input(hidden_states, indices, batch, seqlen): # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) # output[indices] = hidden_states output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, "(b s) ... -> b s ...", b=batch) \ No newline at end of file + return rearrange(output, "(b s) ... -> b s ...", b=batch) diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index 48c9a79f..1dce190b 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -209,12 +209,12 @@ def _build_model_optimizer(self): init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings) with init_context(): - self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path, - config=config, - torch_dtype=torch.float32, - attn_implementation='flash_attention_2' - if is_cuda_available else 'sdpa', - trust_remote_code=trust_remote_code) + self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + local_model_path, + config=config, + torch_dtype=torch.float32, + attn_implementation='flash_attention_2' if is_cuda_available else 'sdpa', + trust_remote_code=trust_remote_code) # Apply Liger kernel if use_liger is enabled if self.config.model.get('use_liger', False): @@ -253,17 +253,17 @@ def _build_model_optimizer(self): else: cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params) - self.fsdp_model = FSDP(module=self.model, - auto_wrap_policy=auto_wrap_policy, - param_init_fn=init_fn, - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - device_mesh=self.device_mesh, - sync_module_states=True, - device_id=torch.cuda.current_device() if is_cuda_available else - torch.npu.current_device(), - cpu_offload=cpu_offload, - use_orig_params=False) + self.fsdp_model = FSDP( + module=self.model, + auto_wrap_policy=auto_wrap_policy, + param_init_fn=init_fn, + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mixed_precision, + device_mesh=self.device_mesh, + sync_module_states=True, + device_id=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device(), + cpu_offload=cpu_offload, + use_orig_params=False) log_gpu_memory_usage('After FSDP wrapping', logger=logger) @@ -489,7 +489,8 @@ def fit(self): # Perform final validation val_losses = [] for val_data in self.val_dataloader: - val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device) + val_data = TensorDict(val_data, + batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device) val_loss = self.validation_step(val_data) val_losses.append(val_loss) if rank == 0: diff --git a/verl/utils/device.py b/verl/utils/device.py index 55344e5d..0a11d5e2 100644 --- a/verl/utils/device.py +++ b/verl/utils/device.py @@ -1,9 +1,34 @@ # This code is inspired by the torchtune. # https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice,this list +# of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, this +# list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors may +# be used to endorse or promote products derived from this software without specific +# prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT +# SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +# DAMAGE. -import os import logging -from enum import Enum from typing import Optional import torch @@ -69,7 +94,5 @@ def get_torch_device() -> any: try: return getattr(torch, device_name) except AttributeError: - logger.warning( - f"Device namespace '{device_name}' not found in torch, try to load torch.cuda." - ) - return torch.cuda \ No newline at end of file + logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") + return torch.cuda diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 7120418d..2c50e7dc 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -31,9 +31,9 @@ def init_fn(x: torch.nn.Module): if not torch.distributed.get_rank() == 0: - x = x.to_empty(device=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device() if is_cuda_available else torch.npu.current_device(), + x = x.to_empty(device=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device(), recurse=False) - torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() if is_cuda_available else torch.npu.empty_cache() if is_cuda_available else torch.npu.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() return x @@ -127,7 +127,7 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): flat_param._local_shard = flat_param.data assert id(flat_param._local_shard) != id(flat_param.data) if empty_cache: - torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() if is_cuda_available else torch.npu.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() @torch.no_grad() diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 2245f639..c2230a31 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -197,12 +197,12 @@ def _build_model_optimizer(self, with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") - actor_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, - config=actor_model_config, - attn_implementation='flash_attention_2' - if is_cuda_available else 'sdpa', - trust_remote_code=trust_remote_code) + actor_module = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=actor_model_config, + attn_implementation='flash_attention_2' if is_cuda_available else 'sdpa', + trust_remote_code=trust_remote_code) # Apply Liger kernel to the model if use_liger is set to True if use_liger: from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance @@ -411,8 +411,9 @@ def update_actor(self, data: DataProto): if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device() - if is_cuda_available else torch.npu.current_device()) + load_fsdp_optimizer( + optimizer=self.actor_optimizer, + device_id=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device()) data.batch = data.batch.to(DEVICE) @@ -668,12 +669,12 @@ def _build_critic_model_optimizer(self, config): warnings.simplefilter("ignore") setattr(critic_model_config, 'classifier_dropout', 0.) setattr(critic_model_config, 'hidden_dropout', '0') - critic_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, - config=critic_model_config, - attn_implementation='flash_attention_2' - if is_cuda_available else 'sdpa', - trust_remote_code=trust_remote_code) + critic_module = AutoModelForTokenClassification.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=critic_model_config, + attn_implementation='flash_attention_2' if is_cuda_available else 'sdpa', + trust_remote_code=trust_remote_code) # some parameters may not in torch_dtype critic_module.to(torch_dtype) @@ -710,8 +711,7 @@ def _build_critic_model_optimizer(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device() - if is_cuda_available else torch.npu.current_device(), + device_id=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, @@ -792,8 +792,9 @@ def update_critic(self, data: DataProto): if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device() - if is_cuda_available else torch.npu.current_device()) + load_fsdp_optimizer( + optimizer=self.critic_optimizer, + device_id=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device()) # perform forward computation with self.ulysses_sharding_manager: @@ -923,12 +924,12 @@ def _build_model(self, config): with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") setattr(model_config, 'classifier_dropout', 0.) - reward_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path, - config=model_config, - torch_dtype=torch.bfloat16, - attn_implementation='flash_attention_2' - if is_cuda_available else 'sdpa', - trust_remote_code=trust_remote_code) + reward_module = AutoModelForTokenClassification.from_pretrained( + pretrained_model_name_or_path=local_path, + config=model_config, + torch_dtype=torch.bfloat16, + attn_implementation='flash_attention_2' if is_cuda_available else 'sdpa', + trust_remote_code=trust_remote_code) reward_module.to(torch.bfloat16) auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)