From 616c3ee2bf05833cf5958ae8717c41a1fc4996df Mon Sep 17 00:00:00 2001 From: awehenkel Date: Wed, 6 Jan 2021 15:38:00 +0100 Subject: [PATCH] Pull request for adding UMNN to this repository. Implemented coupling and autoregressive version. Simple unit tests passed and adapted notebook examples work too. --- environment.yml | 1 + nflows/transforms/UMNN/MonotonicNormalizer.py | 83 +++++++++++++++++++ nflows/transforms/UMNN/__init__.py | 1 + nflows/transforms/__init__.py | 2 + nflows/transforms/autoregressive.py | 66 +++++++++++++++ nflows/transforms/coupling.py | 70 +++++++++++++++- tests/transforms/autoregressive_test.py | 18 ++++ tests/transforms/coupling_test.py | 47 +++++++++++ 8 files changed, 287 insertions(+), 1 deletion(-) create mode 100644 nflows/transforms/UMNN/MonotonicNormalizer.py create mode 100644 nflows/transforms/UMNN/__init__.py diff --git a/environment.yml b/environment.yml index ff0abee..b8cd96f 100644 --- a/environment.yml +++ b/environment.yml @@ -21,6 +21,7 @@ dependencies: - pip: - torchtestcase - -e . # install package in development mode + - umnn - pytest - python - pytorch diff --git a/nflows/transforms/UMNN/MonotonicNormalizer.py b/nflows/transforms/UMNN/MonotonicNormalizer.py new file mode 100644 index 0000000..49f22bc --- /dev/null +++ b/nflows/transforms/UMNN/MonotonicNormalizer.py @@ -0,0 +1,83 @@ +import torch +from UMNN import NeuralIntegral, ParallelNeuralIntegral +import torch.nn as nn + + +def _flatten(sequence): + flat = [p.contiguous().view(-1) for p in sequence] + return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) + + +class ELUPlus(nn.Module): + def __init__(self): + super().__init__() + self.elu = nn.ELU() + + def forward(self, x): + return self.elu(x) + 1. + + +class IntegrandNet(nn.Module): + def __init__(self, hidden, cond_in): + super(IntegrandNet, self).__init__() + l1 = [1 + cond_in] + hidden + l2 = hidden + [1] + layers = [] + for h1, h2 in zip(l1, l2): + layers += [nn.Linear(h1, h2), nn.ReLU()] + layers.pop() + layers.append(ELUPlus()) + self.net = nn.Sequential(*layers) + + def forward(self, x, h): + nb_batch, in_d = x.shape + x = torch.cat((x, h), 1) + x_he = x.view(nb_batch, -1, in_d).transpose(1, 2).contiguous().view(nb_batch * in_d, -1) + y = self.net(x_he).view(nb_batch, -1) + return y + + +class MonotonicNormalizer(nn.Module): + def __init__(self, integrand_net, cond_size, nb_steps=20, solver="CC"): + super(MonotonicNormalizer, self).__init__() + if type(integrand_net) is list: + self.integrand_net = IntegrandNet(integrand_net, cond_size) + else: + self.integrand_net = integrand_net + self.solver = solver + self.nb_steps = nb_steps + + def forward(self, x, h, context=None): + x0 = torch.zeros(x.shape).to(x.device) + xT = x + z0 = h[:, :, 0] + h = h.permute(0, 2, 1).contiguous().view(x.shape[0], -1) + if self.solver == "CC": + z = NeuralIntegral.apply(x0, xT, self.integrand_net, _flatten(self.integrand_net.parameters()), + h, self.nb_steps) + z0 + elif self.solver == "CCParallel": + z = ParallelNeuralIntegral.apply(x0, xT, self.integrand_net, + _flatten(self.integrand_net.parameters()), + h, self.nb_steps) + z0 + else: + return None + return z, self.integrand_net(x, h) + + def inverse_transform(self, z, h, context=None): + # Old inversion by binary search + x_max = torch.ones_like(z) * 20 + x_min = -torch.ones_like(z) * 20 + z_max, _ = self.forward(x_max, h, context) + z_min, _ = self.forward(x_min, h, context) + for i in range(25): + x_middle = (x_max + x_min) / 2 + z_middle, _ = self.forward(x_middle, h, context) + left = (z_middle > z).float() + right = 1 - left + x_max = left * x_middle + right * x_max + x_min = right * x_middle + left * x_min + z_max = left * z_middle + right * z_max + z_min = right * z_middle + left * z_min + return (x_max + x_min) / 2 + + diff --git a/nflows/transforms/UMNN/__init__.py b/nflows/transforms/UMNN/__init__.py new file mode 100644 index 0000000..b22e8c5 --- /dev/null +++ b/nflows/transforms/UMNN/__init__.py @@ -0,0 +1 @@ +from nflows.transforms.UMNN.MonotonicNormalizer import MonotonicNormalizer, IntegrandNet \ No newline at end of file diff --git a/nflows/transforms/__init__.py b/nflows/transforms/__init__.py index 77a3001..d645752 100644 --- a/nflows/transforms/__init__.py +++ b/nflows/transforms/__init__.py @@ -4,6 +4,7 @@ MaskedPiecewiseLinearAutoregressiveTransform, MaskedPiecewiseQuadraticAutoregressiveTransform, MaskedPiecewiseRationalQuadraticAutoregressiveTransform, + MaskedUMNNAutoregressiveTransform, ) from nflows.transforms.base import ( CompositeTransform, @@ -21,6 +22,7 @@ PiecewiseLinearCouplingTransform, PiecewiseQuadraticCouplingTransform, PiecewiseRationalQuadraticCouplingTransform, + UMNNCouplingTransform, ) from nflows.transforms.linear import NaiveLinear from nflows.transforms.lu import LULinear diff --git a/nflows/transforms/autoregressive.py b/nflows/transforms/autoregressive.py index 6c3958f..1105a49 100644 --- a/nflows/transforms/autoregressive.py +++ b/nflows/transforms/autoregressive.py @@ -18,6 +18,7 @@ unconstrained_rational_quadratic_spline, ) from nflows.utils import torchutils +from nflows.transforms.UMNN import * class AutoregressiveTransform(Transform): @@ -127,6 +128,71 @@ def _unconstrained_scale_and_shift(self, autoregressive_params): return unconstrained_scale, shift +class MaskedUMNNAutoregressiveTransform(AutoregressiveTransform): + """An unconstrained monotonic neural networks autoregressive layer that transforms the variables. + + Reference: + > A. Wehenkel and G. Louppe, Unconstrained Monotonic Neural Networks, NeurIPS2019. + + ---- Specific arguments ---- + integrand_net_layers: the layers dimension to put in the integrand network. + cond_size: The embedding size for the conditioning factors. + nb_steps: The number of integration steps. + solver: The quadrature algorithm - CC or CCParallel. Both implements Clenshaw-Curtis quadrature with + Leibniz rule for backward computation. CCParallel pass all the evaluation points (nb_steps) at once, it is faster + but requires more memory. + """ + def __init__( + self, + features, + hidden_features, + context_features=None, + num_blocks=2, + use_residual_blocks=True, + random_mask=False, + activation=F.relu, + dropout_probability=0.0, + use_batch_norm=False, + integrand_net_layers=[50, 50, 50], + cond_size=20, + nb_steps=20, + solver="CCParallel", + ): + self.features = features + self.cond_size = cond_size + made = made_module.MADE( + features=features, + hidden_features=hidden_features, + context_features=context_features, + num_blocks=num_blocks, + output_multiplier=self._output_dim_multiplier(), + use_residual_blocks=use_residual_blocks, + random_mask=random_mask, + activation=activation, + dropout_probability=dropout_probability, + use_batch_norm=use_batch_norm, + ) + self._epsilon = 1e-3 + super().__init__(made) + self.transformer = MonotonicNormalizer(integrand_net_layers, cond_size, nb_steps, solver) + + + def _output_dim_multiplier(self): + return self.cond_size + + def _elementwise_forward(self, inputs, autoregressive_params): + z, jac = self.transformer(inputs, autoregressive_params.reshape(inputs.shape[0], inputs.shape[1], -1)) + log_det_jac = jac.log().sum(1) + return z, log_det_jac + + def _elementwise_inverse(self, inputs, autoregressive_params): + x = self.transformer.inverse_transform(inputs, autoregressive_params.reshape(inputs.shape[0], inputs.shape[1], -1)) + z, jac = self.transformer(x, autoregressive_params.reshape(inputs.shape[0], inputs.shape[1], -1)) + log_det_jac = -jac.log().sum(1) + return x, log_det_jac + + + class MaskedPiecewiseLinearAutoregressiveTransform(AutoregressiveTransform): def __init__( self, diff --git a/nflows/transforms/coupling.py b/nflows/transforms/coupling.py index a33b319..86bd620 100644 --- a/nflows/transforms/coupling.py +++ b/nflows/transforms/coupling.py @@ -13,6 +13,7 @@ PiecewiseRationalQuadraticCDF, ) from nflows.utils import torchutils +from nflows.transforms.UMNN import * class CouplingTransform(Transform): @@ -140,6 +141,73 @@ def _coupling_transform_inverse(self, inputs, transform_params): raise NotImplementedError() +class UMNNCouplingTransform(CouplingTransform): + """An unconstrained monotonic neural networks coupling layer that transforms the variables. + + Reference: + > A. Wehenkel and G. Louppe, Unconstrained Monotonic Neural Networks, NeurIPS2019. + + ---- Specific arguments ---- + integrand_net_layers: the layers dimension to put in the integrand network. + cond_size: The embedding size for the conditioning factors. + nb_steps: The number of integration steps. + solver: The quadrature algorithm - CC or CCParallel. Both implements Clenshaw-Curtis quadrature with + Leibniz rule for backward computation. CCParallel pass all the evaluation points (nb_steps) at once, it is faster + but requires more memory. + + """ + def __init__( + self, + mask, + transform_net_create_fn, + integrand_net_layers=[50, 50, 50], + cond_size=20, + nb_steps=20, + solver="CCParallel", + apply_unconditional_transform=False + ): + + if apply_unconditional_transform: + unconditional_transform = lambda features: MonotonicNormalizer(integrand_net_layers, 0, nb_steps, solver) + else: + unconditional_transform = None + self.cond_size = cond_size + super().__init__( + mask, + transform_net_create_fn, + unconditional_transform=unconditional_transform, + ) + + self.transformer = MonotonicNormalizer(integrand_net_layers, cond_size, nb_steps, solver) + + def _transform_dim_multiplier(self): + return self.cond_size + + def _coupling_transform_forward(self, inputs, transform_params): + if len(inputs.shape) == 2: + z, jac = self.transformer(inputs, transform_params.reshape(inputs.shape[0], inputs.shape[1], -1)) + log_det_jac = jac.log().sum(1) + return z, log_det_jac + else: + B, C, H, W = inputs.shape + z, jac = self.transformer(inputs.permute(0, 2, 3, 1).reshape(-1, inputs.shape[1]), transform_params.permute(0, 2, 3, 1).reshape(-1, 1, transform_params.shape[1])) + log_det_jac = jac.log().reshape(B, -1).sum(1) + return z.reshape(B, H, W, C).permute(0, 3, 1, 2), log_det_jac + + def _coupling_transform_inverse(self, inputs, transform_params): + if len(inputs.shape) == 2: + x = self.transformer.inverse_transform(inputs, transform_params.reshape(inputs.shape[0], inputs.shape[1], -1)) + z, jac = self.transformer(x, transform_params.reshape(inputs.shape[0], inputs.shape[1], -1)) + log_det_jac = -jac.log().sum(1) + return x, log_det_jac + else: + B, C, H, W = inputs.shape + x = self.transformer.inverse_transform(inputs.permute(0, 2, 3, 1).reshape(-1, inputs.shape[1]), transform_params.permute(0, 2, 3, 1).reshape(-1, 1, transform_params.shape[1])) + z, jac = self.transformer(x, transform_params.permute(0, 2, 3, 1).reshape(-1, 1, transform_params.shape[1])) + log_det_jac = -jac.log().reshape(B, -1).sum(1) + return x.reshape(B, H, W, C).permute(0, 3, 1, 2), log_det_jac + + class AffineCouplingTransform(CouplingTransform): """An affine coupling layer that scales and shifts part of the variables. @@ -151,7 +219,7 @@ def _transform_dim_multiplier(self): return 2 def _scale_and_shift(self, transform_params): - unconstrained_scale = transform_params[:, self.num_transform_features :, ...] + unconstrained_scale = transform_params[:, self.num_transform_features:, ...] shift = transform_params[:, : self.num_transform_features, ...] # scale = (F.softplus(unconstrained_scale) + 1e-3).clamp(0, 3) scale = torch.sigmoid(unconstrained_scale + 2) + 1e-3 diff --git a/tests/transforms/autoregressive_test.py b/tests/transforms/autoregressive_test.py index d7072ba..61861d8 100644 --- a/tests/transforms/autoregressive_test.py +++ b/tests/transforms/autoregressive_test.py @@ -114,6 +114,24 @@ def test_forward_inverse_are_consistent(self): self.assert_forward_inverse_are_consistent(transform, inputs) +class MaskedUMNNAutoregressiveTranformTest(TransformTest): + def test_forward_inverse_are_consistent(self): + batch_size = 10 + features = 20 + inputs = torch.rand(batch_size, features) + self.eps = 1e-4 + + transform = autoregressive.MaskedUMNNAutoregressiveTransform( + cond_size=10, + features=features, + hidden_features=30, + num_blocks=5, + use_residual_blocks=True, + ) + + self.assert_forward_inverse_are_consistent(transform, inputs) + + class MaskedPiecewiseCubicAutoregressiveTranformTest(TransformTest): def test_forward_inverse_are_consistent(self): batch_size = 10 diff --git a/tests/transforms/coupling_test.py b/tests/transforms/coupling_test.py index f218936..90ed480 100644 --- a/tests/transforms/coupling_test.py +++ b/tests/transforms/coupling_test.py @@ -112,6 +112,53 @@ def test_forward_inverse_are_consistent(self): self.assert_forward_inverse_are_consistent(transform, inputs) +class UMNNTransformTest(TransformTest): + shapes = [[20], [2, 4, 4]] + + def test_forward(self): + for shape in self.shapes: + inputs = torch.randn(batch_size, *shape) + transform, mask = create_coupling_transform( + coupling.UMNNCouplingTransform, shape, integrand_net_layers=[50, 50, 50], + cond_size=20, + nb_steps=20, + solver="CC" + ) + outputs, logabsdet = transform(inputs) + with self.subTest(shape=shape): + self.assert_tensor_is_good(outputs, [batch_size] + shape) + self.assert_tensor_is_good(logabsdet, [batch_size]) + self.assertEqual(outputs[:, mask <= 0, ...], inputs[:, mask <= 0, ...]) + + def test_inverse(self): + for shape in self.shapes: + inputs = torch.randn(batch_size, *shape) + transform, mask = create_coupling_transform( + coupling.UMNNCouplingTransform, shape, integrand_net_layers=[50, 50, 50], + cond_size=20, + nb_steps=20, + solver="CC" + ) + outputs, logabsdet = transform(inputs) + with self.subTest(shape=shape): + self.assert_tensor_is_good(outputs, [batch_size] + shape) + self.assert_tensor_is_good(logabsdet, [batch_size]) + self.assertEqual(outputs[:, mask <= 0, ...], inputs[:, mask <= 0, ...]) + + def test_forward_inverse_are_consistent(self): + self.eps = 1e-6 + for shape in self.shapes: + inputs = torch.randn(batch_size, *shape) + transform, mask = create_coupling_transform( + coupling.UMNNCouplingTransform, shape, integrand_net_layers=[50, 50, 50], + cond_size=20, + nb_steps=20, + solver="CC" + ) + with self.subTest(shape=shape): + self.assert_forward_inverse_are_consistent(transform, inputs) + + class PiecewiseCouplingTransformTest(TransformTest): classes = [ coupling.PiecewiseLinearCouplingTransform,