Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in DataLoader and tutorial5 notebook #376

Draft
wants to merge 7 commits into
base: 0.2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pina/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
__all__ = [
"Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset",
"PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph'
"PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph', 'LabelParameter'
]

from .meta import *
from .label_tensor import LabelTensor
from .label_tensor import LabelTensor, LabelParameter
from .solvers.solver import SolverInterface
from .trainer import Trainer
from .plotter import Plotter
Expand Down
4 changes: 2 additions & 2 deletions pina/condition/data_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ class DataConditionInterface(ConditionInterface):
"""

__slots__ = ["input_points", "conditional_variables"]
condition_type = ['unsupervised']

def __init__(self, input_points, conditional_variables=None):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
self.conditional_variables = conditional_variables
self._condition_type = 'unsupervised'

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'conditional_variables'):
Expand Down
5 changes: 2 additions & 3 deletions pina/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ class DomainEquationCondition(ConditionInterface):
"""

__slots__ = ["domain", "equation"]

condition_type = ['physics']
def __init__(self, domain, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.domain = domain
self.equation = equation
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'domain':
Expand Down
5 changes: 2 additions & 3 deletions pina/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ class InputPointsEquationCondition(ConditionInterface):
"""

__slots__ = ["input_points", "equation"]

condition_type = ['physics']
def __init__(self, input_points, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
self.equation = equation
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'input_points':
Expand Down
5 changes: 2 additions & 3 deletions pina/condition/input_output_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ class InputOutputPointsCondition(ConditionInterface):
"""

__slots__ = ["input_points", "output_points"]

condition_type = ['supervised']
def __init__(self, input_points, output_points):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
self.output_points = output_points
self._condition_type = ['supervised', 'physics']

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'output_points'):
Expand Down
64 changes: 54 additions & 10 deletions pina/data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""
Basic data module implementation
"""
import math
import torch
import logging

from torch.utils.data import Dataset

from ..label_tensor import LabelTensor
from .pina_subset import PinaSubset


class BaseDataset(Dataset):
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(self, problem=None, device=torch.device('cpu')):
super().__init__()
self.empty = True
self.problem = problem
self.device = device
self.device = torch.device('cpu')
self.condition_indices = None
for slot in self.__slots__:
setattr(self, slot, [])
Expand All @@ -53,7 +53,7 @@ def __init__(self, problem=None, device=torch.device('cpu')):

def _init_from_problem(self, collector_dict):
"""
TODO
TODO : Add docstring
"""
for name, data in collector_dict.items():
keys = list(data.keys())
Expand Down Expand Up @@ -109,14 +109,14 @@ def initialize(self):
already filled
"""
logging.debug(f'Initialize dataset {self.__class__.__name__}')

if self.num_el_per_condition:
self.condition_indices = torch.cat([
torch.tensor([i] * self.num_el_per_condition[i],
dtype=torch.uint8)
torch.tensor(
[self.conditions_idx[i]] * self.num_el_per_condition[i],
dtype=torch.uint8)
for i in range(len(self.num_el_per_condition))
],
dim=0)
dim=0)
for slot in self.__slots__:
current_attribute = getattr(self, slot)
if all(isinstance(a, LabelTensor) for a in current_attribute):
Expand Down Expand Up @@ -152,6 +152,50 @@ def apply_shuffle(self, indices):
if slot != 'equation':
attribute = getattr(self, slot)
if isinstance(attribute, (LabelTensor, torch.Tensor)):
setattr(self, 'slot', attribute[[indices]])
setattr(self, slot, attribute[[indices]].detach())
if isinstance(attribute, list):
setattr(self, 'slot', [attribute[i] for i in indices])
setattr(self, slot, [attribute[i] for i in indices])
self.condition_indices = self.condition_indices[indices]

def eval_splitting_lengths(self, lengths):
if sum(lengths) - 1 < 1e-3:
len_dataset = len(self)
lengths = [
int(math.floor(len_dataset * length)) for length in lengths
]
remainder = len(self) - sum(lengths)
for i in range(remainder):
lengths[i % len(lengths)] += 1
elif sum(lengths) - 1 >= 1e-3:
raise ValueError(f"Sum of lengths is {sum(lengths)} less than 1")
return lengths

def dataset_split(self, lengths, seed=None, shuffle=True):
"""
Perform the splitting of the dataset
:param dataset: dataset object we wanted to split
:param lengths: lengths of elements in dataset
:param seed: random seed
:param shuffle: shuffle dataset
:return: split dataset
:rtype: PinaSubset
"""

lengths = self.eval_splitting_lengths(lengths)

if shuffle:
if seed is not None:
generator = torch.Generator()
generator.manual_seed(seed)
indices = torch.randperm(sum(lengths), generator=generator)
else:
indices = torch.randperm(sum(lengths))
self.apply_shuffle(indices)

offsets = [
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
]
return [
PinaSubset(self, slice(offset, offset + length))
for offset, length in zip(offsets, lengths)
]
114 changes: 53 additions & 61 deletions pina/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
This module provide basic data management functionalities
"""

import math
import torch
import logging
from pytorch_lightning import LightningDataModule
from lightning.pytorch import LightningDataModule
from .sample_dataset import SamplePointDataset
from .supervised_dataset import SupervisedDataset
from .unsupervised_dataset import UnsupervisedDataset
from .pina_dataloader import PinaDataLoader
from .pina_subset import PinaSubset


class PinaDataModule(LightningDataModule):
Expand All @@ -23,8 +20,8 @@ def __init__(self,
problem,
device,
train_size=.7,
test_size=.1,
val_size=.2,
test_size=.2,
val_size=.1,
predict_size=0.,
batch_size=None,
shuffle=True,
Expand Down Expand Up @@ -61,28 +58,31 @@ def __init__(self,
if train_size > 0:
self.split_names.append('train')
self.split_length.append(train_size)
self.loader_functions['train_dataloader'] = lambda: PinaDataLoader(
self.splits['train'], self.batch_size, self.condition_names)
else:
self.train_dataloader = super().train_dataloader

if test_size > 0:
self.split_length.append(test_size)
self.split_names.append('test')
self.loader_functions['test_dataloader'] = lambda: PinaDataLoader(
self.splits['test'], self.batch_size, self.condition_names)
else:
self.test_dataloader = super().test_dataloader

if val_size > 0:
self.split_length.append(val_size)
self.split_names.append('val')
self.loader_functions['val_dataloader'] = lambda: PinaDataLoader(
self.splits['val'], self.batch_size, self.condition_names)
else:
self.val_dataloader = super().val_dataloader

if predict_size > 0:
self.split_length.append(predict_size)
self.split_names.append('predict')
self.loader_functions['predict_dataloader'] = lambda: PinaDataLoader(
self.splits['predict'], self.batch_size, self.condition_names)
else:
self.predict_dataloader = super().predict_dataloader

self.splits = {k: {} for k in self.split_names}
self.shuffle = shuffle

for k, v in self.loader_functions.items():
setattr(self, k, v)
self.has_setup_fit = False
self.has_setup_test = False

def prepare_data(self):
if self.datasets is None:
Expand All @@ -98,57 +98,21 @@ def setup(self, stage=None):
if stage == 'fit' or stage is None:
for dataset in self.datasets:
if len(dataset) > 0:
splits = self.dataset_split(dataset,
self.split_length,
shuffle=self.shuffle)
splits = dataset.dataset_split(
self.split_length,
shuffle=self.shuffle)
for i in range(len(self.split_length)):
self.splits[self.split_names[i]][
dataset.data_type] = splits[i]
self.has_setup_fit = True
elif stage == 'test':
raise NotImplementedError("Testing pipeline not implemented yet")
if self.has_setup_fit is False:
raise NotImplementedError(
"You must call setup with stage='fit' "
"first")
else:
raise ValueError("stage must be either 'fit' or 'test'")

@staticmethod
def dataset_split(dataset, lengths, seed=None, shuffle=True):
"""
Perform the splitting of the dataset
:param dataset: dataset object we wanted to split
:param lengths: lengths of elements in dataset
:param seed: random seed
:param shuffle: shuffle dataset
:return: split dataset
:rtype: PinaSubset
"""
if sum(lengths) - 1 < 1e-3:
len_dataset = len(dataset)
lengths = [
int(math.floor(len_dataset * length)) for length in lengths
]
remainder = len(dataset) - sum(lengths)
for i in range(remainder):
lengths[i % len(lengths)] += 1
elif sum(lengths) - 1 >= 1e-3:
raise ValueError(f"Sum of lengths is {sum(lengths)} less than 1")

if shuffle:
if seed is not None:
generator = torch.Generator()
generator.manual_seed(seed)
indices = torch.randperm(sum(lengths), generator=generator)
else:
indices = torch.randperm(sum(lengths))
dataset.apply_shuffle(indices)

indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist()
offsets = [
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
]
return [
PinaSubset(dataset, indices[offset:offset + length])
for offset, length in zip(offsets, lengths)
]

def _create_datasets(self):
"""
Create the dataset objects putting data
Expand Down Expand Up @@ -177,3 +141,31 @@ def _create_datasets(self):
dataset.initialize()
datasets.append(dataset)
self.datasets = datasets

def val_dataloader(self):
"""
Create the validation dataloader
"""
return PinaDataLoader(self.splits['val'], self.batch_size,
self.condition_names, device=self.device)

def train_dataloader(self):
"""
Create the training dataloader
"""
return PinaDataLoader(self.splits['train'], self.batch_size,
self.condition_names, device=self.device)

def test_dataloader(self):
"""
Create the testing dataloader
"""
return PinaDataLoader(self.splits['test'], self.batch_size,
self.condition_names, device=self.device)

def predict_dataloader(self):
"""
Create the prediction dataloader
"""
return PinaDataLoader(self.splits['predict'], self.batch_size,
self.condition_names, device=self.device)
Loading
Loading