Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marlin fp8 #241

Closed
wants to merge 19 commits into from
Closed

Marlin fp8 #241

wants to merge 19 commits into from

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Jul 12, 2024

Fixes #238

@fxmarty fxmarty requested a review from dacorvo as a code owner July 12, 2024 11:50
Comment on lines 226 to 238
if self.weight_qtype == qfloat8_e4m3fn and self.activation_qtype is None:
# Marlin FP8 kernel only supports per-tensor fp8 quantization.
axis = None
else:
axis = 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having this controlflow here is quite ugly.

Copy link
Collaborator

@dacorvo dacorvo Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we have this restriction while the kernel checks for a vector scale.

Copy link
Collaborator

@dacorvo dacorvo Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the kernel code:

  TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
  TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
              " is not size_n = ", size_n);
  // Channelwise only for FP8
  TORCH_CHECK(b_scales.size(0) == 1)
  num_groups = b_scales.size(0);

My message was wrong (I meant vector -> I edited the comment).
So clearly NOT a scalar (this is why in the unit test you had to repeat your scalar scale). I suspect the reason why it only works with a scalar scale (same value for all output features) is because of some kind of inter-leaving that is missing.

@fxmarty
Copy link
Contributor Author

fxmarty commented Jul 12, 2024

Compared to calling the kernel in isolation, in my config end to end benchmark is slower using quanto, due to many overheads.
image

A separate PR may be needed to remove some of them first (or after this one - if we don't care about perf).

I'll try with a larger model / different GPU see if python latency is hidden or not.

@dacorvo Am I doing something wrong here?

@dacorvo
Copy link
Collaborator

dacorvo commented Jul 17, 2024

Compared to calling the kernel in isolation, in my config end to end benchmark is slower using quanto, due to many overheads. image

A separate PR may be needed to remove some of them first (or after this one - if we don't care about perf).

I'll try with a larger model / different GPU see if python latency is hidden or not.

@dacorvo Am I doing something wrong here?

Can you explain how you deduce there is an overhead here, as compared to the standard call from inside torch.nn.functional.linear to torch.ops.aten.mm ? BTW, what tool did you use ?

@fxmarty
Copy link
Contributor Author

fxmarty commented Jul 17, 2024

For sure. Last week I was using quanto benchmark script with

python evaluate_model.py --device cuda --metric decode-latency --quantizer quanto --weights float8_e4m3fn --activations none --dtype fp16 --batch_size 1
python evaluate_model.py --device cuda --metric decode-latency --quantizer quanto --weights none --activations none --dtype fp16 --batch_size 1

and did not have speedups, but I am unsure why. From the profile, I assume it is some overhead from quanto's dispatch, but could be something else.

The following script gives the profile below: https://gist.github.com/fxmarty/1aff830cdd57aa650412f34bd4076b3b, comparing a linear call through a quanto module, through direct torch.ops.quanto_ext.fp8_gemm call, or through fp16 torch.nn.functional.linear.

image

Benchmarking the same (https://gist.github.com/fxmarty/e449c55e4a1dbf9b1657f395aa542eb4) for a single linear, there indeed seems to be a 10-25% overhead (in my problem setup) from torch_function, multiple torch.library dispatch, etc (first and second col). Still faster than native fp16 though, so have to see
image

@fxmarty
Copy link
Contributor Author

fxmarty commented Jul 19, 2024

@dacorvo Running python evaluate_model.py --device cuda --metric decode-latency --quantizer quanto --weights float8_e4m3fn --activations none --dtype fp16 --batch_size 1 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 on my laptop between b8dbdf0 and d52e44a (avoid dispatch when not necessary),

I get:

To me this is a substantial enough difference to care.

Edit: on A100, for llama 3 8B, with end-to-end (prefill + decode), I do get with mixed fp16+fp8:

^ this is not faster than fp16-fp16 in any case (getting 32.681 ms), weird

@dacorvo
Copy link
Collaborator

dacorvo commented Jul 19, 2024

@dacorvo Running python evaluate_model.py --device cuda --metric decode-latency --quantizer quanto --weights float8_e4m3fn --activations none --dtype fp16 --batch_size 1 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 on my laptop between b8dbdf0 and d52e44a (avoid dispatch when not necessary),

I get:

To me this is a substantial enough difference to care.

Edit: on A100, for llama 3 8B, with end-to-end (prefill + decode), I do get with mixed fp16+fp8:

^ this is not faster than fp16-fp16 in any case (getting 32.681 ms), weird

That is indeed a substantial difference. It would be great if you could:

  • run the unit tests to identify the extent of the consequences of your changes on other features,
  • rebase your branch,
  • maybe clean up the branch, group your commits into meaningful groups of changes and use conventional commits, so that we can run the CI.

Comment on lines 72 to 75
activation_qtype (`qtype`, defaults to `None`):
The qtype used for the activations. If one needs to use a different tensor subclass e.g. for weights depending on the activations qtype, this argument must be specified accordingly when calling `QBytesTensor.create`.
tensor_type (`Optional[str]`, defaults to `None`):
Specifies whether the tensor is to be considered as a `"weight"` or `"activation"`, which may influence the tensor subclass to be used.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point we need to know whether we are quantizing a weight or activation, and what is the qtype of the activation so as to pick the correct tensor subclass.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only reason I see for which we would want to do this would be to avoid creating a f8 packed tensor when the activations might be float8, to be able to use scaled_mm later on. For other use cases you can always dequantize the input to be able to call the kernel.
I need to think about this because at this stage this means the factory method is only ever used when creating quantized weights for Linear layers. This means there might actually be a subclass involved here (like a QLinearBytesTensor). To be honest, this was already the case also for AWQBitsTensor and TinyGemmBitsTensor.

Copy link
Contributor Author

@fxmarty fxmarty Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly: when calling QBytesTensor.create for a float8 activation, or for a float8 weight when the activations are float8 as well.

Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks you for this pull-request, that challenges the current design more than I would have expected. I need to think a bit about how to address the valid concerns you raised here.
If we put the organization of the code aside, I also think the restriction to per-tensor scale makes the kernel unusable, at least for quanto.

optimum/quanto/library/extensions/extension.py Outdated Show resolved Hide resolved
optimum/quanto/nn/qlinear.py Outdated Show resolved Hide resolved
optimum/quanto/nn/qmodule.py Outdated Show resolved Hide resolved
optimum/quanto/tensor/marlin/fp8_packed.py Outdated Show resolved Hide resolved
optimum/quanto/tensor/qbytes_ops.py Outdated Show resolved Hide resolved
optimum/quanto/tensor/qtensor.py Outdated Show resolved Hide resolved
optimum/quanto/tensor/qactivation.py Outdated Show resolved Hide resolved
optimum/quanto/tensor/qtensor_func.py Outdated Show resolved Hide resolved
optimum/quanto/tensor/quantizers/symmetric.py Outdated Show resolved Hide resolved
@fxmarty
Copy link
Contributor Author

fxmarty commented Jul 24, 2024

the restriction to per-tensor scale makes the kernel unusable

I think this can be easily changed in a later PR. Neither vllm nor tgi support per-column scales, and yet they achieve nice memory reductions & speedup with claimed no quality loss.

@fxmarty fxmarty force-pushed the marlin-fp8 branch 3 times, most recently from a29517d to f5222fa Compare July 24, 2024 18:00
@fxmarty
Copy link
Contributor Author

fxmarty commented Jul 25, 2024

On par with my tests in TGI, we have decent speedup with this kernel only when using cudagraph. I don't really explain myself why.

Using transformers + A100 + 8B model & measuring decode latency at batch size = 1:

fp8 + torch.compile + static cache
Peak memory during benchmark: 9.7056 GB
Average decode latency per token: 8.833560943603516 ms

fp16 + torch.compile + static cache
Peak memory during benchmark: 15.3295 GB
Average decode latency per token: 12.276221656799317 ms

fp8 + eager
Peak memory during benchmark: 8.5034 GB
Average decode latency per token: 33.871917724609375 ms

fp16 + eager
Peak memory during benchmark: 14.9994 GB
Average decode latency per token: 32.10729331970215 ms

Perplexity bench is very decent as well

@fxmarty fxmarty requested a review from dacorvo July 29, 2024 15:38
Copy link

This PR is stale because it has been open 15 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@dacorvo
Copy link
Collaborator

dacorvo commented Aug 28, 2024

Obsoleted by #296

@dacorvo dacorvo closed this Aug 28, 2024
@dacorvo dacorvo deleted the marlin-fp8 branch September 13, 2024 12:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Integrate marlin fp16/bf16-float8 matrix multiplication kernel
2 participants