Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement scalar rules for Zygote with ChainRules #103

Merged
merged 10 commits into from
Aug 23, 2020
Merged

Implement scalar rules for Zygote with ChainRules #103

merged 10 commits into from
Aug 23, 2020

Conversation

devmotion
Copy link
Member

I think we should start moving from ZygoteRules to ChainRulesCore wherever it is possible (as suggested by @oxinabox at JuliaCon). We can't drop it completely before JuliaDiff/ChainRulesCore.jl#68 is fixed since sometimes we forward calls to Zygote.pullback. In this PR I just moved simple scalar rules over to ChainRules, which was quite straightforward. I noticed a bug (or at least inconsistency with logpdf(::Uniform, x) in Distributions) in the implementation of uniformlogpdf which I fixed in this PR (only the derivatives should be NaN in the second branch).

@devmotion
Copy link
Member Author

BTW the use of thunks also addresses some of @marcoct's comments in #65 (comment), namely

One potential issue (which is probably a separate discussion, because I see it shows up in the current DistributionsAD code as well) is that the user may want gradients with respect to some subset of the parameters, and always computing the gradients for all parameters may be wasteful, and only computing gradients with respect to individual parameters would also be wasteful.

src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
@sethaxen
Copy link
Member

sethaxen commented Aug 1, 2020

One suggestion for rules defined using the ChainRulesCore methods is to test them using the corresponding testers in ChainRulesTestUtils, which uses finite differences and also tests conventions as they are established.

@devmotion
Copy link
Member Author

Yes, I am aware of this, I was just lazy 😄

Currently we evaluate the gradients of logpdf(::Distribution, x) of a fixed set of Distributions and a given set of distribution parameters and samples using Zygote, ForwardDiff, ReverseDiff, and Tracker and compare them with finite differencing using FiniteDifferences. IMO it would be an improvement to additionally test the gradient implementations more directly whenever possible.

@devmotion
Copy link
Member Author

I added tests but I couldn't manage to test the NaN branch properly - is this a general limitation of ChainRulesTestUtils right now? It failed with the error message

chainrules: Error During Test at /home/david/.julia/dev/DistributionsAD/test/ad/chainrules.jl:2
  Got exception outside of a @test
  ArgumentError: eps must be positive, got NaN
  Stacktrace:
   [1] fdm(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::FiniteDifferences.var"#59#61"{Int64,Base.var"#64#65"{Base.var"#64#65"{Base.var"#64#65"{typeof(first),typeof(to_vec)},FiniteDifferences.var"#76#77"{ChainRulesTestUtils.var"#fnew#17"{ChainRulesTestUtils.var"#25#26"{NamedTuple{(),Tuple{}},typeof(DistributionsAD.uniformlogpdf)},Tuple{Float64,Float64,Float64},Tuple{Bool,Bool,Bool}}}},FiniteDifferences.var"#Tuple_from_vec#51"{Tuple{Float64,Float64,Float64},Tuple{Array{Float64,1},Array{Float64,1},Array{Float64,1}},Tuple{FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39"},Array{Int64,1}}},Array{Float64,1}}, ::Float64, ::Val{true}; condition::Int64, bound::Float64, eps::Float64, adapt::Int64, max_step::Float64) at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/methods.jl:234
   [2] fdm(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::FiniteDifferences.var"#59#61"{Int64,Base.var"#64#65"{Base.var"#64#65"{Base.var"#64#65"{typeof(first),typeof(to_vec)},FiniteDifferences.var"#76#77"{ChainRulesTestUtils.var"#fnew#17"{ChainRulesTestUtils.var"#25#26"{NamedTuple{(),Tuple{}},typeof(DistributionsAD.uniformlogpdf)},Tuple{Float64,Float64,Float64},Tuple{Bool,Bool,Bool}}}},FiniteDifferences.var"#Tuple_from_vec#51"{Tuple{Float64,Float64,Float64},Tuple{Array{Float64,1},Array{Float64,1},Array{Float64,1}},Tuple{FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39"},Array{Int64,1}}},Array{Float64,1}}, ::Float64, ::Val{true}) at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/methods.jl:231
   [3] #fdm#34 at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/methods.jl:289 [inlined]
   [4] fdm at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/methods.jl:289 [inlined] (repeats 2 times)
   [5] #_#15 at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/methods.jl:95 [inlined]
   [6] Central at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/methods.jl:95 [inlined]
   [7] #58 at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/grad.jl:16 [inlined]
   [8] iterate at ./generator.jl:47 [inlined]
   [9] _collect(::Base.OneTo{Int64}, ::Base.Generator{Base.OneTo{Int64},FiniteDifferences.var"#58#60"{FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}},Base.var"#64#65"{Base.var"#64#65"{Base.var"#64#65"{typeof(first),typeof(to_vec)},FiniteDifferences.var"#76#77"{ChainRulesTestUtils.var"#fnew#17"{ChainRulesTestUtils.var"#25#26"{NamedTuple{(),Tuple{}},typeof(DistributionsAD.uniformlogpdf)},Tuple{Float64,Float64,Float64},Tuple{Bool,Bool,Bool}}}},FiniteDifferences.var"#Tuple_from_vec#51"{Tuple{Float64,Float64,Float64},Tuple{Array{Float64,1},Array{Float64,1},Array{Float64,1}},Tuple{FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39"},Array{Int64,1}}},Array{Float64,1}}}, ::Base.EltypeUnknown, ::Base.HasShape{1}) at ./array.jl:678
   [10] collect_similar(::Base.OneTo{Int64}, ::Base.Generator{Base.OneTo{Int64},FiniteDifferences.var"#58#60"{FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}},Base.var"#64#65"{Base.var"#64#65"{Base.var"#64#65"{typeof(first),typeof(to_vec)},FiniteDifferences.var"#76#77"{ChainRulesTestUtils.var"#fnew#17"{ChainRulesTestUtils.var"#25#26"{NamedTuple{(),Tuple{}},typeof(DistributionsAD.uniformlogpdf)},Tuple{Float64,Float64,Float64},Tuple{Bool,Bool,Bool}}}},FiniteDifferences.var"#Tuple_from_vec#51"{Tuple{Float64,Float64,Float64},Tuple{Array{Float64,1},Array{Float64,1},Array{Float64,1}},Tuple{FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39"},Array{Int64,1}}},Array{Float64,1}}}) at ./array.jl:607
   [11] map(::Function, ::Base.OneTo{Int64}) at ./abstractarray.jl:2072
   [12] #jacobian#57 at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/grad.jl:15 [inlined]
   [13] jacobian at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/grad.jl:10 [inlined]
   [14] _j′vp(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Function, ::Array{Float64,1}, ::Array{Float64,1}) at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/grad.jl:80
   [15] j′vp(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Function, ::Float64, ::Tuple{Float64,Float64,Float64}) at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/grad.jl:73
   [16] j′vp(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Function, ::Float64, ::Float64, ::Float64, ::Vararg{Float64,N} where N) at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/grad.jl:76
   [17] _make_j′vp_call(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Function, ::Float64, ::Tuple{Float64,Float64,Float64}, ::Tuple{Bool,Bool,Bool}) at /home/david/.julia/packages/ChainRulesTestUtils/9hZgi/src/testers.jl:69
   [18] rrule_test(::Function, ::Float64, ::Tuple{Float64,Float64}, ::Vararg{Tuple{Float64,Float64},N} where N; rtol::Float64, atol::Float64, fdm::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, fkwargs::NamedTuple{(),Tuple{}}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/david/.julia/packages/ChainRulesTestUtils/9hZgi/src/testers.jl:239
   [19] rrule_test(::Function, ::Float64, ::Tuple{Float64,Float64}, ::Vararg{Tuple{Float64,Float64},N} where N) at /home/david/.julia/packages/ChainRulesTestUtils/9hZgi/src/testers.jl:220
   [20] top-level scope at /home/david/.julia/dev/DistributionsAD/test/ad/chainrules.jl:15
   [21] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.4/Test/src/Test.jl:1113
   [22] top-level scope at /home/david/.julia/dev/DistributionsAD/test/ad/chainrules.jl:3
   [23] include(::String) at ./client.jl:439
   [24] top-level scope at /home/david/.julia/dev/DistributionsAD/test/runtests.jl:59
   [25] include(::String) at ./client.jl:439
   [26] top-level scope at none:6
   [27] eval(::Module, ::Any) at ./boot.jl:331
   [28] exec_options(::Base.JLOptions) at ./client.jl:264
   [29] _start() at ./client.jl:484

@devmotion devmotion requested review from mohamed82008 and removed request for mohamed82008 August 17, 2020 14:17
Copy link
Member

@mohamed82008 mohamed82008 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Sorry for the delay!

@yebai yebai merged commit 01ad761 into master Aug 23, 2020
@yebai yebai deleted the chainrules branch August 23, 2020 10:06
@yebai
Copy link
Member

yebai commented Aug 23, 2020

Thanks, @devmotion and @mohamed82008!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants