Skip to content

Commit

Permalink
Added callback functions for building message tensors and contracting…
Browse files Browse the repository at this point in the history
… them
  • Loading branch information
JoeyT1994 committed Jan 24, 2024
1 parent 504085c commit c32edf5
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 90 deletions.
3 changes: 1 addition & 2 deletions examples/boundary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ tn = ITensorNetwork(named_grid((6, 3)); link_space=4)

@visualize tn

pvertices = partitioned_vertices(underlying_graph(tn); nvertices_per_partition=2)
ptn = PartitionedGraph(tn, pvertices)
ptn = PartitionedGraph(tn; nvertices_per_partition=2)
sub_vs_1 = vertices(ptn, PartitionVertex(1))
sub_vs_2 = vertices(ptn, PartitionVertex(2))

Expand Down
2 changes: 1 addition & 1 deletion examples/distances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ t = dijkstra_tree(ψ, only(center(ψ)))
@show a_star(ψ, (2, 1), (2, 5))
@show mincut_partitions(ψ)
@show mincut_partitions(ψ, (1, 1), (3, 5))
@show partitioned_vertices(underlying_graph(ψ); npartitions=2)
@show partitioned_vertices(ψ; npartitions=2)
3 changes: 1 addition & 2 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ using NamedGraphs:
parent_graph,
vertex_to_parent_vertex,
parent_vertices_to_vertices,
not_implemented,
parent
not_implemented

include("imports.jl")

Expand Down
16 changes: 8 additions & 8 deletions src/apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ Gate does not necessarily need to be passed. Can supply an edge to do an identit
function vidal_apply(
o::Union{ITensor,NamedEdge},
ψ::AbstractITensorNetwork,
bond_tensors;
bond_tensors::DataGraph;
normalize=false,
apply_kwargs...,
)
Expand All @@ -314,13 +314,13 @@ function vidal_apply(

for vn in neighbors(ψ, src(e))
if (vn != dst(e))
ψv1 = noprime(ψv1 * bond_tensors[NamedEdge(vn => src(e))])
ψv1 = noprime(ψv1 * bond_tensors[vn => src(e)])
end
end

for vn in neighbors(ψ, dst(e))
if (vn != src(e))
ψv2 = noprime(ψv2 * bond_tensors[NamedEdge(vn => dst(e))])
ψv2 = noprime(ψv2 * bond_tensors[vn => dst(e)])
end
end

Expand All @@ -336,20 +336,20 @@ function vidal_apply(

ind_to_replace = commonind(V, S)
ind_to_replace_with = commonind(U, S)
replaceind!(S, ind_to_replace, ind_to_replace_with')
replaceind!(V, ind_to_replace, ind_to_replace_with)
S = replaceind(S, ind_to_replace => ind_to_replace_with')
V = replaceind(V, ind_to_replace => ind_to_replace_with)

ψv1, bond_tensors[e], bond_tensors[reverse(e)], ψv2 = U * Qᵥ₁, S, S, V * Qᵥ₂
ψv1, bond_tensors[e], ψv2 = U * Qᵥ₁, S, V * Qᵥ₂

for vn in neighbors(ψ, src(e))
if (vn != dst(e))
ψv1 = noprime(ψv1 * inv_diag(bond_tensors[NamedEdge(vn => src(e))]))
ψv1 = noprime(ψv1 * inv_diag(bond_tensors[vn => src(e)]))
end
end

for vn in neighbors(ψ, dst(e))
if (vn != src(e))
ψv2 = noprime(ψv2 * inv_diag(bond_tensors[NamedEdge(vn => dst(e))]))
ψv2 = noprime(ψv2 * inv_diag(bond_tensors[vn => dst(e)]))
end
end

Expand Down
89 changes: 49 additions & 40 deletions src/beliefpropagation/beliefpropagation.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,59 @@
function message_tensors(
ptn::PartitionedGraph; itensor_constructor=inds_e -> ITensor[dense(delta(inds_e))]
)
mts = Dict()
for e in partitionedges(ptn)
src_e_itn = subgraph(ptn, src(e))
dst_e_itn = subgraph(ptn, dst(e))
inds_e = commoninds(src_e_itn, dst_e_itn)
mts[e] = itensor_constructor(inds_e)
mts[reverse(e)] = dag.(mts[e])
end
default_mt_constructor(inds_e) = ITensor[denseblocks(delta(inds_e))]
default_bp_cache(ptn::PartitionedGraph) = Dict()
function default_contractor(contract_list::Vector{ITensor})
return ITensor[normalize!(
contract(contract_list; sequence=contraction_sequence(contract_list; alg="optimal"))
)]
end

function contract_to_MPS(contract_list::Vector{ITensor}; svd_kwargs...)
contract_kwargs = (;
alg="density_matrix",
output_structure=path_graph_structure,
contraction_sequence_alg="optimal",
svd_kwargs...,
)
mts = ITensor(first(contract(ITensorNetwork(contract_list); contract_kwargs...)))
mts = normalize!.(mts)
return mts
end

function message_tensor(
ptn::PartitionedGraph, edge::PartitionEdge; mt_constructor=default_mt_constructor
)
src_e_itn = subgraph(ptn, src(edge))
dst_e_itn = subgraph(ptn, dst(edge))
inds_e = commoninds(src_e_itn, dst_e_itn)
return mt_constructor(inds_e)
end

"""
Do a single update of a message tensor using the current subgraph and the incoming mts
"""
function update_message_tensor(
ptn::PartitionedGraph, edge::PartitionEdge, mts; contract_kwargs=(; alg="exact")
ptn::PartitionedGraph,
edge::PartitionEdge,
mts;
contractor=default_contractor,
mt_constructor=default_mt_constructor,
)
pedges = setdiff(
partitionedges(ptn, boundary_edges(ptn, vertices(ptn, src(edge)); dir=:in)),
[reverse(edge)],
)
incoming_messages = [mts[e_in] for e_in in pedges]

incoming_messages = [
e_in keys(mts) ? mts[e_in] : message_tensor(ptn, e_in; mt_constructor) for
e_in in pedges
]
incoming_messages = reduce(vcat, incoming_messages; init=ITensor[])

contract_list = ITensor[
incoming_messages
ITensor(subgraph(ptn, src(edge)))
]

if contract_kwargs.alg != "exact"
mt = first(contract(ITensorNetwork(contract_list); contract_kwargs...))
else
mt = contract(
contract_list; sequence=contraction_sequence(contract_list; alg="optimal")
)
end

mt = isa(mt, ITensor) ? ITensor[mt] : ITensor(mt)
normalize!.(mt)

return mt
return contractor(contract_list)
end

"""
Expand All @@ -51,16 +63,17 @@ function belief_propagation_iteration(
ptn::PartitionedGraph,
mts,
edges::Vector{<:PartitionEdge};
contract_kwargs=(; alg="exact"),
contractor=default_contractor,
compute_norm=false,
)
new_mts = copy(mts)
c = 0
for e in edges
new_mts[e] = update_message_tensor(ptn, e, new_mts; contract_kwargs)
new_mts[e] = update_message_tensor(ptn, e, new_mts; contractor)

if compute_norm
LHS, RHS = ITensors.contract(mts[e]), ITensors.contract(new_mts[e])
LHS = e keys(mts) ? contract(mts[e]) : contract(message_tensor(ptn, e))
RHS = contract(new_mts[e])
#This line only makes sense if the message tensors are rank 2??? Should fix this.
LHS /= sum(diag(LHS))
RHS /= sum(diag(RHS))
Expand All @@ -79,14 +92,14 @@ function belief_propagation_iteration(
ptn::PartitionedGraph,
mts,
edge_groups::Vector{<:Vector{<:PartitionEdge}};
contract_kwargs=(; alg="exact"),
contractor=default_contractor,
compute_norm=false,
)
new_mts = copy(mts)
c = 0
for edges in edge_groups
updated_mts, ct = belief_propagation_iteration(
ptn, mts, edges; contract_kwargs, compute_norm
ptn, mts, edges; contractor, compute_norm
)
for e in edges
new_mts[e] = updated_mts[e]
Expand All @@ -99,17 +112,17 @@ end
function belief_propagation_iteration(
ptn::PartitionedGraph,
mts;
contract_kwargs=(; alg="exact"),
contractor=default_contractor,
compute_norm=false,
edges=PartitionEdge.(edge_sequence(partitioned_graph(ptn))),
)
return belief_propagation_iteration(ptn, mts, edges; contract_kwargs, compute_norm)
return belief_propagation_iteration(ptn, mts, edges; contractor, compute_norm)
end

function belief_propagation(
ptn::PartitionedGraph,
mts;
contract_kwargs=(; alg="exact"),
contractor=default_contractor,
niters=default_bp_niters(partitioned_graph(ptn)),
target_precision=nothing,
edges=PartitionEdge.(edge_sequence(partitioned_graph(ptn))),
Expand All @@ -120,7 +133,7 @@ function belief_propagation(
error("You need to specify a number of iterations for BP!")
end
for i in 1:niters
mts, c = belief_propagation_iteration(ptn, mts, edges; contract_kwargs, compute_norm)
mts, c = belief_propagation_iteration(ptn, mts, edges; contractor, compute_norm)
if compute_norm && c <= target_precision
if verbose
println("BP converged to desired precision after $i iterations.")
Expand All @@ -131,12 +144,8 @@ function belief_propagation(
return mts
end

function belief_propagation(
ptn::PartitionedGraph;
itensor_constructor=inds_e -> ITensor[dense(delta(inds_e))],
kwargs...,
)
mts = message_tensors(ptn; itensor_constructor)
function belief_propagation(ptn::PartitionedGraph; bp_cache=default_bp_cache, kwargs...)
mts = bp_cache(ptn)
return belief_propagation(ptn, mts; kwargs...)
end
"""
Expand Down
1 change: 1 addition & 0 deletions src/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export AbstractITensorNetwork,
itensors,
reverse_bfs_edges,
data_graph,
default_bp_cache,
flatten_networks,
inner_network,
norm_network,
Expand Down
31 changes: 11 additions & 20 deletions src/gauging.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""initialize bond tensors of an ITN to identity matrices"""
function initialize_bond_tensors::ITensorNetwork; index_map=prime)
bond_tensors = Dict()
bond_tensors = DataGraph{vertextype(ψ),Nothing,ITensor}(underlying_graph(ψ))

for e in edges(ψ)
index = commoninds(ψ[src(e)], ψ[dst(e)])
bond_tensors[e] = dense(delta(index, index_map(index)))
bond_tensors[reverse(e)] = bond_tensors[e]
bond_tensors[e] = denseblocks(delta(index, index_map(index)))
end

return bond_tensors
Expand All @@ -16,7 +15,7 @@ function vidal_gauge(
ψ::ITensorNetwork,
pψψ::PartitionedGraph,
mts,
bond_tensors;
bond_tensors::DataGraph;
eigen_message_tensor_cutoff=10 * eps(real(scalartype(ψ))),
regularization=10 * eps(real(scalartype(ψ))),
edges=NamedGraphs.edges(ψ),
Expand Down Expand Up @@ -68,7 +67,7 @@ function vidal_gauge(
[commoninds(S, U)..., commoninds(S, V)...] =>
[new_edge_ind..., prime(new_edge_ind)...],
)
bond_tensors[e], bond_tensors[reverse(e)] = S, S
bond_tensors[e] = S
end

return ψ_vidal, bond_tensors
Expand Down Expand Up @@ -109,27 +108,19 @@ function vidal_gauge(
)
ψψ = norm_network(ψ)
pψψ = PartitionedGraph(ψψ, group(v -> v[1], vertices(ψψ)))
mts = message_tensors(pψψ)

mts = belief_propagation(
pψψ,
mts;
contract_kwargs=(; alg="exact"),
niters,
target_precision=target_canonicalness,
verbose,
)
mts = belief_propagation(pψψ; niters, target_precision=target_canonicalness, verbose)
return vidal_gauge(
ψ, pψψ, mts; eigen_message_tensor_cutoff, regularization, svd_kwargs...
)
end

"""Transform from an ITensor in the Vidal Gauge (bond tensors) to the Symmetric Gauge (partitionedgraph, message tensors)"""
function vidal_to_symmetric_gauge::ITensorNetwork, bond_tensors)
function vidal_to_symmetric_gauge::ITensorNetwork, bond_tensors::DataGraph)
ψsymm = copy(ψ)
ψψsymm = norm_network(ψsymm)
pψψsymm = PartitionedGraph(ψψsymm, group(v -> v[1], vertices(ψψsymm)))
ψsymm_mts = message_tensors(pψψsymm)
ψsymm_mts = default_bp_cache(pψψsymm)

for e in edges(ψsymm)
vsrc, vdst = src(e), dst(e)
Expand Down Expand Up @@ -176,7 +167,7 @@ function symmetric_to_vidal_gauge(
mts;
regularization=10 * eps(real(scalartype(ψ))),
)
bond_tensors = Dict()
bond_tensors = DataGraph{vertextype(ψ),Nothing,ITensor}(underlying_graph(ψ))

ψ_vidal = copy(ψ)

Expand All @@ -195,7 +186,7 @@ end
"""Function to measure the 'isometries' of a state in the Vidal Gauge"""
function vidal_itn_isometries(
ψ::ITensorNetwork,
bond_tensors;
bond_tensors::DataGraph;
edges=vcat(NamedGraphs.edges(ψ), reverse.(NamedGraphs.edges(ψ))),
)
isometries = Dict()
Expand All @@ -204,7 +195,7 @@ function vidal_itn_isometries(
vsrc, vdst = src(e), dst(e)
ψv = copy(ψ[vsrc])
for vn in setdiff(neighbors(ψ, vsrc), [vdst])
ψv = noprime(ψv * bond_tensors[NamedEdge(vn => vsrc)])
ψv = noprime(ψv * bond_tensors[vn => vsrc])
end

ψvdag = dag(ψv)
Expand All @@ -216,7 +207,7 @@ function vidal_itn_isometries(
end

"""Function to measure the 'canonicalness' of a state in the Vidal Gauge"""
function vidal_itn_canonicalness::ITensorNetwork, bond_tensors)
function vidal_itn_canonicalness::ITensorNetwork, bond_tensors::DataGraph)
f = 0

isometries = vidal_itn_isometries(ψ, bond_tensors)
Expand Down
5 changes: 2 additions & 3 deletions test/test_apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ using ITensorNetworks:
belief_propagation,
environment_tensors,
contract_inner,
message_tensors,
vidal_gauge,
vidal_apply,
vidal_to_symmetric_gauge,
Expand Down Expand Up @@ -31,14 +30,14 @@ using SplitApplyCombine

#Simple Belief Propagation Grouping
pψψ_SBP = PartitionedGraph(ψψ, group(v -> v[1], vertices(ψψ)))
mtsSBP = belief_propagation(pψψ_SBP; contract_kwargs=(; alg="exact"), niters=50)
mtsSBP = belief_propagation(pψψ_SBP; niters=20)
envsSBP = environment_tensors(pψψ_SBP, mtsSBP, PartitionVertex.([v1, v2]))

ψ_vidal, bond_tensors = vidal_gauge(ψ, pψψ_SBP, mtsSBP)

#This grouping will correspond to calculating the environments exactly (each column of the grid is a partition)
pψψ_GBP = PartitionedGraph(ψψ, group(v -> v[1][1], vertices(ψψ)))
mtsGBP = belief_propagation(pψψ_GBP; contract_kwargs=(; alg="exact"), niters=50)
mtsGBP = belief_propagation(pψψ_GBP; niters=20)
envsGBP = environment_tensors(pψψ_GBP, mtsGBP, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)])

ngates = 5
Expand Down
Loading

0 comments on commit c32edf5

Please sign in to comment.