-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix rules for ^
#513
Fix rules for ^
#513
Conversation
test/rulesets/Base/fastmath_able.jl
Outdated
(0.0, 2) => (0.0, 0.0), | ||
(-0.0, 2) => (-0.0, 0.0), | ||
(0.0, 1) => (1.0, 0.0), | ||
(-0.0, 1) => (1.0, 0.0), | ||
(0.0, 0) => (0.0, NaN), | ||
(-0.0, 0) => (0.0, NaN), | ||
(0.0, -1) => (-Inf, NaN), | ||
(-0.0, -1) => (-Inf, NaN), | ||
(0.0, -2) => (-Inf, NaN), | ||
(-0.0, -2) => (Inf, NaN), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This list of (x,p) => (dx,dp)
is what the PR aims to produce for 0^p
cases, x
and p
both real.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's a not very sophisticated argument that the NaNs should be either -Inf, or +-Inf according to sign of x, for even & odd negative integer p:
using Plots, ChainRules;
powderivs(x,p) = real.(ChainRules.unthunk.(ChainRules.rrule(^, x+0im, p)[2](1)[2:3]));
function pplot(p)
ps = p .+ (-0.1:0.001:0.1)
plot(p ->real((0.0+0im)^p), ps; lab="0^p");
for x in (-0.1, -0.05, -0.01, 0.01, 0.05, 0.1)
plot!(p -> powderivs(x,p)[2], ps; lab="slope($x^p)", s=(x<0 ? :dash : :solid))
end
plot!(xguide="p")
end
pplot(-1) # gradient smooth in p, but diverges to +-Inf as x -> 0
pplot(-2) # ... diverges to -Inf only
pplot(-3) # +- Inf again
pplot(0.5)
pplot(1)
pplot(1.5)
pplot(2) # these all converge to 0
pplot(0) # argue this is -Inf?
But for non-integer p
, it has to do something in between. Which sounds like NaN.
src/rulesets/Base/fastmath_able.jl
Outdated
_pow_grad_x(x, p, y) = (p * y / x) | ||
function _pow_grad_x(x::Real, p::Real, y) | ||
return ifelse(!iszero(x) | (p<0), (p * y / x), | ||
ifelse(isone(p), one(y), | ||
ifelse(0<p<1, oftype(y, Inf), zero(y) ))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... and this ugly nest is how it does it, with primal y = x^p
.
If you try computing yox = x^(p-1)
and y = yox * x
, then getting the edge cases for the gradient is easier. But getting the edge cases for the primal to match what Base does results in a similar-size nest of ifelse
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is using ifelse
here really beneficial? I would have expected that floating point division is expensive enough that branching here makes more sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tried to time it, was focused on figuring out how to hit all the edge cases.
Sorry, I am not going to have time to review this til mid next week. One thing I am currently believing is that: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I still haven't finished reviewing.
Can we maybe add a justification for each example in the tests of POWERGRADS ?
(0.0, 0) => (0.0, NaN), | ||
(-0.0, 0) => (0.0, NaN), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought there should be NoTangent()
which we could normalize to 0
for performance reason.
Since if you can't really the perturb the input about that point.
Why is NaN
better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought I persuaded myself these were indeterminate, hence made them like 0/0. But I'm not super-sure, sometimes I also think they are a particular sign of Inf.
In some way NaN
is a safe answer, as it guarantees that nobody will be relying on a particular possibly wrong choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, but the test failures need to be addressed of course. If you want, I can also take a look at them.
src/rulesets/Base/fastmath_able.jl
Outdated
_pow_grad_x(x, p, y) = (p * y / x) | ||
function _pow_grad_x(x::Real, p::Real, y) | ||
return ifelse(!iszero(x) | (p<0), (p * y / x), | ||
ifelse(isone(p), one(y), | ||
ifelse(0<p<1, oftype(y, Inf), zero(y) ))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is using ifelse
here really beneficial? I would have expected that floating point division is expensive enough that branching here makes more sense.
src/rulesets/Base/fastmath_able.jl
Outdated
return ifelse(!iszero(x), y * real(log(complex(x))), | ||
ifelse(p>0, zero(y), oftype(y, NaN) )) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, this probably should use branching instead of ifelse
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A quick attempt to time this, with contradoctory results.
julia> _pow_ifelse(x::Real, p::Real, y)= ifelse(!iszero(x) | (p<0), (p * y / x),
ifelse(isone(p), one(y),
ifelse(0<p<1, oftype(y, Inf), zero(y) )));
julia> _pow_branch(x::Real, p::Real, y)= (!iszero(x) | (p<0)) ? (p * y / x) :
ifelse(isone(p), one(y),
ifelse(0<p<1, oftype(y, Inf), zero(y) ));
julia> N = 1000; x,y,p,z = randn(N),randn(N),randn(N),randn(N);
julia> _pow_ifelse.(x,p,y) ≈ _pow_branch.(x,p,y)
true
julia> @btime @. $z = _pow_ifelse($x,$p,$y);
1.625 μs (0 allocations: 0 bytes)
julia> @btime @. $z = _pow_branch($x,$p,$y);
653.632 ns (0 allocations: 0 bytes)
julia> x[randn(N).>0] .= 0; p[rand(N).>0.9] .= 1;
julia> @btime @. $z = _pow_branch($x,$p,$y);
997.200 ns (0 allocations: 0 bytes)
julia> versioninfo()
Julia Version 1.7.0-beta3.0
OS: macOS (x86_64-apple-darwin18.7.0) # rosetta
Natively, the reverse?
julia> @btime @. $z = _pow_ifelse($x,$p,$y);
491.835 ns (0 allocations: 0 bytes)
julia> @btime @. $z = _pow_branch($x,$p,$y);
634.312 ns (0 allocations: 0 bytes)
julia> x[randn(N).>0] .= 0; p[rand(N).>0.9] .= 1;
julia> @btime @. $z = _pow_branch($x,$p,$y);
973.059 ns (0 allocations: 0 bytes)
julia> versioninfo()
Julia Version 1.8.0-DEV.459
OS: macOS (arm64-apple-darwin20.6.0) # M1 native
Remotely:
julia> @btime @. $z = _pow_ifelse($x,$p,$y);
2.545 μs (0 allocations: 0 bytes)
julia> @btime @. $z = _pow_branch($x,$p,$y);
1.177 μs (0 allocations: 0 bytes)
julia> x[randn(N).>0] .= 0; p[rand(N).>0.9] .= 1;
julia> @btime @. $z = _pow_branch($x,$p,$y);
1.216 μs (0 allocations: 0 bytes)
julia> versioninfo()
Julia Version 1.6.1
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Xeon(R) Gold 6226 CPU @ 2.70GHz
Not sure this is a good way to time this. Will a broadcast gradient ever compile down to just this or will there always be a lot of other junk around?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, very interesting! So it does look like the ifelse
can be faster in some situations, but I think I'd still go with the branching version since it seems faster on the more common architectures and ideally llvm would figure out how to better vectorize the branching version if it can prove that there are no side effects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I find most surprising is that in the first test, where the first condition is always true, the ifelse
is path is sometimes much slower. I expected the stuff like oftype(y, Inf), zero(y)
to be essentially free compared to p * y / x
.
Maybe the other tests like isone(p)
are more expensive than I thought? In which case maybe you want more branches that just the one I tried here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is telling me that 0<p<1
has a branch, maybe much of the above is about how that gets treated:
julia> Meta.@lower 1<2<3
:($(Expr(:thunk, CodeInfo(
@ none within `top-level scope`
1 ─ %1 = 1 < 2
└── goto #3 if not %1
2 ─ %3 = 2 < 3
└── return %3
3 ─ return false
))))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good catch! You could try it with (0 < p) & (p < 1)
instead. These are really just microoptimizations though, so we can always leave that for later to decide.
I will leave giving final approval on this to @simeonschaub |
This reverts commit 688af01.
a2b08dd
to
936cb21
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a small question and some style nits, otherwise I think this just needs a version bump and it's good to go. Or is there any other reason this is still marked as a draft?
test/rulesets/Base/fastmath_able.jl
Outdated
# ∂x_fwd = frule((0,1,0), ^, x, p)[1] | ||
# ∂p_fwd = frule((0,0,1), ^, x, p)[2] | ||
# isequal(∂x, ∂x_fwd) || println("^ forward `x` gradient for $y = $x^$p: got $∂x_fwd, expected $∂x, maybe!") | ||
# isequal(∂p, ∂p_fwd) || println("^ forward `p` gradient for $x^$p: got $∂p_fwd, expected $∂p, maybe") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's marked draft because these frule tests aren't quite right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The print statements give this:
^ forward `x` gradient for 1.0 = 1.0^2: got 1.0, expected 2.0, maybe!
^ forward `x` gradient for 0.0 = 0.0^1: got 0.0, expected 1.0, maybe!
^ forward `x` gradient for -0.0 = -0.0^1: got -0.0, expected 1.0, maybe!
^ forward `x` gradient for 1.0 = 0.0^0: got 1.0, expected 0.0, maybe!
^ forward `x` gradient for 1.0 = -0.0^0: got 1.0, expected 0.0, maybe!
^ forward `x` gradient for Inf = 0.0^-1: got Inf, expected -Inf, maybe!
^ forward `x` gradient for Inf = 0.0^-2: got Inf, expected -Inf, maybe!
^ forward `x` gradient for 0 = 0^1: got 0, expected 1.0, maybe!
^ forward `x` gradient for 1 = 0^0: got 1, expected 0.0, maybe!
^ forward `x` gradient for 0.0 = 0.0^0.5: got 0.0, expected Inf, maybe!
^ forward `p` gradient for 0.0^0.5: got NaN, expected 0.0, maybe
^ forward `x` gradient for Inf = 0.0^-1.5: got Inf, expected -Inf, maybe!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe fixed, by regarding input Δp==0
as a "strong zero". I'm not 100% sure this is idea but someone else can fix it if they desire.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The diffractor issue is forward mode. My version:
function f1(x) # from https://github.com/JuliaDiff/Diffractor.jl/issues/26
res = 0.0
for i in 1:10
res += x^i
end
res
end
using Diffractor, ForwardDiff, Plots
f1fwd(x) = let var"'" = Diffractor.PrimeDerivativeFwd
f1'(x)
end
f1rev(x) = let var"'" = Diffractor.PrimeDerivativeBack
f1'(x)
end
xs = -0.5:0.01:0.5
p1 = plot(xs, f1.(xs), lab="f1")
plot!(xs, f1fwd.(xs), lab="f1', forward")
plot!(xs, ForwardDiff.derivative.(f1, xs).+0.02, lab="f1', dual")
plot!(xs, f1rev.(xs).+0.04, lab="f1', reverse") # fails with this PR
So the original problem is fixed. But there was also a worse problem with reverse mode. Which now doesn't run at all after this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh now I'm less sure. That's on Julia 1.8-, mac M1 native. On 1.7, rosetta, I can't plot, but I don't see the problem. With or without this PR, forward and reverse agree:
julia> f1fwd(0.5)
3.9765625
julia> f1rev(0.5)
3.9765625
Builds on #485 but tries to get all the edge cases. Still a bit messy.