-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement scalar rules for Zygote with ChainRules #103
Conversation
BTW the use of thunks also addresses some of @marcoct's comments in #65 (comment), namely
|
One suggestion for rules defined using the ChainRulesCore methods is to test them using the corresponding testers in ChainRulesTestUtils, which uses finite differences and also tests conventions as they are established. |
Co-authored-by: Seth Axen <[email protected]>
Yes, I am aware of this, I was just lazy 😄 Currently we evaluate the gradients of |
I added tests but I couldn't manage to test the chainrules: Error During Test at /home/david/.julia/dev/DistributionsAD/test/ad/chainrules.jl:2
Got exception outside of a @test
ArgumentError: eps must be positive, got NaN
Stacktrace:
[1] fdm(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::FiniteDifferences.var"#59#61"{Int64,Base.var"#64#65"{Base.var"#64#65"{Base.var"#64#65"{typeof(first),typeof(to_vec)},FiniteDifferences.var"#76#77"{ChainRulesTestUtils.var"#fnew#17"{ChainRulesTestUtils.var"#25#26"{NamedTuple{(),Tuple{}},typeof(DistributionsAD.uniformlogpdf)},Tuple{Float64,Float64,Float64},Tuple{Bool,Bool,Bool}}}},FiniteDifferences.var"#Tuple_from_vec#51"{Tuple{Float64,Float64,Float64},Tuple{Array{Float64,1},Array{Float64,1},Array{Float64,1}},Tuple{FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39"},Array{Int64,1}}},Array{Float64,1}}, ::Float64, ::Val{true}; condition::Int64, bound::Float64, eps::Float64, adapt::Int64, max_step::Float64) at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/methods.jl:234
[2] fdm(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::FiniteDifferences.var"#59#61"{Int64,Base.var"#64#65"{Base.var"#64#65"{Base.var"#64#65"{typeof(first),typeof(to_vec)},FiniteDifferences.var"#76#77"{ChainRulesTestUtils.var"#fnew#17"{ChainRulesTestUtils.var"#25#26"{NamedTuple{(),Tuple{}},typeof(DistributionsAD.uniformlogpdf)},Tuple{Float64,Float64,Float64},Tuple{Bool,Bool,Bool}}}},FiniteDifferences.var"#Tuple_from_vec#51"{Tuple{Float64,Float64,Float64},Tuple{Array{Float64,1},Array{Float64,1},Array{Float64,1}},Tuple{FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39"},Array{Int64,1}}},Array{Float64,1}}, ::Float64, ::Val{true}) at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/methods.jl:231
[3] #fdm#34 at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/methods.jl:289 [inlined]
[4] fdm at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/methods.jl:289 [inlined] (repeats 2 times)
[5] #_#15 at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/methods.jl:95 [inlined]
[6] Central at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/methods.jl:95 [inlined]
[7] #58 at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/grad.jl:16 [inlined]
[8] iterate at ./generator.jl:47 [inlined]
[9] _collect(::Base.OneTo{Int64}, ::Base.Generator{Base.OneTo{Int64},FiniteDifferences.var"#58#60"{FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}},Base.var"#64#65"{Base.var"#64#65"{Base.var"#64#65"{typeof(first),typeof(to_vec)},FiniteDifferences.var"#76#77"{ChainRulesTestUtils.var"#fnew#17"{ChainRulesTestUtils.var"#25#26"{NamedTuple{(),Tuple{}},typeof(DistributionsAD.uniformlogpdf)},Tuple{Float64,Float64,Float64},Tuple{Bool,Bool,Bool}}}},FiniteDifferences.var"#Tuple_from_vec#51"{Tuple{Float64,Float64,Float64},Tuple{Array{Float64,1},Array{Float64,1},Array{Float64,1}},Tuple{FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39"},Array{Int64,1}}},Array{Float64,1}}}, ::Base.EltypeUnknown, ::Base.HasShape{1}) at ./array.jl:678
[10] collect_similar(::Base.OneTo{Int64}, ::Base.Generator{Base.OneTo{Int64},FiniteDifferences.var"#58#60"{FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}},Base.var"#64#65"{Base.var"#64#65"{Base.var"#64#65"{typeof(first),typeof(to_vec)},FiniteDifferences.var"#76#77"{ChainRulesTestUtils.var"#fnew#17"{ChainRulesTestUtils.var"#25#26"{NamedTuple{(),Tuple{}},typeof(DistributionsAD.uniformlogpdf)},Tuple{Float64,Float64,Float64},Tuple{Bool,Bool,Bool}}}},FiniteDifferences.var"#Tuple_from_vec#51"{Tuple{Float64,Float64,Float64},Tuple{Array{Float64,1},Array{Float64,1},Array{Float64,1}},Tuple{FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39",FiniteDifferences.var"#Real_from_vec#39"},Array{Int64,1}}},Array{Float64,1}}}) at ./array.jl:607
[11] map(::Function, ::Base.OneTo{Int64}) at ./abstractarray.jl:2072
[12] #jacobian#57 at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/grad.jl:15 [inlined]
[13] jacobian at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/grad.jl:10 [inlined]
[14] _j′vp(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Function, ::Array{Float64,1}, ::Array{Float64,1}) at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/grad.jl:80
[15] j′vp(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Function, ::Float64, ::Tuple{Float64,Float64,Float64}) at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/grad.jl:73
[16] j′vp(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Function, ::Float64, ::Float64, ::Float64, ::Vararg{Float64,N} where N) at /home/david/.julia/packages/FiniteDifferences/sCBeL/src/grad.jl:76
[17] _make_j′vp_call(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Function, ::Float64, ::Tuple{Float64,Float64,Float64}, ::Tuple{Bool,Bool,Bool}) at /home/david/.julia/packages/ChainRulesTestUtils/9hZgi/src/testers.jl:69
[18] rrule_test(::Function, ::Float64, ::Tuple{Float64,Float64}, ::Vararg{Tuple{Float64,Float64},N} where N; rtol::Float64, atol::Float64, fdm::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, fkwargs::NamedTuple{(),Tuple{}}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/david/.julia/packages/ChainRulesTestUtils/9hZgi/src/testers.jl:239
[19] rrule_test(::Function, ::Float64, ::Tuple{Float64,Float64}, ::Vararg{Tuple{Float64,Float64},N} where N) at /home/david/.julia/packages/ChainRulesTestUtils/9hZgi/src/testers.jl:220
[20] top-level scope at /home/david/.julia/dev/DistributionsAD/test/ad/chainrules.jl:15
[21] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.4/Test/src/Test.jl:1113
[22] top-level scope at /home/david/.julia/dev/DistributionsAD/test/ad/chainrules.jl:3
[23] include(::String) at ./client.jl:439
[24] top-level scope at /home/david/.julia/dev/DistributionsAD/test/runtests.jl:59
[25] include(::String) at ./client.jl:439
[26] top-level scope at none:6
[27] eval(::Module, ::Any) at ./boot.jl:331
[28] exec_options(::Base.JLOptions) at ./client.jl:264
[29] _start() at ./client.jl:484 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Sorry for the delay!
Thanks, @devmotion and @mohamed82008! |
I think we should start moving from ZygoteRules to ChainRulesCore wherever it is possible (as suggested by @oxinabox at JuliaCon). We can't drop it completely before JuliaDiff/ChainRulesCore.jl#68 is fixed since sometimes we forward calls to
Zygote.pullback
. In this PR I just moved simple scalar rules over to ChainRules, which was quite straightforward. I noticed a bug (or at least inconsistency withlogpdf(::Uniform, x)
in Distributions) in the implementation ofuniformlogpdf
which I fixed in this PR (only the derivatives should beNaN
in the second branch).