Skip to content

Commit

Permalink
Add required keyword parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed May 2, 2024
1 parent 7c68be6 commit 7b7ea94
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 110 deletions.
2 changes: 1 addition & 1 deletion examples/transit_infinite_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main():


def infinite_datasets_test_session(test_datasets: list[LightCurveDataset], model: Module,
metric_functions: list[Module], batch_size: int, device: Device, steps: int):
metric_functions: list[Module], *, batch_size: int, device: Device, steps: int):
test_dataloaders: list[DataLoader] = []
for test_dataset in test_datasets:
test_dataloaders.append(DataLoader(test_dataset, batch_size=batch_size, pin_memory=True))
Expand Down
11 changes: 6 additions & 5 deletions src/qusi/internal/finite_test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@


def finite_datasets_test_session(
test_datasets: list[FiniteStandardLightCurveObservationDataset],
model: Module,
metric_functions: list[Module],
batch_size: int,
device: Device,
test_datasets: list[FiniteStandardLightCurveObservationDataset],
model: Module,
*,
metric_functions: list[Module],
batch_size: int,
device: Device,
):
test_dataloaders: list[DataLoader] = []
for test_dataset in test_datasets:
Expand Down
2 changes: 1 addition & 1 deletion src/qusi/internal/hadryss_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Hadryss(Module):
A 1D convolutional neural network model for light curve data that will auto-size itself for a given input light
curve length.
"""
def __init__(self, input_length: int):
def __init__(self, *, input_length: int):
super().__init__()
self.input_length: int = input_length
pooling_sizes, dense_size = self.determine_block_pooling_sizes_and_dense_size()
Expand Down
2 changes: 1 addition & 1 deletion src/qusi/internal/infer_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def infer_session(
infer_datasets: list[FiniteStandardLightCurveDataset], model: Module, batch_size: int, device: Device
infer_datasets: list[FiniteStandardLightCurveDataset], model: Module, *, batch_size: int, device: Device
) -> list[np.ndarray]:
infer_dataloaders: list[DataLoader] = []
for infer_dataset in infer_datasets:
Expand Down
124 changes: 51 additions & 73 deletions src/qusi/internal/light_curve_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,17 @@ class LightCurveDataset(IterableDataset):
"""

def __init__(
self,
standard_light_curve_collections: list[LightCurveObservationCollection],
injectee_light_curve_collections: list[LightCurveObservationCollection],
injectable_light_curve_collections: list[LightCurveObservationCollection],
post_injection_transform: Callable[[Any], Any],
self,
standard_light_curve_collections: list[LightCurveObservationCollection],
injectee_light_curve_collections: list[LightCurveObservationCollection],
injectable_light_curve_collections: list[LightCurveObservationCollection],
post_injection_transform: Callable[[Any], Any],
):
self.standard_light_curve_collections: list[
LightCurveObservationCollection
] = standard_light_curve_collections
self.injectee_light_curve_collections: list[
LightCurveObservationCollection
] = injectee_light_curve_collections
self.standard_light_curve_collections: list[LightCurveObservationCollection] = standard_light_curve_collections
self.injectee_light_curve_collections: list[LightCurveObservationCollection] = injectee_light_curve_collections
self.injectable_light_curve_collections: list[
LightCurveObservationCollection
] = injectable_light_curve_collections
if (
len(self.standard_light_curve_collections) == 0
and len(self.injectee_light_curve_collections) == 0
):
LightCurveObservationCollection] = injectable_light_curve_collections
if len(self.standard_light_curve_collections) == 0 and len(self.injectee_light_curve_collections) == 0:
error_message = (
"Either the standard or injectee light curve collection lists must not be empty. "
"Both were empty."
Expand Down Expand Up @@ -99,27 +91,18 @@ def __iter__(self):
loop_iter_function(injectee_collection.observation_iter),
LightCurveCollectionType.INJECTEE,
)
base_light_curve_collection_iter_and_type_pairs.append(
base_light_curve_collection_iter_and_type_pair
)
base_light_curve_collection_iter_and_type_pairs.append(base_light_curve_collection_iter_and_type_pair)
injectable_light_curve_collection_iters: list[
Iterator[LightCurveObservation]
] = []
for injectable_collection in self.injectable_light_curve_collections:
injectable_light_curve_collection_iter = loop_iter_function(
injectable_collection.observation_iter
)
injectable_light_curve_collection_iters.append(
injectable_light_curve_collection_iter
)
injectable_light_curve_collection_iter = loop_iter_function(injectable_collection.observation_iter)
injectable_light_curve_collection_iters.append(injectable_light_curve_collection_iter)
while True:
for (
base_light_curve_collection_iter_and_type_pair
base_light_curve_collection_iter_and_type_pair
) in base_light_curve_collection_iter_and_type_pairs:
(
base_collection_iter,
collection_type,
) = base_light_curve_collection_iter_and_type_pair
(base_collection_iter, collection_type) = base_light_curve_collection_iter_and_type_pair
if collection_type in [
LightCurveCollectionType.STANDARD,
LightCurveCollectionType.STANDARD_AND_INJECTEE,
Expand All @@ -135,9 +118,7 @@ def __iter__(self):
LightCurveCollectionType.INJECTEE,
LightCurveCollectionType.STANDARD_AND_INJECTEE,
]:
for (
injectable_light_curve_collection_iter
) in injectable_light_curve_collection_iters:
for (injectable_light_curve_collection_iter) in injectable_light_curve_collection_iters:
injectable_light_curve = next(
injectable_light_curve_collection_iter
)
Expand All @@ -152,18 +133,16 @@ def __iter__(self):

@classmethod
def new(
cls,
standard_light_curve_collections: list[LightCurveObservationCollection]
| None = None,
injectee_light_curve_collections: list[LightCurveObservationCollection]
| None = None,
injectable_light_curve_collections: list[LightCurveObservationCollection]
| None = None,
post_injection_transform: Callable[[Any], Any] | None = None,
cls,
*,
standard_light_curve_collections: list[LightCurveObservationCollection] | None = None,
injectee_light_curve_collections: list[LightCurveObservationCollection] | None = None,
injectable_light_curve_collections: list[LightCurveObservationCollection] | None = None,
post_injection_transform: Callable[[Any], Any] | None = None,
) -> Self:
if (
standard_light_curve_collections is None
and injectee_light_curve_collections is None
standard_light_curve_collections is None
and injectee_light_curve_collections is None
):
error_message = (
"Either the standard or injectee light curve collection lists must be specified. "
Expand All @@ -190,8 +169,8 @@ def new(


def inject_light_curve(
injectee_observation: LightCurveObservation,
injectable_observation: LightCurveObservation,
injectee_observation: LightCurveObservation,
injectable_observation: LightCurveObservation,
) -> LightCurveObservation:
(
fluxes_with_injected_signal,
Expand Down Expand Up @@ -298,7 +277,7 @@ def __iter__(self):


def default_light_curve_observation_post_injection_transform(
x: LightCurveObservation, length: int
x: LightCurveObservation, length: int
) -> (Tensor, Tensor):
x = remove_nan_flux_data_points_from_light_curve_observation(x)
x = randomly_roll_light_curve_observation(x)
Expand Down Expand Up @@ -326,18 +305,18 @@ def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor:
median_absolute_deviation_from_median = torch.median(absolute_deviation_from_median)
if median_absolute_deviation_from_median != 0:
modified_z_score = (
0.6745 * deviation_from_median / median_absolute_deviation_from_median
0.6745 * deviation_from_median / median_absolute_deviation_from_median
)
else:
modified_z_score = torch.zeros_like(tensor)
return modified_z_score


def make_fluxes_and_label_array_uniform_length(
arrays: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]],
length: int,
*,
randomize: bool = True,
arrays: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]],
length: int,
*,
randomize: bool = True,
) -> (np.ndarray, np.ndarray):
fluxes, label = arrays
uniform_length_times = make_uniform_length(
Expand All @@ -347,7 +326,7 @@ def make_fluxes_and_label_array_uniform_length(


def make_uniform_length(
example: np.ndarray, length: int, *, randomize: bool = True
example: np.ndarray, length: int, *, randomize: bool = True
) -> np.ndarray:
"""Makes the example a specific length, by clipping those too large and repeating those too small."""
if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases.
Expand Down Expand Up @@ -395,14 +374,13 @@ class BaselineFluxEstimationMethod(Enum):


def inject_signal_into_light_curve_with_intermediates(
light_curve_times: npt.NDArray[np.float64],
light_curve_fluxes: npt.NDArray[np.float64],
signal_times: npt.NDArray[np.float64],
signal_magnifications: npt.NDArray[np.float64],
out_of_bounds_injection_handling_method: OutOfBoundsInjectionHandlingMethod = (
OutOfBoundsInjectionHandlingMethod.ERROR
),
baseline_flux_estimation_method: BaselineFluxEstimationMethod = BaselineFluxEstimationMethod.MEDIAN,
light_curve_times: npt.NDArray[np.float64],
light_curve_fluxes: npt.NDArray[np.float64],
signal_times: npt.NDArray[np.float64],
signal_magnifications: npt.NDArray[np.float64],
out_of_bounds_injection_handling_method: OutOfBoundsInjectionHandlingMethod = (
OutOfBoundsInjectionHandlingMethod.ERROR),
baseline_flux_estimation_method: BaselineFluxEstimationMethod = BaselineFluxEstimationMethod.MEDIAN,
) -> (npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]):
"""
Injects a synthetic magnification signal into real light curve fluxes.
Expand All @@ -423,12 +401,12 @@ def inject_signal_into_light_curve_with_intermediates(
light_curve_time_length = np.max(relative_light_curve_times)
time_length_difference = light_curve_time_length - signal_time_length
signal_start_offset = (
np.random.random() * time_length_difference
) + minimum_light_curve_time
np.random.random() * time_length_difference
) + minimum_light_curve_time
offset_signal_times = relative_signal_times + signal_start_offset
if (
baseline_flux_estimation_method
== BaselineFluxEstimationMethod.MEDIAN_ABSOLUTE_DEVIATION
baseline_flux_estimation_method
== BaselineFluxEstimationMethod.MEDIAN_ABSOLUTE_DEVIATION
):
baseline_flux = stats.median_abs_deviation(light_curve_fluxes)
baseline_to_median_absolute_deviation_ratio = (
Expand All @@ -439,16 +417,16 @@ def inject_signal_into_light_curve_with_intermediates(
baseline_flux = np.median(light_curve_fluxes)
signal_fluxes = (signal_magnifications * baseline_flux) - baseline_flux
if (
out_of_bounds_injection_handling_method
is OutOfBoundsInjectionHandlingMethod.RANDOM_INJECTION_LOCATION
out_of_bounds_injection_handling_method
is OutOfBoundsInjectionHandlingMethod.RANDOM_INJECTION_LOCATION
):
signal_flux_interpolator = interp1d(
offset_signal_times, signal_fluxes, bounds_error=False, fill_value=0
)
elif (
out_of_bounds_injection_handling_method
is OutOfBoundsInjectionHandlingMethod.REPEAT_SIGNAL
and time_length_difference > 0
out_of_bounds_injection_handling_method
is OutOfBoundsInjectionHandlingMethod.REPEAT_SIGNAL
and time_length_difference > 0
):
before_signal_gap = signal_start_offset - minimum_light_curve_time
after_signal_gap = time_length_difference - before_signal_gap
Expand All @@ -465,13 +443,13 @@ def inject_signal_into_light_curve_with_intermediates(
repeated_signal_times = None
for repeat_index in range(-before_repeats_needed, after_repeats_needed + 1):
repeat_signal_start_offset = (
signal_time_length + minimum_signal_time_step
) * repeat_index
signal_time_length + minimum_signal_time_step
) * repeat_index
if repeated_signal_times is None:
repeated_signal_times = offset_signal_times + repeat_signal_start_offset
else:
repeat_index_signal_times = (
offset_signal_times + repeat_signal_start_offset
offset_signal_times + repeat_signal_start_offset
)
repeated_signal_times = np.concatenate(
[repeated_signal_times, repeat_index_signal_times]
Expand Down
19 changes: 10 additions & 9 deletions src/qusi/internal/train_hyperparameter_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ class TrainHyperparameterConfiguration:

@classmethod
def new(
cls,
learning_rate: float = 1e-4,
optimizer_epsilon: float = 1e-7,
weight_decay: float = 0.0001,
batch_size: int = 100,
train_steps_per_cycle: int = 100,
validation_steps_per_cycle: int = 10,
cycles: int = 5000,
norm_based_gradient_clip: float = 1.0,
cls,
*,
learning_rate: float = 1e-4,
optimizer_epsilon: float = 1e-7,
weight_decay: float = 0.0001,
batch_size: int = 100,
train_steps_per_cycle: int = 100,
validation_steps_per_cycle: int = 10,
cycles: int = 5000,
norm_based_gradient_clip: float = 1.0,
):
return cls(
learning_rate=learning_rate,
Expand Down
9 changes: 5 additions & 4 deletions src/qusi/internal/train_logging_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ class TrainLoggingConfiguration:

@classmethod
def new(
cls,
wandb_project: str | None = None,
wandb_entity: str | None = None,
additional_log_dictionary: dict[str, Any] | None = None,
cls,
*,
wandb_project: str | None = None,
wandb_entity: str | None = None,
additional_log_dictionary: dict[str, Any] | None = None,
):
if additional_log_dictionary is None:
additional_log_dictionary = {}
Expand Down
31 changes: 16 additions & 15 deletions src/qusi/internal/train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@


def train_session(
train_datasets: list[LightCurveDataset],
validation_datasets: list[LightCurveDataset],
model: Module,
loss_function: Module | None = None,
metric_functions: list[Module] | None = None,
hyperparameter_configuration: TrainHyperparameterConfiguration | None = None,
logging_configuration: TrainLoggingConfiguration | None = None,
train_datasets: list[LightCurveDataset],
validation_datasets: list[LightCurveDataset],
model: Module,
loss_function: Module | None = None,
metric_functions: list[Module] | None = None,
*,
hyperparameter_configuration: TrainHyperparameterConfiguration | None = None,
logging_configuration: TrainLoggingConfiguration | None = None,
):
if hyperparameter_configuration is None:
hyperparameter_configuration = TrainHyperparameterConfiguration.new()
Expand Down Expand Up @@ -112,13 +113,13 @@ def train_session(


def train_phase(
dataloader,
model,
loss_function,
metric_functions: list[Module],
optimizer,
steps,
device,
dataloader,
model,
loss_function,
metric_functions: list[Module],
optimizer,
steps,
device,
):
model.train()
total_loss = 0
Expand Down Expand Up @@ -173,7 +174,7 @@ def get_metric_name(metric_function):


def validation_phase(
dataloader, model, loss_function, metric_functions: list[Module], steps, device
dataloader, model, loss_function, metric_functions: list[Module], steps, device
):
model.eval()
validation_loss = 0
Expand Down
6 changes: 5 additions & 1 deletion src/qusi/internal/train_system_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ class TrainSystemConfiguration:
preprocessing_processes_per_train_process: int

@classmethod
def new(cls, preprocessing_processes_per_train_process: int = 10):
def new(
cls,
*,
preprocessing_processes_per_train_process: int = 10
):
return cls(preprocessing_processes_per_train_process=preprocessing_processes_per_train_process)
Loading

0 comments on commit 7b7ea94

Please sign in to comment.