diff --git a/docs/api/filtering/transformations.md b/docs/api/filtering/transformations.md index c50457f9..51673616 100644 --- a/docs/api/filtering/transformations.md +++ b/docs/api/filtering/transformations.md @@ -8,17 +8,17 @@ Most users find that this is a simpler API when working with complicated PyTrees Likewise, `eqx.filter_grad` will automatically differentiate all floating-point JAX arrays and treat the rest nondifferentiably, etc. Each transformation here just combines [`equinox.partition`][], `jax.{jit, ...}` and [`equinox.combine`][] together. -## Just-in-time compilation +## Compilation ::: equinox.filter_jit --- -::: equinox.filter_make_jaxpr +::: equinox.filter_device_put --- -::: equinox.filter_eval_shape +::: equinox.filter_with_sharding_constraint ## Automatic differentiation @@ -48,10 +48,6 @@ Likewise, `eqx.filter_grad` will automatically differentiate all floating-point ::: equinox.filter_custom_vjp ---- - -::: equinox.filter_closure_convert - ## Vectorisation and parallelisation ::: equinox.filter_vmap @@ -63,3 +59,17 @@ Likewise, `eqx.filter_grad` will automatically differentiate all floating-point ## Callbacks ::: equinox.filter_pure_callback + +# Tracing + +--- + +::: equinox.filter_make_jaxpr + +--- + +::: equinox.filter_eval_shape + +--- + +::: equinox.filter_closure_convert diff --git a/equinox/__init__.py b/equinox/__init__.py index a39b0406..e95472e5 100644 --- a/equinox/__init__.py +++ b/equinox/__init__.py @@ -49,6 +49,10 @@ tree_deserialise_leaves as tree_deserialise_leaves, tree_serialise_leaves as tree_serialise_leaves, ) +from ._sharding import ( + filter_device_put as filter_device_put, + filter_with_sharding_constraint as filter_with_sharding_constraint, +) from ._tree import ( tree_at as tree_at, tree_check as tree_check, diff --git a/equinox/_sharding.py b/equinox/_sharding.py new file mode 100644 index 00000000..aa8fee52 --- /dev/null +++ b/equinox/_sharding.py @@ -0,0 +1,65 @@ +from typing import Any + +import jax +import jax.lax as lax +from jaxtyping import PyTree + +from ._filters import combine, is_array, partition + + +def filter_with_sharding_constraint(x: PyTree[Any], shardings): + """Filtered version of `jax.lax.with_sharding_constraint`. Enforces sharding within + a JIT'd computation. (That is, how an array is split between multiple devices, i.e. + multiple GPUs/TPUs.) + + This should always be called *inside* of a JIT'd computation. + + This is a strict constraint for the XLA compiler, and not just a hint. It is + typically placed on the inputs of JIT'd computations to assert that they are sharded + in the correct way, and on the output of JIT'd computations to specify how they + should be sharded. + + **Arguments:** + + - `x`: A PyTree, with potentially a mix of arrays and non-arrays on the leaves. They + will have their shardings constrained. + - `shardings`: a PyTree of sharding specifications. The structure should be a prefix + of `x`. + + **Returns:** + + A copy of `x` with the specified sharding constraints. + + !!! Example + + See also the [autoparallelism example](../../../examples/parallelism). + """ + dynamic, static = partition(x, is_array) + dynamic = lax.with_sharding_constraint(dynamic, shardings) + return combine(dynamic, static) + + +def filter_device_put(x: PyTree[Any], device): + """Filtered version of `jax.device_put`. Places all arrays in `x` on the device. + Non-arrays are unchanged. + + This should always be called *outside* of a JIT'd computation. + + **Arguments:** + + - `x`: A PyTree, with potentially a mix of arrays and non-arrays on the leaves. + - `device`: A specification for how to place `x` on a device. Most typically this is + either a `Device` (as returned by `jax.local_devices`) or a sharding (usually a + `jax.sharding.NamedSharding` or `jax.sharding.PositionalSharding`). + + **Returns:** + + A copy of `x` that resides on `device`. + + !!! Example + + See also the [autoparallelism example](../../../examples/parallelism). + """ + dynamic, static = partition(x, is_array) + dynamic = jax.device_put(dynamic, device) + return combine(dynamic, static) diff --git a/examples/parallelism.ipynb b/examples/parallelism.ipynb index d6be6dd5..ecfd5431 100644 --- a/examples/parallelism.ipynb +++ b/examples/parallelism.ipynb @@ -11,20 +11,16 @@ "\n", "JAX has a number of advanced APIs to support this. The main technique is to \"shard\" an array, so that each device holds part of the array.\n", "\n", - "In this example, we'll parallelise our computation (usually it's a training step) over 8 devices, so that each device gets 1/8 of the batch of data." - ] - }, - { - "cell_type": "markdown", - "id": "f2a0bae8-2435-4b37-b1f2-24322cfeb1dd", - "metadata": {}, - "source": [ - "First let's import everything, and set up our toy problem." + "In this example, we'll perform *data parallelism*. Each GPU gets a fully copy of our model on every GPU, and some fraction of the batch of input data. (In this case it will be 1/8 of the data, as this example is written for 8 GPUs.)\n", + "\n", + "---\n", + "\n", + "First let's import everything, and set up our toy problem. Everything here will be exactly the same as when running on a single device." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "83bba892-5425-4eed-a7f7-9c325fe5cc53", "metadata": {}, "outputs": [], @@ -34,7 +30,7 @@ "import jax.experimental.mesh_utils as mesh_utils\n", "import jax.numpy as jnp\n", "import jax.random as jr\n", - "import jax.sharding as sharding\n", + "import jax.sharding as jshard\n", "import numpy as np\n", "import optax # https://github.com/deepmind/optax\n", "\n", @@ -45,7 +41,7 @@ "hidden_size = 32\n", "depth = 1\n", "learning_rate = 3e-4\n", - "num_steps = 10\n", + "num_steps = 50\n", "batch_size = 16 # must be a multiple of our number of devices.\n", "\n", "# Generate some synthetic data\n", @@ -57,49 +53,67 @@ "opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))\n", "\n", "\n", + "# Loss function for a batch of data\n", "def compute_loss(model, x, y):\n", " pred_y = jax.vmap(model)(x)\n", " return jnp.mean((y - pred_y) ** 2)\n", "\n", "\n", - "@eqx.filter_jit\n", - "def make_step(model, opt_state, x, y):\n", - " grads = eqx.filter_grad(compute_loss)(model, x, y)\n", - " updates, opt_state = optim.update(grads, opt_state)\n", - " model = eqx.apply_updates(model, updates)\n", - " return model, opt_state" + "# Simple dataloader; randomly slices our dataset and shuffles between epochs.\n", + "# In NumPy for speed, as our dataset is small enough to fit entirely in host memory.\n", + "#\n", + "# For larger datasets (that require loading from disk) then use PyTorch's `DataLoader`\n", + "# or TensorFlow's `tf.data`.\n", + "def train_dataloader(arrays, batch_size):\n", + " dataset_size = arrays[0].shape[0]\n", + " assert all(array.shape[0] == dataset_size for array in arrays)\n", + " indices = np.arange(dataset_size)\n", + " while True:\n", + " perm = np.random.permutation(indices)\n", + " start = 0\n", + " end = batch_size\n", + " while end <= dataset_size:\n", + " batch_perm = perm[start:end]\n", + " yield tuple(array[batch_perm] for array in arrays)\n", + " start = end\n", + " end = start + batch_size" ] }, { "cell_type": "markdown", - "id": "0fb345b0-c9b3-44df-94e8-d74c7ad172b8", + "id": "50e663d7-fc20-483d-8d87-c1fe6c773f8c", "metadata": {}, "source": [ - "Here's a very simple dataloader, that randomly shuffles and slices our dataset. We keep everything in pure-NumPy for speed, as this all happens on the host, prior to moving our data to our devices. (Which will often be a cluster of GPUs.)\n", + "Okay, now the interesting things start happening!\n", + "\n", + "First, we're going to arrange to \"donate\" memory, which specifes that we can re-use the memory for our input arrays (e.g. model parameters) to store the output arrays (e.g. updated model parameters). (This isn't technically related to autoparallelism, but it's good practice so you should do it anyway :) )\n", "\n", - "In practice it's also common to load data using either PyTorch's `DataLoader` or TensorFlow's `tf.data` API; see [here](../mnist/) for more details." + "Second, we're going to use `filter_with_sharding_constraint` to assert (on the inputs) and enforce (on the outputs) how each array is split across each of our devices. As we're doing data parallelism in this example, then we'll be replicating our model parameters on to every device, whilst sharding our data between devices." ] }, { "cell_type": "code", - "execution_count": 3, - "id": "fd94db04-9fe4-4530-808e-945becef9df5", + "execution_count": 2, + "id": "15b2cec2-5925-4f04-89fb-79761288cf14", "metadata": {}, "outputs": [], "source": [ - "def dataloader(arrays, batch_size):\n", - " dataset_size = arrays[0].shape[0]\n", - " assert all(array.shape[0] == dataset_size for array in arrays)\n", - " indices = np.arange(dataset_size)\n", - " while True:\n", - " perm = np.random.permutation(indices)\n", - " start = 0\n", - " end = batch_size\n", - " while end <= dataset_size:\n", - " batch_perm = perm[start:end]\n", - " yield tuple(array[batch_perm] for array in arrays)\n", - " start = end\n", - " end = start + batch_size" + "@eqx.filter_jit(donate=\"all\")\n", + "def train_step(model, opt_state, x, y, sharding):\n", + " replicated = sharding.replicate()\n", + " model, opt_state = eqx.filter_with_sharding_constraint(\n", + " (model, opt_state), replicated\n", + " )\n", + " x, y = eqx.filter_with_sharding_constraint((x, y), sharding)\n", + "\n", + " grads = eqx.filter_grad(compute_loss)(model, x, y)\n", + " updates, opt_state = optim.update(grads, opt_state)\n", + " model = eqx.apply_updates(model, updates)\n", + "\n", + " model, opt_state = eqx.filter_with_sharding_constraint(\n", + " (model, opt_state), replicated\n", + " )\n", + " return model, opt_state" ] }, { @@ -112,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "32c6b58e-f72f-4dd4-bf2c-f1dc75643eda", "metadata": { "tags": [] @@ -121,11 +135,65 @@ "source": [ "num_devices = len(jax.devices())\n", "devices = mesh_utils.create_device_mesh((num_devices, 1))\n", - "shard = sharding.PositionalSharding(devices)\n", - "\n", - "for step, (x, y) in zip(range(num_steps), dataloader((xs, ys), batch_size)):\n", - " x, y = jax.device_put((x, y), shard)\n", - " model, opt_state = make_step(model, opt_state, x, y)" + "sharding = jshard.PositionalSharding(devices)\n", + "replicated = sharding.replicate()\n", + "\n", + "model = eqx.filter_device_put(model, replicated)\n", + "for step, (x, y) in zip(\n", + " range(1, num_steps + 1), train_dataloader((xs, ys), batch_size)\n", + "):\n", + " x, y = eqx.filter_device_put((x, y), sharding)\n", + " model, opt_state = train_step(model, opt_state, x, y, sharding)" + ] + }, + { + "cell_type": "markdown", + "id": "582db8d5-1d5c-4502-b95f-1c951d6c822c", + "metadata": {}, + "source": [ + "Not strictly related to parallelism, but a common question at this point: if we want to evaluate our model, then we probably don't want to donate its parameters (which would render the model unusable, as all its memory is freed). As such, inference looks like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9153b1c2-66f7-47cc-80d0-2f2a9e6d53d4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train loss=0.47004854679107666\n" + ] + } + ], + "source": [ + "def eval_dataloader(arrays, batch_size):\n", + " dataset_size = arrays[0].shape[0]\n", + " assert all(array.shape[0] == dataset_size for array in arrays)\n", + " start = 0\n", + " end = batch_size\n", + " while start < dataset_size:\n", + " yield tuple(array[start:end] for array in arrays)\n", + " start = end\n", + " end = start + batch_size\n", + "\n", + "\n", + "@eqx.filter_jit(donate=\"all-except-first\")\n", + "def evaluate(model, x, y, sharding):\n", + " replicated = sharding.replicate()\n", + " model = eqx.filter_with_sharding_constraint(model, replicated)\n", + " x, y = eqx.filter_with_sharding_constraint((x, y), sharding)\n", + " return compute_loss(model, x, y)\n", + "\n", + "\n", + "loss = 0\n", + "num_batches = 0\n", + "for x, y in eval_dataloader((xs, ys), batch_size):\n", + " loss = loss + evaluate(model, x, y, sharding).item()\n", + " num_batches = num_batches + 1\n", + "print(f\"train loss={loss/num_batches}\")" ] }, { @@ -139,11 +207,9 @@ "\n", "If you ran the above example on a cluster of NVIDIA GPUs, then you can check whether you're using as many GPUs as you expected by running `nvidia-smi` from the command line. You can also use `jax.debug.visualize_array_sharding(array)` to inspect the sharding manually.\n", "\n", - "One possible optimisation here is to re-use the memory used by the input arrays, to store the output arrays. This often improves speed a little bit. This is disabled by default, but can be enabled by passing `eqx.filter_jit(donate=\"all\")`.\n", - "\n", "**What about pmap?**\n", "\n", - "The JAX team have been hard at work introducing these new easy-to-use parallelism features, based around JIT and sharding. These are often faster and more expressive than pmap, so pmap is no longer recommended!\n", + "The JAX team have been hard at work introducing these new easy-to-use parallelism features, based around JIT and sharding. These are often faster and more expressive than pmap, so pmap is no longer recommended.\n", "\n", "**Types of parallelism**\n", "\n", @@ -151,6 +217,10 @@ "\n", "For completeness we note that there are other kinds of parallelism available -- e.g. model parallelism, which instead places different parts of the model on different devices. A discussion on those is a more advanced topic. :)\n", "\n", + "**`jax.device_put` vs `eqx.filter_device_put`, `jax.lax.with_sharding_constraint` vs `eqx.filter_with_sharding_constraint`**\n", + "\n", + "These are the usual story in Equinox: we have a filtered version of the operation that leaves any non-arrays alone. In this case, they are used because we have an activation function (i.e. just some arbitrary Python function, which isn't an array) as part of the MLP.\n", + "\n", "**Further reading**\n", "\n", "Equinox works smoothly with all the built-in parallelism APIs provided by JAX. If you want to know more, then the relevant parts of the JAX documentation are:\n", @@ -163,9 +233,9 @@ ], "metadata": { "kernelspec": { - "display_name": "jax39", + "display_name": "py311", "language": "python", - "name": "jax39" + "name": "py311" }, "language_info": { "codemirror_mode": { @@ -177,7 +247,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.11.3" } }, "nbformat": 4, diff --git a/tests/test_sharding.py b/tests/test_sharding.py new file mode 100644 index 00000000..23ca1378 --- /dev/null +++ b/tests/test_sharding.py @@ -0,0 +1,18 @@ +import equinox as eqx +import jax +import jax.random as jr + + +[cpu] = jax.local_devices(backend="cpu") +sharding = jax.sharding.PositionalSharding([cpu]) + + +def test_sharding(): + mlp = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0)) + eqx.filter_device_put(mlp, cpu) + + @eqx.filter_jit + def f(x): + return eqx.filter_with_sharding_constraint(x, sharding) + + f(mlp)