Skip to content

Commit

Permalink
Remove deprecation warning, add on_after_load_sample option
Browse files Browse the repository at this point in the history
	modified:   data.py
	modified:   simulator.py
  • Loading branch information
cweniger committed Oct 21, 2022
1 parent 47093fa commit af56155
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 4 additions & 2 deletions swyft/lightning/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
batch_size: int = 32,
num_workers: int = 0,
shuffle: bool = False,
on_after_load_sample: Optional[callable] = None,
):
super().__init__()
self.data = data
Expand All @@ -60,6 +61,7 @@ def __init__(
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.on_after_load_sample = on_after_load_sample

@staticmethod
def _get_lengths(fractions, N):
Expand All @@ -72,15 +74,15 @@ def _get_lengths(fractions, N):

def setup(self, stage: str):
if isinstance(self.data, Samples):
dataset = self.data.get_dataset()
dataset = self.data.get_dataset(on_after_load_sample = self.on_after_load_sample)
splits = torch.utils.data.random_split(dataset, self.lengths)
self.dataset_train, self.dataset_val, self.dataset_test = splits
elif isinstance(self.data, swyft.ZarrStore):
idxr1 = (0, self.lengths[1])
idxr2 = (self.lengths[1], self.lengths[1] + self.lengths[2])
idxr3 = (self.lengths[1] + self.lengths[2], len(self.data))
self.dataset_train = self.data.get_dataset(
idx_range=idxr1, on_after_load_sample=None
idx_range=idxr1, on_after_load_sample=self.on_after_load_sample
)
self.dataset_val = self.data.get_dataset(
idx_range=idxr2, on_after_load_sample=None
Expand Down
3 changes: 1 addition & 2 deletions swyft/lightning/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,14 @@ def get_dataloader(
repeat=None,
num_workers=0,
):
"""(Deprecated) Generator function to directly generate a dataloader object.
"""Generator function to directly generate a dataloader object.
Args:
batch_size: batch_size for dataloader
shuffle: shuffle for dataloader
on_after_load_sample: see `get_dataset`
repeat: If not None, Wrap dataset in RepeatDatasetWrapper
"""
print("WARNING: Deprecated")
dataset = self.get_dataset(on_after_load_sample=on_after_load_sample)
if repeat is not None:
dataset = swyft.lightning.data.RepeatDatasetWrapper(dataset, repeat=repeat)
Expand Down

0 comments on commit af56155

Please sign in to comment.