diff --git a/pyproject.toml b/pyproject.toml index e96b83e..83e6b90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,4 +107,6 @@ exclude_also = [ "raise NotImplementedError", "__repr__", "TYPE_CHECKING", + "if cls._instance is None:", + "except KeyError as missing_key:", ] \ No newline at end of file diff --git a/src/rail/cli/rail_plot/plot_commands.py b/src/rail/cli/rail_plot/plot_commands.py index cda2352..c4086c5 100644 --- a/src/rail/cli/rail_plot/plot_commands.py +++ b/src/rail/cli/rail_plot/plot_commands.py @@ -22,6 +22,8 @@ def plot_cli() -> None: @plot_options.exclude_groups() @plot_options.save_plots() @plot_options.purge_plots() +@plot_options.find_only() +@plot_options.make_html() @options.outdir() def run_command(config_file: str, **kwargs: Any) -> int: """Make a bunch of plots""" @@ -48,27 +50,39 @@ def inspect_command(config_file: str) -> int: @plot_cli.command(name="extract-datasets") @project_options.config_file() -@plot_options.dataset_list_name() @plot_options.extractor_class() +@options.output_yaml() +@plot_options.dataset_list_name() +@plot_options.dataset_holder_class() @project_options.flavor() @project_options.selection() -@options.output_yaml() +@plot_options.split_by_flavor() def extract_datasets_command( config_file: str, - dataset_list_name: str, extractor_class: str, - flavor: list[str], - selection: list[str], output_yaml: str, + **kwargs: dict[str, Any], ) -> int: """Create a yaml file with the datasets in a project""" control.clear() control.extract_datasets( config_file, - dataset_list_name, extractor_class, - flavors=flavor, - selections=selection, output_yaml=output_yaml, + **kwargs, ) return 0 + + +@plot_cli.command(name="make-plot-groups") +@options.output_yaml() +@plot_options.plotter_yaml_path() +@plot_options.dataset_yaml_path() +@plot_options.plotter_list_name() +@plot_options.output_prefix() +@plot_options.dataset_list_name(multiple=True) +def make_plot_groups(output_yaml: str, **kwargs: dict[str, Any]) -> int: + """Create a yaml file with the datasets in a project""" + control.clear() + control.make_plot_group_yaml(output_yaml, **kwargs) + return 0 diff --git a/src/rail/cli/rail_plot/plot_options.py b/src/rail/cli/rail_plot/plot_options.py index 7459d25..949074d 100644 --- a/src/rail/cli/rail_plot/plot_options.py +++ b/src/rail/cli/rail_plot/plot_options.py @@ -6,10 +6,18 @@ __all__: list[str] = [ "purge_plots", "save_plots", + "find_only", + "make_html", + "dataset_holder_class", "dataset_list_name", + "plotter_list_name", + "output_prefix", + "dataset_yaml_path", + "plotter_yaml_path", "include_groups", "exclude_groups", "extractor_class", + "split_by_flavor", ] @@ -34,6 +42,13 @@ ) +dataset_holder_class = PartialOption( + "--dataset_holder_class", + help="Class for the dataset holder", + type=str, +) + + dataset_list_name = PartialOption( "--dataset_list_name", help="Name for dataset list", @@ -41,6 +56,49 @@ ) +dataset_yaml_path = PartialOption( + "--dataset_yaml_path", + help="Name for dataset list", + type=str, +) + + +output_prefix = PartialOption( + "--output_prefix", + help="Name for dataset list", + default="", + type=str, +) + + +plotter_list_name = PartialOption( + "--plotter_list_name", + help="Name for plotter list", + type=str, +) + + +plotter_yaml_path = PartialOption( + "--plotter_yaml_path", + help="Name for plotter list", + type=str, +) + + +find_only = PartialOption( + "--find_only", + help="Find existing plots, do not create new ones", + is_flag=True, +) + + +make_html = PartialOption( + "--make_html", + help="Make html files to help browse plots", + is_flag=True, +) + + purge_plots = PartialOption( "--purge_plots", help="Purge plots from memory after saving", @@ -52,3 +110,9 @@ help="Save plots to disk", is_flag=True, ) + +split_by_flavor = PartialOption( + "--split_by_flavor", + help="Split dataset organization by flavor", + is_flag=True, +) diff --git a/src/rail/plotting/control.py b/src/rail/plotting/control.py index 1da33dc..31a6524 100644 --- a/src/rail/plotting/control.py +++ b/src/rail/plotting/control.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import Any +import os import yaml from rail.projects import RailProject @@ -59,6 +60,8 @@ load_plot_group_yaml = RailPlotGroupFactory.load_yaml +make_plot_group_yaml = RailPlotGroupFactory.make_yaml + print_plot_group_contents = RailPlotGroupFactory.print_contents get_plot_group_dict = RailPlotGroupFactory.get_plot_groups @@ -179,6 +182,8 @@ def run( include_groups = kwargs.pop("include_groups", None) exclude_groups = kwargs.pop("exclude_groups", None) + make_html = kwargs.get("make_html", False) + outdir = kwargs.get("outdir", "plots") if include_groups is None or not include_groups: include_groups = list(group_dict.keys()) @@ -187,19 +192,24 @@ def run( for exclude_group_ in exclude_groups: # pragma: no cover include_groups.remove(exclude_group_) + output_pages: list[str] = [] for group_ in include_groups: plot_group = group_dict[group_] out_dict.update(plot_group(**kwargs)) + if make_html: + output_pages.append(f"plots_{plot_group.name}.html") + if make_html: + RailPlotGroup.make_html_index( + os.path.join(outdir, "plot_index.html"), output_pages + ) return out_dict def extract_datasets( config_file: str, - dataset_list_name: str, extractor_class: str, - flavors: list[str], - selections: list[str], output_yaml: str, + **kwargs: dict[str, Any], ) -> None: """Extract datasets into a yaml file @@ -208,11 +218,19 @@ def extract_datasets( config_file: str Yaml project configuration file + extractor_class: str + Class used to extract Datasets + + output_yaml: str + Path to output file + + Keywords + -------- dataset_list_name: str Name for the resulting DatasetList - extractor_class: str - Class used to extract Datasets + dataset_holder_class: str + Class for the dataset holder selections: list[str] Selections to use @@ -220,16 +238,14 @@ def extract_datasets( flavors: list[str] Flavors to use - output_yaml: str - Path to output file + split_by_flavor: bool + Split dataset lists by flavor """ extractor_cls = load_extractor_class(extractor_class) project = RailProject.load_config(config_file) output_data = extractor_cls.generate_dataset_dict( - dataset_list_name, - project, - selections, - flavors, + project=project, + **kwargs, ) with open(output_yaml, "w", encoding="utf-8") as fout: yaml.dump(output_data, fout) diff --git a/src/rail/plotting/data_extraction.py b/src/rail/plotting/data_extraction.py index e70cceb..1ebd6b6 100644 --- a/src/rail/plotting/data_extraction.py +++ b/src/rail/plotting/data_extraction.py @@ -2,8 +2,6 @@ from typing import Any -from rail.projects import RailProject - from .configurable import Configurable from .dynamic_class import DynamicClass from .validation import validate_inputs @@ -75,10 +73,7 @@ def _get_data(self, **kwargs: Any) -> dict[str, Any] | None: @classmethod def generate_dataset_dict( cls, - dataset_list_name: str, - project: RailProject, - selections: list[str] | None = None, - flavors: list[str] | None = None, + **kwargs: dict[str, Any], ) -> list[dict[str, Any]]: """Create a dict of the datasets that this extractor can extract @@ -87,6 +82,9 @@ def generate_dataset_dict( dataset_list_name: str Name for the resulting DatasetList + dataset_holder_class: str + Class for the dataset holder + project: RailProject Project to inspect diff --git a/src/rail/plotting/dataset_factory.py b/src/rail/plotting/dataset_factory.py index 283d5c4..2bafa29 100644 --- a/src/rail/plotting/dataset_factory.py +++ b/src/rail/plotting/dataset_factory.py @@ -57,21 +57,21 @@ def __init__(self) -> None: @classmethod def instance(cls) -> RailDatasetFactory: """Return the singleton instance of the factory""" - if cls._instance is None: # pragma: no cover + if cls._instance is None: cls._instance = RailDatasetFactory() return cls._instance @classmethod def clear(cls) -> None: """Clear the contents of the factory""" - if cls._instance is None: # pragma: no cover + if cls._instance is None: return cls._instance.clear_instance() @classmethod def print_contents(cls) -> None: """Print the contents of the factory""" - if cls._instance is None: # pragma: no cover + if cls._instance is None: cls._instance = RailDatasetFactory() cls._instance.print_instance_contents() @@ -236,11 +236,11 @@ def print_instance_contents(self) -> None: def _make_dataset(self, **kwargs: Any) -> RailDatasetHolder: try: name = kwargs["name"] - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( "Dataset yaml block does not contain name for dataset: " f"{list(kwargs.keys())}" - ) from msg + ) from missing_key if name in self._datasets: # pragma: no cover raise KeyError(f"Dataset {name} is already defined") dataset_holder = RailDatasetHolder.create_from_dict(kwargs) @@ -256,11 +256,11 @@ def _make_dataset_dict( for dataset_name in dataset_name_list: try: dataset = self._datasets[dataset_name] - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( f"Dataset {dataset_name} used in DatasetList " f"is not found {list(self._datasets.keys())}" - ) from msg + ) from missing_key datasets[dataset_name] = dataset self._dataset_dicts[name] = datasets return datasets @@ -287,18 +287,18 @@ def load_dataset_list_from_yaml_tag( """ try: name = dataset_list_config.pop("name") - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( "DatasetList yaml block does not contain name for dataset: " f"{list(dataset_list_config.keys())}" - ) from msg + ) from missing_key try: dataset_names = dataset_list_config.pop("datasets") - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( "DatasetList yaml block does not contain dataset: " f"{list(dataset_list_config.keys())}" - ) from msg + ) from missing_key self._make_dataset_dict(name, dataset_names) def load_project_from_yaml_tag(self, project_config: dict[str, Any]) -> None: @@ -311,18 +311,18 @@ def load_project_from_yaml_tag(self, project_config: dict[str, Any]) -> None: """ try: name = project_config.pop("name") - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( "Project yaml block does not contain name for project: " f"{list(project_config.keys())}" - ) from msg + ) from missing_key try: project_yaml = project_config.pop("yaml_file") - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( "Project yaml block does not contain yaml_file: " f"{list(project_config.keys())}" - ) from msg + ) from missing_key self._projects[name] = RailProject.load_config(project_yaml) def load_instance_yaml(self, yaml_file: str) -> None: diff --git a/src/rail/plotting/dataset_holder.py b/src/rail/plotting/dataset_holder.py index 3fdd2ea..3071a41 100644 --- a/src/rail/plotting/dataset_holder.py +++ b/src/rail/plotting/dataset_holder.py @@ -53,10 +53,10 @@ def _validate_extractor_inputs(cls, **kwargs: Any) -> None: for key, expected_type in cls.extractor_inputs.items(): try: data = kwargs[key] - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( f"{key} not provided to RailDatasetHolder {cls} in {list(kwargs.keys())}" - ) from msg + ) from missing_key if isinstance(expected_type, GenericAlias): if not isinstance(data, expected_type.__origin__): # pragma: no cover raise TypeError( diff --git a/src/rail/plotting/plot_group.py b/src/rail/plotting/plot_group.py index fd7e108..d7abbeb 100644 --- a/src/rail/plotting/plot_group.py +++ b/src/rail/plotting/plot_group.py @@ -20,6 +20,7 @@ class RailPlotGroup: jinja_env: Environment | None = None jinja_template: Template | None = None + jinja_index_template: Template | None = None @classmethod def _load_jinja(cls) -> None: @@ -29,6 +30,7 @@ def _load_jinja(cls) -> None: loader=FileSystemLoader("src/rail/projects/html_templates") ) cls.jinja_template = cls.jinja_env.get_template("plot_group_table.html") + cls.jinja_index_template = cls.jinja_env.get_template("plot_group_index.html") def __init__( self, @@ -139,6 +141,20 @@ def find_plots( ) return self._plots + @classmethod + def make_html_index( + cls, + outfile: str, + output_pages: list[str], + ) -> None: + cls._load_jinja() + assert cls.jinja_index_template is not None + + # Render template data and save to HTML file + output = cls.jinja_index_template.render(output_pages=output_pages, os=os) + with open(outfile, "w", encoding="utf-8") as file: + file.write(output) + def make_html( self, outfile: str, diff --git a/src/rail/plotting/plot_group_factory.py b/src/rail/plotting/plot_group_factory.py index 59e889b..ce8a242 100644 --- a/src/rail/plotting/plot_group_factory.py +++ b/src/rail/plotting/plot_group_factory.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + import yaml from .dataset_factory import RailDatasetFactory @@ -35,24 +37,67 @@ def __init__(self) -> None: @classmethod def instance(cls) -> RailPlotGroupFactory: """Return the singleton instance of the factory""" - if cls._instance is None: # pragma: no cover + if cls._instance is None: cls._instance = RailPlotGroupFactory() return cls._instance @classmethod def clear(cls) -> None: """Clear the contents of the factory""" - if cls._instance is None: # pragma: no cover + if cls._instance is None: return cls._instance.clear_instance() @classmethod def print_contents(cls) -> None: """Print the contents of the factory""" - if cls._instance is None: # pragma: no cover + if cls._instance is None: cls._instance = RailPlotGroupFactory() cls._instance.print_instance_contents() + @classmethod + def make_yaml( + cls, + output_yaml: str, + plotter_yaml_path: str, + dataset_yaml_path: str, + plotter_list_name: str, + output_prefix: str = "", + dataset_list_name: list[str] | None = None, + ) -> None: + """Construct a yaml file defining plot groups + + Parameters + ---------- + output_yaml: str + Path to the output file + + plotter_yaml_path: str + Path to the yaml file defining the plotter_lists + + dataset_yaml_path: str + Path to the yaml file defining the datasets + + plotter_list_name: str + Name of plotter list to use + + output_prefix: str="" + Prefix for PlotGroup names we construct + + dataset_list_names: list[str] | None=None + Names of dataset lists to use + """ + if cls._instance is None: + cls._instance = RailPlotGroupFactory() + cls._instance.make_instance_yaml( + output_yaml=output_yaml, + plotter_yaml_path=plotter_yaml_path, + dataset_yaml_path=dataset_yaml_path, + plotter_list_name=plotter_list_name, + output_prefix=output_prefix, + dataset_list_name=dataset_list_name, + ) + @classmethod def load_yaml(cls, yaml_file: str) -> dict[str, RailPlotGroup]: """Load a yaml file @@ -97,6 +142,68 @@ def print_instance_contents(self) -> None: print(f" {plot_group_name}: {plot_group}") print("----------------") + def make_instance_yaml( + self, + output_yaml: str, + plotter_yaml_path: str, + dataset_yaml_path: str, + plotter_list_name: str, + output_prefix: str = "", + dataset_list_name: list[str] | None = None, + ) -> None: + """Construct a yaml file defining plot groups + + Parameters + ---------- + output_yaml: str + Path to the output file + + plotter_yaml_path: str + Path to the yaml file defining the plotter_lists + + dataset_yaml_path: str + Path to the yaml file defining the datasets + + plotter_list_name: str + Name of plotter list to use + + output_prefix: str="" + Prefix for PlotGroup names we construct + + dataset_list_name: list[str] + Names of dataset lists to use + """ + RailPlotterFactory.clear() + RailPlotterFactory.load_yaml(plotter_yaml_path) + RailDatasetFactory.clear() + RailDatasetFactory.load_yaml(dataset_yaml_path) + + plotter_list = RailPlotterFactory.get_plotter_list(plotter_list_name) + assert plotter_list + if not dataset_list_name: # pragma: no cover + dataset_list_name = RailDatasetFactory.get_dataset_dict_names() + + output: list[dict[str, Any]] = [] + output.append( + dict(PlotterYaml=dict(path=plotter_yaml_path)), + ) + output.append( + dict(DatasetYaml=dict(path=dataset_yaml_path)), + ) + for ds_name in dataset_list_name: + group_name = f"{output_prefix}{ds_name}_{plotter_list_name}" + output.append( + dict( + PlotGroup=dict( + name=group_name, + plotter_list_name=plotter_list_name, + dataset_dict_name=ds_name, + ) + ) + ) + with open(output_yaml, "w", encoding="utf-8") as fout: + yaml.dump(output, fout) + def load_instance_yaml( self, yaml_file: str, @@ -138,33 +245,33 @@ def load_instance_yaml( plotter_yaml_config = group_item["PlotterYaml"] try: plotter_yaml_path = plotter_yaml_config.pop("path") - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( "PlotterYaml yaml block does not contain path: " f"{list(plotter_yaml_config.keys())}" - ) from msg + ) from missing_key RailPlotterFactory.clear() RailPlotterFactory.load_yaml(plotter_yaml_path) elif "DatasetYaml" in group_item: dataset_yaml_config = group_item["DatasetYaml"] try: dataset_yaml_path = dataset_yaml_config.pop("path") - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( "PlotterYamlDatasetYaml yaml block does not contain path: " f"{list(dataset_yaml_config.keys())}" - ) from msg + ) from missing_key RailDatasetFactory.clear() RailDatasetFactory.load_yaml(dataset_yaml_path) elif "PlotGroup" in group_item: plot_group_config = group_item["PlotGroup"] try: name = plot_group_config.pop("name") - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( "PlotGroup yaml block does not contain name for plot group: " f"{list(plot_group_config.keys())}" - ) from msg + ) from missing_key self._plot_groups[name] = RailPlotGroup.create(name, plot_group_config) else: # pragma: no cover good_keys = ["PlotterYaml", "DatasetYaml", "PlotGroup"] diff --git a/src/rail/plotting/plotter_factory.py b/src/rail/plotting/plotter_factory.py index 59ca966..9ff9126 100644 --- a/src/rail/plotting/plotter_factory.py +++ b/src/rail/plotting/plotter_factory.py @@ -47,21 +47,21 @@ def __init__(self) -> None: @classmethod def instance(cls) -> RailPlotterFactory: """Return the singleton instance of the factory""" - if cls._instance is None: # pragma: no cover + if cls._instance is None: cls._instance = RailPlotterFactory() return cls._instance @classmethod def clear(cls) -> None: """Clear the contents of the factory""" - if cls._instance is None: # pragma: no cover + if cls._instance is None: return cls._instance.clear_instance() @classmethod def print_contents(cls) -> None: """Print the contents of the factory""" - if cls._instance is None: # pragma: no cover + if cls._instance is None: cls._instance = RailPlotterFactory() cls._instance.print_instance_contents() @@ -188,11 +188,11 @@ def print_instance_contents(self) -> None: def _make_plotter(self, config_dict: dict[str, Any]) -> RailPlotter: try: name = config_dict["name"] - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( "Plotter yaml block does not contain name for plotter: " f"{list(config_dict.keys())}" - ) from msg + ) from missing_key if name in self._plotter_dict: # pragma: no cover raise KeyError(f"Plotter {name} is already defined") plotter = RailPlotter.create_from_dict(config_dict) @@ -208,11 +208,11 @@ def _make_plotter_list( for plotter_name in plotter_list: try: plotter = self._plotter_dict[plotter_name] - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( f"RailPlotter {plotter_name} used in PlotterList " f"is not found {list(self._plotter_dict.keys())}" - ) from msg + ) from missing_key plotters.append(plotter) self._plotter_list_dict[name] = plotters return plotters @@ -239,17 +239,17 @@ def load_plotter_list_from_yaml_tag( """ try: name = plotter_list_config.pop("name") - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( "PlotterList yaml block does not contain name for plotter: " f"{list(plotter_list_config.keys())}" - ) from msg + ) from missing_key try: plotters = plotter_list_config.pop("plotters") - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( f"PlotterList yaml block does not contain plotter: {list(plotter_list_config.keys())}" - ) from msg + ) from missing_key self._make_plotter_list(name, plotters) def load_instance_yaml(self, yaml_file: str) -> None: diff --git a/src/rail/plotting/pz_data_extraction.py b/src/rail/plotting/pz_data_extraction.py index d6e0fc9..f38b956 100644 --- a/src/rail/plotting/pz_data_extraction.py +++ b/src/rail/plotting/pz_data_extraction.py @@ -39,11 +39,38 @@ def _get_data(self, **kwargs: Any) -> dict[str, Any] | None: @classmethod def generate_dataset_dict( cls, - dataset_list_name: str, - project: RailProject, - selections: list[str] | None = None, - flavors: list[str] | None = None, + **kwargs: Any, ) -> list[dict[str, Any]]: + """ + Keywords + -------- + dataset_list_name: str + Name for the resulting DatasetList + + dataset_holder_class: str + Class for the dataset holder + + project: RailProject + Project to inspect + + selections: list[str] + Selections to use + + flavors: list[str] + Flavors to use + + Returns + ------- + output: list[dict[str, Any]] + Dictionary of the extracted datasets + """ + dataset_list_name: str | None = kwargs.get('dataset_list_name') + dataset_holder_class: str | None = kwargs.get('dataset_holder_class') + project: RailProject = kwargs.get('project') + selections = kwargs.get('selections') + flavors = kwargs.get('flavors') + split_by_flavor = kwargs.get('split_by_flavor', False) + output: list[dict[str, Any]] = [] flavor_dict = project.get_flavors() @@ -65,7 +92,10 @@ def generate_dataset_dict( output.append(project_block) - datasets: list[str] = [] + dataset_list_dict: dict[str, list[str]] = {} + dataset_key = dataset_list_name + if not split_by_flavor: + dataset_list_dict[dataset_key] = [] for key in flavors: val = flavor_dict[key] @@ -78,6 +108,10 @@ def generate_dataset_dict( algos = list(project.get_pzalgorithms().keys()) for selection_ in selections: + if split_by_flavor: + dataset_key = f"{dataset_list_name}_{selection_}_{key}" + dataset_list_dict[dataset_key] = [] + for algo_ in algos: path = get_ceci_pz_output_path( project, @@ -90,6 +124,7 @@ def generate_dataset_dict( dataset_name = f"{selection_}_{key}_{algo_}" dataset_dict = dict( name=dataset_name, + class_name=dataset_holder_class, extractor="rail.plotting.pz_data_extraction.PZPointEstimateDataExtractor", project=project_name, flavor=key, @@ -97,15 +132,19 @@ def generate_dataset_dict( tag="test", selection=selection_, ) - datasets.append(dataset_name) - output.append(dict(Dataset=dataset_dict)) - dataset_list = dict( - name=dataset_list_name, - datasets=datasets, - ) + dataset_list_dict[dataset_key].append(dataset_name) + output.append(dict(Dataset=dataset_dict)) - output.append(dict(DatasetList=dataset_list)) + for ds_name, ds_list in dataset_list_dict.items(): + # Skip empty lists + if not ds_list: + continue + dataset_list = dict( + name=ds_name, + datasets=ds_list, + ) + output.append(dict(DatasetList=dataset_list)) return output diff --git a/src/rail/plotting/validation.py b/src/rail/plotting/validation.py index 89e248d..15bac99 100644 --- a/src/rail/plotting/validation.py +++ b/src/rail/plotting/validation.py @@ -8,10 +8,10 @@ def validate_inputs(a_class: type, expected_inputs: dict, **kwargs: Any) -> None for key, expected_type in expected_inputs.items(): try: data = kwargs[key] - except KeyError as msg: # pragma: no cover + except KeyError as missing_key: raise KeyError( f"{key} not provided to {a_class.__name__} in {list(kwargs.keys())}" - ) from msg + ) from missing_key if isinstance(expected_type, GenericAlias): if not isinstance(data, expected_type.__origin__): # pragma: no cover raise TypeError( diff --git a/src/rail/projects/html_templates/plot_group_index.html b/src/rail/projects/html_templates/plot_group_index.html new file mode 100644 index 0000000..1bbe98a --- /dev/null +++ b/src/rail/projects/html_templates/plot_group_index.html @@ -0,0 +1,28 @@ + + + + + + + + +
+

Super secret index of RAIL plots

+ + + + + + + + {% for output_page in output_pages %} + + + + {% endfor %} + +
Link
{{ output_page }} +
+
+ + \ No newline at end of file diff --git a/tests/cli/plot/test_plot_cli.py b/tests/cli/plot/test_plot_cli.py index c373d3a..fcf42d9 100644 --- a/tests/cli/plot/test_plot_cli.py +++ b/tests/cli/plot/test_plot_cli.py @@ -47,13 +47,44 @@ def test_cli_extract_datasets(setup_project_area: int) -> None: assert setup_project_area == 0 runner = CliRunner() + # run with split by flavor result = runner.invoke( plot_cli, "extract-datasets " "--extractor_class rail.plotting.pz_data_extraction.PZPointEstimateDataExtractor " "--flavor all " "--selection all " + "--split_by_flavor " "--output_yaml tests/temp_data/dataset_out.yaml " "tests/ci_project.yaml", ) check_result(result) + + # run without split by flavor + result = runner.invoke( + plot_cli, + "extract-datasets " + "--extractor_class rail.plotting.pz_data_extraction.PZPointEstimateDataExtractor " + "--flavor all " + "--selection all " + "--output_yaml tests/temp_data/dataset_out.yaml " + "tests/ci_project.yaml", + ) + check_result(result) + + +@pytest.mark.skipif(missing_ci_data, reason="NO CI data") +def test_cli_make_plot_groups(setup_project_area: int) -> None: + assert setup_project_area == 0 + runner = CliRunner() + + result = runner.invoke( + plot_cli, + "make-plot-groups " + "--output_yaml tests/temp_data/check_plot_group.yaml " + "--plotter_yaml_path tests/ci_plots.yaml " + "--dataset_yaml_path tests/ci_datasets.yaml " + "--plotter_list_name zestimate_v_ztrue " + "--dataset_list_name blend_baseline_all", + ) + check_result(result)