From 1afb2e4afdb0e042ed70845f6598da45fd7447d0 Mon Sep 17 00:00:00 2001 From: "Artur A. Galstyan" Date: Wed, 14 Feb 2024 00:14:30 +0100 Subject: [PATCH 01/10] State space model start --- .gitignore | 2 + equinox/nn/_state_space_models.py | 99 +++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 equinox/nn/_state_space_models.py diff --git a/.gitignore b/.gitignore index ebeaa306..817b545d 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ examples/CIFAR examples/MNIST examples/multipart_serialised.eqx .python-version +.DS_Store + diff --git a/equinox/nn/_state_space_models.py b/equinox/nn/_state_space_models.py new file mode 100644 index 00000000..7b528f4d --- /dev/null +++ b/equinox/nn/_state_space_models.py @@ -0,0 +1,99 @@ +import math +from typing import Literal, Union + +import jax +from jaxtyping import Array, PRNGKeyArray + +from .._module import field, Module +from ._conv import Conv1d +from ._linear import Linear + + +class StateSpaceModel(Module, strict=True): + d_inner: int = field(static=True) + d_state: int = field(static=True) + d_conv: int = field(static=True) + n_embd: int = field(static=True) + n_dims: int = field(static=True) + expand: int = field(static=True) + dt_rank: int = field(static=True) + pad_vocab_size_multiple: int = field(static=True) + + in_proj: Linear + conv1d: Conv1d + + x_proj: Linear + dt_proj: Linear + + A_log: Array + D: Array + + out_proj: Linear + + def __init__( + self, + n_embd: int, + expand: int, + d_state: int, + d_conv: int, + dt_rank: Union[int, Literal["auto"]], + pad_vocab_size_multiple: int = 8, + n_dims: int = 256, + use_bias_in_proj: bool = True, + use_bias_conv1d: bool = True, + *, + key: PRNGKeyArray, + ): + self.d_state = d_state + self.d_conv = d_conv + self.n_embd = n_embd + self.expand = expand + self.n_dims = n_dims + self.d_inner = int(self.expand * self.n_embd) + self.pad_vocab_size_multiple = pad_vocab_size_multiple + + if dt_rank == "auto": + self.dt_rank = math.ceil(self.n_embd / self.d_state) + + if self.n_dims % self.pad_vocab_size_multiple != 0: + self.n_dims += ( + self.pad_vocab_size_multiple + - self.n_dims % self.pad_vocab_size_multiple + ) + + ( + key, + linear_key, + conv1d_key, + x_proj_key, + dt_proj_key, + out_proj_key, + ) = jax.random.split(key, 6) + + self.in_proj = Linear( + n_embd, + self.d_inner * 2, + use_bias=use_bias_in_proj, + key=linear_key, + ) + + self.conv1d = Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + kernel_size=d_conv, + use_bias=use_bias_conv1d, + groups=self.d_inner, + padding=d_conv - 1, + key=conv1d_key, + ) + + self.x_proj = Linear( + self.d_inner, + self.dt_rank + d_state * 2, + use_bias=False, + key=x_proj_key, + ) + + @jax.named_scope("eqx.nn.StateSpaceModel") + def __call__(self) -> Array: + raise NotImplementedError From 8aa2c72cab81e2d0fc66f3889bc91335d72ed948 Mon Sep 17 00:00:00 2001 From: "Artur A. Galstyan" Date: Wed, 14 Feb 2024 23:02:44 +0100 Subject: [PATCH 02/10] added more docs! --- .gitignore | 2 +- equinox/nn/_selective_state_space_models.py | 133 ++++++++++++++++++++ equinox/nn/_state_space_models.py | 99 --------------- 3 files changed, 134 insertions(+), 100 deletions(-) create mode 100644 equinox/nn/_selective_state_space_models.py delete mode 100644 equinox/nn/_state_space_models.py diff --git a/.gitignore b/.gitignore index 817b545d..978ba5f8 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,4 @@ examples/MNIST examples/multipart_serialised.eqx .python-version .DS_Store - +.ruff_cache diff --git a/equinox/nn/_selective_state_space_models.py b/equinox/nn/_selective_state_space_models.py new file mode 100644 index 00000000..98c05a8e --- /dev/null +++ b/equinox/nn/_selective_state_space_models.py @@ -0,0 +1,133 @@ +import math +from typing import Literal, Union + +import jax +import jax.numpy as jnp +from jaxtyping import Array, PRNGKeyArray + +from .._module import field, Module +from ._conv import Conv1d +from ._linear import Linear + + +class SelectiveStateSpaceModel(Module, strict=True): + """ + State Space Model with Selective Scan. This is the implementation of the + Mamba Block from the paper + "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" [1]. + + [1] Albert Gu and Tri Dao, Mamba: Linear-Time Sequence Modeling + with Selective State Spaces, 2023 + """ + + n_input_dims: int = field(static=True) + state_space_dims: int = field(static=True) + + d_inner: int = field(static=True) + d_conv: int = field(static=True) + + expand: int = field(static=True) + dt_rank: int = field(static=True) + pad_vocab_size_multiple: int = field(static=True) + + in_proj: Linear + conv1d: Conv1d + + x_proj: Linear + dt_proj: Linear + + A_log: Array + D: Array + + out_proj: Linear + + def __init__( + self, + n_input_dims: int, + state_space_dims: int, + expand: int, + d_conv: int, + dt_rank: Union[int, Literal["auto"]], + pad_vocab_size_multiple: int = 8, + use_bias_in_proj: bool = True, + use_bias_conv1d: bool = True, + use_bias_out_proj: bool = True, + *, + key: PRNGKeyArray, + ): + """ + Args: + n_input_dims: The dimension of the input. + state_space_dims: The dimension of the SSM (refers to 'N' in [1]). + expand: The expansion factor of the inner dimension (refers to 'E' in [1]). + d_conv: The kernel size of the convolutional layer + dt_rank: The rank of delta. If "auto", it will be + set to ceil(n_input_dims / state_space_dims). + pad_vocab_size_multiple: The multiple of the vocabulary size + + """ + self.n_input_dims = n_input_dims + self.state_space_dims = state_space_dims + + self.d_conv = d_conv + self.expand = expand + + self.d_inner = int(self.expand * self.n_input_dims) + + self.pad_vocab_size_multiple = pad_vocab_size_multiple + + if dt_rank == "auto": + self.dt_rank = math.ceil(self.n_input_dims / self.state_space_dims) + + ( + key, + linear_key, + conv1d_key, + x_proj_key, + dt_proj_key, + out_proj_key, + ) = jax.random.split(key, 6) + + self.in_proj = Linear( + n_input_dims, + self.d_inner * 2, + use_bias=use_bias_in_proj, + key=linear_key, + ) + + self.conv1d = Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + kernel_size=d_conv, + use_bias=use_bias_conv1d, + groups=self.d_inner, + padding=d_conv - 1, + key=conv1d_key, + ) + + self.x_proj = Linear( + self.d_inner, + self.dt_rank + state_space_dims * 2, + use_bias=False, + key=x_proj_key, + ) + + self.dt_proj = Linear( + self.dt_rank, self.d_inner, use_bias=True, key=dt_proj_key + ) + + A = jnp.repeat(jnp.arange(1, self.state_space_dims + 1), self.d_inner).reshape( + self.d_inner, self.state_space_dims + ) + self.A_log = jnp.log(A) + self.D = jnp.ones(self.d_inner) + self.out_proj = Linear( + self.d_inner, + self.n_input_dims, + use_bias=use_bias_out_proj, + key=x_proj_key, + ) + + @jax.named_scope("eqx.nn.StateSpaceModel") + def __call__(self) -> Array: + raise NotImplementedError diff --git a/equinox/nn/_state_space_models.py b/equinox/nn/_state_space_models.py deleted file mode 100644 index 7b528f4d..00000000 --- a/equinox/nn/_state_space_models.py +++ /dev/null @@ -1,99 +0,0 @@ -import math -from typing import Literal, Union - -import jax -from jaxtyping import Array, PRNGKeyArray - -from .._module import field, Module -from ._conv import Conv1d -from ._linear import Linear - - -class StateSpaceModel(Module, strict=True): - d_inner: int = field(static=True) - d_state: int = field(static=True) - d_conv: int = field(static=True) - n_embd: int = field(static=True) - n_dims: int = field(static=True) - expand: int = field(static=True) - dt_rank: int = field(static=True) - pad_vocab_size_multiple: int = field(static=True) - - in_proj: Linear - conv1d: Conv1d - - x_proj: Linear - dt_proj: Linear - - A_log: Array - D: Array - - out_proj: Linear - - def __init__( - self, - n_embd: int, - expand: int, - d_state: int, - d_conv: int, - dt_rank: Union[int, Literal["auto"]], - pad_vocab_size_multiple: int = 8, - n_dims: int = 256, - use_bias_in_proj: bool = True, - use_bias_conv1d: bool = True, - *, - key: PRNGKeyArray, - ): - self.d_state = d_state - self.d_conv = d_conv - self.n_embd = n_embd - self.expand = expand - self.n_dims = n_dims - self.d_inner = int(self.expand * self.n_embd) - self.pad_vocab_size_multiple = pad_vocab_size_multiple - - if dt_rank == "auto": - self.dt_rank = math.ceil(self.n_embd / self.d_state) - - if self.n_dims % self.pad_vocab_size_multiple != 0: - self.n_dims += ( - self.pad_vocab_size_multiple - - self.n_dims % self.pad_vocab_size_multiple - ) - - ( - key, - linear_key, - conv1d_key, - x_proj_key, - dt_proj_key, - out_proj_key, - ) = jax.random.split(key, 6) - - self.in_proj = Linear( - n_embd, - self.d_inner * 2, - use_bias=use_bias_in_proj, - key=linear_key, - ) - - self.conv1d = Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - kernel_size=d_conv, - use_bias=use_bias_conv1d, - groups=self.d_inner, - padding=d_conv - 1, - key=conv1d_key, - ) - - self.x_proj = Linear( - self.d_inner, - self.dt_rank + d_state * 2, - use_bias=False, - key=x_proj_key, - ) - - @jax.named_scope("eqx.nn.StateSpaceModel") - def __call__(self) -> Array: - raise NotImplementedError From 2e6d8f659d5d29c18b82af3e3bd99971f86022aa Mon Sep 17 00:00:00 2001 From: "Artur A. Galstyan" Date: Thu, 15 Feb 2024 21:23:08 +0100 Subject: [PATCH 03/10] added mamba block, need to test --- equinox/nn/_selective_state_space_models.py | 67 ++++++++++++++++++++- 1 file changed, 64 insertions(+), 3 deletions(-) diff --git a/equinox/nn/_selective_state_space_models.py b/equinox/nn/_selective_state_space_models.py index 98c05a8e..f09e7474 100644 --- a/equinox/nn/_selective_state_space_models.py +++ b/equinox/nn/_selective_state_space_models.py @@ -3,13 +3,41 @@ import jax import jax.numpy as jnp -from jaxtyping import Array, PRNGKeyArray +from jaxtyping import Array, Float, PRNGKeyArray from .._module import field, Module from ._conv import Conv1d from ._linear import Linear +def _selective_scan( + u: Float[Array, "seq_len d_inner"], + delta: Float[Array, "seq_len d_inner"], + A: Float[Array, "d_inner state_space_dims"], + B: Float[Array, "seq_len state_space_dims"], + C: Float[Array, "seq_len state_space_dims"], + D: Float[Array, "d_inner"], # noqa +): + seq_len, _ = u.shape + d_inner, state_space_dims = A.shape + + delta_A = jnp.exp(jnp.einsum("l d,d n -> l d n", delta, A)) + delta_B_u = jnp.einsum("l d,l n,l d -> l d n", delta, B, u) + + x_res = jnp.zeros(shape=(d_inner, state_space_dims)) + + def step(x, i): + x = delta_A[i] * x + delta_B_u[i] + + y = jnp.einsum("d n,n -> d", x, C[i, :]) + return x, y + + _, ys = jax.lax.scan(step, x_res, jnp.arange(seq_len)) + + ys = ys + u * D + return ys + + class SelectiveStateSpaceModel(Module, strict=True): """ State Space Model with Selective Scan. This is the implementation of the @@ -129,5 +157,38 @@ def __init__( ) @jax.named_scope("eqx.nn.StateSpaceModel") - def __call__(self) -> Array: - raise NotImplementedError + def __call__(self, x: Float[Array, "seq_len n_input_dims"]) -> Array: + seq_len, d = x.shape + if d != self.n_input_dims: + raise ValueError( + f"Input dimension mismatch: expected {self.n_input_dims}, got {d}" + ) + x_and_res = jax.vmap(self.in_proj)(x) + (x, res) = jnp.split(x_and_res, 2, axis=-1) + + x = jnp.transpose(x) + x = self.conv1d(x)[:, :seq_len] + x = jnp.transpose(x) + x = jax.nn.silu(x) + + y = self._ssm(x) + y = y * jax.nn.silu(res) + + output = jax.vmap(self.out_proj)(y) + return output + + def _ssm(self, x: Float[Array, "seq_len d_inner"]) -> Array: + A = -jnp.exp(self.A_log) + D = self.D + + x_delta_b_c = jax.vmap(self.x_proj)(x) + + split_indices = [ + self.dt_rank, + self.dt_rank + self.state_space_dims, + ] + delta, B, C = jnp.split(x_delta_b_c, split_indices, axis=-1) + delta = jax.nn.softplus(jax.vmap(self.dt_proj)(delta)) + + y = _selective_scan(x, delta, A, B, C, D) + return y From 572fda43e100eab673745f2a2bb4800c4372fe7b Mon Sep 17 00:00:00 2001 From: "Artur A. Galstyan" Date: Fri, 16 Feb 2024 17:38:42 +0100 Subject: [PATCH 04/10] started with example --- equinox/nn/_selective_state_space_models.py | 6 +- examples/mamba.ipynb | 203 ++++++++++++++++++++ imgs/Mamba1.drawio.svg | 4 + 3 files changed, 212 insertions(+), 1 deletion(-) create mode 100644 examples/mamba.ipynb create mode 100644 imgs/Mamba1.drawio.svg diff --git a/equinox/nn/_selective_state_space_models.py b/equinox/nn/_selective_state_space_models.py index f09e7474..2c1b74ec 100644 --- a/equinox/nn/_selective_state_space_models.py +++ b/equinox/nn/_selective_state_space_models.py @@ -92,6 +92,10 @@ def __init__( dt_rank: The rank of delta. If "auto", it will be set to ceil(n_input_dims / state_space_dims). pad_vocab_size_multiple: The multiple of the vocabulary size + use_bias_in_proj: Whether to use bias in the input projection layer. + use_bias_conv1d: Whether to use bias in the convolutional layer. + use_bias_out_proj: Whether to use bias in the output projection layer. + key: The PRNG key. """ self.n_input_dims = n_input_dims @@ -156,7 +160,7 @@ def __init__( key=x_proj_key, ) - @jax.named_scope("eqx.nn.StateSpaceModel") + @jax.named_scope("eqx.nn.SelectiveStateSpaceModel") def __call__(self, x: Float[Array, "seq_len n_input_dims"]) -> Array: seq_len, d = x.shape if d != self.n_input_dims: diff --git a/examples/mamba.ipynb b/examples/mamba.ipynb new file mode 100644 index 00000000..d3bf3826 --- /dev/null +++ b/examples/mamba.ipynb @@ -0,0 +1,203 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f0591c0a-21c6-4ff4-b99f-4628dcc076df", + "metadata": {}, + "source": [ + "# Mamba\n", + "\n", + "In this example, we will implement the new Mamba model from Albert Gu and Tri Dao [[1]](https://arxiv.org/abs/2312.00752) by utilising the new `SelectiveStateSpaceModel` layer. \n", + "\n", + "In this example, you will learn the following:\n", + "\n", + " - how to implement Mamba\n", + " - how to use a shared layer\n", + "\n", + "Special thanks and cretits go to John (Zhiyao) Ma and his excellent Mamba implementation in PyTorch, which served as a great inspriration and foundation for this Equinox version. Go check it out [here](https://github.com/johnma2006/mamba-minimal).\n", + "\n", + "The original implementation includes **a lot** of CUDA code [[2]](https://github.com/state-spaces/mamba) to optimise the so-called `selective_scan` algorithm, but this first iteration of the `SelectiveStateSpaceModel` implementation is not as heavily optimised. However, in future iterations, by using some clever Pallas code, we can get to the same performance. " + ] + }, + { + "cell_type": "markdown", + "id": "b418d2b9-80f8-4048-9b8b-d8e958035705", + "metadata": {}, + "source": [ + "The following image shows the high level architecture of Mamba which we will implement." + ] + }, + { + "cell_type": "markdown", + "id": "660315df-3e77-4e3c-a793-c153840f094c", + "metadata": {}, + "source": [ + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "b9a8ed93-4946-4758-a228-cc39326b8c9d", + "metadata": {}, + "source": [ + "Before dive into the `ResidualBlock` part, which contains the main `SelectiveStateSpaceModel` code, let's quickly build everything around it first. Also note that the weights of the embedding layer and the final linear layer are shared! This is not a problem though, because we can use `eqx.nn.Shared` to implement this." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ab229dca-2c2a-46ee-8f24-eedce5c06e18", + "metadata": {}, + "outputs": [], + "source": [ + "import equinox as eqx\n", + "import jax\n", + "from jaxtyping import Array, Float, Int, PRNGKeyArray" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dae7f1eb-ad21-4e6f-a6a1-352eea03d414", + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "module 'equinox.nn' has no attribute 'RMSNorm'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43;01mMamba\u001b[39;49;00m\u001b[43m(\u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mModule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mlayers\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mSequential\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mnormalization\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRMSNorm\u001b[49m\n", + "Cell \u001b[0;32mIn[4], line 3\u001b[0m, in \u001b[0;36mMamba\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mMamba\u001b[39;00m(eqx\u001b[38;5;241m.\u001b[39mModule, strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[1;32m 2\u001b[0m layers: eqx\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mSequential\n\u001b[0;32m----> 3\u001b[0m normalization: \u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRMSNorm\u001b[49m\n\u001b[1;32m 4\u001b[0m shared_emb_lm_head: eqx\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mShared\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, n_layers: \u001b[38;5;28mint\u001b[39m, n_dims: \u001b[38;5;28mint\u001b[39m, n_embd: \u001b[38;5;28mint\u001b[39m, \u001b[38;5;241m*\u001b[39m, key: PRNGKeyArray):\n", + "\u001b[0;31mAttributeError\u001b[0m: module 'equinox.nn' has no attribute 'RMSNorm'" + ] + } + ], + "source": [ + "class Mamba(eqx.Module, strict=True):\n", + " layers: eqx.nn.Sequential\n", + " normalization: eqx.nn.RMSNorm\n", + " shared_emb_lm_head: eqx.nn.Shared\n", + "\n", + " def __init__(self, n_layers: int, n_dims: int, n_embd: int, *, key: PRNGKeyArray):\n", + " key, *subkeys = jax.random.split(key, 1 + n_layers)\n", + " self.layers = eqx.nn.Sequential(\n", + " [ResidualBlock(key=subkeys[i + 1]) for i in range(n_layers)],\n", + " )\n", + " self.normalization = eqx.nn.RMSNorm(n_embd)\n", + "\n", + " embedding = eqx.nn.Embedding(n_dims, n_embd, key=subkeys[0])\n", + " lm_head = eqx.nn.Linear(\n", + " n_embd,\n", + " n_dims,\n", + " use_bias=False,\n", + " key=subkeys[-1],\n", + " )\n", + " where = lambda embed_and_lin: embed_and_lin[1].weight\n", + " get = lambda embed_and_lin: embed_and_lin[0].weight\n", + " self.shared_emb_lm_head = eqx.nn.Shared(\n", + " (embedding, lm_head), where=where, get=get\n", + " )\n", + "\n", + " def __call__(\n", + " self,\n", + " x: Int[Array, \"seq_len\"], # noqa\n", + " *,\n", + " key: PRNGKeyArray = None,\n", + " ) -> Float[Array, \"seq_len n_dims\"]: # noqa\n", + " embedding, linear = self.shared_emb_lm_head()\n", + " x = jax.vmap(embedding)(x)\n", + "\n", + " x = self.layers(x)\n", + " x = jax.vmap(self.normalization)(x)\n", + " logits = jax.vmap(linear)(x)\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "id": "d557fa4a-1fa4-4e72-ba77-d0ef7f946cdf", + "metadata": {}, + "source": [ + "We haven't implementated `ResidualBlock` yet, but we will get there soon. Note the usage of `eqx.nn.Shared`:\n", + "\n", + "```python\n", + " # Embedding layer\n", + " embedding = eqx.nn.Embedding(\n", + " n_dims, n_embd, key=subkeys[0]\n", + " )\n", + " # Linear layer\n", + " lm_head = eqx.nn.Linear(\n", + " n_embd,\n", + " n_dims,\n", + " use_bias=False,\n", + " key=subkeys[-1],\n", + " )\n", + " # refers to the linear weights\n", + " where = lambda embed_and_lin: embed_and_lin[1].weight \n", + "\n", + " # refers to the embedding weights\n", + " get = lambda embed_and_lin: embed_and_lin[0].weight\n", + "\n", + " # Create a shared layer\n", + " self.shared_emb_lm_head = eqx.nn.Shared(\n", + " (embedding, lm_head), where=where, get=get\n", + " )\n", + "```\n", + "\n", + "And to use the shared layers, we have to get them first out of the shared layer:\n", + "\n", + "```python\n", + " embedding, linear = self.shared_emb_lm_head()\n", + " # embedding and linear are eqx.nn.Embedding and eqx.nn.Linear respectively\n", + " # proceed usage as usual\n", + "```\n", + "\n", + "Let's continue with the `ResidualBlock`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1ed52a9-abea-43e8-822a-9d75ac8ae480", + "metadata": {}, + "outputs": [], + "source": [ + "class ResidualBlock(eqx.Module):\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8aa05c2c-d8d0-484f-9d29-430b41bf4181", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/imgs/Mamba1.drawio.svg b/imgs/Mamba1.drawio.svg new file mode 100644 index 00000000..ccce3599 --- /dev/null +++ b/imgs/Mamba1.drawio.svg @@ -0,0 +1,4 @@ + + + +
(seq_len, n_dims)
(seq_len, n_dims)
Embedding Dimension
n_dims -> n_embd
Embedding Dimension...
Residual Block
Residual Block
Residual Block
Residual Block
Residual Block
Residual Block
Residual Block
Residual Block
Normalization
Normalization
Linear
n_embd -> n_dims
Linear...
(seq_len, n_dims)
(seq_len, n_dims)
N layers
N layers
Text is not SVG - cannot display
\ No newline at end of file From b7edd5739c823271434e8d7e1466496ac085cdb8 Mon Sep 17 00:00:00 2001 From: "Artur A. Galstyan" Date: Sun, 25 Feb 2024 17:47:00 +0100 Subject: [PATCH 05/10] included more graphs in the example --- examples/mamba.ipynb | 79 +++++++++++++++++++++++++++++++----------- imgs/Mamba2.drawio.svg | 4 +++ imgs/Mamba3.drawio.svg | 4 +++ 3 files changed, 67 insertions(+), 20 deletions(-) create mode 100644 imgs/Mamba2.drawio.svg create mode 100644 imgs/Mamba3.drawio.svg diff --git a/examples/mamba.ipynb b/examples/mamba.ipynb index d3bf3826..20e57057 100644 --- a/examples/mamba.ipynb +++ b/examples/mamba.ipynb @@ -52,6 +52,8 @@ "metadata": {}, "outputs": [], "source": [ + "from typing import Optional\n", + "\n", "import equinox as eqx\n", "import jax\n", "from jaxtyping import Array, Float, Int, PRNGKeyArray" @@ -59,23 +61,10 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "dae7f1eb-ad21-4e6f-a6a1-352eea03d414", + "execution_count": 3, + "id": "91463596-114f-45f7-a2f0-b1ae8d16f25f", "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "module 'equinox.nn' has no attribute 'RMSNorm'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43;01mMamba\u001b[39;49;00m\u001b[43m(\u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mModule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mlayers\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mSequential\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mnormalization\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRMSNorm\u001b[49m\n", - "Cell \u001b[0;32mIn[4], line 3\u001b[0m, in \u001b[0;36mMamba\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mMamba\u001b[39;00m(eqx\u001b[38;5;241m.\u001b[39mModule, strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[1;32m 2\u001b[0m layers: eqx\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mSequential\n\u001b[0;32m----> 3\u001b[0m normalization: \u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRMSNorm\u001b[49m\n\u001b[1;32m 4\u001b[0m shared_emb_lm_head: eqx\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mShared\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, n_layers: \u001b[38;5;28mint\u001b[39m, n_dims: \u001b[38;5;28mint\u001b[39m, n_embd: \u001b[38;5;28mint\u001b[39m, \u001b[38;5;241m*\u001b[39m, key: PRNGKeyArray):\n", - "\u001b[0;31mAttributeError\u001b[0m: module 'equinox.nn' has no attribute 'RMSNorm'" - ] - } - ], + "outputs": [], "source": [ "class Mamba(eqx.Module, strict=True):\n", " layers: eqx.nn.Sequential\n", @@ -106,7 +95,7 @@ " self,\n", " x: Int[Array, \"seq_len\"], # noqa\n", " *,\n", - " key: PRNGKeyArray = None,\n", + " key: Optional[PRNGKeyArray] = None,\n", " ) -> Float[Array, \"seq_len n_dims\"]: # noqa\n", " embedding, linear = self.shared_emb_lm_head()\n", " x = jax.vmap(embedding)(x)\n", @@ -159,21 +148,71 @@ "Let's continue with the `ResidualBlock`." ] }, + { + "cell_type": "markdown", + "id": "6b4b5de1-8300-4fc4-934e-3b0194bf8372", + "metadata": {}, + "source": [ + "Here's an overview of what the components of the `ResidualBlock` will look like.\n", + "\n", + "
\n", + " \n", + "
\n", + "\n", + "As you can see, we keep diving further into the model. Let's implement this `ResidualBlock` now." + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "d1ed52a9-abea-43e8-822a-9d75ac8ae480", "metadata": {}, "outputs": [], "source": [ - "class ResidualBlock(eqx.Module):\n", + "class ResidualBlock(eqx.Module, strict=True):\n", + " mamba_block: MambaBlock\n", + " rns_norm: eqx.nn.RMSNorm\n", + "\n", + " def __init__(self, n_embd: int, *, key: PRNGKeyArray):\n", + " \n", + " self.mamba_block = MambaBlock(\n", + " key=key,\n", + " )\n", + " self.rns_norm = eqx.nn.RMSNorm(n_embd)\n", + "\n", + " def __call__(\n", + " self, x: Float[Array, \"seq_len n_embd\"], *, key: Optional[PRNGKeyArray] = None\n", + " ) -> Array:\n", + " return self.mamba_block(jax.vmap(self.rns_norm)(x)) + x" + ] + }, + { + "cell_type": "markdown", + "id": "d8562a8b-3bea-4997-8c53-be33d66893f1", + "metadata": {}, + "source": [ + "We're getting closer and closer to the heart of the Mamba model. Let's look at what the `MambaBlock` looks like. This time, I've included the shapes of the matrices as they traverse through all kinds of transformations. \n", + "\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6b57d346-18d0-4d29-b35d-e1a6e99c6a75", + "metadata": {}, + "outputs": [], + "source": [ + "class MambaBlock(eqx.Module):\n", " pass" ] }, { "cell_type": "code", "execution_count": null, - "id": "8aa05c2c-d8d0-484f-9d29-430b41bf4181", + "id": "e84076b2-9808-4916-8580-d789cdc49c05", "metadata": {}, "outputs": [], "source": [] diff --git a/imgs/Mamba2.drawio.svg b/imgs/Mamba2.drawio.svg new file mode 100644 index 00000000..758b254e --- /dev/null +++ b/imgs/Mamba2.drawio.svg @@ -0,0 +1,4 @@ + + + +
(seq_len, n_embd)
Normalisation
MambaBlock
(seq_len, n_dims)
\ No newline at end of file diff --git a/imgs/Mamba3.drawio.svg b/imgs/Mamba3.drawio.svg new file mode 100644 index 00000000..1edc7487 --- /dev/null +++ b/imgs/Mamba3.drawio.svg @@ -0,0 +1,4 @@ + + + +
Input Projection
(n_embd -> 2 * d_inner)
(seq_len, n_embd)
(seq_len, 2 * d_inner)
Split
x
(seq_len, d_inner)
residual
(seq_len, d_inner)
Conv1d
x
(seq_len, d_inner)
Truncate to seq_len
Silu
SSM
x
(seq_len, d_inner)
Silu
x
(seq_len, d_inner)
Output Projection
(d_inner -> n_embd)
x * residual
(seq_len, d_inner)
x
(seq_len, n_embd)
\ No newline at end of file From 1a20cc6f9692aeb3235bf534f76f057f46c1745d Mon Sep 17 00:00:00 2001 From: "Artur A. Galstyan" Date: Thu, 29 Feb 2024 22:50:48 +0100 Subject: [PATCH 06/10] more mamba example --- examples/mamba.ipynb | 85 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 79 insertions(+), 6 deletions(-) diff --git a/examples/mamba.ipynb b/examples/mamba.ipynb index 20e57057..be1bf0f4 100644 --- a/examples/mamba.ipynb +++ b/examples/mamba.ipynb @@ -47,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "id": "ab229dca-2c2a-46ee-8f24-eedce5c06e18", "metadata": {}, "outputs": [], @@ -61,7 +61,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "id": "91463596-114f-45f7-a2f0-b1ae8d16f25f", "metadata": {}, "outputs": [], @@ -164,7 +164,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 11, "id": "d1ed52a9-abea-43e8-822a-9d75ac8ae480", "metadata": {}, "outputs": [], @@ -174,7 +174,6 @@ " rns_norm: eqx.nn.RMSNorm\n", "\n", " def __init__(self, n_embd: int, *, key: PRNGKeyArray):\n", - " \n", " self.mamba_block = MambaBlock(\n", " key=key,\n", " )\n", @@ -200,19 +199,93 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 10, "id": "6b57d346-18d0-4d29-b35d-e1a6e99c6a75", "metadata": {}, "outputs": [], "source": [ "class MambaBlock(eqx.Module):\n", + " in_proj: eqx.nn.Linear\n", + " conv1d: eqx.nn.Conv1d\n", + " ssm: SSM\n", + " out_proj: eqx.nn.Linear\n", + "\n", + " def __init__(\n", + " self,\n", + " n_embd: int,\n", + " d_inner: int,\n", + " dt_rank: int,\n", + " d_conv: int,\n", + " use_in_projection_bias: bool=True, \n", + " use_conv_bias: bool=True,\n", + " use_out_proj_bias: bool = True,\n", + " *,\n", + " key: PRNGKeyArray,\n", + " ):\n", + " (\n", + " key,\n", + " linear_key,\n", + " conv1d_key,\n", + " ssm_key,\n", + " out_proj_key,\n", + " ) = jax.random.split(key, 5)\n", + "\n", + " self.in_proj = eqx.nn.Linear(\n", + " n_embd,\n", + " d_inner * 2,\n", + " use_bias=use_in_projection_bias,\n", + " key=linear_key,\n", + " )\n", + "\n", + " self.conv1d = eqx.nn.Conv1d(\n", + " in_channels=d_inner,\n", + " out_channels=d_inner,\n", + " kernel_size=d_conv,\n", + " use_bias=use_conv_bias,\n", + " groups=d_inner,\n", + " padding=d_conv - 1,\n", + " key=conv1d_key,\n", + " )\n", + " self.ssm = SSM(key=ssm_key)\n", + " self.out_proj = eqx.nn.Linear(\n", + " d_inner,\n", + " n_embd,\n", + " use_bias=use_out_proj_bias,\n", + " key=out_proj_key,\n", + " )\n", + "\n", + " def __call__(self, x: Array):\n", + " seq_len, d = x.shape\n", + " x_and_res = jax.vmap(self.in_proj)(x)\n", + "\n", + " (x, res) = jnp.split(x_and_res, 2, axis=-1)\n", + " x = jnp.transpose(x)\n", + " x = self.conv1d(x)[:, :seq_len]\n", + " x = jnp.transpose(x)\n", + " x = jax.nn.silu(x)\n", + "\n", + " y = self.ssm(x)\n", + " y = y * jax.nn.silu(res)\n", + "\n", + " output = jax.vmap(self.out_proj)(y)\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e84076b2-9808-4916-8580-d789cdc49c05", + "metadata": {}, + "outputs": [], + "source": [ + "class SSM(eqx.Module):\n", " pass" ] }, { "cell_type": "code", "execution_count": null, - "id": "e84076b2-9808-4916-8580-d789cdc49c05", + "id": "1196448d-c659-44d2-9527-6b29723b2c59", "metadata": {}, "outputs": [], "source": [] From b76ecd8c04786b4e1aa6109aa9ba495791d827cf Mon Sep 17 00:00:00 2001 From: "Artur A. Galstyan" Date: Sun, 3 Mar 2024 11:43:40 +0100 Subject: [PATCH 07/10] mamba progress --- examples/mamba.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mamba.ipynb b/examples/mamba.ipynb index be1bf0f4..d7d51d6c 100644 --- a/examples/mamba.ipynb +++ b/examples/mamba.ipynb @@ -216,8 +216,8 @@ " d_inner: int,\n", " dt_rank: int,\n", " d_conv: int,\n", - " use_in_projection_bias: bool=True, \n", - " use_conv_bias: bool=True,\n", + " use_in_projection_bias: bool = True,\n", + " use_conv_bias: bool = True,\n", " use_out_proj_bias: bool = True,\n", " *,\n", " key: PRNGKeyArray,\n", From 5f02a8c4ce96ad05a835050639adf70ca22b9fb2 Mon Sep 17 00:00:00 2001 From: "Artur A. Galstyan" Date: Mon, 4 Mar 2024 22:36:39 +0100 Subject: [PATCH 08/10] added mamba example --- examples/mamba.ipynb | 457 ++++++++++++++++++++++++++++++----------- imgs/Mamba4.drawio.svg | 4 + 2 files changed, 342 insertions(+), 119 deletions(-) create mode 100644 imgs/Mamba4.drawio.svg diff --git a/examples/mamba.ipynb b/examples/mamba.ipynb index d7d51d6c..58a04a0c 100644 --- a/examples/mamba.ipynb +++ b/examples/mamba.ipynb @@ -1,5 +1,20 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "ab229dca-2c2a-46ee-8f24-eedce5c06e18", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Optional\n", + "\n", + "import equinox as eqx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from jaxtyping import Array, Float, Int, PRNGKeyArray" + ] + }, { "cell_type": "markdown", "id": "f0591c0a-21c6-4ff4-b99f-4628dcc076df", @@ -39,167 +54,188 @@ }, { "cell_type": "markdown", - "id": "b9a8ed93-4946-4758-a228-cc39326b8c9d", + "id": "7f2b1134-b021-4fae-97d7-90c6afb66cd5", "metadata": {}, "source": [ - "Before dive into the `ResidualBlock` part, which contains the main `SelectiveStateSpaceModel` code, let's quickly build everything around it first. Also note that the weights of the embedding layer and the final linear layer are shared! This is not a problem though, because we can use `eqx.nn.Shared` to implement this." + "If we zoom into the `ResidualBlock`, we find the following:" ] }, { - "cell_type": "code", - "execution_count": 4, - "id": "ab229dca-2c2a-46ee-8f24-eedce5c06e18", + "cell_type": "markdown", + "id": "6b4b5de1-8300-4fc4-934e-3b0194bf8372", "metadata": {}, - "outputs": [], "source": [ - "from typing import Optional\n", + "
\n", + " \n", + "
\n", "\n", - "import equinox as eqx\n", - "import jax\n", - "from jaxtyping import Array, Float, Int, PRNGKeyArray" + "As you can see, we keep diving further into the model. Let's implement this `ResidualBlock` now. Let's keep on zooming until we get to the deepest component - at which point we can start to implement everything and work our way back up. Let's keep going." ] }, { - "cell_type": "code", - "execution_count": 5, - "id": "91463596-114f-45f7-a2f0-b1ae8d16f25f", + "cell_type": "markdown", + "id": "d8562a8b-3bea-4997-8c53-be33d66893f1", "metadata": {}, - "outputs": [], "source": [ - "class Mamba(eqx.Module, strict=True):\n", - " layers: eqx.nn.Sequential\n", - " normalization: eqx.nn.RMSNorm\n", - " shared_emb_lm_head: eqx.nn.Shared\n", - "\n", - " def __init__(self, n_layers: int, n_dims: int, n_embd: int, *, key: PRNGKeyArray):\n", - " key, *subkeys = jax.random.split(key, 1 + n_layers)\n", - " self.layers = eqx.nn.Sequential(\n", - " [ResidualBlock(key=subkeys[i + 1]) for i in range(n_layers)],\n", - " )\n", - " self.normalization = eqx.nn.RMSNorm(n_embd)\n", - "\n", - " embedding = eqx.nn.Embedding(n_dims, n_embd, key=subkeys[0])\n", - " lm_head = eqx.nn.Linear(\n", - " n_embd,\n", - " n_dims,\n", - " use_bias=False,\n", - " key=subkeys[-1],\n", - " )\n", - " where = lambda embed_and_lin: embed_and_lin[1].weight\n", - " get = lambda embed_and_lin: embed_and_lin[0].weight\n", - " self.shared_emb_lm_head = eqx.nn.Shared(\n", - " (embedding, lm_head), where=where, get=get\n", - " )\n", - "\n", - " def __call__(\n", - " self,\n", - " x: Int[Array, \"seq_len\"], # noqa\n", - " *,\n", - " key: Optional[PRNGKeyArray] = None,\n", - " ) -> Float[Array, \"seq_len n_dims\"]: # noqa\n", - " embedding, linear = self.shared_emb_lm_head()\n", - " x = jax.vmap(embedding)(x)\n", + "We're getting closer and closer to the heart of the Mamba model. Let's look at what the `MambaBlock` looks like. This time, I've included the shapes of the matrices as they traverse through all kinds of transformations. \n", "\n", - " x = self.layers(x)\n", - " x = jax.vmap(self.normalization)(x)\n", - " logits = jax.vmap(linear)(x)\n", - " return logits" + "
\n", + " \n", + "
" ] }, { "cell_type": "markdown", - "id": "d557fa4a-1fa4-4e72-ba77-d0ef7f946cdf", + "id": "a58c528d-2f7a-4696-808e-f1fe9db2110d", "metadata": {}, "source": [ - "We haven't implementated `ResidualBlock` yet, but we will get there soon. Note the usage of `eqx.nn.Shared`:\n", + "Most of the parts we need are already present in Equinox's library. What's missing though is the new `SelectiveStateSpaceModel` (abbreviated as `SSM` above). Everything in green are trainable parameters. \n", "\n", - "```python\n", - " # Embedding layer\n", - " embedding = eqx.nn.Embedding(\n", - " n_dims, n_embd, key=subkeys[0]\n", - " )\n", - " # Linear layer\n", - " lm_head = eqx.nn.Linear(\n", - " n_embd,\n", - " n_dims,\n", - " use_bias=False,\n", - " key=subkeys[-1],\n", - " )\n", - " # refers to the linear weights\n", - " where = lambda embed_and_lin: embed_and_lin[1].weight \n", - "\n", - " # refers to the embedding weights\n", - " get = lambda embed_and_lin: embed_and_lin[0].weight\n", - "\n", - " # Create a shared layer\n", - " self.shared_emb_lm_head = eqx.nn.Shared(\n", - " (embedding, lm_head), where=where, get=get\n", - " )\n", - "```\n", - "\n", - "And to use the shared layers, we have to get them first out of the shared layer:\n", - "\n", - "```python\n", - " embedding, linear = self.shared_emb_lm_head()\n", - " # embedding and linear are eqx.nn.Embedding and eqx.nn.Linear respectively\n", - " # proceed usage as usual\n", - "```\n", - "\n", - "Let's continue with the `ResidualBlock`." + "
\n", + " \n", + "
" ] }, { "cell_type": "markdown", - "id": "6b4b5de1-8300-4fc4-934e-3b0194bf8372", + "id": "3c88cd1b-eb91-4af1-b537-f209342a8bb1", "metadata": {}, "source": [ - "Here's an overview of what the components of the `ResidualBlock` will look like.\n", + "Alright! This is the deepest we can get. We've reached the point at which we have all needed components available to us (except for the `selective_scan` function, but that's not a problem). Let's start with the `SelectiveStateSpaceModel` and then work our way back up again." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9efd5640-cdb6-427d-be3d-860cf371e357", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'Float' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mselective_scan\u001b[39m(\n\u001b[0;32m----> 2\u001b[0m x: \u001b[43mFloat\u001b[49m[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseq_length d_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 3\u001b[0m delta: Float[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseq_length d_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 4\u001b[0m A: Float[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner d_state\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 5\u001b[0m B: Float[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseq_length d_state\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 6\u001b[0m C: Float[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseq_length d_state\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 7\u001b[0m D: Float[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m d_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 8\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Float[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseq_length d_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[1;32m 9\u001b[0m L, d_inner \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 10\u001b[0m _, d_state \u001b[38;5;241m=\u001b[39m A\u001b[38;5;241m.\u001b[39mshape\n", + "\u001b[0;31mNameError\u001b[0m: name 'Float' is not defined" + ] + } + ], + "source": [ + "def selective_scan(\n", + " x: Float[Array, \"seq_length d_inner\"],\n", + " delta: Float[Array, \"seq_length d_inner\"],\n", + " A: Float[Array, \"d_inner d_state\"],\n", + " B: Float[Array, \"seq_length d_state\"],\n", + " C: Float[Array, \"seq_length d_state\"],\n", + " D: Float[Array, \" d_inner\"],\n", + ") -> Float[Array, \"seq_length d_inner\"]:\n", + " L, d_inner = x.shape\n", + " _, d_state = A.shape\n", + " delta_A = jnp.exp(jnp.einsum(\"l d,d n -> l d n\", delta, A))\n", + " delta_B_u = jnp.einsum(\"l d,l n,l d -> l d n\", delta, B, x)\n", "\n", - "
\n", - " \n", - "
\n", + " x_res = jnp.zeros(shape=(d_inner, d_state))\n", + "\n", + " def step(x, i):\n", + " x = delta_A[i] * x + delta_B_u[i]\n", + "\n", + " y = jnp.einsum(\"d n,n -> d\", x, C[i, :])\n", + " return x, y\n", + "\n", + " _, ys = jax.lax.scan(step, x_res, jnp.arange(L))\n", "\n", - "As you can see, we keep diving further into the model. Let's implement this `ResidualBlock` now." + " ys = ys + x * D\n", + " return ys" ] }, { "cell_type": "code", - "execution_count": 11, - "id": "d1ed52a9-abea-43e8-822a-9d75ac8ae480", + "execution_count": null, + "id": "d0856563-ef82-4deb-91cf-f34f6c8793bb", "metadata": {}, "outputs": [], "source": [ - "class ResidualBlock(eqx.Module, strict=True):\n", - " mamba_block: MambaBlock\n", - " rns_norm: eqx.nn.RMSNorm\n", + "class SelectiveStateSpaceModel(eqx.Module, strict=True):\n", + " input_proj: eqx.nn.Linear\n", + " delta_proj: eqx.nn.Linear\n", + " A_log: Float[Array, \"d_inner d_state\"]\n", + " D: Float[Array, \" d_inner\"]\n", "\n", - " def __init__(self, n_embd: int, *, key: PRNGKeyArray):\n", - " self.mamba_block = MambaBlock(\n", - " key=key,\n", + " d_inner: int = eqx.field(static=True)\n", + " dt_rank: int = eqx.field(static=True)\n", + " d_state: int = eqx.field(static=True)\n", + "\n", + " def __init__(\n", + " self,\n", + " d_inner: int,\n", + " dt_rank: int,\n", + " d_state: int,\n", + " use_input_proj_bias: bool = False,\n", + " use_delta_proj_bias: bool = False,\n", + " *,\n", + " key: PRNGKeyArray,\n", + " ):\n", + " self.d_inner = d_inner\n", + " self.dt_rank = dt_rank\n", + " self.d_state = d_state\n", + " (\n", + " key,\n", + " input_proj_key,\n", + " delta_proj_key,\n", + " ) = jax.random.split(key, 3)\n", + " self.input_proj = eqx.nn.Linear(\n", + " d_inner,\n", + " dt_rank + d_state * 2,\n", + " use_bias=use_input_proj_bias,\n", + " key=input_proj_key,\n", " )\n", - " self.rns_norm = eqx.nn.RMSNorm(n_embd)\n", "\n", - " def __call__(\n", - " self, x: Float[Array, \"seq_len n_embd\"], *, key: Optional[PRNGKeyArray] = None\n", - " ) -> Array:\n", - " return self.mamba_block(jax.vmap(self.rns_norm)(x)) + x" + " self.delta_proj = eqx.nn.Linear(\n", + " dt_rank, d_inner, use_bias=use_delta_proj_bias, key=delta_proj_key\n", + " )\n", + " A = jnp.repeat(jnp.arange(1, d_state + 1), d_inner).reshape(d_inner, d_state)\n", + " self.A_log = jnp.log(A)\n", + " self.D = jnp.ones(d_inner)\n", + "\n", + " def __call__(self, x: Float[Array, \"seq_length d_inner\"]):\n", + " A = -jnp.exp(self.A_log)\n", + " D = self.D\n", + "\n", + " delta_b_c = jax.vmap(self.input_proj)(x)\n", + "\n", + " split_indices = [\n", + " self.dt_rank,\n", + " self.dt_rank + self.d_state,\n", + " ]\n", + " delta, B, C = jnp.split(delta_b_c, split_indices, axis=-1)\n", + " delta = jax.nn.softplus(jax.vmap(self.delta_proj)(delta))\n", + "\n", + " y = selective_scan(x, delta, A, B, C, D)\n", + " return y" ] }, { "cell_type": "markdown", - "id": "d8562a8b-3bea-4997-8c53-be33d66893f1", + "id": "eb44075c-e70c-4d93-8f6b-1b160cc7e7dd", "metadata": {}, "source": [ - "We're getting closer and closer to the heart of the Mamba model. Let's look at what the `MambaBlock` looks like. This time, I've included the shapes of the matrices as they traverse through all kinds of transformations. \n", - "\n", - "
\n", - " \n", - "
" + "## Detour: State Space Models\n", + "___TODO___: Explain SSMs in general!" + ] + }, + { + "cell_type": "markdown", + "id": "2b48addc-b78f-429b-9299-d1e283f3bd76", + "metadata": {}, + "source": [ + "Armed with the `SSM`, we can now implement the `MambaBlock` part. See the images above for where we are right now!" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "6b57d346-18d0-4d29-b35d-e1a6e99c6a75", "metadata": {}, "outputs": [], @@ -207,7 +243,7 @@ "class MambaBlock(eqx.Module):\n", " in_proj: eqx.nn.Linear\n", " conv1d: eqx.nn.Conv1d\n", - " ssm: SSM\n", + " ssm: SelectiveStateSpaceModel\n", " out_proj: eqx.nn.Linear\n", "\n", " def __init__(\n", @@ -219,6 +255,8 @@ " use_in_projection_bias: bool = True,\n", " use_conv_bias: bool = True,\n", " use_out_proj_bias: bool = True,\n", + " ssm_use_delta_proj_bias: bool = False,\n", + " ssm_use_input_proj_bias: bool = False,\n", " *,\n", " key: PRNGKeyArray,\n", " ):\n", @@ -246,7 +284,14 @@ " padding=d_conv - 1,\n", " key=conv1d_key,\n", " )\n", - " self.ssm = SSM(key=ssm_key)\n", + " self.ssm = SelectiveStateSpaceModel(\n", + " d_inner=d_inner,\n", + " dt_rank=dt_rank,\n", + " d_state=d_inner,\n", + " use_delta_proj_bias=ssm_use_delta_proj_bias,\n", + " use_input_proj_bias=ssm_use_input_proj_bias,\n", + " key=ssm_key,\n", + " )\n", " self.out_proj = eqx.nn.Linear(\n", " d_inner,\n", " n_embd,\n", @@ -271,21 +316,195 @@ " return output" ] }, + { + "cell_type": "markdown", + "id": "e8a7543f-5bc9-480e-bab4-7f839cc8faad", + "metadata": {}, + "source": [ + "Now, we can wrap the `MambaBlock` into the `ResidualBlock` -- as the name suggests, this has a residual connection (or in non-_sciency_ words: it adds the original input to the transformation)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1ed52a9-abea-43e8-822a-9d75ac8ae480", + "metadata": {}, + "outputs": [], + "source": [ + "class ResidualBlock(eqx.Module, strict=True):\n", + " mamba_block: MambaBlock\n", + " rns_norm: eqx.nn.RMSNorm\n", + "\n", + " def __init__(\n", + " self,\n", + " n_embd: int,\n", + " d_inner: int,\n", + " dt_rank: int,\n", + " d_conv: int,\n", + " use_in_projection_bias: bool = True,\n", + " use_conv_bias: bool = True,\n", + " use_out_proj_bias: bool = True,\n", + " ssm_use_delta_proj_bias: bool = False,\n", + " ssm_use_input_proj_bias: bool = False,\n", + " *,\n", + " key: PRNGKeyArray,\n", + " ):\n", + " self.mamba_block = MambaBlock(\n", + " n_embd=n_embd,\n", + " d_inner=d_inner,\n", + " dt_rank=dt_rank,\n", + " d_conv=d_conv,\n", + " use_in_projection_bias=use_in_projection_bias,\n", + " use_conv_bias=use_conv_bias,\n", + " use_out_proj_bias=use_out_proj_bias,\n", + " ssm_use_delta_proj_bias=ssm_use_delta_proj_bias,\n", + " ssm_use_input_proj_bias=ssm_use_input_proj_bias,\n", + " key=key,\n", + " )\n", + " self.rns_norm = eqx.nn.RMSNorm(n_embd)\n", + "\n", + " def __call__(\n", + " self, x: Float[Array, \"seq_len n_embd\"], *, key: Optional[PRNGKeyArray] = None\n", + " ) -> Array:\n", + " return self.mamba_block(jax.vmap(self.rns_norm)(x)) + x" + ] + }, + { + "cell_type": "markdown", + "id": "b9a8ed93-4946-4758-a228-cc39326b8c9d", + "metadata": {}, + "source": [ + "We've arrived at the highest point again. We can put everything into the `Mamba` class now. Note that the weights of the embedding layer and the final linear layer are shared! This is not a problem though, because we can use `eqx.nn.Shared` to implement this." + ] + }, { "cell_type": "code", - "execution_count": 9, - "id": "e84076b2-9808-4916-8580-d789cdc49c05", + "execution_count": null, + "id": "91463596-114f-45f7-a2f0-b1ae8d16f25f", "metadata": {}, "outputs": [], "source": [ - "class SSM(eqx.Module):\n", - " pass" + "class Mamba(eqx.Module, strict=True):\n", + " layers: eqx.nn.Sequential\n", + " normalization: eqx.nn.RMSNorm\n", + " shared_emb_lm_head: eqx.nn.Shared\n", + "\n", + " def __init__(\n", + " self,\n", + " n_layers: int,\n", + " n_dims: int,\n", + " n_embd: int,\n", + " d_inner: int,\n", + " dt_rank: int,\n", + " d_conv: int,\n", + " use_in_projection_bias: bool = True,\n", + " use_conv_bias: bool = True,\n", + " use_out_proj_bias: bool = True,\n", + " ssm_use_delta_proj_bias: bool = False,\n", + " ssm_use_input_proj_bias: bool = False,\n", + " *,\n", + " key: PRNGKeyArray,\n", + " ):\n", + " key, *subkeys = jax.random.split(key, 1 + n_layers)\n", + " self.layers = eqx.nn.Sequential(\n", + " [\n", + " ResidualBlock(\n", + " n_embd=n_embd,\n", + " d_inner=d_inner,\n", + " dt_rank=dt_rank,\n", + " d_conv=d_conv,\n", + " use_in_projection_bias=use_in_projection_bias,\n", + " use_conv_bias=use_conv_bias,\n", + " use_out_proj_bias=use_out_proj_bias,\n", + " ssm_use_delta_proj_bias=ssm_use_delta_proj_bias,\n", + " ssm_use_input_proj_bias=ssm_use_input_proj_bias,\n", + " key=subkeys[i + 1],\n", + " )\n", + " for i in range(n_layers)\n", + " ],\n", + " )\n", + " self.normalization = eqx.nn.RMSNorm(n_embd)\n", + "\n", + " embedding = eqx.nn.Embedding(n_dims, n_embd, key=subkeys[0])\n", + " lm_head = eqx.nn.Linear(\n", + " n_embd,\n", + " n_dims,\n", + " use_bias=False,\n", + " key=subkeys[-1],\n", + " )\n", + " where = lambda embed_and_lin: embed_and_lin[1].weight\n", + " get = lambda embed_and_lin: embed_and_lin[0].weight\n", + " self.shared_emb_lm_head = eqx.nn.Shared(\n", + " (embedding, lm_head), where=where, get=get\n", + " )\n", + "\n", + " def __call__(\n", + " self,\n", + " x: Int[Array, \"seq_len\"], # noqa\n", + " *,\n", + " key: Optional[PRNGKeyArray] = None,\n", + " ) -> Float[Array, \"seq_len n_dims\"]: # noqa\n", + " embedding, linear = self.shared_emb_lm_head()\n", + " x = jax.vmap(embedding)(x)\n", + "\n", + " x = self.layers(x)\n", + " x = jax.vmap(self.normalization)(x)\n", + " logits = jax.vmap(linear)(x)\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "id": "d557fa4a-1fa4-4e72-ba77-d0ef7f946cdf", + "metadata": {}, + "source": [ + "Note the usage of `eqx.nn.Shared`:\n", + "\n", + "```python\n", + " # Embedding layer\n", + " embedding = eqx.nn.Embedding(\n", + " n_dims, n_embd, key=subkeys[0]\n", + " )\n", + " # Linear layer\n", + " lm_head = eqx.nn.Linear(\n", + " n_embd,\n", + " n_dims,\n", + " use_bias=False,\n", + " key=subkeys[-1],\n", + " )\n", + " # refers to the linear weights\n", + " where = lambda embed_and_lin: embed_and_lin[1].weight \n", + "\n", + " # refers to the embedding weights\n", + " get = lambda embed_and_lin: embed_and_lin[0].weight\n", + "\n", + " # Create a shared layer\n", + " self.shared_emb_lm_head = eqx.nn.Shared(\n", + " (embedding, lm_head), where=where, get=get\n", + " )\n", + "```\n", + "\n", + "And to use the shared layers, we have to get them first out of the shared layer:\n", + "\n", + "```python\n", + " embedding, linear = self.shared_emb_lm_head()\n", + " # embedding and linear are eqx.nn.Embedding and eqx.nn.Linear respectively\n", + " # proceed usage as usual\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "af07491e-70ea-43b3-9139-0940caf1dd50", + "metadata": {}, + "source": [ + "Excellent! We have successfully implemented the Mamba model!" ] }, { "cell_type": "code", "execution_count": null, - "id": "1196448d-c659-44d2-9527-6b29723b2c59", + "id": "e96aabdf-081a-42b7-8784-2e1809f7562a", "metadata": {}, "outputs": [], "source": [] diff --git a/imgs/Mamba4.drawio.svg b/imgs/Mamba4.drawio.svg new file mode 100644 index 00000000..197638e3 --- /dev/null +++ b/imgs/Mamba4.drawio.svg @@ -0,0 +1,4 @@ + + + +
x
(seq_len, d_inner)
input_projection
(d_inner -> dt_rank + d_state * 2)
Split
delta
(seq_len, dt_rank)
B
(seq_len, d_state)
C
(seq_len, d_state)
delta_projection
(dt_rank -> d_inner)
delta
(seq_len, d_inner)
A
(d_inner, d_state)
D
(d_inner,)
selective_scan
x'
(seq_len, d_inner)
trainable
parameters
Softplus
(seq_len, d_inner)
\ No newline at end of file From 030bf7733adbbd2df48e6d1a8700c30f6f3f7963 Mon Sep 17 00:00:00 2001 From: "Artur A. Galstyan" Date: Mon, 4 Mar 2024 22:44:50 +0100 Subject: [PATCH 09/10] subkey wrong count --- examples/mamba.ipynb | 34 +++++++++++----------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/examples/mamba.ipynb b/examples/mamba.ipynb index 58a04a0c..bbc33de0 100644 --- a/examples/mamba.ipynb +++ b/examples/mamba.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "ab229dca-2c2a-46ee-8f24-eedce5c06e18", "metadata": {}, "outputs": [], @@ -106,22 +106,10 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "9efd5640-cdb6-427d-be3d-860cf371e357", "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'Float' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mselective_scan\u001b[39m(\n\u001b[0;32m----> 2\u001b[0m x: \u001b[43mFloat\u001b[49m[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseq_length d_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 3\u001b[0m delta: Float[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseq_length d_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 4\u001b[0m A: Float[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner d_state\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 5\u001b[0m B: Float[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseq_length d_state\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 6\u001b[0m C: Float[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseq_length d_state\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 7\u001b[0m D: Float[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m d_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 8\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Float[Array, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseq_length d_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[1;32m 9\u001b[0m L, d_inner \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 10\u001b[0m _, d_state \u001b[38;5;241m=\u001b[39m A\u001b[38;5;241m.\u001b[39mshape\n", - "\u001b[0;31mNameError\u001b[0m: name 'Float' is not defined" - ] - } - ], + "outputs": [], "source": [ "def selective_scan(\n", " x: Float[Array, \"seq_length d_inner\"],\n", @@ -152,7 +140,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "d0856563-ef82-4deb-91cf-f34f6c8793bb", "metadata": {}, "outputs": [], @@ -235,7 +223,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "6b57d346-18d0-4d29-b35d-e1a6e99c6a75", "metadata": {}, "outputs": [], @@ -326,7 +314,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "d1ed52a9-abea-43e8-822a-9d75ac8ae480", "metadata": {}, "outputs": [], @@ -405,7 +393,7 @@ " *,\n", " key: PRNGKeyArray,\n", " ):\n", - " key, *subkeys = jax.random.split(key, 1 + n_layers)\n", + " key, *subkeys = jax.random.split(key, 3 + n_layers)\n", " self.layers = eqx.nn.Sequential(\n", " [\n", " ResidualBlock(\n", @@ -418,19 +406,19 @@ " use_out_proj_bias=use_out_proj_bias,\n", " ssm_use_delta_proj_bias=ssm_use_delta_proj_bias,\n", " ssm_use_input_proj_bias=ssm_use_input_proj_bias,\n", - " key=subkeys[i + 1],\n", + " key=subkeys[i],\n", " )\n", " for i in range(n_layers)\n", " ],\n", " )\n", " self.normalization = eqx.nn.RMSNorm(n_embd)\n", "\n", - " embedding = eqx.nn.Embedding(n_dims, n_embd, key=subkeys[0])\n", + " embedding = eqx.nn.Embedding(n_dims, n_embd, key=subkeys[n_layers])\n", " lm_head = eqx.nn.Linear(\n", " n_embd,\n", " n_dims,\n", " use_bias=False,\n", - " key=subkeys[-1],\n", + " key=subkeys[n_layers + 1],\n", " )\n", " where = lambda embed_and_lin: embed_and_lin[1].weight\n", " get = lambda embed_and_lin: embed_and_lin[0].weight\n", @@ -463,7 +451,7 @@ "```python\n", " # Embedding layer\n", " embedding = eqx.nn.Embedding(\n", - " n_dims, n_embd, key=subkeys[0]\n", + " n_dims, n_embd, key=subkeys[n_layers]\n", " )\n", " # Linear layer\n", " lm_head = eqx.nn.Linear(\n", From 3504d77d30bfe07530befa0c4e432acacf1a7a53 Mon Sep 17 00:00:00 2001 From: Artur Galstyan Date: Sat, 23 Mar 2024 16:31:04 +0100 Subject: [PATCH 10/10] added docs and more examples --- docs/api/nn/state_spaces.md | 7 + equinox/nn/__init__.py | 3 + equinox/nn/_selective_state_space_models.py | 62 +++- examples/mamba.ipynb | 376 ++++++++++++++++++-- examples/score_based_diffusion.ipynb | 6 +- mkdocs.yml | 260 +++++++------- 6 files changed, 532 insertions(+), 182 deletions(-) create mode 100644 docs/api/nn/state_spaces.md diff --git a/docs/api/nn/state_spaces.md b/docs/api/nn/state_spaces.md new file mode 100644 index 00000000..d256c98a --- /dev/null +++ b/docs/api/nn/state_spaces.md @@ -0,0 +1,7 @@ +# State Spaces + +::: equinox.nn.SelectiveStateSpaceModel + selection: + members: + - __init__ + - __call__ diff --git a/equinox/nn/__init__.py b/equinox/nn/__init__.py index 8ab04a0d..9e940510 100644 --- a/equinox/nn/__init__.py +++ b/equinox/nn/__init__.py @@ -38,6 +38,9 @@ Pool as Pool, ) from ._rnn import GRUCell as GRUCell, LSTMCell as LSTMCell +from ._selective_state_space_models import ( + SelectiveStateSpaceModel as SelectiveStateSpaceModel, +) from ._sequential import ( Lambda as Lambda, Sequential as Sequential, diff --git a/equinox/nn/_selective_state_space_models.py b/equinox/nn/_selective_state_space_models.py index 2c1b74ec..1055a6d0 100644 --- a/equinox/nn/_selective_state_space_models.py +++ b/equinox/nn/_selective_state_space_models.py @@ -16,7 +16,7 @@ def _selective_scan( A: Float[Array, "d_inner state_space_dims"], B: Float[Array, "seq_len state_space_dims"], C: Float[Array, "seq_len state_space_dims"], - D: Float[Array, "d_inner"], # noqa + D: Float[Array, " d_inner"], ): seq_len, _ = u.shape d_inner, state_space_dims = A.shape @@ -39,13 +39,25 @@ def step(x, i): class SelectiveStateSpaceModel(Module, strict=True): - """ + r""" State Space Model with Selective Scan. This is the implementation of the Mamba Block from the paper "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" [1]. - [1] Albert Gu and Tri Dao, Mamba: Linear-Time Sequence Modeling - with Selective State Spaces, 2023 + + ??? cite + [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) + ```bibtex + @misc{ + gu2023mamba, + title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, + author={Albert Gu and Tri Dao}, + year={2023}, + eprint={2312.00752}, + archivePrefix={arXiv}, + primaryClass={cs.LG} + } + ``` """ n_input_dims: int = field(static=True) @@ -83,19 +95,19 @@ def __init__( *, key: PRNGKeyArray, ): - """ - Args: - n_input_dims: The dimension of the input. - state_space_dims: The dimension of the SSM (refers to 'N' in [1]). - expand: The expansion factor of the inner dimension (refers to 'E' in [1]). - d_conv: The kernel size of the convolutional layer - dt_rank: The rank of delta. If "auto", it will be - set to ceil(n_input_dims / state_space_dims). - pad_vocab_size_multiple: The multiple of the vocabulary size - use_bias_in_proj: Whether to use bias in the input projection layer. - use_bias_conv1d: Whether to use bias in the convolutional layer. - use_bias_out_proj: Whether to use bias in the output projection layer. - key: The PRNG key. + r"""**Arguments:** + + - `n_input_dims`: The dimension of the input. + - `state_space_dims`: The dimension of the SSM (refers to $N$ in [1]). + - `expand`: The expansion factor of the inner dimension (refers to $E$ in [1]). + - `d_conv`: The kernel size of the convolutional layer + - `dt_rank`: The rank of delta. If "auto", it will be set to + ceil(n_input_dims / state_space_dims). + - `pad_vocab_size_multiple`: The multiple of the vocabulary size + - `use_bias_in_proj`: Whether to use bias in the input projection layer. + - `use_bias_conv1d`: Whether to use bias in the convolutional layer. + - `use_bias_out_proj`: Whether to use bias in the output projection layer. + - `key`: The PRNG key. """ self.n_input_dims = n_input_dims @@ -148,8 +160,10 @@ def __init__( self.dt_rank, self.d_inner, use_bias=True, key=dt_proj_key ) - A = jnp.repeat(jnp.arange(1, self.state_space_dims + 1), self.d_inner).reshape( - self.d_inner, self.state_space_dims + A = ( + jnp.repeat(jnp.arange(1, self.state_space_dims + 1), self.d_inner) + .reshape(self.state_space_dims, self.d_inner) + .transpose() ) self.A_log = jnp.log(A) self.D = jnp.ones(self.d_inner) @@ -162,6 +176,16 @@ def __init__( @jax.named_scope("eqx.nn.SelectiveStateSpaceModel") def __call__(self, x: Float[Array, "seq_len n_input_dims"]) -> Array: + r"""**Arguments:** + + - `x`: The input sequence. Should be a JAX array of + shape `(seq_len, n_input_dims)`. + + **Returns:** + + - A JAX array of shape `(seq_len, n_input_dims)`. + + """ seq_len, d = x.shape if d != self.n_input_dims: raise ValueError( diff --git a/examples/mamba.ipynb b/examples/mamba.ipynb index bbc33de0..737ede4b 100644 --- a/examples/mamba.ipynb +++ b/examples/mamba.ipynb @@ -7,12 +7,12 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Optional\n", + "from typing import Any, Optional\n", "\n", "import equinox as eqx\n", "import jax\n", "import jax.numpy as jnp\n", - "from jaxtyping import Array, Float, Int, PRNGKeyArray" + "from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree" ] }, { @@ -31,6 +31,14 @@ "\n", "Special thanks and cretits go to John (Zhiyao) Ma and his excellent Mamba implementation in PyTorch, which served as a great inspriration and foundation for this Equinox version. Go check it out [here](https://github.com/johnma2006/mamba-minimal).\n", "\n", + "Author: [Artur Galstyan](https://github.com/artur-galstyan)" + ] + }, + { + "cell_type": "markdown", + "id": "9f1861d8-d012-46d0-9a72-23d940674d2d", + "metadata": {}, + "source": [ "The original implementation includes **a lot** of CUDA code [[2]](https://github.com/state-spaces/mamba) to optimise the so-called `selective_scan` algorithm, but this first iteration of the `SelectiveStateSpaceModel` implementation is not as heavily optimised. However, in future iterations, by using some clever Pallas code, we can get to the same performance. " ] }, @@ -47,9 +55,7 @@ "id": "660315df-3e77-4e3c-a793-c153840f094c", "metadata": {}, "source": [ - "
\n", - " \n", - "
" + "![Mamba](../imgs/Mamba1.drawio.svg)" ] }, { @@ -65,9 +71,7 @@ "id": "6b4b5de1-8300-4fc4-934e-3b0194bf8372", "metadata": {}, "source": [ - "
\n", - " \n", - "
\n", + "![Mamba](../imgs/Mamba2.drawio.svg)\n", "\n", "As you can see, we keep diving further into the model. Let's implement this `ResidualBlock` now. Let's keep on zooming until we get to the deepest component - at which point we can start to implement everything and work our way back up. Let's keep going." ] @@ -79,9 +83,7 @@ "source": [ "We're getting closer and closer to the heart of the Mamba model. Let's look at what the `MambaBlock` looks like. This time, I've included the shapes of the matrices as they traverse through all kinds of transformations. \n", "\n", - "
\n", - " \n", - "
" + "![Mamba](../imgs/Mamba3.drawio.svg)" ] }, { @@ -91,9 +93,7 @@ "source": [ "Most of the parts we need are already present in Equinox's library. What's missing though is the new `SelectiveStateSpaceModel` (abbreviated as `SSM` above). Everything in green are trainable parameters. \n", "\n", - "
\n", - " \n", - "
" + "![Mamba](../imgs/Mamba4.drawio.svg)" ] }, { @@ -183,7 +183,11 @@ " self.delta_proj = eqx.nn.Linear(\n", " dt_rank, d_inner, use_bias=use_delta_proj_bias, key=delta_proj_key\n", " )\n", - " A = jnp.repeat(jnp.arange(1, d_state + 1), d_inner).reshape(d_inner, d_state)\n", + " A = (\n", + " jnp.repeat(jnp.arange(1, d_state + 1), d_inner)\n", + " .reshape(d_state, d_inner)\n", + " .transpose()\n", + " )\n", " self.A_log = jnp.log(A)\n", " self.D = jnp.ones(d_inner)\n", "\n", @@ -204,21 +208,12 @@ " return y" ] }, - { - "cell_type": "markdown", - "id": "eb44075c-e70c-4d93-8f6b-1b160cc7e7dd", - "metadata": {}, - "source": [ - "## Detour: State Space Models\n", - "___TODO___: Explain SSMs in general!" - ] - }, { "cell_type": "markdown", "id": "2b48addc-b78f-429b-9299-d1e283f3bd76", "metadata": {}, "source": [ - "Armed with the `SSM`, we can now implement the `MambaBlock` part. See the images above for where we are right now!" + "Armed with the `SSM`, we can now implement the `MambaBlock` part. See the images above for where we are right now! Thankfully, Equinox also includes the implementation above, which we can access from `eqx.nn.SelectiveStateSpaceModel`. " ] }, { @@ -367,7 +362,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "91463596-114f-45f7-a2f0-b1ae8d16f25f", "metadata": {}, "outputs": [], @@ -486,16 +481,335 @@ "id": "af07491e-70ea-43b3-9139-0940caf1dd50", "metadata": {}, "source": [ - "Excellent! We have successfully implemented the Mamba model!" + "Excellent! We have successfully implemented the Mamba model! From here we can train the model for example on the [TinyShakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset. We can use [Jaxonloader](https://github.com/Artur-Galstyan/jaxonloader), which provides the necessary preprocessing steps for us, so we can start training! Let's install it." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "24888c56-1c00-48f4-9e83-c7eaf5ce7b94", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[2mAudited \u001b[1m2 packages\u001b[0m in 16ms\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/arturgalstyan/.pyenv/versions/3.11.8/lib/python3.11/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + } + ], + "source": [ + "!uv pip install jaxonloader optax\n", + "# or simply\n", + "# !pip install jaxonloader optax\n", + "# if you don't use uv" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6edf447f-99cb-4335-829b-bfe293296b57", + "metadata": {}, + "outputs": [], + "source": [ + "import functools as ft\n", + "import math\n", + "from collections.abc import Callable\n", + "\n", + "import optax\n", + "from jaxonloader import (\n", + " get_tiny_shakespeare,\n", + " Index,\n", + " JaxonDataLoader,\n", + " JITJaxonDataLoader,\n", + " make,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d760872d-8f4d-4522-90d4-591e88e4b563", + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset, test_dataset, vocab_size, encode, decode = get_tiny_shakespeare()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2de3d0a4-601e-4ee6-a314-ee4c3e66a0f2", + "metadata": {}, + "outputs": [], + "source": [ + "def train(\n", + " train_dataloader: JaxonDataLoader | JITJaxonDataLoader,\n", + " train_index: Index,\n", + " learning_rate: float,\n", + " model: PyTree,\n", + " key: PRNGKeyArray,\n", + " early_stop: int | None = None,\n", + " log_every: Optional[int] = 100,\n", + ") -> PyTree:\n", + " optimizer = optax.adamw(learning_rate=learning_rate)\n", + " opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))\n", + " loss_value = 0\n", + " i = 0\n", + " while it := train_dataloader(train_index):\n", + " x, train_index, done = it\n", + " if done:\n", + " break\n", + " x, y = jnp.split(x, 2, axis=1)\n", + " key, subkey = jax.random.split(key)\n", + " model, opt_state, loss_value = step(\n", + " model, opt_state, x, y, optimizer, key=subkey\n", + " )\n", + " if log_every is not None and i % log_every == 0:\n", + " print(f\"Loss: {loss_value}\")\n", + " if early_stop is not None and i > early_stop:\n", + " break\n", + "\n", + " i += 1\n", + " print(\"Finished training\")\n", + " print(f\"Final loss: {loss_value}\")\n", + " return model\n", + "\n", + "\n", + "def loss_fn(\n", + " model: PyTree,\n", + " x: Int[Array, \"batch_size max_seq_len n_dims\"],\n", + " labels: Int[Array, \"batch_size max_seq_len n_dims\"],\n", + " key: Optional[PRNGKeyArray],\n", + ") -> Array:\n", + " partial_model = ft.partial(model, key=key)\n", + " logits = eqx.filter_vmap(partial_model)(x)\n", + " return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, labels))\n", + "\n", + "\n", + "@eqx.filter_jit\n", + "def step(\n", + " model: PyTree,\n", + " opt_state: PyTree,\n", + " x: Array,\n", + " y: Array,\n", + " optimizer: optax.GradientTransformation,\n", + " key: PRNGKeyArray,\n", + ") -> tuple[PyTree, PyTree, Any]:\n", + " loss, grads = eqx.filter_value_and_grad(loss_fn)(model, x, y, key)\n", + " updates, opt_state = optimizer.update(grads, opt_state, model)\n", + " model = eqx.apply_updates(model, updates)\n", + "\n", + " return model, opt_state, loss" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ae127d75-45d1-47b1-899c-6190632faf9d", + "metadata": {}, + "outputs": [], + "source": [ + "train_dataloader, index = make(train_dataset, batch_size=64, jit=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "992acff9-5eb4-4aa3-ba2b-8b13bddd4f49", + "metadata": {}, + "outputs": [], + "source": [ + "def train_mamba(\n", + " train_dataloader,\n", + " train_index,\n", + " n_dims: int,\n", + " n_embd: int,\n", + " expand: int,\n", + " d_state: int,\n", + " n_layers: int,\n", + " d_conv: int,\n", + " learning_rate: float,\n", + " early_stop: int,\n", + " key: PRNGKeyArray,\n", + "):\n", + " mamba = Mamba(\n", + " n_layers=n_layers,\n", + " n_dims=n_dims,\n", + " n_embd=n_embd,\n", + " d_inner=int(expand * n_embd),\n", + " dt_rank=math.ceil(n_embd / d_state),\n", + " d_conv=d_conv,\n", + " key=key,\n", + " )\n", + " key, subkey = jax.random.split(key)\n", + " mamba = train(\n", + " train_dataloader,\n", + " train_index,\n", + " learning_rate,\n", + " mamba,\n", + " subkey,\n", + " early_stop=early_stop,\n", + " )\n", + " return mamba" ] }, { "cell_type": "code", - "execution_count": null, - "id": "e96aabdf-081a-42b7-8784-2e1809f7562a", + "execution_count": 13, + "id": "88679246-7eff-491b-a115-64be758318f8", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "n_dims = vocab_size\n", + "n_embd = 64\n", + "learning_rate = 3e-4\n", + "num_heads = 4\n", + "n_layers = 3\n", + "d_state = 16\n", + "d_conv = 4\n", + "expand = 2\n", + "early_stop = 1000\n", + "key = jax.random.PRNGKey(222)" + ] + }, + { + "cell_type": "markdown", + "id": "445edf0c-e48b-477b-a890-48ac6118fe85", + "metadata": {}, + "source": [ + "You probably won't get very far using those parameters above and you'll have to increase these numbers to achieve greater performance. \n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "592458f1-9680-4cb5-9173-3bf5452d58bb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loss: 59.77830123901367\n", + "Loss: 4.061922073364258\n", + "Loss: 2.9172539710998535\n", + "Loss: 2.7635631561279297\n", + "Loss: 2.735884189605713\n", + "Loss: 2.3609986305236816\n", + "Loss: 2.490962505340576\n", + "Loss: 2.332058906555176\n", + "Loss: 2.2114100456237793\n", + "Loss: 2.3747291564941406\n", + "Loss: 2.366851329803467\n", + "Finished training\n", + "Final loss: 2.29056978225708\n" + ] + } + ], + "source": [ + "mamba = train_mamba(\n", + " train_dataloader,\n", + " index,\n", + " n_dims,\n", + " n_embd,\n", + " expand,\n", + " d_state,\n", + " n_layers,\n", + " d_conv,\n", + " learning_rate,\n", + " early_stop,\n", + " key\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "eb6a9732-7d0e-4593-8a5d-46a26b3f5e58", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_text(\n", + " model: PyTree,\n", + " max_seq_len: int,\n", + " max_new_tokens: int,\n", + " decode: Callable,\n", + " vocab_size: int,\n", + " print_to_console: bool = True,\n", + " random_key_seed: int = 0,\n", + ") -> tuple[list[str], list[int]]:\n", + " jitted_model = eqx.filter_jit(model)\n", + " x = jnp.zeros((max_seq_len,), dtype=jnp.int32)\n", + " key = jax.random.PRNGKey(random_key_seed)\n", + " tokens = []\n", + " decoded_tokens = []\n", + " for _ in range(max_new_tokens):\n", + " key, subkey, model_key = jax.random.split(key, 3)\n", + " logits = jitted_model(x, key=model_key)\n", + " logits = logits[-1, :]\n", + " probs = jax.nn.softmax(logits, axis=-1)\n", + "\n", + " next_token = jax.random.choice(\n", + " subkey,\n", + " jnp.arange(len(probs)),\n", + " p=probs,\n", + " )\n", + " next_token = jnp.array(next_token, dtype=jnp.int32).reshape((1,))\n", + " next_token = min(next_token.item(), vocab_size - 1)\n", + "\n", + " if print_to_console:\n", + " print(decode([next_token]), end=\"\")\n", + "\n", + " tokens.append(next_token)\n", + " decoded_tokens.append(decode([next_token]))\n", + "\n", + " x = jnp.concatenate((x[1:], jnp.array([next_token])))\n", + " return decoded_tokens, tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a60c579b-d18d-49c8-b87f-2b24daf5d442", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OfRT:\n", + "Thet buce,\n", + "Whaduti fot se w,\n", + "LI:\n", + "An\n", + "Ber?\n", + "\n", + "NOLIOLONULLEO:\n", + "Os to myse, fFieln torm's\n", + "Durme mence th thates myt hit bere hou ps prore onoth but to come beat inwas notven bestot bucn mend not thats " + ] + } + ], + "source": [ + "text = generate_text(mamba, 8, 200, decode, vocab_size) # noqa" + ] + }, + { + "cell_type": "markdown", + "id": "4db62c07-d466-4fcd-a7d1-031cb300ed78", + "metadata": {}, + "source": [ + "Truly, a magnificient masterpiece. " + ] } ], "metadata": { @@ -514,7 +828,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/examples/score_based_diffusion.ipynb b/examples/score_based_diffusion.ipynb index 9dc062dd..02b335e7 100644 --- a/examples/score_based_diffusion.ipynb +++ b/examples/score_based_diffusion.ipynb @@ -508,9 +508,9 @@ ], "metadata": { "kernelspec": { - "display_name": "jax0227", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "jax0227" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -522,7 +522,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/mkdocs.yml b/mkdocs.yml index bf5bda76..07a99afd 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,34 +1,34 @@ theme: - name: material - features: - - navigation.sections # Sections are included in the navigation on the left. - - toc.integrate # Table of contents is integrated on the left; does not appear separately on the right. - - header.autohide # header disappears as you scroll - palette: - # Light mode / dark mode - # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as - # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. - - scheme: default - primary: white - accent: amber - toggle: - icon: material/weather-night - name: Switch to dark mode - - scheme: slate - primary: black - accent: amber - toggle: - icon: material/weather-sunny - name: Switch to light mode - icon: - repo: fontawesome/brands/github # GitHub logo in top right - logo: "material/circle-opacity" # Equinox logo in top left - favicon: "_static/favicon.png" - custom_dir: "docs/_overrides" # Overriding part of the HTML + name: material + features: + - navigation.sections # Sections are included in the navigation on the left. + - toc.integrate # Table of contents is integrated on the left; does not appear separately on the right. + - header.autohide # header disappears as you scroll + palette: + # Light mode / dark mode + # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as + # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. + - scheme: default + primary: white + accent: amber + toggle: + icon: material/weather-night + name: Switch to dark mode + - scheme: slate + primary: black + accent: amber + toggle: + icon: material/weather-sunny + name: Switch to light mode + icon: + repo: fontawesome/brands/github # GitHub logo in top right + logo: "material/circle-opacity" # Equinox logo in top left + favicon: "_static/favicon.png" + custom_dir: "docs/_overrides" # Overriding part of the HTML - # These additions are my own custom ones, having overridden a partial. - twitter_name: "@PatrickKidger" - twitter_url: "https://twitter.com/PatrickKidger" + # These additions are my own custom ones, having overridden a partial. + twitter_name: "@PatrickKidger" + twitter_url: "https://twitter.com/PatrickKidger" site_name: Equinox site_description: The documentation for the Equinox software library. @@ -37,114 +37,116 @@ site_url: https://docs.kidger.site/equinox repo_url: https://github.com/patrick-kidger/equinox repo_name: patrick-kidger/equinox -edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate +edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate -strict: true # Don't allow warnings during the build process +strict: true # Don't allow warnings during the build process -extra_javascript: - # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ - - _static/mathjax.js - - https://polyfill.io/v3/polyfill.min.js?features=es6 - - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js +extra_javascript: + # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ + - _static/mathjax.js + - https://polyfill.io/v3/polyfill.min.js?features=es6 + - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js extra_css: - - _static/custom_css.css + - _static/custom_css.css markdown_extensions: - - pymdownx.arithmatex: # Render LaTeX via MathJax - generic: true - - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. - - pymdownx.details # Allowing hidden expandable regions denoted by ??? - - pymdownx.snippets: # Include one Markdown file into another - base_path: docs - - admonition - - toc: - permalink: "¤" # Adds a clickable permalink to each section heading - toc_depth: 4 + - pymdownx.arithmatex: # Render LaTeX via MathJax + generic: true + - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. + - pymdownx.details # Allowing hidden expandable regions denoted by ??? + - pymdownx.snippets: # Include one Markdown file into another + base_path: docs + - admonition + - toc: + permalink: "¤" # Adds a clickable permalink to each section heading + toc_depth: 4 plugins: - - search # default search plugin; needs manually re-enabling when using any other plugins - - autorefs # Cross-links to headings - - include_exclude_files: - include: - - ".htaccess" - exclude: - - "_overrides" - - "examples/MNIST" - - "examples/bert_checkpoint.eqx" - - mknotebooks # Jupyter notebooks - - mkdocstrings: - handlers: - python: - setup_commands: - - import pytkdocs_tweaks - - pytkdocs_tweaks.main() - - import jaxtyping - - jaxtyping.set_array_name_format("array") - - import jax - - jax.ShapeDtypeStruct.__module__ = "jax" - - jax.core.ClosedJaxpr.__module__ = "jax.core" + - search # default search plugin; needs manually re-enabling when using any other plugins + - autorefs # Cross-links to headings + - include_exclude_files: + include: + - ".htaccess" + exclude: + - "_overrides" + - "examples/MNIST" + - "examples/bert_checkpoint.eqx" + - mknotebooks # Jupyter notebooks + - mkdocstrings: + handlers: + python: + setup_commands: + - import pytkdocs_tweaks + - pytkdocs_tweaks.main() + - import jaxtyping + - jaxtyping.set_array_name_format("array") + - import jax + - jax.ShapeDtypeStruct.__module__ = "jax" + - jax.core.ClosedJaxpr.__module__ = "jax.core" - selection: - inherited_members: true # Allow looking up inherited methods - rendering: - show_root_heading: true # actually display anything at all... - show_root_full_path: true # display "diffrax.asdf" not just "asdf" - show_if_no_docstring: true - show_signature_annotations: true - show_source: false # don't include source code - members_order: source # order methods according to their order of definition in the source code, not alphabetical order - heading_level: 4 + selection: + inherited_members: true # Allow looking up inherited methods + rendering: + show_root_heading: true # actually display anything at all... + show_root_full_path: true # display "diffrax.asdf" not just "asdf" + show_if_no_docstring: true + show_signature_annotations: true + show_source: false # don't include source code + members_order: source # order methods according to their order of definition in the source code, not alphabetical order + heading_level: 4 nav: - - 'index.md' - - 'all-of-equinox.md' - - Examples: - - Introductory: - - CNN on MNIST: 'examples/mnist.ipynb' - - Train RNN: 'examples/train_rnn.ipynb' - - Advanced: - - Generative score-based diffusion: 'examples/score_based_diffusion.ipynb' - - BERT language model: 'examples/bert.ipynb' - - U-Net implementation: 'examples/unet.ipynb' - - Vision transformer: 'examples/vision_transformer.ipynb' - - Image GAN: 'examples/deep_convolutional_gan.ipynb' - - Features: - - Freezing parameters: 'examples/frozen_layer.ipynb' - - Compatibility with init-apply libraries: 'examples/init_apply.ipynb' - - Stateful operations (e.g. BatchNorm): 'examples/stateful.ipynb' - - Autoparallelism (e.g. multi-GPU): 'examples/parallelism.ipynb' - - Serialisation (with hyperparameters): 'examples/serialisation.ipynb' - - Basic API: - - Modules: - - 'api/module/module.md' - - 'api/module/advanced_fields.md' - - Neural network layers: - - 'api/nn/linear.md' - - 'api/nn/conv.md' - - 'api/nn/rnn.md' - - 'api/nn/attention.md' - - 'api/nn/activations.md' - - 'api/nn/pool.md' - - 'api/nn/dropout.md' - - 'api/nn/normalisation.md' - - 'api/nn/embedding.md' - - 'api/nn/mlp.md' - - 'api/nn/sequential.md' - - 'api/nn/inference.md' - - 'api/nn/shared.md' - - 'api/nn/stateful.md' - - 'api/transformations.md' - - 'api/manipulation.md' - - Advanced API: - - 'api/caches.md' - - 'api/debug.md' - - 'api/enumerations.md' - - 'api/errors.md' - - 'api/pretty-printing.md' - - 'api/serialisation.md' - - Misc: - - 'faq.md' - - 'tricks.md' - # - 'pattern.md' - - 'citation.md' + - "index.md" + - "all-of-equinox.md" + - Examples: + - Introductory: + - CNN on MNIST: "examples/mnist.ipynb" + - Train RNN: "examples/train_rnn.ipynb" + - Advanced: + - Generative score-based diffusion: "examples/score_based_diffusion.ipynb" + - BERT language model: "examples/bert.ipynb" + - U-Net implementation: "examples/unet.ipynb" + - Vision transformer: "examples/vision_transformer.ipynb" + - Image GAN: "examples/deep_convolutional_gan.ipynb" + - MAMBA: "examples/mamba.ipynb" + - Features: + - Freezing parameters: "examples/frozen_layer.ipynb" + - Compatibility with init-apply libraries: "examples/init_apply.ipynb" + - Stateful operations (e.g. BatchNorm): "examples/stateful.ipynb" + - Autoparallelism (e.g. multi-GPU): "examples/parallelism.ipynb" + - Serialisation (with hyperparameters): "examples/serialisation.ipynb" + - Basic API: + - Modules: + - "api/module/module.md" + - "api/module/advanced_fields.md" + - Neural network layers: + - "api/nn/linear.md" + - "api/nn/conv.md" + - "api/nn/rnn.md" + - "api/nn/attention.md" + - "api/nn/state_spaces.md" + - "api/nn/activations.md" + - "api/nn/pool.md" + - "api/nn/dropout.md" + - "api/nn/normalisation.md" + - "api/nn/embedding.md" + - "api/nn/mlp.md" + - "api/nn/sequential.md" + - "api/nn/inference.md" + - "api/nn/shared.md" + - "api/nn/stateful.md" + - "api/transformations.md" + - "api/manipulation.md" + - Advanced API: + - "api/caches.md" + - "api/debug.md" + - "api/enumerations.md" + - "api/errors.md" + - "api/pretty-printing.md" + - "api/serialisation.md" + - Misc: + - "faq.md" + - "tricks.md" + # - 'pattern.md' + - "citation.md"