From 457820d8fd4dbfad13489395f99267a852dd827e Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Sat, 9 Nov 2024 14:07:37 +0100 Subject: [PATCH] Bug fix in SupervisedSolver and improve LabelTensor with override of __torch_functions__ and __mul__ --- pina/label_tensor.py | 56 +++++++++++++++++++++++++++++++++- pina/solvers/pinns/basepinn.py | 24 ++++++--------- pina/solvers/supervised.py | 9 +++--- 3 files changed, 68 insertions(+), 21 deletions(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index ad9034b2..a9c9da4e 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -15,7 +15,8 @@ def issubset(a, b): if isinstance(a, range) and isinstance(b, range): return a.start <= b.start and a.stop >= b.stop return False - +MATH_MODULES = {torch.sin, torch.cos, torch.exp, torch.tan, torch.log, + torch.sqrt} class LabelTensor(torch.Tensor): """Torch tensor with a label for any column.""" @@ -48,6 +49,59 @@ def __init__(self, x, labels, **kwargs): self.full = kwargs.get('full', full_labels) self.labels = labels + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func in MATH_MODULES: + str_labels = func.__name__ + labels = copy(args[0].stored_labels) + lt = super().__torch_function__(func, types, args=args, + kwargs=kwargs) + lt_shape = lt.shape + + if len(lt_shape) - 1 in labels.keys(): + labels.update({ + len(lt_shape) - 1: { + 'dof': [f'{str_labels}({i})' for i in + labels[len(lt_shape) - 1]['dof']], + 'name': len(lt_shape) - 1 + } + }) + lt._labels = labels + return lt + return super().__torch_function__(func, types, args=args, kwargs=kwargs) + + def __mul__(self, other): + lt = super().__mul__(other) + if isinstance(other, (int, float)): + if hasattr(self, '_labels'): + lt._labels = self._labels + if isinstance(other, LabelTensor): + lt_shape = lt.shape + labels = copy(self.stored_labels) + other_labels = other.stored_labels + check = False + for (k, v), (ko, vo) in zip(sorted(labels.items()), + sorted(other_labels.items())): + if k != ko: + raise ValueError('Labels must be the same') + if k != len(lt_shape) - 1: + if vo != v: + raise ValueError('Labels must be the same') + else: + check = True + if check: + labels.update({ + len(lt_shape) - 1: {'dof': [f'{i}{j}' for i, j in + zip(self.stored_labels[len(lt_shape) - 1]['dof'], + other.stored_labels[len(lt_shape) - 1]['dof'])], + 'name': self.stored_labels[len(lt_shape) - 1]['name']} + }) + + lt._labels = labels + return lt + @classmethod def __internal_init__(cls, x, diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index 3762cc88..fbed0bc5 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -122,7 +122,7 @@ def training_step(self, batch, _): condition_idx = supervised.condition_indices else: condition_idx = torch.tensor([]) - + loss = torch.tensor(0, dtype=torch.float32) for condition_id in torch.unique(condition_idx).tolist(): condition_name = self._dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] @@ -132,11 +132,8 @@ def training_step(self, batch, _): output_pts = out[condition_idx == condition_id] input_pts = pts[condition_idx == condition_id] - input_pts.labels = pts.labels - output_pts.labels = out.labels - - loss = self.loss_data(input_points=input_pts, output_points=output_pts) - loss = loss.as_subclass(torch.Tensor) + loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + loss += loss_.as_subclass(torch.Tensor) condition_idx = physics.condition_indices for condition_id in torch.unique(condition_idx).tolist(): @@ -147,20 +144,18 @@ def training_step(self, batch, _): pts = batch.physics.input_points input_pts = pts[condition_idx == condition_id] - input_pts.labels = pts.labels - loss = self.loss_phys(pts, condition.equation) + loss_ = self.loss_phys(input_pts, condition.equation) # add condition losses for each epoch - condition_losses.append(loss) + loss += loss_.as_subclass(torch.Tensor) # clamp unknown parameters in InverseProblem (if needed) self._clamp_params() # total loss (must be a torch.Tensor) - total_loss = sum(condition_losses) - return total_loss.as_subclass(torch.Tensor) + return loss - def loss_data(self, input_points, output_points): + def loss_data(self, input_pts, output_pts): """ The data loss for the PINN solver. It computes the loss between the network output against the true solution. This function @@ -172,9 +167,8 @@ def loss_data(self, input_points, output_points): :return: The residual loss averaged on the input coordinates :rtype: torch.Tensor """ - loss_value = self.loss(self.forward(input_points), output_points) - self.store_log(loss_value=float(loss_value)) - return loss_value + return self._loss(self.forward(input_pts), output_pts) + @abstractmethod def loss_phys(self, samples, equation): diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index b9258f7d..a2be1102 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -1,5 +1,4 @@ """ Module for SupervisedSolver """ - import torch from torch.nn.modules.loss import _Loss from ..optim import TorchOptimizer, TorchScheduler @@ -118,6 +117,7 @@ def training_step(self, batch, batch_idx): """ condition_idx = batch.supervised.condition_indices + loss = torch.tensor(0, dtype=torch.float32) for condition_id in range(condition_idx.min(), condition_idx.max() + 1): condition_name = self._dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] @@ -130,14 +130,13 @@ def training_step(self, batch, batch_idx): if not hasattr(condition, "output_points"): raise NotImplementedError( f"{type(self).__name__} works only in data-driven mode.") + output_pts = out[condition_idx == condition_id] input_pts = pts[condition_idx == condition_id] - input_pts.labels = pts.labels - output_pts.labels = out.labels - loss = self.loss_data(input_pts=input_pts, output_pts=output_pts) - loss = loss.as_subclass(torch.Tensor) + loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + loss += loss_.as_subclass(torch.Tensor) self.log("mean_loss", float(loss), prog_bar=True, logger=True) return loss