From 25f1c684cac47e63c99e4ee3dc5ad1f0edf44542 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Tue, 9 Mar 2021 17:41:41 -0500 Subject: [PATCH] feat: support data= --- docs/changelog.rst | 3 +++ src/hist/basehist.py | 4 ++++ tests/test_general.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 21daea74..527c14a0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,6 +7,9 @@ Version 2.2.0 * Support boost-histogram 1.0. Better plain reprs. Full Static Typing. `#137 `_ +* Support ``data=`` when construction a histogram to copy in initial data. + `#142 `_ + * Support ``Hist.from_columns``, for simple conversion of DataFrames and similar structures `#140 `_ diff --git a/src/hist/basehist.py b/src/hist/basehist.py index 0552c19f..64d8bf67 100644 --- a/src/hist/basehist.py +++ b/src/hist/basehist.py @@ -61,6 +61,7 @@ def __init__( *args: Union[AxisProtocol, Storage, str, Tuple[int, float, float]], storage: Optional[Union[Storage, str]] = None, metadata: Any = None, + data: Optional[np.ndarray] = None, ) -> None: """ Initialize BaseHist object. Axis params can contain the names. @@ -97,6 +98,9 @@ def __init__( if not ax.label: ax.label = f"Axis {i}" + if data is not None: + self[...] = data # type: ignore + def _generate_axes_(self) -> NamedAxesTuple: """ This is called to fill in the axes. Subclasses can override it if they need diff --git a/tests/test_general.py b/tests/test_general.py index 7d369e6d..6d9e2df9 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -821,3 +821,32 @@ def test_from_columns(named_hist): with pytest.raises(TypeError): named_hist.from_columns(columns, (axis.Integer(1, 5), "y"), weight="data") + + +def test_from_array(named_hist): + h = Hist( + axis.Regular(10, 1, 2, name="A"), + axis.Regular(7, 1, 3, name="B"), + data=np.ones((10, 7)), + ) + assert h.values() == approx(np.ones((10, 7))) + assert h.sum() == approx(70) + assert h.sum(flow=True) == approx(70) + + h = Hist( + axis.Regular(10, 1, 2, name="A"), + axis.Regular(7, 1, 3, name="B"), + data=np.ones((12, 9)), + ) + + assert h.values(flow=False) == approx(np.ones((10, 7))) + assert h.values(flow=True) == approx(np.ones((12, 9))) + assert h.sum() == approx(70) + assert h.sum(flow=True) == approx(12 * 9) + + with pytest.raises(ValueError): + h = Hist( + axis.Regular(10, 1, 2, name="A"), + axis.Regular(7, 1, 3, name="B"), + data=np.ones((11, 9)), + )