Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in DataLoader and tutorial5 notebook #376

Draft
wants to merge 7 commits into
base: 0.2
Choose a base branch
from
Draft
2 changes: 1 addition & 1 deletion pina/condition/data_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class DataConditionInterface(ConditionInterface):
"""

__slots__ = ["input_points", "conditional_variables"]
condition_type = ['unsupervised']

def __init__(self, input_points, conditional_variables=None):
"""
Expand All @@ -23,7 +24,6 @@ def __init__(self, input_points, conditional_variables=None):
super().__init__()
self.input_points = input_points
self.conditional_variables = conditional_variables
self._condition_type = 'unsupervised'

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'conditional_variables'):
Expand Down
3 changes: 1 addition & 2 deletions pina/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ class DomainEquationCondition(ConditionInterface):
"""

__slots__ = ["domain", "equation"]

condition_type = ['physics']
def __init__(self, domain, equation):
"""
TODO
"""
super().__init__()
self.domain = domain
self.equation = equation
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'domain':
Expand Down
3 changes: 1 addition & 2 deletions pina/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ class InputPointsEquationCondition(ConditionInterface):
"""

__slots__ = ["input_points", "equation"]

condition_type = ['physics']
def __init__(self, input_points, equation):
"""
TODO
"""
super().__init__()
self.input_points = input_points
self.equation = equation
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'input_points':
Expand Down
3 changes: 1 addition & 2 deletions pina/condition/input_output_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ class InputOutputPointsCondition(ConditionInterface):
"""

__slots__ = ["input_points", "output_points"]

condition_type = ['supervised']
def __init__(self, input_points, output_points):
"""
TODO
"""
super().__init__()
self.input_points = input_points
self.output_points = output_points
self._condition_type = ['supervised', 'physics']

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'output_points'):
Expand Down
9 changes: 4 additions & 5 deletions pina/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging

from torch.utils.data import Dataset

from ..label_tensor import LabelTensor


Expand Down Expand Up @@ -109,14 +108,14 @@ def initialize(self):
already filled
"""
logging.debug(f'Initialize dataset {self.__class__.__name__}')

if self.num_el_per_condition:
self.condition_indices = torch.cat([
torch.tensor([i] * self.num_el_per_condition[i],
dtype=torch.uint8)
torch.tensor(
[self.conditions_idx[i]] * self.num_el_per_condition[i],
dtype=torch.uint8)
for i in range(len(self.num_el_per_condition))
],
dim=0)
dim=0)
for slot in self.__slots__:
current_attribute = getattr(self, slot)
if all(isinstance(a, LabelTensor) for a in current_attribute):
Expand Down
66 changes: 50 additions & 16 deletions pina/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def __init__(self,
problem,
device,
train_size=.7,
test_size=.1,
val_size=.2,
test_size=.2,
val_size=.1,
predict_size=0.,
batch_size=None,
shuffle=True,
Expand Down Expand Up @@ -61,28 +61,31 @@ def __init__(self,
if train_size > 0:
self.split_names.append('train')
self.split_length.append(train_size)
self.loader_functions['train_dataloader'] = lambda: PinaDataLoader(
self.splits['train'], self.batch_size, self.condition_names)
else:
self.train_dataloader = super().train_dataloader

if test_size > 0:
self.split_length.append(test_size)
self.split_names.append('test')
self.loader_functions['test_dataloader'] = lambda: PinaDataLoader(
self.splits['test'], self.batch_size, self.condition_names)
else:
self.test_dataloader = super().test_dataloader

if val_size > 0:
self.split_length.append(val_size)
self.split_names.append('val')
self.loader_functions['val_dataloader'] = lambda: PinaDataLoader(
self.splits['val'], self.batch_size, self.condition_names)
else:
self.val_dataloader = super().val_dataloader

if predict_size > 0:
self.split_length.append(predict_size)
self.split_names.append('predict')
self.loader_functions['predict_dataloader'] = lambda: PinaDataLoader(
self.splits['predict'], self.batch_size, self.condition_names)
else:
self.predict_dataloader = super().predict_dataloader

self.splits = {k: {} for k in self.split_names}
self.shuffle = shuffle

for k, v in self.loader_functions.items():
setattr(self, k, v)
self.has_setup_fit = False
self.has_setup_test = False

def prepare_data(self):
if self.datasets is None:
Expand All @@ -104,8 +107,12 @@ def setup(self, stage=None):
for i in range(len(self.split_length)):
self.splits[self.split_names[i]][
dataset.data_type] = splits[i]
self.has_setup_fit = True
elif stage == 'test':
raise NotImplementedError("Testing pipeline not implemented yet")
if self.has_setup_fit is False:
raise NotImplementedError(
"You must call setup with stage='fit' "
"first")
else:
raise ValueError("stage must be either 'fit' or 'test'")

Expand Down Expand Up @@ -140,12 +147,11 @@ def dataset_split(dataset, lengths, seed=None, shuffle=True):
indices = torch.randperm(sum(lengths))
dataset.apply_shuffle(indices)

indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist()
offsets = [
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
]
return [
PinaSubset(dataset, indices[offset:offset + length])
PinaSubset(dataset, slice(offset, offset + length))
for offset, length in zip(offsets, lengths)
]

Expand Down Expand Up @@ -177,3 +183,31 @@ def _create_datasets(self):
dataset.initialize()
datasets.append(dataset)
self.datasets = datasets

def val_dataloader(self):
"""
Create the validation dataloader
"""
return PinaDataLoader(self.splits['val'], self.batch_size,
self.condition_names)

def train_dataloader(self):
"""
Create the training dataloader
"""
return PinaDataLoader(self.splits['train'], self.batch_size,
self.condition_names)

def test_dataloader(self):
"""
Create the testing dataloader
"""
return PinaDataLoader(self.splits['test'], self.batch_size,
self.condition_names)

def predict_dataloader(self):
"""
Create the prediction dataloader
"""
return PinaDataLoader(self.splits['predict'], self.batch_size,
self.condition_names)
27 changes: 12 additions & 15 deletions pina/data/pina_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@ class Batch:
def __init__(self, dataset_dict, idx_dict, require_grad=True):
self.attributes = []
for k, v in dataset_dict.items():
setattr(self, k, v)
index = idx_dict[k]
if isinstance(v, PinaSubset):
dataset_index = v.indices
if isinstance(dataset_index, slice):
index = slice(dataset_index.start + index.start,
min(dataset_index.start + index.stop,
dataset_index.stop))
setattr(self, k, PinaSubset(v.dataset, index,
require_grad=require_grad))
self.attributes.append(k)

for k, v in idx_dict.items():
setattr(self, k + '_idx', v)
self.require_grad = require_grad

def __len__(self):
Expand All @@ -27,21 +32,13 @@ def __len__(self):
:rtype: int
"""
length = 0
for dataset in dir(self):
for dataset in self.attributes:
attribute = getattr(self, dataset)
if isinstance(attribute, list):
length += len(getattr(self, dataset))
length += len(attribute)
return length

def __getattribute__(self, item):
if item in super().__getattribute__('attributes'):
dataset = super().__getattribute__(item)
index = super().__getattribute__(item + '_idx')
return PinaSubset(dataset.dataset, dataset.indices[index])
return super().__getattribute__(item)

def __getattr__(self, item):
if item == 'data' and len(self.attributes) == 1:
item = self.attributes[0]
return super().__getattribute__(item)
return self.__getattribute__(item)
raise AttributeError(f"'Batch' object has no attribute '{item}'")
52 changes: 40 additions & 12 deletions pina/data/pina_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This module is used to create an iterable object used during training
"""
import math

from .pina_batch import Batch


Expand All @@ -26,6 +27,7 @@ def __init__(self, dataset_dict, batch_size, condition_names) -> None:
"""
self.condition_names = condition_names
self.dataset_dict = dataset_dict
self.batch_size = batch_size
self._init_batches(batch_size)

def _init_batches(self, batch_size=None):
Expand All @@ -36,20 +38,46 @@ def _init_batches(self, batch_size=None):
n_elements = sum(len(v) for v in self.dataset_dict.values())
if batch_size is None:
batch_size = n_elements
indexes_dict = {}
self.batch_size = n_elements
n_batches = int(math.ceil(n_elements / batch_size))
for k, v in self.dataset_dict.items():
if n_batches != 1:
indexes_dict[k] = math.floor(len(v) / (n_batches - 1))
else:
indexes_dict[k] = len(v)
for i in range(n_batches):
temp_dict = {}
for k, v in indexes_dict.items():
if i != n_batches - 1:
temp_dict[k] = slice(i * v, (i + 1) * v)
indexes_dict = {
k: math.floor(len(v) / n_batches) if n_batches != 1 else len(v) for
k, v in self.dataset_dict.items()}

dataset_names = list(self.dataset_dict.keys())
num_el_per_batch = [{i: indexes_dict[i] for i in dataset_names} for _
in range(n_batches - 1)] + [
{i: 0 for i in dataset_names}]
reminders = {
i: len(self.dataset_dict[i]) - indexes_dict[i] * (n_batches - 1) for
i in dataset_names}
dataset_names = iter(dataset_names)
name = next(dataset_names, None)
for batch in num_el_per_batch:
tot_num_el = sum(batch.values())
batch_reminder = batch_size - tot_num_el
for _ in range(batch_reminder):
if name is None:
break
if reminders[name] > 0:
batch[name] += 1
reminders[name] -= 1
else:
temp_dict[k] = slice(i * v, len(self.dataset_dict[k]))
name = next(dataset_names, None)
if name is None:
break
batch[name] += 1
reminders[name] -= 1

reminders, dataset_names, indexes_dict = None, None, None # free memory
actual_indices = {k: 0 for k in self.dataset_dict.keys()}
for batch in num_el_per_batch:
temp_dict = {}
total_length = 0
for k, v in batch.items():
temp_dict[k] = slice(actual_indices[k], actual_indices[k] + v)
actual_indices[k] = actual_indices[k] + v
total_length += v
self.batches.append(
Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict))

Expand Down
19 changes: 16 additions & 3 deletions pina/data/pina_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class PinaSubset:
"""
__slots__ = ['dataset', 'indices', 'require_grad']

def __init__(self, dataset, indices, require_grad=True):
def __init__(self, dataset, indices, require_grad=False):
"""
TODO
"""
Expand All @@ -23,14 +23,27 @@ def __len__(self):
"""
TODO
"""
if isinstance(self.indices, slice):
return self.indices.stop - self.indices.start
return len(self.indices)

def __getattr__(self, name):
tensor = self.dataset.__getattribute__(name)
if isinstance(tensor, (LabelTensor, Tensor)):
tensor = tensor[[self.indices]].to(self.dataset.device)
if isinstance(self.indices, slice):
tensor = tensor[self.indices]
if (tensor.device != self.dataset.device
and tensor.dtype == float32):
tensor = tensor.to(self.dataset.device)
elif isinstance(self.indices, list):
tensor = tensor[[self.indices]].to(self.dataset.device)
else:
raise ValueError(f"Indices type {type(self.indices)} not "
f"supported")
return tensor.requires_grad_(
self.require_grad) if tensor.dtype == float32 else tensor
if isinstance(tensor, list):
return [tensor[i] for i in self.indices]
if isinstance(self.indices, list):
return [tensor[i] for i in self.indices]
return tensor[self.indices]
raise AttributeError(f"No attribute named {name}")
6 changes: 3 additions & 3 deletions pina/domain/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,15 @@ def _single_points_sample(n, variables):
return result

if self.fixed_ and (not self.range_):
return _single_points_sample(n, variables)
return _single_points_sample(n, variables).sort_labels()

if variables == "all":
variables = list(self.range_.keys()) + list(self.fixed_.keys())

if mode in ["grid", "chebyshev"]:
return _1d_sampler(n, mode, variables)
return _1d_sampler(n, mode, variables).sort_labels()
elif mode in ["random", "lh", "latin"]:
return _Nd_sampler(n, mode, variables)
return _Nd_sampler(n, mode, variables).sort_labels()
else:
raise ValueError(f"mode={mode} is not valid.")

Expand Down
2 changes: 1 addition & 1 deletion pina/domain/operation_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _check_dimensions(self, geometries):
:type geometries: list[Location]
"""
for geometry in geometries:
if geometry.variables != geometries[0].variables:
if sorted(geometry.variables) != sorted(geometries[0].variables):
raise NotImplementedError(
f"The geometries need to have same dimensions and labels."
)
Loading
Loading