Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rules for ^ #513

Merged
merged 20 commits into from
Sep 5, 2021
Merged

Fix rules for ^ #513

merged 20 commits into from
Sep 5, 2021

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Aug 27, 2021

Builds on #485 but tries to get all the edge cases. Still a bit messy.

Comment on lines 188 to 191
(0.0, 2) => (0.0, 0.0),
(-0.0, 2) => (-0.0, 0.0),
(0.0, 1) => (1.0, 0.0),
(-0.0, 1) => (1.0, 0.0),
(0.0, 0) => (0.0, NaN),
(-0.0, 0) => (0.0, NaN),
(0.0, -1) => (-Inf, NaN),
(-0.0, -1) => (-Inf, NaN),
(0.0, -2) => (-Inf, NaN),
(-0.0, -2) => (Inf, NaN),
Copy link
Member Author

@mcabbott mcabbott Aug 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This list of (x,p) => (dx,dp) is what the PR aims to produce for 0^p cases, x and p both real.

Copy link
Member Author

@mcabbott mcabbott Sep 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a not very sophisticated argument that the NaNs should be either -Inf, or +-Inf according to sign of x, for even & odd negative integer p:

using Plots, ChainRules;

powderivs(x,p) = real.(ChainRules.unthunk.(ChainRules.rrule(^, x+0im, p)[2](1)[2:3]));

function pplot(p)
    ps = p .+ (-0.1:0.001:0.1)
    plot(p ->real((0.0+0im)^p), ps; lab="0^p");
    for x in (-0.1, -0.05, -0.01, 0.01, 0.05, 0.1)
        plot!(p -> powderivs(x,p)[2], ps; lab="slope($x^p)", s=(x<0 ? :dash : :solid))
    end
    plot!(xguide="p")
end
pplot(-1)  # gradient smooth in p, but diverges to +-Inf as x -> 0
pplot(-2)  # ... diverges to -Inf only
pplot(-3)  # +- Inf again

pplot(0.5)
pplot(1)
pplot(1.5)
pplot(2)  # these all converge to 0

pplot(0)  # argue this is -Inf?

But for non-integer p, it has to do something in between. Which sounds like NaN.

Comment on lines 211 to 215
_pow_grad_x(x, p, y) = (p * y / x)
function _pow_grad_x(x::Real, p::Real, y)
return ifelse(!iszero(x) | (p<0), (p * y / x),
ifelse(isone(p), one(y),
ifelse(0<p<1, oftype(y, Inf), zero(y) )))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and this ugly nest is how it does it, with primal y = x^p.

If you try computing yox = x^(p-1) and y = yox * x, then getting the edge cases for the gradient is easier. But getting the edge cases for the primal to match what Base does results in a similar-size nest of ifelse.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is using ifelse here really beneficial? I would have expected that floating point division is expensive enough that branching here makes more sense.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't tried to time it, was focused on figuring out how to hit all the edge cases.

@oxinabox
Copy link
Member

Sorry, I am not going to have time to review this til mid next week.
I have started redoing JuliaDiff/ChainRulesCore.jl#404
to be less about subgradients and more just iterating through every case and saying what to do.

One thing I am currently believing is that:
Inf should be avoided where possible, but it is ok if can't.
(also that 0*Inf=NaN is right, because else it is better for derivative of cbrt(x)^3 at 0, to be NaN than it is for it to be 0. Ideally it would be 1)
NaN should really be avoided if there is any possible excude to answer anything else, unless the input or output primal is also NaN in which case it is fine.

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I still haven't finished reviewing.
Can we maybe add a justification for each example in the tests of POWERGRADS ?

Comment on lines +192 to +187
(0.0, 0) => (0.0, NaN),
(-0.0, 0) => (0.0, NaN),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought there should be NoTangent() which we could normalize to 0 for performance reason.
Since if you can't really the perturb the input about that point.
Why is NaN better?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I persuaded myself these were indeterminate, hence made them like 0/0. But I'm not super-sure, sometimes I also think they are a particular sign of Inf.

In some way NaN is a safe answer, as it guarantees that nobody will be relying on a particular possibly wrong choice.

Copy link
Member

@simeonschaub simeonschaub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, but the test failures need to be addressed of course. If you want, I can also take a look at them.

Comment on lines 211 to 215
_pow_grad_x(x, p, y) = (p * y / x)
function _pow_grad_x(x::Real, p::Real, y)
return ifelse(!iszero(x) | (p<0), (p * y / x),
ifelse(isone(p), one(y),
ifelse(0<p<1, oftype(y, Inf), zero(y) )))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is using ifelse here really beneficial? I would have expected that floating point division is expensive enough that branching here makes more sense.

Comment on lines 201 to 202
return ifelse(!iszero(x), y * real(log(complex(x))),
ifelse(p>0, zero(y), oftype(y, NaN) ))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, this probably should use branching instead of ifelse.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick attempt to time this, with contradoctory results.

julia> _pow_ifelse(x::Real, p::Real, y)= ifelse(!iszero(x) | (p<0), (p * y / x),
                            ifelse(isone(p), one(y),
                              ifelse(0<p<1, oftype(y, Inf), zero(y) )));

julia> _pow_branch(x::Real, p::Real, y)= (!iszero(x) | (p<0)) ? (p * y / x) :
                            ifelse(isone(p), one(y),
                              ifelse(0<p<1, oftype(y, Inf), zero(y) ));

julia> N = 1000; x,y,p,z = randn(N),randn(N),randn(N),randn(N);

julia> _pow_ifelse.(x,p,y) ≈ _pow_branch.(x,p,y)
true

julia> @btime @. $z = _pow_ifelse($x,$p,$y);
  1.625 μs (0 allocations: 0 bytes)

julia> @btime @. $z = _pow_branch($x,$p,$y);
  653.632 ns (0 allocations: 0 bytes)

julia> x[randn(N).>0] .= 0; p[rand(N).>0.9] .= 1;

julia> @btime @. $z = _pow_branch($x,$p,$y);
  997.200 ns (0 allocations: 0 bytes)

julia> versioninfo()
Julia Version 1.7.0-beta3.0
  OS: macOS (x86_64-apple-darwin18.7.0)  # rosetta

Natively, the reverse?

julia> @btime @. $z = _pow_ifelse($x,$p,$y);
  491.835 ns (0 allocations: 0 bytes)

julia> @btime @. $z = _pow_branch($x,$p,$y);
  634.312 ns (0 allocations: 0 bytes)

julia> x[randn(N).>0] .= 0; p[rand(N).>0.9] .= 1;

julia> @btime @. $z = _pow_branch($x,$p,$y);
  973.059 ns (0 allocations: 0 bytes)

julia> versioninfo()
Julia Version 1.8.0-DEV.459
  OS: macOS (arm64-apple-darwin20.6.0)  # M1 native

Remotely:

julia> @btime @. $z = _pow_ifelse($x,$p,$y);
  2.545 μs (0 allocations: 0 bytes)

julia> @btime @. $z = _pow_branch($x,$p,$y);
  1.177 μs (0 allocations: 0 bytes)

julia> x[randn(N).>0] .= 0; p[rand(N).>0.9] .= 1;

julia> @btime @. $z = _pow_branch($x,$p,$y);
  1.216 μs (0 allocations: 0 bytes)

julia> versioninfo()
Julia Version 1.6.1
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Xeon(R) Gold 6226 CPU @ 2.70GHz

Not sure this is a good way to time this. Will a broadcast gradient ever compile down to just this or will there always be a lot of other junk around?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, very interesting! So it does look like the ifelse can be faster in some situations, but I think I'd still go with the branching version since it seems faster on the more common architectures and ideally llvm would figure out how to better vectorize the branching version if it can prove that there are no side effects.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I find most surprising is that in the first test, where the first condition is always true, the ifelse is path is sometimes much slower. I expected the stuff like oftype(y, Inf), zero(y) to be essentially free compared to p * y / x.

Maybe the other tests like isone(p) are more expensive than I thought? In which case maybe you want more branches that just the one I tried here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is telling me that 0<p<1 has a branch, maybe much of the above is about how that gets treated:

julia> Meta.@lower 1<2<3
:($(Expr(:thunk, CodeInfo(
    @ none within `top-level scope`
1 ─ %1 = 1 < 2
└──      goto #3 if not %1
2 ─ %3 = 2 < 3
└──      return %3
3 ─      return false
))))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good catch! You could try it with (0 < p) & (p < 1) instead. These are really just microoptimizations though, so we can always leave that for later to decide.

@oxinabox
Copy link
Member

oxinabox commented Sep 2, 2021

I will leave giving final approval on this to @simeonschaub
and unsubscribe from this issue.
Ping me if i am needed

Copy link
Member

@simeonschaub simeonschaub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small question and some style nits, otherwise I think this just needs a version bump and it's good to go. Or is there any other reason this is still marked as a draft?

Comment on lines 215 to 218
# ∂x_fwd = frule((0,1,0), ^, x, p)[1]
# ∂p_fwd = frule((0,0,1), ^, x, p)[2]
# isequal(∂x, ∂x_fwd) || println("^ forward `x` gradient for $y = $x^$p: got $∂x_fwd, expected $∂x, maybe!")
# isequal(∂p, ∂p_fwd) || println("^ forward `p` gradient for $x^$p: got $∂p_fwd, expected $∂p, maybe")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's marked draft because these frule tests aren't quite right.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The print statements give this:

^ forward `x` gradient for 1.0 = 1.0^2: got 1.0, expected 2.0, maybe!
^ forward `x` gradient for 0.0 = 0.0^1: got 0.0, expected 1.0, maybe!
^ forward `x` gradient for -0.0 = -0.0^1: got -0.0, expected 1.0, maybe!
^ forward `x` gradient for 1.0 = 0.0^0: got 1.0, expected 0.0, maybe!
^ forward `x` gradient for 1.0 = -0.0^0: got 1.0, expected 0.0, maybe!
^ forward `x` gradient for Inf = 0.0^-1: got Inf, expected -Inf, maybe!
^ forward `x` gradient for Inf = 0.0^-2: got Inf, expected -Inf, maybe!
^ forward `x` gradient for 0 = 0^1: got 0, expected 1.0, maybe!
^ forward `x` gradient for 1 = 0^0: got 1, expected 0.0, maybe!
^ forward `x` gradient for 0.0 = 0.0^0.5: got 0.0, expected Inf, maybe!
^ forward `p` gradient for 0.0^0.5: got NaN, expected 0.0, maybe
^ forward `x` gradient for Inf = 0.0^-1.5: got Inf, expected -Inf, maybe!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe fixed, by regarding input Δp==0 as a "strong zero". I'm not 100% sure this is idea but someone else can fix it if they desire.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The diffractor issue is forward mode. My version:

function f1(x) # from https://github.com/JuliaDiff/Diffractor.jl/issues/26
    res = 0.0
    for i in 1:10
        res += x^i
    end
    res
end

using Diffractor, ForwardDiff, Plots
f1fwd(x) = let var"'" = Diffractor.PrimeDerivativeFwd 
    f1'(x)
end
f1rev(x) = let var"'" = Diffractor.PrimeDerivativeBack
    f1'(x)
end
xs = -0.5:0.01:0.5
p1 = plot(xs, f1.(xs), lab="f1")
plot!(xs, f1fwd.(xs), lab="f1', forward")
plot!(xs, ForwardDiff.derivative.(f1, xs).+0.02, lab="f1', dual")
plot!(xs, f1rev.(xs).+0.04, lab="f1', reverse")  # fails with this PR

Before:
f1_dual

After:
f1_after

So the original problem is fixed. But there was also a worse problem with reverse mode. Which now doesn't run at all after this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh now I'm less sure. That's on Julia 1.8-, mac M1 native. On 1.7, rosetta, I can't plot, but I don't see the problem. With or without this PR, forward and reverse agree:

julia> f1fwd(0.5)
3.9765625

julia> f1rev(0.5)
3.9765625

@mcabbott mcabbott marked this pull request as ready for review September 3, 2021 22:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants