From 922a85f85064f56147d007f88540f508beeca25e Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 11 Aug 2023 14:33:30 +0200 Subject: [PATCH] FIX: switch back to `jacfwd` --- .cspell.json | 2 +- docs/report/draft.ipynb | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.cspell.json b/.cspell.json index 38391e16..b859ff23 100644 --- a/.cspell.json +++ b/.cspell.json @@ -213,7 +213,7 @@ "ipywidgets", "isinstance", "isort", - "jacrev", + "jacfwd", "jaxlib", "joinpath", "juliaup", diff --git a/docs/report/draft.ipynb b/docs/report/draft.ipynb index 835e73c2..7742fc2e 100644 --- a/docs/report/draft.ipynb +++ b/docs/report/draft.ipynb @@ -492,7 +492,7 @@ "outputs": [], "source": [ "func_with_data_inserted = Partial(intensity_func.function, *data_columns.values())\n", - "gradient_func = jax.jacrev(\n", + "gradient_func = jax.jacfwd(\n", " func_with_data_inserted,\n", " argnums=range(len(parameter_values)),\n", ")\n", @@ -578,7 +578,7 @@ "import jax.numpy as jnp\n", "\n", "\n", - "# @jax.jit # Do not JIT here, otherwise jax.jacrev crashes!\n", + "# @jax.jit # Do not JIT here, otherwise jax.jacfwd crashes!\n", "def estimator(args):\n", " data_intensities = func_with_data_inserted(*args)\n", " phsp_intensities = func_with_phsp_inserted(*args)\n", @@ -640,7 +640,7 @@ }, "outputs": [], "source": [ - "estimator_gradient = jax.jacrev(estimator)" + "estimator_gradient = jax.jacfwd(estimator)" ] }, {