From bf83fa6c4c70f79e8562c637db9416fd9209fcd4 Mon Sep 17 00:00:00 2001 From: Vitor Shen <17490173+shenvitor@users.noreply.github.com> Date: Tue, 6 Aug 2024 15:29:13 +0200 Subject: [PATCH 1/7] first commit --- docs/ampform/LambdaKpi0.ipynb | 118 +++++++++++++++++----------------- 1 file changed, 59 insertions(+), 59 deletions(-) diff --git a/docs/ampform/LambdaKpi0.ipynb b/docs/ampform/LambdaKpi0.ipynb index a1134f5..4b705a0 100644 --- a/docs/ampform/LambdaKpi0.ipynb +++ b/docs/ampform/LambdaKpi0.ipynb @@ -426,65 +426,6 @@ "phsp = helicity_transformer(phsp_momenta)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "jupyter": { - "source_hidden": true - }, - "tags": [ - "hide-input" - ] - }, - "outputs": [], - "source": [ - "%config InlineBackend.figure_formats = ['png']\n", - "bin_values, xedges, yedges = jnp.histogram2d(\n", - " phsp[\"m_01\"].real ** 2,\n", - " phsp[\"m_12\"].real ** 2,\n", - " bins=200,\n", - " weights=intensity_func(phsp),\n", - " density=True,\n", - ")\n", - "bin_values = jnp.where(bin_values < 1e-6, jnp.nan, bin_values)\n", - "X, Y = jnp.meshgrid(xedges[:-1], yedges[:-1])\n", - "\n", - "fig, ax = plt.subplots(dpi=150)\n", - "mesh = ax.pcolormesh(X, Y, bin_values.T, cmap=\"jet\", vmax=0.15)\n", - "ax.set_title(\"Model-weighted Phase space Dalitz Plot\")\n", - "ax.set_xlabel(R\"$m^2(\\Lambda K^+)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", - "ax.set_ylabel(R\"$m^2(K^+ \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", - "fig.colorbar(mesh, ax=ax)\n", - "fig.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "jupyter": { - "source_hidden": true - }, - "tags": [ - "hide-output", - "hide-input" - ] - }, - "outputs": [], - "source": [ - "resonances = defaultdict(set)\n", - "for transition in reaction.transitions:\n", - " topology = transition.topology\n", - " top_decay_products = topology.get_edge_ids_outgoing_from_node(0)\n", - " (resonance_id, resonance), *_ = transition.intermediate_states.items()\n", - " recoil_id, *_ = top_decay_products - {resonance_id}\n", - " resonances[recoil_id].add(resonance.particle)\n", - "resonances = {k: sorted(v, key=lambda p: p.mass) for k, v in resonances.items()}\n", - "{k: [p.name for p in v] for k, v in resonances.items()}" - ] - }, { "cell_type": "code", "execution_count": null, @@ -548,6 +489,65 @@ "UI = w.Tab(children=tabs, titles=tuple(categorized_sliders_m))" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "%config InlineBackend.figure_formats = ['png']\n", + "bin_values, xedges, yedges = jnp.histogram2d(\n", + " phsp[\"m_01\"].real ** 2,\n", + " phsp[\"m_12\"].real ** 2,\n", + " bins=200,\n", + " weights=intensity_func(phsp),\n", + " density=True,\n", + ")\n", + "bin_values = jnp.where(bin_values < 1e-6, jnp.nan, bin_values)\n", + "X, Y = jnp.meshgrid(xedges[:-1], yedges[:-1])\n", + "\n", + "fig, ax = plt.subplots(dpi=150)\n", + "mesh = ax.pcolormesh(X, Y, bin_values.T, cmap=\"jet\", vmax=0.15)\n", + "ax.set_title(\"Model-weighted Phase space Dalitz Plot\")\n", + "ax.set_xlabel(R\"$m^2(\\Lambda K^+)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", + "ax.set_ylabel(R\"$m^2(K^+ \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", + "fig.colorbar(mesh, ax=ax)\n", + "fig.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-output", + "hide-input" + ] + }, + "outputs": [], + "source": [ + "resonances = defaultdict(set)\n", + "for transition in reaction.transitions:\n", + " topology = transition.topology\n", + " top_decay_products = topology.get_edge_ids_outgoing_from_node(0)\n", + " (resonance_id, resonance), *_ = transition.intermediate_states.items()\n", + " recoil_id, *_ = top_decay_products - {resonance_id}\n", + " resonances[recoil_id].add(resonance.particle)\n", + "resonances = {k: sorted(v, key=lambda p: p.mass) for k, v in resonances.items()}\n", + "{k: [p.name for p in v] for k, v in resonances.items()}" + ] + }, { "cell_type": "code", "execution_count": null, From df73885bcd41cd36264f06e976b7098c0a8aeddf Mon Sep 17 00:00:00 2001 From: Vitor Shen <17490173+shenvitor@users.noreply.github.com> Date: Wed, 7 Aug 2024 17:02:11 +0200 Subject: [PATCH 2/7] Add sliders widget for Dalitz plot --- docs/ampform/LambdaKpi0.ipynb | 58 +++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/docs/ampform/LambdaKpi0.ipynb b/docs/ampform/LambdaKpi0.ipynb index 4b705a0..e02a0e3 100644 --- a/docs/ampform/LambdaKpi0.ipynb +++ b/docs/ampform/LambdaKpi0.ipynb @@ -77,7 +77,9 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, "source": [ "### Particle definitions" ] @@ -160,7 +162,9 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, "source": [ "### Initial state definition" ] @@ -282,7 +286,9 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, "source": [ "## Formulate amplitude model" ] @@ -489,6 +495,52 @@ "UI = w.Tab(children=tabs, titles=tuple(categorized_sliders_m))" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "%matplotlib widget\n", + "%config InlineBackend.figure_formats = ['png']\n", + "\n", + "\n", + "def update_histogram(**parameters):\n", + " for symbol, value in parameters.items():\n", + " model.parameter_defaults[symbol] = value\n", + "\n", + " intensity_weights = jnp.array(intensity_func(phsp))\n", + "\n", + " bin_values, xedges, yedges = jnp.histogram2d(\n", + " phsp[\"m_01\"].real ** 2,\n", + " phsp[\"m_12\"].real ** 2,\n", + " bins=200,\n", + " weights=intensity_weights,\n", + " density=True,\n", + " )\n", + " bin_values = jnp.where(bin_values < 1e-6, jnp.nan, bin_values)\n", + " x, y = jnp.meshgrid(xedges[:-1], yedges[:-1])\n", + " fig, ax = plt.subplots(dpi=150)\n", + " mesh = ax.pcolormesh(x, y, bin_values.T, cmap=\"jet\", vmax=0.15)\n", + " ax.set_title(\"Model-weighted Phase space Dalitz Plot\")\n", + " ax.set_xlabel(R\"$m^2(\\Lambda K^+)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", + " ax.set_ylabel(R\"$m^2(K^+ \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", + " fig.colorbar(mesh, ax=ax)\n", + " fig.tight_layout()\n", + " plt.show()\n", + "\n", + "\n", + "interactive_plot = w.interactive_output(update_histogram, sliders)\n", + "display(UI, interactive_plot)" + ] + }, { "cell_type": "code", "execution_count": null, From c7a514e02c473d551f007deff46bf9443607ea59 Mon Sep 17 00:00:00 2001 From: Vitor Shen <17490173+shenvitor@users.noreply.github.com> Date: Wed, 7 Aug 2024 17:04:07 +0200 Subject: [PATCH 3/7] remove the original static default Dalitz plot --- docs/ampform/LambdaKpi0.ipynb | 37 ----------------------------------- 1 file changed, 37 deletions(-) diff --git a/docs/ampform/LambdaKpi0.ipynb b/docs/ampform/LambdaKpi0.ipynb index e02a0e3..01a01d1 100644 --- a/docs/ampform/LambdaKpi0.ipynb +++ b/docs/ampform/LambdaKpi0.ipynb @@ -545,43 +545,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "jupyter": { - "source_hidden": true - }, - "tags": [ - "hide-input" - ] - }, - "outputs": [], - "source": [ - "%config InlineBackend.figure_formats = ['png']\n", - "bin_values, xedges, yedges = jnp.histogram2d(\n", - " phsp[\"m_01\"].real ** 2,\n", - " phsp[\"m_12\"].real ** 2,\n", - " bins=200,\n", - " weights=intensity_func(phsp),\n", - " density=True,\n", - ")\n", - "bin_values = jnp.where(bin_values < 1e-6, jnp.nan, bin_values)\n", - "X, Y = jnp.meshgrid(xedges[:-1], yedges[:-1])\n", - "\n", - "fig, ax = plt.subplots(dpi=150)\n", - "mesh = ax.pcolormesh(X, Y, bin_values.T, cmap=\"jet\", vmax=0.15)\n", - "ax.set_title(\"Model-weighted Phase space Dalitz Plot\")\n", - "ax.set_xlabel(R\"$m^2(\\Lambda K^+)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", - "ax.set_ylabel(R\"$m^2(K^+ \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", - "fig.colorbar(mesh, ax=ax)\n", - "fig.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "jupyter": { - "source_hidden": true - }, "tags": [ "hide-output", "hide-input" From ee016e5fa8215a1ed904b7e8f8de14c2f4be603e Mon Sep 17 00:00:00 2001 From: Vitor Shen <17490173+shenvitor@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:00:58 +0200 Subject: [PATCH 4/7] update 2d hist update to update in the same canvas --- docs/ampform/LambdaKpi0.ipynb | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/docs/ampform/LambdaKpi0.ipynb b/docs/ampform/LambdaKpi0.ipynb index 01a01d1..e16df6e 100644 --- a/docs/ampform/LambdaKpi0.ipynb +++ b/docs/ampform/LambdaKpi0.ipynb @@ -499,9 +499,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "jupyter": { - "source_hidden": true - }, "tags": [ "hide-input" ] @@ -510,9 +507,16 @@ "source": [ "%matplotlib widget\n", "%config InlineBackend.figure_formats = ['png']\n", + "fig_2d, ax_2d = plt.subplots(dpi=200)\n", + "ax_2d.set_title(\"Model-weighted Phase space Dalitz Plot\")\n", + "ax_2d.set_xlabel(R\"$m^2(\\Lambda K^+)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", + "ax_2d.set_ylabel(R\"$m^2(K^+ \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", + "\n", + "mesh = None\n", "\n", "\n", "def update_histogram(**parameters):\n", + " global mesh\n", " for symbol, value in parameters.items():\n", " model.parameter_defaults[symbol] = value\n", "\n", @@ -527,16 +531,17 @@ " )\n", " bin_values = jnp.where(bin_values < 1e-6, jnp.nan, bin_values)\n", " x, y = jnp.meshgrid(xedges[:-1], yedges[:-1])\n", - " fig, ax = plt.subplots(dpi=150)\n", - " mesh = ax.pcolormesh(x, y, bin_values.T, cmap=\"jet\", vmax=0.15)\n", - " ax.set_title(\"Model-weighted Phase space Dalitz Plot\")\n", - " ax.set_xlabel(R\"$m^2(\\Lambda K^+)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", - " ax.set_ylabel(R\"$m^2(K^+ \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", - " fig.colorbar(mesh, ax=ax)\n", - " fig.tight_layout()\n", - " plt.show()\n", + "\n", + " if mesh is None:\n", + " mesh = ax_2d.pcolormesh(x, y, bin_values.T, cmap=\"jet\", vmax=0.15)\n", + " else:\n", + " mesh.set_array(bin_values.T)\n", + "\n", + " fig_2d.canvas.draw_idle()\n", "\n", "\n", + "fig_2d.colorbar(mesh, ax=ax_2d)\n", + "fig_2d.tight_layout()\n", "interactive_plot = w.interactive_output(update_histogram, sliders)\n", "display(UI, interactive_plot)" ] @@ -581,7 +586,7 @@ "%matplotlib widget\n", "%config InlineBackend.figure_formats = ['svg']\n", "\n", - "fig, axes = plt.subplots(figsize=(11, 3.5), ncols=3, sharey=True)\n", + "fig, axes = plt.subplots(figsize=(12, 4), ncols=3, sharey=True)\n", "fig.canvas.toolbar_visible = False\n", "fig.canvas.header_visible = False\n", "fig.canvas.footer_visible = False\n", From afb850e98054ef01ef3b12ca522f9dfeecf7ff6e Mon Sep 17 00:00:00 2001 From: Vitor Shen <17490173+shenvitor@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:24:16 +0200 Subject: [PATCH 5/7] to make color bar update accordingly move resonances output above --- docs/ampform/LambdaKpi0.ipynb | 51 ++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/docs/ampform/LambdaKpi0.ipynb b/docs/ampform/LambdaKpi0.ipynb index e16df6e..e7f70ad 100644 --- a/docs/ampform/LambdaKpi0.ipynb +++ b/docs/ampform/LambdaKpi0.ipynb @@ -432,6 +432,28 @@ "phsp = helicity_transformer(phsp_momenta)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-output", + "hide-input" + ] + }, + "outputs": [], + "source": [ + "resonances = defaultdict(set)\n", + "for transition in reaction.transitions:\n", + " topology = transition.topology\n", + " top_decay_products = topology.get_edge_ids_outgoing_from_node(0)\n", + " (resonance_id, resonance), *_ = transition.intermediate_states.items()\n", + " recoil_id, *_ = top_decay_products - {resonance_id}\n", + " resonances[recoil_id].add(resonance.particle)\n", + "resonances = {k: sorted(v, key=lambda p: p.mass) for k, v in resonances.items()}\n", + "{k: [p.name for p in v] for k, v in resonances.items()}" + ] + }, { "cell_type": "code", "execution_count": null, @@ -499,6 +521,9 @@ "cell_type": "code", "execution_count": null, "metadata": { + "jupyter": { + "source_hidden": true + }, "tags": [ "hide-input" ] @@ -534,40 +559,18 @@ "\n", " if mesh is None:\n", " mesh = ax_2d.pcolormesh(x, y, bin_values.T, cmap=\"jet\", vmax=0.15)\n", + " fig_2d.colorbar(mesh, ax=ax_2d)\n", " else:\n", " mesh.set_array(bin_values.T)\n", - "\n", + " mesh.set_clim(vmin=jnp.nanmin(bin_values), vmax=0.15)\n", " fig_2d.canvas.draw_idle()\n", "\n", "\n", - "fig_2d.colorbar(mesh, ax=ax_2d)\n", "fig_2d.tight_layout()\n", "interactive_plot = w.interactive_output(update_histogram, sliders)\n", "display(UI, interactive_plot)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "hide-output", - "hide-input" - ] - }, - "outputs": [], - "source": [ - "resonances = defaultdict(set)\n", - "for transition in reaction.transitions:\n", - " topology = transition.topology\n", - " top_decay_products = topology.get_edge_ids_outgoing_from_node(0)\n", - " (resonance_id, resonance), *_ = transition.intermediate_states.items()\n", - " recoil_id, *_ = top_decay_products - {resonance_id}\n", - " resonances[recoil_id].add(resonance.particle)\n", - "resonances = {k: sorted(v, key=lambda p: p.mass) for k, v in resonances.items()}\n", - "{k: [p.name for p in v] for k, v in resonances.items()}" - ] - }, { "cell_type": "code", "execution_count": null, From afd13210b3f1502b90fdf849995024925845687a Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 8 Aug 2024 14:12:20 +0200 Subject: [PATCH 6/7] FIX: create colorbar after plot --- docs/ampform/LambdaKpi0.ipynb | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/docs/ampform/LambdaKpi0.ipynb b/docs/ampform/LambdaKpi0.ipynb index e7f70ad..7bd301e 100644 --- a/docs/ampform/LambdaKpi0.ipynb +++ b/docs/ampform/LambdaKpi0.ipynb @@ -542,11 +542,8 @@ "\n", "def update_histogram(**parameters):\n", " global mesh\n", - " for symbol, value in parameters.items():\n", - " model.parameter_defaults[symbol] = value\n", - "\n", - " intensity_weights = jnp.array(intensity_func(phsp))\n", - "\n", + " intensity_func.update_parameters(parameters)\n", + " intensity_weights = intensity_func(phsp)\n", " bin_values, xedges, yedges = jnp.histogram2d(\n", " phsp[\"m_01\"].real ** 2,\n", " phsp[\"m_12\"].real ** 2,\n", @@ -556,18 +553,16 @@ " )\n", " bin_values = jnp.where(bin_values < 1e-6, jnp.nan, bin_values)\n", " x, y = jnp.meshgrid(xedges[:-1], yedges[:-1])\n", - "\n", " if mesh is None:\n", " mesh = ax_2d.pcolormesh(x, y, bin_values.T, cmap=\"jet\", vmax=0.15)\n", - " fig_2d.colorbar(mesh, ax=ax_2d)\n", " else:\n", " mesh.set_array(bin_values.T)\n", - " mesh.set_clim(vmin=jnp.nanmin(bin_values), vmax=0.15)\n", " fig_2d.canvas.draw_idle()\n", "\n", "\n", - "fig_2d.tight_layout()\n", "interactive_plot = w.interactive_output(update_histogram, sliders)\n", + "fig_2d.tight_layout()\n", + "fig_2d.colorbar(mesh, ax=ax_2d)\n", "display(UI, interactive_plot)" ] }, From 6fef0df88606f0ccd34897b79054cea0798e5d40 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 8 Aug 2024 14:12:55 +0200 Subject: [PATCH 7/7] MAINT: reduce diff --- docs/ampform/LambdaKpi0.ipynb | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/docs/ampform/LambdaKpi0.ipynb b/docs/ampform/LambdaKpi0.ipynb index 7bd301e..96cdb08 100644 --- a/docs/ampform/LambdaKpi0.ipynb +++ b/docs/ampform/LambdaKpi0.ipynb @@ -77,9 +77,7 @@ }, { "cell_type": "markdown", - "metadata": { - "jp-MarkdownHeadingCollapsed": true - }, + "metadata": {}, "source": [ "### Particle definitions" ] @@ -162,9 +160,7 @@ }, { "cell_type": "markdown", - "metadata": { - "jp-MarkdownHeadingCollapsed": true - }, + "metadata": {}, "source": [ "### Initial state definition" ] @@ -286,9 +282,7 @@ }, { "cell_type": "markdown", - "metadata": { - "jp-MarkdownHeadingCollapsed": true - }, + "metadata": {}, "source": [ "## Formulate amplitude model" ] @@ -584,7 +578,7 @@ "%matplotlib widget\n", "%config InlineBackend.figure_formats = ['svg']\n", "\n", - "fig, axes = plt.subplots(figsize=(12, 4), ncols=3, sharey=True)\n", + "fig, axes = plt.subplots(figsize=(11, 3.5), ncols=3, sharey=True)\n", "fig.canvas.toolbar_visible = False\n", "fig.canvas.header_visible = False\n", "fig.canvas.footer_visible = False\n",