Memory bandwidth efficient sparse tree attention
- (precompute) chunk the tree into query blocks
- (precompute) compute unique ancestors, attention mask, and leaves for each block
- (runtime) only load keys and values for the query block's unique ancestors and leaves
- (runtime) go fast
- A100 Colab Benchmark
go forth, search the tree of possible futures
![Screen Shot 2024-02-25 at 14 52 59](https://private-user-images.githubusercontent.com/26588632/307627685-f793cd9b-4bf2-48ea-93e3-c0819422a1a4.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkyMjcyNjEsIm5iZiI6MTczOTIyNjk2MSwicGF0aCI6Ii8yNjU4ODYzMi8zMDc2Mjc2ODUtZjc5M2NkOWItNGJmMi00OGVhLTkzZTMtYzA4MTk0MjJhMWE0LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTAlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjEwVDIyMzYwMVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTEyM2U4MGJlMDBlNWEwMWY5NDc1OTQzZGYzNTc4NjE2YzJhYTU5ZmM3OTk2YTQ5YWU5ODI4OTRiZTU1MmVhMTAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.fioyDhXkXTr-Rc1UH60v2LVtB7mZozTpUtKTbTpfRD0)
Notes on precomputation:
- Can probably make this fast enough for runtime with a bit more work since for a dynamic tree structure (i.e. dependent on the model's output), we only need to compute these kernel inputs once, and then they get reused by all attention layers in the model
- Static tree structures are still useful: Medusa uses a size 256 static left weighted tree that gets populated via cartesian products of their multiple topk output heads to accelerate batch size 1 inference by ~3x
Todo:
- Organize blocks based on DFS odering to minimize the number of blocks that need to load the same ancestor KVs (i.e. maximize the shared lineage of each block)
Credits: