diff --git a/src/autodiff/chainrules_patch.jl b/src/autodiff/chainrules_patch.jl index 5f64e75..87c9785 100644 --- a/src/autodiff/chainrules_patch.jl +++ b/src/autodiff/chainrules_patch.jl @@ -1,10 +1,10 @@ -import ChainRulesCore: rrule, @non_differentiable, NoTangent +import ChainRulesCore: rrule, @non_differentiable, NoTangent, Tangent function rrule(::typeof(apply), reg::ArrayReg, block::AbstractBlock) out = apply(reg, block) out, function (outδ) (in, inδ), paramsδ = apply_back((copy(out), outδ), block) - return (NoTangent(), inδ, paramsδ) + return (NoTangent(), inδ, dispatch(block, paramsδ)) end end @@ -12,14 +12,14 @@ function rrule(::typeof(apply), reg::ArrayReg, block::Add) out = apply(reg, block) out, function (outδ) (in, inδ), paramsδ = apply_back((copy(out), outδ), block; in = reg) - return (NoTangent(), inδ, paramsδ) + return (NoTangent(), inδ, dispatch(block, paramsδ)) end end function rrule(::typeof(dispatch), block::AbstractBlock, params) out = dispatch(block, params) out, function (outδ) - (NoTangent(), NoTangent(), outδ) + (NoTangent(), NoTangent(), parameters(outδ)) end end @@ -34,11 +34,30 @@ function rrule(::typeof(expect), op::AbstractBlock, reg::AbstractRegister{B}) wh end end -function rrule(::Type{Matrix}, block::AbstractBlock) - out = Matrix(block) +function rrule(::typeof(expect), op::AbstractBlock, reg_and_circuit::Pair{<:ArrayReg{B},<:AbstractBlock}) where {B} + out = expect(op, reg_and_circuit) + out, function (outδ) + greg, gcircuit = expect_g(op, reg_and_circuit) + for b in 1:B + viewbatch(greg, b).state .*= 2 * outδ[b] + end + return (NoTangent(), NoTangent(), Tangent{typeof(reg_and_circuit)}(; first=greg, second=dispatch(reg_and_circuit.second, gcircuit))) + end +end + +function rrule(::Type{T}, block::AbstractBlock) where T<:Matrix + out = T(block) + out, function (outδ) + paramsδ = mat_back(block, outδ) + return (NoTangent(), dispatch(block, paramsδ)) + end +end + +function rrule(::typeof(mat), ::Type{T}, block::AbstractBlock) where T + out = mat(T, block) out, function (outδ) paramsδ = mat_back(block, outδ) - return (NoTangent(), paramsδ) + return (NoTangent(), NoTangent(), dispatch(block, paramsδ)) end end diff --git a/test/autodiff/chainrules_patch.jl b/test/autodiff/chainrules_patch.jl index 91b410d..39f07a6 100644 --- a/test/autodiff/chainrules_patch.jl +++ b/test/autodiff/chainrules_patch.jl @@ -2,6 +2,10 @@ import Zygote, ForwardDiff using Random, Test using YaoBlocks, YaoArrayRegister +function Zygote.accum(a::AbstractBlock, b::AbstractBlock) + dispatch(a, parameters(a) + parameters(b)) +end + @testset "rules" begin h = put(5, 3 => Z) + put(5, 2 => X) c = chain(put(5, 2 => chain(Rx(1.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5))) @@ -31,8 +35,7 @@ using YaoBlocks, YaoArrayRegister @test Zygote.gradient(x -> real(sum(abs2, statevec(x'))), r)[1].state ≈ g1 # zygote does not work if `sin` is not here, # because it gives an adjoint of different type as the output matrix type. - # do not modify the data type please! Zygote - @test Zygote.gradient(x -> real(sum(sin, Matrix(x))), c)[1] ≈ + @test parameters(Zygote.gradient(x -> real(sum(sin, Matrix(x))), c)[1]) ≈ ForwardDiff.gradient(x -> real(sum(sin, Matrix(dispatch(c, x)))), parameters(c)) end @@ -46,6 +49,7 @@ end sum(real(st .* st)) end + # apply reg0 = zero_state(5) params = rand!(parameters(c)) paramsδ = Zygote.gradient(params -> loss(reg0, dispatch(c, params)), params)[1] @@ -64,6 +68,49 @@ end ) @test fregδ ≈ reinterpret(Float64, regδ.state) @test fparamsδ ≈ paramsδ + + # expect and fidelity + c = chain(put(5, 5=>Rx(1.5)), put(5,1=>Rx(0.4)), put(5,4=>Rx(0.2)), put(5, 2 => chain(Rx(0.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5))) + h = chain(repeat(5, X, 1:5)) + reg = rand_state(5) + function loss2(reg::AbstractRegister, circuit::AbstractBlock{N}) where {N} + return 5*real(expect(h, copy(reg) => circuit) + fidelity(reg, apply(reg, circuit))) + end + params = rand!(parameters(c)) + fδc = ForwardDiff.gradient( + params -> + loss2(ArrayReg(Matrix{Complex{eltype(params)}}(reg.state)), dispatch(c, params)), + params, + ) + δr, δc = Zygote.gradient((reg, params)->loss2(reg, dispatch(c, params)),reg, params) + @test δc ≈ fδc + + fregδ = ForwardDiff.gradient( + x -> loss2( + ArrayReg([Complex(x[2i-1], x[2i]) for i in 1:length(x)÷2]), + dispatch(c, Vector{real(eltype(x))}(params)), + ), + reinterpret(Float64, reg.state), + ) + @test fregδ ≈ reinterpret(Float64, δr.state) + + # operator fidelity + c = chain(put(5, 5=>Rx(1.5)), put(5,1=>Rx(0.4)), put(5,4=>Rx(0.2)), put(5, 2 => chain(Rx(0.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5))) + h = chain(repeat(5, X, 1:5)) + function loss3(circuit::AbstractBlock{N}, h) where {N} + return operator_fidelity(circuit, h) + end + params = rand!(parameters(c)) + fδc = ForwardDiff.gradient( + params -> + loss3(dispatch(c, params), h), + params, + ) + δc, = Zygote.gradient(p->loss3(dispatch(c, p), h), params) + @test δc ≈ fδc + + # NOTE: operator back propagation in expect is not implemented! + # to differentiate operators, we need to use the expensive `mat_back` function. end @testset "add block" begin