Skip to content

Commit

Permalink
add test and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Dec 24, 2024
1 parent 62961a8 commit c0050b4
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 16 deletions.
4 changes: 3 additions & 1 deletion src/arviz_plots/backend/bokeh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
from bokeh.layouts import GridBox, gridplot
from bokeh.models import GridPlot, Range1d, Title, Span
from bokeh.models import GridPlot, Range1d, Span, Title
from bokeh.plotting import figure
from bokeh.plotting import show as _show

Expand Down Expand Up @@ -364,13 +364,15 @@ def fill_between_y(x, y_bottom, y_top, target, **artist_kws):
y_top = y_top.item()
return target.varea(x=x, y1=y_bottom, y2=y_top, **artist_kws)


def vline(x, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws):
"""Interface to bokeh for a vertical line spanning the whole axes."""
kwargs = {"line_color": color, "line_alpha": alpha, "line_width": width, "line_dash": linestyle}
span_element = Span(location=x, dimension="height", **_filter_kwargs(kwargs, artist_kws))
target.add_layout(span_element)
return span_element


def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws):
"""Interface to bokeh for a horizontal line spanning the whole axes."""
kwargs = {"line_color": color, "line_alpha": alpha, "line_width": width, "line_dash": linestyle}
Expand Down
2 changes: 2 additions & 0 deletions src/arviz_plots/backend/matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,14 @@ def fill_between_y(x, y_bottom, y_top, target, **artist_kws):
artist_kws.setdefault("linewidth", 0)
return target.fill_between(x, y_bottom, y_top, **artist_kws)


def vline(x, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws):
"""Interface to matplotlib for a vertical line spanning the whole axes."""
artist_kws.setdefault("zorder", 2)
kwargs = {"color": color, "alpha": alpha, "linewidth": width, "linestyle": linestyle}
return target.axvline(x, **_filter_kwargs(kwargs, Line2D, artist_kws))


def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws):
"""Interface to matplotlib for a horizontal line spanning the whole axes."""
artist_kws.setdefault("zorder", 2)
Expand Down
3 changes: 3 additions & 0 deletions src/arviz_plots/backend/none/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def fill_between_y(x, y_bottom, y_top, target, *, color=unset, alpha=unset, **ar
target.append(artist_element)
return artist_element


def vline(x, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws):
"""Interface to a vertical line spanning the whole axes."""
kwargs = {"color": color, "alpha": alpha, "width": width, "linestyle": linestyle}
Expand All @@ -341,6 +342,7 @@ def vline(x, target, *, color=unset, alpha=unset, width=unset, linestyle=unset,
target.append(artist_element)
return artist_element


def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws):
"""Interface to a horizontal line spanning the whole axes."""
kwargs = {"color": color, "alpha": alpha, "width": width, "linestyle": linestyle}
Expand All @@ -354,6 +356,7 @@ def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset,
target.append(artist_element)
return artist_element


# general plot appeareance
def title(string, target, *, size=unset, color=unset, **artist_kws):
"""Interface to adding a title to a plot."""
Expand Down
11 changes: 9 additions & 2 deletions src/arviz_plots/backend/plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,21 +438,28 @@ def fill_between_y(x, y_bottom, y_top, target, *, color=unset, alpha=unset, **ar
target.add_trace(second_line_with_fill)
return second_line_with_fill


def vline(x, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws):
"""Interface to plotly for a vertical line spanning the whole axes."""
artist_kws.setdefault("showlegend", False)
line_kwargs = {"color": color, "width": width, "dash": linestyle}
line_artist_kws = artist_kws.pop("line", {}).copy()
kwargs = {"opacity": alpha}
return target.add_vline(x, line=_filter_kwargs(line_kwargs, line_artist_kws), **_filter_kwargs(kwargs, artist_kws))
return target.add_vline(
x, line=_filter_kwargs(line_kwargs, line_artist_kws), **_filter_kwargs(kwargs, artist_kws)
)


def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws):
"""Interface to plotly for a horizontal line spanning the whole axes."""
artist_kws.setdefault("showlegend", False)
line_kwargs = {"color": color, "width": width, "dash": linestyle}
line_artist_kws = artist_kws.pop("line", {}).copy()
kwargs = {"opacity": alpha}
return target.add_hline(y, line=_filter_kwargs(line_kwargs, line_artist_kws), **_filter_kwargs(kwargs, artist_kws))
return target.add_hline(
y, line=_filter_kwargs(line_kwargs, line_artist_kws), **_filter_kwargs(kwargs, artist_kws)
)


# general plot appeareance
def title(string, target, *, size=unset, color=unset, **artist_kws):
Expand Down
31 changes: 18 additions & 13 deletions src/arviz_plots/plots/convergencedistplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def plot_convergence_dist(

ref_line_kwargs = copy(plot_kwargs.get("ref_line", {}))
if ref_line_kwargs is False:
raise ValueError("plot_kwargs['ref_line'] can't be False, use ref_line=False to remove this element")
raise ValueError(
"plot_kwargs['ref_line'] can't be False, use ref_line=False to remove this element"
)

if pc_kwargs is None:
pc_kwargs = {}
Expand All @@ -154,7 +156,9 @@ def plot_convergence_dist(
aes_map = aes_map.copy()

if diagnostics is None:
diagnostics = ["ess_bulk", "ess_tail", "rhat_rank"]
diagnostics = ["ess_bulk", "ess_tail", "rhat"]
elif isinstance(diagnostics, str):
diagnostics = [diagnostics]

if backend is None:
if plot_collection is None:
Expand Down Expand Up @@ -208,13 +212,14 @@ def plot_convergence_dist(
ess_ref = dt.sizes["chain"] * 100
# is this valid for all r_hat methods? Do we want to correct for multiple comparisons?
r_hat_ref = 1.01
ref_ds = xr.Dataset({diagnostic: ess_ref if "ess" in diagnostic else r_hat_ref for diagnostic in distribution.data_vars})
ref_ds = xr.Dataset(
{
diagnostic: ess_ref if "ess" in diagnostic else r_hat_ref
for diagnostic in distribution.data_vars
}
)
plot_collection.map(
vline,
"ref_line",
data=ref_ds,
ignore_aes=ref_ignore,
**ref_line_kwargs
vline, "ref_line", data=ref_ds, ignore_aes=ref_ignore, **ref_line_kwargs
)

return plot_collection
Expand All @@ -228,16 +233,16 @@ def _compute_diagnostics(dt, diagnostics, sample_dims):
method = diagnostic.split("_", 1)[1].split("(", 1)[0]
if method in {"tail", "quantile", "local"} and "(" in diagnostic:
prob = [float(p) for p in diagnostic.split("(", 1)[1].rstrip(")").split(", ")]
diagnostic_values[diagnostic] = (
dt.azstats.ess(method=method, prob=prob, dims=sample_dims).to_stacked_array("label", sample_dims=[])
)
diagnostic_values[diagnostic] = dt.azstats.ess(
method=method, prob=prob, dims=sample_dims
).to_stacked_array("label", sample_dims=[])
elif "rhat" in diagnostic:
kwargs = {"dims": sample_dims}
if diagnostic != "rhat":
method = diagnostic.split("_", 1)[1]
kwargs.update({"method": method})
diagnostic_values[diagnostic] = (
dt.azstats.rhat(**kwargs).to_stacked_array("label", sample_dims=[])
diagnostic_values[diagnostic] = dt.azstats.rhat(**kwargs).to_stacked_array(
"label", sample_dims=[]
)
else:
warnings.warn(f"{diagnostic} is not recognized as a valid diagnostic")
Expand Down
3 changes: 3 additions & 0 deletions src/arviz_plots/visuals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,19 @@ def ecdf_line(values, target, backend, **kwargs):
plot_backend = import_module(f"arviz_plots.backend.{backend}")
return plot_backend.line(values.sel(plot_axis="x"), values.sel(plot_axis="y"), target, **kwargs)


def vline(values, target, backend, **kwargs):
"""Plot a vertical line that spans the whole figure independently of zoom."""
plot_backend = import_module(f"arviz_plots.backend.{backend}")
return plot_backend.vline(values.item(), target, **kwargs)


def hline(values, target, backend, **kwargs):
"""Plot a horizontal line that spans the whole figure independently of zoom."""
plot_backend = import_module(f"arviz_plots.backend.{backend}")
return plot_backend.hline(values.item(), target, **kwargs)


def fill_between_y(da, target, backend, *, x=None, y_bottom=None, y=None, y_top=None, **kwargs):
"""Fill the region between to given y values."""
if "kwarg" in da.dims:
Expand Down
51 changes: 51 additions & 0 deletions tests/test_hypothesis_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from scipy.stats import halfnorm, norm

from arviz_plots import (
plot_convergence_dist,
plot_dist,
plot_ess,
plot_ess_evolution,
Expand Down Expand Up @@ -354,3 +355,53 @@ def test_plot_psense(datatree, alphas, kind, point_estimate, ci_kind, plot_kwarg
assert all(key not in child for child in pc.viz.children.values())
elif key != "remove_axis":
assert all(key in child for child in pc.viz.children.values())


@given(
plot_kwargs=st.fixed_dictionaries(
{},
optional={
"kind": plot_kwargs_value,
"ref_line": plot_kwargs_value_no_false,
"title": plot_kwargs_value,
"remove_axis": st.just(False),
},
),
diagnostics=st.sampled_from(
[
# fmt: off
None, "rhat", "rhat_rank", "rhat_folded", "rhat_z_scale", "rhat_split",
"rhat_identity", "ess_bulk", "ess_tail", "ess_mean", "ess_sd", "ess_quantile",
"ess_local", "ess_median", "ess_mad", "ess_z_scale", "ess_folded", "ess_identity"
# fmt: on
]
),
kind=kind_value,
ref_line=st.booleans(),
)
def test_plot_convergence_dist(datatree, diagnostics, kind, ref_line, plot_kwargs):
kind_kwargs = plot_kwargs.pop("kind", None)
if kind_kwargs is not None:
plot_kwargs[kind] = kind_kwargs
pc = plot_convergence_dist(
datatree,
diagnostics=diagnostics,
backend="none",
kind=kind,
ref_line=ref_line,
plot_kwargs=plot_kwargs,
)
assert all("plot" in child for child in pc.viz.children.values())
if diagnostics is None:
diagnostics = ["ess_bulk", "ess_tail", "rhat"]
assert [diagnostic in pc.viz.children for diagnostic in diagnostics]
for key, value in plot_kwargs.items():
if value is False:
assert all(key not in child for child in pc.viz.children.values())
elif key == "ref_line":
if ref_line:
assert all(key in child for child in pc.viz.children.values())
else:
assert all(key not in child for child in pc.viz.children.values())
elif key != "remove_axis":
assert all(key in child for child in pc.viz.children.values())

0 comments on commit c0050b4

Please sign in to comment.