Skip to content

Commit

Permalink
feat: support data=
Browse files Browse the repository at this point in the history
  • Loading branch information
henryiii committed Mar 9, 2021
1 parent 3827e75 commit 25f1c68
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Version 2.2.0
* Support boost-histogram 1.0. Better plain reprs. Full Static Typing.
`#137 <https://github.com/scikit-hep/hist/pull/137>`_

* Support ``data=`` when construction a histogram to copy in initial data.
`#142 <https://github.com/scikit-hep/hist/pull/142>`_

* Support ``Hist.from_columns``, for simple conversion of DataFrames and similar structures
`#140 <https://github.com/scikit-hep/hist/pull/140>`_

Expand Down
4 changes: 4 additions & 0 deletions src/hist/basehist.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
*args: Union[AxisProtocol, Storage, str, Tuple[int, float, float]],
storage: Optional[Union[Storage, str]] = None,
metadata: Any = None,
data: Optional[np.ndarray] = None,
) -> None:
"""
Initialize BaseHist object. Axis params can contain the names.
Expand Down Expand Up @@ -97,6 +98,9 @@ def __init__(
if not ax.label:
ax.label = f"Axis {i}"

if data is not None:
self[...] = data # type: ignore

def _generate_axes_(self) -> NamedAxesTuple:
"""
This is called to fill in the axes. Subclasses can override it if they need
Expand Down
29 changes: 29 additions & 0 deletions tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,3 +821,32 @@ def test_from_columns(named_hist):

with pytest.raises(TypeError):
named_hist.from_columns(columns, (axis.Integer(1, 5), "y"), weight="data")


def test_from_array(named_hist):
h = Hist(
axis.Regular(10, 1, 2, name="A"),
axis.Regular(7, 1, 3, name="B"),
data=np.ones((10, 7)),
)
assert h.values() == approx(np.ones((10, 7)))
assert h.sum() == approx(70)
assert h.sum(flow=True) == approx(70)

h = Hist(
axis.Regular(10, 1, 2, name="A"),
axis.Regular(7, 1, 3, name="B"),
data=np.ones((12, 9)),
)

assert h.values(flow=False) == approx(np.ones((10, 7)))
assert h.values(flow=True) == approx(np.ones((12, 9)))
assert h.sum() == approx(70)
assert h.sum(flow=True) == approx(12 * 9)

with pytest.raises(ValueError):
h = Hist(
axis.Regular(10, 1, 2, name="A"),
axis.Regular(7, 1, 3, name="B"),
data=np.ones((11, 9)),
)

0 comments on commit 25f1c68

Please sign in to comment.