From 63fcfc5c3f6f143c5a8e4a2313d5f6dfbaaed465 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Wed, 15 May 2024 14:26:17 -0400 Subject: [PATCH] Remove randomization from make_uniform_length and leave that to a separate randomly roll call --- docs/source/tutorials/crafting_standard_datasets.md | 2 +- src/qusi/internal/light_curve_dataset.py | 4 ++-- src/qusi/internal/light_curve_transforms.py | 6 +----- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/docs/source/tutorials/crafting_standard_datasets.md b/docs/source/tutorials/crafting_standard_datasets.md index 89273d9..9c06cf4 100644 --- a/docs/source/tutorials/crafting_standard_datasets.md +++ b/docs/source/tutorials/crafting_standard_datasets.md @@ -39,7 +39,7 @@ def default_light_curve_observation_post_injection_transform(x: LightCurveObserv if randomize: x = randomly_roll_light_curve_observation(x) x = from_light_curve_observation_to_fluxes_array_and_label_array(x) - x = (make_uniform_length(x[0], length=length, randomize=randomize), x[1]) # Make the fluxes a uniform length. + x = (make_uniform_length(x[0], length=length), x[1]) x = pair_array_to_tensor(x) x = (normalize_tensor_by_modified_z_score(x[0]), x[1]) return x diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index b0f435c..6aad6bc 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -306,7 +306,7 @@ def default_light_curve_observation_post_injection_transform( if randomize: x = randomly_roll_light_curve_observation(x) x = from_light_curve_observation_to_fluxes_array_and_label_array(x) - x = (make_uniform_length(x[0], length=length, randomize=randomize), x[1]) # Make the fluxes a uniform length. + x = (make_uniform_length(x[0], length=length), x[1]) # Make the fluxes a uniform length. x = pair_array_to_tensor(x) x = (normalize_tensor_by_modified_z_score(x[0]), x[1]) return x @@ -331,7 +331,7 @@ def default_light_curve_post_injection_transform( if randomize: x = randomly_roll_light_curve(x) x = x.fluxes - x = make_uniform_length(x, length=length, randomize=randomize) + x = make_uniform_length(x, length=length) x = torch.tensor(x, dtype=torch.float32) x = normalize_tensor_by_modified_z_score(x) return x diff --git a/src/qusi/internal/light_curve_transforms.py b/src/qusi/internal/light_curve_transforms.py index f5afa28..7abdeeb 100644 --- a/src/qusi/internal/light_curve_transforms.py +++ b/src/qusi/internal/light_curve_transforms.py @@ -63,16 +63,12 @@ def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor: return modified_z_score -def make_uniform_length( - example: np.ndarray, length: int, *, randomize: bool = True -) -> np.ndarray: +def make_uniform_length(example: np.ndarray, length: int) -> np.ndarray: """Makes the example a specific length, by clipping those too large and repeating those too small.""" if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases. raise ValueError( f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}" ) - if randomize: - example = randomly_roll_elements(example) if example.shape[0] == length: pass elif example.shape[0] > length: