Skip to content

Commit

Permalink
Fix to use pymc 4
Browse files Browse the repository at this point in the history
  • Loading branch information
zaxtax authored and rlouf committed Feb 15, 2022
1 parent ee28976 commit fd6e6f9
Showing 1 changed file with 33 additions and 121 deletions.
154 changes: 33 additions & 121 deletions examples/use_with_pymc3.ipynb → examples/use_with_pymc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"id": "397995ab"
},
"source": [
"# Use BlackJAX with PyMC3\n",
"# Use BlackJAX with PyMC\n",
"Author: Kaustubh Chaudhari"
]
},
Expand All @@ -25,39 +25,12 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "b260c3fa",
"metadata": {
"id": "3a905211"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/bin/ld: /tmp/tmpewbjudzh/tmp/tmpewbjudzh/source.o: in function `main':\n",
"/tmp/tmpewbjudzh/source.c:6: undefined reference to `cblas_ddot'\n",
"collect2: error: ld returned 1 exit status\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on PyMC3 v4.0.0b2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/remi/.virtualenvs/blackjax/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:86: UserWarning: JAX omnistaging couldn't be disabled: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see https://github.com/google/jax/blob/main/design_notes/omnistaging.md.\n",
" warnings.warn(f\"JAX omnistaging couldn't be disabled: {e}\")\n",
"/home/remi/projects/pymc/pymc/sampling_jax.py:31: UserWarning: This module is experimental.\n",
" warnings.warn(\"This module is experimental.\")\n"
]
}
],
"outputs": [],
"source": [
"import jax\n",
"import numpy as np\n",
Expand All @@ -66,7 +39,7 @@
"\n",
"import blackjax\n",
"\n",
"print(f\"Running on PyMC3 v{pm.__version__}\")"
"print(f\"Running on PyMC v{pm.__version__}\")"
]
},
{
Expand All @@ -83,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "0b6aaeb6",
"metadata": {
"id": "imotOe9sUNYF"
Expand All @@ -108,7 +81,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "e82b0be9",
"metadata": {
"id": "PiBv9iOvRK0f"
Expand Down Expand Up @@ -137,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "d69ddad1",
"metadata": {
"colab": {
Expand All @@ -147,63 +120,7 @@
"id": "0ZyMxwLFY_ZI",
"outputId": "793af037-31e4-4e55-9c76-231c9d78532d"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Sequential sampling (1 chains in 1 job)\n",
"NUTS: [mu, tau, theta]\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='51000' class='' max='51000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [51000/51000 00:29<00:00 Sampling chain 0, 0 divergences]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling 1 chain for 1_000 tune and 50_000 draw iterations (1_000 + 50_000 draws total) took 30 seconds.\n",
"Only one chain was sampled, this makes it impossible to run some convergence checks\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 30.8 s, sys: 285 ms, total: 31.1 s\n",
"Wall time: 32.9 s\n"
]
}
],
"outputs": [],
"source": [
"%%time\n",
"\n",
Expand All @@ -223,7 +140,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "de0ad319",
"metadata": {
"colab": {
Expand All @@ -232,22 +149,7 @@
"id": "daQ5OO6aZS9t",
"outputId": "d865c9dc-45ae-4baa-c643-f145492ea4ab"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Compiling...\n",
"Compilation time = 0 days 00:00:00.102013\n",
"Sampling...\n",
"Sampling time = 0 days 00:00:03.760742\n",
"Transforming variables...\n",
"Transformation time = 0 days 00:00:00.020388\n",
"CPU times: user 3.98 s, sys: 43.6 ms, total: 4.03 s\n",
"Wall time: 3.99 s\n"
]
}
],
"outputs": [],
"source": [
"%%time\n",
"\n",
Expand All @@ -273,16 +175,27 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "a1de0d56",
"metadata": {},
"outputs": [],
"source": [
"from pymc.sampling_jax import get_jaxified_graph\n",
"\n",
"rvs = [rv.name for rv in model.value_vars]\n",
"init_position_dict = model.compute_initial_point()\n",
"init_position = [init_position_dict[rv] for rv in rvs]\n",
"\n",
"logprob_fn = pm.sampling_jax.get_jaxified_logp(model)"
"def get_jaxified_logp(model):\n",
"\n",
" logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model.logpt()])\n",
"\n",
" def logp_fn_wrap(x):\n",
" return logp_fn(*x)[0]\n",
"\n",
" return logp_fn_wrap\n",
"\n",
"logprob_fn = get_jaxified_logp(model)"
]
},
{
Expand All @@ -295,21 +208,12 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "e0dcad4d",
"metadata": {
"id": "cTlcZCYmidZ6"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 5.21 s, sys: 10.9 ms, total: 5.22 s\n",
"Wall time: 5.17 s\n"
]
}
],
"outputs": [],
"source": [
"%%time\n",
"\n",
Expand All @@ -333,6 +237,14 @@
"# Sample from the posterior distribution\n",
"states, infos = inference_loop(seed, kernel, last_state, 50_000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa558fa7-d323-4b4e-813c-a4e8ab8f519a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit fd6e6f9

Please sign in to comment.