-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix In-place Assignments for PiecewiseRationalQuadratic
Compatibility with functorch
and torch2.0
#77
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey Hamid,
Thanks for taking the time for submit the PR. Left a few comments.
I am not against making the code more functional, but I would like to be a bit careful about performance implications of the additional copies. Is there a chance to see some before/after performance stats of the function?
Cheers,
Artur
@@ -21,7 +21,6 @@ def unconstrained_rational_quadratic_spline( | |||
min_bin_width=DEFAULT_MIN_BIN_WIDTH, | |||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, | |||
min_derivative=DEFAULT_MIN_DERIVATIVE, | |||
enable_identity_init=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this on purpose, or we should re-base the changes?
outputs[outside_interval_mask] = inputs[outside_interval_mask] | ||
logabsdet[outside_interval_mask] = 0 | ||
|
||
outputs = torch.where(outside_interval_mask, inputs, outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we drop the zero init above if we're copying here anyway? I.e. outputs = torch.where(outside_interval_mask, inputs, torch.zeros_like(inputs))
.
logabsdet[outside_interval_mask] = 0 | ||
outputs = torch.where(outside_interval_mask, inputs, outputs) | ||
logabsdet = torch.where(outside_interval_mask, torch.zeros_like(logabsdet), logabsdet) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, does the original line even have an effect? We're assigning zeros to what is already initialized to zeros.
else: | ||
raise RuntimeError("{} tails are not implemented.".format(tails)) | ||
|
||
if torch.any(inside_interval_mask): | ||
( | ||
outputs[inside_interval_mask], | ||
logabsdet[inside_interval_mask], | ||
# outputs[inside_interval_mask], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not leave unused code in comments.
@@ -97,11 +114,7 @@ def rational_quadratic_spline( | |||
cumwidths[..., -1] = right | |||
widths = cumwidths[..., 1:] - cumwidths[..., :-1] | |||
|
|||
if enable_identity_init: #flow is the identity if initialized with parameters equal to zero |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as earlier comment: we want to keep this, not sure if this is on purpose.
) | ||
|
||
|
||
# turn inside_interval_mask into an int tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we simplify below by using masked_scatter?
Hello,
While working on my project (OOD Detection using Manifolds), I noticed an issue with the current implementation of the
PiecewiseRationalQuadratic
coupling layers. The in-place masked assignments, specifically:pose challenges when constructing the computation graph in both
functorch
andtorch2.0
.To address this, I've made necessary modifications in my fork, ensuring a functional adaptation that aligns with these libraries without altering the primary functionality of the layer.
I kindly ask for your review of these changes. If they align with your vision and maintain the library's integrity, I'd appreciate their incorporation into the main branch.
Thank you for your time and consideration.
Best,
Hamid