Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
pre-commit on example
Browse files Browse the repository at this point in the history
thibmonsel committed Aug 16, 2023
1 parent 9f0aaa1 commit 2b24da8
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions examples/neural_dde.ipynb
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Neural DDE"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -28,6 +30,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -53,6 +56,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -69,12 +73,12 @@
"outputs": [],
"source": [
"delays = diffrax.Delays(\n",
" delays=[lambda t, y, args: 0.2],\n",
" initial_discontinuities=jnp.array([0.0])\n",
" )"
" delays=[lambda t, y, args: 0.2], initial_discontinuities=jnp.array([0.0])\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -106,6 +110,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -114,6 +119,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -145,10 +151,11 @@
" stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),\n",
" saveat=diffrax.SaveAt(ts=ts, dense=True),\n",
" )\n",
" return solution.ys\n"
" return solution.ys"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -164,7 +171,6 @@
"def _get_data(ts, *, key):\n",
" y0 = jrandom.uniform(key, (2,), minval=0.1, maxval=2.0)\n",
"\n",
" \n",
" def vector_field(t, y, args, history):\n",
" return jnp.array(\n",
" [\n",
@@ -179,13 +185,13 @@
" t0=ts[0],\n",
" t1=ts[-1],\n",
" dt0=ts[1] - ts[0],\n",
" y0=lambda t : y0,\n",
" y0=lambda t: y0,\n",
" adjoint=diffrax.NoAdjoint(),\n",
" stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),\n",
" saveat=diffrax.SaveAt(ts=ts, dense=True),\n",
" delays=delays,\n",
" )\n",
" \n",
"\n",
" return sol.ys\n",
"\n",
"\n",
@@ -213,6 +219,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -230,8 +237,8 @@
" batch_size=32,\n",
" width_size=32,\n",
" depth=2,\n",
" tot_steps = 500,\n",
" lr = 10e-3,\n",
" tot_steps=500,\n",
" lr=10e-3,\n",
" seed=5678,\n",
" plot=True,\n",
" print_every=100,\n",
@@ -262,7 +269,7 @@
" range(tot_steps), dataloader((ys,), batch_size, key=loader_key)\n",
" ):\n",
" start = time.time()\n",
" loss, model, opt_state = make_step(_ts, yi, model, opt_state)\n",
" loss, model, opt_state = make_step(ts, yi, model, opt_state)\n",
" end = time.time()\n",
" if (step % print_every) == 0 or step == tot_steps - 1:\n",
" print(f\"Step: {step}, Loss: {loss}, Computation time: {end - start}\")\n",
@@ -278,7 +285,7 @@
" plt.savefig(\"neural_ode.png\")\n",
" plt.show()\n",
"\n",
" return ts, ys, model\n"
" return ts, ys, model"
]
}
],

0 comments on commit 2b24da8

Please sign in to comment.