Skip to content

Commit

Permalink
Merge pull request #29 from AWehenkel/UMNN
Browse files Browse the repository at this point in the history
UMNNs implementation
  • Loading branch information
arturbekasov authored Feb 21, 2021
2 parents 75048ff + 616c3ee commit 639c3a7
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 1 deletion.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies:
- pip:
- torchtestcase
- -e . # install package in development mode
- umnn
- pytest
- python
- pytorch
Expand Down
83 changes: 83 additions & 0 deletions nflows/transforms/UMNN/MonotonicNormalizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
from UMNN import NeuralIntegral, ParallelNeuralIntegral
import torch.nn as nn


def _flatten(sequence):
flat = [p.contiguous().view(-1) for p in sequence]
return torch.cat(flat) if len(flat) > 0 else torch.tensor([])


class ELUPlus(nn.Module):
def __init__(self):
super().__init__()
self.elu = nn.ELU()

def forward(self, x):
return self.elu(x) + 1.


class IntegrandNet(nn.Module):
def __init__(self, hidden, cond_in):
super(IntegrandNet, self).__init__()
l1 = [1 + cond_in] + hidden
l2 = hidden + [1]
layers = []
for h1, h2 in zip(l1, l2):
layers += [nn.Linear(h1, h2), nn.ReLU()]
layers.pop()
layers.append(ELUPlus())
self.net = nn.Sequential(*layers)

def forward(self, x, h):
nb_batch, in_d = x.shape
x = torch.cat((x, h), 1)
x_he = x.view(nb_batch, -1, in_d).transpose(1, 2).contiguous().view(nb_batch * in_d, -1)
y = self.net(x_he).view(nb_batch, -1)
return y


class MonotonicNormalizer(nn.Module):
def __init__(self, integrand_net, cond_size, nb_steps=20, solver="CC"):
super(MonotonicNormalizer, self).__init__()
if type(integrand_net) is list:
self.integrand_net = IntegrandNet(integrand_net, cond_size)
else:
self.integrand_net = integrand_net
self.solver = solver
self.nb_steps = nb_steps

def forward(self, x, h, context=None):
x0 = torch.zeros(x.shape).to(x.device)
xT = x
z0 = h[:, :, 0]
h = h.permute(0, 2, 1).contiguous().view(x.shape[0], -1)
if self.solver == "CC":
z = NeuralIntegral.apply(x0, xT, self.integrand_net, _flatten(self.integrand_net.parameters()),
h, self.nb_steps) + z0
elif self.solver == "CCParallel":
z = ParallelNeuralIntegral.apply(x0, xT, self.integrand_net,
_flatten(self.integrand_net.parameters()),
h, self.nb_steps) + z0
else:
return None
return z, self.integrand_net(x, h)

def inverse_transform(self, z, h, context=None):
# Old inversion by binary search
x_max = torch.ones_like(z) * 20
x_min = -torch.ones_like(z) * 20
z_max, _ = self.forward(x_max, h, context)
z_min, _ = self.forward(x_min, h, context)
for i in range(25):
x_middle = (x_max + x_min) / 2
z_middle, _ = self.forward(x_middle, h, context)
left = (z_middle > z).float()
right = 1 - left
x_max = left * x_middle + right * x_max
x_min = right * x_middle + left * x_min
z_max = left * z_middle + right * z_max
z_min = right * z_middle + left * z_min
return (x_max + x_min) / 2


1 change: 1 addition & 0 deletions nflows/transforms/UMNN/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from nflows.transforms.UMNN.MonotonicNormalizer import MonotonicNormalizer, IntegrandNet
2 changes: 2 additions & 0 deletions nflows/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
MaskedPiecewiseLinearAutoregressiveTransform,
MaskedPiecewiseQuadraticAutoregressiveTransform,
MaskedPiecewiseRationalQuadraticAutoregressiveTransform,
MaskedUMNNAutoregressiveTransform,
)
from nflows.transforms.base import (
CompositeTransform,
Expand All @@ -21,6 +22,7 @@
PiecewiseLinearCouplingTransform,
PiecewiseQuadraticCouplingTransform,
PiecewiseRationalQuadraticCouplingTransform,
UMNNCouplingTransform,
)
from nflows.transforms.linear import NaiveLinear
from nflows.transforms.lu import LULinear
Expand Down
66 changes: 66 additions & 0 deletions nflows/transforms/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
unconstrained_rational_quadratic_spline,
)
from nflows.utils import torchutils
from nflows.transforms.UMNN import *


class AutoregressiveTransform(Transform):
Expand Down Expand Up @@ -127,6 +128,71 @@ def _unconstrained_scale_and_shift(self, autoregressive_params):
return unconstrained_scale, shift


class MaskedUMNNAutoregressiveTransform(AutoregressiveTransform):
"""An unconstrained monotonic neural networks autoregressive layer that transforms the variables.
Reference:
> A. Wehenkel and G. Louppe, Unconstrained Monotonic Neural Networks, NeurIPS2019.
---- Specific arguments ----
integrand_net_layers: the layers dimension to put in the integrand network.
cond_size: The embedding size for the conditioning factors.
nb_steps: The number of integration steps.
solver: The quadrature algorithm - CC or CCParallel. Both implements Clenshaw-Curtis quadrature with
Leibniz rule for backward computation. CCParallel pass all the evaluation points (nb_steps) at once, it is faster
but requires more memory.
"""
def __init__(
self,
features,
hidden_features,
context_features=None,
num_blocks=2,
use_residual_blocks=True,
random_mask=False,
activation=F.relu,
dropout_probability=0.0,
use_batch_norm=False,
integrand_net_layers=[50, 50, 50],
cond_size=20,
nb_steps=20,
solver="CCParallel",
):
self.features = features
self.cond_size = cond_size
made = made_module.MADE(
features=features,
hidden_features=hidden_features,
context_features=context_features,
num_blocks=num_blocks,
output_multiplier=self._output_dim_multiplier(),
use_residual_blocks=use_residual_blocks,
random_mask=random_mask,
activation=activation,
dropout_probability=dropout_probability,
use_batch_norm=use_batch_norm,
)
self._epsilon = 1e-3
super().__init__(made)
self.transformer = MonotonicNormalizer(integrand_net_layers, cond_size, nb_steps, solver)


def _output_dim_multiplier(self):
return self.cond_size

def _elementwise_forward(self, inputs, autoregressive_params):
z, jac = self.transformer(inputs, autoregressive_params.reshape(inputs.shape[0], inputs.shape[1], -1))
log_det_jac = jac.log().sum(1)
return z, log_det_jac

def _elementwise_inverse(self, inputs, autoregressive_params):
x = self.transformer.inverse_transform(inputs, autoregressive_params.reshape(inputs.shape[0], inputs.shape[1], -1))
z, jac = self.transformer(x, autoregressive_params.reshape(inputs.shape[0], inputs.shape[1], -1))
log_det_jac = -jac.log().sum(1)
return x, log_det_jac



class MaskedPiecewiseLinearAutoregressiveTransform(AutoregressiveTransform):
def __init__(
self,
Expand Down
70 changes: 69 additions & 1 deletion nflows/transforms/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
PiecewiseRationalQuadraticCDF,
)
from nflows.utils import torchutils
from nflows.transforms.UMNN import *


class CouplingTransform(Transform):
Expand Down Expand Up @@ -140,6 +141,73 @@ def _coupling_transform_inverse(self, inputs, transform_params):
raise NotImplementedError()


class UMNNCouplingTransform(CouplingTransform):
"""An unconstrained monotonic neural networks coupling layer that transforms the variables.
Reference:
> A. Wehenkel and G. Louppe, Unconstrained Monotonic Neural Networks, NeurIPS2019.
---- Specific arguments ----
integrand_net_layers: the layers dimension to put in the integrand network.
cond_size: The embedding size for the conditioning factors.
nb_steps: The number of integration steps.
solver: The quadrature algorithm - CC or CCParallel. Both implements Clenshaw-Curtis quadrature with
Leibniz rule for backward computation. CCParallel pass all the evaluation points (nb_steps) at once, it is faster
but requires more memory.
"""
def __init__(
self,
mask,
transform_net_create_fn,
integrand_net_layers=[50, 50, 50],
cond_size=20,
nb_steps=20,
solver="CCParallel",
apply_unconditional_transform=False
):

if apply_unconditional_transform:
unconditional_transform = lambda features: MonotonicNormalizer(integrand_net_layers, 0, nb_steps, solver)
else:
unconditional_transform = None
self.cond_size = cond_size
super().__init__(
mask,
transform_net_create_fn,
unconditional_transform=unconditional_transform,
)

self.transformer = MonotonicNormalizer(integrand_net_layers, cond_size, nb_steps, solver)

def _transform_dim_multiplier(self):
return self.cond_size

def _coupling_transform_forward(self, inputs, transform_params):
if len(inputs.shape) == 2:
z, jac = self.transformer(inputs, transform_params.reshape(inputs.shape[0], inputs.shape[1], -1))
log_det_jac = jac.log().sum(1)
return z, log_det_jac
else:
B, C, H, W = inputs.shape
z, jac = self.transformer(inputs.permute(0, 2, 3, 1).reshape(-1, inputs.shape[1]), transform_params.permute(0, 2, 3, 1).reshape(-1, 1, transform_params.shape[1]))
log_det_jac = jac.log().reshape(B, -1).sum(1)
return z.reshape(B, H, W, C).permute(0, 3, 1, 2), log_det_jac

def _coupling_transform_inverse(self, inputs, transform_params):
if len(inputs.shape) == 2:
x = self.transformer.inverse_transform(inputs, transform_params.reshape(inputs.shape[0], inputs.shape[1], -1))
z, jac = self.transformer(x, transform_params.reshape(inputs.shape[0], inputs.shape[1], -1))
log_det_jac = -jac.log().sum(1)
return x, log_det_jac
else:
B, C, H, W = inputs.shape
x = self.transformer.inverse_transform(inputs.permute(0, 2, 3, 1).reshape(-1, inputs.shape[1]), transform_params.permute(0, 2, 3, 1).reshape(-1, 1, transform_params.shape[1]))
z, jac = self.transformer(x, transform_params.permute(0, 2, 3, 1).reshape(-1, 1, transform_params.shape[1]))
log_det_jac = -jac.log().reshape(B, -1).sum(1)
return x.reshape(B, H, W, C).permute(0, 3, 1, 2), log_det_jac


class AffineCouplingTransform(CouplingTransform):
"""An affine coupling layer that scales and shifts part of the variables.
Expand All @@ -151,7 +219,7 @@ def _transform_dim_multiplier(self):
return 2

def _scale_and_shift(self, transform_params):
unconstrained_scale = transform_params[:, self.num_transform_features :, ...]
unconstrained_scale = transform_params[:, self.num_transform_features:, ...]
shift = transform_params[:, : self.num_transform_features, ...]
# scale = (F.softplus(unconstrained_scale) + 1e-3).clamp(0, 3)
scale = torch.sigmoid(unconstrained_scale + 2) + 1e-3
Expand Down
18 changes: 18 additions & 0 deletions tests/transforms/autoregressive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,24 @@ def test_forward_inverse_are_consistent(self):
self.assert_forward_inverse_are_consistent(transform, inputs)


class MaskedUMNNAutoregressiveTranformTest(TransformTest):
def test_forward_inverse_are_consistent(self):
batch_size = 10
features = 20
inputs = torch.rand(batch_size, features)
self.eps = 1e-4

transform = autoregressive.MaskedUMNNAutoregressiveTransform(
cond_size=10,
features=features,
hidden_features=30,
num_blocks=5,
use_residual_blocks=True,
)

self.assert_forward_inverse_are_consistent(transform, inputs)


class MaskedPiecewiseCubicAutoregressiveTranformTest(TransformTest):
def test_forward_inverse_are_consistent(self):
batch_size = 10
Expand Down
47 changes: 47 additions & 0 deletions tests/transforms/coupling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,53 @@ def test_forward_inverse_are_consistent(self):
self.assert_forward_inverse_are_consistent(transform, inputs)


class UMNNTransformTest(TransformTest):
shapes = [[20], [2, 4, 4]]

def test_forward(self):
for shape in self.shapes:
inputs = torch.randn(batch_size, *shape)
transform, mask = create_coupling_transform(
coupling.UMNNCouplingTransform, shape, integrand_net_layers=[50, 50, 50],
cond_size=20,
nb_steps=20,
solver="CC"
)
outputs, logabsdet = transform(inputs)
with self.subTest(shape=shape):
self.assert_tensor_is_good(outputs, [batch_size] + shape)
self.assert_tensor_is_good(logabsdet, [batch_size])
self.assertEqual(outputs[:, mask <= 0, ...], inputs[:, mask <= 0, ...])

def test_inverse(self):
for shape in self.shapes:
inputs = torch.randn(batch_size, *shape)
transform, mask = create_coupling_transform(
coupling.UMNNCouplingTransform, shape, integrand_net_layers=[50, 50, 50],
cond_size=20,
nb_steps=20,
solver="CC"
)
outputs, logabsdet = transform(inputs)
with self.subTest(shape=shape):
self.assert_tensor_is_good(outputs, [batch_size] + shape)
self.assert_tensor_is_good(logabsdet, [batch_size])
self.assertEqual(outputs[:, mask <= 0, ...], inputs[:, mask <= 0, ...])

def test_forward_inverse_are_consistent(self):
self.eps = 1e-6
for shape in self.shapes:
inputs = torch.randn(batch_size, *shape)
transform, mask = create_coupling_transform(
coupling.UMNNCouplingTransform, shape, integrand_net_layers=[50, 50, 50],
cond_size=20,
nb_steps=20,
solver="CC"
)
with self.subTest(shape=shape):
self.assert_forward_inverse_are_consistent(transform, inputs)


class PiecewiseCouplingTransformTest(TransformTest):
classes = [
coupling.PiecewiseLinearCouplingTransform,
Expand Down

0 comments on commit 639c3a7

Please sign in to comment.