-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow scalar broadcasting in VisitorRowBroadcast and VisitorColBroadcast #1539
base: main
Are you sure you want to change the base?
Conversation
Very happy to add unit tests and put in the work to get this PR into a landable state. But first hoping to get some high-level feedback on whether this is the right approach or a reasonable thing to do. Thanks! |
cc @mnicely |
Hi @tlrmchlsmth thanks for your contribution. I'm working on int8 GEMM with dequant fusion. // inputs
// A [M, K] int8
// B [N, K] int8
// alphaCol [M, 1] fp32
// alphaRow [1, N] fp32
// outputs
// mat [M, N] fp32
// alphaCol [M, 1] fp32
using V1Broadcast = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<int32_t, _1, _0> // StrideMNL
>;
// alphaRow [1, N] fp32
using V2Broadcast = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<_0, _1, int32_t> // StrideMNL
>;
I don’t quite understand this PR. Regarding this issue, could you please provide some examples? In what situations won’t it work, and in what situations will it work based on this PR? |
@Hongbosherlock In the second, the row and column broadcast epilogues ( I tried the approach you suggest for cutlass 2.0 but couldn't get it to compile. If you have a full working example, I'd like to see it :) Anyway, the same approach won't work for cutlass 3.0, as you will fail this static assert |
@hwu36 can we ask Zhaodong to merge this? I don't know his GitHub username |
JFYI I did end up going in a different direction with these epilogue changes. See vllm-project/vllm#5137 -- I found that it was much nicer for a variety of reasons if both the scalar and the vector broadcast cases take a |
bool guard = get<1>(coord_v(i)) < n; | ||
cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const*)&src_v(i), guard); | ||
} | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: New line after branch close.
CUTLASS_PRAGMA_UNROLL | ||
for (int i = 0; i < size(src_v); ++i) { | ||
if(get<1>(coord_v(i)) < n) | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not: no new line before brace open
|
||
CUTLASS_PRAGMA_UNROLL | ||
for (int i = 0; i < size(src_v); ++i) { | ||
if(get<1>(coord_v(i)) < n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: spacing if (get
copy_if(pred, tC_gCol, tC_rCol); | ||
|
||
if (params_ptr->ptr_col) { | ||
// In this case we are loading from a column vector and broadcasting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A design question: this isn't really a scalar operation anymore. Does it make sense to extend this visitor, or to add replace this with a vector broadcast instead that then has a broadcasting layout for its data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace this with a vector broadcast instead that then has a broadcasting layout for its data
Could you point me to an example of this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for taking a look at the PR BTW :)
Hi @tlrmchlsmth, thanks for the PR! One question I have is that can we use the |
@apuaaChen We could totally do that, but then in order to have a kernel for every case of fp8 quantized GEMM that we need to support, this is 4x the number of kernels. The activations can have per-tensor or per-token scales and weights can have per-tensor or per-output channel scales. So this PR lets us pick a another point in the binary size/performance tradeoff space. |
@tlrmchlsmth Got it! Let me merge it. Thanks for the explanation. |
@apuaaChen while @tlrmchlsmth ended up using a custom visitor that loads both a row and a scalar from the |
I guess could you let me know if you plan to merge it, or if there's any cleanup you want me to do before we merge. I also have a version with a boolean |
@ProExpertProg Please push your changes to this branch. I will first merge your updates to our internal repo. After the CI is passed, I can get your PR merged, thanks! |
80a5654
to
0b6c76e
Compare
Perfect, thank you!! |
And please don't hesitate to ask for any changes or improved comments, and feel free to make edits yourself if there are any style/formatting issues. |
This PR has been labeled |
@apuaaChen were you able to get the PR run on the internal CI? |
Yes,It passed the internal CI. I’m combining it with a few other fixes right now |
This PR has been labeled |
This PR addresses an inconsistency between the VisitorRowBroadcast/VisitorColBroadcast epilogues and the SM90RowBroadcast/SM90ColBroadcast epilogues.
The inconsistency is that the SM90 epilogues can handle either row/column broadcasting by passing in a nullptr for the first argument, and a float for the second, while the visitor epilogues cannot. This PR adds this functionality to the visitor epilogues.
I am using this for quantized GEMMs that can handle either per-token/per channel quantization or per-tensor quantization without compiling and distributing multiple kernels to handle all cases.
For reference, I ran into this issue when developing vllm-project/vllm#4749