diff --git a/docs/source/conf.py b/docs/source/conf.py index 8f582c59..f264459e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,6 +26,10 @@ templates_path = ["_templates"] exclude_patterns = [] source_suffix = [".rst", ".md"] +autodoc_class_signature = 'separated' +autodoc_default_options = { + 'special-members': None, +} # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output diff --git a/docs/source/reference_index.md b/docs/source/reference_index.md index 0d342294..162c5229 100644 --- a/docs/source/reference_index.md +++ b/docs/source/reference_index.md @@ -1,6 +1,23 @@ # Reference ```{eval-rst} -.. automodule:: qusi.light_curve - :members: +.. autoclass:: qusi.data.LightCurve + :members: new +.. autoclass:: qusi.data.LightCurveCollection + :members: new +.. autoclass:: qusi.data.LightCurveDataset + :members: new +.. autoclass:: qusi.data.LightCurveObservationCollection + :members: new +.. autoclass:: qusi.data.FiniteStandardLightCurveDataset + :members: new +.. autoclass:: qusi.data.FiniteStandardLightCurveObservationDataset + :members: new +.. autoclass:: qusi.model.Hadryss + :members: new +.. autofunction:: qusi.session.get_device +.. autofunction:: qusi.session.infer_session +.. autofunction:: qusi.session.train_session +.. autoclass:: qusi.session.TrainHyperparameterConfiguration + :members: new ``` diff --git a/docs/source/tutorials/basic_transit_identification_dataset_construction.md b/docs/source/tutorials/basic_transit_identification_dataset_construction.md index 1a61e93d..033799ad 100644 --- a/docs/source/tutorials/basic_transit_identification_dataset_construction.md +++ b/docs/source/tutorials/basic_transit_identification_dataset_construction.md @@ -21,7 +21,7 @@ def get_positive_train_paths(): This functions says to create a `Path` object for a directory at `data/spoc_transit_experiment/train/positives`. Then, it obtains all the files ending with the `.fits` extension. It puts that in a list and returns that list. In particular, `qusi` expects a function that takes no input parameters and outputs a list of `Path`s. -In our example code, we've split the data based on if it's train, validation, or test data and we've split the data based on if it's positive or negative data. And we provide a function for each of the 6 permutations of this, which is almost identical to what's above. You can see the above function and other 5 similar functions near the top of `examples/transit_dataset.py`. +In our example code, we've split the data based on if it's train, validation, or test data and we've split the data based on if it's positive or negative data. And we provide a function for each of the 6 permutations of this, which is almost identical to what's above. You can see the above function and other 5 similar functions near the top of `scripts/transit_dataset.py`. `qusi` is flexible in how the paths are provided, and this construction of having a separate function for each type of data is certainly not the only way of approaching this. Depending on your task, another option might serve better. In another tutorial, we will explore a few example alternatives. However, to better understand those alternatives, it's first useful to see the rest of this dataset construction. @@ -35,7 +35,7 @@ def load_times_and_fluxes_from_path(path): return light_curve.times, light_curve.fluxes ``` -This uses a builtin class in `qusi` that is designed for loading light curves from TESS mission FITS files. However, the important thing is that your function returns two comma separated values, which is a NumPy array of the times and a NumPy array of the fluxes of your light curve. And the function takes a single `Path` object as input. These `Path` objects will be one of the ones we returned from the functions in the previous section. But you can write any code you need to get from a `Path` to the two arrays that represent times and fluxes. For example, if your file is a simple CSV file, it would be easy to use Pandas to load the CSV file and extract the time column and the flux column as two arrays which are then returned at the end of the function. You will see the above function in `examples/transit_dataset.py`. +This uses a builtin class in `qusi` that is designed for loading light curves from TESS mission FITS files. However, the important thing is that your function returns two comma separated values, which is a NumPy array of the times and a NumPy array of the fluxes of your light curve. And the function takes a single `Path` object as input. These `Path` objects will be one of the ones we returned from the functions in the previous section. But you can write any code you need to get from a `Path` to the two arrays that represent times and fluxes. For example, if your file is a simple CSV file, it would be easy to use Pandas to load the CSV file and extract the time column and the flux column as two arrays which are then returned at the end of the function. You will see the above function in `scripts/transit_dataset.py`. ## Creating a function to provide a label for the data @@ -49,20 +49,17 @@ def negative_label_function(path): return 0 ``` -Note, `qusi` expects the label functions to take in a `Path` object as input, even if we don't end up using it. This is because, it allows for more flexible configurations. For example, in a different situation, the data might not be split into positive and negative directories, but instead, the label data might be contained within the user's data file itself. Also, in other cases, this label can also be something other than 0 and 1. The label is whatever the NN is attempting to predict for the input light curve. But for our binary classification case, 0 and 1 are what we want to use. Once again, you can see these functions in `examples/transit_dataset.py`. +Note, `qusi` expects the label functions to take in a `Path` object as input, even if we don't end up using it. This is because, it allows for more flexible configurations. For example, in a different situation, the data might not be split into positive and negative directories, but instead, the label data might be contained within the user's data file itself. Also, in other cases, this label can also be something other than 0 and 1. The label is whatever the NN is attempting to predict for the input light curve. But for our binary classification case, 0 and 1 are what we want to use. Once again, you can see these functions in `scripts/transit_dataset.py`. ## Creating a light curve collection Now we're going to join the various functions we've just defined into `LightCurveObservationCollection`s. For the case of positive train light curves, this looks like: ```python -positive_train_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_train_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) +positive_train_light_curve_collection = LightCurveObservationCollection.new() ``` -This defines a collection of labeled light curves where `qusi` knows how to obtain the paths, how to load the times and fluxes of the light curves, and how to load the labels. This `LightCurveObservationCollection.new(...` function takes in the three pieces we just built earlier. Note that you pass in the functions themselves, not the output of the functions. So for the `get_paths_function` parameter, we pass `get_positive_train_paths`, not `get_positive_train_paths()` (notice the difference in parenthesis). `qusi` will call these functions internally. However, the above bit of code is not by itself in `examples/transit_dataset.py` as the rest of the code in this tutorial was. This is because `qusi` doesn't use this collection by itself. It uses it as part of a dataset. We will explain why there's this extra layer in a moment. +This defines a collection of labeled light curves where `qusi` knows how to obtain the paths, how to load the times and fluxes of the light curves, and how to load the labels. This `LightCurveObservationCollection.new(...` function takes in the three pieces we just built earlier. Note that you pass in the functions themselves, not the output of the functions. So for the `get_paths_function` parameter, we pass `get_positive_train_paths`, not `get_positive_train_paths()` (notice the difference in parenthesis). `qusi` will call these functions internally. However, the above bit of code is not by itself in `scripts/transit_dataset.py` as the rest of the code in this tutorial was. This is because `qusi` doesn't use this collection by itself. It uses it as part of a dataset. We will explain why there's this extra layer in a moment. ## Creating a dataset @@ -70,23 +67,16 @@ Finally, we build the dataset `qusi` uses to train the network. First, we'll tak ```python def get_transit_train_dataset(): - positive_train_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_train_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) - negative_train_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_negative_train_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=negative_label_function) - train_light_curve_dataset = LightCurveDataset.new( - standard_light_curve_collections=[positive_train_light_curve_collection, - negative_train_light_curve_collection]) + positive_train_light_curve_collection = LightCurveObservationCollection.new() + negative_train_light_curve_collection = LightCurveObservationCollection.new() + train_light_curve_dataset = LightCurveDataset.new(light_curve_collections=[positive_train_light_curve_collection, + negative_train_light_curve_collection]) return train_light_curve_dataset ``` This is the function which generates the training dataset we called in the {doc}`/tutorials/basic_transit_identification_with_prebuilt_components` tutorial. The parts of this function are as follows. First, we create the `positive_train_light_curve_collection`. This is exactly what we just saw in the previous section. Next, we create a `negative_train_light_curve_collection`. This is almost identical to its positive counterpart, except now we pass the `get_negative_train_paths` and `negative_label_function` instead of the positive versions. Then there is the `train_light_curve_dataset = LightCurveDataset.new(` line. This creates a `qusi` dataset built from these two collections. The reason the collections are separate is that `LightCurveDataset` has several mechanisms working under-the-hood. Notably for this case, `LightCurveDataset` will balance the two light curve collections. We know of a lot more light curves that don't have planet transits in them than we do light curves that do have planet transits. In the real world case, it's thousands of times more at least. But for a NN, it's usually useful to during the training process to show equal amounts of the positives and negatives. `LightCurveDataset` will do this for us. You may have also noticed that we passed these collections in as the `standard_light_curve_collections` parameter. `LightCurveDataset` also allows for passing different types of collections. Notably, collections can be passed such that light curves from one collection will be injected into another. This is useful for injecting synthetic signals into real telescope data. However, we'll save the injection options for another tutorial. -You can see the above `get_transit_train_dataset` dataset creation function in the `examples/transit_dataset.py` file. The only part of that file we haven't yet looked at in detail is the `get_transit_validation_dataset` and `get_transit_finite_test_dataset` functions. However, these are nearly identical to the above `get_transit_train_dataset` expect using the validation and test path obtaining functions above instead of the train ones. +You can see the above `get_transit_train_dataset` dataset creation function in the `scripts/transit_dataset.py` file. The only part of that file we haven't yet looked at in detail is the `get_transit_validation_dataset` and `get_transit_finite_test_dataset` functions. However, these are nearly identical to the above `get_transit_train_dataset` expect using the validation and test path obtaining functions above instead of the train ones. ## Adjusting this for your own binary classification task diff --git a/docs/source/tutorials/basic_transit_identification_with_prebuilt_components.md b/docs/source/tutorials/basic_transit_identification_with_prebuilt_components.md index a4975db0..9ebf8472 100644 --- a/docs/source/tutorials/basic_transit_identification_with_prebuilt_components.md +++ b/docs/source/tutorials/basic_transit_identification_with_prebuilt_components.md @@ -1,17 +1,22 @@ # Basic transit identification with prebuilt components -This tutorial will get you up and running with a neural network (NN) that can identify transiting exoplanets in data from the Transiting Exoplanet Survey Satellite (TESS). Many of the components used in this example will be prebuilt bits of code that we'll import from the package's example code. However, in later tutorials, we'll walkthrough how you would build each of these pieces yourself and how you would modify it for whatever your use case is. +This tutorial will get you up and running with a neural network (NN) that can identify transiting exoplanets in data from the Transiting Exoplanet Survey Satellite (TESS). Many of the components used in this example will be prebuilt bits of code that we'll import from the package's example code. However, in later tutorials, we'll walk through how you would build each of these pieces yourself and how you would modify it for whatever your use case is. ## Getting the example code -First, create a directory to hold the project named `qusi_example_project`, or some other suitable name. Then get the example scripts from the `qusi` repository. You can download just that directory by clicking [here](https://download-directory.github.io/?url=https%3A%2F%2Fgithub.com%2Fgolmschenk%2Fqusi%2Ftree%2Fmain%2Fexamples). Move this `examples` directory into your project directory so that you have `qusi_example_project/examples`. The remainder of the commands will assume you are running code from the project directory, unless otherwise stated. +First, we'll download some example code and enter that project's directory. To do this, run +```sh +git clone https://github.com/golmschenk/qusi_example_transit_binary_classification.git +cd qusi_example_transit_binary_classification +``` +The remainder of the commands will assume you are running code from the project directory, unless otherwise stated. ## Downloading the dataset -The next thing we'll do is download a dataset of light curves that include cases both with and without transiting planets. To do this, run the example script at `examples/download_spoc_transit_light_curves`. For now, don't worry about how each part of the code works. You can run the script with +The next thing we'll do is download a dataset of light curves that include cases both with and without transiting planets. To do this, run the example script at `scripts/download_spoc_transit_light_curves`. For now, don't worry about how each part of the code works. You can run the script with ```sh -python examples/download_spoc_transit_light_curves.py +python scripts/download_spoc_transit_light_curves.py ``` The main thing to know is that this will create a `data` directory within the project directory and within that will be a `spoc_transit_experiment` directory, referring to the data for the experiment of finding transiting planets within the TESS SPOC data. This will further contain 3 directories. One for train data, one for validation data, and one for test data. Within each of those, it will create a `positive` directory, that will hold the light curves with transits, and a `negative` directory, that will hold the light curves without transits. So the project directory tree now looks like @@ -31,10 +36,10 @@ data examples ``` -Each of these `positive` and `negative` data directories will now contain a set of light curves. The reason why the code in this script is not very important for you to know, is that it's mostly irrelevant for future uses. When you're working on your own problem, you'll obtain your data some other way. And `qusi` is flexible about the data structure, so this directory structure is not required. It's just one way to structure the data. Note, this is a relatively small dataset to make sure it doesn't take very long to get up and running. To get a better result, you'd want to download all known transiting light curves and a much larger collection non-transiting light curves. To quickly visualize one of these light curves, you can use the script at `examples/transit_light_curve_visualization.py`. Due to the available light curves on MAST being updated constantly, the random selection of light curves you downloaded might not include the light curve noted in this example file. Be sure to open the `examples/transit_light_curve_visualization.py` file and update the path to one of the light curves you downloaded. To see a transit case, be sure to select one from one of the `positive` directories. Then run +Each of these `positive` and `negative` data directories will now contain a set of light curves. The reason why the code in this script is not very important for you to know, is that it's mostly irrelevant for future uses. When you're working on your own problem, you'll obtain your data some other way. And `qusi` is flexible about the data structure, so this directory structure is not required. It's just one way to structure the data. Note, this is a relatively small dataset to make sure it doesn't take very long to get up and running. To get a better result, you'd want to download all known transiting light curves and a much larger collection non-transiting light curves. To quickly visualize one of these light curves, you can use the script at `scripts/transit_light_curve_visualization.py`. Due to the available light curves on MAST being updated constantly, the random selection of light curves you downloaded might not include the light curve noted in this example file. Be sure to open the `scripts/transit_light_curve_visualization.py` file and update the path to one of the light curves you downloaded. To see a transit case, be sure to select one from one of the `positive` directories. Then run ```sh -python examples/transit_light_curve_visualization.py +python scripts/transit_light_curve_visualization.py ``` You should see something like @@ -62,7 +67,7 @@ This will only log runs locally. If you choose the offline route, at some point, ## Train the network -Next, we'll look at the `examples/transit_train.py` file. In this script is a `main` function which will train our neural network on our data. The training script has 3 main components: +Next, we'll look at the `scripts/transit_train.py` file. In this script is a `main` function which will train our neural network on our data. The training script has 3 main components: 1. Code to prepare our datasets. 2. Code to prepare the neural network model. @@ -71,7 +76,7 @@ Next, we'll look at the `examples/transit_train.py` file. In this script is a `m Since `qusi` provides both models and and training loop code, the only one of these components that every user will be expected to deal with is preparing the dataset, since you'll eventually want to have `qusi` tackle the task you're interested in which will require you're own data. And the `qusi` dataset component will help make your data more suitable for training a neural network. However, we're going to save how to set up your own dataset (and how these example datasets are created) for the next tutorial. For now, we'll just use the example datasets as is. So, in the example script, you will see the first couple of lines of the `main` function call other functions that produce an example train and validation dataset for us. Then we choose one of the neural network models `qusi` provides (in this case the `Hadryss` model). Then finally, we start the training session. To run this training, simply run the script with: ```sh -python examples/transit_train.py +python scripts/transit_train.py ``` You should see some output showing basic training statistics from the terminal as it runs through the training loop. It will run for as many train cycles as were specified in the script. On every completed cycle, `qusi` will save the latest version of the fitted model to `sessions//latest_model`. @@ -80,10 +85,10 @@ You can also go to your Wandb project to see the metrics over the course of the ## Test the fitted model -A "fitted model" is a model which has been trained, or fitted, on some training data. Next, we'll take the fitted model we produced during training, and test it on data it didn't see during the training process. This is what happens in the `examples/transit_finite_dataset_test.py` script. The `main` function will look semi-similar to from the training script. Again, we'll defer how the dataset is produced until the next tutorial. Then we create the model as we did before, but this time we load the fitted parameters of the model from the saved file. Here, you will need to update the script to point to your saved model produced in the last section. Then we can run the script with +A "fitted model" is a model which has been trained, or fitted, on some training data. Next, we'll take the fitted model we produced during training, and test it on data it didn't see during the training process. This is what happens in the `scripts/transit_finite_dataset_test.py` script. The `main` function will look semi-similar to from the training script. Again, we'll defer how the dataset is produced until the next tutorial. Then we create the model as we did before, but this time we load the fitted parameters of the model from the saved file. Here, you will need to update the script to point to your saved model produced in the last section. Then we can run the script with ```sh -python examples/transit_finite_dataset_test.py +python scripts/transit_finite_dataset_test.py ``` This will run the network on the test data, producing the metrics that are requested in the file. diff --git a/docs/source/tutorials/crafting_standard_datasets.md b/docs/source/tutorials/crafting_standard_datasets.md index 2c1097e7..99e8dcf5 100644 --- a/docs/source/tutorials/crafting_standard_datasets.md +++ b/docs/source/tutorials/crafting_standard_datasets.md @@ -12,38 +12,39 @@ However, the uniform length is set to a specific default value. A good choice fo ```python from functools import partial -from qusi.light_curve_dataset import default_light_curve_post_injection_transform +from qusi.transform import default_light_curve_post_injection_transform ``` Then, were we specify the construction of our dataset, we'll add an additional input parameter. So taking what we had in the previous tutorial, we can now change the dataset creation statement to: ```python -train_light_curve_dataset = LightCurveDataset.new( - standard_light_curve_collections=[positive_train_light_curve_collection, - negative_train_light_curve_collection] - post_injection_transform=partial(default_light_curve_post_injection_transform, length=4000) +train_light_curve_dataset = LightCurveObservationDataset.new( + light_curve_collections=[positive_train_light_curve_collection, + negative_train_light_curve_collection], + post_injection_transform = partial(default_light_curve_post_injection_transform, length=4000) ) ``` -Let's clarify what's happening here. The `LightCurveDataset.new()` constructor takes as input a parameter called `post_injection_transform`. This function will be applied to our light curves before they get handed to the NN. `default_light_curve_post_injection_transform` is the default set of preprocessing transforms `qusi` uses. We'll look at these transforms in more detail in the next section. `partial` is a Python builtin function, that takes another function as input, along with a parameter of that function, and creates a new function with that parameter prefilled. So `partial(default_light_curve_post_injection_transform, length=4000)` is taking our default transforms, setting the uniforming lengthening step to 4000, then giving us back the updated function, which we can then give to the dataset. The advantage to this approach is that `post_injection_transform` is completely flexible, as we'll explore more in the next section. +Let's clarify what's happening here. The `LightCurveObservationDataset.new()` constructor takes as input a parameter called `post_injection_transform`. This function will be applied to our light curves before they get handed to the NN. `default_light_curve_observation_post_injection_transform` is the default set of preprocessing transforms `qusi` uses. We'll look at these transforms in more detail in the next section. `partial` is a Python builtin function, that takes another function as input, along with a parameter of that function, and creates a new function with that parameter prefilled. So `partial(default_light_curve_observation_post_injection_transform, length=4000)` is taking our default transforms, setting the uniforming lengthening step to 4000, then giving us back the updated function, which we can then give to the dataset. The advantage to this approach is that `post_injection_transform` is completely flexible, as we'll explore more in the next section. Before we run the updated code, we also need to use a NN model which expects our new input size. Luckily, `qusi` has NN architectures that automatically resize their components for a given input size. So the only other change from the existing code is to change `Hadryss.new()` to `Hadryss.new(input_length=4000)`. ## Modifying the preprocessing -In the previous section, we only changed the length of that the uniform lengthening preprocessing transform is using. However, we still used all the remaining default preprocessing steps that are contained in `default_light_curve_post_injection_transform`. Let's take a look at what the default one does. It looks like: +In the previous section, we only changed the length of that the uniform lengthening preprocessing transform is using. However, we still used all the remaining default preprocessing steps that are contained in `default_light_curve_observation_post_injection_transform`. Let's take a look at what the default one does. It looks something like: ```python -def default_light_curve_observation_post_injection_transform(x: LightCurveObservation, length: int) -> (Tensor, Tensor): +def default_light_curve_observation_post_injection_transform(x: LightCurveObservation, length: int, randomize: bool = True) -> (Tensor, Tensor): x = remove_nan_flux_data_points_from_light_curve_observation(x) - x = randomly_roll_light_curve_observation(x) + if randomize: + x = randomly_roll_light_curve_observation(x) x = from_light_curve_observation_to_fluxes_array_and_label_array(x) - x = make_fluxes_and_label_array_uniform_length(x, length=length) + x = (make_uniform_length(x[0], length=length), x[1]) x = pair_array_to_tensor(x) x = (normalize_tensor_by_modified_z_score(x[0]), x[1]) return x ``` -It's a function that takes in a `LightCurveObservation` and spits out two `Tensor`s, one for the fluxes and one for the label to predict. Most of the data transform functions within have names that are largely descriptive, but we'll walk through them anyway. `remove_nan_flux_data_points_from_light_curve_observation` removes time steps from a `LightCurveObservation` where the flux is NaN. `randomly_roll_light_curve_observation` randomly rolls the light curve (a random cut is made and the two segments' order is swapped). `from_light_curve_observation_to_fluxes_array_and_label_array` extracts two NumPy arrays from a `LightCurveObservation`, one for the fluxes and one from the label (which in this case will be an array with a single value). `make_fluxes_and_label_array_uniform_length` performs the uniform lengthening we discussed in the previous section. `pair_array_to_tensor` converts the pair of NumPy arrays to a pair of PyTorch tensors (PyTorch's equivalent of an array). `normalize_tensor_by_modified_z_score` normalizes a tensor via based on the median absolute deviation. Notice, this is only applied to the flux tensor, not the label tensor. +It's a function that takes in a `LightCurveObservation` and spits out two `Tensor`s, one for the fluxes and one for the label to predict. Most of the data transform functions within have names that are largely descriptive, but we'll walk through them anyway. `remove_nan_flux_data_points_from_light_curve_observation` removes time steps from a `LightCurveObservation` where the flux is NaN. `randomly_roll_light_curve_observation` randomly rolls the light curve (a random cut is made and the two segments' order is swapped). `from_light_curve_observation_to_fluxes_array_and_label_array` extracts two NumPy arrays from a `LightCurveObservation`, one for the fluxes and one from the label (which in this case will be an array with a single value). `make_uniform_length` performs the uniform lengthening on the fluxes as we discussed in the previous section. `pair_array_to_tensor` converts the pair of NumPy arrays to a pair of PyTorch tensors (PyTorch's equivalent of an array). `normalize_tensor_by_modified_z_score` normalizes a tensor via based on the median absolute deviation. Notice, this is only applied to the flux tensor, not the label tensor. The `randomize` parameter enables or disables randomization of the functions which may include randomization. During training, randomization should be on to make sure we get variation in training observations. During evaluation and inference, it should be off to get repeatable results. In our previous example, to keep the code simple, we did not disable randomization for the validation dataset. Although in most cases it will not make a major difference, randomization should be disabled on the validation dataset. It should only be enabled for the training dataset. -It's worth noting, `default_light_curve_post_injection_transform` is just a function that can be replaced as desired. To remove one of the preprocessing steps or add in an addition one, we can simply make a modified version of this function. Additionally, `qusi` does not require the transform function to output only the fluxes and a binary label. The `Hadryss` NN model expects these two types of values for training, but other models may take advantage of the times of the light curve, or they may predict multi-class or regression labels. +It's worth noting, `default_light_curve_observation_post_injection_transform` is just a function that can be replaced as desired. To remove one of the preprocessing steps or add in an addition one, we can simply make a modified version of this function. Additionally, `qusi` does not require the transform function to output only the fluxes and a binary label. The `Hadryss` NN model expects these two types of values for training, but other models may take advantage of the times of the light curve, or they may predict multi-class or regression labels. diff --git a/examples/download_spoc_transit_light_curves.py b/examples/download_spoc_transit_light_curves.py deleted file mode 100644 index 418ec051..00000000 --- a/examples/download_spoc_transit_light_curves.py +++ /dev/null @@ -1,65 +0,0 @@ -from pathlib import Path - -import numpy as np - -from ramjet.data_interface.tess_data_interface import ( - download_spoc_light_curves_for_tic_ids, - get_spoc_tic_id_list_from_mast, -) -from ramjet.data_interface.tess_toi_data_interface import TessToiDataInterface, ToiColumns - - -def main(): - print('Retrieving metadata...') - spoc_target_tic_ids = get_spoc_tic_id_list_from_mast() - tess_toi_data_interface = TessToiDataInterface() - positive_tic_ids = tess_toi_data_interface.toi_dispositions[ - tess_toi_data_interface.toi_dispositions[ToiColumns.disposition.value] != 'FP'][ToiColumns.tic_id.value] - negative_tic_ids = list(set(spoc_target_tic_ids) - set(positive_tic_ids)) - positive_tic_ids_splits = np.split( - np.array(positive_tic_ids), [int(len(positive_tic_ids) * 0.8), int(len(positive_tic_ids) * 0.9)]) - positive_train_tic_ids = positive_tic_ids_splits[0].tolist() - positive_validation_tic_ids = positive_tic_ids_splits[1].tolist() - positive_test_tic_ids = positive_tic_ids_splits[2].tolist() - negative_tic_ids_splits = np.split( - np.array(negative_tic_ids), [int(len(negative_tic_ids) * 0.8), int(len(negative_tic_ids) * 0.9)]) - negative_train_tic_ids = negative_tic_ids_splits[0].tolist() - negative_validation_tic_ids = negative_tic_ids_splits[1].tolist() - negative_test_tic_ids = negative_tic_ids_splits[2].tolist() - sectors = list(range(27, 56)) - - print('Downloading light curves...') - download_spoc_light_curves_for_tic_ids( - tic_ids=positive_train_tic_ids, - download_directory=Path('data/spoc_transit_experiment/train/positives'), - sectors=sectors, - limit=2000) - download_spoc_light_curves_for_tic_ids( - tic_ids=negative_train_tic_ids, - download_directory=Path('data/spoc_transit_experiment/train/negatives'), - sectors=sectors, - limit=6000) - download_spoc_light_curves_for_tic_ids( - tic_ids=positive_validation_tic_ids, - download_directory=Path('data/spoc_transit_experiment/validation/positives'), - sectors=sectors, - limit=200) - download_spoc_light_curves_for_tic_ids( - tic_ids=negative_validation_tic_ids, - download_directory=Path('data/spoc_transit_experiment/validation/negatives'), - sectors=sectors, - limit=600) - download_spoc_light_curves_for_tic_ids( - tic_ids=positive_test_tic_ids, - download_directory=Path('data/spoc_transit_experiment/test/positives'), - sectors=sectors, - limit=200) - download_spoc_light_curves_for_tic_ids( - tic_ids=negative_test_tic_ids, - download_directory=Path('data/spoc_transit_experiment/test/negatives'), - sectors=sectors, - limit=600) - - -if __name__ == '__main__': - main() diff --git a/examples/transit_dataset.py b/examples/transit_dataset.py deleted file mode 100644 index 37265bf0..00000000 --- a/examples/transit_dataset.py +++ /dev/null @@ -1,88 +0,0 @@ -from pathlib import Path - -from qusi.finite_standard_light_curve_observation_dataset import FiniteStandardLightCurveObservationDataset -from qusi.light_curve_collection import LightCurveObservationCollection -from qusi.light_curve_dataset import LightCurveDataset -from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve - - -def get_positive_train_paths(): - return list(Path('data/spoc_transit_experiment/train/positives').glob('*.fits')) - - -def get_negative_train_paths(): - return list(Path('data/spoc_transit_experiment/train/negatives').glob('*.fits')) - - -def get_positive_validation_paths(): - return list(Path('data/spoc_transit_experiment/validation/positives').glob('*.fits')) - - -def get_negative_validation_paths(): - return list(Path('data/spoc_transit_experiment/validation/negatives').glob('*.fits')) - - -def get_negative_test_paths(): - return list(Path('data/spoc_transit_experiment/test/negatives').glob('*.fits')) - - -def get_positive_test_paths(): - return list(Path('data/spoc_transit_experiment/test/positives').glob('*.fits')) - - -def load_times_and_fluxes_from_path(path): - light_curve = TessMissionLightCurve.from_path(path) - return light_curve.times, light_curve.fluxes - - -def positive_label_function(path): - return 1 - - -def negative_label_function(path): - return 0 - - -def get_transit_train_dataset(): - positive_train_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_train_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) - negative_train_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_negative_train_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=negative_label_function) - train_light_curve_dataset = LightCurveDataset.new( - standard_light_curve_collections=[positive_train_light_curve_collection, - negative_train_light_curve_collection]) - return train_light_curve_dataset - - -def get_transit_validation_dataset(): - positive_validation_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_validation_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) - negative_validation_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_negative_validation_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=negative_label_function) - validation_light_curve_dataset = LightCurveDataset.new( - standard_light_curve_collections=[positive_validation_light_curve_collection, - negative_validation_light_curve_collection]) - return validation_light_curve_dataset - - -def get_transit_finite_test_dataset(): - positive_test_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_test_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) - negative_test_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_negative_test_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=negative_label_function) - test_light_curve_dataset = FiniteStandardLightCurveObservationDataset.new( - standard_light_curve_collections=[positive_test_light_curve_collection, - negative_test_light_curve_collection]) - return test_light_curve_dataset diff --git a/examples/transit_finite_dataset_test.py b/examples/transit_finite_dataset_test.py deleted file mode 100644 index 6c6479ae..00000000 --- a/examples/transit_finite_dataset_test.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch -from torch.nn import BCELoss -from torchmetrics.classification import BinaryAccuracy - -from qusi.finite_test_session import finite_datasets_test_session, get_device -from qusi.hadryss_model import Hadryss - -from transit_dataset import get_transit_finite_test_dataset - -def main(): - test_light_curve_dataset = get_transit_finite_test_dataset() - model = Hadryss.new() - device = get_device() - model.load_state_dict(torch.load('sessions/_latest_model.pt', map_location=device)) - metric_functions = [BinaryAccuracy(), BCELoss()] - results = finite_datasets_test_session(test_datasets=[test_light_curve_dataset], model=model, - metric_functions=metric_functions, batch_size=100, device=device) - print(f'Binary accuracy: {results[0][0]}') - print(f'Binary cross entropy: {results[0][1]}') - - -if __name__ == '__main__': - main() diff --git a/examples/transit_infer.py b/examples/transit_infer.py deleted file mode 100644 index 80673fdb..00000000 --- a/examples/transit_infer.py +++ /dev/null @@ -1,45 +0,0 @@ -from pathlib import Path - -import numpy as np -import torch - -from qusi.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset -from qusi.hadryss_model import Hadryss -from qusi.infer_session import infer_session -from qusi.device import get_device -from qusi.light_curve_collection import LightCurveCollection -from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve - - -def get_infer_paths(): - return (list(Path('data/spoc_transit_experiment/test/negatives').glob('*.fits')) + - list(Path('data/spoc_transit_experiment/test/positives').glob('*.fits'))) - - -def load_times_and_fluxes_from_path(path: Path) -> (np.ndarray, np.ndarray): - light_curve = TessMissionLightCurve.from_path(path) - return light_curve.times, light_curve.fluxes - - -def main(): - infer_light_curve_collection = LightCurveCollection.new( - get_paths_function=get_infer_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path) - - test_light_curve_dataset = FiniteStandardLightCurveDataset.new( - light_curve_collections=[infer_light_curve_collection]) - - model = Hadryss.new() - device = get_device() - model.load_state_dict(torch.load('sessions/_latest_model.pt', map_location=device)) - confidences = infer_session(infer_datasets=[test_light_curve_dataset], model=model, - batch_size=100, device=device)[0] - paths = list(get_infer_paths()) - paths_with_confidences = zip(paths, confidences) - sorted_paths_with_confidences = sorted( - paths_with_confidences, key=lambda path_with_confidence: path_with_confidence[1], reverse=True) - print(sorted_paths_with_confidences) - - -if __name__ == '__main__': - main() diff --git a/examples/transit_infinite_dataset_test.py b/examples/transit_infinite_dataset_test.py deleted file mode 100644 index 6f6e34c5..00000000 --- a/examples/transit_infinite_dataset_test.py +++ /dev/null @@ -1,96 +0,0 @@ -from pathlib import Path - -import numpy as np -import torch -from torch.nn import BCELoss, Module -from torch.types import Device -from torch.utils.data import DataLoader -from torchmetrics.classification import BinaryAccuracy - -from qusi.hadryss_model import Hadryss -from qusi.device import get_device -from qusi.light_curve_collection import LightCurveObservationCollection -from qusi.light_curve_dataset import LightCurveDataset -from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve - - -def get_negative_test_paths(): - return list(Path('data/spoc_transit_experiment/test/negatives').glob('*.fits')) - - -def get_positive_test_paths(): - return list(Path('data/spoc_transit_experiment/test/positives').glob('*.fits')) - - -def load_times_and_fluxes_from_path(path: Path) -> (np.ndarray, np.ndarray): - light_curve = TessMissionLightCurve.from_path(path) - return light_curve.times, light_curve.fluxes - - -def positive_label_function(_path: Path) -> int: - return 1 - - -def negative_label_function(_path: Path) -> int: - return 0 - - -def main(): - positive_test_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_test_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) - negative_test_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_negative_test_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=negative_label_function) - - test_light_curve_dataset = LightCurveDataset.new( - standard_light_curve_collections=[positive_test_light_curve_collection, - negative_test_light_curve_collection]) - - model = Hadryss.new() - device = get_device() - model.load_state_dict(torch.load('sessions/pleasant-lion-32_latest_model.pt', map_location=device)) - metric_functions = [BinaryAccuracy(), BCELoss()] - results = infinite_datasets_test_session(test_datasets=[test_light_curve_dataset], model=model, - metric_functions=metric_functions, batch_size=100, device=device, - steps=100) - return results - - -def infinite_datasets_test_session(test_datasets: list[LightCurveDataset], model: Module, - metric_functions: list[Module], batch_size: int, device: Device, steps: int): - test_dataloaders: list[DataLoader] = [] - for test_dataset in test_datasets: - test_dataloaders.append(DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)) - model.eval() - results = [] - for test_dataloader in test_dataloaders: - result = infinite_dataset_test_phase(test_dataloader, model, metric_functions, device=device, steps=steps) - results.append(result) - return results - - -def infinite_dataset_test_phase(dataloader, model: Module, metric_functions: list[Module], device: Device, steps: int): - batch_count = 0 - metric_totals = torch.zeros(size=[len(metric_functions)]) - model.eval() - with torch.no_grad(): - for input_features, targets in dataloader: - input_features = input_features.to(device, non_blocking=True) - targets = targets.to(device, non_blocking=True) - predicted_targets = model(input_features) - for metric_function_index, metric_function in enumerate(metric_functions): - batch_metric_value = metric_function(predicted_targets.to(device, non_blocking=True), - targets) - metric_totals[metric_function_index] += batch_metric_value.to('cpu', non_blocking=True) - batch_count += 1 - if batch_count >= steps: - break - cycle_metric_values = metric_totals / batch_count - return cycle_metric_values - - -if __name__ == '__main__': - main() diff --git a/examples/transit_light_curve_visualization.py b/examples/transit_light_curve_visualization.py deleted file mode 100644 index 1175da4d..00000000 --- a/examples/transit_light_curve_visualization.py +++ /dev/null @@ -1,20 +0,0 @@ -from pathlib import Path - -from bokeh.io import show -from bokeh.plotting import figure as Figure - -from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve - - -def main(): - light_curve_path = Path( - 'data/spoc_transit_experiment/train/positives/hlsp_tess-spoc_tess_phot_0000000004605846-s0044_tess_v1_lc.fits') - light_curve = TessMissionLightCurve.from_path(light_curve_path) - light_curve_figure = Figure(x_axis_label='Time (BTJD)', y_axis_label='Flux') - light_curve_figure.circle(x=light_curve.times, y=light_curve.fluxes) - light_curve_figure.line(x=light_curve.times, y=light_curve.fluxes, line_alpha=0.3) - show(light_curve_figure) - - -if __name__ == '__main__': - main() diff --git a/examples/transit_train.py b/examples/transit_train.py deleted file mode 100644 index 79054045..00000000 --- a/examples/transit_train.py +++ /dev/null @@ -1,18 +0,0 @@ -from qusi.hadryss_model import Hadryss -from qusi.train_hyperparameter_configuration import TrainHyperparameterConfiguration -from qusi.train_session import train_session - -from transit_dataset import get_transit_train_dataset, get_transit_validation_dataset - -def main(): - train_light_curve_dataset = get_transit_train_dataset() - validation_light_curve_dataset = get_transit_validation_dataset() - model = Hadryss.new() - train_hyperparameter_configuration = TrainHyperparameterConfiguration.new( - batch_size=100, cycles=20, train_steps_per_cycle=100, validation_steps_per_cycle=10) - train_session(train_datasets=[train_light_curve_dataset], validation_datasets=[validation_light_curve_dataset], - model=model, hyperparameter_configuration=train_hyperparameter_configuration) - - -if __name__ == '__main__': - main() diff --git a/pyproject.toml b/pyproject.toml index 6b3e2de8..fe9d2221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,8 @@ dependencies = [ "sphinx>=6.1.3", "backports.strenum", "typing_extensions", - "myst-parser" + "myst-parser", + "torcheval>=0.0.7", ] [build-system] diff --git a/qodana.yaml b/qodana.yaml new file mode 100644 index 00000000..c28d57d6 --- /dev/null +++ b/qodana.yaml @@ -0,0 +1,12 @@ +version: "1.0" + +profile: + name: qodana.starter + +exclude: + - name: All + paths: + - src/ramjet + - examples + +linter: jetbrains/qodana-python:latest diff --git a/src/qusi/data.py b/src/qusi/data.py new file mode 100644 index 00000000..bf0dd4ec --- /dev/null +++ b/src/qusi/data.py @@ -0,0 +1,17 @@ +""" +Data structure related public interface. +""" +from qusi.internal.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset +from qusi.internal.finite_standard_light_curve_observation_dataset import FiniteStandardLightCurveObservationDataset +from qusi.internal.light_curve import LightCurve +from qusi.internal.light_curve_collection import LightCurveObservationCollection, LightCurveCollection +from qusi.internal.light_curve_dataset import LightCurveDataset + +__all__ = [ + 'FiniteStandardLightCurveDataset', + 'FiniteStandardLightCurveObservationDataset', + 'LightCurve', + 'LightCurveCollection', + 'LightCurveDataset', + 'LightCurveObservationCollection', +] diff --git a/src/qusi/experimental/__init__.py b/src/qusi/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/qusi/experimental/application/__init__.py b/src/qusi/experimental/application/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/qusi/experimental/application/tess.py b/src/qusi/experimental/application/tess.py new file mode 100644 index 00000000..419a83b7 --- /dev/null +++ b/src/qusi/experimental/application/tess.py @@ -0,0 +1,14 @@ +from ramjet.data_interface.tess_data_interface import ( + download_spoc_light_curves_for_tic_ids, + get_spoc_tic_id_list_from_mast, +) +from ramjet.data_interface.tess_toi_data_interface import TessToiDataInterface, ToiColumns +from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve + +__all__ = [ + 'download_spoc_light_curves_for_tic_ids', + 'get_spoc_tic_id_list_from_mast', + 'TessMissionLightCurve', + 'TessToiDataInterface', + 'ToiColumns', +] diff --git a/src/qusi/experimental/session.py b/src/qusi/experimental/session.py new file mode 100644 index 00000000..aa01d435 --- /dev/null +++ b/src/qusi/experimental/session.py @@ -0,0 +1,5 @@ +from qusi.internal.finite_test_session import finite_datasets_test_session + +__all__ = [ + 'finite_datasets_test_session', +] \ No newline at end of file diff --git a/src/qusi/internal/__init__.py b/src/qusi/internal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/qusi/device.py b/src/qusi/internal/device.py similarity index 68% rename from src/qusi/device.py rename to src/qusi/internal/device.py index 37455a22..fc725a11 100644 --- a/src/qusi/device.py +++ b/src/qusi/internal/device.py @@ -3,6 +3,11 @@ def get_device() -> Device: + """ + Gets the available device for PyTorch to run on. + + :return: The device. + """ if torch.cuda.is_available(): device = torch.device("cuda") else: diff --git a/src/qusi/finite_standard_light_curve_dataset.py b/src/qusi/internal/finite_standard_light_curve_dataset.py similarity index 59% rename from src/qusi/finite_standard_light_curve_dataset.py rename to src/qusi/internal/finite_standard_light_curve_dataset.py index 8a941663..f2333cba 100644 --- a/src/qusi/finite_standard_light_curve_dataset.py +++ b/src/qusi/internal/finite_standard_light_curve_dataset.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from functools import partial from typing import Any, Callable @@ -6,19 +8,37 @@ from torch.utils.data import Dataset from typing_extensions import Self -from qusi.light_curve_collection import LightCurveCollection -from qusi.light_curve_dataset import default_light_curve_post_injection_transform +from qusi.internal.light_curve_collection import LightCurveCollection +from qusi.internal.light_curve_dataset import default_light_curve_post_injection_transform @dataclass class FiniteStandardLightCurveDataset(Dataset): + """ + A finite light curve dataset without injection. + """ standard_light_curve_collections: list[LightCurveCollection] post_injection_transform: Callable[[Any], Any] length: int collection_start_indexes: list[int] @classmethod - def new(cls, light_curve_collections: list[LightCurveCollection]) -> Self: + def new( + cls, + light_curve_collections: list[LightCurveCollection], + *, + post_injection_transform: Callable[[Any], Any] | None = None, + ) -> Self: + """ + Creates a new `FiniteStandardLightCurveDataset`. + + :param light_curve_collections: The light curve collections to include in the dataset. + :param post_injection_transform: Transforms to the data to occur after injection. + :return: The dataset. + """ + if post_injection_transform is None: + post_injection_transform = partial(default_light_curve_post_injection_transform, length=3500, + randomize=False) length = 0 collection_start_indexes: list[int] = [] for light_curve_collection in light_curve_collections: @@ -27,7 +47,7 @@ def new(cls, light_curve_collections: list[LightCurveCollection]) -> Self: length += standard_light_curve_collection_length instance = cls( standard_light_curve_collections=light_curve_collections, - post_injection_transform=partial(default_light_curve_post_injection_transform, length=2500), + post_injection_transform=post_injection_transform, length=length, collection_start_indexes=collection_start_indexes, ) diff --git a/src/qusi/finite_standard_light_curve_observation_dataset.py b/src/qusi/internal/finite_standard_light_curve_observation_dataset.py similarity index 53% rename from src/qusi/finite_standard_light_curve_observation_dataset.py rename to src/qusi/internal/finite_standard_light_curve_observation_dataset.py index 61555b71..c1918361 100644 --- a/src/qusi/finite_standard_light_curve_observation_dataset.py +++ b/src/qusi/internal/finite_standard_light_curve_observation_dataset.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from functools import partial from typing import Any, Callable @@ -6,28 +8,46 @@ from torch.utils.data import Dataset from typing_extensions import Self -from qusi.light_curve_collection import LightCurveObservationCollection -from qusi.light_curve_dataset import default_light_curve_observation_post_injection_transform +from qusi.internal.light_curve_collection import LightCurveObservationCollection +from qusi.internal.light_curve_dataset import default_light_curve_observation_post_injection_transform @dataclass class FiniteStandardLightCurveObservationDataset(Dataset): + """ + A finite light curve observation dataset without injection. + """ standard_light_curve_collections: list[LightCurveObservationCollection] post_injection_transform: Callable[[Any], Any] length: int collection_start_indexes: list[int] @classmethod - def new(cls, standard_light_curve_collections: list[LightCurveObservationCollection]) -> Self: + def new( + cls, + light_curve_collections: list[LightCurveObservationCollection], + *, + post_injection_transform: Callable[[Any], Any] | None = None, + ) -> Self: + """ + Creates a new `FiniteStandardLightCurveObservationDataset`. + + :param light_curve_collections: The light curve observation collections to include in the dataset. + :param post_injection_transform: Transforms to the data to occur after injection. + :return: The dataset. + """ + if post_injection_transform is None: + post_injection_transform = partial(default_light_curve_observation_post_injection_transform, length=3500, + randomize=False) length = 0 collection_start_indexes: list[int] = [] - for standard_light_curve_collection in standard_light_curve_collections: + for standard_light_curve_collection in light_curve_collections: standard_light_curve_collection_length = len(list(standard_light_curve_collection.path_getter.get_paths())) collection_start_indexes.append(length) length += standard_light_curve_collection_length instance = cls( - standard_light_curve_collections=standard_light_curve_collections, - post_injection_transform=partial(default_light_curve_observation_post_injection_transform, length=2500), + standard_light_curve_collections=light_curve_collections, + post_injection_transform=post_injection_transform, length=length, collection_start_indexes=collection_start_indexes, ) diff --git a/src/qusi/finite_test_session.py b/src/qusi/internal/finite_test_session.py similarity index 64% rename from src/qusi/finite_test_session.py rename to src/qusi/internal/finite_test_session.py index de3c48ea..8430ec60 100644 --- a/src/qusi/finite_test_session.py +++ b/src/qusi/internal/finite_test_session.py @@ -3,16 +3,28 @@ from torch.types import Device from torch.utils.data import DataLoader -from qusi.finite_standard_light_curve_observation_dataset import FiniteStandardLightCurveObservationDataset +from qusi.internal.finite_standard_light_curve_observation_dataset import FiniteStandardLightCurveObservationDataset def finite_datasets_test_session( - test_datasets: list[FiniteStandardLightCurveObservationDataset], - model: Module, - metric_functions: list[Module], - batch_size: int, - device: Device, + test_datasets: list[FiniteStandardLightCurveObservationDataset], + model: Module, + metric_functions: list[Module], + *, + batch_size: int = 100, + device: Device = torch.device('cpu'), ): + """ + Runs a test session on finite datasets. + + :param test_datasets: A list of datasets to run the test session on. + :param model: A model to perform the inference. + :param metric_functions: A metrics to test. + :param batch_size: A batch size to use during testing. + :param device: A device to run the model on. + :return: A list of arrays, with one array for each test dataset, with each array containing an element for each + metric that was tested. + """ test_dataloaders: list[DataLoader] = [] for test_dataset in test_datasets: test_dataloader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True) @@ -25,14 +37,6 @@ def finite_datasets_test_session( return results -def get_device(): - if torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") - return device - - def finite_dataset_test_phase(dataloader, model: Module, metric_functions: list[Module], device: Device): batch_count = 0 metric_totals = torch.zeros(size=[len(metric_functions)]) diff --git a/src/qusi/hadryss_model.py b/src/qusi/internal/hadryss_model.py similarity index 89% rename from src/qusi/hadryss_model.py rename to src/qusi/internal/hadryss_model.py index 61f09d1d..c8b54908 100644 --- a/src/qusi/hadryss_model.py +++ b/src/qusi/internal/hadryss_model.py @@ -18,7 +18,11 @@ class Hadryss(Module): - def __init__(self, input_length: int): + """ + A 1D convolutional neural network model for light curve data that will auto-size itself for a given input light + curve length. + """ + def __init__(self, *, input_length: int): super().__init__() self.input_length: int = input_length pooling_sizes, dense_size = self.determine_block_pooling_sizes_and_dense_size() @@ -123,7 +127,13 @@ def forward(self, x: Tensor) -> Tensor: return x @classmethod - def new(cls, input_length: int = 2500) -> Self: + def new(cls, input_length: int = 3500) -> Self: + """ + Creates a new Hadryss model. + + :param input_length: The length of the input to auto-size the network to. + :return: The model. + """ instance = cls(input_length=input_length) return instance @@ -145,16 +155,16 @@ def determine_block_pooling_sizes_and_dense_size(self) -> (list[int], int): class LightCurveNetworkBlock(Module): def __init__( - self, - input_channels: int, - output_channels: int, - kernel_size: int, - pooling_size: int, - dropout_rate: float = 0.0, - *, - batch_normalization: bool = False, - spatial: bool = True, - length: int | None = None, + self, + input_channels: int, + output_channels: int, + kernel_size: int, + pooling_size: int, + dropout_rate: float = 0.0, + *, + batch_normalization: bool = False, + spatial: bool = True, + length: int | None = None, ): super().__init__() self.leaky_relu = LeakyReLU() diff --git a/src/qusi/infer_session.py b/src/qusi/internal/infer_session.py similarity index 65% rename from src/qusi/infer_session.py rename to src/qusi/internal/infer_session.py index bc2a2f7b..32f9d936 100644 --- a/src/qusi/infer_session.py +++ b/src/qusi/internal/infer_session.py @@ -4,12 +4,25 @@ from torch.types import Device from torch.utils.data import DataLoader -from qusi.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset +from qusi.internal.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset def infer_session( - infer_datasets: list[FiniteStandardLightCurveDataset], model: Module, batch_size: int, device: Device + infer_datasets: list[FiniteStandardLightCurveDataset], + model: Module, + *, + batch_size: int, + device: Device, ) -> list[np.ndarray]: + """ + Runs an infer session on finite datasets. + + :param infer_datasets: The list of datasets to run the infer session on. + :param model: The model to perform the inference. + :param batch_size: The batch size to use during inference. + :param device: The device to run the model on. + :return: A list of arrays with each element being the array predicted for each light curve in the dataset. + """ infer_dataloaders: list[DataLoader] = [] for infer_dataset in infer_datasets: infer_dataloader = DataLoader(infer_dataset, batch_size=batch_size, pin_memory=True) diff --git a/src/qusi/internal/infinite_datasets_test_session.py b/src/qusi/internal/infinite_datasets_test_session.py new file mode 100644 index 00000000..5abddb72 --- /dev/null +++ b/src/qusi/internal/infinite_datasets_test_session.py @@ -0,0 +1,51 @@ +from torch.nn import Module +from torch.types import Device +from torch.utils.data import DataLoader +from wandb.wandb_torch import torch + +from qusi.internal.light_curve_dataset import LightCurveDataset + + +def infinite_datasets_test_session(test_datasets: list[LightCurveDataset], model: Module, + metric_functions: list[Module], *, batch_size: int, device: Device, steps: int): + """ + Runs a test session on finite datasets. + + :param test_datasets: A list of datasets to run the test session on. + :param model: A model to perform the inference. + :param metric_functions: A metrics to test. + :param batch_size: A batch size to use during testing. + :param device: A device to run the model on. + :param steps: The number of steps to run on the infinite datasets. + :return: A list of arrays, with one array for each test dataset, with each array containing an element for each + metric that was tested. + """ + test_dataloaders: list[DataLoader] = [] + for test_dataset in test_datasets: + test_dataloaders.append(DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)) + model.eval() + results = [] + for test_dataloader in test_dataloaders: + result = infinite_dataset_test_phase(test_dataloader, model, metric_functions, device=device, steps=steps) + results.append(result) + return results + + +def infinite_dataset_test_phase(dataloader, model: Module, metric_functions: list[Module], device: Device, steps: int): + batch_count = 0 + metric_totals = torch.zeros(size=[len(metric_functions)]) + model.eval() + with torch.no_grad(): + for input_features, targets in dataloader: + input_features = input_features.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + predicted_targets = model(input_features) + for metric_function_index, metric_function in enumerate(metric_functions): + batch_metric_value = metric_function(predicted_targets.to(device, non_blocking=True), + targets) + metric_totals[metric_function_index] += batch_metric_value.to('cpu', non_blocking=True) + batch_count += 1 + if batch_count >= steps: + break + cycle_metric_values = metric_totals / batch_count + return cycle_metric_values diff --git a/src/qusi/light_curve.py b/src/qusi/internal/light_curve.py similarity index 73% rename from src/qusi/light_curve.py rename to src/qusi/internal/light_curve.py index baa1b309..d866dda1 100644 --- a/src/qusi/light_curve.py +++ b/src/qusi/internal/light_curve.py @@ -34,6 +34,13 @@ def new(cls, times: npt.NDArray[np.float32], fluxes: npt.NDArray[np.float32]) -> def remove_nan_flux_data_points_from_light_curve(light_curve: LightCurve) -> LightCurve: + """ + Removes the NaN values from a light curve in a light curve. If there is a NaN in either the times or the + fluxes, both corresponding values are removed. + + :param light_curve: The light curve. + :return: The light curve with NaN values removed. + """ light_curve = deepcopy(light_curve) nan_flux_indexes = np.isnan(light_curve.fluxes) light_curve.fluxes = light_curve.fluxes[~nan_flux_indexes] @@ -42,6 +49,13 @@ def remove_nan_flux_data_points_from_light_curve(light_curve: LightCurve) -> Lig def randomly_roll_light_curve(light_curve: LightCurve) -> LightCurve: + """ + Randomly rolls a light curve. That is, a random position in the light curve is chosen, the light curve + is split at that point, and the order of the two halves are swapped. + + :param light_curve: The light curve. + :return: The rolled light curve. + """ light_curve = deepcopy(light_curve) random_index = np.random.randint(light_curve.times.shape[0]) light_curve.times = np.roll(light_curve.times, random_index) diff --git a/src/qusi/light_curve_collection.py b/src/qusi/internal/light_curve_collection.py similarity index 94% rename from src/qusi/light_curve_collection.py rename to src/qusi/internal/light_curve_collection.py index d5059d5a..6bbb479d 100644 --- a/src/qusi/light_curve_collection.py +++ b/src/qusi/internal/light_curve_collection.py @@ -11,8 +11,8 @@ import numpy.typing as npt from typing_extensions import Self -from qusi.light_curve import LightCurve -from qusi.light_curve_observation import LightCurveObservation +from qusi.internal.light_curve import LightCurve +from qusi.internal.light_curve_observation import LightCurveObservation if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -117,6 +117,8 @@ class LightCurveCollection( LightCurveCollectionBase, LightCurveObservationIndexableBase ): """ + A collection of light curves, including where to find paths to the data and how to load the data. + :ivar path_getter: The PathIterableBase object for the collection. :ivar load_times_and_fluxes_from_path_function: The function to load the times and fluxes from the light curve. """ @@ -154,6 +156,8 @@ def light_curve_iter(self) -> Iterator[LightCurve]: :return: The iterable of the light curves. """ light_curve_paths = self.path_getter.get_shuffled_paths() + if len(light_curve_paths) == 0: + raise ValueError('LightCurveCollection returned no paths.') for light_curve_path in light_curve_paths: times, fluxes = self.load_times_and_fluxes_from_path_function( light_curve_path @@ -178,6 +182,9 @@ class LightCurveObservationCollection( LightCurveObservationCollectionBase, LightCurveObservationIndexableBase ): """ + A collection of light curve observations. Includes where to find the light curve data paths, and how to load + the times, fluxes, and label data. + :ivar path_getter: The PathGetterBase object for the collection. :ivar light_curve_collection: The LightCurveCollectionBase object for the collection. :ivar load_label_from_path_function: The function to load the label for the light curve. @@ -259,6 +266,8 @@ def observation_iter(self) -> Iterator[LightCurveObservation]: :return: The iterable of the light curves. """ light_curve_paths = self.path_getter.get_shuffled_paths() + if len(light_curve_paths) == 0: + raise ValueError('LightCurveObservationCollection returned no paths.') for light_curve_path in light_curve_paths: times, fluxes = self.light_curve_collection.load_times_and_fluxes_from_path( light_curve_path diff --git a/src/qusi/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py similarity index 70% rename from src/qusi/light_curve_dataset.py rename to src/qusi/internal/light_curve_dataset.py index 91c6adc0..6aad6bc7 100644 --- a/src/qusi/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -20,52 +20,45 @@ from torch.utils.data import IterableDataset from typing_extensions import Self -from qusi.light_curve import ( +from qusi.internal.light_curve import ( LightCurve, randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve, ) -from qusi.light_curve_observation import ( +from qusi.internal.light_curve_observation import ( LightCurveObservation, randomly_roll_light_curve_observation, remove_nan_flux_data_points_from_light_curve_observation, ) -from qusi.light_curve_transforms import ( +from qusi.internal.light_curve_transforms import ( from_light_curve_observation_to_fluxes_array_and_label_array, - pair_array_to_tensor, + pair_array_to_tensor, normalize_tensor_by_modified_z_score, make_uniform_length, ) if TYPE_CHECKING: from collections.abc import Iterable, Iterator - from qusi.light_curve_collection import LightCurveObservationCollection + from qusi.internal.light_curve_collection import LightCurveObservationCollection class LightCurveDataset(IterableDataset): """ - A dataset of light curve data. + A dataset of light curves. Includes cases where light curves can be injected into one another. """ def __init__( - self, - standard_light_curve_collections: list[LightCurveObservationCollection], - injectee_light_curve_collections: list[LightCurveObservationCollection], - injectable_light_curve_collections: list[LightCurveObservationCollection], - post_injection_transform: Callable[[Any], Any], + self, + standard_light_curve_collections: list[LightCurveObservationCollection], + *, + injectee_light_curve_collections: list[LightCurveObservationCollection], + injectable_light_curve_collections: list[LightCurveObservationCollection], + post_injection_transform: Callable[[Any], Any], ): - self.standard_light_curve_collections: list[ - LightCurveObservationCollection - ] = standard_light_curve_collections - self.injectee_light_curve_collections: list[ - LightCurveObservationCollection - ] = injectee_light_curve_collections + self.standard_light_curve_collections: list[LightCurveObservationCollection] = standard_light_curve_collections + self.injectee_light_curve_collections: list[LightCurveObservationCollection] = injectee_light_curve_collections self.injectable_light_curve_collections: list[ - LightCurveObservationCollection - ] = injectable_light_curve_collections - if ( - len(self.standard_light_curve_collections) == 0 - and len(self.injectee_light_curve_collections) == 0 - ): + LightCurveObservationCollection] = injectable_light_curve_collections + if len(self.standard_light_curve_collections) == 0 and len(self.injectee_light_curve_collections) == 0: error_message = ( "Either the standard or injectee light curve collection lists must not be empty. " "Both were empty." @@ -99,27 +92,18 @@ def __iter__(self): loop_iter_function(injectee_collection.observation_iter), LightCurveCollectionType.INJECTEE, ) - base_light_curve_collection_iter_and_type_pairs.append( - base_light_curve_collection_iter_and_type_pair - ) + base_light_curve_collection_iter_and_type_pairs.append(base_light_curve_collection_iter_and_type_pair) injectable_light_curve_collection_iters: list[ Iterator[LightCurveObservation] ] = [] for injectable_collection in self.injectable_light_curve_collections: - injectable_light_curve_collection_iter = loop_iter_function( - injectable_collection.observation_iter - ) - injectable_light_curve_collection_iters.append( - injectable_light_curve_collection_iter - ) + injectable_light_curve_collection_iter = loop_iter_function(injectable_collection.observation_iter) + injectable_light_curve_collection_iters.append(injectable_light_curve_collection_iter) while True: for ( - base_light_curve_collection_iter_and_type_pair + base_light_curve_collection_iter_and_type_pair ) in base_light_curve_collection_iter_and_type_pairs: - ( - base_collection_iter, - collection_type, - ) = base_light_curve_collection_iter_and_type_pair + (base_collection_iter, collection_type) = base_light_curve_collection_iter_and_type_pair if collection_type in [ LightCurveCollectionType.STANDARD, LightCurveCollectionType.STANDARD_AND_INJECTEE, @@ -135,9 +119,7 @@ def __iter__(self): LightCurveCollectionType.INJECTEE, LightCurveCollectionType.STANDARD_AND_INJECTEE, ]: - for ( - injectable_light_curve_collection_iter - ) in injectable_light_curve_collection_iters: + for (injectable_light_curve_collection_iter) in injectable_light_curve_collection_iters: injectable_light_curve = next( injectable_light_curve_collection_iter ) @@ -152,18 +134,26 @@ def __iter__(self): @classmethod def new( - cls, - standard_light_curve_collections: list[LightCurveObservationCollection] - | None = None, - injectee_light_curve_collections: list[LightCurveObservationCollection] - | None = None, - injectable_light_curve_collections: list[LightCurveObservationCollection] - | None = None, - post_injection_transform: Callable[[Any], Any] | None = None, + cls, + standard_light_curve_collections: list[LightCurveObservationCollection] | None = None, + *, + injectee_light_curve_collections: list[LightCurveObservationCollection] | None = None, + injectable_light_curve_collections: list[LightCurveObservationCollection] | None = None, + post_injection_transform: Callable[[Any], Any] | None = None, ) -> Self: + """ + Creates a new light curve dataset. + + :param standard_light_curve_collections: The light curve collections to be used without injection. + :param injectee_light_curve_collections: The light curve collections that other light curves will be injected + into. + :param injectable_light_curve_collections: The light curve collections that will be injected into other light + curves. + :return: The light curve dataset. + """ if ( - standard_light_curve_collections is None - and injectee_light_curve_collections is None + standard_light_curve_collections is None + and injectee_light_curve_collections is None ): error_message = ( "Either the standard or injectee light curve collection lists must be specified. " @@ -178,7 +168,7 @@ def new( injectable_light_curve_collections = [] if post_injection_transform is None: post_injection_transform = partial( - default_light_curve_observation_post_injection_transform, length=2500 + default_light_curve_observation_post_injection_transform, length=3500 ) instance = cls( standard_light_curve_collections=standard_light_curve_collections, @@ -190,8 +180,8 @@ def new( def inject_light_curve( - injectee_observation: LightCurveObservation, - injectable_observation: LightCurveObservation, + injectee_observation: LightCurveObservation, + injectable_observation: LightCurveObservation, ) -> LightCurveObservation: ( fluxes_with_injected_signal, @@ -251,7 +241,7 @@ class LightCurveCollectionType(Enum): class InterleavedDataset(IterableDataset): def __init__(self, *datasets: IterableDataset): - self.datasets: tuple[IterableDataset] = datasets + self.datasets: tuple[IterableDataset, ...] = datasets @classmethod def new(cls, *datasets: IterableDataset): @@ -266,7 +256,7 @@ def __iter__(self): class ConcatenatedIterableDataset(IterableDataset): def __init__(self, *datasets: IterableDataset): - self.datasets: tuple[IterableDataset] = datasets + self.datasets: tuple[IterableDataset, ...] = datasets @classmethod def new(cls, *datasets: IterableDataset): @@ -298,20 +288,48 @@ def __iter__(self): def default_light_curve_observation_post_injection_transform( - x: LightCurveObservation, length: int + x: LightCurveObservation, + *, + length: int, + randomize: bool = True, ) -> (Tensor, Tensor): + """ + The default light curve observation post injection transforms. A set of transforms that is expected to work well for + a variety of use cases. + + :param x: The light curve observation to be transformed. + :param length: The length to make all light curves. + :param randomize: Whether to have randomization in the transforms. + :return: The transformed light curve observation. + """ x = remove_nan_flux_data_points_from_light_curve_observation(x) - x = randomly_roll_light_curve_observation(x) + if randomize: + x = randomly_roll_light_curve_observation(x) x = from_light_curve_observation_to_fluxes_array_and_label_array(x) - x = make_fluxes_and_label_array_uniform_length(x, length=length) + x = (make_uniform_length(x[0], length=length), x[1]) # Make the fluxes a uniform length. x = pair_array_to_tensor(x) x = (normalize_tensor_by_modified_z_score(x[0]), x[1]) return x -def default_light_curve_post_injection_transform(x: LightCurve, length: int) -> Tensor: +def default_light_curve_post_injection_transform( + x: LightCurve, + *, + length: int, + randomize: bool = True, +) -> Tensor: + """ + The default light curve post injection transforms. A set of transforms that is expected to work well for a variety + of use cases. + + :param x: The light curve to be transformed. + :param length: The length to make all light curves. + :param randomize: Whether to have randomization in the transforms. + :return: The transformed light curve. + """ x = remove_nan_flux_data_points_from_light_curve(x) - x = randomly_roll_light_curve(x) + if randomize: + x = randomly_roll_light_curve(x) x = x.fluxes x = make_uniform_length(x, length=length) x = torch.tensor(x, dtype=torch.float32) @@ -319,62 +337,6 @@ def default_light_curve_post_injection_transform(x: LightCurve, length: int) -> return x -def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor: - median = torch.median(tensor) - deviation_from_median = tensor - median - absolute_deviation_from_median = torch.abs(deviation_from_median) - median_absolute_deviation_from_median = torch.median(absolute_deviation_from_median) - if median_absolute_deviation_from_median != 0: - modified_z_score = ( - 0.6745 * deviation_from_median / median_absolute_deviation_from_median - ) - else: - modified_z_score = torch.zeros_like(tensor) - return modified_z_score - - -def make_fluxes_and_label_array_uniform_length( - arrays: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]], - length: int, - *, - randomize: bool = True, -) -> (np.ndarray, np.ndarray): - fluxes, label = arrays - uniform_length_times = make_uniform_length( - fluxes, length=length, randomize=randomize - ) - return uniform_length_times, label - - -def make_uniform_length( - example: np.ndarray, length: int, *, randomize: bool = True -) -> np.ndarray: - """Makes the example a specific length, by clipping those too large and repeating those too small.""" - if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases. - raise ValueError( - f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}" - ) - if randomize: - example = randomly_roll_elements(example) - if example.shape[0] == length: - pass - elif example.shape[0] > length: - example = example[:length] - else: - elements_to_repeat = length - example.shape[0] - if len(example.shape) == 1: - example = np.pad(example, (0, elements_to_repeat), mode="wrap") - else: - example = np.pad(example, ((0, elements_to_repeat), (0, 0)), mode="wrap") - return example - - -def randomly_roll_elements(example: np.ndarray) -> np.ndarray: - """Randomly rolls the elements.""" - example = np.roll(example, np.random.randint(example.shape[0]), axis=0) - return example - - class OutOfBoundsInjectionHandlingMethod(Enum): """ An enum of approaches for handling cases where the injectable signal is shorter than the injectee signal. @@ -395,14 +357,13 @@ class BaselineFluxEstimationMethod(Enum): def inject_signal_into_light_curve_with_intermediates( - light_curve_times: npt.NDArray[np.float64], - light_curve_fluxes: npt.NDArray[np.float64], - signal_times: npt.NDArray[np.float64], - signal_magnifications: npt.NDArray[np.float64], - out_of_bounds_injection_handling_method: OutOfBoundsInjectionHandlingMethod = ( - OutOfBoundsInjectionHandlingMethod.ERROR - ), - baseline_flux_estimation_method: BaselineFluxEstimationMethod = BaselineFluxEstimationMethod.MEDIAN, + light_curve_times: npt.NDArray[np.float64], + light_curve_fluxes: npt.NDArray[np.float64], + signal_times: npt.NDArray[np.float64], + signal_magnifications: npt.NDArray[np.float64], + out_of_bounds_injection_handling_method: OutOfBoundsInjectionHandlingMethod = ( + OutOfBoundsInjectionHandlingMethod.ERROR), + baseline_flux_estimation_method: BaselineFluxEstimationMethod = BaselineFluxEstimationMethod.MEDIAN, ) -> (npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]): """ Injects a synthetic magnification signal into real light curve fluxes. @@ -423,12 +384,12 @@ def inject_signal_into_light_curve_with_intermediates( light_curve_time_length = np.max(relative_light_curve_times) time_length_difference = light_curve_time_length - signal_time_length signal_start_offset = ( - np.random.random() * time_length_difference - ) + minimum_light_curve_time + np.random.random() * time_length_difference + ) + minimum_light_curve_time offset_signal_times = relative_signal_times + signal_start_offset if ( - baseline_flux_estimation_method - == BaselineFluxEstimationMethod.MEDIAN_ABSOLUTE_DEVIATION + baseline_flux_estimation_method + == BaselineFluxEstimationMethod.MEDIAN_ABSOLUTE_DEVIATION ): baseline_flux = stats.median_abs_deviation(light_curve_fluxes) baseline_to_median_absolute_deviation_ratio = ( @@ -439,16 +400,16 @@ def inject_signal_into_light_curve_with_intermediates( baseline_flux = np.median(light_curve_fluxes) signal_fluxes = (signal_magnifications * baseline_flux) - baseline_flux if ( - out_of_bounds_injection_handling_method - is OutOfBoundsInjectionHandlingMethod.RANDOM_INJECTION_LOCATION + out_of_bounds_injection_handling_method + is OutOfBoundsInjectionHandlingMethod.RANDOM_INJECTION_LOCATION ): signal_flux_interpolator = interp1d( offset_signal_times, signal_fluxes, bounds_error=False, fill_value=0 ) elif ( - out_of_bounds_injection_handling_method - is OutOfBoundsInjectionHandlingMethod.REPEAT_SIGNAL - and time_length_difference > 0 + out_of_bounds_injection_handling_method + is OutOfBoundsInjectionHandlingMethod.REPEAT_SIGNAL + and time_length_difference > 0 ): before_signal_gap = signal_start_offset - minimum_light_curve_time after_signal_gap = time_length_difference - before_signal_gap @@ -465,13 +426,13 @@ def inject_signal_into_light_curve_with_intermediates( repeated_signal_times = None for repeat_index in range(-before_repeats_needed, after_repeats_needed + 1): repeat_signal_start_offset = ( - signal_time_length + minimum_signal_time_step - ) * repeat_index + signal_time_length + minimum_signal_time_step + ) * repeat_index if repeated_signal_times is None: repeated_signal_times = offset_signal_times + repeat_signal_start_offset else: repeat_index_signal_times = ( - offset_signal_times + repeat_signal_start_offset + offset_signal_times + repeat_signal_start_offset ) repeated_signal_times = np.concatenate( [repeated_signal_times, repeat_index_signal_times] diff --git a/src/qusi/light_curve_observation.py b/src/qusi/internal/light_curve_observation.py similarity index 65% rename from src/qusi/light_curve_observation.py rename to src/qusi/internal/light_curve_observation.py index 5d88dd18..aa1cc86f 100644 --- a/src/qusi/light_curve_observation.py +++ b/src/qusi/internal/light_curve_observation.py @@ -3,7 +3,7 @@ from typing_extensions import Self -from qusi.light_curve import LightCurve, randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve +from qusi.internal.light_curve import LightCurve, randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve @dataclass @@ -34,6 +34,13 @@ def new(cls, light_curve: LightCurve, label: int) -> Self: def remove_nan_flux_data_points_from_light_curve_observation( light_curve_observation: LightCurveObservation, ) -> LightCurveObservation: + """ + Removes the NaN values from a light curve in a light curve observation. If there is a NaN in either the times or the + fluxes, both corresponding values are removed. + + :param light_curve_observation: The light curve observation. + :return: The light curve observation with NaN values removed. + """ light_curve_observation = deepcopy(light_curve_observation) light_curve_observation.light_curve = remove_nan_flux_data_points_from_light_curve( light_curve_observation.light_curve @@ -42,6 +49,13 @@ def remove_nan_flux_data_points_from_light_curve_observation( def randomly_roll_light_curve_observation(light_curve_observation: LightCurveObservation) -> LightCurveObservation: + """ + Randomly rolls a light curve observation. That is, a random position in the light curve is chosen, the light curve + is split at that point, and the order of the two halves are swapped. + + :param light_curve_observation: The light curve observation. + :return: The light curve observation with the rolled light curve. + """ light_curve_observation = deepcopy(light_curve_observation) light_curve_observation.light_curve = randomly_roll_light_curve(light_curve_observation.light_curve) return light_curve_observation diff --git a/src/qusi/internal/light_curve_transforms.py b/src/qusi/internal/light_curve_transforms.py new file mode 100644 index 00000000..7abdeeb7 --- /dev/null +++ b/src/qusi/internal/light_curve_transforms.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import numpy as np +import numpy.typing as npt +import torch +from torch import Tensor + +from qusi.internal.light_curve_observation import LightCurveObservation + + +def from_light_curve_observation_to_fluxes_array_and_label_array( + light_curve_observation: LightCurveObservation, +) -> (npt.NDArray[np.float32], npt.NDArray[np.float32]): + """ + Extracts the fluxes and label from a light curve observation. + + :param light_curve_observation: The light curve observation. + :return: The fluxes and label array. + """ + fluxes = light_curve_observation.light_curve.fluxes + label = light_curve_observation.label + return fluxes, np.array(label, dtype=np.float32) + + +def pair_array_to_tensor( + arrays: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]], +) -> (Tensor, Tensor): + """ + Converts a pair of arrays to a pair of tensors. + + :param arrays: The arrays to convert. + :return: The tensors. + """ + return torch.tensor(arrays[0], dtype=torch.float32), torch.tensor( + arrays[1], dtype=torch.float32 + ) + + +def randomly_roll_elements(example: np.ndarray) -> np.ndarray: + """Randomly rolls the elements.""" + example = np.roll(example, np.random.randint(example.shape[0]), axis=0) + return example + + +def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor: + """ + Normalizes a tensor by a modified z-score. That is, normalizes the values of the tensor based on the median + absolute deviation. + + :param tensor: The tensor to normalize. + :return: The normalized tensor. + """ + median = torch.median(tensor) + deviation_from_median = tensor - median + absolute_deviation_from_median = torch.abs(deviation_from_median) + median_absolute_deviation_from_median = torch.median(absolute_deviation_from_median) + if median_absolute_deviation_from_median != 0: + modified_z_score = ( + 0.6745 * deviation_from_median / median_absolute_deviation_from_median + ) + else: + modified_z_score = torch.zeros_like(tensor) + return modified_z_score + + +def make_uniform_length(example: np.ndarray, length: int) -> np.ndarray: + """Makes the example a specific length, by clipping those too large and repeating those too small.""" + if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases. + raise ValueError( + f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}" + ) + if example.shape[0] == length: + pass + elif example.shape[0] > length: + example = example[:length] + else: + elements_to_repeat = length - example.shape[0] + if len(example.shape) == 1: + example = np.pad(example, (0, elements_to_repeat), mode="wrap") + else: + example = np.pad(example, ((0, elements_to_repeat), (0, 0)), mode="wrap") + return example diff --git a/src/qusi/internal/logging.py b/src/qusi/internal/logging.py new file mode 100644 index 00000000..1a4b9f26 --- /dev/null +++ b/src/qusi/internal/logging.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import datetime +import logging +import re +import sys + +import stringcase + +logger_initialized = False + + +def create_default_formatter() -> logging.Formatter: + formatter = logging.Formatter('qusi [{asctime} {levelname} {name}] {message}', style='{') + return formatter + + +def set_up_default_logger(): + global logger_initialized # noqa PLW0603 : TODO: Probably a bad hack. Consider further. + if not logger_initialized: + formatter = create_default_formatter() + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.DEBUG) + handler.setFormatter(formatter) + logger = logging.getLogger('qusi') + logger.addHandler(handler) + logger.setLevel(logging.INFO) + logger.propagate = False + sys.excepthook = excepthook + logger_initialized = True + + +def excepthook(exc_type, exc_value, exc_traceback): + logger = logging.getLogger('qusi') + logger.critical(f'Uncaught exception at {datetime.datetime.now().astimezone()}:') + logger.handlers[0].flush() + sys.__excepthook__(exc_type, exc_value, exc_traceback) + + +def get_metric_name(metric_function): + metric_name = type(metric_function).__name__ + metric_name = camel_case_acronyms(metric_name) + metric_name = stringcase.snakecase(metric_name) + metric_name = metric_name.replace('_metric', '').replace('_loss', '') + return metric_name + + +def camel_case_acronyms(string: str) -> str: + def camel_case_single_acronym(string: str | None) -> str: + if string is None: + return '' + return stringcase.capitalcase(string.lower()) + + return re.sub( + r'([A-Z]{2,})([A-Z][a-z])|([A-Z]{2,})', + lambda match: ''.join(map(camel_case_single_acronym, [match.group(1), match.group(2), match.group(3)])), + string + ) diff --git a/src/qusi/single_dense_layer_model.py b/src/qusi/internal/single_dense_layer_model.py similarity index 100% rename from src/qusi/single_dense_layer_model.py rename to src/qusi/internal/single_dense_layer_model.py diff --git a/src/qusi/toy_light_curve_collection.py b/src/qusi/internal/toy_light_curve_collection.py similarity index 93% rename from src/qusi/toy_light_curve_collection.py rename to src/qusi/internal/toy_light_curve_collection.py index df3dd764..d22a0afe 100644 --- a/src/qusi/toy_light_curve_collection.py +++ b/src/qusi/internal/toy_light_curve_collection.py @@ -2,13 +2,13 @@ import numpy as np -from qusi.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset -from qusi.light_curve import LightCurve -from qusi.light_curve_collection import ( +from qusi.internal.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset +from qusi.internal.light_curve import LightCurve +from qusi.internal.light_curve_collection import ( LightCurveObservationCollection, create_constant_label_for_path_function, LightCurveCollection, ) -from qusi.light_curve_dataset import LightCurveDataset +from qusi.internal.light_curve_dataset import LightCurveDataset class ToyLightCurve: diff --git a/src/qusi/internal/train_hyperparameter_configuration.py b/src/qusi/internal/train_hyperparameter_configuration.py new file mode 100644 index 00000000..09bfb38e --- /dev/null +++ b/src/qusi/internal/train_hyperparameter_configuration.py @@ -0,0 +1,71 @@ +from dataclasses import dataclass + + +@dataclass +class TrainHyperparameterConfiguration: + """ + Hyperparameter configuration settings for a train session. + + :ivar cycles: The number of cycles to run. Cycles consist of one set of training steps and one set of validation + steps. They can be seen as analogous to epochs. However, as qusi datasets are often + infinite or have different length sub-collections, there is not always the exact equivalent of an + epoch, so cycles are used instead. + :ivar train_steps_per_cycle: The number of training steps per cycle. + :ivar validation_steps_per_cycle: The number of validation steps per cycle. + :ivar batch_size: The size of the batch for each train process. Each training step will use a number of observations + equal to this value multiplied by the number of train processes. + :ivar learning_rate: The learning rate. + :ivar optimizer_epsilon: The epsilon to be used by the optimizer. + :ivar weight_decay: The weight decay of the optimizer. + :ivar norm_based_gradient_clip: The norm based gradient clipping value. + """ + + cycles: int + train_steps_per_cycle: int + validation_steps_per_cycle: int + batch_size: int + learning_rate: float + optimizer_epsilon: float + weight_decay: float + norm_based_gradient_clip: float + + @classmethod + def new( + cls, + *, + cycles: int = 5000, + train_steps_per_cycle: int = 100, + validation_steps_per_cycle: int = 10, + batch_size: int = 100, + learning_rate: float = 1e-4, + optimizer_epsilon: float = 1e-7, + weight_decay: float = 0.0001, + norm_based_gradient_clip: float = 1.0, + ): + """ + Creates a new `TrainHyperparameterConfiguration`. + + :param cycles: The number of cycles to run. Cycles consist of one set of training steps and one set of validation + steps. They can be seen as analogous to epochs. However, as qusi datasets are often + infinite or have different length sub-collections, there is not always the exact equivalent of an + epoch, so cycles are used instead. + :param train_steps_per_cycle: The number of training steps per cycle. + :param validation_steps_per_cycle: The number of validation steps per cycle. + :param batch_size: The size of the batch for each train process. Each training step will use a number of observations + equal to this value multiplied by the number of train processes. + :param learning_rate: The learning rate. + :param optimizer_epsilon: The epsilon to be used by the optimizer. + :param weight_decay: The weight decay of the optimizer. + :param norm_based_gradient_clip: The norm based gradient clipping value. + :return: The hyperparameter configuration. + """ + return cls( + learning_rate=learning_rate, + optimizer_epsilon=optimizer_epsilon, + weight_decay=weight_decay, + batch_size=batch_size, + cycles=cycles, + train_steps_per_cycle=train_steps_per_cycle, + validation_steps_per_cycle=validation_steps_per_cycle, + norm_based_gradient_clip=norm_based_gradient_clip, + ) diff --git a/src/qusi/train_logging_configuration.py b/src/qusi/internal/train_logging_configuration.py similarity index 60% rename from src/qusi/train_logging_configuration.py rename to src/qusi/internal/train_logging_configuration.py index 2272a44a..6664c283 100644 --- a/src/qusi/train_logging_configuration.py +++ b/src/qusi/internal/train_logging_configuration.py @@ -20,11 +20,20 @@ class TrainLoggingConfiguration: @classmethod def new( - cls, - wandb_project: str | None = None, - wandb_entity: str | None = None, - additional_log_dictionary: dict[str, Any] | None = None, + cls, + *, + wandb_project: str | None = None, + wandb_entity: str | None = None, + additional_log_dictionary: dict[str, Any] | None = None, ): + """ + Creates a `TrainLoggingConfiguration`. + + :param wandb_project: The wandb project to log to. + :param wandb_entity: The wandb entity to log to. + :param additional_log_dictionary: The dictionary of additional values to log. + :return: The `TrainLoggingConfiguration`. + """ if additional_log_dictionary is None: additional_log_dictionary = {} return cls( diff --git a/src/qusi/train_session.py b/src/qusi/internal/train_session.py similarity index 80% rename from src/qusi/train_session.py rename to src/qusi/internal/train_session.py index c62e464e..50067025 100644 --- a/src/qusi/train_session.py +++ b/src/qusi/internal/train_session.py @@ -4,32 +4,44 @@ from pathlib import Path import numpy as np -import stringcase import torch from torch.nn import BCELoss, Module from torch.optim import AdamW from torch.utils.data import DataLoader -from torchmetrics.classification import BinaryAccuracy import wandb -from qusi.light_curve_dataset import InterleavedDataset, LightCurveDataset -from qusi.logging import set_up_default_logger -from qusi.train_hyperparameter_configuration import TrainHyperparameterConfiguration -from qusi.train_logging_configuration import TrainLoggingConfiguration -from qusi.wandb_liaison import wandb_commit, wandb_init, wandb_log +from torchmetrics.classification import BinaryAccuracy, BinaryAUROC + +from qusi.internal.light_curve_dataset import InterleavedDataset, LightCurveDataset +from qusi.internal.logging import set_up_default_logger, get_metric_name +from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration +from qusi.internal.train_logging_configuration import TrainLoggingConfiguration +from qusi.internal.wandb_liaison import wandb_commit, wandb_init, wandb_log logger = logging.getLogger(__name__) def train_session( - train_datasets: list[LightCurveDataset], - validation_datasets: list[LightCurveDataset], - model: Module, - loss_function: Module | None = None, - metric_functions: list[Module] | None = None, - hyperparameter_configuration: TrainHyperparameterConfiguration | None = None, - logging_configuration: TrainLoggingConfiguration | None = None, -): + train_datasets: list[LightCurveDataset], + validation_datasets: list[LightCurveDataset], + model: Module, + loss_function: Module | None = None, + metric_functions: list[Module] | None = None, + *, + hyperparameter_configuration: TrainHyperparameterConfiguration | None = None, + logging_configuration: TrainLoggingConfiguration | None = None, +) -> None: + """ + Runs a training session. + + :param train_datasets: The datasets to train on. + :param validation_datasets: The datasets to validate on. + :param model: The model to train. + :param loss_function: The loss function to train the model on. + :param metric_functions: A list of metric functions to record during the training process. + :param hyperparameter_configuration: The configuration of the hyperparameters + :param logging_configuration: The configuration of the logging. + """ if hyperparameter_configuration is None: hyperparameter_configuration = TrainHyperparameterConfiguration.new() if logging_configuration is None: @@ -37,16 +49,17 @@ def train_session( if loss_function is None: loss_function = BCELoss() if metric_functions is None: - metric_functions = [BinaryAccuracy()] + metric_functions = [BinaryAccuracy(), BinaryAUROC()] set_up_default_logger() + sessions_directory = Path("sessions") + sessions_directory.mkdir(exist_ok=True) wandb_init( process_rank=0, project=logging_configuration.wandb_project, entity=logging_configuration.wandb_entity, settings=wandb.Settings(start_method="thread"), + dir=sessions_directory, ) - sessions_directory = Path("sessions") - sessions_directory.mkdir(exist_ok=True) train_dataset = InterleavedDataset.new(*train_datasets) torch.multiprocessing.set_start_method("spawn") debug = False @@ -89,6 +102,7 @@ def train_session( for metric_function in metric_functions ] for _cycle_index in range(hyperparameter_configuration.cycles): + logger.info(f'Cycle {_cycle_index}') train_phase( dataloader=train_dataloader, model=model, @@ -112,13 +126,13 @@ def train_session( def train_phase( - dataloader, - model, - loss_function, - metric_functions: list[Module], - optimizer, - steps, - device, + dataloader, + model, + loss_function, + metric_functions: list[Module], + optimizer, + steps, + device, ): model.train() total_loss = 0 @@ -165,15 +179,8 @@ def train_phase( ) -def get_metric_name(metric_function): - metric_name = type(metric_function).__name__ - metric_name = stringcase.snakecase(metric_name) - metric_name = metric_name.replace("_metric", "").replace("_loss", "") - return metric_name - - def validation_phase( - dataloader, model, loss_function, metric_functions: list[Module], steps, device + dataloader, model, loss_function, metric_functions: list[Module], steps, device ): model.eval() validation_loss = 0 diff --git a/src/qusi/train_system_configuration.py b/src/qusi/internal/train_system_configuration.py similarity index 53% rename from src/qusi/train_system_configuration.py rename to src/qusi/internal/train_system_configuration.py index 7e5cf877..ac3153b1 100644 --- a/src/qusi/train_system_configuration.py +++ b/src/qusi/internal/train_system_configuration.py @@ -14,5 +14,17 @@ class TrainSystemConfiguration: preprocessing_processes_per_train_process: int @classmethod - def new(cls, preprocessing_processes_per_train_process: int = 10): + def new( + cls, + *, + preprocessing_processes_per_train_process: int = 10 + ): + """ + Creates a `TrainSystemConfiguration`. + + :param preprocessing_processes_per_train_process: The number of processes that are started to preprocess the data + per train process. The train session will create this many processes for each the train data and the validation + data. + :return: The `TrainSystemConfiguration`. + """ return cls(preprocessing_processes_per_train_process=preprocessing_processes_per_train_process) diff --git a/src/qusi/wandb_liaison.py b/src/qusi/internal/wandb_liaison.py similarity index 100% rename from src/qusi/wandb_liaison.py rename to src/qusi/internal/wandb_liaison.py diff --git a/src/qusi/light_curve_transforms.py b/src/qusi/light_curve_transforms.py deleted file mode 100644 index 4aeb703e..00000000 --- a/src/qusi/light_curve_transforms.py +++ /dev/null @@ -1,28 +0,0 @@ -import numpy as np -import numpy.typing as npt -import torch -from torch import Tensor - -from qusi.light_curve_observation import LightCurveObservation - - -def from_light_curve_observation_to_fluxes_array_and_label_array( - light_curve_observation: LightCurveObservation, -) -> (npt.NDArray[np.float32], npt.NDArray[np.float32]): - fluxes = light_curve_observation.light_curve.fluxes - label = light_curve_observation.label - return fluxes, np.array(label, dtype=np.float32) - - -def pair_array_to_tensor( - arrays: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]], -) -> (Tensor, Tensor): - return torch.tensor(arrays[0], dtype=torch.float32), torch.tensor( - arrays[1], dtype=torch.float32 - ) - - -def randomly_roll_elements(example: np.ndarray) -> np.ndarray: - """Randomly rolls the elements.""" - example = np.roll(example, np.random.randint(example.shape[0]), axis=0) - return example diff --git a/src/qusi/logging.py b/src/qusi/logging.py deleted file mode 100644 index 25b63467..00000000 --- a/src/qusi/logging.py +++ /dev/null @@ -1,32 +0,0 @@ -import datetime -import logging -import sys - -logger_initialized = False - - -def create_default_formatter() -> logging.Formatter: - formatter = logging.Formatter("qusi [{asctime} {levelname} {name}] {message}", style="{") - return formatter - - -def set_up_default_logger(): - global logger_initialized # noqa PLW0603 : TODO: Probably a bad hack. Consider further. - if not logger_initialized: - formatter = create_default_formatter() - handler = logging.StreamHandler(sys.stdout) - handler.setLevel(logging.DEBUG) - handler.setFormatter(formatter) - logger = logging.getLogger("qusi") - logger.addHandler(handler) - logger.setLevel(logging.INFO) - logger.propagate = False - sys.excepthook = excepthook - logger_initialized = True - - -def excepthook(exc_type, exc_value, exc_traceback): - logger = logging.getLogger("qusi") - logger.critical(f"Uncaught exception at {datetime.datetime.now().astimezone()}:") - logger.handlers[0].flush() - sys.__excepthook__(exc_type, exc_value, exc_traceback) diff --git a/src/qusi/model.py b/src/qusi/model.py new file mode 100644 index 00000000..02ecb910 --- /dev/null +++ b/src/qusi/model.py @@ -0,0 +1,8 @@ +""" +Neural network model related public interface. +""" +from qusi.internal.hadryss_model import Hadryss + +__all__ = [ + 'Hadryss', +] diff --git a/src/qusi/session.py b/src/qusi/session.py new file mode 100644 index 00000000..8d883245 --- /dev/null +++ b/src/qusi/session.py @@ -0,0 +1,23 @@ +""" +Session related public interface. +""" +from qusi.internal.device import get_device +from qusi.internal.finite_test_session import finite_datasets_test_session +from qusi.internal.infer_session import infer_session +from qusi.internal.infinite_datasets_test_session import infinite_datasets_test_session +from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration +from qusi.internal.train_logging_configuration import TrainLoggingConfiguration +from qusi.internal.train_system_configuration import TrainSystemConfiguration +from qusi.internal.train_session import train_session + +__all__ = [ + 'finite_datasets_test_session', + 'get_device', + 'infer_session', + 'infinite_datasets_test_session', + 'TrainHyperparameterConfiguration', + 'TrainLoggingConfiguration', + 'TrainSystemConfiguration', + 'train_session', +] + diff --git a/src/qusi/train_hyperparameter_configuration.py b/src/qusi/train_hyperparameter_configuration.py deleted file mode 100644 index 2f5dd864..00000000 --- a/src/qusi/train_hyperparameter_configuration.py +++ /dev/null @@ -1,44 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class TrainHyperparameterConfiguration: - """ - Hyperparameter configuration settings for a train session. - - :ivar batch_size: The size of the batch for each train process. Each training step will use a number of examples - equal to this value multiplied by the number of train processes. - :ivar cycles: The number of train cycles to run. - """ - - learning_rate: float - optimizer_epsilon: float - weight_decay: float - batch_size: int - cycles: int - train_steps_per_cycle: int - validation_steps_per_cycle: int - norm_based_gradient_clip: float - - @classmethod - def new( - cls, - learning_rate: float = 1e-4, - optimizer_epsilon: float = 1e-7, - weight_decay: float = 0.0001, - batch_size: int = 100, - train_steps_per_cycle: int = 100, - validation_steps_per_cycle: int = 10, - cycles: int = 5000, - norm_based_gradient_clip: float = 1.0, - ): - return cls( - learning_rate=learning_rate, - optimizer_epsilon=optimizer_epsilon, - weight_decay=weight_decay, - batch_size=batch_size, - cycles=cycles, - train_steps_per_cycle=train_steps_per_cycle, - validation_steps_per_cycle=validation_steps_per_cycle, - norm_based_gradient_clip=norm_based_gradient_clip, - ) diff --git a/src/qusi/transform.py b/src/qusi/transform.py new file mode 100644 index 00000000..ba3b8638 --- /dev/null +++ b/src/qusi/transform.py @@ -0,0 +1,23 @@ +""" +Data transform related public interface. +""" +from qusi.internal.light_curve import randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve +from qusi.internal.light_curve_dataset import default_light_curve_post_injection_transform, \ + default_light_curve_observation_post_injection_transform +from qusi.internal.light_curve_observation import remove_nan_flux_data_points_from_light_curve_observation, \ + randomly_roll_light_curve_observation +from qusi.internal.light_curve_transforms import from_light_curve_observation_to_fluxes_array_and_label_array, \ + pair_array_to_tensor, make_uniform_length, normalize_tensor_by_modified_z_score + +__all__ = [ + 'default_light_curve_post_injection_transform', + 'default_light_curve_observation_post_injection_transform', + 'from_light_curve_observation_to_fluxes_array_and_label_array', + 'make_uniform_length', + 'normalize_tensor_by_modified_z_score', + 'pair_array_to_tensor', + 'randomly_roll_light_curve', + 'randomly_roll_light_curve_observation', + 'remove_nan_flux_data_points_from_light_curve', + 'remove_nan_flux_data_points_from_light_curve_observation', +] diff --git a/tests/end_to_end_tests/test_toy_infer_session.py b/tests/end_to_end_tests/test_toy_infer_session.py index 9c3f046f..72384bc6 100644 --- a/tests/end_to_end_tests/test_toy_infer_session.py +++ b/tests/end_to_end_tests/test_toy_infer_session.py @@ -3,13 +3,13 @@ import numpy as np -from qusi.infer_session import infer_session -from qusi.device import get_device -from qusi.light_curve_dataset import ( +from qusi.internal.infer_session import infer_session +from qusi.internal.device import get_device +from qusi.internal.light_curve_dataset import ( default_light_curve_post_injection_transform, ) -from qusi.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel -from qusi.toy_light_curve_collection import get_toy_finite_light_curve_dataset +from qusi.internal.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel +from qusi.internal.toy_light_curve_collection import get_toy_finite_light_curve_dataset def test_toy_infer_session(): diff --git a/tests/end_to_end_tests/test_toy_train_session.py b/tests/end_to_end_tests/test_toy_train_session.py index 4833d8ca..47e17f0d 100644 --- a/tests/end_to_end_tests/test_toy_train_session.py +++ b/tests/end_to_end_tests/test_toy_train_session.py @@ -1,13 +1,13 @@ import os from functools import partial -from qusi.light_curve_dataset import ( +from qusi.internal.light_curve_dataset import ( default_light_curve_observation_post_injection_transform, ) -from qusi.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel -from qusi.toy_light_curve_collection import get_toy_dataset -from qusi.train_hyperparameter_configuration import TrainHyperparameterConfiguration -from qusi.train_session import train_session +from qusi.internal.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel +from qusi.internal.toy_light_curve_collection import get_toy_dataset +from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration +from qusi.internal.train_session import train_session def test_toy_train_session(): diff --git a/tests/unit_tests/logging.py b/tests/unit_tests/logging.py new file mode 100644 index 00000000..5e64016d --- /dev/null +++ b/tests/unit_tests/logging.py @@ -0,0 +1,13 @@ +from torch.nn import BCELoss +from torchmetrics.classification import BinaryAUROC + +from qusi.internal.logging import camel_case_acronyms, get_metric_name + + +def test_camel_case_acronyms(): + assert camel_case_acronyms('BCEntropy') == 'BcEntropy' + assert camel_case_acronyms('BinaryAUROC') == 'BinaryAuroc' + +def test_get_metric_name(): + assert get_metric_name(BCELoss()) == 'bce' + assert get_metric_name(BinaryAUROC()) == 'binary_auroc' diff --git a/tests/unit_tests/test_hydryss_model.py b/tests/unit_tests/test_hydryss_model.py index 7e62ec5c..b95dc2fe 100644 --- a/tests/unit_tests/test_hydryss_model.py +++ b/tests/unit_tests/test_hydryss_model.py @@ -1,6 +1,6 @@ import torch -from qusi.hadryss_model import Hadryss +from qusi.internal.hadryss_model import Hadryss def test_lengths_give_correct_output_size(): diff --git a/tests/unit_tests/test_light_curve_dataset.py b/tests/unit_tests/test_light_curve_dataset.py index 62f55e28..b2e5897f 100644 --- a/tests/unit_tests/test_light_curve_dataset.py +++ b/tests/unit_tests/test_light_curve_dataset.py @@ -2,7 +2,7 @@ from itertools import islice from unittest.mock import Mock -from qusi.light_curve_dataset import ( +from qusi.internal.light_curve_dataset import ( contains_injected_dataset, interleave_infinite_iterators, is_injected_dataset,