From f9cdc3dfc1947abf4fef17ce7526bfbf427e14a5 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Mon, 29 Apr 2024 22:15:52 -0400 Subject: [PATCH 01/21] Move internal structure to a separate place than the public API --- docs/source/conf.py | 4 +++ docs/source/reference_index.md | 21 +++++++++++-- .../download_spoc_transit_light_curves.py | 5 ++-- examples/transit_dataset.py | 4 +-- examples/transit_finite_dataset_test.py | 4 +-- examples/transit_infer.py | 10 +++---- examples/transit_infinite_dataset_test.py | 10 +++---- examples/transit_light_curve_visualization.py | 2 +- examples/transit_train.py | 5 ++-- src/qusi/data.py | 17 +++++++++++ src/qusi/experimental/__init__.py | 0 src/qusi/experimental/application/__init__.py | 0 src/qusi/experimental/application/tess.py | 14 +++++++++ src/qusi/experimental/session.py | 5 ++++ src/qusi/internal/__init__.py | 0 src/qusi/{ => internal}/device.py | 5 ++++ .../finite_standard_light_curve_dataset.py | 10 +++++-- ...tandard_light_curve_observation_dataset.py | 4 +-- .../{ => internal}/finite_test_session.py | 2 +- src/qusi/{ => internal}/hadryss_model.py | 30 ++++++++++++------- src/qusi/{ => internal}/infer_session.py | 2 +- src/qusi/{ => internal}/light_curve.py | 0 .../{ => internal}/light_curve_collection.py | 4 +-- .../{ => internal}/light_curve_dataset.py | 8 ++--- .../{ => internal}/light_curve_observation.py | 2 +- .../{ => internal}/light_curve_transforms.py | 2 +- src/qusi/{ => internal}/logging.py | 0 .../single_dense_layer_model.py | 0 .../toy_light_curve_collection.py | 8 ++--- .../train_hyperparameter_configuration.py | 0 .../train_logging_configuration.py | 0 src/qusi/{ => internal}/train_session.py | 10 +++---- .../train_system_configuration.py | 0 src/qusi/{ => internal}/wandb_liaison.py | 0 src/qusi/model.py | 8 +++++ src/qusi/session.py | 14 +++++++++ .../test_toy_infer_session.py | 10 +++---- .../test_toy_train_session.py | 10 +++---- tests/unit_tests/test_hydryss_model.py | 2 +- tests/unit_tests/test_light_curve_dataset.py | 2 +- 40 files changed, 165 insertions(+), 69 deletions(-) create mode 100644 src/qusi/data.py create mode 100644 src/qusi/experimental/__init__.py create mode 100644 src/qusi/experimental/application/__init__.py create mode 100644 src/qusi/experimental/application/tess.py create mode 100644 src/qusi/experimental/session.py create mode 100644 src/qusi/internal/__init__.py rename src/qusi/{ => internal}/device.py (68%) rename src/qusi/{ => internal}/finite_standard_light_curve_dataset.py (82%) rename src/qusi/{ => internal}/finite_standard_light_curve_observation_dataset.py (90%) rename src/qusi/{ => internal}/finite_test_session.py (94%) rename src/qusi/{ => internal}/hadryss_model.py (90%) rename src/qusi/{ => internal}/infer_session.py (93%) rename src/qusi/{ => internal}/light_curve.py (100%) rename src/qusi/{ => internal}/light_curve_collection.py (98%) rename src/qusi/{ => internal}/light_curve_dataset.py (98%) rename src/qusi/{ => internal}/light_curve_observation.py (92%) rename src/qusi/{ => internal}/light_curve_transforms.py (92%) rename src/qusi/{ => internal}/logging.py (100%) rename src/qusi/{ => internal}/single_dense_layer_model.py (100%) rename src/qusi/{ => internal}/toy_light_curve_collection.py (93%) rename src/qusi/{ => internal}/train_hyperparameter_configuration.py (100%) rename src/qusi/{ => internal}/train_logging_configuration.py (100%) rename src/qusi/{ => internal}/train_session.py (95%) rename src/qusi/{ => internal}/train_system_configuration.py (100%) rename src/qusi/{ => internal}/wandb_liaison.py (100%) create mode 100644 src/qusi/model.py create mode 100644 src/qusi/session.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 8f582c59..f264459e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,6 +26,10 @@ templates_path = ["_templates"] exclude_patterns = [] source_suffix = [".rst", ".md"] +autodoc_class_signature = 'separated' +autodoc_default_options = { + 'special-members': None, +} # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output diff --git a/docs/source/reference_index.md b/docs/source/reference_index.md index 0d342294..162c5229 100644 --- a/docs/source/reference_index.md +++ b/docs/source/reference_index.md @@ -1,6 +1,23 @@ # Reference ```{eval-rst} -.. automodule:: qusi.light_curve - :members: +.. autoclass:: qusi.data.LightCurve + :members: new +.. autoclass:: qusi.data.LightCurveCollection + :members: new +.. autoclass:: qusi.data.LightCurveDataset + :members: new +.. autoclass:: qusi.data.LightCurveObservationCollection + :members: new +.. autoclass:: qusi.data.FiniteStandardLightCurveDataset + :members: new +.. autoclass:: qusi.data.FiniteStandardLightCurveObservationDataset + :members: new +.. autoclass:: qusi.model.Hadryss + :members: new +.. autofunction:: qusi.session.get_device +.. autofunction:: qusi.session.infer_session +.. autofunction:: qusi.session.train_session +.. autoclass:: qusi.session.TrainHyperparameterConfiguration + :members: new ``` diff --git a/examples/download_spoc_transit_light_curves.py b/examples/download_spoc_transit_light_curves.py index 418ec051..0bff68ee 100644 --- a/examples/download_spoc_transit_light_curves.py +++ b/examples/download_spoc_transit_light_curves.py @@ -2,11 +2,12 @@ import numpy as np -from ramjet.data_interface.tess_data_interface import ( +from qusi.experimental.application.tess import ( download_spoc_light_curves_for_tic_ids, get_spoc_tic_id_list_from_mast, + TessToiDataInterface, + ToiColumns, ) -from ramjet.data_interface.tess_toi_data_interface import TessToiDataInterface, ToiColumns def main(): diff --git a/examples/transit_dataset.py b/examples/transit_dataset.py index 37265bf0..ad30126d 100644 --- a/examples/transit_dataset.py +++ b/examples/transit_dataset.py @@ -1,8 +1,6 @@ from pathlib import Path -from qusi.finite_standard_light_curve_observation_dataset import FiniteStandardLightCurveObservationDataset -from qusi.light_curve_collection import LightCurveObservationCollection -from qusi.light_curve_dataset import LightCurveDataset +from qusi.data import FiniteStandardLightCurveObservationDataset, LightCurveDataset, LightCurveObservationCollection from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve diff --git a/examples/transit_finite_dataset_test.py b/examples/transit_finite_dataset_test.py index 6c6479ae..f09b1aaf 100644 --- a/examples/transit_finite_dataset_test.py +++ b/examples/transit_finite_dataset_test.py @@ -2,8 +2,8 @@ from torch.nn import BCELoss from torchmetrics.classification import BinaryAccuracy -from qusi.finite_test_session import finite_datasets_test_session, get_device -from qusi.hadryss_model import Hadryss +from qusi.session import finite_datasets_test_session, get_device +from qusi.model import Hadryss from transit_dataset import get_transit_finite_test_dataset diff --git a/examples/transit_infer.py b/examples/transit_infer.py index 80673fdb..d8dfec87 100644 --- a/examples/transit_infer.py +++ b/examples/transit_infer.py @@ -3,12 +3,10 @@ import numpy as np import torch -from qusi.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset -from qusi.hadryss_model import Hadryss -from qusi.infer_session import infer_session -from qusi.device import get_device -from qusi.light_curve_collection import LightCurveCollection -from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve +from qusi.data import FiniteStandardLightCurveDataset, LightCurveCollection +from qusi.model import Hadryss +from qusi.session import get_device, infer_session +from qusi.experimental.application.tess import TessMissionLightCurve def get_infer_paths(): diff --git a/examples/transit_infinite_dataset_test.py b/examples/transit_infinite_dataset_test.py index 6f6e34c5..4fa2b0f6 100644 --- a/examples/transit_infinite_dataset_test.py +++ b/examples/transit_infinite_dataset_test.py @@ -7,11 +7,11 @@ from torch.utils.data import DataLoader from torchmetrics.classification import BinaryAccuracy -from qusi.hadryss_model import Hadryss -from qusi.device import get_device -from qusi.light_curve_collection import LightCurveObservationCollection -from qusi.light_curve_dataset import LightCurveDataset -from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve +from qusi.model import Hadryss +from qusi.session import get_device +from qusi.data import LightCurveObservationCollection +from qusi.data import LightCurveDataset +from qusi.experimental.application.tess import TessMissionLightCurve def get_negative_test_paths(): diff --git a/examples/transit_light_curve_visualization.py b/examples/transit_light_curve_visualization.py index 1175da4d..42d0a552 100644 --- a/examples/transit_light_curve_visualization.py +++ b/examples/transit_light_curve_visualization.py @@ -3,7 +3,7 @@ from bokeh.io import show from bokeh.plotting import figure as Figure -from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve +from qusi.experimental.application.tess import TessMissionLightCurve def main(): diff --git a/examples/transit_train.py b/examples/transit_train.py index 79054045..77a8b5d0 100644 --- a/examples/transit_train.py +++ b/examples/transit_train.py @@ -1,6 +1,5 @@ -from qusi.hadryss_model import Hadryss -from qusi.train_hyperparameter_configuration import TrainHyperparameterConfiguration -from qusi.train_session import train_session +from qusi.model import Hadryss +from qusi.session import TrainHyperparameterConfiguration, train_session from transit_dataset import get_transit_train_dataset, get_transit_validation_dataset diff --git a/src/qusi/data.py b/src/qusi/data.py new file mode 100644 index 00000000..03d77d86 --- /dev/null +++ b/src/qusi/data.py @@ -0,0 +1,17 @@ +""" +Data related public interface. +""" +from qusi.internal.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset +from qusi.internal.finite_standard_light_curve_observation_dataset import FiniteStandardLightCurveObservationDataset +from qusi.internal.light_curve import LightCurve +from qusi.internal.light_curve_dataset import LightCurveDataset +from qusi.internal.light_curve_collection import LightCurveObservationCollection, LightCurveCollection + +__all__ = [ + 'FiniteStandardLightCurveDataset', + 'FiniteStandardLightCurveObservationDataset', + 'LightCurve', + 'LightCurveCollection', + 'LightCurveDataset', + 'LightCurveObservationCollection', +] diff --git a/src/qusi/experimental/__init__.py b/src/qusi/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/qusi/experimental/application/__init__.py b/src/qusi/experimental/application/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/qusi/experimental/application/tess.py b/src/qusi/experimental/application/tess.py new file mode 100644 index 00000000..419a83b7 --- /dev/null +++ b/src/qusi/experimental/application/tess.py @@ -0,0 +1,14 @@ +from ramjet.data_interface.tess_data_interface import ( + download_spoc_light_curves_for_tic_ids, + get_spoc_tic_id_list_from_mast, +) +from ramjet.data_interface.tess_toi_data_interface import TessToiDataInterface, ToiColumns +from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve + +__all__ = [ + 'download_spoc_light_curves_for_tic_ids', + 'get_spoc_tic_id_list_from_mast', + 'TessMissionLightCurve', + 'TessToiDataInterface', + 'ToiColumns', +] diff --git a/src/qusi/experimental/session.py b/src/qusi/experimental/session.py new file mode 100644 index 00000000..aa01d435 --- /dev/null +++ b/src/qusi/experimental/session.py @@ -0,0 +1,5 @@ +from qusi.internal.finite_test_session import finite_datasets_test_session + +__all__ = [ + 'finite_datasets_test_session', +] \ No newline at end of file diff --git a/src/qusi/internal/__init__.py b/src/qusi/internal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/qusi/device.py b/src/qusi/internal/device.py similarity index 68% rename from src/qusi/device.py rename to src/qusi/internal/device.py index 37455a22..fc725a11 100644 --- a/src/qusi/device.py +++ b/src/qusi/internal/device.py @@ -3,6 +3,11 @@ def get_device() -> Device: + """ + Gets the available device for PyTorch to run on. + + :return: The device. + """ if torch.cuda.is_available(): device = torch.device("cuda") else: diff --git a/src/qusi/finite_standard_light_curve_dataset.py b/src/qusi/internal/finite_standard_light_curve_dataset.py similarity index 82% rename from src/qusi/finite_standard_light_curve_dataset.py rename to src/qusi/internal/finite_standard_light_curve_dataset.py index 8a941663..c3f7f6cb 100644 --- a/src/qusi/finite_standard_light_curve_dataset.py +++ b/src/qusi/internal/finite_standard_light_curve_dataset.py @@ -6,8 +6,8 @@ from torch.utils.data import Dataset from typing_extensions import Self -from qusi.light_curve_collection import LightCurveCollection -from qusi.light_curve_dataset import default_light_curve_post_injection_transform +from qusi.internal.light_curve_collection import LightCurveCollection +from qusi.internal.light_curve_dataset import default_light_curve_post_injection_transform @dataclass @@ -19,6 +19,12 @@ class FiniteStandardLightCurveDataset(Dataset): @classmethod def new(cls, light_curve_collections: list[LightCurveCollection]) -> Self: + """ + Creates a new `FiniteStandardLightCurveDataset`. + + :param light_curve_collections: The light curve collections to include in the dataset. + :return: The dataset. + """ length = 0 collection_start_indexes: list[int] = [] for light_curve_collection in light_curve_collections: diff --git a/src/qusi/finite_standard_light_curve_observation_dataset.py b/src/qusi/internal/finite_standard_light_curve_observation_dataset.py similarity index 90% rename from src/qusi/finite_standard_light_curve_observation_dataset.py rename to src/qusi/internal/finite_standard_light_curve_observation_dataset.py index 61555b71..c64f4b66 100644 --- a/src/qusi/finite_standard_light_curve_observation_dataset.py +++ b/src/qusi/internal/finite_standard_light_curve_observation_dataset.py @@ -6,8 +6,8 @@ from torch.utils.data import Dataset from typing_extensions import Self -from qusi.light_curve_collection import LightCurveObservationCollection -from qusi.light_curve_dataset import default_light_curve_observation_post_injection_transform +from qusi.internal.light_curve_collection import LightCurveObservationCollection +from qusi.internal.light_curve_dataset import default_light_curve_observation_post_injection_transform @dataclass diff --git a/src/qusi/finite_test_session.py b/src/qusi/internal/finite_test_session.py similarity index 94% rename from src/qusi/finite_test_session.py rename to src/qusi/internal/finite_test_session.py index de3c48ea..c623bf53 100644 --- a/src/qusi/finite_test_session.py +++ b/src/qusi/internal/finite_test_session.py @@ -3,7 +3,7 @@ from torch.types import Device from torch.utils.data import DataLoader -from qusi.finite_standard_light_curve_observation_dataset import FiniteStandardLightCurveObservationDataset +from qusi.internal.finite_standard_light_curve_observation_dataset import FiniteStandardLightCurveObservationDataset def finite_datasets_test_session( diff --git a/src/qusi/hadryss_model.py b/src/qusi/internal/hadryss_model.py similarity index 90% rename from src/qusi/hadryss_model.py rename to src/qusi/internal/hadryss_model.py index 61f09d1d..70c8c39e 100644 --- a/src/qusi/hadryss_model.py +++ b/src/qusi/internal/hadryss_model.py @@ -18,6 +18,10 @@ class Hadryss(Module): + """ + A 1D convolutional neural network model for light curve data that will auto-size itself for a given input light + curve length. + """ def __init__(self, input_length: int): super().__init__() self.input_length: int = input_length @@ -124,6 +128,12 @@ def forward(self, x: Tensor) -> Tensor: @classmethod def new(cls, input_length: int = 2500) -> Self: + """ + Creates a new Hadryss model. + + :param input_length: The length of the input to auto-size the network to. + :return: The model. + """ instance = cls(input_length=input_length) return instance @@ -145,16 +155,16 @@ def determine_block_pooling_sizes_and_dense_size(self) -> (list[int], int): class LightCurveNetworkBlock(Module): def __init__( - self, - input_channels: int, - output_channels: int, - kernel_size: int, - pooling_size: int, - dropout_rate: float = 0.0, - *, - batch_normalization: bool = False, - spatial: bool = True, - length: int | None = None, + self, + input_channels: int, + output_channels: int, + kernel_size: int, + pooling_size: int, + dropout_rate: float = 0.0, + *, + batch_normalization: bool = False, + spatial: bool = True, + length: int | None = None, ): super().__init__() self.leaky_relu = LeakyReLU() diff --git a/src/qusi/infer_session.py b/src/qusi/internal/infer_session.py similarity index 93% rename from src/qusi/infer_session.py rename to src/qusi/internal/infer_session.py index bc2a2f7b..ae027aba 100644 --- a/src/qusi/infer_session.py +++ b/src/qusi/internal/infer_session.py @@ -4,7 +4,7 @@ from torch.types import Device from torch.utils.data import DataLoader -from qusi.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset +from qusi.internal.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset def infer_session( diff --git a/src/qusi/light_curve.py b/src/qusi/internal/light_curve.py similarity index 100% rename from src/qusi/light_curve.py rename to src/qusi/internal/light_curve.py diff --git a/src/qusi/light_curve_collection.py b/src/qusi/internal/light_curve_collection.py similarity index 98% rename from src/qusi/light_curve_collection.py rename to src/qusi/internal/light_curve_collection.py index d5059d5a..329c7071 100644 --- a/src/qusi/light_curve_collection.py +++ b/src/qusi/internal/light_curve_collection.py @@ -11,8 +11,8 @@ import numpy.typing as npt from typing_extensions import Self -from qusi.light_curve import LightCurve -from qusi.light_curve_observation import LightCurveObservation +from qusi.internal.light_curve import LightCurve +from qusi.internal.light_curve_observation import LightCurveObservation if TYPE_CHECKING: from collections.abc import Iterable, Iterator diff --git a/src/qusi/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py similarity index 98% rename from src/qusi/light_curve_dataset.py rename to src/qusi/internal/light_curve_dataset.py index 91c6adc0..d8e61959 100644 --- a/src/qusi/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -20,17 +20,17 @@ from torch.utils.data import IterableDataset from typing_extensions import Self -from qusi.light_curve import ( +from qusi.internal.light_curve import ( LightCurve, randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve, ) -from qusi.light_curve_observation import ( +from qusi.internal.light_curve_observation import ( LightCurveObservation, randomly_roll_light_curve_observation, remove_nan_flux_data_points_from_light_curve_observation, ) -from qusi.light_curve_transforms import ( +from qusi.internal.light_curve_transforms import ( from_light_curve_observation_to_fluxes_array_and_label_array, pair_array_to_tensor, ) @@ -38,7 +38,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator - from qusi.light_curve_collection import LightCurveObservationCollection + from qusi.internal.light_curve_collection import LightCurveObservationCollection class LightCurveDataset(IterableDataset): diff --git a/src/qusi/light_curve_observation.py b/src/qusi/internal/light_curve_observation.py similarity index 92% rename from src/qusi/light_curve_observation.py rename to src/qusi/internal/light_curve_observation.py index 5d88dd18..042f4762 100644 --- a/src/qusi/light_curve_observation.py +++ b/src/qusi/internal/light_curve_observation.py @@ -3,7 +3,7 @@ from typing_extensions import Self -from qusi.light_curve import LightCurve, randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve +from qusi.internal.light_curve import LightCurve, randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve @dataclass diff --git a/src/qusi/light_curve_transforms.py b/src/qusi/internal/light_curve_transforms.py similarity index 92% rename from src/qusi/light_curve_transforms.py rename to src/qusi/internal/light_curve_transforms.py index 4aeb703e..c9b10b73 100644 --- a/src/qusi/light_curve_transforms.py +++ b/src/qusi/internal/light_curve_transforms.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from qusi.light_curve_observation import LightCurveObservation +from qusi.internal.light_curve_observation import LightCurveObservation def from_light_curve_observation_to_fluxes_array_and_label_array( diff --git a/src/qusi/logging.py b/src/qusi/internal/logging.py similarity index 100% rename from src/qusi/logging.py rename to src/qusi/internal/logging.py diff --git a/src/qusi/single_dense_layer_model.py b/src/qusi/internal/single_dense_layer_model.py similarity index 100% rename from src/qusi/single_dense_layer_model.py rename to src/qusi/internal/single_dense_layer_model.py diff --git a/src/qusi/toy_light_curve_collection.py b/src/qusi/internal/toy_light_curve_collection.py similarity index 93% rename from src/qusi/toy_light_curve_collection.py rename to src/qusi/internal/toy_light_curve_collection.py index df3dd764..d22a0afe 100644 --- a/src/qusi/toy_light_curve_collection.py +++ b/src/qusi/internal/toy_light_curve_collection.py @@ -2,13 +2,13 @@ import numpy as np -from qusi.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset -from qusi.light_curve import LightCurve -from qusi.light_curve_collection import ( +from qusi.internal.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset +from qusi.internal.light_curve import LightCurve +from qusi.internal.light_curve_collection import ( LightCurveObservationCollection, create_constant_label_for_path_function, LightCurveCollection, ) -from qusi.light_curve_dataset import LightCurveDataset +from qusi.internal.light_curve_dataset import LightCurveDataset class ToyLightCurve: diff --git a/src/qusi/train_hyperparameter_configuration.py b/src/qusi/internal/train_hyperparameter_configuration.py similarity index 100% rename from src/qusi/train_hyperparameter_configuration.py rename to src/qusi/internal/train_hyperparameter_configuration.py diff --git a/src/qusi/train_logging_configuration.py b/src/qusi/internal/train_logging_configuration.py similarity index 100% rename from src/qusi/train_logging_configuration.py rename to src/qusi/internal/train_logging_configuration.py diff --git a/src/qusi/train_session.py b/src/qusi/internal/train_session.py similarity index 95% rename from src/qusi/train_session.py rename to src/qusi/internal/train_session.py index c62e464e..13eb3989 100644 --- a/src/qusi/train_session.py +++ b/src/qusi/internal/train_session.py @@ -12,11 +12,11 @@ from torchmetrics.classification import BinaryAccuracy import wandb -from qusi.light_curve_dataset import InterleavedDataset, LightCurveDataset -from qusi.logging import set_up_default_logger -from qusi.train_hyperparameter_configuration import TrainHyperparameterConfiguration -from qusi.train_logging_configuration import TrainLoggingConfiguration -from qusi.wandb_liaison import wandb_commit, wandb_init, wandb_log +from qusi.internal.light_curve_dataset import InterleavedDataset, LightCurveDataset +from qusi.internal.logging import set_up_default_logger +from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration +from qusi.internal.train_logging_configuration import TrainLoggingConfiguration +from qusi.internal.wandb_liaison import wandb_commit, wandb_init, wandb_log logger = logging.getLogger(__name__) diff --git a/src/qusi/train_system_configuration.py b/src/qusi/internal/train_system_configuration.py similarity index 100% rename from src/qusi/train_system_configuration.py rename to src/qusi/internal/train_system_configuration.py diff --git a/src/qusi/wandb_liaison.py b/src/qusi/internal/wandb_liaison.py similarity index 100% rename from src/qusi/wandb_liaison.py rename to src/qusi/internal/wandb_liaison.py diff --git a/src/qusi/model.py b/src/qusi/model.py new file mode 100644 index 00000000..02ecb910 --- /dev/null +++ b/src/qusi/model.py @@ -0,0 +1,8 @@ +""" +Neural network model related public interface. +""" +from qusi.internal.hadryss_model import Hadryss + +__all__ = [ + 'Hadryss', +] diff --git a/src/qusi/session.py b/src/qusi/session.py new file mode 100644 index 00000000..c8569e33 --- /dev/null +++ b/src/qusi/session.py @@ -0,0 +1,14 @@ +""" +Session related public interface. +""" +from qusi.internal.device import get_device +from qusi.internal.infer_session import infer_session +from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration +from qusi.internal.train_session import train_session + +__all__ = [ + 'get_device', + 'infer_session', + 'TrainHyperparameterConfiguration', + 'train_session', +] diff --git a/tests/end_to_end_tests/test_toy_infer_session.py b/tests/end_to_end_tests/test_toy_infer_session.py index 9c3f046f..72384bc6 100644 --- a/tests/end_to_end_tests/test_toy_infer_session.py +++ b/tests/end_to_end_tests/test_toy_infer_session.py @@ -3,13 +3,13 @@ import numpy as np -from qusi.infer_session import infer_session -from qusi.device import get_device -from qusi.light_curve_dataset import ( +from qusi.internal.infer_session import infer_session +from qusi.internal.device import get_device +from qusi.internal.light_curve_dataset import ( default_light_curve_post_injection_transform, ) -from qusi.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel -from qusi.toy_light_curve_collection import get_toy_finite_light_curve_dataset +from qusi.internal.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel +from qusi.internal.toy_light_curve_collection import get_toy_finite_light_curve_dataset def test_toy_infer_session(): diff --git a/tests/end_to_end_tests/test_toy_train_session.py b/tests/end_to_end_tests/test_toy_train_session.py index 4833d8ca..47e17f0d 100644 --- a/tests/end_to_end_tests/test_toy_train_session.py +++ b/tests/end_to_end_tests/test_toy_train_session.py @@ -1,13 +1,13 @@ import os from functools import partial -from qusi.light_curve_dataset import ( +from qusi.internal.light_curve_dataset import ( default_light_curve_observation_post_injection_transform, ) -from qusi.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel -from qusi.toy_light_curve_collection import get_toy_dataset -from qusi.train_hyperparameter_configuration import TrainHyperparameterConfiguration -from qusi.train_session import train_session +from qusi.internal.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel +from qusi.internal.toy_light_curve_collection import get_toy_dataset +from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration +from qusi.internal.train_session import train_session def test_toy_train_session(): diff --git a/tests/unit_tests/test_hydryss_model.py b/tests/unit_tests/test_hydryss_model.py index 7e62ec5c..b95dc2fe 100644 --- a/tests/unit_tests/test_hydryss_model.py +++ b/tests/unit_tests/test_hydryss_model.py @@ -1,6 +1,6 @@ import torch -from qusi.hadryss_model import Hadryss +from qusi.internal.hadryss_model import Hadryss def test_lengths_give_correct_output_size(): diff --git a/tests/unit_tests/test_light_curve_dataset.py b/tests/unit_tests/test_light_curve_dataset.py index 62f55e28..b2e5897f 100644 --- a/tests/unit_tests/test_light_curve_dataset.py +++ b/tests/unit_tests/test_light_curve_dataset.py @@ -2,7 +2,7 @@ from itertools import islice from unittest.mock import Mock -from qusi.light_curve_dataset import ( +from qusi.internal.light_curve_dataset import ( contains_injected_dataset, interleave_infinite_iterators, is_injected_dataset, From 5c435e311140140ece4665d23cee39d87fe79f45 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Wed, 1 May 2024 23:50:41 -0400 Subject: [PATCH 02/21] Add required keyword parameters --- examples/transit_infinite_dataset_test.py | 2 +- src/qusi/internal/finite_test_session.py | 11 +- src/qusi/internal/hadryss_model.py | 2 +- src/qusi/internal/infer_session.py | 2 +- src/qusi/internal/light_curve_dataset.py | 124 +++++++----------- .../train_hyperparameter_configuration.py | 19 +-- .../internal/train_logging_configuration.py | 9 +- src/qusi/internal/train_session.py | 31 ++--- .../internal/train_system_configuration.py | 6 +- src/qusi/session.py | 7 + 10 files changed, 103 insertions(+), 110 deletions(-) diff --git a/examples/transit_infinite_dataset_test.py b/examples/transit_infinite_dataset_test.py index 4fa2b0f6..98484286 100644 --- a/examples/transit_infinite_dataset_test.py +++ b/examples/transit_infinite_dataset_test.py @@ -60,7 +60,7 @@ def main(): def infinite_datasets_test_session(test_datasets: list[LightCurveDataset], model: Module, - metric_functions: list[Module], batch_size: int, device: Device, steps: int): + metric_functions: list[Module], *, batch_size: int, device: Device, steps: int): test_dataloaders: list[DataLoader] = [] for test_dataset in test_datasets: test_dataloaders.append(DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)) diff --git a/src/qusi/internal/finite_test_session.py b/src/qusi/internal/finite_test_session.py index c623bf53..fde27903 100644 --- a/src/qusi/internal/finite_test_session.py +++ b/src/qusi/internal/finite_test_session.py @@ -7,11 +7,12 @@ def finite_datasets_test_session( - test_datasets: list[FiniteStandardLightCurveObservationDataset], - model: Module, - metric_functions: list[Module], - batch_size: int, - device: Device, + test_datasets: list[FiniteStandardLightCurveObservationDataset], + model: Module, + *, + metric_functions: list[Module], + batch_size: int, + device: Device, ): test_dataloaders: list[DataLoader] = [] for test_dataset in test_datasets: diff --git a/src/qusi/internal/hadryss_model.py b/src/qusi/internal/hadryss_model.py index 70c8c39e..c4701d81 100644 --- a/src/qusi/internal/hadryss_model.py +++ b/src/qusi/internal/hadryss_model.py @@ -22,7 +22,7 @@ class Hadryss(Module): A 1D convolutional neural network model for light curve data that will auto-size itself for a given input light curve length. """ - def __init__(self, input_length: int): + def __init__(self, *, input_length: int): super().__init__() self.input_length: int = input_length pooling_sizes, dense_size = self.determine_block_pooling_sizes_and_dense_size() diff --git a/src/qusi/internal/infer_session.py b/src/qusi/internal/infer_session.py index ae027aba..e9499779 100644 --- a/src/qusi/internal/infer_session.py +++ b/src/qusi/internal/infer_session.py @@ -8,7 +8,7 @@ def infer_session( - infer_datasets: list[FiniteStandardLightCurveDataset], model: Module, batch_size: int, device: Device + infer_datasets: list[FiniteStandardLightCurveDataset], model: Module, *, batch_size: int, device: Device ) -> list[np.ndarray]: infer_dataloaders: list[DataLoader] = [] for infer_dataset in infer_datasets: diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index d8e61959..f2d70a0f 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -47,25 +47,17 @@ class LightCurveDataset(IterableDataset): """ def __init__( - self, - standard_light_curve_collections: list[LightCurveObservationCollection], - injectee_light_curve_collections: list[LightCurveObservationCollection], - injectable_light_curve_collections: list[LightCurveObservationCollection], - post_injection_transform: Callable[[Any], Any], + self, + standard_light_curve_collections: list[LightCurveObservationCollection], + injectee_light_curve_collections: list[LightCurveObservationCollection], + injectable_light_curve_collections: list[LightCurveObservationCollection], + post_injection_transform: Callable[[Any], Any], ): - self.standard_light_curve_collections: list[ - LightCurveObservationCollection - ] = standard_light_curve_collections - self.injectee_light_curve_collections: list[ - LightCurveObservationCollection - ] = injectee_light_curve_collections + self.standard_light_curve_collections: list[LightCurveObservationCollection] = standard_light_curve_collections + self.injectee_light_curve_collections: list[LightCurveObservationCollection] = injectee_light_curve_collections self.injectable_light_curve_collections: list[ - LightCurveObservationCollection - ] = injectable_light_curve_collections - if ( - len(self.standard_light_curve_collections) == 0 - and len(self.injectee_light_curve_collections) == 0 - ): + LightCurveObservationCollection] = injectable_light_curve_collections + if len(self.standard_light_curve_collections) == 0 and len(self.injectee_light_curve_collections) == 0: error_message = ( "Either the standard or injectee light curve collection lists must not be empty. " "Both were empty." @@ -99,27 +91,18 @@ def __iter__(self): loop_iter_function(injectee_collection.observation_iter), LightCurveCollectionType.INJECTEE, ) - base_light_curve_collection_iter_and_type_pairs.append( - base_light_curve_collection_iter_and_type_pair - ) + base_light_curve_collection_iter_and_type_pairs.append(base_light_curve_collection_iter_and_type_pair) injectable_light_curve_collection_iters: list[ Iterator[LightCurveObservation] ] = [] for injectable_collection in self.injectable_light_curve_collections: - injectable_light_curve_collection_iter = loop_iter_function( - injectable_collection.observation_iter - ) - injectable_light_curve_collection_iters.append( - injectable_light_curve_collection_iter - ) + injectable_light_curve_collection_iter = loop_iter_function(injectable_collection.observation_iter) + injectable_light_curve_collection_iters.append(injectable_light_curve_collection_iter) while True: for ( - base_light_curve_collection_iter_and_type_pair + base_light_curve_collection_iter_and_type_pair ) in base_light_curve_collection_iter_and_type_pairs: - ( - base_collection_iter, - collection_type, - ) = base_light_curve_collection_iter_and_type_pair + (base_collection_iter, collection_type) = base_light_curve_collection_iter_and_type_pair if collection_type in [ LightCurveCollectionType.STANDARD, LightCurveCollectionType.STANDARD_AND_INJECTEE, @@ -135,9 +118,7 @@ def __iter__(self): LightCurveCollectionType.INJECTEE, LightCurveCollectionType.STANDARD_AND_INJECTEE, ]: - for ( - injectable_light_curve_collection_iter - ) in injectable_light_curve_collection_iters: + for (injectable_light_curve_collection_iter) in injectable_light_curve_collection_iters: injectable_light_curve = next( injectable_light_curve_collection_iter ) @@ -152,18 +133,16 @@ def __iter__(self): @classmethod def new( - cls, - standard_light_curve_collections: list[LightCurveObservationCollection] - | None = None, - injectee_light_curve_collections: list[LightCurveObservationCollection] - | None = None, - injectable_light_curve_collections: list[LightCurveObservationCollection] - | None = None, - post_injection_transform: Callable[[Any], Any] | None = None, + cls, + *, + standard_light_curve_collections: list[LightCurveObservationCollection] | None = None, + injectee_light_curve_collections: list[LightCurveObservationCollection] | None = None, + injectable_light_curve_collections: list[LightCurveObservationCollection] | None = None, + post_injection_transform: Callable[[Any], Any] | None = None, ) -> Self: if ( - standard_light_curve_collections is None - and injectee_light_curve_collections is None + standard_light_curve_collections is None + and injectee_light_curve_collections is None ): error_message = ( "Either the standard or injectee light curve collection lists must be specified. " @@ -190,8 +169,8 @@ def new( def inject_light_curve( - injectee_observation: LightCurveObservation, - injectable_observation: LightCurveObservation, + injectee_observation: LightCurveObservation, + injectable_observation: LightCurveObservation, ) -> LightCurveObservation: ( fluxes_with_injected_signal, @@ -298,7 +277,7 @@ def __iter__(self): def default_light_curve_observation_post_injection_transform( - x: LightCurveObservation, length: int + x: LightCurveObservation, length: int ) -> (Tensor, Tensor): x = remove_nan_flux_data_points_from_light_curve_observation(x) x = randomly_roll_light_curve_observation(x) @@ -326,7 +305,7 @@ def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor: median_absolute_deviation_from_median = torch.median(absolute_deviation_from_median) if median_absolute_deviation_from_median != 0: modified_z_score = ( - 0.6745 * deviation_from_median / median_absolute_deviation_from_median + 0.6745 * deviation_from_median / median_absolute_deviation_from_median ) else: modified_z_score = torch.zeros_like(tensor) @@ -334,10 +313,10 @@ def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor: def make_fluxes_and_label_array_uniform_length( - arrays: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]], - length: int, - *, - randomize: bool = True, + arrays: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]], + length: int, + *, + randomize: bool = True, ) -> (np.ndarray, np.ndarray): fluxes, label = arrays uniform_length_times = make_uniform_length( @@ -347,7 +326,7 @@ def make_fluxes_and_label_array_uniform_length( def make_uniform_length( - example: np.ndarray, length: int, *, randomize: bool = True + example: np.ndarray, length: int, *, randomize: bool = True ) -> np.ndarray: """Makes the example a specific length, by clipping those too large and repeating those too small.""" if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases. @@ -395,14 +374,13 @@ class BaselineFluxEstimationMethod(Enum): def inject_signal_into_light_curve_with_intermediates( - light_curve_times: npt.NDArray[np.float64], - light_curve_fluxes: npt.NDArray[np.float64], - signal_times: npt.NDArray[np.float64], - signal_magnifications: npt.NDArray[np.float64], - out_of_bounds_injection_handling_method: OutOfBoundsInjectionHandlingMethod = ( - OutOfBoundsInjectionHandlingMethod.ERROR - ), - baseline_flux_estimation_method: BaselineFluxEstimationMethod = BaselineFluxEstimationMethod.MEDIAN, + light_curve_times: npt.NDArray[np.float64], + light_curve_fluxes: npt.NDArray[np.float64], + signal_times: npt.NDArray[np.float64], + signal_magnifications: npt.NDArray[np.float64], + out_of_bounds_injection_handling_method: OutOfBoundsInjectionHandlingMethod = ( + OutOfBoundsInjectionHandlingMethod.ERROR), + baseline_flux_estimation_method: BaselineFluxEstimationMethod = BaselineFluxEstimationMethod.MEDIAN, ) -> (npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]): """ Injects a synthetic magnification signal into real light curve fluxes. @@ -423,12 +401,12 @@ def inject_signal_into_light_curve_with_intermediates( light_curve_time_length = np.max(relative_light_curve_times) time_length_difference = light_curve_time_length - signal_time_length signal_start_offset = ( - np.random.random() * time_length_difference - ) + minimum_light_curve_time + np.random.random() * time_length_difference + ) + minimum_light_curve_time offset_signal_times = relative_signal_times + signal_start_offset if ( - baseline_flux_estimation_method - == BaselineFluxEstimationMethod.MEDIAN_ABSOLUTE_DEVIATION + baseline_flux_estimation_method + == BaselineFluxEstimationMethod.MEDIAN_ABSOLUTE_DEVIATION ): baseline_flux = stats.median_abs_deviation(light_curve_fluxes) baseline_to_median_absolute_deviation_ratio = ( @@ -439,16 +417,16 @@ def inject_signal_into_light_curve_with_intermediates( baseline_flux = np.median(light_curve_fluxes) signal_fluxes = (signal_magnifications * baseline_flux) - baseline_flux if ( - out_of_bounds_injection_handling_method - is OutOfBoundsInjectionHandlingMethod.RANDOM_INJECTION_LOCATION + out_of_bounds_injection_handling_method + is OutOfBoundsInjectionHandlingMethod.RANDOM_INJECTION_LOCATION ): signal_flux_interpolator = interp1d( offset_signal_times, signal_fluxes, bounds_error=False, fill_value=0 ) elif ( - out_of_bounds_injection_handling_method - is OutOfBoundsInjectionHandlingMethod.REPEAT_SIGNAL - and time_length_difference > 0 + out_of_bounds_injection_handling_method + is OutOfBoundsInjectionHandlingMethod.REPEAT_SIGNAL + and time_length_difference > 0 ): before_signal_gap = signal_start_offset - minimum_light_curve_time after_signal_gap = time_length_difference - before_signal_gap @@ -465,13 +443,13 @@ def inject_signal_into_light_curve_with_intermediates( repeated_signal_times = None for repeat_index in range(-before_repeats_needed, after_repeats_needed + 1): repeat_signal_start_offset = ( - signal_time_length + minimum_signal_time_step - ) * repeat_index + signal_time_length + minimum_signal_time_step + ) * repeat_index if repeated_signal_times is None: repeated_signal_times = offset_signal_times + repeat_signal_start_offset else: repeat_index_signal_times = ( - offset_signal_times + repeat_signal_start_offset + offset_signal_times + repeat_signal_start_offset ) repeated_signal_times = np.concatenate( [repeated_signal_times, repeat_index_signal_times] diff --git a/src/qusi/internal/train_hyperparameter_configuration.py b/src/qusi/internal/train_hyperparameter_configuration.py index 2f5dd864..bd975a62 100644 --- a/src/qusi/internal/train_hyperparameter_configuration.py +++ b/src/qusi/internal/train_hyperparameter_configuration.py @@ -22,15 +22,16 @@ class TrainHyperparameterConfiguration: @classmethod def new( - cls, - learning_rate: float = 1e-4, - optimizer_epsilon: float = 1e-7, - weight_decay: float = 0.0001, - batch_size: int = 100, - train_steps_per_cycle: int = 100, - validation_steps_per_cycle: int = 10, - cycles: int = 5000, - norm_based_gradient_clip: float = 1.0, + cls, + *, + learning_rate: float = 1e-4, + optimizer_epsilon: float = 1e-7, + weight_decay: float = 0.0001, + batch_size: int = 100, + train_steps_per_cycle: int = 100, + validation_steps_per_cycle: int = 10, + cycles: int = 5000, + norm_based_gradient_clip: float = 1.0, ): return cls( learning_rate=learning_rate, diff --git a/src/qusi/internal/train_logging_configuration.py b/src/qusi/internal/train_logging_configuration.py index 2272a44a..28c9cc49 100644 --- a/src/qusi/internal/train_logging_configuration.py +++ b/src/qusi/internal/train_logging_configuration.py @@ -20,10 +20,11 @@ class TrainLoggingConfiguration: @classmethod def new( - cls, - wandb_project: str | None = None, - wandb_entity: str | None = None, - additional_log_dictionary: dict[str, Any] | None = None, + cls, + *, + wandb_project: str | None = None, + wandb_entity: str | None = None, + additional_log_dictionary: dict[str, Any] | None = None, ): if additional_log_dictionary is None: additional_log_dictionary = {} diff --git a/src/qusi/internal/train_session.py b/src/qusi/internal/train_session.py index 13eb3989..bdd711c6 100644 --- a/src/qusi/internal/train_session.py +++ b/src/qusi/internal/train_session.py @@ -22,13 +22,14 @@ def train_session( - train_datasets: list[LightCurveDataset], - validation_datasets: list[LightCurveDataset], - model: Module, - loss_function: Module | None = None, - metric_functions: list[Module] | None = None, - hyperparameter_configuration: TrainHyperparameterConfiguration | None = None, - logging_configuration: TrainLoggingConfiguration | None = None, + train_datasets: list[LightCurveDataset], + validation_datasets: list[LightCurveDataset], + model: Module, + loss_function: Module | None = None, + metric_functions: list[Module] | None = None, + *, + hyperparameter_configuration: TrainHyperparameterConfiguration | None = None, + logging_configuration: TrainLoggingConfiguration | None = None, ): if hyperparameter_configuration is None: hyperparameter_configuration = TrainHyperparameterConfiguration.new() @@ -112,13 +113,13 @@ def train_session( def train_phase( - dataloader, - model, - loss_function, - metric_functions: list[Module], - optimizer, - steps, - device, + dataloader, + model, + loss_function, + metric_functions: list[Module], + optimizer, + steps, + device, ): model.train() total_loss = 0 @@ -173,7 +174,7 @@ def get_metric_name(metric_function): def validation_phase( - dataloader, model, loss_function, metric_functions: list[Module], steps, device + dataloader, model, loss_function, metric_functions: list[Module], steps, device ): model.eval() validation_loss = 0 diff --git a/src/qusi/internal/train_system_configuration.py b/src/qusi/internal/train_system_configuration.py index 7e5cf877..608f68a5 100644 --- a/src/qusi/internal/train_system_configuration.py +++ b/src/qusi/internal/train_system_configuration.py @@ -14,5 +14,9 @@ class TrainSystemConfiguration: preprocessing_processes_per_train_process: int @classmethod - def new(cls, preprocessing_processes_per_train_process: int = 10): + def new( + cls, + *, + preprocessing_processes_per_train_process: int = 10 + ): return cls(preprocessing_processes_per_train_process=preprocessing_processes_per_train_process) diff --git a/src/qusi/session.py b/src/qusi/session.py index c8569e33..2954a142 100644 --- a/src/qusi/session.py +++ b/src/qusi/session.py @@ -2,13 +2,20 @@ Session related public interface. """ from qusi.internal.device import get_device +from qusi.internal.finite_test_session import finite_datasets_test_session from qusi.internal.infer_session import infer_session from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration +from qusi.internal.train_logging_configuration import TrainLoggingConfiguration +from qusi.internal.train_system_configuration import TrainSystemConfiguration from qusi.internal.train_session import train_session __all__ = [ + 'finite_datasets_test_session', 'get_device', 'infer_session', 'TrainHyperparameterConfiguration', + 'TrainLoggingConfiguration', + 'TrainSystemConfiguration', 'train_session', ] + From 0c570c1fe3ba302ecb37c3220a75aa7f9bc99055 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Wed, 1 May 2024 23:54:25 -0400 Subject: [PATCH 03/21] Add qodana --- qodana.yaml | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 qodana.yaml diff --git a/qodana.yaml b/qodana.yaml new file mode 100644 index 00000000..c28d57d6 --- /dev/null +++ b/qodana.yaml @@ -0,0 +1,12 @@ +version: "1.0" + +profile: + name: qodana.starter + +exclude: + - name: All + paths: + - src/ramjet + - examples + +linter: jetbrains/qodana-python:latest From d8f48fb05c95a5ba8d0583b42e52d12e47c08ad1 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Thu, 2 May 2024 13:23:55 -0400 Subject: [PATCH 04/21] Add the default transforms to the public API and update the docs to use the observation transforms --- docs/source/tutorials/crafting_standard_datasets.md | 10 +++++----- src/qusi/data.py | 5 ++++- src/qusi/internal/light_curve_dataset.py | 4 ++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/source/tutorials/crafting_standard_datasets.md b/docs/source/tutorials/crafting_standard_datasets.md index 2c1097e7..47358c23 100644 --- a/docs/source/tutorials/crafting_standard_datasets.md +++ b/docs/source/tutorials/crafting_standard_datasets.md @@ -12,26 +12,26 @@ However, the uniform length is set to a specific default value. A good choice fo ```python from functools import partial -from qusi.light_curve_dataset import default_light_curve_post_injection_transform +from qusi.data import default_light_curve_post_injection_transform ``` Then, were we specify the construction of our dataset, we'll add an additional input parameter. So taking what we had in the previous tutorial, we can now change the dataset creation statement to: ```python -train_light_curve_dataset = LightCurveDataset.new( +train_light_curve_dataset = LightCurveObservationDataset.new( standard_light_curve_collections=[positive_train_light_curve_collection, negative_train_light_curve_collection] post_injection_transform=partial(default_light_curve_post_injection_transform, length=4000) ) ``` -Let's clarify what's happening here. The `LightCurveDataset.new()` constructor takes as input a parameter called `post_injection_transform`. This function will be applied to our light curves before they get handed to the NN. `default_light_curve_post_injection_transform` is the default set of preprocessing transforms `qusi` uses. We'll look at these transforms in more detail in the next section. `partial` is a Python builtin function, that takes another function as input, along with a parameter of that function, and creates a new function with that parameter prefilled. So `partial(default_light_curve_post_injection_transform, length=4000)` is taking our default transforms, setting the uniforming lengthening step to 4000, then giving us back the updated function, which we can then give to the dataset. The advantage to this approach is that `post_injection_transform` is completely flexible, as we'll explore more in the next section. +Let's clarify what's happening here. The `LightCurveObservationDataset.new()` constructor takes as input a parameter called `post_injection_transform`. This function will be applied to our light curves before they get handed to the NN. `default_light_curve_observation_post_injection_transform` is the default set of preprocessing transforms `qusi` uses. We'll look at these transforms in more detail in the next section. `partial` is a Python builtin function, that takes another function as input, along with a parameter of that function, and creates a new function with that parameter prefilled. So `partial(default_light_curve_observation_post_injection_transform, length=4000)` is taking our default transforms, setting the uniforming lengthening step to 4000, then giving us back the updated function, which we can then give to the dataset. The advantage to this approach is that `post_injection_transform` is completely flexible, as we'll explore more in the next section. Before we run the updated code, we also need to use a NN model which expects our new input size. Luckily, `qusi` has NN architectures that automatically resize their components for a given input size. So the only other change from the existing code is to change `Hadryss.new()` to `Hadryss.new(input_length=4000)`. ## Modifying the preprocessing -In the previous section, we only changed the length of that the uniform lengthening preprocessing transform is using. However, we still used all the remaining default preprocessing steps that are contained in `default_light_curve_post_injection_transform`. Let's take a look at what the default one does. It looks like: +In the previous section, we only changed the length of that the uniform lengthening preprocessing transform is using. However, we still used all the remaining default preprocessing steps that are contained in `default_light_curve_observation_post_injection_transform`. Let's take a look at what the default one does. It looks like: ```python def default_light_curve_observation_post_injection_transform(x: LightCurveObservation, length: int) -> (Tensor, Tensor): @@ -46,4 +46,4 @@ def default_light_curve_observation_post_injection_transform(x: LightCurveObserv It's a function that takes in a `LightCurveObservation` and spits out two `Tensor`s, one for the fluxes and one for the label to predict. Most of the data transform functions within have names that are largely descriptive, but we'll walk through them anyway. `remove_nan_flux_data_points_from_light_curve_observation` removes time steps from a `LightCurveObservation` where the flux is NaN. `randomly_roll_light_curve_observation` randomly rolls the light curve (a random cut is made and the two segments' order is swapped). `from_light_curve_observation_to_fluxes_array_and_label_array` extracts two NumPy arrays from a `LightCurveObservation`, one for the fluxes and one from the label (which in this case will be an array with a single value). `make_fluxes_and_label_array_uniform_length` performs the uniform lengthening we discussed in the previous section. `pair_array_to_tensor` converts the pair of NumPy arrays to a pair of PyTorch tensors (PyTorch's equivalent of an array). `normalize_tensor_by_modified_z_score` normalizes a tensor via based on the median absolute deviation. Notice, this is only applied to the flux tensor, not the label tensor. -It's worth noting, `default_light_curve_post_injection_transform` is just a function that can be replaced as desired. To remove one of the preprocessing steps or add in an addition one, we can simply make a modified version of this function. Additionally, `qusi` does not require the transform function to output only the fluxes and a binary label. The `Hadryss` NN model expects these two types of values for training, but other models may take advantage of the times of the light curve, or they may predict multi-class or regression labels. +It's worth noting, `default_light_curve_observation_post_injection_transform` is just a function that can be replaced as desired. To remove one of the preprocessing steps or add in an addition one, we can simply make a modified version of this function. Additionally, `qusi` does not require the transform function to output only the fluxes and a binary label. The `Hadryss` NN model expects these two types of values for training, but other models may take advantage of the times of the light curve, or they may predict multi-class or regression labels. diff --git a/src/qusi/data.py b/src/qusi/data.py index 03d77d86..d56669d9 100644 --- a/src/qusi/data.py +++ b/src/qusi/data.py @@ -4,7 +4,8 @@ from qusi.internal.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset from qusi.internal.finite_standard_light_curve_observation_dataset import FiniteStandardLightCurveObservationDataset from qusi.internal.light_curve import LightCurve -from qusi.internal.light_curve_dataset import LightCurveDataset +from qusi.internal.light_curve_dataset import LightCurveDataset, default_light_curve_post_injection_transform, \ + default_light_curve_observation_post_injection_transform from qusi.internal.light_curve_collection import LightCurveObservationCollection, LightCurveCollection __all__ = [ @@ -14,4 +15,6 @@ 'LightCurveCollection', 'LightCurveDataset', 'LightCurveObservationCollection', + 'default_light_curve_post_injection_transform', + 'default_light_curve_observation_post_injection_transform', ] diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index f2d70a0f..b12818e9 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -277,7 +277,7 @@ def __iter__(self): def default_light_curve_observation_post_injection_transform( - x: LightCurveObservation, length: int + x: LightCurveObservation, *, length: int ) -> (Tensor, Tensor): x = remove_nan_flux_data_points_from_light_curve_observation(x) x = randomly_roll_light_curve_observation(x) @@ -288,7 +288,7 @@ def default_light_curve_observation_post_injection_transform( return x -def default_light_curve_post_injection_transform(x: LightCurve, length: int) -> Tensor: +def default_light_curve_post_injection_transform(x: LightCurve, *, length: int) -> Tensor: x = remove_nan_flux_data_points_from_light_curve(x) x = randomly_roll_light_curve(x) x = x.fluxes From ff5bf0fa6555b51f90940b34208a7334ee67fcca Mon Sep 17 00:00:00 2001 From: golmschenk Date: Thu, 2 May 2024 13:33:56 -0400 Subject: [PATCH 05/21] Add the available transforms to the public API --- .../tutorials/crafting_standard_datasets.md | 2 +- src/qusi/data.py | 7 ++----- src/qusi/transform.py | 19 +++++++++++++++++++ 3 files changed, 22 insertions(+), 6 deletions(-) create mode 100644 src/qusi/transform.py diff --git a/docs/source/tutorials/crafting_standard_datasets.md b/docs/source/tutorials/crafting_standard_datasets.md index 47358c23..38cfcd8e 100644 --- a/docs/source/tutorials/crafting_standard_datasets.md +++ b/docs/source/tutorials/crafting_standard_datasets.md @@ -12,7 +12,7 @@ However, the uniform length is set to a specific default value. A good choice fo ```python from functools import partial -from qusi.data import default_light_curve_post_injection_transform +from qusi.transform import default_light_curve_post_injection_transform ``` Then, were we specify the construction of our dataset, we'll add an additional input parameter. So taking what we had in the previous tutorial, we can now change the dataset creation statement to: diff --git a/src/qusi/data.py b/src/qusi/data.py index d56669d9..bf0dd4ec 100644 --- a/src/qusi/data.py +++ b/src/qusi/data.py @@ -1,12 +1,11 @@ """ -Data related public interface. +Data structure related public interface. """ from qusi.internal.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset from qusi.internal.finite_standard_light_curve_observation_dataset import FiniteStandardLightCurveObservationDataset from qusi.internal.light_curve import LightCurve -from qusi.internal.light_curve_dataset import LightCurveDataset, default_light_curve_post_injection_transform, \ - default_light_curve_observation_post_injection_transform from qusi.internal.light_curve_collection import LightCurveObservationCollection, LightCurveCollection +from qusi.internal.light_curve_dataset import LightCurveDataset __all__ = [ 'FiniteStandardLightCurveDataset', @@ -15,6 +14,4 @@ 'LightCurveCollection', 'LightCurveDataset', 'LightCurveObservationCollection', - 'default_light_curve_post_injection_transform', - 'default_light_curve_observation_post_injection_transform', ] diff --git a/src/qusi/transform.py b/src/qusi/transform.py new file mode 100644 index 00000000..ff89742b --- /dev/null +++ b/src/qusi/transform.py @@ -0,0 +1,19 @@ +""" +Data transform related public interface. +""" +from qusi.internal.light_curve_dataset import default_light_curve_post_injection_transform, \ + default_light_curve_observation_post_injection_transform, make_fluxes_and_label_array_uniform_length +from qusi.internal.light_curve_observation import remove_nan_flux_data_points_from_light_curve_observation, \ + randomly_roll_light_curve_observation +from qusi.internal.light_curve_transforms import from_light_curve_observation_to_fluxes_array_and_label_array, \ + pair_array_to_tensor + +__all__ = [ + 'default_light_curve_post_injection_transform', + 'default_light_curve_observation_post_injection_transform', + 'from_light_curve_observation_to_fluxes_array_and_label_array', + 'make_fluxes_and_label_array_uniform_length', + 'pair_array_to_tensor', + 'randomly_roll_light_curve_observation', + 'remove_nan_flux_data_points_from_light_curve_observation', +] From b3a97d995e71dace6a514bb5abde8e97e6898f7c Mon Sep 17 00:00:00 2001 From: golmschenk Date: Thu, 2 May 2024 13:36:25 -0400 Subject: [PATCH 06/21] Remove ramjet import in example scripts to the experimental package --- examples/transit_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/transit_dataset.py b/examples/transit_dataset.py index ad30126d..02db37a0 100644 --- a/examples/transit_dataset.py +++ b/examples/transit_dataset.py @@ -1,7 +1,7 @@ from pathlib import Path from qusi.data import FiniteStandardLightCurveObservationDataset, LightCurveDataset, LightCurveObservationCollection -from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve +from qusi.experimental.application.tess import TessMissionLightCurve def get_positive_train_paths(): From 1321926a5a8c72839602fdc0b9d8799dadd324be Mon Sep 17 00:00:00 2001 From: golmschenk Date: Fri, 3 May 2024 17:54:13 -0400 Subject: [PATCH 07/21] Add inputs of post_injection_transforms --- ...sit_identification_dataset_construction.md | 20 ++++------------ .../tutorials/crafting_standard_datasets.md | 6 ++--- examples/transit_dataset.py | 4 ++-- .../finite_standard_light_curve_dataset.py | 14 +++++++++-- ...tandard_light_curve_observation_dataset.py | 24 +++++++++++++++---- src/qusi/internal/light_curve_dataset.py | 8 +++---- 6 files changed, 46 insertions(+), 30 deletions(-) diff --git a/docs/source/tutorials/basic_transit_identification_dataset_construction.md b/docs/source/tutorials/basic_transit_identification_dataset_construction.md index 1a61e93d..4b9984cf 100644 --- a/docs/source/tutorials/basic_transit_identification_dataset_construction.md +++ b/docs/source/tutorials/basic_transit_identification_dataset_construction.md @@ -56,10 +56,7 @@ Note, `qusi` expects the label functions to take in a `Path` object as input, ev Now we're going to join the various functions we've just defined into `LightCurveObservationCollection`s. For the case of positive train light curves, this looks like: ```python -positive_train_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_train_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) +positive_train_light_curve_collection = LightCurveObservationCollection.new() ``` This defines a collection of labeled light curves where `qusi` knows how to obtain the paths, how to load the times and fluxes of the light curves, and how to load the labels. This `LightCurveObservationCollection.new(...` function takes in the three pieces we just built earlier. Note that you pass in the functions themselves, not the output of the functions. So for the `get_paths_function` parameter, we pass `get_positive_train_paths`, not `get_positive_train_paths()` (notice the difference in parenthesis). `qusi` will call these functions internally. However, the above bit of code is not by itself in `examples/transit_dataset.py` as the rest of the code in this tutorial was. This is because `qusi` doesn't use this collection by itself. It uses it as part of a dataset. We will explain why there's this extra layer in a moment. @@ -70,17 +67,10 @@ Finally, we build the dataset `qusi` uses to train the network. First, we'll tak ```python def get_transit_train_dataset(): - positive_train_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_train_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) - negative_train_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_negative_train_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=negative_label_function) - train_light_curve_dataset = LightCurveDataset.new( - standard_light_curve_collections=[positive_train_light_curve_collection, - negative_train_light_curve_collection]) + positive_train_light_curve_collection = LightCurveObservationCollection.new() + negative_train_light_curve_collection = LightCurveObservationCollection.new() + train_light_curve_dataset = LightCurveDataset.new(light_curve_collections=[positive_train_light_curve_collection, + negative_train_light_curve_collection]) return train_light_curve_dataset ``` diff --git a/docs/source/tutorials/crafting_standard_datasets.md b/docs/source/tutorials/crafting_standard_datasets.md index 38cfcd8e..13578afd 100644 --- a/docs/source/tutorials/crafting_standard_datasets.md +++ b/docs/source/tutorials/crafting_standard_datasets.md @@ -19,9 +19,9 @@ Then, were we specify the construction of our dataset, we'll add an additional i ```python train_light_curve_dataset = LightCurveObservationDataset.new( - standard_light_curve_collections=[positive_train_light_curve_collection, - negative_train_light_curve_collection] - post_injection_transform=partial(default_light_curve_post_injection_transform, length=4000) + light_curve_collections=[positive_train_light_curve_collection, + negative_train_light_curve_collection]) +post_injection_transform = partial(default_light_curve_post_injection_transform, length=4000) ) ``` diff --git a/examples/transit_dataset.py b/examples/transit_dataset.py index 02db37a0..918f15de 100644 --- a/examples/transit_dataset.py +++ b/examples/transit_dataset.py @@ -81,6 +81,6 @@ def get_transit_finite_test_dataset(): load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, load_label_from_path_function=negative_label_function) test_light_curve_dataset = FiniteStandardLightCurveObservationDataset.new( - standard_light_curve_collections=[positive_test_light_curve_collection, - negative_test_light_curve_collection]) + light_curve_collections=[positive_test_light_curve_collection, + negative_test_light_curve_collection]) return test_light_curve_dataset diff --git a/src/qusi/internal/finite_standard_light_curve_dataset.py b/src/qusi/internal/finite_standard_light_curve_dataset.py index c3f7f6cb..6c84e9ec 100644 --- a/src/qusi/internal/finite_standard_light_curve_dataset.py +++ b/src/qusi/internal/finite_standard_light_curve_dataset.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from functools import partial from typing import Any, Callable @@ -18,13 +20,21 @@ class FiniteStandardLightCurveDataset(Dataset): collection_start_indexes: list[int] @classmethod - def new(cls, light_curve_collections: list[LightCurveCollection]) -> Self: + def new( + cls, + light_curve_collections: list[LightCurveCollection], + *, + post_injection_transform: Callable[[Any], Any] | None = None, + ) -> Self: """ Creates a new `FiniteStandardLightCurveDataset`. :param light_curve_collections: The light curve collections to include in the dataset. + :param post_injection_transform: Transforms to the data to occur after injection. :return: The dataset. """ + if post_injection_transform is None: + post_injection_transform = partial(default_light_curve_post_injection_transform, length=2500) length = 0 collection_start_indexes: list[int] = [] for light_curve_collection in light_curve_collections: @@ -33,7 +43,7 @@ def new(cls, light_curve_collections: list[LightCurveCollection]) -> Self: length += standard_light_curve_collection_length instance = cls( standard_light_curve_collections=light_curve_collections, - post_injection_transform=partial(default_light_curve_post_injection_transform, length=2500), + post_injection_transform=post_injection_transform, length=length, collection_start_indexes=collection_start_indexes, ) diff --git a/src/qusi/internal/finite_standard_light_curve_observation_dataset.py b/src/qusi/internal/finite_standard_light_curve_observation_dataset.py index c64f4b66..21a44a14 100644 --- a/src/qusi/internal/finite_standard_light_curve_observation_dataset.py +++ b/src/qusi/internal/finite_standard_light_curve_observation_dataset.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from functools import partial from typing import Any, Callable @@ -18,16 +20,30 @@ class FiniteStandardLightCurveObservationDataset(Dataset): collection_start_indexes: list[int] @classmethod - def new(cls, standard_light_curve_collections: list[LightCurveObservationCollection]) -> Self: + def new( + cls, + light_curve_collections: list[LightCurveObservationCollection], + *, + post_injection_transform: Callable[[Any], Any] | None = None, + ) -> Self: + """ + Creates a new `FiniteStandardLightCurveObservationDataset`. + + :param light_curve_collections: The light curve observation collections to include in the dataset. + :param post_injection_transform: Transforms to the data to occur after injection. + :return: The dataset. + """ + if post_injection_transform is None: + post_injection_transform = partial(default_light_curve_observation_post_injection_transform, length=2500) length = 0 collection_start_indexes: list[int] = [] - for standard_light_curve_collection in standard_light_curve_collections: + for standard_light_curve_collection in light_curve_collections: standard_light_curve_collection_length = len(list(standard_light_curve_collection.path_getter.get_paths())) collection_start_indexes.append(length) length += standard_light_curve_collection_length instance = cls( - standard_light_curve_collections=standard_light_curve_collections, - post_injection_transform=partial(default_light_curve_observation_post_injection_transform, length=2500), + standard_light_curve_collections=light_curve_collections, + post_injection_transform=post_injection_transform, length=length, collection_start_indexes=collection_start_indexes, ) diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index b12818e9..0c8112e0 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -49,6 +49,7 @@ class LightCurveDataset(IterableDataset): def __init__( self, standard_light_curve_collections: list[LightCurveObservationCollection], + *, injectee_light_curve_collections: list[LightCurveObservationCollection], injectable_light_curve_collections: list[LightCurveObservationCollection], post_injection_transform: Callable[[Any], Any], @@ -134,8 +135,8 @@ def __iter__(self): @classmethod def new( cls, - *, standard_light_curve_collections: list[LightCurveObservationCollection] | None = None, + *, injectee_light_curve_collections: list[LightCurveObservationCollection] | None = None, injectable_light_curve_collections: list[LightCurveObservationCollection] | None = None, post_injection_transform: Callable[[Any], Any] | None = None, @@ -276,9 +277,8 @@ def __iter__(self): break -def default_light_curve_observation_post_injection_transform( - x: LightCurveObservation, *, length: int -) -> (Tensor, Tensor): +def default_light_curve_observation_post_injection_transform(x: LightCurveObservation, *, length: int + ) -> (Tensor, Tensor): x = remove_nan_flux_data_points_from_light_curve_observation(x) x = randomly_roll_light_curve_observation(x) x = from_light_curve_observation_to_fluxes_array_and_label_array(x) From 425159551c27a6de57b161bda12b5ddd845d5203 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Mon, 6 May 2024 17:20:09 -0400 Subject: [PATCH 08/21] Add a docstring to all public API methods --- .../tutorials/crafting_standard_datasets.md | 11 +-- .../finite_standard_light_curve_dataset.py | 6 +- ...tandard_light_curve_observation_dataset.py | 6 +- src/qusi/internal/finite_test_session.py | 25 +++--- src/qusi/internal/infer_session.py | 15 +++- src/qusi/internal/light_curve.py | 14 ++++ src/qusi/internal/light_curve_collection.py | 5 ++ src/qusi/internal/light_curve_dataset.py | 79 +++++++++++++------ src/qusi/internal/light_curve_observation.py | 14 ++++ src/qusi/internal/light_curve_transforms.py | 12 +++ .../train_hyperparameter_configuration.py | 46 ++++++++--- .../internal/train_logging_configuration.py | 8 ++ src/qusi/internal/train_session.py | 13 ++- .../internal/train_system_configuration.py | 8 ++ src/qusi/transform.py | 7 +- 15 files changed, 214 insertions(+), 55 deletions(-) diff --git a/docs/source/tutorials/crafting_standard_datasets.md b/docs/source/tutorials/crafting_standard_datasets.md index 13578afd..3baaeaa2 100644 --- a/docs/source/tutorials/crafting_standard_datasets.md +++ b/docs/source/tutorials/crafting_standard_datasets.md @@ -31,19 +31,20 @@ Before we run the updated code, we also need to use a NN model which expects our ## Modifying the preprocessing -In the previous section, we only changed the length of that the uniform lengthening preprocessing transform is using. However, we still used all the remaining default preprocessing steps that are contained in `default_light_curve_observation_post_injection_transform`. Let's take a look at what the default one does. It looks like: +In the previous section, we only changed the length of that the uniform lengthening preprocessing transform is using. However, we still used all the remaining default preprocessing steps that are contained in `default_light_curve_observation_post_injection_transform`. Let's take a look at what the default one does. It looks something like: ```python -def default_light_curve_observation_post_injection_transform(x: LightCurveObservation, length: int) -> (Tensor, Tensor): +def default_light_curve_observation_post_injection_transform(x: LightCurveObservation, length: int, randomize: bool = True) -> (Tensor, Tensor): x = remove_nan_flux_data_points_from_light_curve_observation(x) - x = randomly_roll_light_curve_observation(x) + if randomize: + x = randomly_roll_light_curve_observation(x) x = from_light_curve_observation_to_fluxes_array_and_label_array(x) - x = make_fluxes_and_label_array_uniform_length(x, length=length) + x = (make_uniform_length(x[0], length=length, randomize=randomize), x[1]) # Make the fluxes a uniform length. x = pair_array_to_tensor(x) x = (normalize_tensor_by_modified_z_score(x[0]), x[1]) return x ``` -It's a function that takes in a `LightCurveObservation` and spits out two `Tensor`s, one for the fluxes and one for the label to predict. Most of the data transform functions within have names that are largely descriptive, but we'll walk through them anyway. `remove_nan_flux_data_points_from_light_curve_observation` removes time steps from a `LightCurveObservation` where the flux is NaN. `randomly_roll_light_curve_observation` randomly rolls the light curve (a random cut is made and the two segments' order is swapped). `from_light_curve_observation_to_fluxes_array_and_label_array` extracts two NumPy arrays from a `LightCurveObservation`, one for the fluxes and one from the label (which in this case will be an array with a single value). `make_fluxes_and_label_array_uniform_length` performs the uniform lengthening we discussed in the previous section. `pair_array_to_tensor` converts the pair of NumPy arrays to a pair of PyTorch tensors (PyTorch's equivalent of an array). `normalize_tensor_by_modified_z_score` normalizes a tensor via based on the median absolute deviation. Notice, this is only applied to the flux tensor, not the label tensor. +It's a function that takes in a `LightCurveObservation` and spits out two `Tensor`s, one for the fluxes and one for the label to predict. Most of the data transform functions within have names that are largely descriptive, but we'll walk through them anyway. `remove_nan_flux_data_points_from_light_curve_observation` removes time steps from a `LightCurveObservation` where the flux is NaN. `randomly_roll_light_curve_observation` randomly rolls the light curve (a random cut is made and the two segments' order is swapped). `from_light_curve_observation_to_fluxes_array_and_label_array` extracts two NumPy arrays from a `LightCurveObservation`, one for the fluxes and one from the label (which in this case will be an array with a single value). `make_uniform_length` performs the uniform lengthening on the fluxes as we discussed in the previous section. `pair_array_to_tensor` converts the pair of NumPy arrays to a pair of PyTorch tensors (PyTorch's equivalent of an array). `normalize_tensor_by_modified_z_score` normalizes a tensor via based on the median absolute deviation. Notice, this is only applied to the flux tensor, not the label tensor. The `randomize` parameter enables or disables randomization of the functions which may include randomization. During training, randomization should be on to make sure we get variation in training observations. During evaluation and inference, it should be off to get repeatable results. It's worth noting, `default_light_curve_observation_post_injection_transform` is just a function that can be replaced as desired. To remove one of the preprocessing steps or add in an addition one, we can simply make a modified version of this function. Additionally, `qusi` does not require the transform function to output only the fluxes and a binary label. The `Hadryss` NN model expects these two types of values for training, but other models may take advantage of the times of the light curve, or they may predict multi-class or regression labels. diff --git a/src/qusi/internal/finite_standard_light_curve_dataset.py b/src/qusi/internal/finite_standard_light_curve_dataset.py index 6c84e9ec..a1815174 100644 --- a/src/qusi/internal/finite_standard_light_curve_dataset.py +++ b/src/qusi/internal/finite_standard_light_curve_dataset.py @@ -14,6 +14,9 @@ @dataclass class FiniteStandardLightCurveDataset(Dataset): + """ + A finite light curve dataset without injection. + """ standard_light_curve_collections: list[LightCurveCollection] post_injection_transform: Callable[[Any], Any] length: int @@ -34,7 +37,8 @@ def new( :return: The dataset. """ if post_injection_transform is None: - post_injection_transform = partial(default_light_curve_post_injection_transform, length=2500) + post_injection_transform = partial(default_light_curve_post_injection_transform, length=2500, + randomize=False) length = 0 collection_start_indexes: list[int] = [] for light_curve_collection in light_curve_collections: diff --git a/src/qusi/internal/finite_standard_light_curve_observation_dataset.py b/src/qusi/internal/finite_standard_light_curve_observation_dataset.py index 21a44a14..437fd9ea 100644 --- a/src/qusi/internal/finite_standard_light_curve_observation_dataset.py +++ b/src/qusi/internal/finite_standard_light_curve_observation_dataset.py @@ -14,6 +14,9 @@ @dataclass class FiniteStandardLightCurveObservationDataset(Dataset): + """ + A finite light curve observation dataset without injection. + """ standard_light_curve_collections: list[LightCurveObservationCollection] post_injection_transform: Callable[[Any], Any] length: int @@ -34,7 +37,8 @@ def new( :return: The dataset. """ if post_injection_transform is None: - post_injection_transform = partial(default_light_curve_observation_post_injection_transform, length=2500) + post_injection_transform = partial(default_light_curve_observation_post_injection_transform, length=2500, + randomize=False) length = 0 collection_start_indexes: list[int] = [] for standard_light_curve_collection in light_curve_collections: diff --git a/src/qusi/internal/finite_test_session.py b/src/qusi/internal/finite_test_session.py index fde27903..8430ec60 100644 --- a/src/qusi/internal/finite_test_session.py +++ b/src/qusi/internal/finite_test_session.py @@ -9,11 +9,22 @@ def finite_datasets_test_session( test_datasets: list[FiniteStandardLightCurveObservationDataset], model: Module, - *, metric_functions: list[Module], - batch_size: int, - device: Device, + *, + batch_size: int = 100, + device: Device = torch.device('cpu'), ): + """ + Runs a test session on finite datasets. + + :param test_datasets: A list of datasets to run the test session on. + :param model: A model to perform the inference. + :param metric_functions: A metrics to test. + :param batch_size: A batch size to use during testing. + :param device: A device to run the model on. + :return: A list of arrays, with one array for each test dataset, with each array containing an element for each + metric that was tested. + """ test_dataloaders: list[DataLoader] = [] for test_dataset in test_datasets: test_dataloader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True) @@ -26,14 +37,6 @@ def finite_datasets_test_session( return results -def get_device(): - if torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") - return device - - def finite_dataset_test_phase(dataloader, model: Module, metric_functions: list[Module], device: Device): batch_count = 0 metric_totals = torch.zeros(size=[len(metric_functions)]) diff --git a/src/qusi/internal/infer_session.py b/src/qusi/internal/infer_session.py index e9499779..32f9d936 100644 --- a/src/qusi/internal/infer_session.py +++ b/src/qusi/internal/infer_session.py @@ -8,8 +8,21 @@ def infer_session( - infer_datasets: list[FiniteStandardLightCurveDataset], model: Module, *, batch_size: int, device: Device + infer_datasets: list[FiniteStandardLightCurveDataset], + model: Module, + *, + batch_size: int, + device: Device, ) -> list[np.ndarray]: + """ + Runs an infer session on finite datasets. + + :param infer_datasets: The list of datasets to run the infer session on. + :param model: The model to perform the inference. + :param batch_size: The batch size to use during inference. + :param device: The device to run the model on. + :return: A list of arrays with each element being the array predicted for each light curve in the dataset. + """ infer_dataloaders: list[DataLoader] = [] for infer_dataset in infer_datasets: infer_dataloader = DataLoader(infer_dataset, batch_size=batch_size, pin_memory=True) diff --git a/src/qusi/internal/light_curve.py b/src/qusi/internal/light_curve.py index baa1b309..d866dda1 100644 --- a/src/qusi/internal/light_curve.py +++ b/src/qusi/internal/light_curve.py @@ -34,6 +34,13 @@ def new(cls, times: npt.NDArray[np.float32], fluxes: npt.NDArray[np.float32]) -> def remove_nan_flux_data_points_from_light_curve(light_curve: LightCurve) -> LightCurve: + """ + Removes the NaN values from a light curve in a light curve. If there is a NaN in either the times or the + fluxes, both corresponding values are removed. + + :param light_curve: The light curve. + :return: The light curve with NaN values removed. + """ light_curve = deepcopy(light_curve) nan_flux_indexes = np.isnan(light_curve.fluxes) light_curve.fluxes = light_curve.fluxes[~nan_flux_indexes] @@ -42,6 +49,13 @@ def remove_nan_flux_data_points_from_light_curve(light_curve: LightCurve) -> Lig def randomly_roll_light_curve(light_curve: LightCurve) -> LightCurve: + """ + Randomly rolls a light curve. That is, a random position in the light curve is chosen, the light curve + is split at that point, and the order of the two halves are swapped. + + :param light_curve: The light curve. + :return: The rolled light curve. + """ light_curve = deepcopy(light_curve) random_index = np.random.randint(light_curve.times.shape[0]) light_curve.times = np.roll(light_curve.times, random_index) diff --git a/src/qusi/internal/light_curve_collection.py b/src/qusi/internal/light_curve_collection.py index 329c7071..531810ab 100644 --- a/src/qusi/internal/light_curve_collection.py +++ b/src/qusi/internal/light_curve_collection.py @@ -117,6 +117,8 @@ class LightCurveCollection( LightCurveCollectionBase, LightCurveObservationIndexableBase ): """ + A collection of light curves, including where to find paths to the data and how to load the data. + :ivar path_getter: The PathIterableBase object for the collection. :ivar load_times_and_fluxes_from_path_function: The function to load the times and fluxes from the light curve. """ @@ -178,6 +180,9 @@ class LightCurveObservationCollection( LightCurveObservationCollectionBase, LightCurveObservationIndexableBase ): """ + A collection of light curve observations. Includes where to find the light curve data paths, and how to load + the times, fluxes, and label data. + :ivar path_getter: The PathGetterBase object for the collection. :ivar light_curve_collection: The LightCurveCollectionBase object for the collection. :ivar load_label_from_path_function: The function to load the label for the light curve. diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index 0c8112e0..d2423274 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -43,7 +43,7 @@ class LightCurveDataset(IterableDataset): """ - A dataset of light curve data. + A dataset of light curves. Includes cases where light curves can be injected into one another. """ def __init__( @@ -141,6 +141,16 @@ def new( injectable_light_curve_collections: list[LightCurveObservationCollection] | None = None, post_injection_transform: Callable[[Any], Any] | None = None, ) -> Self: + """ + Creates a new light curve dataset. + + :param standard_light_curve_collections: The light curve collections to be used without injection. + :param injectee_light_curve_collections: The light curve collections that other light curves will be injected + into. + :param injectable_light_curve_collections: The light curve collections that will be injected into other light + curves. + :return: The light curve dataset. + """ if ( standard_light_curve_collections is None and injectee_light_curve_collections is None @@ -231,7 +241,7 @@ class LightCurveCollectionType(Enum): class InterleavedDataset(IterableDataset): def __init__(self, *datasets: IterableDataset): - self.datasets: tuple[IterableDataset] = datasets + self.datasets: tuple[IterableDataset, ...] = datasets @classmethod def new(cls, *datasets: IterableDataset): @@ -246,7 +256,7 @@ def __iter__(self): class ConcatenatedIterableDataset(IterableDataset): def __init__(self, *datasets: IterableDataset): - self.datasets: tuple[IterableDataset] = datasets + self.datasets: tuple[IterableDataset, ...] = datasets @classmethod def new(cls, *datasets: IterableDataset): @@ -277,28 +287,64 @@ def __iter__(self): break -def default_light_curve_observation_post_injection_transform(x: LightCurveObservation, *, length: int - ) -> (Tensor, Tensor): +def default_light_curve_observation_post_injection_transform( + x: LightCurveObservation, + *, + length: int, + randomize: bool = True, +) -> (Tensor, Tensor): + """ + The default light curve observation post injection transforms. A set of transforms that is expected to work well for + a variety of use cases. + + :param x: The light curve observation to be transformed. + :param length: The length to make all light curves. + :param randomize: Whether to have randomization in the transforms. + :return: The transformed light curve observation. + """ x = remove_nan_flux_data_points_from_light_curve_observation(x) - x = randomly_roll_light_curve_observation(x) + if randomize: + x = randomly_roll_light_curve_observation(x) x = from_light_curve_observation_to_fluxes_array_and_label_array(x) - x = make_fluxes_and_label_array_uniform_length(x, length=length) + x = (make_uniform_length(x[0], length=length, randomize=randomize), x[1]) # Make the fluxes a uniform length. x = pair_array_to_tensor(x) x = (normalize_tensor_by_modified_z_score(x[0]), x[1]) return x -def default_light_curve_post_injection_transform(x: LightCurve, *, length: int) -> Tensor: +def default_light_curve_post_injection_transform( + x: LightCurve, + *, + length: int, + randomize: bool = True, +) -> Tensor: + """ + The default light curve post injection transforms. A set of transforms that is expected to work well for a variety + of use cases. + + :param x: The light curve to be transformed. + :param length: The length to make all light curves. + :param randomize: Whether to have randomization in the transforms. + :return: The transformed light curve. + """ x = remove_nan_flux_data_points_from_light_curve(x) - x = randomly_roll_light_curve(x) + if randomize: + x = randomly_roll_light_curve(x) x = x.fluxes - x = make_uniform_length(x, length=length) + x = make_uniform_length(x, length=length, randomize=randomize) x = torch.tensor(x, dtype=torch.float32) x = normalize_tensor_by_modified_z_score(x) return x def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor: + """ + Normalizes a tensor by a modified z-score. That is, normalizes the values of the tensor based on the median + absolute deviation. + + :param tensor: The tensor to normalize. + :return: The normalized tensor. + """ median = torch.median(tensor) deviation_from_median = tensor - median absolute_deviation_from_median = torch.abs(deviation_from_median) @@ -312,19 +358,6 @@ def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor: return modified_z_score -def make_fluxes_and_label_array_uniform_length( - arrays: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]], - length: int, - *, - randomize: bool = True, -) -> (np.ndarray, np.ndarray): - fluxes, label = arrays - uniform_length_times = make_uniform_length( - fluxes, length=length, randomize=randomize - ) - return uniform_length_times, label - - def make_uniform_length( example: np.ndarray, length: int, *, randomize: bool = True ) -> np.ndarray: diff --git a/src/qusi/internal/light_curve_observation.py b/src/qusi/internal/light_curve_observation.py index 042f4762..aa1cc86f 100644 --- a/src/qusi/internal/light_curve_observation.py +++ b/src/qusi/internal/light_curve_observation.py @@ -34,6 +34,13 @@ def new(cls, light_curve: LightCurve, label: int) -> Self: def remove_nan_flux_data_points_from_light_curve_observation( light_curve_observation: LightCurveObservation, ) -> LightCurveObservation: + """ + Removes the NaN values from a light curve in a light curve observation. If there is a NaN in either the times or the + fluxes, both corresponding values are removed. + + :param light_curve_observation: The light curve observation. + :return: The light curve observation with NaN values removed. + """ light_curve_observation = deepcopy(light_curve_observation) light_curve_observation.light_curve = remove_nan_flux_data_points_from_light_curve( light_curve_observation.light_curve @@ -42,6 +49,13 @@ def remove_nan_flux_data_points_from_light_curve_observation( def randomly_roll_light_curve_observation(light_curve_observation: LightCurveObservation) -> LightCurveObservation: + """ + Randomly rolls a light curve observation. That is, a random position in the light curve is chosen, the light curve + is split at that point, and the order of the two halves are swapped. + + :param light_curve_observation: The light curve observation. + :return: The light curve observation with the rolled light curve. + """ light_curve_observation = deepcopy(light_curve_observation) light_curve_observation.light_curve = randomly_roll_light_curve(light_curve_observation.light_curve) return light_curve_observation diff --git a/src/qusi/internal/light_curve_transforms.py b/src/qusi/internal/light_curve_transforms.py index c9b10b73..dae294da 100644 --- a/src/qusi/internal/light_curve_transforms.py +++ b/src/qusi/internal/light_curve_transforms.py @@ -9,6 +9,12 @@ def from_light_curve_observation_to_fluxes_array_and_label_array( light_curve_observation: LightCurveObservation, ) -> (npt.NDArray[np.float32], npt.NDArray[np.float32]): + """ + Extracts the fluxes and label from a light curve observation. + + :param light_curve_observation: The light curve observation. + :return: The fluxes and label array. + """ fluxes = light_curve_observation.light_curve.fluxes label = light_curve_observation.label return fluxes, np.array(label, dtype=np.float32) @@ -17,6 +23,12 @@ def from_light_curve_observation_to_fluxes_array_and_label_array( def pair_array_to_tensor( arrays: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]], ) -> (Tensor, Tensor): + """ + Converts a pair of arrays to a pair of tensors. + + :param arrays: The arrays to convert. + :return: The tensors. + """ return torch.tensor(arrays[0], dtype=torch.float32), torch.tensor( arrays[1], dtype=torch.float32 ) diff --git a/src/qusi/internal/train_hyperparameter_configuration.py b/src/qusi/internal/train_hyperparameter_configuration.py index bd975a62..09bfb38e 100644 --- a/src/qusi/internal/train_hyperparameter_configuration.py +++ b/src/qusi/internal/train_hyperparameter_configuration.py @@ -6,33 +6,59 @@ class TrainHyperparameterConfiguration: """ Hyperparameter configuration settings for a train session. - :ivar batch_size: The size of the batch for each train process. Each training step will use a number of examples + :ivar cycles: The number of cycles to run. Cycles consist of one set of training steps and one set of validation + steps. They can be seen as analogous to epochs. However, as qusi datasets are often + infinite or have different length sub-collections, there is not always the exact equivalent of an + epoch, so cycles are used instead. + :ivar train_steps_per_cycle: The number of training steps per cycle. + :ivar validation_steps_per_cycle: The number of validation steps per cycle. + :ivar batch_size: The size of the batch for each train process. Each training step will use a number of observations equal to this value multiplied by the number of train processes. - :ivar cycles: The number of train cycles to run. + :ivar learning_rate: The learning rate. + :ivar optimizer_epsilon: The epsilon to be used by the optimizer. + :ivar weight_decay: The weight decay of the optimizer. + :ivar norm_based_gradient_clip: The norm based gradient clipping value. """ - learning_rate: float - optimizer_epsilon: float - weight_decay: float - batch_size: int cycles: int train_steps_per_cycle: int validation_steps_per_cycle: int + batch_size: int + learning_rate: float + optimizer_epsilon: float + weight_decay: float norm_based_gradient_clip: float @classmethod def new( cls, *, + cycles: int = 5000, + train_steps_per_cycle: int = 100, + validation_steps_per_cycle: int = 10, + batch_size: int = 100, learning_rate: float = 1e-4, optimizer_epsilon: float = 1e-7, weight_decay: float = 0.0001, - batch_size: int = 100, - train_steps_per_cycle: int = 100, - validation_steps_per_cycle: int = 10, - cycles: int = 5000, norm_based_gradient_clip: float = 1.0, ): + """ + Creates a new `TrainHyperparameterConfiguration`. + + :param cycles: The number of cycles to run. Cycles consist of one set of training steps and one set of validation + steps. They can be seen as analogous to epochs. However, as qusi datasets are often + infinite or have different length sub-collections, there is not always the exact equivalent of an + epoch, so cycles are used instead. + :param train_steps_per_cycle: The number of training steps per cycle. + :param validation_steps_per_cycle: The number of validation steps per cycle. + :param batch_size: The size of the batch for each train process. Each training step will use a number of observations + equal to this value multiplied by the number of train processes. + :param learning_rate: The learning rate. + :param optimizer_epsilon: The epsilon to be used by the optimizer. + :param weight_decay: The weight decay of the optimizer. + :param norm_based_gradient_clip: The norm based gradient clipping value. + :return: The hyperparameter configuration. + """ return cls( learning_rate=learning_rate, optimizer_epsilon=optimizer_epsilon, diff --git a/src/qusi/internal/train_logging_configuration.py b/src/qusi/internal/train_logging_configuration.py index 28c9cc49..6664c283 100644 --- a/src/qusi/internal/train_logging_configuration.py +++ b/src/qusi/internal/train_logging_configuration.py @@ -26,6 +26,14 @@ def new( wandb_entity: str | None = None, additional_log_dictionary: dict[str, Any] | None = None, ): + """ + Creates a `TrainLoggingConfiguration`. + + :param wandb_project: The wandb project to log to. + :param wandb_entity: The wandb entity to log to. + :param additional_log_dictionary: The dictionary of additional values to log. + :return: The `TrainLoggingConfiguration`. + """ if additional_log_dictionary is None: additional_log_dictionary = {} return cls( diff --git a/src/qusi/internal/train_session.py b/src/qusi/internal/train_session.py index bdd711c6..4366d4a6 100644 --- a/src/qusi/internal/train_session.py +++ b/src/qusi/internal/train_session.py @@ -30,7 +30,18 @@ def train_session( *, hyperparameter_configuration: TrainHyperparameterConfiguration | None = None, logging_configuration: TrainLoggingConfiguration | None = None, -): +) -> None: + """ + Runs a training session. + + :param train_datasets: The datasets to train on. + :param validation_datasets: The datasets to validate on. + :param model: The model to train. + :param loss_function: The loss function to train the model on. + :param metric_functions: A list of metric functions to record during the training process. + :param hyperparameter_configuration: The configuration of the hyperparameters + :param logging_configuration: The configuration of the logging. + """ if hyperparameter_configuration is None: hyperparameter_configuration = TrainHyperparameterConfiguration.new() if logging_configuration is None: diff --git a/src/qusi/internal/train_system_configuration.py b/src/qusi/internal/train_system_configuration.py index 608f68a5..ac3153b1 100644 --- a/src/qusi/internal/train_system_configuration.py +++ b/src/qusi/internal/train_system_configuration.py @@ -19,4 +19,12 @@ def new( *, preprocessing_processes_per_train_process: int = 10 ): + """ + Creates a `TrainSystemConfiguration`. + + :param preprocessing_processes_per_train_process: The number of processes that are started to preprocess the data + per train process. The train session will create this many processes for each the train data and the validation + data. + :return: The `TrainSystemConfiguration`. + """ return cls(preprocessing_processes_per_train_process=preprocessing_processes_per_train_process) diff --git a/src/qusi/transform.py b/src/qusi/transform.py index ff89742b..3f929510 100644 --- a/src/qusi/transform.py +++ b/src/qusi/transform.py @@ -1,8 +1,9 @@ """ Data transform related public interface. """ +from qusi.internal.light_curve import randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve from qusi.internal.light_curve_dataset import default_light_curve_post_injection_transform, \ - default_light_curve_observation_post_injection_transform, make_fluxes_and_label_array_uniform_length + default_light_curve_observation_post_injection_transform, make_uniform_length from qusi.internal.light_curve_observation import remove_nan_flux_data_points_from_light_curve_observation, \ randomly_roll_light_curve_observation from qusi.internal.light_curve_transforms import from_light_curve_observation_to_fluxes_array_and_label_array, \ @@ -12,8 +13,10 @@ 'default_light_curve_post_injection_transform', 'default_light_curve_observation_post_injection_transform', 'from_light_curve_observation_to_fluxes_array_and_label_array', - 'make_fluxes_and_label_array_uniform_length', + 'make_uniform_length', 'pair_array_to_tensor', + 'randomly_roll_light_curve', 'randomly_roll_light_curve_observation', + 'remove_nan_flux_data_points_from_light_curve', 'remove_nan_flux_data_points_from_light_curve_observation', ] From 6c23e8454d8662d5e877b0197fa957e2a3adba56 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Mon, 6 May 2024 17:25:46 -0400 Subject: [PATCH 09/21] Add note about randomization --- docs/source/tutorials/crafting_standard_datasets.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tutorials/crafting_standard_datasets.md b/docs/source/tutorials/crafting_standard_datasets.md index 3baaeaa2..89273d95 100644 --- a/docs/source/tutorials/crafting_standard_datasets.md +++ b/docs/source/tutorials/crafting_standard_datasets.md @@ -45,6 +45,6 @@ def default_light_curve_observation_post_injection_transform(x: LightCurveObserv return x ``` -It's a function that takes in a `LightCurveObservation` and spits out two `Tensor`s, one for the fluxes and one for the label to predict. Most of the data transform functions within have names that are largely descriptive, but we'll walk through them anyway. `remove_nan_flux_data_points_from_light_curve_observation` removes time steps from a `LightCurveObservation` where the flux is NaN. `randomly_roll_light_curve_observation` randomly rolls the light curve (a random cut is made and the two segments' order is swapped). `from_light_curve_observation_to_fluxes_array_and_label_array` extracts two NumPy arrays from a `LightCurveObservation`, one for the fluxes and one from the label (which in this case will be an array with a single value). `make_uniform_length` performs the uniform lengthening on the fluxes as we discussed in the previous section. `pair_array_to_tensor` converts the pair of NumPy arrays to a pair of PyTorch tensors (PyTorch's equivalent of an array). `normalize_tensor_by_modified_z_score` normalizes a tensor via based on the median absolute deviation. Notice, this is only applied to the flux tensor, not the label tensor. The `randomize` parameter enables or disables randomization of the functions which may include randomization. During training, randomization should be on to make sure we get variation in training observations. During evaluation and inference, it should be off to get repeatable results. +It's a function that takes in a `LightCurveObservation` and spits out two `Tensor`s, one for the fluxes and one for the label to predict. Most of the data transform functions within have names that are largely descriptive, but we'll walk through them anyway. `remove_nan_flux_data_points_from_light_curve_observation` removes time steps from a `LightCurveObservation` where the flux is NaN. `randomly_roll_light_curve_observation` randomly rolls the light curve (a random cut is made and the two segments' order is swapped). `from_light_curve_observation_to_fluxes_array_and_label_array` extracts two NumPy arrays from a `LightCurveObservation`, one for the fluxes and one from the label (which in this case will be an array with a single value). `make_uniform_length` performs the uniform lengthening on the fluxes as we discussed in the previous section. `pair_array_to_tensor` converts the pair of NumPy arrays to a pair of PyTorch tensors (PyTorch's equivalent of an array). `normalize_tensor_by_modified_z_score` normalizes a tensor via based on the median absolute deviation. Notice, this is only applied to the flux tensor, not the label tensor. The `randomize` parameter enables or disables randomization of the functions which may include randomization. During training, randomization should be on to make sure we get variation in training observations. During evaluation and inference, it should be off to get repeatable results. In our previous example, to keep the code simple, we did not disable randomization for the validation dataset. Although in most cases it will not make a major difference, randomization should be disabled on the validation dataset. It should only be enabled for the training dataset. It's worth noting, `default_light_curve_observation_post_injection_transform` is just a function that can be replaced as desired. To remove one of the preprocessing steps or add in an addition one, we can simply make a modified version of this function. Additionally, `qusi` does not require the transform function to output only the fluxes and a binary label. The `Hadryss` NN model expects these two types of values for training, but other models may take advantage of the times of the light curve, or they may predict multi-class or regression labels. From 8406a76f2a564dc686c3fb823defa5498baecd57 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Mon, 6 May 2024 17:49:42 -0400 Subject: [PATCH 10/21] Add an infer case --- examples/download_spoc_transit_light_curves.py | 7 +++++++ examples/transit_dataset.py | 4 ++++ examples/transit_infer.py | 13 +------------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/download_spoc_transit_light_curves.py b/examples/download_spoc_transit_light_curves.py index 0bff68ee..46a9976a 100644 --- a/examples/download_spoc_transit_light_curves.py +++ b/examples/download_spoc_transit_light_curves.py @@ -1,3 +1,4 @@ +import shutil from pathlib import Path import numpy as np @@ -60,6 +61,12 @@ def main(): download_directory=Path('data/spoc_transit_experiment/test/negatives'), sectors=sectors, limit=600) + # In this toy example, we reuse our test light curves as infer light curves. In a real world case, you would likely + # want to infer on cases you don't already know the answer for. + infer_directory_path = Path('data/spoc_transit_experiment/infer') + infer_directory_path.mkdir(exist_ok=True, parents=True) + for light_curve_path in Path('data/spoc_transit_experiment/test').glob('**/*.fits'): + shutil.copy(light_curve_path, infer_directory_path.joinpath(light_curve_path.name)) if __name__ == '__main__': diff --git a/examples/transit_dataset.py b/examples/transit_dataset.py index 918f15de..bc9d9e20 100644 --- a/examples/transit_dataset.py +++ b/examples/transit_dataset.py @@ -28,6 +28,10 @@ def get_positive_test_paths(): return list(Path('data/spoc_transit_experiment/test/positives').glob('*.fits')) +def get_infer_paths(): + return list(Path('data/spoc_transit_experiment/infer').glob('*.fits')) + + def load_times_and_fluxes_from_path(path): light_curve = TessMissionLightCurve.from_path(path) return light_curve.times, light_curve.fluxes diff --git a/examples/transit_infer.py b/examples/transit_infer.py index d8dfec87..7e2a3923 100644 --- a/examples/transit_infer.py +++ b/examples/transit_infer.py @@ -1,22 +1,11 @@ from pathlib import Path -import numpy as np import torch from qusi.data import FiniteStandardLightCurveDataset, LightCurveCollection from qusi.model import Hadryss from qusi.session import get_device, infer_session -from qusi.experimental.application.tess import TessMissionLightCurve - - -def get_infer_paths(): - return (list(Path('data/spoc_transit_experiment/test/negatives').glob('*.fits')) + - list(Path('data/spoc_transit_experiment/test/positives').glob('*.fits'))) - - -def load_times_and_fluxes_from_path(path: Path) -> (np.ndarray, np.ndarray): - light_curve = TessMissionLightCurve.from_path(path) - return light_curve.times, light_curve.fluxes +from transit_dataset import load_times_and_fluxes_from_path def main(): From 4dc9e1f6baef83fc2e2e427a9c2cef27d401612a Mon Sep 17 00:00:00 2001 From: golmschenk Date: Mon, 6 May 2024 18:17:53 -0400 Subject: [PATCH 11/21] Add the infinite_datasets_test_session --- .../infinite_datasets_test_session.py | 51 +++++++++++++++++++ src/qusi/session.py | 2 + 2 files changed, 53 insertions(+) create mode 100644 src/qusi/internal/infinite_datasets_test_session.py diff --git a/src/qusi/internal/infinite_datasets_test_session.py b/src/qusi/internal/infinite_datasets_test_session.py new file mode 100644 index 00000000..5abddb72 --- /dev/null +++ b/src/qusi/internal/infinite_datasets_test_session.py @@ -0,0 +1,51 @@ +from torch.nn import Module +from torch.types import Device +from torch.utils.data import DataLoader +from wandb.wandb_torch import torch + +from qusi.internal.light_curve_dataset import LightCurveDataset + + +def infinite_datasets_test_session(test_datasets: list[LightCurveDataset], model: Module, + metric_functions: list[Module], *, batch_size: int, device: Device, steps: int): + """ + Runs a test session on finite datasets. + + :param test_datasets: A list of datasets to run the test session on. + :param model: A model to perform the inference. + :param metric_functions: A metrics to test. + :param batch_size: A batch size to use during testing. + :param device: A device to run the model on. + :param steps: The number of steps to run on the infinite datasets. + :return: A list of arrays, with one array for each test dataset, with each array containing an element for each + metric that was tested. + """ + test_dataloaders: list[DataLoader] = [] + for test_dataset in test_datasets: + test_dataloaders.append(DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)) + model.eval() + results = [] + for test_dataloader in test_dataloaders: + result = infinite_dataset_test_phase(test_dataloader, model, metric_functions, device=device, steps=steps) + results.append(result) + return results + + +def infinite_dataset_test_phase(dataloader, model: Module, metric_functions: list[Module], device: Device, steps: int): + batch_count = 0 + metric_totals = torch.zeros(size=[len(metric_functions)]) + model.eval() + with torch.no_grad(): + for input_features, targets in dataloader: + input_features = input_features.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + predicted_targets = model(input_features) + for metric_function_index, metric_function in enumerate(metric_functions): + batch_metric_value = metric_function(predicted_targets.to(device, non_blocking=True), + targets) + metric_totals[metric_function_index] += batch_metric_value.to('cpu', non_blocking=True) + batch_count += 1 + if batch_count >= steps: + break + cycle_metric_values = metric_totals / batch_count + return cycle_metric_values diff --git a/src/qusi/session.py b/src/qusi/session.py index 2954a142..8d883245 100644 --- a/src/qusi/session.py +++ b/src/qusi/session.py @@ -4,6 +4,7 @@ from qusi.internal.device import get_device from qusi.internal.finite_test_session import finite_datasets_test_session from qusi.internal.infer_session import infer_session +from qusi.internal.infinite_datasets_test_session import infinite_datasets_test_session from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration from qusi.internal.train_logging_configuration import TrainLoggingConfiguration from qusi.internal.train_system_configuration import TrainSystemConfiguration @@ -13,6 +14,7 @@ 'finite_datasets_test_session', 'get_device', 'infer_session', + 'infinite_datasets_test_session', 'TrainHyperparameterConfiguration', 'TrainLoggingConfiguration', 'TrainSystemConfiguration', From aa298a8f62b021ad0f157c0534d5cb506b3afe9b Mon Sep 17 00:00:00 2001 From: golmschenk Date: Wed, 8 May 2024 19:35:25 -0400 Subject: [PATCH 12/21] Make the default TESS SPOC light curve processed length be 3500 --- src/qusi/internal/finite_standard_light_curve_dataset.py | 2 +- .../internal/finite_standard_light_curve_observation_dataset.py | 2 +- src/qusi/internal/hadryss_model.py | 2 +- src/qusi/internal/light_curve_dataset.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/qusi/internal/finite_standard_light_curve_dataset.py b/src/qusi/internal/finite_standard_light_curve_dataset.py index a1815174..f2333cba 100644 --- a/src/qusi/internal/finite_standard_light_curve_dataset.py +++ b/src/qusi/internal/finite_standard_light_curve_dataset.py @@ -37,7 +37,7 @@ def new( :return: The dataset. """ if post_injection_transform is None: - post_injection_transform = partial(default_light_curve_post_injection_transform, length=2500, + post_injection_transform = partial(default_light_curve_post_injection_transform, length=3500, randomize=False) length = 0 collection_start_indexes: list[int] = [] diff --git a/src/qusi/internal/finite_standard_light_curve_observation_dataset.py b/src/qusi/internal/finite_standard_light_curve_observation_dataset.py index 437fd9ea..c1918361 100644 --- a/src/qusi/internal/finite_standard_light_curve_observation_dataset.py +++ b/src/qusi/internal/finite_standard_light_curve_observation_dataset.py @@ -37,7 +37,7 @@ def new( :return: The dataset. """ if post_injection_transform is None: - post_injection_transform = partial(default_light_curve_observation_post_injection_transform, length=2500, + post_injection_transform = partial(default_light_curve_observation_post_injection_transform, length=3500, randomize=False) length = 0 collection_start_indexes: list[int] = [] diff --git a/src/qusi/internal/hadryss_model.py b/src/qusi/internal/hadryss_model.py index c4701d81..c8b54908 100644 --- a/src/qusi/internal/hadryss_model.py +++ b/src/qusi/internal/hadryss_model.py @@ -127,7 +127,7 @@ def forward(self, x: Tensor) -> Tensor: return x @classmethod - def new(cls, input_length: int = 2500) -> Self: + def new(cls, input_length: int = 3500) -> Self: """ Creates a new Hadryss model. diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index d2423274..f320cd5c 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -168,7 +168,7 @@ def new( injectable_light_curve_collections = [] if post_injection_transform is None: post_injection_transform = partial( - default_light_curve_observation_post_injection_transform, length=2500 + default_light_curve_observation_post_injection_transform, length=3500 ) instance = cls( standard_light_curve_collections=standard_light_curve_collections, From 95a5b906967ce42af829d6e27c406dc22c1b273f Mon Sep 17 00:00:00 2001 From: golmschenk Date: Wed, 8 May 2024 19:42:51 -0400 Subject: [PATCH 13/21] Remove examples from the main project repository to move them to a separate repository --- .../download_spoc_transit_light_curves.py | 73 -------------- examples/transit_dataset.py | 90 ----------------- examples/transit_finite_dataset_test.py | 23 ----- examples/transit_infer.py | 32 ------- examples/transit_infinite_dataset_test.py | 96 ------------------- examples/transit_light_curve_visualization.py | 20 ---- examples/transit_train.py | 17 ---- 7 files changed, 351 deletions(-) delete mode 100644 examples/download_spoc_transit_light_curves.py delete mode 100644 examples/transit_dataset.py delete mode 100644 examples/transit_finite_dataset_test.py delete mode 100644 examples/transit_infer.py delete mode 100644 examples/transit_infinite_dataset_test.py delete mode 100644 examples/transit_light_curve_visualization.py delete mode 100644 examples/transit_train.py diff --git a/examples/download_spoc_transit_light_curves.py b/examples/download_spoc_transit_light_curves.py deleted file mode 100644 index 46a9976a..00000000 --- a/examples/download_spoc_transit_light_curves.py +++ /dev/null @@ -1,73 +0,0 @@ -import shutil -from pathlib import Path - -import numpy as np - -from qusi.experimental.application.tess import ( - download_spoc_light_curves_for_tic_ids, - get_spoc_tic_id_list_from_mast, - TessToiDataInterface, - ToiColumns, -) - - -def main(): - print('Retrieving metadata...') - spoc_target_tic_ids = get_spoc_tic_id_list_from_mast() - tess_toi_data_interface = TessToiDataInterface() - positive_tic_ids = tess_toi_data_interface.toi_dispositions[ - tess_toi_data_interface.toi_dispositions[ToiColumns.disposition.value] != 'FP'][ToiColumns.tic_id.value] - negative_tic_ids = list(set(spoc_target_tic_ids) - set(positive_tic_ids)) - positive_tic_ids_splits = np.split( - np.array(positive_tic_ids), [int(len(positive_tic_ids) * 0.8), int(len(positive_tic_ids) * 0.9)]) - positive_train_tic_ids = positive_tic_ids_splits[0].tolist() - positive_validation_tic_ids = positive_tic_ids_splits[1].tolist() - positive_test_tic_ids = positive_tic_ids_splits[2].tolist() - negative_tic_ids_splits = np.split( - np.array(negative_tic_ids), [int(len(negative_tic_ids) * 0.8), int(len(negative_tic_ids) * 0.9)]) - negative_train_tic_ids = negative_tic_ids_splits[0].tolist() - negative_validation_tic_ids = negative_tic_ids_splits[1].tolist() - negative_test_tic_ids = negative_tic_ids_splits[2].tolist() - sectors = list(range(27, 56)) - - print('Downloading light curves...') - download_spoc_light_curves_for_tic_ids( - tic_ids=positive_train_tic_ids, - download_directory=Path('data/spoc_transit_experiment/train/positives'), - sectors=sectors, - limit=2000) - download_spoc_light_curves_for_tic_ids( - tic_ids=negative_train_tic_ids, - download_directory=Path('data/spoc_transit_experiment/train/negatives'), - sectors=sectors, - limit=6000) - download_spoc_light_curves_for_tic_ids( - tic_ids=positive_validation_tic_ids, - download_directory=Path('data/spoc_transit_experiment/validation/positives'), - sectors=sectors, - limit=200) - download_spoc_light_curves_for_tic_ids( - tic_ids=negative_validation_tic_ids, - download_directory=Path('data/spoc_transit_experiment/validation/negatives'), - sectors=sectors, - limit=600) - download_spoc_light_curves_for_tic_ids( - tic_ids=positive_test_tic_ids, - download_directory=Path('data/spoc_transit_experiment/test/positives'), - sectors=sectors, - limit=200) - download_spoc_light_curves_for_tic_ids( - tic_ids=negative_test_tic_ids, - download_directory=Path('data/spoc_transit_experiment/test/negatives'), - sectors=sectors, - limit=600) - # In this toy example, we reuse our test light curves as infer light curves. In a real world case, you would likely - # want to infer on cases you don't already know the answer for. - infer_directory_path = Path('data/spoc_transit_experiment/infer') - infer_directory_path.mkdir(exist_ok=True, parents=True) - for light_curve_path in Path('data/spoc_transit_experiment/test').glob('**/*.fits'): - shutil.copy(light_curve_path, infer_directory_path.joinpath(light_curve_path.name)) - - -if __name__ == '__main__': - main() diff --git a/examples/transit_dataset.py b/examples/transit_dataset.py deleted file mode 100644 index bc9d9e20..00000000 --- a/examples/transit_dataset.py +++ /dev/null @@ -1,90 +0,0 @@ -from pathlib import Path - -from qusi.data import FiniteStandardLightCurveObservationDataset, LightCurveDataset, LightCurveObservationCollection -from qusi.experimental.application.tess import TessMissionLightCurve - - -def get_positive_train_paths(): - return list(Path('data/spoc_transit_experiment/train/positives').glob('*.fits')) - - -def get_negative_train_paths(): - return list(Path('data/spoc_transit_experiment/train/negatives').glob('*.fits')) - - -def get_positive_validation_paths(): - return list(Path('data/spoc_transit_experiment/validation/positives').glob('*.fits')) - - -def get_negative_validation_paths(): - return list(Path('data/spoc_transit_experiment/validation/negatives').glob('*.fits')) - - -def get_negative_test_paths(): - return list(Path('data/spoc_transit_experiment/test/negatives').glob('*.fits')) - - -def get_positive_test_paths(): - return list(Path('data/spoc_transit_experiment/test/positives').glob('*.fits')) - - -def get_infer_paths(): - return list(Path('data/spoc_transit_experiment/infer').glob('*.fits')) - - -def load_times_and_fluxes_from_path(path): - light_curve = TessMissionLightCurve.from_path(path) - return light_curve.times, light_curve.fluxes - - -def positive_label_function(path): - return 1 - - -def negative_label_function(path): - return 0 - - -def get_transit_train_dataset(): - positive_train_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_train_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) - negative_train_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_negative_train_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=negative_label_function) - train_light_curve_dataset = LightCurveDataset.new( - standard_light_curve_collections=[positive_train_light_curve_collection, - negative_train_light_curve_collection]) - return train_light_curve_dataset - - -def get_transit_validation_dataset(): - positive_validation_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_validation_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) - negative_validation_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_negative_validation_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=negative_label_function) - validation_light_curve_dataset = LightCurveDataset.new( - standard_light_curve_collections=[positive_validation_light_curve_collection, - negative_validation_light_curve_collection]) - return validation_light_curve_dataset - - -def get_transit_finite_test_dataset(): - positive_test_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_test_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) - negative_test_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_negative_test_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=negative_label_function) - test_light_curve_dataset = FiniteStandardLightCurveObservationDataset.new( - light_curve_collections=[positive_test_light_curve_collection, - negative_test_light_curve_collection]) - return test_light_curve_dataset diff --git a/examples/transit_finite_dataset_test.py b/examples/transit_finite_dataset_test.py deleted file mode 100644 index f09b1aaf..00000000 --- a/examples/transit_finite_dataset_test.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch -from torch.nn import BCELoss -from torchmetrics.classification import BinaryAccuracy - -from qusi.session import finite_datasets_test_session, get_device -from qusi.model import Hadryss - -from transit_dataset import get_transit_finite_test_dataset - -def main(): - test_light_curve_dataset = get_transit_finite_test_dataset() - model = Hadryss.new() - device = get_device() - model.load_state_dict(torch.load('sessions/_latest_model.pt', map_location=device)) - metric_functions = [BinaryAccuracy(), BCELoss()] - results = finite_datasets_test_session(test_datasets=[test_light_curve_dataset], model=model, - metric_functions=metric_functions, batch_size=100, device=device) - print(f'Binary accuracy: {results[0][0]}') - print(f'Binary cross entropy: {results[0][1]}') - - -if __name__ == '__main__': - main() diff --git a/examples/transit_infer.py b/examples/transit_infer.py deleted file mode 100644 index 7e2a3923..00000000 --- a/examples/transit_infer.py +++ /dev/null @@ -1,32 +0,0 @@ -from pathlib import Path - -import torch - -from qusi.data import FiniteStandardLightCurveDataset, LightCurveCollection -from qusi.model import Hadryss -from qusi.session import get_device, infer_session -from transit_dataset import load_times_and_fluxes_from_path - - -def main(): - infer_light_curve_collection = LightCurveCollection.new( - get_paths_function=get_infer_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path) - - test_light_curve_dataset = FiniteStandardLightCurveDataset.new( - light_curve_collections=[infer_light_curve_collection]) - - model = Hadryss.new() - device = get_device() - model.load_state_dict(torch.load('sessions/_latest_model.pt', map_location=device)) - confidences = infer_session(infer_datasets=[test_light_curve_dataset], model=model, - batch_size=100, device=device)[0] - paths = list(get_infer_paths()) - paths_with_confidences = zip(paths, confidences) - sorted_paths_with_confidences = sorted( - paths_with_confidences, key=lambda path_with_confidence: path_with_confidence[1], reverse=True) - print(sorted_paths_with_confidences) - - -if __name__ == '__main__': - main() diff --git a/examples/transit_infinite_dataset_test.py b/examples/transit_infinite_dataset_test.py deleted file mode 100644 index 98484286..00000000 --- a/examples/transit_infinite_dataset_test.py +++ /dev/null @@ -1,96 +0,0 @@ -from pathlib import Path - -import numpy as np -import torch -from torch.nn import BCELoss, Module -from torch.types import Device -from torch.utils.data import DataLoader -from torchmetrics.classification import BinaryAccuracy - -from qusi.model import Hadryss -from qusi.session import get_device -from qusi.data import LightCurveObservationCollection -from qusi.data import LightCurveDataset -from qusi.experimental.application.tess import TessMissionLightCurve - - -def get_negative_test_paths(): - return list(Path('data/spoc_transit_experiment/test/negatives').glob('*.fits')) - - -def get_positive_test_paths(): - return list(Path('data/spoc_transit_experiment/test/positives').glob('*.fits')) - - -def load_times_and_fluxes_from_path(path: Path) -> (np.ndarray, np.ndarray): - light_curve = TessMissionLightCurve.from_path(path) - return light_curve.times, light_curve.fluxes - - -def positive_label_function(_path: Path) -> int: - return 1 - - -def negative_label_function(_path: Path) -> int: - return 0 - - -def main(): - positive_test_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_test_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) - negative_test_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_negative_test_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=negative_label_function) - - test_light_curve_dataset = LightCurveDataset.new( - standard_light_curve_collections=[positive_test_light_curve_collection, - negative_test_light_curve_collection]) - - model = Hadryss.new() - device = get_device() - model.load_state_dict(torch.load('sessions/pleasant-lion-32_latest_model.pt', map_location=device)) - metric_functions = [BinaryAccuracy(), BCELoss()] - results = infinite_datasets_test_session(test_datasets=[test_light_curve_dataset], model=model, - metric_functions=metric_functions, batch_size=100, device=device, - steps=100) - return results - - -def infinite_datasets_test_session(test_datasets: list[LightCurveDataset], model: Module, - metric_functions: list[Module], *, batch_size: int, device: Device, steps: int): - test_dataloaders: list[DataLoader] = [] - for test_dataset in test_datasets: - test_dataloaders.append(DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)) - model.eval() - results = [] - for test_dataloader in test_dataloaders: - result = infinite_dataset_test_phase(test_dataloader, model, metric_functions, device=device, steps=steps) - results.append(result) - return results - - -def infinite_dataset_test_phase(dataloader, model: Module, metric_functions: list[Module], device: Device, steps: int): - batch_count = 0 - metric_totals = torch.zeros(size=[len(metric_functions)]) - model.eval() - with torch.no_grad(): - for input_features, targets in dataloader: - input_features = input_features.to(device, non_blocking=True) - targets = targets.to(device, non_blocking=True) - predicted_targets = model(input_features) - for metric_function_index, metric_function in enumerate(metric_functions): - batch_metric_value = metric_function(predicted_targets.to(device, non_blocking=True), - targets) - metric_totals[metric_function_index] += batch_metric_value.to('cpu', non_blocking=True) - batch_count += 1 - if batch_count >= steps: - break - cycle_metric_values = metric_totals / batch_count - return cycle_metric_values - - -if __name__ == '__main__': - main() diff --git a/examples/transit_light_curve_visualization.py b/examples/transit_light_curve_visualization.py deleted file mode 100644 index 42d0a552..00000000 --- a/examples/transit_light_curve_visualization.py +++ /dev/null @@ -1,20 +0,0 @@ -from pathlib import Path - -from bokeh.io import show -from bokeh.plotting import figure as Figure - -from qusi.experimental.application.tess import TessMissionLightCurve - - -def main(): - light_curve_path = Path( - 'data/spoc_transit_experiment/train/positives/hlsp_tess-spoc_tess_phot_0000000004605846-s0044_tess_v1_lc.fits') - light_curve = TessMissionLightCurve.from_path(light_curve_path) - light_curve_figure = Figure(x_axis_label='Time (BTJD)', y_axis_label='Flux') - light_curve_figure.circle(x=light_curve.times, y=light_curve.fluxes) - light_curve_figure.line(x=light_curve.times, y=light_curve.fluxes, line_alpha=0.3) - show(light_curve_figure) - - -if __name__ == '__main__': - main() diff --git a/examples/transit_train.py b/examples/transit_train.py deleted file mode 100644 index 77a8b5d0..00000000 --- a/examples/transit_train.py +++ /dev/null @@ -1,17 +0,0 @@ -from qusi.model import Hadryss -from qusi.session import TrainHyperparameterConfiguration, train_session - -from transit_dataset import get_transit_train_dataset, get_transit_validation_dataset - -def main(): - train_light_curve_dataset = get_transit_train_dataset() - validation_light_curve_dataset = get_transit_validation_dataset() - model = Hadryss.new() - train_hyperparameter_configuration = TrainHyperparameterConfiguration.new( - batch_size=100, cycles=20, train_steps_per_cycle=100, validation_steps_per_cycle=10) - train_session(train_datasets=[train_light_curve_dataset], validation_datasets=[validation_light_curve_dataset], - model=model, hyperparameter_configuration=train_hyperparameter_configuration) - - -if __name__ == '__main__': - main() From d8945e842974c0b3839940f157bbe6f473dc0090 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Thu, 9 May 2024 14:31:57 -0400 Subject: [PATCH 14/21] Add binary AUROC metric --- pyproject.toml | 3 ++- src/qusi/internal/train_session.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6b3e2de8..fe9d2221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,8 @@ dependencies = [ "sphinx>=6.1.3", "backports.strenum", "typing_extensions", - "myst-parser" + "myst-parser", + "torcheval>=0.0.7", ] [build-system] diff --git a/src/qusi/internal/train_session.py b/src/qusi/internal/train_session.py index 4366d4a6..4b7f4110 100644 --- a/src/qusi/internal/train_session.py +++ b/src/qusi/internal/train_session.py @@ -9,7 +9,7 @@ from torch.nn import BCELoss, Module from torch.optim import AdamW from torch.utils.data import DataLoader -from torchmetrics.classification import BinaryAccuracy +from torcheval.metrics import BinaryAccuracy, BinaryAUROC import wandb from qusi.internal.light_curve_dataset import InterleavedDataset, LightCurveDataset @@ -49,7 +49,7 @@ def train_session( if loss_function is None: loss_function = BCELoss() if metric_functions is None: - metric_functions = [BinaryAccuracy()] + metric_functions = [BinaryAccuracy(), BinaryAUROC()] set_up_default_logger() wandb_init( process_rank=0, From a1accd5f0dd585969b58d1f62809982b80c53724 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Thu, 9 May 2024 14:32:28 -0400 Subject: [PATCH 15/21] Remove torchmetrics requirement --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fe9d2221..60cc0360 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ dependencies = [ "torch>=2.0.1", "torchvision>=0.15.2", "polars>=0.19.10", - "torchmetrics>=1.2.0", "stringcase>=1.2.0", "atpublic>=4.0", "pytest-pycharm>=0.7.0", From 98c8b8ce94b96999f20b0d1c01704340158cf5a8 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Thu, 9 May 2024 22:50:29 -0400 Subject: [PATCH 16/21] Switch back to torchmetrics --- pyproject.toml | 1 + src/qusi/internal/light_curve_collection.py | 4 ++++ src/qusi/internal/train_session.py | 8 +++++--- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 60cc0360..fe9d2221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "torch>=2.0.1", "torchvision>=0.15.2", "polars>=0.19.10", + "torchmetrics>=1.2.0", "stringcase>=1.2.0", "atpublic>=4.0", "pytest-pycharm>=0.7.0", diff --git a/src/qusi/internal/light_curve_collection.py b/src/qusi/internal/light_curve_collection.py index 531810ab..6bbb479d 100644 --- a/src/qusi/internal/light_curve_collection.py +++ b/src/qusi/internal/light_curve_collection.py @@ -156,6 +156,8 @@ def light_curve_iter(self) -> Iterator[LightCurve]: :return: The iterable of the light curves. """ light_curve_paths = self.path_getter.get_shuffled_paths() + if len(light_curve_paths) == 0: + raise ValueError('LightCurveCollection returned no paths.') for light_curve_path in light_curve_paths: times, fluxes = self.load_times_and_fluxes_from_path_function( light_curve_path @@ -264,6 +266,8 @@ def observation_iter(self) -> Iterator[LightCurveObservation]: :return: The iterable of the light curves. """ light_curve_paths = self.path_getter.get_shuffled_paths() + if len(light_curve_paths) == 0: + raise ValueError('LightCurveObservationCollection returned no paths.') for light_curve_path in light_curve_paths: times, fluxes = self.light_curve_collection.load_times_and_fluxes_from_path( light_curve_path diff --git a/src/qusi/internal/train_session.py b/src/qusi/internal/train_session.py index 4b7f4110..136355e7 100644 --- a/src/qusi/internal/train_session.py +++ b/src/qusi/internal/train_session.py @@ -9,9 +9,10 @@ from torch.nn import BCELoss, Module from torch.optim import AdamW from torch.utils.data import DataLoader -from torcheval.metrics import BinaryAccuracy, BinaryAUROC import wandb +from torchmetrics.classification import BinaryAccuracy, BinaryAUROC + from qusi.internal.light_curve_dataset import InterleavedDataset, LightCurveDataset from qusi.internal.logging import set_up_default_logger from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration @@ -51,14 +52,15 @@ def train_session( if metric_functions is None: metric_functions = [BinaryAccuracy(), BinaryAUROC()] set_up_default_logger() + sessions_directory = Path("sessions") + sessions_directory.mkdir(exist_ok=True) wandb_init( process_rank=0, project=logging_configuration.wandb_project, entity=logging_configuration.wandb_entity, settings=wandb.Settings(start_method="thread"), + dir=sessions_directory, ) - sessions_directory = Path("sessions") - sessions_directory.mkdir(exist_ok=True) train_dataset = InterleavedDataset.new(*train_datasets) torch.multiprocessing.set_start_method("spawn") debug = False From 60bff1d673b52d425857c0d881fae0266d0a1d50 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Tue, 14 May 2024 18:46:43 -0400 Subject: [PATCH 17/21] Add better name casing for acronyms --- src/qusi/internal/logging.py | 34 ++++++++++++++++++++++++++---- src/qusi/internal/train_session.py | 10 +-------- tests/unit_tests/logging.py | 13 ++++++++++++ 3 files changed, 44 insertions(+), 13 deletions(-) create mode 100644 tests/unit_tests/logging.py diff --git a/src/qusi/internal/logging.py b/src/qusi/internal/logging.py index 25b63467..1a4b9f26 100644 --- a/src/qusi/internal/logging.py +++ b/src/qusi/internal/logging.py @@ -1,12 +1,17 @@ +from __future__ import annotations + import datetime import logging +import re import sys +import stringcase + logger_initialized = False def create_default_formatter() -> logging.Formatter: - formatter = logging.Formatter("qusi [{asctime} {levelname} {name}] {message}", style="{") + formatter = logging.Formatter('qusi [{asctime} {levelname} {name}] {message}', style='{') return formatter @@ -17,7 +22,7 @@ def set_up_default_logger(): handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.DEBUG) handler.setFormatter(formatter) - logger = logging.getLogger("qusi") + logger = logging.getLogger('qusi') logger.addHandler(handler) logger.setLevel(logging.INFO) logger.propagate = False @@ -26,7 +31,28 @@ def set_up_default_logger(): def excepthook(exc_type, exc_value, exc_traceback): - logger = logging.getLogger("qusi") - logger.critical(f"Uncaught exception at {datetime.datetime.now().astimezone()}:") + logger = logging.getLogger('qusi') + logger.critical(f'Uncaught exception at {datetime.datetime.now().astimezone()}:') logger.handlers[0].flush() sys.__excepthook__(exc_type, exc_value, exc_traceback) + + +def get_metric_name(metric_function): + metric_name = type(metric_function).__name__ + metric_name = camel_case_acronyms(metric_name) + metric_name = stringcase.snakecase(metric_name) + metric_name = metric_name.replace('_metric', '').replace('_loss', '') + return metric_name + + +def camel_case_acronyms(string: str) -> str: + def camel_case_single_acronym(string: str | None) -> str: + if string is None: + return '' + return stringcase.capitalcase(string.lower()) + + return re.sub( + r'([A-Z]{2,})([A-Z][a-z])|([A-Z]{2,})', + lambda match: ''.join(map(camel_case_single_acronym, [match.group(1), match.group(2), match.group(3)])), + string + ) diff --git a/src/qusi/internal/train_session.py b/src/qusi/internal/train_session.py index 136355e7..b821f15e 100644 --- a/src/qusi/internal/train_session.py +++ b/src/qusi/internal/train_session.py @@ -4,7 +4,6 @@ from pathlib import Path import numpy as np -import stringcase import torch from torch.nn import BCELoss, Module from torch.optim import AdamW @@ -14,7 +13,7 @@ from torchmetrics.classification import BinaryAccuracy, BinaryAUROC from qusi.internal.light_curve_dataset import InterleavedDataset, LightCurveDataset -from qusi.internal.logging import set_up_default_logger +from qusi.internal.logging import set_up_default_logger, get_metric_name from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration from qusi.internal.train_logging_configuration import TrainLoggingConfiguration from qusi.internal.wandb_liaison import wandb_commit, wandb_init, wandb_log @@ -179,13 +178,6 @@ def train_phase( ) -def get_metric_name(metric_function): - metric_name = type(metric_function).__name__ - metric_name = stringcase.snakecase(metric_name) - metric_name = metric_name.replace("_metric", "").replace("_loss", "") - return metric_name - - def validation_phase( dataloader, model, loss_function, metric_functions: list[Module], steps, device ): diff --git a/tests/unit_tests/logging.py b/tests/unit_tests/logging.py new file mode 100644 index 00000000..5e64016d --- /dev/null +++ b/tests/unit_tests/logging.py @@ -0,0 +1,13 @@ +from torch.nn import BCELoss +from torchmetrics.classification import BinaryAUROC + +from qusi.internal.logging import camel_case_acronyms, get_metric_name + + +def test_camel_case_acronyms(): + assert camel_case_acronyms('BCEntropy') == 'BcEntropy' + assert camel_case_acronyms('BinaryAUROC') == 'BinaryAuroc' + +def test_get_metric_name(): + assert get_metric_name(BCELoss()) == 'bce' + assert get_metric_name(BinaryAUROC()) == 'binary_auroc' From 250f2a648bb405d1e24e74e60024d9cc0f6e6875 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Wed, 15 May 2024 14:21:10 -0400 Subject: [PATCH 18/21] Move transforms to transforms file --- src/qusi/internal/light_curve_dataset.py | 52 +-------------------- src/qusi/internal/light_curve_transforms.py | 46 ++++++++++++++++++ src/qusi/transform.py | 5 +- 3 files changed, 50 insertions(+), 53 deletions(-) diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index f320cd5c..b0f435c9 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -32,7 +32,7 @@ ) from qusi.internal.light_curve_transforms import ( from_light_curve_observation_to_fluxes_array_and_label_array, - pair_array_to_tensor, + pair_array_to_tensor, normalize_tensor_by_modified_z_score, make_uniform_length, ) if TYPE_CHECKING: @@ -337,56 +337,6 @@ def default_light_curve_post_injection_transform( return x -def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor: - """ - Normalizes a tensor by a modified z-score. That is, normalizes the values of the tensor based on the median - absolute deviation. - - :param tensor: The tensor to normalize. - :return: The normalized tensor. - """ - median = torch.median(tensor) - deviation_from_median = tensor - median - absolute_deviation_from_median = torch.abs(deviation_from_median) - median_absolute_deviation_from_median = torch.median(absolute_deviation_from_median) - if median_absolute_deviation_from_median != 0: - modified_z_score = ( - 0.6745 * deviation_from_median / median_absolute_deviation_from_median - ) - else: - modified_z_score = torch.zeros_like(tensor) - return modified_z_score - - -def make_uniform_length( - example: np.ndarray, length: int, *, randomize: bool = True -) -> np.ndarray: - """Makes the example a specific length, by clipping those too large and repeating those too small.""" - if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases. - raise ValueError( - f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}" - ) - if randomize: - example = randomly_roll_elements(example) - if example.shape[0] == length: - pass - elif example.shape[0] > length: - example = example[:length] - else: - elements_to_repeat = length - example.shape[0] - if len(example.shape) == 1: - example = np.pad(example, (0, elements_to_repeat), mode="wrap") - else: - example = np.pad(example, ((0, elements_to_repeat), (0, 0)), mode="wrap") - return example - - -def randomly_roll_elements(example: np.ndarray) -> np.ndarray: - """Randomly rolls the elements.""" - example = np.roll(example, np.random.randint(example.shape[0]), axis=0) - return example - - class OutOfBoundsInjectionHandlingMethod(Enum): """ An enum of approaches for handling cases where the injectable signal is shorter than the injectee signal. diff --git a/src/qusi/internal/light_curve_transforms.py b/src/qusi/internal/light_curve_transforms.py index dae294da..f5afa28c 100644 --- a/src/qusi/internal/light_curve_transforms.py +++ b/src/qusi/internal/light_curve_transforms.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import numpy.typing as npt import torch @@ -38,3 +40,47 @@ def randomly_roll_elements(example: np.ndarray) -> np.ndarray: """Randomly rolls the elements.""" example = np.roll(example, np.random.randint(example.shape[0]), axis=0) return example + + +def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor: + """ + Normalizes a tensor by a modified z-score. That is, normalizes the values of the tensor based on the median + absolute deviation. + + :param tensor: The tensor to normalize. + :return: The normalized tensor. + """ + median = torch.median(tensor) + deviation_from_median = tensor - median + absolute_deviation_from_median = torch.abs(deviation_from_median) + median_absolute_deviation_from_median = torch.median(absolute_deviation_from_median) + if median_absolute_deviation_from_median != 0: + modified_z_score = ( + 0.6745 * deviation_from_median / median_absolute_deviation_from_median + ) + else: + modified_z_score = torch.zeros_like(tensor) + return modified_z_score + + +def make_uniform_length( + example: np.ndarray, length: int, *, randomize: bool = True +) -> np.ndarray: + """Makes the example a specific length, by clipping those too large and repeating those too small.""" + if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases. + raise ValueError( + f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}" + ) + if randomize: + example = randomly_roll_elements(example) + if example.shape[0] == length: + pass + elif example.shape[0] > length: + example = example[:length] + else: + elements_to_repeat = length - example.shape[0] + if len(example.shape) == 1: + example = np.pad(example, (0, elements_to_repeat), mode="wrap") + else: + example = np.pad(example, ((0, elements_to_repeat), (0, 0)), mode="wrap") + return example diff --git a/src/qusi/transform.py b/src/qusi/transform.py index 3f929510..ba3b8638 100644 --- a/src/qusi/transform.py +++ b/src/qusi/transform.py @@ -3,17 +3,18 @@ """ from qusi.internal.light_curve import randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve from qusi.internal.light_curve_dataset import default_light_curve_post_injection_transform, \ - default_light_curve_observation_post_injection_transform, make_uniform_length + default_light_curve_observation_post_injection_transform from qusi.internal.light_curve_observation import remove_nan_flux_data_points_from_light_curve_observation, \ randomly_roll_light_curve_observation from qusi.internal.light_curve_transforms import from_light_curve_observation_to_fluxes_array_and_label_array, \ - pair_array_to_tensor + pair_array_to_tensor, make_uniform_length, normalize_tensor_by_modified_z_score __all__ = [ 'default_light_curve_post_injection_transform', 'default_light_curve_observation_post_injection_transform', 'from_light_curve_observation_to_fluxes_array_and_label_array', 'make_uniform_length', + 'normalize_tensor_by_modified_z_score', 'pair_array_to_tensor', 'randomly_roll_light_curve', 'randomly_roll_light_curve_observation', From 63fcfc5c3f6f143c5a8e4a2313d5f6dfbaaed465 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Wed, 15 May 2024 14:26:17 -0400 Subject: [PATCH 19/21] Remove randomization from make_uniform_length and leave that to a separate randomly roll call --- docs/source/tutorials/crafting_standard_datasets.md | 2 +- src/qusi/internal/light_curve_dataset.py | 4 ++-- src/qusi/internal/light_curve_transforms.py | 6 +----- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/docs/source/tutorials/crafting_standard_datasets.md b/docs/source/tutorials/crafting_standard_datasets.md index 89273d95..9c06cf41 100644 --- a/docs/source/tutorials/crafting_standard_datasets.md +++ b/docs/source/tutorials/crafting_standard_datasets.md @@ -39,7 +39,7 @@ def default_light_curve_observation_post_injection_transform(x: LightCurveObserv if randomize: x = randomly_roll_light_curve_observation(x) x = from_light_curve_observation_to_fluxes_array_and_label_array(x) - x = (make_uniform_length(x[0], length=length, randomize=randomize), x[1]) # Make the fluxes a uniform length. + x = (make_uniform_length(x[0], length=length), x[1]) x = pair_array_to_tensor(x) x = (normalize_tensor_by_modified_z_score(x[0]), x[1]) return x diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index b0f435c9..6aad6bc7 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -306,7 +306,7 @@ def default_light_curve_observation_post_injection_transform( if randomize: x = randomly_roll_light_curve_observation(x) x = from_light_curve_observation_to_fluxes_array_and_label_array(x) - x = (make_uniform_length(x[0], length=length, randomize=randomize), x[1]) # Make the fluxes a uniform length. + x = (make_uniform_length(x[0], length=length), x[1]) # Make the fluxes a uniform length. x = pair_array_to_tensor(x) x = (normalize_tensor_by_modified_z_score(x[0]), x[1]) return x @@ -331,7 +331,7 @@ def default_light_curve_post_injection_transform( if randomize: x = randomly_roll_light_curve(x) x = x.fluxes - x = make_uniform_length(x, length=length, randomize=randomize) + x = make_uniform_length(x, length=length) x = torch.tensor(x, dtype=torch.float32) x = normalize_tensor_by_modified_z_score(x) return x diff --git a/src/qusi/internal/light_curve_transforms.py b/src/qusi/internal/light_curve_transforms.py index f5afa28c..7abdeeb7 100644 --- a/src/qusi/internal/light_curve_transforms.py +++ b/src/qusi/internal/light_curve_transforms.py @@ -63,16 +63,12 @@ def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor: return modified_z_score -def make_uniform_length( - example: np.ndarray, length: int, *, randomize: bool = True -) -> np.ndarray: +def make_uniform_length(example: np.ndarray, length: int) -> np.ndarray: """Makes the example a specific length, by clipping those too large and repeating those too small.""" if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases. raise ValueError( f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}" ) - if randomize: - example = randomly_roll_elements(example) if example.shape[0] == length: pass elif example.shape[0] > length: From 176d9b0f6c40708895aa9c02cd656111bc75670a Mon Sep 17 00:00:00 2001 From: golmschenk Date: Wed, 15 May 2024 14:31:15 -0400 Subject: [PATCH 20/21] Fix tutorial code --- docs/source/tutorials/crafting_standard_datasets.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/tutorials/crafting_standard_datasets.md b/docs/source/tutorials/crafting_standard_datasets.md index 9c06cf41..99e8dcf5 100644 --- a/docs/source/tutorials/crafting_standard_datasets.md +++ b/docs/source/tutorials/crafting_standard_datasets.md @@ -20,8 +20,8 @@ Then, were we specify the construction of our dataset, we'll add an additional i ```python train_light_curve_dataset = LightCurveObservationDataset.new( light_curve_collections=[positive_train_light_curve_collection, - negative_train_light_curve_collection]) -post_injection_transform = partial(default_light_curve_post_injection_transform, length=4000) + negative_train_light_curve_collection], + post_injection_transform = partial(default_light_curve_post_injection_transform, length=4000) ) ``` From 117d200acab0ee860b3c05143dcc395ad81fb4d9 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Fri, 17 May 2024 00:14:00 -0400 Subject: [PATCH 21/21] Switch the tutorial to use the example project repository --- ...sit_identification_dataset_construction.md | 10 ++++---- ...identification_with_prebuilt_components.md | 25 +++++++++++-------- src/qusi/internal/train_session.py | 1 + 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/docs/source/tutorials/basic_transit_identification_dataset_construction.md b/docs/source/tutorials/basic_transit_identification_dataset_construction.md index 4b9984cf..033799ad 100644 --- a/docs/source/tutorials/basic_transit_identification_dataset_construction.md +++ b/docs/source/tutorials/basic_transit_identification_dataset_construction.md @@ -21,7 +21,7 @@ def get_positive_train_paths(): This functions says to create a `Path` object for a directory at `data/spoc_transit_experiment/train/positives`. Then, it obtains all the files ending with the `.fits` extension. It puts that in a list and returns that list. In particular, `qusi` expects a function that takes no input parameters and outputs a list of `Path`s. -In our example code, we've split the data based on if it's train, validation, or test data and we've split the data based on if it's positive or negative data. And we provide a function for each of the 6 permutations of this, which is almost identical to what's above. You can see the above function and other 5 similar functions near the top of `examples/transit_dataset.py`. +In our example code, we've split the data based on if it's train, validation, or test data and we've split the data based on if it's positive or negative data. And we provide a function for each of the 6 permutations of this, which is almost identical to what's above. You can see the above function and other 5 similar functions near the top of `scripts/transit_dataset.py`. `qusi` is flexible in how the paths are provided, and this construction of having a separate function for each type of data is certainly not the only way of approaching this. Depending on your task, another option might serve better. In another tutorial, we will explore a few example alternatives. However, to better understand those alternatives, it's first useful to see the rest of this dataset construction. @@ -35,7 +35,7 @@ def load_times_and_fluxes_from_path(path): return light_curve.times, light_curve.fluxes ``` -This uses a builtin class in `qusi` that is designed for loading light curves from TESS mission FITS files. However, the important thing is that your function returns two comma separated values, which is a NumPy array of the times and a NumPy array of the fluxes of your light curve. And the function takes a single `Path` object as input. These `Path` objects will be one of the ones we returned from the functions in the previous section. But you can write any code you need to get from a `Path` to the two arrays that represent times and fluxes. For example, if your file is a simple CSV file, it would be easy to use Pandas to load the CSV file and extract the time column and the flux column as two arrays which are then returned at the end of the function. You will see the above function in `examples/transit_dataset.py`. +This uses a builtin class in `qusi` that is designed for loading light curves from TESS mission FITS files. However, the important thing is that your function returns two comma separated values, which is a NumPy array of the times and a NumPy array of the fluxes of your light curve. And the function takes a single `Path` object as input. These `Path` objects will be one of the ones we returned from the functions in the previous section. But you can write any code you need to get from a `Path` to the two arrays that represent times and fluxes. For example, if your file is a simple CSV file, it would be easy to use Pandas to load the CSV file and extract the time column and the flux column as two arrays which are then returned at the end of the function. You will see the above function in `scripts/transit_dataset.py`. ## Creating a function to provide a label for the data @@ -49,7 +49,7 @@ def negative_label_function(path): return 0 ``` -Note, `qusi` expects the label functions to take in a `Path` object as input, even if we don't end up using it. This is because, it allows for more flexible configurations. For example, in a different situation, the data might not be split into positive and negative directories, but instead, the label data might be contained within the user's data file itself. Also, in other cases, this label can also be something other than 0 and 1. The label is whatever the NN is attempting to predict for the input light curve. But for our binary classification case, 0 and 1 are what we want to use. Once again, you can see these functions in `examples/transit_dataset.py`. +Note, `qusi` expects the label functions to take in a `Path` object as input, even if we don't end up using it. This is because, it allows for more flexible configurations. For example, in a different situation, the data might not be split into positive and negative directories, but instead, the label data might be contained within the user's data file itself. Also, in other cases, this label can also be something other than 0 and 1. The label is whatever the NN is attempting to predict for the input light curve. But for our binary classification case, 0 and 1 are what we want to use. Once again, you can see these functions in `scripts/transit_dataset.py`. ## Creating a light curve collection @@ -59,7 +59,7 @@ Now we're going to join the various functions we've just defined into `LightCurv positive_train_light_curve_collection = LightCurveObservationCollection.new() ``` -This defines a collection of labeled light curves where `qusi` knows how to obtain the paths, how to load the times and fluxes of the light curves, and how to load the labels. This `LightCurveObservationCollection.new(...` function takes in the three pieces we just built earlier. Note that you pass in the functions themselves, not the output of the functions. So for the `get_paths_function` parameter, we pass `get_positive_train_paths`, not `get_positive_train_paths()` (notice the difference in parenthesis). `qusi` will call these functions internally. However, the above bit of code is not by itself in `examples/transit_dataset.py` as the rest of the code in this tutorial was. This is because `qusi` doesn't use this collection by itself. It uses it as part of a dataset. We will explain why there's this extra layer in a moment. +This defines a collection of labeled light curves where `qusi` knows how to obtain the paths, how to load the times and fluxes of the light curves, and how to load the labels. This `LightCurveObservationCollection.new(...` function takes in the three pieces we just built earlier. Note that you pass in the functions themselves, not the output of the functions. So for the `get_paths_function` parameter, we pass `get_positive_train_paths`, not `get_positive_train_paths()` (notice the difference in parenthesis). `qusi` will call these functions internally. However, the above bit of code is not by itself in `scripts/transit_dataset.py` as the rest of the code in this tutorial was. This is because `qusi` doesn't use this collection by itself. It uses it as part of a dataset. We will explain why there's this extra layer in a moment. ## Creating a dataset @@ -76,7 +76,7 @@ def get_transit_train_dataset(): This is the function which generates the training dataset we called in the {doc}`/tutorials/basic_transit_identification_with_prebuilt_components` tutorial. The parts of this function are as follows. First, we create the `positive_train_light_curve_collection`. This is exactly what we just saw in the previous section. Next, we create a `negative_train_light_curve_collection`. This is almost identical to its positive counterpart, except now we pass the `get_negative_train_paths` and `negative_label_function` instead of the positive versions. Then there is the `train_light_curve_dataset = LightCurveDataset.new(` line. This creates a `qusi` dataset built from these two collections. The reason the collections are separate is that `LightCurveDataset` has several mechanisms working under-the-hood. Notably for this case, `LightCurveDataset` will balance the two light curve collections. We know of a lot more light curves that don't have planet transits in them than we do light curves that do have planet transits. In the real world case, it's thousands of times more at least. But for a NN, it's usually useful to during the training process to show equal amounts of the positives and negatives. `LightCurveDataset` will do this for us. You may have also noticed that we passed these collections in as the `standard_light_curve_collections` parameter. `LightCurveDataset` also allows for passing different types of collections. Notably, collections can be passed such that light curves from one collection will be injected into another. This is useful for injecting synthetic signals into real telescope data. However, we'll save the injection options for another tutorial. -You can see the above `get_transit_train_dataset` dataset creation function in the `examples/transit_dataset.py` file. The only part of that file we haven't yet looked at in detail is the `get_transit_validation_dataset` and `get_transit_finite_test_dataset` functions. However, these are nearly identical to the above `get_transit_train_dataset` expect using the validation and test path obtaining functions above instead of the train ones. +You can see the above `get_transit_train_dataset` dataset creation function in the `scripts/transit_dataset.py` file. The only part of that file we haven't yet looked at in detail is the `get_transit_validation_dataset` and `get_transit_finite_test_dataset` functions. However, these are nearly identical to the above `get_transit_train_dataset` expect using the validation and test path obtaining functions above instead of the train ones. ## Adjusting this for your own binary classification task diff --git a/docs/source/tutorials/basic_transit_identification_with_prebuilt_components.md b/docs/source/tutorials/basic_transit_identification_with_prebuilt_components.md index a4975db0..9ebf8472 100644 --- a/docs/source/tutorials/basic_transit_identification_with_prebuilt_components.md +++ b/docs/source/tutorials/basic_transit_identification_with_prebuilt_components.md @@ -1,17 +1,22 @@ # Basic transit identification with prebuilt components -This tutorial will get you up and running with a neural network (NN) that can identify transiting exoplanets in data from the Transiting Exoplanet Survey Satellite (TESS). Many of the components used in this example will be prebuilt bits of code that we'll import from the package's example code. However, in later tutorials, we'll walkthrough how you would build each of these pieces yourself and how you would modify it for whatever your use case is. +This tutorial will get you up and running with a neural network (NN) that can identify transiting exoplanets in data from the Transiting Exoplanet Survey Satellite (TESS). Many of the components used in this example will be prebuilt bits of code that we'll import from the package's example code. However, in later tutorials, we'll walk through how you would build each of these pieces yourself and how you would modify it for whatever your use case is. ## Getting the example code -First, create a directory to hold the project named `qusi_example_project`, or some other suitable name. Then get the example scripts from the `qusi` repository. You can download just that directory by clicking [here](https://download-directory.github.io/?url=https%3A%2F%2Fgithub.com%2Fgolmschenk%2Fqusi%2Ftree%2Fmain%2Fexamples). Move this `examples` directory into your project directory so that you have `qusi_example_project/examples`. The remainder of the commands will assume you are running code from the project directory, unless otherwise stated. +First, we'll download some example code and enter that project's directory. To do this, run +```sh +git clone https://github.com/golmschenk/qusi_example_transit_binary_classification.git +cd qusi_example_transit_binary_classification +``` +The remainder of the commands will assume you are running code from the project directory, unless otherwise stated. ## Downloading the dataset -The next thing we'll do is download a dataset of light curves that include cases both with and without transiting planets. To do this, run the example script at `examples/download_spoc_transit_light_curves`. For now, don't worry about how each part of the code works. You can run the script with +The next thing we'll do is download a dataset of light curves that include cases both with and without transiting planets. To do this, run the example script at `scripts/download_spoc_transit_light_curves`. For now, don't worry about how each part of the code works. You can run the script with ```sh -python examples/download_spoc_transit_light_curves.py +python scripts/download_spoc_transit_light_curves.py ``` The main thing to know is that this will create a `data` directory within the project directory and within that will be a `spoc_transit_experiment` directory, referring to the data for the experiment of finding transiting planets within the TESS SPOC data. This will further contain 3 directories. One for train data, one for validation data, and one for test data. Within each of those, it will create a `positive` directory, that will hold the light curves with transits, and a `negative` directory, that will hold the light curves without transits. So the project directory tree now looks like @@ -31,10 +36,10 @@ data examples ``` -Each of these `positive` and `negative` data directories will now contain a set of light curves. The reason why the code in this script is not very important for you to know, is that it's mostly irrelevant for future uses. When you're working on your own problem, you'll obtain your data some other way. And `qusi` is flexible about the data structure, so this directory structure is not required. It's just one way to structure the data. Note, this is a relatively small dataset to make sure it doesn't take very long to get up and running. To get a better result, you'd want to download all known transiting light curves and a much larger collection non-transiting light curves. To quickly visualize one of these light curves, you can use the script at `examples/transit_light_curve_visualization.py`. Due to the available light curves on MAST being updated constantly, the random selection of light curves you downloaded might not include the light curve noted in this example file. Be sure to open the `examples/transit_light_curve_visualization.py` file and update the path to one of the light curves you downloaded. To see a transit case, be sure to select one from one of the `positive` directories. Then run +Each of these `positive` and `negative` data directories will now contain a set of light curves. The reason why the code in this script is not very important for you to know, is that it's mostly irrelevant for future uses. When you're working on your own problem, you'll obtain your data some other way. And `qusi` is flexible about the data structure, so this directory structure is not required. It's just one way to structure the data. Note, this is a relatively small dataset to make sure it doesn't take very long to get up and running. To get a better result, you'd want to download all known transiting light curves and a much larger collection non-transiting light curves. To quickly visualize one of these light curves, you can use the script at `scripts/transit_light_curve_visualization.py`. Due to the available light curves on MAST being updated constantly, the random selection of light curves you downloaded might not include the light curve noted in this example file. Be sure to open the `scripts/transit_light_curve_visualization.py` file and update the path to one of the light curves you downloaded. To see a transit case, be sure to select one from one of the `positive` directories. Then run ```sh -python examples/transit_light_curve_visualization.py +python scripts/transit_light_curve_visualization.py ``` You should see something like @@ -62,7 +67,7 @@ This will only log runs locally. If you choose the offline route, at some point, ## Train the network -Next, we'll look at the `examples/transit_train.py` file. In this script is a `main` function which will train our neural network on our data. The training script has 3 main components: +Next, we'll look at the `scripts/transit_train.py` file. In this script is a `main` function which will train our neural network on our data. The training script has 3 main components: 1. Code to prepare our datasets. 2. Code to prepare the neural network model. @@ -71,7 +76,7 @@ Next, we'll look at the `examples/transit_train.py` file. In this script is a `m Since `qusi` provides both models and and training loop code, the only one of these components that every user will be expected to deal with is preparing the dataset, since you'll eventually want to have `qusi` tackle the task you're interested in which will require you're own data. And the `qusi` dataset component will help make your data more suitable for training a neural network. However, we're going to save how to set up your own dataset (and how these example datasets are created) for the next tutorial. For now, we'll just use the example datasets as is. So, in the example script, you will see the first couple of lines of the `main` function call other functions that produce an example train and validation dataset for us. Then we choose one of the neural network models `qusi` provides (in this case the `Hadryss` model). Then finally, we start the training session. To run this training, simply run the script with: ```sh -python examples/transit_train.py +python scripts/transit_train.py ``` You should see some output showing basic training statistics from the terminal as it runs through the training loop. It will run for as many train cycles as were specified in the script. On every completed cycle, `qusi` will save the latest version of the fitted model to `sessions//latest_model`. @@ -80,10 +85,10 @@ You can also go to your Wandb project to see the metrics over the course of the ## Test the fitted model -A "fitted model" is a model which has been trained, or fitted, on some training data. Next, we'll take the fitted model we produced during training, and test it on data it didn't see during the training process. This is what happens in the `examples/transit_finite_dataset_test.py` script. The `main` function will look semi-similar to from the training script. Again, we'll defer how the dataset is produced until the next tutorial. Then we create the model as we did before, but this time we load the fitted parameters of the model from the saved file. Here, you will need to update the script to point to your saved model produced in the last section. Then we can run the script with +A "fitted model" is a model which has been trained, or fitted, on some training data. Next, we'll take the fitted model we produced during training, and test it on data it didn't see during the training process. This is what happens in the `scripts/transit_finite_dataset_test.py` script. The `main` function will look semi-similar to from the training script. Again, we'll defer how the dataset is produced until the next tutorial. Then we create the model as we did before, but this time we load the fitted parameters of the model from the saved file. Here, you will need to update the script to point to your saved model produced in the last section. Then we can run the script with ```sh -python examples/transit_finite_dataset_test.py +python scripts/transit_finite_dataset_test.py ``` This will run the network on the test data, producing the metrics that are requested in the file. diff --git a/src/qusi/internal/train_session.py b/src/qusi/internal/train_session.py index b821f15e..50067025 100644 --- a/src/qusi/internal/train_session.py +++ b/src/qusi/internal/train_session.py @@ -102,6 +102,7 @@ def train_session( for metric_function in metric_functions ] for _cycle_index in range(hyperparameter_configuration.cycles): + logger.info(f'Cycle {_cycle_index}') train_phase( dataloader=train_dataloader, model=model,