Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eac/thursday #8

Merged
merged 2 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 198 additions & 1 deletion src/rail/plotting/data_extraction_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
from typing import Any
import glob

import numpy as np
import tables_io
Expand Down Expand Up @@ -67,6 +68,29 @@ def extract_z_point(
return z_estimates


def extract_z_pdf(
filepath: str,
) -> qp.ensemble:
"""Extract the pdf estimates of redshifts from a file

Parameters
----------
filepath: str
Path to file with tabular data

Returns
-------
z_pdf: qp.ensemble
Redshift pdf in question

Notes
-----
This assumes the point estimates are in a qp file
"""
z_pdf = qp.read(filepath)
return z_pdf


def extract_multiple_z_point(
filepaths: dict[str, str],
colname: str = "zmode",
Expand Down Expand Up @@ -212,6 +236,88 @@ def get_ceci_pz_output_path(
return outpath if os.path.exists(outpath) else None


def get_ceci_nz_output_paths(
project: RailProject,
selection: str,
flavor: str,
algo: str,
classifier: str,
summarizer: str,
) -> list[str]:
"""Get the paths to the file with n(z) estimates
for a particualar analysis selection and flavor

Parameters
----------
project: RailProject
Object with information about the structure of the current project

selection: str
Data selection in question, e.g., 'gold', or 'blended'

flavor: str
Analysis flavor in question, e.g., 'baseline' or 'zCosmos'

algo: str
Algorithm we want the estimates for, e.g., 'knn', 'bpz'], etc...

classifier: str
Algorithm we use to make tomograpic bin

summarizer: str
Algorithm we use to go from p(z) to n(z)

Returns
-------
paths: list[str]
Paths to data
"""
outdir = project.get_path("ceci_output_dir", selection=selection, flavor=flavor)
basename = f"single_NZ_summarize_{algo}_{classifier}_bin*_{summarizer}.hdf5"
outpath = os.path.join(outdir, basename)
paths = sorted(glob.glob(outpath))
return paths


def get_ceci_true_nz_output_paths(
project: RailProject,
selection: str,
flavor: str,
algo: str,
classifier: str,
) -> list[str]:
"""Get the paths to the file with n(z) estimates
for a particualar analysis selection and flavor

Parameters
----------
project: RailProject
Object with information about the structure of the current project

selection: str
Data selection in question, e.g., 'gold', or 'blended'

flavor: str
Analysis flavor in question, e.g., 'baseline' or 'zCosmos'

algo: str
Algorithm we want the estimates for, e.g., 'knn', 'bpz'], etc...

classifier: str
Algorithm we use to make tomograpic bin

Returns
-------
paths: list[str]
Paths to data
"""
outdir = project.get_path("ceci_output_dir", selection=selection, flavor=flavor)
basename = f"true_NZ_true_nz_{algo}_{classifier}_bin*.hdf5"
outpath = os.path.join(outdir, basename)
paths = sorted(glob.glob(outpath))
return paths


def get_pz_point_estimate_data(
project: RailProject,
selection: str,
Expand Down Expand Up @@ -258,10 +364,10 @@ def get_multi_pz_point_estimate_data(
point_estimate_infos: dict[str, dict[str, Any]],
) -> dict[str, Any] | None:
"""Get the true redshifts and point estimates

for several analysis variants

This checks that they all have the same redshifts

Parameters
----------
point_estimate_infos: dict[str, dict[str, Any]]
Expand Down Expand Up @@ -292,3 +398,94 @@ def get_multi_pz_point_estimate_data(
return None
pz_data = make_z_true_multi_z_point_dict(ztrue_data, point_estimates)
return pz_data


def get_tomo_bins_nz_estimate_data(
project: RailProject,
selection: str,
flavor: str,
algo: str,
classifier: str,
summarizer: str,
) -> qp.Ensemble:
"""Get the tomographic bin n(z) estimates

Parameters
----------
project: RailProject
Object with information about the structure of the current project

selection: str
Data selection in question, e.g., 'gold', or 'blended'

flavor: str
Analysis flavor in question, e.g., 'baseline' or 'zCosmos'

algo: str
Algorithm we want the estimates for, e.g., 'knn', 'bpz'], etc...

classifier: str
Algorithm we use to make tomograpic bin

summarizer: str
Algorithm we use to go from p(z) to n(z)

Returns
-------
nz_data: qp.Ensemble
Tomographic bin n(z) data
"""
paths = get_ceci_nz_output_paths(
project,
selection,
flavor,
algo,
classifier,
summarizer,
)

data = qp.concatenate([extract_z_pdf(path_) for path_ in paths])
return data


def get_tomo_bins_true_nz_data(
project: RailProject,
selection: str,
flavor: str,
algo: str,
classifier: str,
) -> qp.Ensemble:
"""Get the tomographic bin true n(z)

Parameters
----------
project: RailProject
Object with information about the structure of the current project

selection: str
Data selection in question, e.g., 'gold', or 'blended'

flavor: str
Analysis flavor in question, e.g., 'baseline' or 'zCosmos'

algo: str
Algorithm we want the estimates for, e.g., 'knn', 'bpz'], etc...

classifier: str
Algorithm we use to make tomograpic bin

Returns
-------
nz_data: qp.Ensemble
Tomographic bin n(z) data
"""
paths = get_ceci_true_nz_output_paths(
project,
selection,
flavor,
algo,
classifier,
)

data = qp.concatenate([extract_z_pdf(path_) for path_ in paths])
return data
2 changes: 1 addition & 1 deletion src/rail/plotting/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class RailDatasetFactory:
- Dataset:
name: blend_baseline_test
class: rail.plotting.project_dataset_holder.RailProjectDatasetHolder
exctractor: rail.plottings.pz_data_extractor.PZPointEstimateDataExtractorxs
exctractor: rail.plottings.pz_data_extractor.PZPointEstimateDataExtractor
project: some_project
selection: blend
flavor: baseline
Expand Down
44 changes: 44 additions & 0 deletions src/rail/plotting/nz_data_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from typing import Any


from rail.projects import RailProject

from .data_extractor import RailProjectDataExtractor
from .data_extraction_funcs import (
get_tomo_bins_nz_estimate_data,
get_tomo_bins_true_nz_data,
)


class NZTomoBinDataExtractor(RailProjectDataExtractor):
"""Class to extract true redshifts and n(z) tomo bin estimates
from a RailProject.

This will return a dict:

truth: np.ndarray
True redshifts for each tomo bin

nz_estimates: np.ndarray
n(z) estimates for each tomo bin
"""

inputs: dict = {
"project": RailProject,
"selection": str,
"flavor": str,
"algo": str,
"classifier": str,
"summarizer": str,
}

def _get_data(self, **kwargs: Any) -> dict[str, Any]:
kwcopy = kwargs.copy()
kwcopy.pop("summarizer")
data = dict(
nz_estimates=get_tomo_bins_nz_estimate_data(**kwargs),
truth=get_tomo_bins_true_nz_data(**kwcopy),
)
return data
79 changes: 79 additions & 0 deletions src/rail/plotting/nz_plotters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

import os
from typing import Any
import numpy as np
import qp
from matplotlib import pyplot as plt
from ceci.config import StageParameter

from .plotter import RailPlotter
from .plot_holder import RailPlotHolder
from .dataset_holder import RailDatasetHolder


class NZPlotterTomoBins(RailPlotter):
"""Class to make a histogram of all the nz distributions"""

config_options: dict[str, StageParameter] = RailPlotter.config_options.copy()
config_options.update(
z_min=StageParameter(float, 0.0, fmt="%0.2f", msg="Minimum Redshift"),
z_max=StageParameter(float, 3.0, fmt="%0.2f", msg="Maximum Redshift"),
n_zbins=StageParameter(int, 150, fmt="%i", msg="Number of z bins"),
)

inputs: dict = {
"truth": qp.Ensemble,
"nz_estimates": qp.Ensemble,
}

def _make_plot(
self,
prefix: str,
truth: qp.Ensemble,
nz_estimates: qp.Ensemble,
dataset_holder: RailDatasetHolder | None = None,
) -> RailPlotHolder:
figure, axes = plt.subplots()
bin_edges = np.linspace(
self.config.z_min, self.config.z_max, self.config.n_zbins + 1
)
truth_vals = truth.pdf(bin_edges)
nz_vals = nz_estimates.pdf(bin_edges)
n_pdf = truth.npdf

for i in range(n_pdf):
axes.plot(bin_edges, truth_vals[i], "-")
axes.plot(bin_edges, nz_vals[i])
plt.xlabel("z")
plt.ylabel("n(z)")
plot_name = self._make_full_plot_name(prefix, "")
return RailPlotHolder(
name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder
)

def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]:
find_only = kwargs.get("find_only", False)
figtype = kwargs.get("figtype", "png")
dataset_holder = kwargs.get("dataset_holder")
out_dict: dict[str, RailPlotHolder] = {}
truth: qp.Ensemble = kwargs["truth"]
nz_estimates: qp.Ensemble = kwargs["nz_estimates"]
if find_only:
plot_name = self._make_full_plot_name(prefix, "")
assert dataset_holder
plot = RailPlotHolder(
name=plot_name,
path=os.path.join(dataset_holder.config.name, f"{plot_name}.{figtype}"),
plotter=self,
dataset_holder=dataset_holder,
)
else:
plot = self._make_plot(
prefix=prefix,
truth=truth,
nz_estimates=nz_estimates,
dataset_holder=dataset_holder,
)
out_dict[plot.name] = plot
return out_dict
Loading
Loading