Skip to content

Commit

Permalink
Eac/split by flavor (#4)
Browse files Browse the repository at this point in the history
* split by flavor

* split by flavor

* split by flavor

* split by flavor

* split by flavor

* more stuff

* more stuff

* more stuff

* stuff

* stuff

* stuff

* added html index

* added html index

* added html index

* added html index

* added html index

* huh
  • Loading branch information
eacharles authored Jan 22, 2025
1 parent acbe14b commit d148509
Show file tree
Hide file tree
Showing 14 changed files with 391 additions and 76 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,6 @@ exclude_also = [
"raise NotImplementedError",
"__repr__",
"TYPE_CHECKING",
"if cls._instance is None:",
"except KeyError as missing_key:",
]
30 changes: 22 additions & 8 deletions src/rail/cli/rail_plot/plot_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def plot_cli() -> None:
@plot_options.exclude_groups()
@plot_options.save_plots()
@plot_options.purge_plots()
@plot_options.find_only()
@plot_options.make_html()
@options.outdir()
def run_command(config_file: str, **kwargs: Any) -> int:
"""Make a bunch of plots"""
Expand All @@ -48,27 +50,39 @@ def inspect_command(config_file: str) -> int:

@plot_cli.command(name="extract-datasets")
@project_options.config_file()
@plot_options.dataset_list_name()
@plot_options.extractor_class()
@options.output_yaml()
@plot_options.dataset_list_name()
@plot_options.dataset_holder_class()
@project_options.flavor()
@project_options.selection()
@options.output_yaml()
@plot_options.split_by_flavor()
def extract_datasets_command(
config_file: str,
dataset_list_name: str,
extractor_class: str,
flavor: list[str],
selection: list[str],
output_yaml: str,
**kwargs: dict[str, Any],
) -> int:
"""Create a yaml file with the datasets in a project"""
control.clear()
control.extract_datasets(
config_file,
dataset_list_name,
extractor_class,
flavors=flavor,
selections=selection,
output_yaml=output_yaml,
**kwargs,
)
return 0


@plot_cli.command(name="make-plot-groups")
@options.output_yaml()
@plot_options.plotter_yaml_path()
@plot_options.dataset_yaml_path()
@plot_options.plotter_list_name()
@plot_options.output_prefix()
@plot_options.dataset_list_name(multiple=True)
def make_plot_groups(output_yaml: str, **kwargs: dict[str, Any]) -> int:
"""Create a yaml file with the datasets in a project"""
control.clear()
control.make_plot_group_yaml(output_yaml, **kwargs)
return 0
64 changes: 64 additions & 0 deletions src/rail/cli/rail_plot/plot_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
__all__: list[str] = [
"purge_plots",
"save_plots",
"find_only",
"make_html",
"dataset_holder_class",
"dataset_list_name",
"plotter_list_name",
"output_prefix",
"dataset_yaml_path",
"plotter_yaml_path",
"include_groups",
"exclude_groups",
"extractor_class",
"split_by_flavor",
]


Expand All @@ -34,13 +42,63 @@
)


dataset_holder_class = PartialOption(
"--dataset_holder_class",
help="Class for the dataset holder",
type=str,
)


dataset_list_name = PartialOption(
"--dataset_list_name",
help="Name for dataset list",
type=str,
)


dataset_yaml_path = PartialOption(
"--dataset_yaml_path",
help="Name for dataset list",
type=str,
)


output_prefix = PartialOption(
"--output_prefix",
help="Name for dataset list",
default="",
type=str,
)


plotter_list_name = PartialOption(
"--plotter_list_name",
help="Name for plotter list",
type=str,
)


plotter_yaml_path = PartialOption(
"--plotter_yaml_path",
help="Name for plotter list",
type=str,
)


find_only = PartialOption(
"--find_only",
help="Find existing plots, do not create new ones",
is_flag=True,
)


make_html = PartialOption(
"--make_html",
help="Make html files to help browse plots",
is_flag=True,
)


purge_plots = PartialOption(
"--purge_plots",
help="Purge plots from memory after saving",
Expand All @@ -52,3 +110,9 @@
help="Save plots to disk",
is_flag=True,
)

split_by_flavor = PartialOption(
"--split_by_flavor",
help="Split dataset organization by flavor",
is_flag=True,
)
38 changes: 27 additions & 11 deletions src/rail/plotting/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from typing import Any
import os
import yaml

from rail.projects import RailProject
Expand Down Expand Up @@ -59,6 +60,8 @@

load_plot_group_yaml = RailPlotGroupFactory.load_yaml

make_plot_group_yaml = RailPlotGroupFactory.make_yaml

print_plot_group_contents = RailPlotGroupFactory.print_contents

get_plot_group_dict = RailPlotGroupFactory.get_plot_groups
Expand Down Expand Up @@ -179,6 +182,8 @@ def run(

include_groups = kwargs.pop("include_groups", None)
exclude_groups = kwargs.pop("exclude_groups", None)
make_html = kwargs.get("make_html", False)
outdir = kwargs.get("outdir", "plots")

if include_groups is None or not include_groups:
include_groups = list(group_dict.keys())
Expand All @@ -187,19 +192,24 @@ def run(
for exclude_group_ in exclude_groups: # pragma: no cover
include_groups.remove(exclude_group_)

output_pages: list[str] = []
for group_ in include_groups:
plot_group = group_dict[group_]
out_dict.update(plot_group(**kwargs))
if make_html:
output_pages.append(f"plots_{plot_group.name}.html")
if make_html:
RailPlotGroup.make_html_index(
os.path.join(outdir, "plot_index.html"), output_pages
)
return out_dict


def extract_datasets(
config_file: str,
dataset_list_name: str,
extractor_class: str,
flavors: list[str],
selections: list[str],
output_yaml: str,
**kwargs: dict[str, Any],
) -> None:
"""Extract datasets into a yaml file
Expand All @@ -208,28 +218,34 @@ def extract_datasets(
config_file: str
Yaml project configuration file
extractor_class: str
Class used to extract Datasets
output_yaml: str
Path to output file
Keywords
--------
dataset_list_name: str
Name for the resulting DatasetList
extractor_class: str
Class used to extract Datasets
dataset_holder_class: str
Class for the dataset holder
selections: list[str]
Selections to use
flavors: list[str]
Flavors to use
output_yaml: str
Path to output file
split_by_flavor: bool
Split dataset lists by flavor
"""
extractor_cls = load_extractor_class(extractor_class)
project = RailProject.load_config(config_file)
output_data = extractor_cls.generate_dataset_dict(
dataset_list_name,
project,
selections,
flavors,
project=project,
**kwargs,
)
with open(output_yaml, "w", encoding="utf-8") as fout:
yaml.dump(output_data, fout)
10 changes: 4 additions & 6 deletions src/rail/plotting/data_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from typing import Any

from rail.projects import RailProject

from .configurable import Configurable
from .dynamic_class import DynamicClass
from .validation import validate_inputs
Expand Down Expand Up @@ -75,10 +73,7 @@ def _get_data(self, **kwargs: Any) -> dict[str, Any] | None:
@classmethod
def generate_dataset_dict(
cls,
dataset_list_name: str,
project: RailProject,
selections: list[str] | None = None,
flavors: list[str] | None = None,
**kwargs: dict[str, Any],
) -> list[dict[str, Any]]:
"""Create a dict of the datasets that this extractor can extract
Expand All @@ -87,6 +82,9 @@ def generate_dataset_dict(
dataset_list_name: str
Name for the resulting DatasetList
dataset_holder_class: str
Class for the dataset holder
project: RailProject
Project to inspect
Expand Down
30 changes: 15 additions & 15 deletions src/rail/plotting/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,21 @@ def __init__(self) -> None:
@classmethod
def instance(cls) -> RailDatasetFactory:
"""Return the singleton instance of the factory"""
if cls._instance is None: # pragma: no cover
if cls._instance is None:
cls._instance = RailDatasetFactory()
return cls._instance

@classmethod
def clear(cls) -> None:
"""Clear the contents of the factory"""
if cls._instance is None: # pragma: no cover
if cls._instance is None:
return
cls._instance.clear_instance()

@classmethod
def print_contents(cls) -> None:
"""Print the contents of the factory"""
if cls._instance is None: # pragma: no cover
if cls._instance is None:
cls._instance = RailDatasetFactory()
cls._instance.print_instance_contents()

Expand Down Expand Up @@ -236,11 +236,11 @@ def print_instance_contents(self) -> None:
def _make_dataset(self, **kwargs: Any) -> RailDatasetHolder:
try:
name = kwargs["name"]
except KeyError as msg: # pragma: no cover
except KeyError as missing_key:
raise KeyError(
"Dataset yaml block does not contain name for dataset: "
f"{list(kwargs.keys())}"
) from msg
) from missing_key
if name in self._datasets: # pragma: no cover
raise KeyError(f"Dataset {name} is already defined")
dataset_holder = RailDatasetHolder.create_from_dict(kwargs)
Expand All @@ -256,11 +256,11 @@ def _make_dataset_dict(
for dataset_name in dataset_name_list:
try:
dataset = self._datasets[dataset_name]
except KeyError as msg: # pragma: no cover
except KeyError as missing_key:
raise KeyError(
f"Dataset {dataset_name} used in DatasetList "
f"is not found {list(self._datasets.keys())}"
) from msg
) from missing_key
datasets[dataset_name] = dataset
self._dataset_dicts[name] = datasets
return datasets
Expand All @@ -287,18 +287,18 @@ def load_dataset_list_from_yaml_tag(
"""
try:
name = dataset_list_config.pop("name")
except KeyError as msg: # pragma: no cover
except KeyError as missing_key:
raise KeyError(
"DatasetList yaml block does not contain name for dataset: "
f"{list(dataset_list_config.keys())}"
) from msg
) from missing_key
try:
dataset_names = dataset_list_config.pop("datasets")
except KeyError as msg: # pragma: no cover
except KeyError as missing_key:
raise KeyError(
"DatasetList yaml block does not contain dataset: "
f"{list(dataset_list_config.keys())}"
) from msg
) from missing_key
self._make_dataset_dict(name, dataset_names)

def load_project_from_yaml_tag(self, project_config: dict[str, Any]) -> None:
Expand All @@ -311,18 +311,18 @@ def load_project_from_yaml_tag(self, project_config: dict[str, Any]) -> None:
"""
try:
name = project_config.pop("name")
except KeyError as msg: # pragma: no cover
except KeyError as missing_key:
raise KeyError(
"Project yaml block does not contain name for project: "
f"{list(project_config.keys())}"
) from msg
) from missing_key
try:
project_yaml = project_config.pop("yaml_file")
except KeyError as msg: # pragma: no cover
except KeyError as missing_key:
raise KeyError(
"Project yaml block does not contain yaml_file: "
f"{list(project_config.keys())}"
) from msg
) from missing_key
self._projects[name] = RailProject.load_config(project_yaml)

def load_instance_yaml(self, yaml_file: str) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/rail/plotting/dataset_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def _validate_extractor_inputs(cls, **kwargs: Any) -> None:
for key, expected_type in cls.extractor_inputs.items():
try:
data = kwargs[key]
except KeyError as msg: # pragma: no cover
except KeyError as missing_key:
raise KeyError(
f"{key} not provided to RailDatasetHolder {cls} in {list(kwargs.keys())}"
) from msg
) from missing_key
if isinstance(expected_type, GenericAlias):
if not isinstance(data, expected_type.__origin__): # pragma: no cover
raise TypeError(
Expand Down
Loading

0 comments on commit d148509

Please sign in to comment.