Skip to content

Commit

Permalink
Bug fix in LabelTensor, Dataset and DataLoader and codacy warning cor…
Browse files Browse the repository at this point in the history
…rection
  • Loading branch information
FilippoOlivo committed Nov 4, 2024
1 parent dbb5476 commit a7d4582
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 175 deletions.
2 changes: 1 addition & 1 deletion pina/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def initialize(self):
dtype=torch.uint8)
for i in range(len(self.num_el_per_condition))
],
dim=0)
dim=0)
for slot in self.__slots__:
current_attribute = getattr(self, slot)
if all(isinstance(a, LabelTensor) for a in current_attribute):
Expand Down
6 changes: 3 additions & 3 deletions pina/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def __init__(self,
problem,
device,
train_size=.7,
test_size=.1,
val_size=.2,
test_size=.2,
val_size=.1,
predict_size=0.,
batch_size=None,
shuffle=True,
Expand Down Expand Up @@ -140,7 +140,7 @@ def dataset_split(dataset, lengths, seed=None, shuffle=True):
indices = torch.randperm(sum(lengths))
dataset.apply_shuffle(indices)

indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist()
indices = torch.arange(0, sum(lengths), 1, dtype=torch.int32).tolist()
offsets = [
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
]
Expand Down
3 changes: 2 additions & 1 deletion pina/data/pina_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __getattribute__(self, item):
if item in super().__getattribute__('attributes'):
dataset = super().__getattribute__(item)
index = super().__getattribute__(item + '_idx')
return PinaSubset(dataset.dataset, dataset.indices[index])
return PinaSubset(dataset.dataset, dataset.indices[index],
require_grad=self.require_grad)
return super().__getattribute__(item)

def __getattr__(self, item):
Expand Down
2 changes: 1 addition & 1 deletion pina/data/pina_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class PinaSubset:
"""
__slots__ = ['dataset', 'indices', 'require_grad']

def __init__(self, dataset, indices, require_grad=True):
def __init__(self, dataset, indices, require_grad=False):
"""
TODO
"""
Expand Down
95 changes: 61 additions & 34 deletions pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ class LabelTensor(torch.Tensor):

@staticmethod
def __new__(cls, x, labels, *args, **kwargs):
full = kwargs.pop("full", False)
if isinstance(x, LabelTensor):
x.full = full
return x
else:
return super().__new__(cls, x, *args, **kwargs)
Expand Down Expand Up @@ -124,13 +126,12 @@ def _init_labels_from_dict(self, labels):
does not match with tensor shape
"""
tensor_shape = self.shape

if hasattr(self, 'full') and self.full:
labels = {
i: labels[i] if i in labels else {
'name': i
'name': i, 'dof': range(tensor_shape[i])
}
for i in labels.keys()
for i in range(len(tensor_shape))
}
for k, v in labels.items():
# Init labels from str
Expand Down Expand Up @@ -231,8 +232,8 @@ def extract(self, labels_to_extract):
if not isinstance(v, range):
extractor[idx_dim] = [dim_labels.index(i)
for i in v] if len(v) > 1 else slice(
dim_labels.index(v[0]),
dim_labels.index(v[0]) + 1)
dim_labels.index(v[0]),
dim_labels.index(v[0]) + 1)
else:
extractor[idx_dim] = slice(v.start, v.stop)

Expand Down Expand Up @@ -274,31 +275,34 @@ def cat(tensors, dim=0):
return tensors[0]
# Perform cat on tensors
new_tensor = torch.cat(tensors, dim=dim)

new_tensor_shape = new_tensor.shape
# Update labels
labels = LabelTensor.__create_labels_cat(tensors, dim)
labels = LabelTensor.__create_labels_cat(tensors, dim, new_tensor_shape)

return LabelTensor.__internal_init__(new_tensor, labels,
tensors[0].dim_names)

@staticmethod
def __create_labels_cat(tensors, dim):
def __create_labels_cat(tensors, dim, tensor_shape):
# Check if names and dof of the labels are the same in all dimensions
# except in dim
stored_labels = [tensor.stored_labels for tensor in tensors]

# check if:
# - labels dict have same keys
# - all labels are the same expect for dimension dim
if not all(
all(stored_labels[i][k] == stored_labels[0][k]
for i in range(len(stored_labels)))
for k in stored_labels[0].keys() if k != dim):
if not all(all(stored_labels[i][k] == stored_labels[0][k]
for i in range(len(stored_labels)))
for k in stored_labels[0].keys() if k != dim):
raise RuntimeError('tensors must have the same shape and dof')

labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()}
if dim in labels.keys():
last_dim_dof = [i for j in stored_labels for i in j[dim]['dof']]
labels_list = [j[dim]['dof'] for j in stored_labels]
if all(isinstance(j, range) for j in labels_list):
last_dim_dof = range(tensor_shape[dim])
else:
last_dim_dof = [i for j in labels_list for i in j]
labels[dim]['dof'] = last_dim_dof
return labels

Expand Down Expand Up @@ -329,7 +333,8 @@ def clone(self, *args, **kwargs):
:return: A copy of the tensor.
:rtype: LabelTensor
"""
labels = {k: copy(v) for k, v in self._labels.items()}
labels = {k: {sub_k: copy(sub_v) for sub_k, sub_v in v.items()} for k, v
in self.stored_labels.items()}
out = LabelTensor(super().clone(*args, **kwargs), labels)
return out

Expand Down Expand Up @@ -396,34 +401,47 @@ def vstack(label_tensors):
"""
return LabelTensor.cat(label_tensors, dim=0)

@profile
def __getitem__(self, index):
"""
TODO: Complete docstring
:param index:
:return:
"""
if isinstance(index,
str) or (isinstance(index, (tuple, list))
and all(isinstance(a, str) for a in index)):
if isinstance(index, str) or (isinstance(index, (tuple, list))
and all(
isinstance(a, str) for a in index)):
return self.extract(index)

if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
index = [index.nonzero().squeeze()]
selected_lt = super().__getitem__(index)



if isinstance(index, (int, slice)):
index = [index]

if index[0] == Ellipsis:
index = [slice(None)] * (self.ndim - 1) + [index[1]]

if hasattr(self, "labels"):
labels = {k: copy(v) for k, v in self.stored_labels.items()}
try:
stored_labels = self.stored_labels
labels = {}
for j, idx in enumerate(index):

if isinstance(idx, int):
selected_lt = selected_lt.unsqueeze(j)
if j in labels.keys() and idx != slice(None):
self._update_single_label(labels, labels, idx, j)
if j in self.stored_labels.keys() and idx != slice(None):
self._update_single_label(stored_labels, labels, idx, j)
labels.update(
{k: {sub_k: copy(sub_v) for sub_k, sub_v in v.items()} for k, v
in stored_labels.items() if k not in labels})
selected_lt = LabelTensor.__internal_init__(selected_lt, labels,
self.dim_names)
except AttributeError:
import warnings
warnings.warn('No attribute labels in LabelTensor')
return selected_lt

@staticmethod
Expand All @@ -436,29 +454,38 @@ def _update_single_label(old_labels, to_update_labels, index, dim):
:param dim: label index
:return:
"""

old_dof = old_labels[dim]['dof']
if not isinstance(
index,
(int, slice)) and len(index) == len(old_dof) and isinstance(
old_dof, range):
if isinstance(index, torch.Tensor) and index.ndim == 0:
index = int(index)
if (not isinstance(
index, (int, slice)) and len(index) == len(old_dof) and
isinstance(old_dof, range)):
return

if isinstance(index, torch.Tensor):
index = index.nonzero(
as_tuple=True
)[0] if index.dtype == torch.bool else index.tolist()
if isinstance(old_dof, range):
to_update_labels.update({
dim: {
'dof': index.tolist(),
'name': old_labels[dim]['name']
}
})
return
index = index.tolist()
if isinstance(index, list):
to_update_labels.update({
dim: {
'dof': [old_dof[i] for i in index],
'name': old_labels[dim]['name']
}
})
else:
to_update_labels.update(
{dim: {
'dof': old_dof[index],
'name': old_labels[dim]['name']
}})
return
to_update_labels.update(
{dim: {
'dof': old_dof[index],
'name': old_labels[dim]['name']
}})

def sort_labels(self, dim=None):

Expand Down
22 changes: 12 additions & 10 deletions pina/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self,
batch_size=None,
train_size=.7,
test_size=.2,
eval_size=.1,
val_size=.1,
**kwargs):
"""
PINA Trainer class for costumizing every aspect of training via flags.
Expand All @@ -39,11 +39,12 @@ def __init__(self,
check_consistency(batch_size, int)
self.train_size = train_size
self.test_size = test_size
self.eval_size = eval_size
self.val_size = val_size
self.solver = solver
self.batch_size = batch_size
self._create_loader()
self._move_to_device()
self.data_module = None

def _move_to_device(self):
device = self._accelerator_connector._parallel_devices[0]
Expand All @@ -64,7 +65,7 @@ def _create_loader(self):
if not self.solver.problem.collector.full:
error_message = '\n'.join([
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
"not sampled"}""" for key, value in
"not sampled"}""" for key, value in
self._solver.problem.collector._is_conditions_ready.items()
])
raise RuntimeError('Cannot create Trainer if not all conditions '
Expand All @@ -77,13 +78,14 @@ def _create_loader(self):

device = devices[0]

data_module = PinaDataModule(problem=self.solver.problem,
device=device,
train_size=self.train_size,
test_size=self.test_size,
val_size=self.eval_size)
data_module.setup()
self._loader = data_module.train_dataloader()
self.data_module = PinaDataModule(problem=self.solver.problem,
device=device,
train_size=self.train_size,
test_size=self.test_size,
val_size=self.val_size,
batch_size=self.batch_size, )
self.data_module.setup()
self._loader = self.data_module.train_dataloader()

def train(self, **kwargs):
"""
Expand Down
70 changes: 35 additions & 35 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,49 +32,49 @@ class Poisson(SpatialProblem):

conditions = {
'gamma1':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 1
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 1
}),
equation=FixedValue(0.0)),
'gamma2':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 0
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 0
}),
equation=FixedValue(0.0)),
'gamma3':
Condition(domain=CartesianDomain({
'x': 1,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': 1,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'gamma4':
Condition(domain=CartesianDomain({
'x': 0,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': 0,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'D':
Condition(input_points=LabelTensor(torch.rand(size=(100, 2)),
['x', 'y']),
equation=my_laplace),
Condition(input_points=LabelTensor(torch.rand(size=(100, 2)),
['x', 'y']),
equation=my_laplace),
'data':
Condition(input_points=in_, output_points=out_),
Condition(input_points=in_, output_points=out_),
'data2':
Condition(input_points=in2_, output_points=out2_),
Condition(input_points=in2_, output_points=out2_),
'unsupervised':
Condition(
input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(45, 1)),
['alpha']),
),
Condition(
input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(45, 1)),
['alpha']),
),
'unsupervised2':
Condition(
input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(90, 1)),
['alpha']),
)
Condition(
input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(90, 1)),
['alpha']),
)
}


Expand Down
2 changes: 1 addition & 1 deletion tests/test_solvers/test_supervised_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_train_cpu():
batch_size=5,
train_size=1,
test_size=0.,
eval_size=0.)
val_size=0.)
trainer.train()
test_train_cpu()

Expand Down
Loading

0 comments on commit a7d4582

Please sign in to comment.