Skip to content

Commit

Permalink
add details
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Dec 18, 2024
1 parent cec5463 commit 9c3ca61
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions _posts/2024-12-12-flashinfer-v02-release.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ FlashInfer's standout feature is its highly flexible block-sparse FlashAttention
By leveraging [CuTE](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cute/00_quickstart.md)'s `CustomStride` and `ComposedLayout` abstractions, we have extended vector-sparsity to FlashAttention-3. Inspired by [Cutlass's gather/scatter convolution](https://github.com/NVIDIA/cutlass/tree/e1cd8c7866dd6de02b66a89879795e7d7301aacc/examples/59_ampere_gather_scatter_conv), this was achieved through an elegant modification to the producer's memory loading module.

### Performance Benchmark
We compared vector-sparse attention [^2] (PageAttention with `page_size=1`) with dense attention [^3] (the variable-length version) under the same problem sizes for both the FA-2 backend (v0.1.*) and FA-3 backend (v0.2). Benchmarks used `head_dim=128`, `causal=True`, varying batch sizes `(B)` and sequence lengths `(L)` with Gaussian-initialized input Q/K/V tensors.
We compared two attention implementations: PageAttention with `page_size=1` [^2] (use vector-sparse attention implementation) and variable-length dense attention [^3], benchmarking them under identical problem sizes across both FA-2 (v0.1.*) and FA-3 (v0.2) backends. Benchmarks used `head_dim=128`, `causal=True`, varying batch sizes `(B)` and sequence lengths `(L)` with Gaussian-initialized input Q/K/V tensors.

<p align="center">
<img src="/assets/imgs/fa3-template.png" alt="Performance comparison between dense/sparse attention on FA2&3 template" width="800"/>
<br>
Performance comparison between dense/vector-sparse attention on FA-2 and FA-3 templates on H100 SXM5, compiled with CUDA 12.4. y-axis: different settings, x-axis: achieved TFLOPs/s
</p>

**Results:** Vector-sparse attention achieves 90% of dense attention's throughput under identical conditions. The FA-3 backend consistently outperforms FA-2. Thanks to FlashInfer's stable API, upgrading from FA-2 to FA-3 requires no code changes—just install FlashInfer 0.2.
**Results:** Vector-sparse attention achieves 90% of dense attention's throughput under identical conditions. The FA-3 backend consistently outperforms FA-2. Thanks to FlashInfer's stable API, upgrading from FA-2 to FA-3 requires no code changes—just install FlashInfer 0.2. The reference benchmark script for reproducing these results is available [here](https://github.com/flashinfer-ai/flashinfer/blob/d7ac8e3ddc6623572c5c0e44af9e50a4c536a76c/benchmarks/bench_hopper_attention.py).

## JIT Compilation for Attention Customization

Expand Down Expand Up @@ -90,7 +90,7 @@ We integrated **Cutlass 3.5 SM90 Grouped-GEMM** into our [SegmentGEMM](https://d
KV-Cache can now utilize non-contiguous storage layouts, improving support for [offloading](https://github.com/flashinfer-ai/flashinfer/issues/506).

#### Faster `plan` Functions
`plan` functions now use non-blocking host-to-device memory transfers, improving performance.
`plan` functions now use non-blocking host-to-device memory transfers, improving performance. After FlashInfer v0.2, it's encouraged to pass **host tensors** instead of device tensors to reduce synchronization in the `plan` function.

#### KV-Cache Append Optimization
KV-Cache append throughput for small batch sizes was improved by parallelizing per element instead of per request. A new API, [get_batch_indices_positions](https://docs.flashinfer.ai/generated/flashinfer.page.get_batch_indices_positions.html), supports this. Note that we made some breaking changes to this API to accomodate different parallelization mode. See [our benchmark](https://github.com/flashinfer-ai/flashinfer/blob/124daea86fcdff4ba64e5b51337d81a46d6068cb/benchmarks/bench_append_paged_kv_cache.py) for the new API usage.
Expand Down

0 comments on commit 9c3ca61

Please sign in to comment.