Skip to content

Commit

Permalink
Merge pull request #1063 from FluxML/cl/pair
Browse files Browse the repository at this point in the history
fix pair getfield pullback
  • Loading branch information
CarloLucibello authored Sep 9, 2021
2 parents 05d0c2a + d9227ba commit 96f8e2b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 23 deletions.
25 changes: 14 additions & 11 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,26 @@ end

# named tuple
@adjoint function pairs(t::NamedTuple{N}) where N
pairs_namedtuple(dx::NamedTuple) = (dx.data,)
function pairs_namedtuple::Dict)
t0 = map(zero, t)
for (idx, v) in Δ
t0 = NamedTuple{N}(Base.setindex((t0...,), v, idx))
end
return (t0,)

pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,)

function pairs_namedtuple_pullback::Dict)
t0 = map(zero, t)
for (idx, v) in Δ
t0 = NamedTuple{N}(Base.setindex((t0...,), v, idx))
end
return pairs(t), pairs_namedtuple
return (t0,)
end

return pairs(t), pairs_namedtuple_pullback
end

@adjoint function Base.getfield(p::Pair, i::Int)
function pair_getfield(Δ)
f, s = i == 1 ? (Δ, zero(p[2])) : (zero(p[1]), Δ)
function pair_getfield_pullback(Δ)
f, s = i == 1 ? (Δ, nothing) : (nothing, Δ)
return (first=f, second=s), nothing
end
return getfield(p, i), pair_getfield
return getfield(p, i), pair_getfield_pullback
end

@adjoint Base.nameof(x::UnionAll) = nameof(x), _ -> (nothing,)
Expand Down
29 changes: 17 additions & 12 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,17 +424,17 @@ end
end

mutable struct MyMutable
value::Float64
value::Float64
end

function foo!(m::MyMutable, x)
m.value = x
m.value = x
end

function baz(args)
m = MyMutable(0.)
foo!(m, args...)
m.value
m = MyMutable(0.)
foo!(m, args...)
m.value
end

let
Expand All @@ -449,13 +449,18 @@ end
@test pullback(type_test)[1] == Complex{<:Real}

@testset "Pairs" begin
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0
@test (x->10*pairs((a=x, b=2))[2])'(100) === 0
foo(;kw...) = 1
@test gradient(() -> foo(a=1,b=2.0)) === ()

@test (x->10*(x => 2)[1])'(100) === 10.0
@test (x->10*(x => 2)[2])'(100) === 0
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0
@test (x->10*pairs((a=x, b=2))[2])'(100) === 0
foo(;kw...) = 1
@test gradient(() -> foo(a=1,b=2.0)) === ()

@test (x->10*(x => 2)[1])'(100) === 10.0
@test (x->10*(x => 2)[2])'(100) === nothing

@test gradient(x-> (:x => x)[2], 17) == (1,)

d = Dict(:x=>1.0, :y=>3.0);
@test gradient(d -> Dict(:x => d[:x])[:x], d) == (Dict(:x => 1),)
end

# https://github.com/JuliaDiff/ChainRules.jl/issues/257
Expand Down

0 comments on commit 96f8e2b

Please sign in to comment.