Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix In-place Assignments for PiecewiseRationalQuadratic Compatibility with functorch and torch2.0 #77

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

HamidrezaKmK
Copy link

Hello,

While working on my project (OOD Detection using Manifolds), I noticed an issue with the current implementation of the PiecewiseRationalQuadratic coupling layers. The in-place masked assignments, specifically:

outputs[outside_interval_mask] = inputs[outside_interval_mask]

pose challenges when constructing the computation graph in both functorch and torch2.0.

To address this, I've made necessary modifications in my fork, ensuring a functional adaptation that aligns with these libraries without altering the primary functionality of the layer.

I kindly ask for your review of these changes. If they align with your vision and maintain the library's integrity, I'd appreciate their incorporation into the main branch.

Thank you for your time and consideration.

Best,
Hamid

Copy link
Contributor

@arturbekasov arturbekasov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Hamid,

Thanks for taking the time for submit the PR. Left a few comments.

I am not against making the code more functional, but I would like to be a bit careful about performance implications of the additional copies. Is there a chance to see some before/after performance stats of the function?

Cheers,

Artur

@@ -21,7 +21,6 @@ def unconstrained_rational_quadratic_spline(
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
enable_identity_init=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this on purpose, or we should re-base the changes?

outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0

outputs = torch.where(outside_interval_mask, inputs, outputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we drop the zero init above if we're copying here anyway? I.e. outputs = torch.where(outside_interval_mask, inputs, torch.zeros_like(inputs)).

logabsdet[outside_interval_mask] = 0
outputs = torch.where(outside_interval_mask, inputs, outputs)
logabsdet = torch.where(outside_interval_mask, torch.zeros_like(logabsdet), logabsdet)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, does the original line even have an effect? We're assigning zeros to what is already initialized to zeros.

else:
raise RuntimeError("{} tails are not implemented.".format(tails))

if torch.any(inside_interval_mask):
(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
# outputs[inside_interval_mask],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not leave unused code in comments.

@@ -97,11 +114,7 @@ def rational_quadratic_spline(
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]

if enable_identity_init: #flow is the identity if initialized with parameters equal to zero
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as earlier comment: we want to keep this, not sure if this is on purpose.

)


# turn inside_interval_mask into an int tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we simplify below by using masked_scatter?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants