Skip to content

Commit

Permalink
fix some issues with rules for ^
Browse files Browse the repository at this point in the history
  • Loading branch information
simeonschaub authored and mcabbott committed Sep 3, 2021
1 parent a130b8f commit 415c8fb
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 8 deletions.
11 changes: 9 additions & 2 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,18 @@ end
# note: rules for ^ are defined in the fastmath_able.jl
function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{p}) where p
y = Base.literal_pow(^, x, pv)
return y, (p * y / x * Δx)
return y, ifelse(iszero(x), zero(y), p * y / x * Δx)
end
frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{1}) = x^1, Δx

function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{p}) where p
y = Base.literal_pow(^, x, pv)
literal_pow_pullback(dy) = NoTangent(), NoTangent(), (p * y / x * dy), NoTangent()
function literal_pow_pullback(dy)
return NoTangent(), NoTangent(), ifelse(iszero(x), zero(y), p * y / x * dy), NoTangent()
end
return y, literal_pow_pullback
end
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{1})
literal_pow_one_pullback(dy) = NoTangent(), NoTangent(), dy, NoTangent()
return x^1, literal_pow_one_pullback
end
14 changes: 8 additions & 6 deletions src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,16 @@ let
@scalar_rule x - y (true, -1)
@scalar_rule x / y (one(x) / y, -/ y))
#log(complex(x)) is required so it gives correct complex answer for x<0
@scalar_rule(x ^ y,
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(complex(x))),
)
@scalar_rule(x ^ y, (
ifelse(iszero(x), ifelse(isone(y), one(Ω), zero(Ω)), y * Ω / x),
Ω * log(complex(x)),
))
# x^y for x < 0 errors when y is not an integer, but then derivative wrt y
# is undefined, so we adopt subgradient convention and set derivative to 0.
@scalar_rule(x::Real ^ y::Real,
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(oftype(Ω, ifelse(x 0, one(x), x)))),
)
@scalar_rule(x::Real ^ y::Real, (
ifelse(iszero(x), ifelse(isone(y), one(Ω), zero(Ω)), y * Ω / x),
Ω * log(oftype(Ω, ifelse(x 0, one(x), x))),
))
@scalar_rule(
rem(x, y),
@setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
Expand Down
9 changes: 9 additions & 0 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@
# for real x and n, x must be >0
test_frule(Base.literal_pow, ^, 3.5, Val(3))
test_rrule(Base.literal_pow, ^, 3.5, Val(3))

test_frule(Base.literal_pow, ^, 0.0, Val(3))
test_rrule(Base.literal_pow, ^, 0.0, Val(3))

test_frule(Base.literal_pow, ^, 3.5, Val(1))
test_rrule(Base.literal_pow, ^, 3.5, Val(1))

test_frule(Base.literal_pow, ^, 0.0, Val(1))
test_rrule(Base.literal_pow, ^, 0.0, Val(1))
end

@testset "Float conversions" begin
Expand Down
9 changes: 9 additions & 0 deletions test/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ const FASTABLE_AST = quote
@test ∂y 0
end
end

@testset "edge cases with ^" begin
# FIXME
@test_skip test_frule(^, 0.0, rand() + 3 NoTangent(); fdm=forward_fdm(5,1))
test_rrule(^, 0.0, rand() + 3; fdm=forward_fdm(5,1))

test_frule(^, 0.0, 1.0 NoTangent(); fdm=forward_fdm(5,1))
test_rrule(^, 0.0, 1.0; fdm=forward_fdm(5,1))
end
end

@testset "sign" begin
Expand Down

0 comments on commit 415c8fb

Please sign in to comment.