-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Mamba Block #656
Closed
Closed
Add Mamba Block #656
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
1afb2e4
State space model start
Artur-Galstyan 8aa2c72
added more docs!
Artur-Galstyan 2e6d8f6
added mamba block, need to test
Artur-Galstyan 572fda4
started with example
Artur-Galstyan b7edd57
included more graphs in the example
Artur-Galstyan 1a20cc6
more mamba example
Artur-Galstyan b76ecd8
mamba progress
Artur-Galstyan 5f02a8c
added mamba example
Artur-Galstyan 030bf77
subkey wrong count
Artur-Galstyan 0747c0a
Merge branch 'patrick-kidger:main' into mamba
Artur-Galstyan 3504d77
added docs and more examples
Artur-Galstyan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,5 @@ examples/CIFAR | |
examples/MNIST | ||
examples/multipart_serialised.eqx | ||
.python-version | ||
.DS_Store | ||
.ruff_cache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# State Spaces | ||
|
||
::: equinox.nn.SelectiveStateSpaceModel | ||
selection: | ||
members: | ||
- __init__ | ||
- __call__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
import math | ||
from typing import Literal, Union | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jaxtyping import Array, Float, PRNGKeyArray | ||
|
||
from .._module import field, Module | ||
from ._conv import Conv1d | ||
from ._linear import Linear | ||
|
||
|
||
def _selective_scan( | ||
u: Float[Array, "seq_len d_inner"], | ||
delta: Float[Array, "seq_len d_inner"], | ||
A: Float[Array, "d_inner state_space_dims"], | ||
B: Float[Array, "seq_len state_space_dims"], | ||
C: Float[Array, "seq_len state_space_dims"], | ||
D: Float[Array, " d_inner"], | ||
): | ||
seq_len, _ = u.shape | ||
d_inner, state_space_dims = A.shape | ||
|
||
delta_A = jnp.exp(jnp.einsum("l d,d n -> l d n", delta, A)) | ||
delta_B_u = jnp.einsum("l d,l n,l d -> l d n", delta, B, u) | ||
|
||
x_res = jnp.zeros(shape=(d_inner, state_space_dims)) | ||
|
||
def step(x, i): | ||
x = delta_A[i] * x + delta_B_u[i] | ||
|
||
y = jnp.einsum("d n,n -> d", x, C[i, :]) | ||
return x, y | ||
|
||
_, ys = jax.lax.scan(step, x_res, jnp.arange(seq_len)) | ||
|
||
ys = ys + u * D | ||
return ys | ||
|
||
|
||
class SelectiveStateSpaceModel(Module, strict=True): | ||
r""" | ||
State Space Model with Selective Scan. This is the implementation of the | ||
Mamba Block from the paper | ||
"Mamba: Linear-Time Sequence Modeling with Selective State Spaces" [1]. | ||
|
||
|
||
??? cite | ||
[Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) | ||
```bibtex | ||
@misc{ | ||
gu2023mamba, | ||
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, | ||
author={Albert Gu and Tri Dao}, | ||
year={2023}, | ||
eprint={2312.00752}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.LG} | ||
} | ||
``` | ||
""" | ||
|
||
n_input_dims: int = field(static=True) | ||
state_space_dims: int = field(static=True) | ||
|
||
d_inner: int = field(static=True) | ||
d_conv: int = field(static=True) | ||
|
||
expand: int = field(static=True) | ||
dt_rank: int = field(static=True) | ||
pad_vocab_size_multiple: int = field(static=True) | ||
|
||
in_proj: Linear | ||
conv1d: Conv1d | ||
|
||
x_proj: Linear | ||
dt_proj: Linear | ||
|
||
A_log: Array | ||
D: Array | ||
|
||
out_proj: Linear | ||
|
||
def __init__( | ||
self, | ||
n_input_dims: int, | ||
state_space_dims: int, | ||
expand: int, | ||
d_conv: int, | ||
dt_rank: Union[int, Literal["auto"]], | ||
pad_vocab_size_multiple: int = 8, | ||
use_bias_in_proj: bool = True, | ||
use_bias_conv1d: bool = True, | ||
use_bias_out_proj: bool = True, | ||
*, | ||
key: PRNGKeyArray, | ||
): | ||
r"""**Arguments:** | ||
|
||
- `n_input_dims`: The dimension of the input. | ||
- `state_space_dims`: The dimension of the SSM (refers to $N$ in [1]). | ||
- `expand`: The expansion factor of the inner dimension (refers to $E$ in [1]). | ||
- `d_conv`: The kernel size of the convolutional layer | ||
- `dt_rank`: The rank of delta. If "auto", it will be set to | ||
ceil(n_input_dims / state_space_dims). | ||
- `pad_vocab_size_multiple`: The multiple of the vocabulary size | ||
- `use_bias_in_proj`: Whether to use bias in the input projection layer. | ||
- `use_bias_conv1d`: Whether to use bias in the convolutional layer. | ||
- `use_bias_out_proj`: Whether to use bias in the output projection layer. | ||
- `key`: The PRNG key. | ||
|
||
""" | ||
self.n_input_dims = n_input_dims | ||
self.state_space_dims = state_space_dims | ||
|
||
self.d_conv = d_conv | ||
self.expand = expand | ||
|
||
self.d_inner = int(self.expand * self.n_input_dims) | ||
|
||
self.pad_vocab_size_multiple = pad_vocab_size_multiple | ||
|
||
if dt_rank == "auto": | ||
self.dt_rank = math.ceil(self.n_input_dims / self.state_space_dims) | ||
|
||
( | ||
key, | ||
linear_key, | ||
conv1d_key, | ||
x_proj_key, | ||
dt_proj_key, | ||
out_proj_key, | ||
) = jax.random.split(key, 6) | ||
|
||
self.in_proj = Linear( | ||
n_input_dims, | ||
self.d_inner * 2, | ||
use_bias=use_bias_in_proj, | ||
key=linear_key, | ||
) | ||
|
||
self.conv1d = Conv1d( | ||
in_channels=self.d_inner, | ||
out_channels=self.d_inner, | ||
kernel_size=d_conv, | ||
use_bias=use_bias_conv1d, | ||
groups=self.d_inner, | ||
padding=d_conv - 1, | ||
key=conv1d_key, | ||
) | ||
|
||
self.x_proj = Linear( | ||
self.d_inner, | ||
self.dt_rank + state_space_dims * 2, | ||
use_bias=False, | ||
key=x_proj_key, | ||
) | ||
|
||
self.dt_proj = Linear( | ||
self.dt_rank, self.d_inner, use_bias=True, key=dt_proj_key | ||
) | ||
|
||
A = ( | ||
jnp.repeat(jnp.arange(1, self.state_space_dims + 1), self.d_inner) | ||
.reshape(self.state_space_dims, self.d_inner) | ||
.transpose() | ||
) | ||
Artur-Galstyan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.A_log = jnp.log(A) | ||
self.D = jnp.ones(self.d_inner) | ||
self.out_proj = Linear( | ||
self.d_inner, | ||
self.n_input_dims, | ||
use_bias=use_bias_out_proj, | ||
key=x_proj_key, | ||
) | ||
|
||
@jax.named_scope("eqx.nn.SelectiveStateSpaceModel") | ||
def __call__(self, x: Float[Array, "seq_len n_input_dims"]) -> Array: | ||
r"""**Arguments:** | ||
|
||
- `x`: The input sequence. Should be a JAX array of | ||
shape `(seq_len, n_input_dims)`. | ||
|
||
**Returns:** | ||
|
||
- A JAX array of shape `(seq_len, n_input_dims)`. | ||
|
||
""" | ||
seq_len, d = x.shape | ||
if d != self.n_input_dims: | ||
raise ValueError( | ||
f"Input dimension mismatch: expected {self.n_input_dims}, got {d}" | ||
) | ||
x_and_res = jax.vmap(self.in_proj)(x) | ||
(x, res) = jnp.split(x_and_res, 2, axis=-1) | ||
|
||
x = jnp.transpose(x) | ||
x = self.conv1d(x)[:, :seq_len] | ||
x = jnp.transpose(x) | ||
x = jax.nn.silu(x) | ||
|
||
y = self._ssm(x) | ||
y = y * jax.nn.silu(res) | ||
|
||
output = jax.vmap(self.out_proj)(y) | ||
return output | ||
|
||
def _ssm(self, x: Float[Array, "seq_len d_inner"]) -> Array: | ||
A = -jnp.exp(self.A_log) | ||
D = self.D | ||
|
||
x_delta_b_c = jax.vmap(self.x_proj)(x) | ||
|
||
split_indices = [ | ||
self.dt_rank, | ||
self.dt_rank + self.state_space_dims, | ||
] | ||
delta, B, C = jnp.split(x_delta_b_c, split_indices, axis=-1) | ||
delta = jax.nn.softplus(jax.vmap(self.dt_proj)(delta)) | ||
|
||
y = _selective_scan(x, delta, A, B, C, D) | ||
return y |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should really be done with
jax.lax.associative_scan
for computing the scan in parallel. This is more faithful to the original and is much faster for long sequences on parallel processing accelerators (e.g. GPU). Its unclear to me based on the Griffin paper if JAX's associative scan will be fater than linear scan on TPUs also, so perhaps there should be an option to choose which implementation to use.See the S5 JAX implementation for a very similar implementation with associative scan https://github.com/lindermanlab/S5/blob/main/s5/ssm.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this, actually. JAX's associative scan appears to completely unroll the whole operation:
https://github.com/google/jax/blob/63ceb5f539c45fe00766634ce7b01ea5176e0cc4/jax/_src/lax/control_flow/loops.py#L2171-L2186
I think there's probably a better way to implement an associative scan. Probably doing something like
equinox/internal/loop/bounded.py
, or possibly alax.scan
over the levels of the tree.