-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Marlin fp8 #241
Marlin fp8 #241
Conversation
optimum/quanto/nn/qmodule.py
Outdated
if self.weight_qtype == qfloat8_e4m3fn and self.activation_qtype is None: | ||
# Marlin FP8 kernel only supports per-tensor fp8 quantization. | ||
axis = None | ||
else: | ||
axis = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having this controlflow here is quite ugly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why we have this restriction while the kernel checks for a vector scale.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the kernel code:
TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
" is not size_n = ", size_n);
// Channelwise only for FP8
TORCH_CHECK(b_scales.size(0) == 1)
num_groups = b_scales.size(0);
My message was wrong (I meant vector -> I edited the comment).
So clearly NOT a scalar (this is why in the unit test you had to repeat your scalar scale). I suspect the reason why it only works with a scalar scale (same value for all output features) is because of some kind of inter-leaving that is missing.
Compared to calling the kernel in isolation, in my config end to end benchmark is slower using quanto, due to many overheads. A separate PR may be needed to remove some of them first (or after this one - if we don't care about perf). I'll try with a larger model / different GPU see if python latency is hidden or not. @dacorvo Am I doing something wrong here? |
Can you explain how you deduce there is an overhead here, as compared to the standard call from inside |
For sure. Last week I was using quanto benchmark script with python evaluate_model.py --device cuda --metric decode-latency --quantizer quanto --weights float8_e4m3fn --activations none --dtype fp16 --batch_size 1
python evaluate_model.py --device cuda --metric decode-latency --quantizer quanto --weights none --activations none --dtype fp16 --batch_size 1 and did not have speedups, but I am unsure why. From the profile, I assume it is some overhead from quanto's dispatch, but could be something else. The following script gives the profile below: https://gist.github.com/fxmarty/1aff830cdd57aa650412f34bd4076b3b, comparing a linear call through a quanto module, through direct Benchmarking the same (https://gist.github.com/fxmarty/e449c55e4a1dbf9b1657f395aa542eb4) for a single linear, there indeed seems to be a 10-25% overhead (in my problem setup) from torch_function, multiple torch.library dispatch, etc (first and second col). Still faster than native fp16 though, so have to see |
@dacorvo Running I get: To me this is a substantial enough difference to care. Edit: on A100, for llama 3 8B, with end-to-end (prefill + decode), I do get with mixed fp16+fp8: ^ this is not faster than fp16-fp16 in any case (getting 32.681 ms), weird |
That is indeed a substantial difference. It would be great if you could:
|
optimum/quanto/tensor/qbytes.py
Outdated
activation_qtype (`qtype`, defaults to `None`): | ||
The qtype used for the activations. If one needs to use a different tensor subclass e.g. for weights depending on the activations qtype, this argument must be specified accordingly when calling `QBytesTensor.create`. | ||
tensor_type (`Optional[str]`, defaults to `None`): | ||
Specifies whether the tensor is to be considered as a `"weight"` or `"activation"`, which may influence the tensor subclass to be used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point we need to know whether we are quantizing a weight or activation, and what is the qtype of the activation so as to pick the correct tensor subclass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only reason I see for which we would want to do this would be to avoid creating a f8 packed tensor when the activations might be float8, to be able to use scaled_mm later on. For other use cases you can always dequantize the input to be able to call the kernel.
I need to think about this because at this stage this means the factory method is only ever used when creating quantized weights for Linear layers. This means there might actually be a subclass involved here (like a QLinearBytesTensor). To be honest, this was already the case also for AWQBitsTensor and TinyGemmBitsTensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly: when calling QBytesTensor.create
for a float8 activation, or for a float8 weight when the activations are float8 as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks you for this pull-request, that challenges the current design more than I would have expected. I need to think a bit about how to address the valid concerns you raised here.
If we put the organization of the code aside, I also think the restriction to per-tensor scale makes the kernel unusable, at least for quanto.
I think this can be easily changed in a later PR. Neither vllm nor tgi support per-column scales, and yet they achieve nice memory reductions & speedup with claimed no quality loss. |
a29517d
to
f5222fa
Compare
On par with my tests in TGI, we have decent speedup with this kernel only when using cudagraph. I don't really explain myself why. Using transformers + A100 + 8B model & measuring decode latency at batch size = 1:
Perplexity bench is very decent as well |
This PR is stale because it has been open 15 days with no activity. Remove stale label or comment or this will be closed in 5 days. |
Obsoleted by #296 |
Fixes #238