Skip to content

Commit

Permalink
Merge pull request #52 from dennisprangle/custom_activation
Browse files Browse the repository at this point in the history
Allow scales >1.001 in AffineCouplingTransform (fixes #49)
  • Loading branch information
arturbekasov authored Dec 2, 2021
2 parents 639c3a7 + cd6e3b7 commit ac0bf43
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
16 changes: 14 additions & 2 deletions nflows/transforms/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
from torch.nn.functional import softplus

from nflows.transforms import splines
from nflows.transforms.base import Transform
Expand Down Expand Up @@ -213,16 +214,27 @@ class AffineCouplingTransform(CouplingTransform):
Reference:
> L. Dinh et al., Density estimation using Real NVP, ICLR 2017.
The user should supply `scale_activation`, the final activation function in the neural network producing the scale tensor.
Two options are predefined in the class.
`DEFAULT_SCALE_ACTIVATION` preserves backwards compatibility but only produces scales <= 1.001.
`GENERAL_SCALE_ACTIVATION` produces scales <= 3, which is more useful in general applications.
"""

DEFAULT_SCALE_ACTIVATION = lambda x : torch.sigmoid(x + 2) + 1e-3
GENERAL_SCALE_ACTIVATION = lambda x : (softplus(x) + 1e-3).clamp(0, 3)

def __init__(self, mask, transform_net_create_fn, unconditional_transform=None, scale_activation=DEFAULT_SCALE_ACTIVATION):
self.scale_activation = scale_activation
super().__init__(mask, transform_net_create_fn, unconditional_transform)

def _transform_dim_multiplier(self):
return 2

def _scale_and_shift(self, transform_params):
unconstrained_scale = transform_params[:, self.num_transform_features:, ...]
shift = transform_params[:, : self.num_transform_features, ...]
# scale = (F.softplus(unconstrained_scale) + 1e-3).clamp(0, 3)
scale = torch.sigmoid(unconstrained_scale + 2) + 1e-3
scale = self.scale_activation(unconstrained_scale)
return scale, shift

def _coupling_transform_forward(self, inputs, transform_params):
Expand Down
12 changes: 12 additions & 0 deletions tests/transforms/coupling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ def test_forward_inverse_are_consistent(self):
with self.subTest(shape=shape):
self.assert_forward_inverse_are_consistent(transform, inputs)

def test_scale_activation_has_an_effect(self):
for shape in self.shapes:
inputs = torch.randn(batch_size, *shape)
transform, mask = create_coupling_transform(
coupling.AffineCouplingTransform, shape
)
outputs_default, logabsdet_default = transform(inputs)
transform.scale_activation = coupling.AffineCouplingTransform.GENERAL_SCALE_ACTIVATION
outputs_general, logabsdet_general = transform(inputs)
with self.subTest(shape=shape):
self.assertNotEqual(outputs_default, outputs_general)
self.assertNotEqual(logabsdet_default, logabsdet_general)

class AdditiveTransformTest(TransformTest):
shapes = [[20], [2, 4, 4]]
Expand Down

0 comments on commit ac0bf43

Please sign in to comment.