From f4536a87467d22432a401e260a617d6c6821fe91 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 8 Sep 2021 07:02:31 +0200 Subject: [PATCH 1/3] fix pair getfield adjoint --- src/lib/base.jl | 25 ++++++++++++++----------- test/features.jl | 29 +++++++++++++++++------------ 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/lib/base.jl b/src/lib/base.jl index 67f8b2c5e..d9e748f9c 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -118,23 +118,26 @@ end # named tuple @adjoint function pairs(t::NamedTuple{N}) where N - pairs_namedtuple(dx::NamedTuple) = (dx.data,) - function pairs_namedtuple(Δ::Dict) - t0 = map(zero, t) - for (idx, v) in Δ - t0 = NamedTuple{N}(Base.setindex((t0...,), v, idx)) - end - return (t0,) + + pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,) + + function pairs_namedtuple_pullback(Δ::Dict) + t0 = map(zero, t) + for (idx, v) in Δ + t0 = NamedTuple{N}(Base.setindex((t0...,), v, idx)) end - return pairs(t), pairs_namedtuple + return (t0,) + end + + return pairs(t), pairs_namedtuple_pullback end @adjoint function Base.getfield(p::Pair, i::Int) - function pair_getfield(Δ) - f, s = i == 1 ? (Δ, zero(p[2])) : (zero(p[1]), Δ) + function pair_getfield_pullback(Δ) + f, s = i == 1 ? (Δ, nothing) : (nothing, Δ) return (first=f, second=s), nothing end - return getfield(p, i), pair_getfield + return getfield(p, i), pair_getfield_pullback end @adjoint Base.nameof(x::UnionAll) = nameof(x), _ -> (nothing,) diff --git a/test/features.jl b/test/features.jl index b17f55b41..7f9a1f70c 100644 --- a/test/features.jl +++ b/test/features.jl @@ -424,17 +424,17 @@ end end mutable struct MyMutable - value::Float64 + value::Float64 end function foo!(m::MyMutable, x) - m.value = x + m.value = x end function baz(args) - m = MyMutable(0.) - foo!(m, args...) - m.value + m = MyMutable(0.) + foo!(m, args...) + m.value end let @@ -449,13 +449,18 @@ end @test pullback(type_test)[1] == Complex{<:Real} @testset "Pairs" begin - @test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0 - @test (x->10*pairs((a=x, b=2))[2])'(100) === 0 - foo(;kw...) = 1 - @test gradient(() -> foo(a=1,b=2.0)) === () - - @test (x->10*(x => 2)[1])'(100) === 10.0 - @test (x->10*(x => 2)[2])'(100) === 0 + @test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0 + @test (x->10*pairs((a=x, b=2))[2])'(100) === 0 + foo(;kw...) = 1 + @test gradient(() -> foo(a=1,b=2.0)) === () + + @test (x->10*(x => 2)[1])'(100) === 10.0 + @test (x->10*(x => 2)[2])'(100) === 0 + + @test gradient(x-> (:x => x)[2], 17) == (1,) + + d = Dict(:x=>1.0, :y=>3.0); + @test gradient(d -> Dict(:x => d[:x])[:x], d) == (Dict(:x => 1),) end # https://github.com/JuliaDiff/ChainRules.jl/issues/257 From b7ee5381822a7de3265223baaf8f688cda1ab2a1 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 8 Sep 2021 07:58:52 +0200 Subject: [PATCH 2/3] fix test --- src/lib/array.jl | 1 + test/features.jl | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 15b994564..c63f0f74a 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -246,6 +246,7 @@ end @nograd workers function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator) + @show g.f g.iter y, b = ∇map(cx, g.f, g.iter) back(::Nothing) = nothing function back(ȳ) diff --git a/test/features.jl b/test/features.jl index 7f9a1f70c..f3931464d 100644 --- a/test/features.jl +++ b/test/features.jl @@ -455,7 +455,7 @@ end @test gradient(() -> foo(a=1,b=2.0)) === () @test (x->10*(x => 2)[1])'(100) === 10.0 - @test (x->10*(x => 2)[2])'(100) === 0 + @test (x->10*(x => 2)[2])'(100) === nothing @test gradient(x-> (:x => x)[2], 17) == (1,) From d9227ba07bdc36d74f804fbdd04aa251755cb0e5 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 8 Sep 2021 10:31:59 +0200 Subject: [PATCH 3/3] cleanup --- src/lib/array.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index c63f0f74a..15b994564 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -246,7 +246,6 @@ end @nograd workers function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator) - @show g.f g.iter y, b = ∇map(cx, g.f, g.iter) back(::Nothing) = nothing function back(ȳ)