Skip to content

Commit

Permalink
Merge pull request #71
Browse files Browse the repository at this point in the history
corrected_dataloader_workers
  • Loading branch information
golmschenk authored Jul 27, 2024
2 parents 8da5e8a + 289ba8d commit 60ad5e7
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 20 deletions.
27 changes: 18 additions & 9 deletions src/qusi/internal/light_curve_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __getitem__(self, indexes: int | tuple[int]) -> Path | tuple[Path]:


class PathGetterBase(PathIterableBase, PathIndexableBase):
pass
random_number_generator: Random


@dataclass
Expand Down Expand Up @@ -265,17 +265,26 @@ def observation_iter(self) -> Iterator[LightCurveObservation]:
:return: The iterable of the light curves.
"""
light_curve_paths = self.path_iter()
for light_curve_path in light_curve_paths:
light_curve_observation = self.observation_from_path(light_curve_path)
yield light_curve_observation

def observation_from_path(self, light_curve_path: Path) -> LightCurveObservation:
times, fluxes = self.light_curve_collection.load_times_and_fluxes_from_path(
light_curve_path
)
label = self.load_label_from_path_function(light_curve_path)
light_curve = LightCurve.new(times, fluxes)
light_curve_observation = LightCurveObservation.new(light_curve, label)
light_curve_observation.path = light_curve_path # TODO: Quick debug hack.
return light_curve_observation

def path_iter(self) -> Iterable[Path]:
light_curve_paths = self.path_getter.get_shuffled_paths()
if len(light_curve_paths) == 0:
raise ValueError('LightCurveObservationCollection returned no paths.')
for light_curve_path in light_curve_paths:
times, fluxes = self.light_curve_collection.load_times_and_fluxes_from_path(
light_curve_path
)
label = self.load_label_from_path_function(light_curve_path)
light_curve = LightCurve.new(times, fluxes)
light_curve_observation = LightCurveObservation.new(light_curve, label)
yield light_curve_observation
return light_curve_paths

def __getitem__(self, index: int) -> LightCurveObservation:
light_curve_path = self.path_getter[index]
Expand Down
44 changes: 33 additions & 11 deletions src/qusi/internal/light_curve_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

import copy
import itertools
import math
import re
import shutil
import socket
from enum import Enum
from functools import partial
from pathlib import Path
from random import Random
from typing import TYPE_CHECKING, Any, Callable, TypeVar

import numpy as np
Expand Down Expand Up @@ -75,52 +77,64 @@ def __init__(
)
raise ValueError(error_message)
self.post_injection_transform: Callable[[Any], Any] = post_injection_transform
self.worker_randomizing_set: bool = False

def __iter__(self):
if not self.worker_randomizing_set:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self.seed_random(worker_info.id)
self.worker_randomizing_set = True
base_light_curve_collection_iter_and_type_pairs: list[
tuple[Iterator[LightCurveObservation], LightCurveCollectionType]
tuple[Iterator[Path], Callable[[Path], LightCurveObservation], LightCurveCollectionType]
] = []
injectee_collections = copy.copy(self.injectee_light_curve_collections)
for standard_collection in self.standard_light_curve_collections:
if standard_collection in injectee_collections:
base_light_curve_collection_iter_and_type_pairs.append(
(
loop_iter_function(standard_collection.observation_iter),
loop_iter_function(standard_collection.path_iter),
standard_collection.observation_from_path,
LightCurveCollectionType.STANDARD_AND_INJECTEE,
)
)
injectee_collections.remove(standard_collection)
else:
base_light_curve_collection_iter_and_type_pairs.append(
(
loop_iter_function(standard_collection.observation_iter),
loop_iter_function(standard_collection.path_iter),
standard_collection.observation_from_path,
LightCurveCollectionType.STANDARD,
)
)
for injectee_collection in injectee_collections:
base_light_curve_collection_iter_and_type_pair = (
loop_iter_function(injectee_collection.observation_iter),
loop_iter_function(injectee_collection.path_iter),
injectee_collection.observation_from_path,
LightCurveCollectionType.INJECTEE,
)
base_light_curve_collection_iter_and_type_pairs.append(base_light_curve_collection_iter_and_type_pair)
injectable_light_curve_collection_iters: list[
Iterator[LightCurveObservation]
tuple[Iterator[Path], Callable[[Path], LightCurveObservation]]
] = []
for injectable_collection in self.injectable_light_curve_collections:
injectable_light_curve_collection_iter = loop_iter_function(injectable_collection.observation_iter)
injectable_light_curve_collection_iters.append(injectable_light_curve_collection_iter)
injectable_light_curve_collection_iter = loop_iter_function(injectable_collection.path_iter)
injectable_light_curve_collection_iters.append(
(injectable_light_curve_collection_iter, injectable_collection.observation_from_path))
while True:
for (
base_light_curve_collection_iter_and_type_pair
) in base_light_curve_collection_iter_and_type_pairs:
(base_collection_iter, collection_type) = base_light_curve_collection_iter_and_type_pair
(base_collection_iter, observation_from_path_function,
collection_type) = base_light_curve_collection_iter_and_type_pair
if collection_type in [
LightCurveCollectionType.STANDARD,
LightCurveCollectionType.STANDARD_AND_INJECTEE,
]:
# TODO: Preprocessing step should be here. Or maybe that should all be on the light curve collection
# as well? Or passed in somewhere else?
standard_light_curve = next(base_collection_iter)
standard_path = next(base_collection_iter)
standard_light_curve = observation_from_path_function(standard_path)
transformed_standard_light_curve = self.post_injection_transform(
standard_light_curve
)
Expand All @@ -129,10 +143,12 @@ def __iter__(self):
LightCurveCollectionType.INJECTEE,
LightCurveCollectionType.STANDARD_AND_INJECTEE,
]:
for (injectable_light_curve_collection_iter) in injectable_light_curve_collection_iters:
injectable_light_curve = next(
for (injectable_light_curve_collection_iter,
injectable_observation_from_path_function) in injectable_light_curve_collection_iters:
injectable_light_path = next(
injectable_light_curve_collection_iter
)
injectable_light_curve = injectable_observation_from_path_function(injectable_light_path)
injectee_light_curve = next(base_collection_iter)
injected_light_curve = inject_light_curve(
injectee_light_curve, injectable_light_curve
Expand Down Expand Up @@ -188,6 +204,12 @@ def new(
)
return instance

def seed_random(self, seed: int):
for collection_group in [self.standard_light_curve_collections, self.injectee_light_curve_collections,
self.injectable_light_curve_collections]:
for collection in collection_group:
collection.path_getter.random_number_generator = Random(seed)


def inject_light_curve(
injectee_observation: LightCurveObservation,
Expand Down
Empty file.
42 changes: 42 additions & 0 deletions tests/integration_tests/test_light_curve_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from pathlib import Path

import numpy as np
import numpy.typing as npt
import torch
from torch.utils.data import DataLoader

from qusi.internal.light_curve_collection import LightCurveObservationCollection
from qusi.internal.light_curve_dataset import LightCurveDataset
from qusi.internal.light_curve_transforms import from_light_curve_observation_to_fluxes_array_and_label_array, \
pair_array_to_tensor


def get_paths() -> list[Path]:
return [Path('1'), Path('2'), Path('3'), Path('4'), Path('5'), Path('6'), Path('7'), Path('8')]

def load_times_and_fluxes_from_path(path: Path) -> [npt.NDArray, npt.NDArray]:
value = float(str(path))
return np.array([value]), np.array([value])

def load_label_from_path_function(path: Path) -> int:
value = int(str(path))
return value * 10

def post_injection_transform(x):
x = from_light_curve_observation_to_fluxes_array_and_label_array(x)
x = pair_array_to_tensor(x)
return x


def test_light_curve_dataset_with_and_without_multiple_workers_gives_same_batch_order():
light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=load_label_from_path_function)
light_curve_dataset = LightCurveDataset.new(standard_light_curve_collections=[light_curve_collection],
post_injection_transform=post_injection_transform)
multi_worker_dataloader = DataLoader(light_curve_dataset, batch_size=4, num_workers=2, prefetch_factor=1)
multi_worker_dataloader_iter = iter(multi_worker_dataloader)
multi_worker_batch0 = next(multi_worker_dataloader_iter)[0].numpy()[:, 0]
multi_worker_batch1 = next(multi_worker_dataloader_iter)[0].numpy()[:, 0]
assert not np.array_equal(multi_worker_batch0, multi_worker_batch1)

0 comments on commit 60ad5e7

Please sign in to comment.