Skip to content

Commit

Permalink
Bug fix in SupervisedSolver and improve LabelTensor with override of …
Browse files Browse the repository at this point in the history
…__torch_functions__ and __mul__
  • Loading branch information
FilippoOlivo committed Nov 9, 2024
1 parent 30e2fa8 commit 457820d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 21 deletions.
56 changes: 55 additions & 1 deletion pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def issubset(a, b):
if isinstance(a, range) and isinstance(b, range):
return a.start <= b.start and a.stop >= b.stop
return False

MATH_MODULES = {torch.sin, torch.cos, torch.exp, torch.tan, torch.log,
torch.sqrt}

class LabelTensor(torch.Tensor):
"""Torch tensor with a label for any column."""
Expand Down Expand Up @@ -48,6 +49,59 @@ def __init__(self, x, labels, **kwargs):
self.full = kwargs.get('full', full_labels)
self.labels = labels

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func in MATH_MODULES:
str_labels = func.__name__
labels = copy(args[0].stored_labels)
lt = super().__torch_function__(func, types, args=args,
kwargs=kwargs)
lt_shape = lt.shape

if len(lt_shape) - 1 in labels.keys():
labels.update({
len(lt_shape) - 1: {
'dof': [f'{str_labels}({i})' for i in
labels[len(lt_shape) - 1]['dof']],
'name': len(lt_shape) - 1
}
})
lt._labels = labels
return lt
return super().__torch_function__(func, types, args=args, kwargs=kwargs)

def __mul__(self, other):
lt = super().__mul__(other)
if isinstance(other, (int, float)):
if hasattr(self, '_labels'):
lt._labels = self._labels
if isinstance(other, LabelTensor):
lt_shape = lt.shape
labels = copy(self.stored_labels)
other_labels = other.stored_labels
check = False
for (k, v), (ko, vo) in zip(sorted(labels.items()),
sorted(other_labels.items())):
if k != ko:
raise ValueError('Labels must be the same')
if k != len(lt_shape) - 1:
if vo != v:
raise ValueError('Labels must be the same')
else:
check = True
if check:
labels.update({
len(lt_shape) - 1: {'dof': [f'{i}{j}' for i, j in
zip(self.stored_labels[len(lt_shape) - 1]['dof'],
other.stored_labels[len(lt_shape) - 1]['dof'])],
'name': self.stored_labels[len(lt_shape) - 1]['name']}
})

lt._labels = labels
return lt

@classmethod
def __internal_init__(cls,
x,
Expand Down
24 changes: 9 additions & 15 deletions pina/solvers/pinns/basepinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def training_step(self, batch, _):
condition_idx = supervised.condition_indices
else:
condition_idx = torch.tensor([])

loss = torch.tensor(0, dtype=torch.float32)
for condition_id in torch.unique(condition_idx).tolist():
condition_name = self._dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
Expand All @@ -132,11 +132,8 @@ def training_step(self, batch, _):
output_pts = out[condition_idx == condition_id]
input_pts = pts[condition_idx == condition_id]

input_pts.labels = pts.labels
output_pts.labels = out.labels

loss = self.loss_data(input_points=input_pts, output_points=output_pts)
loss = loss.as_subclass(torch.Tensor)
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
loss += loss_.as_subclass(torch.Tensor)

condition_idx = physics.condition_indices
for condition_id in torch.unique(condition_idx).tolist():
Expand All @@ -147,20 +144,18 @@ def training_step(self, batch, _):
pts = batch.physics.input_points
input_pts = pts[condition_idx == condition_id]

input_pts.labels = pts.labels
loss = self.loss_phys(pts, condition.equation)
loss_ = self.loss_phys(input_pts, condition.equation)

# add condition losses for each epoch
condition_losses.append(loss)
loss += loss_.as_subclass(torch.Tensor)

# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()

# total loss (must be a torch.Tensor)
total_loss = sum(condition_losses)
return total_loss.as_subclass(torch.Tensor)
return loss

def loss_data(self, input_points, output_points):
def loss_data(self, input_pts, output_pts):
"""
The data loss for the PINN solver. It computes the loss between
the network output against the true solution. This function
Expand All @@ -172,9 +167,8 @@ def loss_data(self, input_points, output_points):
:return: The residual loss averaged on the input coordinates
:rtype: torch.Tensor
"""
loss_value = self.loss(self.forward(input_points), output_points)
self.store_log(loss_value=float(loss_value))
return loss_value
return self._loss(self.forward(input_pts), output_pts)


@abstractmethod
def loss_phys(self, samples, equation):
Expand Down
9 changes: 4 additions & 5 deletions pina/solvers/supervised.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
""" Module for SupervisedSolver """

import torch
from torch.nn.modules.loss import _Loss
from ..optim import TorchOptimizer, TorchScheduler
Expand Down Expand Up @@ -118,6 +117,7 @@ def training_step(self, batch, batch_idx):
"""

condition_idx = batch.supervised.condition_indices
loss = torch.tensor(0, dtype=torch.float32)
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
condition_name = self._dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
Expand All @@ -130,14 +130,13 @@ def training_step(self, batch, batch_idx):
if not hasattr(condition, "output_points"):
raise NotImplementedError(
f"{type(self).__name__} works only in data-driven mode.")

output_pts = out[condition_idx == condition_id]
input_pts = pts[condition_idx == condition_id]

input_pts.labels = pts.labels
output_pts.labels = out.labels

loss = self.loss_data(input_pts=input_pts, output_pts=output_pts)
loss = loss.as_subclass(torch.Tensor)
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
loss += loss_.as_subclass(torch.Tensor)

self.log("mean_loss", float(loss), prog_bar=True, logger=True)
return loss
Expand Down

0 comments on commit 457820d

Please sign in to comment.