diff --git a/tests/transforms/coupling_test.py b/tests/transforms/coupling_test.py index 90ed480..b53e40e 100644 --- a/tests/transforms/coupling_test.py +++ b/tests/transforms/coupling_test.py @@ -71,6 +71,18 @@ def test_forward_inverse_are_consistent(self): with self.subTest(shape=shape): self.assert_forward_inverse_are_consistent(transform, inputs) + def test_scale_activation_has_an_effect(self): + for shape in self.shapes: + inputs = torch.randn(batch_size, *shape) + transform, mask = create_coupling_transform( + coupling.AffineCouplingTransform, shape + ) + outputs_default, logabsdet_default = transform(inputs) + transform.scale_activation = coupling.AffineCouplingTransform.GENERAL_SCALE_ACTIVATION + outputs_general, logabsdet_general = transform(inputs) + with self.subTest(shape=shape): + self.assertNotEqual(outputs_default, outputs_general) + self.assertNotEqual(logabsdet_default, logabsdet_general) class AdditiveTransformTest(TransformTest): shapes = [[20], [2, 4, 4]]