Skip to content

Commit

Permalink
Merge pull request #184 from NeuroDiffGym/v0.6.1
Browse files Browse the repository at this point in the history
Hot fix: solve a fatal compatibility issue with torch v1.13
  • Loading branch information
shuheng-liu authored Dec 7, 2022
2 parents 4eac45c + 36966e8 commit 0f85aeb
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 62 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,6 @@ docs/_build/
_test/

.DS_Store

# Tensorboard
runs/
8 changes: 6 additions & 2 deletions neurodiffeq/_version_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def warn_deprecate_class(new_class):
:return: a function that, when called, acts as if it is a class constructor
:rtype: callable
"""

@functools.wraps(new_class)
def old_class_getter(*args, **kwargs):
warnings.warn(f"This class name is deprecated, use {new_class} instead", FutureWarning)
Expand All @@ -26,12 +27,15 @@ def deprecated_alias(**aliases):
:return: A decorated function that can receive either `old_name` or `new_name` as input
:rtype: function
"""

def deco(f):
@functools.wraps(f) # preserves signature and docstring
def wrapper(*args, **kwargs):
_rename_kwargs(f.__name__, kwargs, aliases)
return f(*args, **kwargs)

return wrapper

return deco


Expand All @@ -40,5 +44,5 @@ def _rename_kwargs(func_name, kwargs, aliases):
if alias in kwargs:
if new in kwargs:
raise KeyError(f'{func_name} received both `{alias}` (deprecated) and `{new}` (recommended)')
warnings.warn(f'The argument `{alias}` is deprecated; use `{new}` instead for {func_name}.', FutureWarning)
kwargs[new] = kwargs.pop(alias)
warnings.warn(f'The argument `{alias}` is deprecated for {func_name}; use `{new}` instead.', FutureWarning)
kwargs[new] = kwargs.pop(alias)
17 changes: 10 additions & 7 deletions neurodiffeq/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,19 +262,19 @@ def __call__(self, solver):
)


class SetCriterion(ActionCallback):
class SetLossFn(ActionCallback):
r"""A callback that sets the ``criterion`` (a.k.a. loss function) of the solver.
Best used together with a condition callback.
:param criterion:
:param loss_fn:
The loss function to be set for the solver. It can be
- An instance of ``torch.nn.modules.loss._Loss``
which computes loss of the PDE/ODE residuals against a zero tensor.
- A callable object which maps residuals, function values, and input coordinates to a scalar loss; or
- A str which is present in ``neurodiffeq.losses._losses.keys()``.
:type criterion: ``torch.nn.modules.loss._Loss`` or callable or str.
:type loss_fn: ``torch.nn.modules.loss._Loss`` or callable or str.
:param reset:
If True, the criterion will be reset every time the callback is called.
Otherwise, the criterion will only be set once.
Expand All @@ -284,19 +284,22 @@ class SetCriterion(ActionCallback):
:type logger: str or ``logging.Logger``
"""

def __init__(self, criterion, reset=False, logger=None):
super(SetCriterion, self).__init__(logger=logger)
self.criterion = criterion
@deprecated_alias(criterion='loss_fn')
def __init__(self, loss_fn, reset=False, logger=None):
super(SetLossFn, self).__init__(logger=logger)
self.loss_fn = loss_fn
self.reset = reset
self.called = False

def __call__(self, solver):
if self.reset or (not self.called):
self.called = True
# noinspection PyProtectedMember
solver._set_criterion(self.criterion)
solver._set_loss_fn(self.loss_fn)


SetCriterion = warn_deprecate_class(SetLossFn)

class SetOptimizer(ActionCallback):
r"""A callback that sets the optimizer of the solver. Best used together with a condition callback.
Expand Down
6 changes: 2 additions & 4 deletions neurodiffeq/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,10 +671,8 @@ def _internal_vars(self) -> dict:
class PredefinedGenerator(BaseGenerator):
"""A generator for generating points that are fixed and predefined.
:param xs: The x-dimension of the trianing points
:type xs: `torch.Tensor`
:param ys: The y-dimension of the training points
:type ys: `torch.Tensor`
:param xs: training points that will be returned
:type xs: Tuple[`torch.Tensor`]
"""

def __init__(self, *xs):
Expand Down
2 changes: 1 addition & 1 deletion neurodiffeq/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class CustomSolver1D(Solver1D):
train_generator=train_generator,
valid_generator=valid_generator,
optimizer=optimizer,
criterion=criterion,
loss_fn=criterion,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
metrics=metrics,
Expand Down
2 changes: 1 addition & 1 deletion neurodiffeq/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ class CustomSolver2D(Solver2D):
train_generator=train_generator,
valid_generator=valid_generator,
optimizer=optimizer,
criterion=criterion,
loss_fn=criterion,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
metrics=metrics,
Expand Down
2 changes: 1 addition & 1 deletion neurodiffeq/pde_spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def enforcer(net, cond, points):
valid_generator=valid_generator,
analytic_solutions=analytic_solutions,
optimizer=optimizer,
criterion=criterion,
loss_fn=criterion,
n_batches_train=1,
n_batches_valid=1,
enforcer=enforcer,
Expand Down
78 changes: 50 additions & 28 deletions neurodiffeq/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@


def _requires_closure(optimizer):
return inspect.signature(optimizer.step).parameters.get('closure').default == inspect._empty
# starting from torch v1.13, simple optimizers no longer have a `closure` argument
closure_param = inspect.signature(optimizer.step).parameters.get('closure')
return closure_param and closure_param.default == inspect._empty


class BaseSolver(ABC, PretrainedSolver):
Expand Down Expand Up @@ -60,7 +62,7 @@ class BaseSolver(ABC, PretrainedSolver):
:param optimizer:
The optimizer to be used for training.
:type optimizer: `torch.nn.optim.Optimizer`, optional
:param criterion:
:param loss_fn:
The loss function used for training.
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
Expand All @@ -72,7 +74,7 @@ class BaseSolver(ABC, PretrainedSolver):
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
so that backpropagation can be performed.
:type criterion:
:type loss_fn:
str or `torch.nn.moduesl.loss._Loss` or callable
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Expand Down Expand Up @@ -107,9 +109,10 @@ class BaseSolver(ABC, PretrainedSolver):
:type shuffle: bool
"""

@deprecated_alias(criterion='loss_fn')
def __init__(self, diff_eqs, conditions,
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None,
optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4,
optimizer=None, loss_fn=None, n_batches_train=1, n_batches_valid=4,
metrics=None, n_input_units=None, n_output_units=None,
# deprecated arguments are listed below
shuffle=None, batch_size=None):
Expand Down Expand Up @@ -176,7 +179,7 @@ def analytic_mse(*args):
self.metrics_history.update({'valid__' + name: [] for name in self.metrics_fn})

self.optimizer = optimizer if optimizer else Adam(set(chain.from_iterable(n.parameters() for n in self.nets)))
self._set_criterion(criterion)
self._set_loss_fn(loss_fn)

def make_pair_dict(train=None, valid=None):
return {'train': train, 'valid': valid}
Expand All @@ -203,15 +206,15 @@ def make_pair_dict(train=None, valid=None):
# the _phase variable is registered for callback functions to access
self._phase = None

def _set_criterion(self, criterion):
def _set_loss_fn(self, criterion):
if criterion is None:
self.criterion = lambda r, f, x: (r ** 2).mean()
self.loss_fn = lambda r, f, x: (r ** 2).mean()
elif isinstance(criterion, nn.modules.loss._Loss):
self.criterion = lambda r, f, x: criterion(r, torch.zeros_like(r))
self.loss_fn = lambda r, f, x: criterion(r, torch.zeros_like(r))
elif isinstance(criterion, str):
self.criterion = _losses[criterion.lower()]
self.loss_fn = _losses[criterion.lower()]
elif callable(criterion):
self.criterion = criterion
self.loss_fn = criterion
else:
raise TypeError(f"Unknown type of criterion {type(criterion)}")

Expand All @@ -236,6 +239,24 @@ def _batch_examples(self):
)
return self._batch

@property
def criterion(self):
warnings.warn(
f'`{self.__class__.__name__}`.criterion is a deprecated alias for `{self.__class__.__name__}.loss_fn`.'
f'The alias is only meant to be accessed by certain functions in `neurodiffeq.solver_utils` '
f'until proper fixes are made; by which time this alias will be removed.'
)
return self.loss_fn

@criterion.setter
def criterion(self, loss_fn):
warnings.warn(
f'`{self.__class__.__name__}`.criterion is a deprecated alias for `{self.__class__.__name__}.loss_fn`.'
f'The alias is only meant to be accessed by certain functions in `neurodiffeq.solver_utils` '
f'until proper fixes are made; by which time this alias will be removed.'
)
self.loss_fn = loss_fn

def compute_func_val(self, net, cond, *coordinates):
r"""Compute the function value evaluated on the points specified by ``coordinates``.
Expand Down Expand Up @@ -352,7 +373,7 @@ def closure(zero_grad=True):
residuals = self.diff_eqs(*funcs, *batch)
residuals = torch.cat(residuals, dim=1)
try:
loss = self.criterion(residuals, funcs, batch) + self.additional_loss(residuals, funcs, batch)
loss = self.loss_fn(residuals, funcs, batch) + self.additional_loss(residuals, funcs, batch)
except TypeError as e:
warnings.warn(
"You might need to update your code. "
Expand Down Expand Up @@ -507,7 +528,8 @@ def _get_internal_variables(self):
"metrics": self.metrics_fn,
"n_batches": self.n_batches,
"best_nets": self.best_nets,
"criterion": self.criterion,
"criterion": self.loss_fn,
"loss_fn": self.loss_fn,
"conditions": self.conditions,
"global_epoch": self.global_epoch,
"lowest_loss": self.lowest_loss,
Expand Down Expand Up @@ -766,7 +788,7 @@ class SolverSpherical(BaseSolver):
Optimizer to be used for training.
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
:type optimizer: ``torch.nn.optim.Optimizer``, optional
:param criterion:
:param loss_fn:
The loss function used for training.
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
Expand All @@ -778,7 +800,7 @@ class SolverSpherical(BaseSolver):
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
so that backpropagation can be performed.
:type criterion:
:type loss_fn:
str or `torch.nn.moduesl.loss._Loss` or callable
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Expand Down Expand Up @@ -820,7 +842,7 @@ class SolverSpherical(BaseSolver):

def __init__(self, pde_system, conditions, r_min=None, r_max=None,
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None,
optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, metrics=None, enforcer=None,
optimizer=None, loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, enforcer=None,
n_output_units=1,
# deprecated arguments are listed below
shuffle=None, batch_size=None):
Expand Down Expand Up @@ -848,7 +870,7 @@ def __init__(self, pde_system, conditions, r_min=None, r_max=None,
valid_generator=valid_generator,
analytic_solutions=analytic_solutions,
optimizer=optimizer,
criterion=criterion,
loss_fn=loss_fn,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
metrics=metrics,
Expand Down Expand Up @@ -1025,7 +1047,7 @@ class Solver1D(BaseSolver):
Optimizer to be used for training.
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
:type optimizer: ``torch.nn.optim.Optimizer``, optional
:param criterion:
:param loss_fn:
The loss function used for training.
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
Expand All @@ -1037,7 +1059,7 @@ class Solver1D(BaseSolver):
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
so that backpropagation can be performed.
:type criterion:
:type loss_fn:
str or `torch.nn.moduesl.loss._Loss` or callable
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Expand Down Expand Up @@ -1073,7 +1095,7 @@ class Solver1D(BaseSolver):

def __init__(self, ode_system, conditions, t_min=None, t_max=None,
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, optimizer=None,
criterion=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
# deprecated arguments are listed below
batch_size=None, shuffle=None):

Expand All @@ -1098,7 +1120,7 @@ def __init__(self, ode_system, conditions, t_min=None, t_max=None,
valid_generator=valid_generator,
analytic_solutions=analytic_solutions,
optimizer=optimizer,
criterion=criterion,
loss_fn=loss_fn,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
metrics=metrics,
Expand Down Expand Up @@ -1209,7 +1231,7 @@ class BundleSolver1D(BaseSolver):
Optimizer to be used for training.
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
:type optimizer: ``torch.nn.optim.Optimizer``, optional
:param criterion:
:param loss_fn:
The loss function used for training.
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
Expand All @@ -1221,7 +1243,7 @@ class BundleSolver1D(BaseSolver):
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
so that backpropagation can be performed.
:type criterion:
:type loss_fn:
str or `torch.nn.moduesl.loss._Loss` or callable
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Expand Down Expand Up @@ -1258,7 +1280,7 @@ class BundleSolver1D(BaseSolver):
def __init__(self, ode_system, conditions, t_min, t_max,
theta_min=None, theta_max=None,
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, optimizer=None,
criterion=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
# deprecated arguments are listed below
batch_size=None, shuffle=None):

Expand Down Expand Up @@ -1319,7 +1341,7 @@ def non_var_filter(*variables):
valid_generator=valid_generator,
analytic_solutions=analytic_solutions,
optimizer=optimizer,
criterion=criterion,
loss_fn=loss_fn,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
metrics=metrics,
Expand Down Expand Up @@ -1420,7 +1442,7 @@ class Solver2D(BaseSolver):
Optimizer to be used for training.
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
:type optimizer: ``torch.nn.optim.Optimizer``, optional
:param criterion:
:param loss_fn:
The loss function used for training.
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
Expand All @@ -1432,7 +1454,7 @@ class Solver2D(BaseSolver):
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
so that backpropagation can be performed.
:type criterion:
:type loss_fn:
str or `torch.nn.moduesl.loss._Loss` or callable
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Expand Down Expand Up @@ -1468,7 +1490,7 @@ class Solver2D(BaseSolver):

def __init__(self, pde_system, conditions, xy_min=None, xy_max=None,
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, optimizer=None,
criterion=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
# deprecated arguments are listed below
batch_size=None, shuffle=None):

Expand All @@ -1493,7 +1515,7 @@ def __init__(self, pde_system, conditions, xy_min=None, xy_max=None,
valid_generator=valid_generator,
analytic_solutions=analytic_solutions,
optimizer=optimizer,
criterion=criterion,
loss_fn=loss_fn,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
metrics=metrics,
Expand Down
Loading

0 comments on commit 0f85aeb

Please sign in to comment.