Skip to content

Commit

Permalink
Merge pull request #1205 from devmotion/dw/notimplemented
Browse files Browse the repository at this point in the history
Handle `ChainRulesCore.NotImplemented`
  • Loading branch information
CarloLucibello authored Apr 13, 2022
2 parents 1eb80c5 + b15eff1 commit 9602c6b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.37"
version = "0.6.38"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
1 change: 1 addition & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
@inline wrap_chainrules_output(x::ChainRulesCore.NotImplemented) = nothing
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
Expand Down
12 changes: 12 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,18 @@ using Zygote: ZygoteRuleConfig
@test (1.0,) == Zygote.gradient(oout_id_outer, π)
@test oout_id_rrule_hitcount[] == 0
end

# issue #1204
@testset "NotImplemented" begin
f_notimplemented(x) = x
@scalar_rule f_notimplemented(x) @not_implemented("not implemented :(")
@test Zygote.gradient(f_notimplemented, 0.1) === (nothing,)
@test Zygote.gradient(x -> f_notimplemented(x[1]), 0.1) === (nothing,)
if isdefined(Base, :only)
@test Zygote.gradient(x -> f_notimplemented(only(x)), (0.1,)) === (nothing,)
@test Zygote.gradient(x -> f_notimplemented(only(x)), [0.1]) === (nothing,)
end
end
end

@testset "ChainRulesCore.rrule_via_ad" begin
Expand Down

2 comments on commit 9602c6b

@CarloLucibello
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/58471

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.38 -m "<description of version>" 9602c6b2038879034c2de14d1f4aa251d99c6ea4
git push origin v0.6.38

Please sign in to comment.