Skip to content

proger/accelerated-scan

Repository files navigation

Accelerated Scan

PyPI Version DOI

This package implements the fastest first-order parallel associative scan on the GPU for forward and backward.

The scan efficiently solves first-order recurrences of the form x[t] = gate[t] * x[t-1] + token[t], common in state space models and linear RNNs.

The accelerated_scan.warp C++ CUDA kernel uses a chunked processing algorithm that leverages the fastest GPU communication primitives available on each level of hierarchy: warp shuffles within warps of 32 threads and shared memory (SRAM) between warps within a thread block. One sequence per channel dimension is confined to one thread block.

The derivation of Chunked Scan has been used to extend tree-level Blelloch algorithm to block.

A similar implementation is available in accelerated_scan.triton using a Triton's tl.associative_scan primitive. It requires Triton 2.2 for its enable_fp_fusion flag.

Quick Start:

pip install accelerated-scan
import torch
from accelerated_scan.warp import scan # a pure c++ kernel, faster than cub
#from accelerated_scan.triton import scan # uses tl.associative_scan
#from accelerated_scan.ref import scan # reference torch implementation

# sequence lengths must be a power of 2 of lengths between 32 and 65536
# hit me up if you need different lengths!

batch_size, dim, seqlen = 3, 1536, 4096
gates = 0.999 + 0.001 * torch.rand(batch_size, dim, seqlen, device="cuda")
tokens = torch.rand(batch_size, dim, seqlen, device="cuda")

out = scan(gates, tokens)

To ensure numerical equivalence, a reference implementation for trees is provided in Torch. It can be sped up using torch.compile.

Benchmarks:

bench.png

See more benchmarks in nanokitchen: https://github.com/proger/nanokitchen

forward speed of (8,1536,seqlen), inference mode:

   SEQUENCE_LENGTH  accelerated_scan.triton (triton 2.2.0)  accelerated_scan.ref  accelerated_scan.warp
0            128.0                                0.027382              0.380874               0.026844
1            256.0                                0.049104              0.567916               0.048593
2            512.0                                0.093008              1.067906               0.092923
3           1024.0                                0.181856              2.048471               0.183581
4           2048.0                                0.358250              3.995369               0.355414
5           4096.0                                0.713511              7.897022               0.714536
6           8192.0                                1.433052             15.698944               1.411390
7          16384.0                                3.260965             31.305046               2.817152
8          32768.0                               31.459671             62.557182               5.645697
9          65536.0                               66.787331            125.208572              11.297921

Notes on Precision

When gates and tokens are sampled uniformly from 0..1 the lack of bfloat16 precision dominates the error (compared to the reference implementation):

max-abs-error.png