Skip to content
This repository has been archived by the owner on Dec 18, 2021. It is now read-only.

Fix issue 170 #171

Merged
merged 2 commits into from
Nov 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions src/autodiff/chainrules_patch.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import ChainRulesCore: rrule, @non_differentiable, NoTangent
import ChainRulesCore: rrule, @non_differentiable, NoTangent, Tangent

function rrule(::typeof(apply), reg::ArrayReg, block::AbstractBlock)
out = apply(reg, block)
out, function (outδ)
(in, inδ), paramsδ = apply_back((copy(out), outδ), block)
return (NoTangent(), inδ, paramsδ)
return (NoTangent(), inδ, dispatch(block, paramsδ))
end
end

function rrule(::typeof(apply), reg::ArrayReg, block::Add)
out = apply(reg, block)
out, function (outδ)
(in, inδ), paramsδ = apply_back((copy(out), outδ), block; in = reg)
return (NoTangent(), inδ, paramsδ)
return (NoTangent(), inδ, dispatch(block, paramsδ))
end
end

function rrule(::typeof(dispatch), block::AbstractBlock, params)
out = dispatch(block, params)
out, function (outδ)
(NoTangent(), NoTangent(), outδ)
(NoTangent(), NoTangent(), parameters(outδ))
end
end

Expand All @@ -34,11 +34,30 @@ function rrule(::typeof(expect), op::AbstractBlock, reg::AbstractRegister{B}) wh
end
end

function rrule(::Type{Matrix}, block::AbstractBlock)
out = Matrix(block)
function rrule(::typeof(expect), op::AbstractBlock, reg_and_circuit::Pair{<:ArrayReg{B},<:AbstractBlock}) where {B}
out = expect(op, reg_and_circuit)
out, function (outδ)
greg, gcircuit = expect_g(op, reg_and_circuit)
for b in 1:B
viewbatch(greg, b).state .*= 2 * outδ[b]
end
return (NoTangent(), NoTangent(), Tangent{typeof(reg_and_circuit)}(; first=greg, second=dispatch(reg_and_circuit.second, gcircuit)))
end
end

function rrule(::Type{T}, block::AbstractBlock) where T<:Matrix
out = T(block)
out, function (outδ)
paramsδ = mat_back(block, outδ)
return (NoTangent(), dispatch(block, paramsδ))
end
end

function rrule(::typeof(mat), ::Type{T}, block::AbstractBlock) where T
out = mat(T, block)
out, function (outδ)
paramsδ = mat_back(block, outδ)
return (NoTangent(), paramsδ)
return (NoTangent(), NoTangent(), dispatch(block, paramsδ))
end
end

Expand Down
51 changes: 49 additions & 2 deletions test/autodiff/chainrules_patch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ import Zygote, ForwardDiff
using Random, Test
using YaoBlocks, YaoArrayRegister

function Zygote.accum(a::AbstractBlock, b::AbstractBlock)
dispatch(a, parameters(a) + parameters(b))
end

@testset "rules" begin
h = put(5, 3 => Z) + put(5, 2 => X)
c = chain(put(5, 2 => chain(Rx(1.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5)))
Expand Down Expand Up @@ -31,8 +35,7 @@ using YaoBlocks, YaoArrayRegister
@test Zygote.gradient(x -> real(sum(abs2, statevec(x'))), r)[1].state ≈ g1
# zygote does not work if `sin` is not here,
# because it gives an adjoint of different type as the output matrix type.
# do not modify the data type please! Zygote
@test Zygote.gradient(x -> real(sum(sin, Matrix(x))), c)[1] ≈
@test parameters(Zygote.gradient(x -> real(sum(sin, Matrix(x))), c)[1]) ≈
ForwardDiff.gradient(x -> real(sum(sin, Matrix(dispatch(c, x)))), parameters(c))
end

Expand All @@ -46,6 +49,7 @@ end
sum(real(st .* st))
end

# apply
reg0 = zero_state(5)
params = rand!(parameters(c))
paramsδ = Zygote.gradient(params -> loss(reg0, dispatch(c, params)), params)[1]
Expand All @@ -64,6 +68,49 @@ end
)
@test fregδ ≈ reinterpret(Float64, regδ.state)
@test fparamsδ ≈ paramsδ

# expect and fidelity
c = chain(put(5, 5=>Rx(1.5)), put(5,1=>Rx(0.4)), put(5,4=>Rx(0.2)), put(5, 2 => chain(Rx(0.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5)))
h = chain(repeat(5, X, 1:5))
reg = rand_state(5)
function loss2(reg::AbstractRegister, circuit::AbstractBlock{N}) where {N}
return 5*real(expect(h, copy(reg) => circuit) + fidelity(reg, apply(reg, circuit)))
end
params = rand!(parameters(c))
fδc = ForwardDiff.gradient(
params ->
loss2(ArrayReg(Matrix{Complex{eltype(params)}}(reg.state)), dispatch(c, params)),
params,
)
δr, δc = Zygote.gradient((reg, params)->loss2(reg, dispatch(c, params)),reg, params)
@test δc ≈ fδc

fregδ = ForwardDiff.gradient(
x -> loss2(
ArrayReg([Complex(x[2i-1], x[2i]) for i in 1:length(x)÷2]),
dispatch(c, Vector{real(eltype(x))}(params)),
),
reinterpret(Float64, reg.state),
)
@test fregδ ≈ reinterpret(Float64, δr.state)

# operator fidelity
c = chain(put(5, 5=>Rx(1.5)), put(5,1=>Rx(0.4)), put(5,4=>Rx(0.2)), put(5, 2 => chain(Rx(0.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5)))
h = chain(repeat(5, X, 1:5))
function loss3(circuit::AbstractBlock{N}, h) where {N}
return operator_fidelity(circuit, h)
end
params = rand!(parameters(c))
fδc = ForwardDiff.gradient(
params ->
loss3(dispatch(c, params), h),
params,
)
δc, = Zygote.gradient(p->loss3(dispatch(c, p), h), params)
@test δc ≈ fδc

# NOTE: operator back propagation in expect is not implemented!
# to differentiate operators, we need to use the expensive `mat_back` function.
end

@testset "add block" begin
Expand Down