diff --git a/pyirf/tests/test_utils.py b/pyirf/tests/test_utils.py index 78c26ee79..d0677b62e 100644 --- a/pyirf/tests/test_utils.py +++ b/pyirf/tests/test_utils.py @@ -35,6 +35,10 @@ def test_cone_solid_angle(): def test_check_table(): from pyirf.utils import check_table + # works with Table as well as QTable + check_table(Table({"foo": [1]}), required_columns=["foo"]) + check_table(Table({"foo": [1] * u.m}), required_units={"foo": u.m}) + check_table(QTable({"foo": [1] * u.m}), required_units={"foo": u.m}) t = Table({"bar": [0, 1, 2] * u.TeV}) @@ -96,12 +100,14 @@ def test_calculate_source_fov_offset(): from pyirf.utils import calculate_source_fov_offset a = u.Quantity([1.0], u.deg) - t = QTable({ - 'pointing_az': a, - 'pointing_alt': a, - 'true_az': a, - 'true_alt': a, - }) + t = QTable( + { + "pointing_az": a, + "pointing_alt": a, + "true_az": a, + "true_alt": a, + } + ) assert u.isclose(calculate_source_fov_offset(t), 0.0 * u.deg) @@ -110,12 +116,16 @@ def test_check_histograms(): from pyirf.binning import create_histogram_table from pyirf.utils import check_histograms - events1 = QTable({ - 'reco_energy': [1, 1, 10, 100, 100, 100] * u.TeV, - }) - events2 = QTable({ - 'reco_energy': [100, 100, 100] * u.TeV, - }) + events1 = QTable( + { + "reco_energy": [1, 1, 10, 100, 100, 100] * u.TeV, + } + ) + events2 = QTable( + { + "reco_energy": [100, 100, 100] * u.TeV, + } + ) bins = [0.5, 5, 50, 500] * u.TeV hist1 = create_histogram_table(events1, bins) diff --git a/pyirf/utils.py b/pyirf/utils.py index f3a865a80..26c37b06f 100644 --- a/pyirf/utils.py +++ b/pyirf/utils.py @@ -143,7 +143,7 @@ def check_table(table, required_columns=None, required_units=None): Parameters ---------- - table: astropy.table.QTable + table: astropy.table.Table Table to check required_columns: iterable[str] Column names that are required to be present @@ -157,6 +157,7 @@ def check_table(table, required_columns=None, required_units=None): as keys in ``required_units are`` not present in the table. WrongColumnUnit: if any column has the wrong unit """ + table = QTable(table) if required_columns is not None: missing = set(required_columns) - set(table.colnames) if missing: