Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mamba Block #656

Closed
wants to merge 11 commits into from
Closed

Conversation

Artur-Galstyan
Copy link
Contributor

This PR adds the Mamba Block.

It's still in early draft modus. Along with the Mamba Block, the full Mamba model will be shown in the docs as an example.

@Artur-Galstyan
Copy link
Contributor Author

Artur-Galstyan commented Mar 4, 2024

@patrick-kidger This isn't 100% done yet of course, but the example mamba.iypnb is almost done. The actual implementation of the SelectiveStateSpaceModel module is in there and will later be transferred to its own file. Besides the TODO in the example, what do you think of it? Anything extra you would like to see there?

BTW: Here you can run and train the Mamba model on TinyShakespeare if you want :)

y = jnp.einsum("d n,n -> d", x, C[i, :])
return x, y

_, ys = jax.lax.scan(step, x_res, jnp.arange(seq_len))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should really be done with jax.lax.associative_scan for computing the scan in parallel. This is more faithful to the original and is much faster for long sequences on parallel processing accelerators (e.g. GPU). Its unclear to me based on the Griffin paper if JAX's associative scan will be fater than linear scan on TPUs also, so perhaps there should be an option to choose which implementation to use.

See the S5 JAX implementation for a very similar implementation with associative scan https://github.com/lindermanlab/S5/blob/main/s5/ssm.py

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this, actually. JAX's associative scan appears to completely unroll the whole operation:

https://github.com/google/jax/blob/63ceb5f539c45fe00766634ce7b01ea5176e0cc4/jax/_src/lax/control_flow/loops.py#L2171-L2186

I think there's probably a better way to implement an associative scan. Probably doing something like equinox/internal/loop/bounded.py, or possibly a lax.scan over the levels of the tree.

equinox/nn/_selective_state_space_models.py Outdated Show resolved Hide resolved
examples/mamba.ipynb Outdated Show resolved Hide resolved
@Artur-Galstyan Artur-Galstyan changed the base branch from main to dev March 23, 2024 14:00
@Artur-Galstyan Artur-Galstyan marked this pull request as ready for review March 23, 2024 15:32
@Artur-Galstyan
Copy link
Contributor Author

Artur-Galstyan commented Mar 23, 2024

Alright, I've added those changes in, added those to the docs, added a little training script for the MAMBA example and some model outputs too. But I think there's still a faulty test here. I might have to git rebase dev if they still fail.

Edit: Nvm, they passed.

@patrick-kidger patrick-kidger deleted the branch patrick-kidger:dev April 14, 2024 12:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants