Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rules for ^ #513

Merged
merged 20 commits into from
Sep 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.11.3"
version = "1.11.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
26 changes: 18 additions & 8 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,24 @@ end
@scalar_rule floor(x) zero(x)
@scalar_rule ceil(x) zero(x)

# note: rules for ^ are defined in the fastmath_able.jl
function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{p}) where p
y = Base.literal_pow(^, x, pv)
return y, (p * y / x * Δx)
# `literal_pow`
# This is mostly handled by AD; it's a micro-optimisation to provide a gradient for x*x*x
# Note that rules for `^` are defined in the fastmath_able.jl

function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{2})
return x * x, 2 * x * Δx
end
function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{3})
x2 = x * x
return x2 * x, 3 * x2 * Δx
end

function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{p}) where p
y = Base.literal_pow(^, x, pv)
literal_pow_pullback(dy) = NoTangent(), NoTangent(), (p * y / x * dy), NoTangent()
return y, literal_pow_pullback
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{2})
square_pullback(dy) = (NoTangent(), NoTangent(), ProjectTo(x)(2 * x * dy), NoTangent())
return x * x, square_pullback
end
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{3})
x2 = x * x
cube_pullback(dy) = (NoTangent(), NoTangent(), ProjectTo(x)(3 * x2 * dy), NoTangent())
return x2 * x, cube_pullback
end
80 changes: 66 additions & 14 deletions src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ let
# exponents
@scalar_rule cbrt(x) inv(3 * Ω ^ 2)
@scalar_rule inv(x) -(Ω ^ 2)
@scalar_rule sqrt(x) inv(2Ω)
@scalar_rule sqrt(x) inv(2Ω) # gradient +Inf at x==0
@scalar_rule exp(x) Ω
@scalar_rule exp10(x) Ω * log(oftype(x, 10))
@scalar_rule exp2(x) Ω * log(oftype(x, 2))
Expand Down Expand Up @@ -137,8 +137,7 @@ let

# Binary functions

# `hypot`

## `hypot`
function frule(
(_, Δx, Δy),
::typeof(hypot),
Expand All @@ -163,29 +162,53 @@ let
@scalar_rule x + y (true, true)
@scalar_rule x - y (true, -1)
@scalar_rule x / y (one(x) / y, -(Ω / y))
#log(complex(x)) is required so it gives correct complex answer for x<0
@scalar_rule(x ^ y,
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(complex(x))),
)
# x^y for x < 0 errors when y is not an integer, but then derivative wrt y
# is undefined, so we adopt subgradient convention and set derivative to 0.
@scalar_rule(x::Real ^ y::Real,
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(oftype(Ω, ifelse(x ≤ 0, one(x), x)))),
)

## power
# literal_pow is in base.jl
function frule((_, Δx, Δp), ::typeof(^), x::Number, p::Number)
y = x ^ p
_dx = _pow_grad_x(x, p, float(y))
if iszero(Δp)
# Treat this as a strong zero, to avoid NaN, and save the cost of log
return y, _dx * Δx
else
# This may do real(log(complex(...))) which matches ProjectTo in rrule
_dp = _pow_grad_p(x, p, float(y))
return y, muladd(_dp, Δp, _dx * Δx)
end
end

function rrule(::typeof(^), x::Number, p::Number)
y = x^p
project_x = ProjectTo(x)
project_p = ProjectTo(p)
function power_pullback(dy)
_dx = _pow_grad_x(x, p, float(y))
return (
NoTangent(),
project_x(conj(_dx) * dy),
# _pow_grad_p contains log, perhaps worth thunking:
@thunk project_p(conj(_pow_grad_p(x, p, float(y))) * dy)
)
end
return y, power_pullback
end

## `rem`
@scalar_rule(
rem(x, y),
@setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))),
)
## `min`, `max`
@scalar_rule max(x, y) @setup(gt = x > y) (gt, !gt)
@scalar_rule min(x, y) @setup(gt = x > y) (!gt, gt)

# Unary functions
@scalar_rule +x true
@scalar_rule -x -1

# `sign`

## `sign`
function frule((_, Δx), ::typeof(sign), x)
n = ifelse(iszero(x), one(real(x)), abs(x))
Ω = x isa Real ? sign(x) : x / n
Expand Down Expand Up @@ -237,9 +260,38 @@ let
"Non-FastMath compatible rules defined in fastmath_able.jl. \n Definitions:\n" *
join(non_transformed_definitions, "\n")
)
# This error() may not play well with Revise. But a wanring @error does:
# @error "Non-FastMath compatible rules defined in fastmath_able.jl." non_transformed_definitions
end

eval(fast_ast)
eval(fastable_ast) # Get original definitions
# we do this second so it overwrites anything we included by mistake in the fastable
end

## power
# Thes functions need to be defined outside the eval() block.
# The special cases they aim to hit are in POWERGRADS in tests.
_pow_grad_x(x, p, y) = (p * y / x)
function _pow_grad_x(x::Real, p::Real, y)
return if !iszero(x) || p < 0
p * y / x
elseif isone(p)
one(y)
elseif iszero(p) || p > 1
zero(y)
else
oftype(y, Inf)
end
end

_pow_grad_p(x, p, y) = y * log(complex(x))
function _pow_grad_p(x::Real, p::Real, y)
return if !iszero(x)
y * real(log(complex(x)))
elseif p > 0
zero(y)
else
oftype(y, NaN)
end
end
8 changes: 4 additions & 4 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@
@test rrule(Base.depwarn, "message", :f) !== nothing
end

@testset "literal_pow" begin
# for real x and n, x must be >0
test_frule(Base.literal_pow, ^, 3.5, Val(3))
test_rrule(Base.literal_pow, ^, 3.5, Val(3))
@testset "literal_pow: $x^$p" for x in [-1.5, 0.0, 3.5], p in [2, 3]
x == 0 && p < 0 && continue
test_frule(Base.literal_pow, ^, x, Val(p))
test_rrule(Base.literal_pow, ^, x, Val(p))
end

@testset "Float conversions" begin
Expand Down
94 changes: 72 additions & 22 deletions test/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ const FASTABLE_AST = quote
test_rrule(f, 10rand(T), rand(T))
end

@testset "$f(x::$T, y::$T) type check" for f in (/, +, -,\, hypot, ^), T in (Float32, Float64)
@testset "$f(x::$T, y::$T) type check" for f in (/, +, -,\, hypot), T in (Float32, Float64)
x, Δx, x̄ = 10rand(T, 3)
y, Δy, ȳ = rand(T, 3)
@assert T == typeof(f(x, y))
Expand All @@ -159,28 +159,78 @@ const FASTABLE_AST = quote
end
end

@testset "^(x::$T, n::$T)" for T in (Float64, ComplexF64)
# for real x and n, x must be >0
test_frule(^, rand(T) + 3, rand(T) + 3)
test_rrule(^, rand(T) + 3, rand(T) + 3)

T <: Real && @testset "discontinuity for ^(x::Real, n::Int) when x ≤ 0" begin
# finite differences doesn't work for x < 0, so we check manually
x = -rand(T) .- 3
y = 3
Δx = randn(T)
Δy = randn(T)
Δz = randn(T)

@test frule((ZeroTangent(), Δx, Δy), ^, x, y)[2] ≈ Δx * y * x^(y - 1)
@test frule((ZeroTangent(), Δx, Δy), ^, zero(x), y)[2] ≈ 0
_, ∂x, ∂y = rrule(^, x, y)[2](Δz)
@test ∂x ≈ Δz * y * x^(y - 1)
@test ∂y ≈ 0
_, ∂x, ∂y = rrule(^, zero(x), y)[2](Δz)
@test ∂x ≈ 0
@test ∂y ≈ 0
@testset "^(x::$T, p::$S)" for T in (Float64, ComplexF64), S in (Float64, ComplexF64)
test_frule(^, rand(T) + 3, rand(S) + 3)
test_rrule(^, rand(T) + 3, rand(S) + 3)

# When both x & p are Real, and !(isinteger(p)),
# then x must be positive to avoid a DomainError
T <: Real && S <: Real && continue
# In other cases, we can test values near zero:

test_frule(^, randn(T), rand(S))
test_rrule(^, rand(T), rand(S))
end

# Tests for power functions, at values near to zero.
POWERGRADS = [ # (x,p) => (dx,dp)
# Some regular points, as sanity checks:
(1.0, 2) => (2.0, 0.0),
(2.0, 2) => (4.0, 2.772588722239781),
# At x=0, gradients for x seem clear,
# for p less certain what's best.
(0.0, 2) => (0.0, 0.0),
(-0.0, 2) => (0.0, 0.0), # probably (-0.0, 0.0) would be ideal
(0.0, 1) => (1.0, 0.0),
(-0.0, 1) => (1.0, 0.0),
(0.0, 0) => (0.0, NaN),
(-0.0, 0) => (0.0, NaN),
Comment on lines +186 to +187
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought there should be NoTangent() which we could normalize to 0 for performance reason.
Since if you can't really the perturb the input about that point.
Why is NaN better?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I persuaded myself these were indeterminate, hence made them like 0/0. But I'm not super-sure, sometimes I also think they are a particular sign of Inf.

In some way NaN is a safe answer, as it guarantees that nobody will be relying on a particular possibly wrong choice.

(0.0, -1) => (-Inf, NaN),
(-0.0, -1) => (-Inf, NaN),
(0.0, -2) => (-Inf, NaN),
(-0.0, -2) => (Inf, NaN),
# Integer x & p, check no InexactErrors
(0, 2) => (0.0, 0.0),
(0, 1) => (1.0, 0.0),
(0, 0) => (0.0, NaN),
(0, -1) => (-Inf, NaN),
(0, -2) => (-Inf, NaN),
# Non-integer powers:
(0.0, 0.5) => (Inf, 0.0),
(0.0, 3.5) => (0.0, 0.0),
(0.0, -1.5) => (-Inf, NaN),
]

@testset "$x ^ $p" for ((x,p), (∂x, ∂p)) in POWERGRADS
if x isa Integer && p isa Integer && p < 0
@test_throws DomainError x^p
continue
end
y = x^p

# Forward
y_fwd = frule((1,1,1), ^, x, p)[1]
@test isequal(y, y_fwd)

∂x_fwd = frule((0,1,0), ^, x, p)[2]
∂p_fwd = frule((0,0,1), ^, x, p)[2]
@test isequal(∂x, ∂x_fwd)
if x===0.0 && p===0.5
@test_broken isequal(∂p, ∂p_fwd)
else
@test isequal(∂p, ∂p_fwd)
end

∂x_fwd = frule((0,1,ZeroTangent()), ^, x, p)[2] # easier, strong zero
@test isequal(∂x, ∂x_fwd)

# Reverse
y_rev = rrule(^, x, p)[1]
@test isequal(y, y_rev)

∂x_rev, ∂p_rev = unthunk.(rrule(^, x, p)[2](1))[2:3]
@test isequal(∂x, ∂x_rev)
@test isequal(∂p, ∂p_rev)
end
end

Expand Down