Skip to content

Commit

Permalink
Merge pull request #974 from mzgubic/mz/rename
Browse files Browse the repository at this point in the history
rename ChainRules differential types
  • Loading branch information
CarloLucibello authored May 21, 2021
2 parents d1c89e0 + 892022c commit b89ea3c
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
*.jl.mem
docs/build
Manifest.toml
dev/
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.10"
version = "0.6.11"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "0.7.55"
ChainRulesCore = "0.9.32"
ChainRulesCore = "0.9.44"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11"
ForwardDiff = "0.10"
Expand Down
6 changes: 3 additions & 3 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
# than happy.
@eval @inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T<:$T_outer}
@eval @inline function wrap_chainrules_output(x::ChainRules.Tangent{P, T}) where {P, T<:$T_outer}
xp = map(wrap_chainrules_output, canonicalize(x))
convert($T_outer, xp)
end
Expand All @@ -59,10 +59,10 @@ end
Convert `x` from the format Zygote uses internally to differentials types ChainRules uses.
"""
@inline wrap_chainrules_input(x) = x
@inline wrap_chainrules_input(::Nothing) = ChainRules.Zero()
@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent()
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, xs)
ChainRules.Composite{Any, typeof(xp)}(xp)
ChainRules.Tangent{Any, typeof(xp)}(xp)
end

"""
Expand Down
4 changes: 2 additions & 2 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ using Zygote, Test, ChainRules
not_diff_eg(x, i) = [10, 20][i]
function ChainRules.rrule(::typeof(not_diff_eg), x, i)
function not_diff_eg_pullback(Δ)
return ChainRules.NO_FIELDS, ChainRules.Zero(), ChainRules.DoesNotExist()
return ChainRules.NO_FIELDS, ChainRules.ZeroTangent(), ChainRules.NoTangent()
end
return not_diff_eg(x, i), not_diff_eg_pullback
end
Expand Down Expand Up @@ -204,7 +204,7 @@ using Zygote, Test, ChainRules
not_diff_kw_eg(x, i; kw=1.0) = [10, 20][i]
function ChainRules.rrule(::typeof(not_diff_kw_eg), x, i; kwargs...)
function not_diff_kw_eg_pullback(Δ)
return ChainRules.NO_FIELDS, ChainRules.Zero(), ChainRules.DoesNotExist()
return ChainRules.NO_FIELDS, ChainRules.ZeroTangent(), ChainRules.NoTangent()
end
return not_diff_kw_eg(x, i; kwargs...), not_diff_kw_eg_pullback
end
Expand Down

2 comments on commit b89ea3c

@CarloLucibello
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/37178

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.11 -m "<description of version>" b89ea3caea0c7abfe903852591a8e7ca26e79552
git push origin v0.6.11

Please sign in to comment.