You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
julia> using Zygote;
julia> functionfoo(xs)
sum(xs) do x
(2*x, 3*x)
end
end
foo (generic function with 1 method)
julia> Zygote.pullback(foo, [1.0, 2.0, 3.0])
ERROR: MethodError: no method matching +(::Tuple{Float64, Float64}, ::Tuple{Float64, Float64})
The function `+` exists, but no method is defined for this combination of argument types.
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...)
@ Base operators.jl:596
+(::ChainRulesCore.ZeroTangent, ::Any)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_arithmetic.jl:99
+(::Any, ::ChainRulesCore.NotImplemented)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_arithmetic.jl:25
...
Stacktrace:
[1] add_sum(x::Tuple{Float64, Float64}, y::Tuple{Float64, Float64})
@ Base ./reduce.jl:24
[2] _mapreduce
@ ./reduce.jl:437 [inlined]
[3] _mapreduce_dim
@ ./reducedim.jl:337 [inlined]
[4] mapreduce
@ ./reducedim.jl:329 [inlined]
[5] _sum
@ ./reducedim.jl:987 [inlined]
[6] sum
@ ./reducedim.jl:983 [inlined]
[7] #rrule#725
@ ~/.julia/packages/ChainRules/sm2ny/src/rulesets/Base/mapreduce.jl:103 [inlined]
[8] rrule
@ ~/.julia/packages/ChainRules/sm2ny/src/rulesets/Base/mapreduce.jl:76 [inlined]
[9] chain_rrule
@ ~/.julia/packages/Zygote/3To5I/src/compiler/chainrules.jl:233 [inlined]
[10] macro expansion
@ ~/.julia/packages/Zygote/3To5I/src/compiler/interface2.jl:0 [inlined]
[11] _pullback
@ ~/.julia/packages/Zygote/3To5I/src/compiler/interface2.jl:91 [inlined]
[12] foo
@ ./REPL[6]:2 [inlined]
[13] _pullback(ctx::Zygote.Context{false}, f::typeof(foo), args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/3To5I/src/compiler/interface2.jl:0
[14] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/3To5I/src/compiler/interface.jl:96
[15] pullback(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/3To5I/src/compiler/interface.jl:94
[16] top-level scope
@ REPL[7]:1
So it seems like the problem is the tuple type should be a Tangent{Tuple{...}, ...}
So Zygote isn't wrapping stuff up right before giving it to chainrules?
The text was updated successfully, but these errors were encountered:
The primal doesn't work either, and throws an almost identical error.
julia> foo([1., 2., 3.])
ERROR: MethodError: no method matching +(::Tuple{Float64, Float64}, ::Tuple{Float64, Float64})
The function `+` exists, but no method is defined for this combination of argument types.
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...)
@ Base operators.jl:596
+(::BitMatrix, ::LinearAlgebra.UniformScaling)
@ LinearAlgebra .julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/uniformscaling.jl:151
+(::Bool, ::Complex{Bool})
@ Base complex.jl:308
...
MWE:
So it seems like the problem is the tuple type should be a
Tangent{Tuple{...}, ...}
So Zygote isn't wrapping stuff up right before giving it to chainrules?
The text was updated successfully, but these errors were encountered: