Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrated Rotary Positional Embeddings (RoPEs) into flash_attn_kvcache #83

Open
wants to merge 26 commits into
base: main_perf
Choose a base branch
from

Conversation

alexkranias-amd
Copy link

@alexkranias-amd alexkranias-amd commented Sep 27, 2024

Motivation

Original Paper: RoFormer: Enhanced Transformer with Rotary Position Embedding

Rotary Positional Embeddings (RoPEs) are a common positional embedding type used in many transformer models today.

RoPEs work by applying a unique rotation transformation to the vectors that represent each token within our q and k tensors based on each token's respective position in the sequence $$m$$.

To compute attention, we must first compute $$\text{matmul(}Q \text{,} ~ K^T \text{)}$$. This effectively is taking the dot product between the vector embeddings of tokens in $$Q$$ and $$K^T$$. Given two tokens at positions $$i$$ and $$j$$, the closer $$i$$ and $$j$$ are to each other, then their vector embeddings will end up getting rotated roughly the same amount, and the dot product between these two token embedding vectors will be largely unchanged. However, the further away these tokens are from each other, the more the transformation applied to these two vector embeddings diverges, which causes the dot product to decay. As the dot product decays, so does the attention weighting applied between the two tokens, and likewise this effectively leads the model to learning that for a single token the tokens near it should be paid more attention to than the tokens much further away.

Dot Product Decay

A more detailed explanation

Fundamentally RoPEs work by dividing the embedding space of our q and k vectors (the $$\text{head}$$ _ $$\text{dim}$$) into many chunks of two. Each 2-dimensional chunk can be thought of as a vector subcomponent of q and k projected on a 2-dimensional plane that exists within the higher dimensional space of the q and k embedding. RoPE "rotates" the planar chunks of our q and k vectors uniquely based on the index of the token in the sequence. Each "chunk" is rotated some unique amount $$\theta_{m, d/2}$$ based on the index of the token in the sequence $$m$$, and the dimension $$d$$ of the subcomponents of q and k being rotated.

RoPE Implementation Details

Implementation

RoPE is applied to Q and K at every attention layer. For developing a kernel there are two options:

  1. Rotate Q and K using one kernel, then pass in the new rotated Q and K vectors into our flash_attn_kernel
  2. Fuse RoPE into our flash_attn_kernel

Since Tri Dao already had a functional seperate RoPE kernel. I implemented approach 1 first.

Seperate RoPE and FlashAttention Kernels

We import from flash_attn.layers.rotary import apply_rotary_emb

Within class _attention(torch.autograd.Function) before calling splitk_flash_attn we rotate q and input_metadata.k_new by making a call to this method apply_rotary_emb which makes a call to a Triton kernel.

Fused RoPE into FlashAttention

TODO

More Notes

Can be found at the following issue: https://github.com/ROCm/triton-internal/issues/33

micmelesse and others added 25 commits August 29, 2024 09:18
* flash_attn_func works

Compress

This is a combination of 12 commits.

add scripts

save

add our kernel

import our kernel

round trip

use bshd layout

figure out segfault

fix

show backward failure with prints

save backward work

run forward only

test smallest config on everything

add test

fix

remove pre commit

install triton

skip dropout

pin d

32 factor d

just run power of 2

remove timeout

run serially

clean up

clean up 2

* Varlen works

This is a combination of 6 commits.

save

some tests passing

enable more

enable everything

move around

alibi works

* keep interface and kernel seperate

* clean up
* Compress kvcache work

This is a combination of 11 commits.

kvcache work

This is a combination of 4 commits.

kvcache is not supported

save

save decode

save

clean up merge

save cases

save

save

save

save

key mask on triton side

fix q size issue

test combos

save

* fix causal. use cache_seqlens

* clean and test what works

* some configs work on new_kv but fails on 1,8

* cache overwrite correct

* new_kv works more or less

* test local

* work on paged kv attention

* prefill paged attention

* fix has_batch_idx and skip local and rotatary emb

* save

* save

* save

* save

* handle new_kv when paged kv cache

* all except has_batch_idx works

* major options are green

* test all

* add tests

* save

* clean up

* minor clean up

* simplest config

* save debug true

* save

* refactor slightly

* save work

* need key masking

* force hip

* use is_hip

* save

* fix cache_seq_len issue

* work on new_kv

* pass new_kv data

* save

* benchmark fwd only

* disable debug

* pandas pdf

* save

* set methods

* record number of heads

* use configs

* flexiable dim, n-heads, headofdim

* better benchmarking

* basic inplace update working

* works upto 64

* new_kv supported!

* test case for has_batch_idx

* has_batch_idx works!

* save

* save

* save

* save ref

* fix mqa and gqa by duplicating

* GQA and MQA working by kernel modifications

* fix new_kv with gqa

* cache index

* deal with nans on fwd_splitk

* save

* causal working on basic case

* causal works!

* alibi works!

* clean up

* clean prefill changes

* remove bwd stuff

* limit decode test to test_op_fwd

* add ref

* use bfloat
Fixes after rebase

rebase fixes

deal with kvcache failure

new run for branch

cancel-in-progress

fix varlen_fwd bug
* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride
- added a pyskip for an odd case of using mha_type:"gqa"
- changed batch_size=1 and nheads=1#
@micmelesse micmelesse force-pushed the main_perf branch 2 times, most recently from 5d03d58 to 730d260 Compare October 28, 2024 19:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants