-
-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Mamba Block #656
Add Mamba Block #656
Conversation
@patrick-kidger This isn't 100% done yet of course, but the example BTW: Here you can run and train the Mamba model on TinyShakespeare if you want :) |
y = jnp.einsum("d n,n -> d", x, C[i, :]) | ||
return x, y | ||
|
||
_, ys = jax.lax.scan(step, x_res, jnp.arange(seq_len)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should really be done with jax.lax.associative_scan
for computing the scan in parallel. This is more faithful to the original and is much faster for long sequences on parallel processing accelerators (e.g. GPU). Its unclear to me based on the Griffin paper if JAX's associative scan will be fater than linear scan on TPUs also, so perhaps there should be an option to choose which implementation to use.
See the S5 JAX implementation for a very similar implementation with associative scan https://github.com/lindermanlab/S5/blob/main/s5/ssm.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this, actually. JAX's associative scan appears to completely unroll the whole operation:
I think there's probably a better way to implement an associative scan. Probably doing something like equinox/internal/loop/bounded.py
, or possibly a lax.scan
over the levels of the tree.
Alright, I've added those changes in, added those to the docs, added a little training script for the MAMBA example and some model outputs too. But I think there's still a faulty test here. I might have to Edit: Nvm, they passed. |
This PR adds the Mamba Block.
It's still in early draft modus. Along with the Mamba Block, the full Mamba model will be shown in the docs as an example.