From fd6e6f95d1d49bce8352d89e6a9719cbb5d80ce7 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 14 Feb 2022 18:54:32 +0000 Subject: [PATCH] Fix to use pymc 4 --- ...e_with_pymc3.ipynb => use_with_pymc.ipynb} | 154 ++++-------------- 1 file changed, 33 insertions(+), 121 deletions(-) rename examples/{use_with_pymc3.ipynb => use_with_pymc.ipynb} (59%) diff --git a/examples/use_with_pymc3.ipynb b/examples/use_with_pymc.ipynb similarity index 59% rename from examples/use_with_pymc3.ipynb rename to examples/use_with_pymc.ipynb index 1d171a45b..96e7d7ffe 100644 --- a/examples/use_with_pymc3.ipynb +++ b/examples/use_with_pymc.ipynb @@ -7,7 +7,7 @@ "id": "397995ab" }, "source": [ - "# Use BlackJAX with PyMC3\n", + "# Use BlackJAX with PyMC\n", "Author: Kaustubh Chaudhari" ] }, @@ -25,39 +25,12 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "b260c3fa", "metadata": { "id": "3a905211" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/bin/ld: /tmp/tmpewbjudzh/tmp/tmpewbjudzh/source.o: in function `main':\n", - "/tmp/tmpewbjudzh/source.c:6: undefined reference to `cblas_ddot'\n", - "collect2: error: ld returned 1 exit status\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running on PyMC3 v4.0.0b2\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/remi/.virtualenvs/blackjax/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:86: UserWarning: JAX omnistaging couldn't be disabled: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see https://github.com/google/jax/blob/main/design_notes/omnistaging.md.\n", - " warnings.warn(f\"JAX omnistaging couldn't be disabled: {e}\")\n", - "/home/remi/projects/pymc/pymc/sampling_jax.py:31: UserWarning: This module is experimental.\n", - " warnings.warn(\"This module is experimental.\")\n" - ] - } - ], + "outputs": [], "source": [ "import jax\n", "import numpy as np\n", @@ -66,7 +39,7 @@ "\n", "import blackjax\n", "\n", - "print(f\"Running on PyMC3 v{pm.__version__}\")" + "print(f\"Running on PyMC v{pm.__version__}\")" ] }, { @@ -83,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "0b6aaeb6", "metadata": { "id": "imotOe9sUNYF" @@ -108,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "e82b0be9", "metadata": { "id": "PiBv9iOvRK0f" @@ -137,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "d69ddad1", "metadata": { "colab": { @@ -147,63 +120,7 @@ "id": "0ZyMxwLFY_ZI", "outputId": "793af037-31e4-4e55-9c76-231c9d78532d" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Auto-assigning NUTS sampler...\n", - "Initializing NUTS using jitter+adapt_diag...\n", - "Sequential sampling (1 chains in 1 job)\n", - "NUTS: [mu, tau, theta]\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " 100.00% [51000/51000 00:29<00:00 Sampling chain 0, 0 divergences]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Sampling 1 chain for 1_000 tune and 50_000 draw iterations (1_000 + 50_000 draws total) took 30 seconds.\n", - "Only one chain was sampled, this makes it impossible to run some convergence checks\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 30.8 s, sys: 285 ms, total: 31.1 s\n", - "Wall time: 32.9 s\n" - ] - } - ], + "outputs": [], "source": [ "%%time\n", "\n", @@ -223,7 +140,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "de0ad319", "metadata": { "colab": { @@ -232,22 +149,7 @@ "id": "daQ5OO6aZS9t", "outputId": "d865c9dc-45ae-4baa-c643-f145492ea4ab" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Compiling...\n", - "Compilation time = 0 days 00:00:00.102013\n", - "Sampling...\n", - "Sampling time = 0 days 00:00:03.760742\n", - "Transforming variables...\n", - "Transformation time = 0 days 00:00:00.020388\n", - "CPU times: user 3.98 s, sys: 43.6 ms, total: 4.03 s\n", - "Wall time: 3.99 s\n" - ] - } - ], + "outputs": [], "source": [ "%%time\n", "\n", @@ -273,16 +175,27 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "a1de0d56", "metadata": {}, "outputs": [], "source": [ + "from pymc.sampling_jax import get_jaxified_graph\n", + "\n", "rvs = [rv.name for rv in model.value_vars]\n", "init_position_dict = model.compute_initial_point()\n", "init_position = [init_position_dict[rv] for rv in rvs]\n", "\n", - "logprob_fn = pm.sampling_jax.get_jaxified_logp(model)" + "def get_jaxified_logp(model):\n", + "\n", + " logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model.logpt()])\n", + "\n", + " def logp_fn_wrap(x):\n", + " return logp_fn(*x)[0]\n", + "\n", + " return logp_fn_wrap\n", + "\n", + "logprob_fn = get_jaxified_logp(model)" ] }, { @@ -295,21 +208,12 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "e0dcad4d", "metadata": { "id": "cTlcZCYmidZ6" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 5.21 s, sys: 10.9 ms, total: 5.22 s\n", - "Wall time: 5.17 s\n" - ] - } - ], + "outputs": [], "source": [ "%%time\n", "\n", @@ -333,6 +237,14 @@ "# Sample from the posterior distribution\n", "states, infos = inference_loop(seed, kernel, last_state, 50_000)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa558fa7-d323-4b4e-813c-a4e8ab8f519a", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {