You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am doing some tests with the identity initialization for rational quadratic splines.
When using the new identity init implemented in #65 , giving in input x = torch.tensor([1, 1e-2 ,1e-6, 1e-8, 1e2], dtype=torch.float32) the following is the inverse for the untrained network (which should be initialized as the identity):
# in spline def: enable_identity_init=True
# transform back
flow.transform_to_noise(x.view(-1,1))
tensor([[ 1.7013],
[ 1.3796],
[ 1.3739],
[ 1.3739],
[100.0000]], grad_fn=<AddmmBackward0>)
If instead I manually set the weights of the last layer to 0 in the last layer of the transform network (as done in the normflows package) I get the identity as expected:
# in spline def: enable_identity_init=False
# in the model def
if init_identity:
torch.nn.init.constant_(autoregressive_net.final_layer.weight, 0.0)
torch.nn.init.constant_(
autoregressive_net.final_layer.bias,
np.log(np.exp(1 - min_derivative) - 1),
)
# stuff
# transform back
flow.transform_to_noise(x.view(-1,1))
tensor([[1.0000e+00],
[1.0000e-02],
[1.0000e-06],
[1.0000e-08],
[1.0000e+02]], grad_fn=<AddmmBackward0>)
I was wondering whether you could help me figure out this difference in behavior.
If this seems potentially useful I can work more than gladly on a pull request.
Best regards,
Francesco
The text was updated successfully, but these errors were encountered:
Hello and thanks for the work on the package,
I am doing some tests with the identity initialization for rational quadratic splines.
When using the new identity init implemented in #65 , giving in input
x = torch.tensor([1, 1e-2 ,1e-6, 1e-8, 1e2], dtype=torch.float32)
the following is the inverse for the untrained network (which should be initialized as the identity):If instead I manually set the weights of the last layer to 0 in the last layer of the transform network (as done in the normflows package) I get the identity as expected:
I was wondering whether you could help me figure out this difference in behavior.
If this seems potentially useful I can work more than gladly on a pull request.
Best regards,
Francesco
The text was updated successfully, but these errors were encountered: