Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added filter_{device_put,with_sharding_constraint} and updated parallelism example #617

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions docs/api/filtering/transformations.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ Most users find that this is a simpler API when working with complicated PyTrees

Likewise, `eqx.filter_grad` will automatically differentiate all floating-point JAX arrays and treat the rest nondifferentiably, etc. Each transformation here just combines [`equinox.partition`][], `jax.{jit, ...}` and [`equinox.combine`][] together.

## Just-in-time compilation
## Compilation

::: equinox.filter_jit

---

::: equinox.filter_make_jaxpr
::: equinox.filter_device_put

---

::: equinox.filter_eval_shape
::: equinox.filter_with_sharding_constraint

## Automatic differentiation

Expand Down Expand Up @@ -48,10 +48,6 @@ Likewise, `eqx.filter_grad` will automatically differentiate all floating-point

::: equinox.filter_custom_vjp

---

::: equinox.filter_closure_convert

## Vectorisation and parallelisation

::: equinox.filter_vmap
Expand All @@ -63,3 +59,17 @@ Likewise, `eqx.filter_grad` will automatically differentiate all floating-point
## Callbacks

::: equinox.filter_pure_callback

# Tracing

---

::: equinox.filter_make_jaxpr

---

::: equinox.filter_eval_shape

---

::: equinox.filter_closure_convert
4 changes: 4 additions & 0 deletions equinox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
tree_deserialise_leaves as tree_deserialise_leaves,
tree_serialise_leaves as tree_serialise_leaves,
)
from ._sharding import (
filter_device_put as filter_device_put,
filter_with_sharding_constraint as filter_with_sharding_constraint,
)
from ._tree import (
tree_at as tree_at,
tree_check as tree_check,
Expand Down
65 changes: 65 additions & 0 deletions equinox/_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Any

import jax
import jax.lax as lax
from jaxtyping import PyTree

from ._filters import combine, is_array, partition


def filter_with_sharding_constraint(x: PyTree[Any], shardings):
"""Filtered version of `jax.lax.with_sharding_constraint`. Enforces sharding within
a JIT'd computation. (That is, how an array is split between multiple devices, i.e.
multiple GPUs/TPUs.)

This should always be called *inside* of a JIT'd computation.

This is a strict constraint for the XLA compiler, and not just a hint. It is
typically placed on the inputs of JIT'd computations to assert that they are sharded
in the correct way, and on the output of JIT'd computations to specify how they
should be sharded.

**Arguments:**

- `x`: A PyTree, with potentially a mix of arrays and non-arrays on the leaves. They
will have their shardings constrained.
- `shardings`: a PyTree of sharding specifications. The structure should be a prefix
of `x`.

**Returns:**

A copy of `x` with the specified sharding constraints.

!!! Example

See also the [autoparallelism example](../../../examples/parallelism).
"""
dynamic, static = partition(x, is_array)
dynamic = lax.with_sharding_constraint(dynamic, shardings)
return combine(dynamic, static)


def filter_device_put(x: PyTree[Any], device):
"""Filtered version of `jax.device_put`. Places all arrays in `x` on the device.
Non-arrays are unchanged.

This should always be called *outside* of a JIT'd computation.

**Arguments:**

- `x`: A PyTree, with potentially a mix of arrays and non-arrays on the leaves.
- `device`: A specification for how to place `x` on a device. Most typically this is
either a `Device` (as returned by `jax.local_devices`) or a sharding (usually a
`jax.sharding.NamedSharding` or `jax.sharding.PositionalSharding`).

**Returns:**

A copy of `x` that resides on `device`.

!!! Example

See also the [autoparallelism example](../../../examples/parallelism).
"""
dynamic, static = partition(x, is_array)
dynamic = jax.device_put(dynamic, device)
return combine(dynamic, static)
166 changes: 118 additions & 48 deletions examples/parallelism.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,16 @@
"\n",
"JAX has a number of advanced APIs to support this. The main technique is to \"shard\" an array, so that each device holds part of the array.\n",
"\n",
"In this example, we'll parallelise our computation (usually it's a training step) over 8 devices, so that each device gets 1/8 of the batch of data."
]
},
{
"cell_type": "markdown",
"id": "f2a0bae8-2435-4b37-b1f2-24322cfeb1dd",
"metadata": {},
"source": [
"First let's import everything, and set up our toy problem."
"In this example, we'll perform *data parallelism*. Each GPU gets a fully copy of our model on every GPU, and some fraction of the batch of input data. (In this case it will be 1/8 of the data, as this example is written for 8 GPUs.)\n",
"\n",
"---\n",
"\n",
"First let's import everything, and set up our toy problem. Everything here will be exactly the same as when running on a single device."
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "83bba892-5425-4eed-a7f7-9c325fe5cc53",
"metadata": {},
"outputs": [],
Expand All @@ -34,7 +30,7 @@
"import jax.experimental.mesh_utils as mesh_utils\n",
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import jax.sharding as sharding\n",
"import jax.sharding as jshard\n",
"import numpy as np\n",
"import optax # https://github.com/deepmind/optax\n",
"\n",
Expand All @@ -45,7 +41,7 @@
"hidden_size = 32\n",
"depth = 1\n",
"learning_rate = 3e-4\n",
"num_steps = 10\n",
"num_steps = 50\n",
"batch_size = 16 # must be a multiple of our number of devices.\n",
"\n",
"# Generate some synthetic data\n",
Expand All @@ -57,49 +53,67 @@
"opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))\n",
"\n",
"\n",
"# Loss function for a batch of data\n",
"def compute_loss(model, x, y):\n",
" pred_y = jax.vmap(model)(x)\n",
" return jnp.mean((y - pred_y) ** 2)\n",
"\n",
"\n",
"@eqx.filter_jit\n",
"def make_step(model, opt_state, x, y):\n",
" grads = eqx.filter_grad(compute_loss)(model, x, y)\n",
" updates, opt_state = optim.update(grads, opt_state)\n",
" model = eqx.apply_updates(model, updates)\n",
" return model, opt_state"
"# Simple dataloader; randomly slices our dataset and shuffles between epochs.\n",
"# In NumPy for speed, as our dataset is small enough to fit entirely in host memory.\n",
"#\n",
"# For larger datasets (that require loading from disk) then use PyTorch's `DataLoader`\n",
"# or TensorFlow's `tf.data`.\n",
"def train_dataloader(arrays, batch_size):\n",
" dataset_size = arrays[0].shape[0]\n",
" assert all(array.shape[0] == dataset_size for array in arrays)\n",
" indices = np.arange(dataset_size)\n",
" while True:\n",
" perm = np.random.permutation(indices)\n",
" start = 0\n",
" end = batch_size\n",
" while end <= dataset_size:\n",
" batch_perm = perm[start:end]\n",
" yield tuple(array[batch_perm] for array in arrays)\n",
" start = end\n",
" end = start + batch_size"
]
},
{
"cell_type": "markdown",
"id": "0fb345b0-c9b3-44df-94e8-d74c7ad172b8",
"id": "50e663d7-fc20-483d-8d87-c1fe6c773f8c",
"metadata": {},
"source": [
"Here's a very simple dataloader, that randomly shuffles and slices our dataset. We keep everything in pure-NumPy for speed, as this all happens on the host, prior to moving our data to our devices. (Which will often be a cluster of GPUs.)\n",
"Okay, now the interesting things start happening!\n",
"\n",
"First, we're going to arrange to \"donate\" memory, which specifes that we can re-use the memory for our input arrays (e.g. model parameters) to store the output arrays (e.g. updated model parameters). (This isn't technically related to autoparallelism, but it's good practice so you should do it anyway :) )\n",
"\n",
"In practice it's also common to load data using either PyTorch's `DataLoader` or TensorFlow's `tf.data` API; see [here](../mnist/) for more details."
"Second, we're going to use `filter_with_sharding_constraint` to assert (on the inputs) and enforce (on the outputs) how each array is split across each of our devices. As we're doing data parallelism in this example, then we'll be replicating our model parameters on to every device, whilst sharding our data between devices."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fd94db04-9fe4-4530-808e-945becef9df5",
"execution_count": 2,
"id": "15b2cec2-5925-4f04-89fb-79761288cf14",
"metadata": {},
"outputs": [],
"source": [
"def dataloader(arrays, batch_size):\n",
" dataset_size = arrays[0].shape[0]\n",
" assert all(array.shape[0] == dataset_size for array in arrays)\n",
" indices = np.arange(dataset_size)\n",
" while True:\n",
" perm = np.random.permutation(indices)\n",
" start = 0\n",
" end = batch_size\n",
" while end <= dataset_size:\n",
" batch_perm = perm[start:end]\n",
" yield tuple(array[batch_perm] for array in arrays)\n",
" start = end\n",
" end = start + batch_size"
"@eqx.filter_jit(donate=\"all\")\n",
"def train_step(model, opt_state, x, y, sharding):\n",
" replicated = sharding.replicate()\n",
" model, opt_state = eqx.filter_with_sharding_constraint(\n",
" (model, opt_state), replicated\n",
" )\n",
" x, y = eqx.filter_with_sharding_constraint((x, y), sharding)\n",
"\n",
" grads = eqx.filter_grad(compute_loss)(model, x, y)\n",
" updates, opt_state = optim.update(grads, opt_state)\n",
" model = eqx.apply_updates(model, updates)\n",
"\n",
" model, opt_state = eqx.filter_with_sharding_constraint(\n",
" (model, opt_state), replicated\n",
" )\n",
" return model, opt_state"
]
},
{
Expand All @@ -112,7 +126,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "32c6b58e-f72f-4dd4-bf2c-f1dc75643eda",
"metadata": {
"tags": []
Expand All @@ -121,11 +135,65 @@
"source": [
"num_devices = len(jax.devices())\n",
"devices = mesh_utils.create_device_mesh((num_devices, 1))\n",
"shard = sharding.PositionalSharding(devices)\n",
"\n",
"for step, (x, y) in zip(range(num_steps), dataloader((xs, ys), batch_size)):\n",
" x, y = jax.device_put((x, y), shard)\n",
" model, opt_state = make_step(model, opt_state, x, y)"
"sharding = jshard.PositionalSharding(devices)\n",
"replicated = sharding.replicate()\n",
"\n",
"model = eqx.filter_device_put(model, replicated)\n",
"for step, (x, y) in zip(\n",
" range(1, num_steps + 1), train_dataloader((xs, ys), batch_size)\n",
"):\n",
" x, y = eqx.filter_device_put((x, y), sharding)\n",
" model, opt_state = train_step(model, opt_state, x, y, sharding)"
]
},
{
"cell_type": "markdown",
"id": "582db8d5-1d5c-4502-b95f-1c951d6c822c",
"metadata": {},
"source": [
"Not strictly related to parallelism, but a common question at this point: if we want to evaluate our model, then we probably don't want to donate its parameters (which would render the model unusable, as all its memory is freed). As such, inference looks like this:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9153b1c2-66f7-47cc-80d0-2f2a9e6d53d4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train loss=0.47004854679107666\n"
]
}
],
"source": [
"def eval_dataloader(arrays, batch_size):\n",
" dataset_size = arrays[0].shape[0]\n",
" assert all(array.shape[0] == dataset_size for array in arrays)\n",
" start = 0\n",
" end = batch_size\n",
" while start < dataset_size:\n",
" yield tuple(array[start:end] for array in arrays)\n",
" start = end\n",
" end = start + batch_size\n",
"\n",
"\n",
"@eqx.filter_jit(donate=\"all-except-first\")\n",
"def evaluate(model, x, y, sharding):\n",
" replicated = sharding.replicate()\n",
" model = eqx.filter_with_sharding_constraint(model, replicated)\n",
" x, y = eqx.filter_with_sharding_constraint((x, y), sharding)\n",
" return compute_loss(model, x, y)\n",
"\n",
"\n",
"loss = 0\n",
"num_batches = 0\n",
"for x, y in eval_dataloader((xs, ys), batch_size):\n",
" loss = loss + evaluate(model, x, y, sharding).item()\n",
" num_batches = num_batches + 1\n",
"print(f\"train loss={loss/num_batches}\")"
]
},
{
Expand All @@ -139,18 +207,20 @@
"\n",
"If you ran the above example on a cluster of NVIDIA GPUs, then you can check whether you're using as many GPUs as you expected by running `nvidia-smi` from the command line. You can also use `jax.debug.visualize_array_sharding(array)` to inspect the sharding manually.\n",
"\n",
"One possible optimisation here is to re-use the memory used by the input arrays, to store the output arrays. This often improves speed a little bit. This is disabled by default, but can be enabled by passing `eqx.filter_jit(donate=\"all\")`.\n",
"\n",
"**What about pmap?**\n",
"\n",
"The JAX team have been hard at work introducing these new easy-to-use parallelism features, based around JIT and sharding. These are often faster and more expressive than pmap, so pmap is no longer recommended!\n",
"The JAX team have been hard at work introducing these new easy-to-use parallelism features, based around JIT and sharding. These are often faster and more expressive than pmap, so pmap is no longer recommended.\n",
"\n",
"**Types of parallelism**\n",
"\n",
"There are multiple types of parallelism. In this example we demonstrated _data parallelism_, in which we parallelise over the data. This is one of the simplest to set up, and often very effective.\n",
"\n",
"For completeness we note that there are other kinds of parallelism available -- e.g. model parallelism, which instead places different parts of the model on different devices. A discussion on those is a more advanced topic. :)\n",
"\n",
"**`jax.device_put` vs `eqx.filter_device_put`, `jax.lax.with_sharding_constraint` vs `eqx.filter_with_sharding_constraint`**\n",
"\n",
"These are the usual story in Equinox: we have a filtered version of the operation that leaves any non-arrays alone. In this case, they are used because we have an activation function (i.e. just some arbitrary Python function, which isn't an array) as part of the MLP.\n",
"\n",
"**Further reading**\n",
"\n",
"Equinox works smoothly with all the built-in parallelism APIs provided by JAX. If you want to know more, then the relevant parts of the JAX documentation are:\n",
Expand All @@ -163,9 +233,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "jax39",
"display_name": "py311",
"language": "python",
"name": "jax39"
"name": "py311"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -177,7 +247,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.11.3"
}
},
"nbformat": 4,
Expand Down
Loading
Loading