You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The MiniMax-01 scales linear attention to large-scale model (456B) and FlashInfer should support it.
The prefill computation of the lightning attention (forward) can be summarized as:
The computation of O_intra of each tile is completely independent and we can just reuse our existing attention kernel by setting use_softmax=False in our attention variant class.
The computation of O_inter is basically a scan operation, we can either perform the entire loop per request within a CTA, or using split-K. In the second case, we split the N into chunks, we first compute the KV matrix of each chunk, compute the cumsum of KV, then compute the O_inter of all tiles independently. The split-k chunk size can be selected adaptively to strike a balance between the O_inter overhead (determined by number of chunks) and the O_intra computation overhead (determined by chunk size). KV should be kept in f32 precision considering the accumulation precision for long context.
For decode, there is no need to maintain KV-Cache in Page Table, we just need to keep one KV (dxd) matrix per request, and accumulating KV by Ki^T Vi for step i. It's still possible to maintain a unified page for softmax attention layers' KV-Cache and linear attention layers' KV, in that case, we can add gather gemm operators to flashinfer for O_inter computation.
The text was updated successfully, but these errors were encountered:
The MiniMax-01 scales linear attention to large-scale model (456B) and FlashInfer should support it.
The prefill computation of the lightning attention (forward) can be summarized as:
The computation of
O_intra
of each tile is completely independent and we can just reuse our existing attention kernel by settinguse_softmax=False
in our attention variant class.The computation of
O_inter
is basically a scan operation, we can either perform the entire loop per request within a CTA, or using split-K. In the second case, we split the N into chunks, we first compute the KV matrix of each chunk, compute the cumsum of KV, then compute theO_inter
of all tiles independently. The split-k chunk size can be selected adaptively to strike a balance between theO_inter
overhead (determined by number of chunks) and theO_intra
computation overhead (determined by chunk size).KV
should be kept in f32 precision considering the accumulation precision for long context.For decode, there is no need to maintain KV-Cache in Page Table, we just need to keep one
KV
(dxd) matrix per request, and accumulating KV by Ki^T Vi for step i. It's still possible to maintain a unified page for softmax attention layers' KV-Cache and linear attention layers'KV
, in that case, we can add gather gemm operators to flashinfer forO_inter
computation.The text was updated successfully, but these errors were encountered: