Skip to content

Commit

Permalink
Merge pull request #1489 from lxvm/product
Browse files Browse the repository at this point in the history
Improve adjoint for product and zip
  • Loading branch information
ToucheSir authored Jan 19, 2024
2 parents 54f1e80 + 18e48db commit 46477ee
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 25 deletions.
73 changes: 48 additions & 25 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ end
struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]
(::StaticGetter{i})(::Nothing) where {i} = nothing
@generated function _unzip(tuples, ::Val{N}) where {N}
Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i 1:N)...)
function _unzip(tuples, ::Val{N}) where {N}
getters = ntuple(n -> StaticGetter{n}(), N)
map(g -> map(g, tuples), getters)
end
function unzip(tuples)
N = length(first(tuples))
Expand Down Expand Up @@ -169,8 +170,11 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U)
# So we keep axes(x) to restore gradient dx to its full length & correct shape.
_tryaxes(x) = axes(x)
_tryaxes(x::Tuple) = Val(length(x))
_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(length, ax) - length(dx))), ax)
_tryaxes(x::Number) = x
_restore(dx::AbstractArray{Nothing}, ax::Tuple) = similar(dx, ax)
_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(map(length, ax)) - length(dx))), ax)
_restore(dx, ::Val{N}) where {N} = ntuple(i -> get(dx,i,nothing), N)
_restore(dx, ::Number) = only(dx)

# Sometimes a pullback doesn't return a Tuple, but rather returns only a
# single nothing to say "all arguments have zero cotangent". This function is needed to
Expand Down Expand Up @@ -268,32 +272,51 @@ end
_ndims(::Base.HasShape{d}) where {d} = d
_ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)) : 1

function productfunc(xs, dy)
@assert length(first(dy)) == length(xs)
ndim = map(Zygote._ndims, xs)
cdim = cumsum((1, ndim[begin:end-1]...))
getters = ntuple(n -> StaticGetter{n}(), length(xs))
map(first(dy), xs, cdim, getters) do dyn, x, cd, getter
dyn === nothing && return nothing
nd = _ndims(x)
dims = nd == 0 ? (:) : ntuple(i -> i<cd ? i : i+nd, Val(ndims(dy)-nd))
init = map(zero, dyn) # allows for tuples, which accum can add:
red = mapreduce(getter, accum, dy; dims, init)
return _project(x, nd == 0 ? red : reshape(red, axes(x)))
end
end

@adjoint function Iterators.product(xs...)
back(::AbstractArray{Nothing}) = nothing
back(dy::NamedTuple{(:iterators,)}) = dy.iterators
function back(dy::AbstractArray)
d = 1
ntuple(length(xs)) do n
nd = _ndims(xs[n])
dims = ntuple(i -> i<d ? i : i+nd, ndims(dy)-nd)
d += nd
first(dy)[n] === nothing && return nothing
init = zero.(first(dy)[n]) # allows for tuples, which accum can add:
red = mapreduce(StaticGetter{n}(), accum, dy; dims=dims, init=init)
return _project(xs[n], reshape(red, axes(xs[n])))
end
product_pullback(::AbstractArray{Nothing}) = nothing
product_pullback(dy::NamedTuple{(:iterators,)}) = dy.iterators
product_pullback(dy::AbstractArray) = productfunc(xs, dy)
Iterators.product(xs...), product_pullback
end

@adjoint function Base.collect(p::Base.Iterators.ProductIterator)
collect_product_pullback(dy) = ((iterators=productfunc(p.iterators, dy),),)
return collect(p), collect_product_pullback
end

function zipfunc(xs, dy)
getters = ntuple(n -> StaticGetter{n}(), length(xs))
map(xs, getters) do x, getter
dx = map(getter, dy)
_project(x, _restore(dx, _tryaxes(x)))
end
Iterators.product(xs...), back
end

@adjoint function Iterators.Zip(xs)
axs = map(_tryaxes, xs) # same function used for map
back(dy::NamedTuple{(:is,)}) = tuple(dy.is)
back(dy::AbstractArray) = ntuple(length(xs)) do d
dx = map(StaticGetter{d}(), dy)
_project(xs[d], _restore(dx, axs[d]))
end |> tuple
Iterators.Zip(xs), back
@adjoint function Iterators.zip(xs...)
zip_pullback(::AbstractArray{Nothing}) = nothing
zip_pullback(dy::NamedTuple{(:is,)}) = dy.is
zip_pullback(dy::AbstractArray) = zipfunc(xs, dy)
Iterators.zip(xs...), zip_pullback
end

@adjoint function Base.collect(z::Base.Iterators.Zip)
collect_zip_pullback(dy::AbstractArray) = ((is=zipfunc(z.is, dy),),)
collect(z), collect_zip_pullback
end

# Reductions
Expand Down
52 changes: 52 additions & 0 deletions test/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,36 @@ test_rrule(ZygoteRuleConfig(), x->sum(sin, Diagonal(x)), rand(3); rrule_f=rrule_
# This was wrong before https://github.com/FluxML/Zygote.jl/pull/1170
@test gradient(x -> sum([y[2] * y[3] for y in Iterators.product(x, x, x, x)]), [1,2,3,4])[1] [320, 320, 320, 320]
@test gradient(x -> sum(y[2] * y[3] for y in Iterators.product(x, x, x, x)), [1,2,3,4])[1] [320, 320, 320, 320]

# Numbers failed before https://github.com/FluxML/Zygote.jl/pull/1489
for p in (1.0, fill(1.0), [1.0])
@test gradient(p -> sum([x*q for q in p, x in 1:3]), p) == (6p,)
@test gradient(p -> sum(x*q for (q, x) in Iterators.product(p, 1:3)), p) == (6p,)
end

# inference would also fail before #1489
y, back = _pullback(Iterators.product, 1:5, fill(1))
@test @inferred back(collect(y)) == (nothing, [1.0, 2.0, 3.0, 4.0, 5.0], fill(5.0))
end

@testset "adjoints of Iterators.zip" begin
y, back = _pullback(Iterators.zip, 1:5, 1:3, 1:2)
@test back(collect(y)) == (nothing, [1.0, 2.0, 0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [1.0, 2.0])
@test back([(nothing, j, k) for (i,j,k) in zip(1:5, 1:3, 1:2)]) == (nothing, nothing, [1.0, 2.0, 0.0], [1.0, 2.0])
@test back([(i, nothing, k) for (i,j,k) in zip(1:5, 1:3, 1:2)]) == (nothing, [1.0, 2.0, 0.0, 0.0, 0.0], nothing, [1.0, 2.0])
@test back([(i, j, nothing) for (i,j,k) in zip(1:5, 1:3, 1:2)]) == (nothing, [1.0, 2.0, 0.0, 0.0, 0.0], [1.0, 2.0, 0.0], nothing)


@test gradient(x -> sum([y[2] * y[3] for y in Iterators.zip(x, x, x, x)]), [1,2,3,4])[1] [2, 4, 6, 8]
@test gradient(x -> sum(y[2] * y[3] for y in Iterators.zip(x, x, x, x)), [1,2,3,4])[1] [2, 4, 6, 8]

for p in (1.0, fill(1.0), [1.0])
@test gradient(p_ -> sum(map(prod, Iterators.zip(p_, p))), p) == (p,)
@test gradient(p_ -> sum(x*q for (q, x) in Iterators.zip(p_, p)), p) == (p,)
end

y, back = _pullback(Iterators.zip, 1:5, fill(1))
@test @inferred back(collect(y)) == (nothing, [1.0, 0.0, 0.0, 0.0, 0.0], fill(1.0))
end

@testset "collect" begin
Expand Down Expand Up @@ -45,6 +75,28 @@ end
g = gradient(d -> sum(x^2 for x in collect(d)), t)[1]
@test g === (2.0, 4.0)
end

@testset "Iterators.ProductIterator" begin
p = Iterators.product(1:3, 1:2)
g = gradient(p -> sum(prod, collect(p)), p)[1]
@test g == (iterators=(3ones(3), 6ones(2)),)

@test gradient(x -> sum(broadcast(prod, Iterators.product(x,x))), ones(4)) == (2*4ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.product(x .^ 2, x))), ones(4)) == (3*4ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.product(x, x .^ 2))), ones(4)) == (3*4ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.product(x .^ 2, x .^ 2))), ones(4)) == (4*4ones(4),)
end

@testset "Iterators.Zip" begin
z = Iterators.zip(1:3, 1:2)
g = gradient(z -> sum(prod, collect(z)), z)[1]
@test g == (is=([1.0, 2.0, 0.0], [1.0, 2.0]),)

@test gradient(x -> sum(broadcast(prod, Iterators.zip(x,x))), ones(4)) == (2ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.zip(x.^2,x))), ones(4)) == (3ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.zip(x,x.^2))), ones(4)) == (3ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.zip(x.^2,x.^2))), ones(4)) == (4ones(4),)
end
end

@testset "dictionary comprehension" begin
Expand Down

2 comments on commit 46477ee

@ToucheSir
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 0.6.68 already exists

Please sign in to comment.