Skip to content

Commit

Permalink
FIX: switch back to jacfwd
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Aug 11, 2023
1 parent b5aed2d commit 922a85f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@
"ipywidgets",
"isinstance",
"isort",
"jacrev",
"jacfwd",
"jaxlib",
"joinpath",
"juliaup",
Expand Down
6 changes: 3 additions & 3 deletions docs/report/draft.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@
"outputs": [],
"source": [
"func_with_data_inserted = Partial(intensity_func.function, *data_columns.values())\n",
"gradient_func = jax.jacrev(\n",
"gradient_func = jax.jacfwd(\n",
" func_with_data_inserted,\n",
" argnums=range(len(parameter_values)),\n",
")\n",
Expand Down Expand Up @@ -578,7 +578,7 @@
"import jax.numpy as jnp\n",
"\n",
"\n",
"# @jax.jit # Do not JIT here, otherwise jax.jacrev crashes!\n",
"# @jax.jit # Do not JIT here, otherwise jax.jacfwd crashes!\n",
"def estimator(args):\n",
" data_intensities = func_with_data_inserted(*args)\n",
" phsp_intensities = func_with_phsp_inserted(*args)\n",
Expand Down Expand Up @@ -640,7 +640,7 @@
},
"outputs": [],
"source": [
"estimator_gradient = jax.jacrev(estimator)"
"estimator_gradient = jax.jacfwd(estimator)"
]
},
{
Expand Down

0 comments on commit 922a85f

Please sign in to comment.