Skip to content

Commit

Permalink
Update getting started doc (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored Oct 15, 2024
1 parent 447e3a8 commit 1865960
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
13 changes: 7 additions & 6 deletions docs/getting_started_with_jax_for_AI.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"source": [
"## Who is this tutorial for?\n",
"\n",
"This tutorial is for those who want to get started using JAX to build and train neural network models. It assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models."
"This tutorial is for those who want to get started using the JAX AI stack to build and train neural network models. It assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models."
]
},
{
Expand Down Expand Up @@ -50,7 +50,7 @@
"source": [
"## Example: a simple neural network with flax\n",
"\n",
"We'll start with a very quick example of what it looks like to use JAX with the [flax](https://flax.readthedocs.io) framework to define and train a very simple neural network on a hand-written digits dataset."
"We'll start with a very quick example of what it looks like to use JAX with the [flax](https://flax.readthedocs.io) framework to define and train a very simple neural network to recognize hand-written digits."
]
},
{
Expand Down Expand Up @@ -135,7 +135,8 @@
"id": "Z3l45KgtfUUo"
},
"source": [
"Let's split these into a training and testing set, and convert these splits into JAX arrays which will be ready to feed into our model:"
"Let's split these into a training and testing set, and convert these splits into JAX arrays which will be ready to feed into our model.\n",
"We'll make use of the `jax.numpy` module, which provides a familiar NumPy-style API around JAX operations:"
]
},
{
Expand Down Expand Up @@ -182,7 +183,7 @@
"source": [
"### Defining the flax model\n",
"\n",
"We can now use the [Flax](http://flax.readthedocs.io) package to create a simple [Feedforward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network with one hidden layer, and use a *scaled exponential linear unit* (SELU) activation function:"
"We can now use the [Flax](http://flax.readthedocs.io) package to create a simple [Feedforward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network with one hidden layer, and use a *scaled exponential linear unit* (SELU) activation function."
]
},
{
Expand Down Expand Up @@ -283,8 +284,8 @@
" optimizer: nnx.Optimizer,\n",
" data: jax.Array,\n",
" labels: jax.Array):\n",
" loss_gradient = nnx.value_and_grad(loss_fun, has_aux=True) # gradient transform!\n",
" (loss, logits), grads = loss_gradient(model, data, labels)\n",
" loss_gradient = nnx.grad(loss_fun, has_aux=True) # gradient transform!\n",
" grads, logits = loss_gradient(model, data, labels)\n",
" optimizer.update(grads) # inplace update"
]
},
Expand Down
13 changes: 7 additions & 6 deletions docs/getting_started_with_jax_for_AI.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ kernelspec:

## Who is this tutorial for?

This tutorial is for those who want to get started using JAX to build and train neural network models. It assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models.
This tutorial is for those who want to get started using the JAX AI stack to build and train neural network models. It assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models.

+++ {"id": "1Y92oUSGeoRz"}

Expand All @@ -42,7 +42,7 @@ Once you've worked through this content, you may wish to visit http://jax.readth

## Example: a simple neural network with flax

We'll start with a very quick example of what it looks like to use JAX with the [flax](https://flax.readthedocs.io) framework to define and train a very simple neural network on a hand-written digits dataset.
We'll start with a very quick example of what it looks like to use JAX with the [flax](https://flax.readthedocs.io) framework to define and train a very simple neural network to recognize hand-written digits.

+++ {"id": "pOlnhK-EioSk"}

Expand Down Expand Up @@ -82,7 +82,8 @@ for i, ax in enumerate(axes.flat):

+++ {"id": "Z3l45KgtfUUo"}

Let's split these into a training and testing set, and convert these splits into JAX arrays which will be ready to feed into our model:
Let's split these into a training and testing set, and convert these splits into JAX arrays which will be ready to feed into our model.
We'll make use of the `jax.numpy` module, which provides a familiar NumPy-style API around JAX operations:

```{code-cell}
:id: 6jrYisoPh6TL
Expand All @@ -105,7 +106,7 @@ print(f"{images_test.shape=} {label_test.shape=}")

### Defining the flax model

We can now use the [Flax](http://flax.readthedocs.io) package to create a simple [Feedforward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network with one hidden layer, and use a *scaled exponential linear unit* (SELU) activation function:
We can now use the [Flax](http://flax.readthedocs.io) package to create a simple [Feedforward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network with one hidden layer, and use a *scaled exponential linear unit* (SELU) activation function.

```{code-cell}
:id: U77VMQwRjTfH
Expand Down Expand Up @@ -165,8 +166,8 @@ def train_step(
optimizer: nnx.Optimizer,
data: jax.Array,
labels: jax.Array):
loss_gradient = nnx.value_and_grad(loss_fun, has_aux=True) # gradient transform!
(loss, logits), grads = loss_gradient(model, data, labels)
loss_gradient = nnx.grad(loss_fun, has_aux=True) # gradient transform!
grads, logits = loss_gradient(model, data, labels)
optimizer.update(grads) # inplace update
```

Expand Down

0 comments on commit 1865960

Please sign in to comment.