Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Superoptimizers #62

Open
jafioti opened this issue May 28, 2024 · 0 comments
Open

Superoptimizers #62

jafioti opened this issue May 28, 2024 · 0 comments

Comments

@jafioti
Copy link
Owner

jafioti commented May 28, 2024

We're looking into superoptimizers, specifically adding a luminal_gpu crate with a gpu-agnostic superoptimizer pass, which will then feed into downstream gpu compilers in luminal_cuda and luminal_metal.

So far, I'm specifically taking ideas from Mirage (https://www.cs.cmu.edu/~zhihaoj2/papers/mirage.pdf), EinNet (https://www.usenix.org/system/files/osdi23-zheng.pdf) and Tensat (https://www.cl.cam.ac.uk/~ey204/teaching/ACS/R244_2022_2023/papers/YANG_MLSYS_2021.pdf). The goal here is to define tensor operations in terms of a linear algebra language, ideally even a scalar-based language (see einnet) which codifies, and then use equality saturation to apply a set of simple rewrite rules simultaneously to discover exponentially many equivalent expressions in linear space. Since these rewrite rules are symbolically guarenteed to generate semantically equivalent expressions, we don't need to worry about testing for correctness like mirage does.

My current thinking for an approach:

Language

We turn linear algebra into a symbolic language, similar to how EinNet does theirs, where Matmul can be represented like this:

A: (mxk) B: (kxn) BT: (nxk) = permute[1,0](B)
Ae: (mxnxk) = expand[1;n](A) BTe: (mxnxk) = expand[0;m](permute[1,0](B))
C = sum[2](expand[1;n](A) * expand[0;m](permute[1,0](B)))

Equality Saturation

We use e-graphs and simple rewriting rules to transform the expressions in this language into various semantically identical forms. Since loops are codified in this language, we can apply rewrites to do loop transformations as well, thereby unifying schedule and algorithm optimizations. We will build an e-graph until it's saturated (or perhaps use a beam width to limit the search space)

Extraction

Given a graph, how can the cost / runtime of that graph be determined? One simple way is to generate a few representative inputs randomly and just them through the graph and measure runtime. This is what tinygrad does (and I think torch.compile). The reason I don't like that much is because I want to make a device-agnostic search stage, so it won't have access to the actual device. Also running a graph with a representative workload takes a long time which severely limits the amount of candidates you can search.

It should be possible to work out all data movement in a graph from global to shared to register memory, and count how many bytes are moving. Weighted appropriately, I wonder if this would be an adequate proxy for cost?

This memory-movement-minimization idea is the main thing that needs to be proven out. Some issues still remain: it doesn't capture occupancy or coalesced memory accesses, which matter a lot. I don't know of any other symbolic cost function out there that doesn't involve learning a cost model.

If anyone has any ideas or papers they think can add value, please throw them here!

@jafioti jafioti pinned this issue May 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant