diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index 85126258..4da7f0d3 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -28,9 +28,10 @@ include("edge_sequences.jl") include("formnetworks/abstractformnetwork.jl") include("formnetworks/bilinearformnetwork.jl") include("formnetworks/quadraticformnetwork.jl") +include("caches/abstractbeliefpropagationcache.jl") include("caches/beliefpropagationcache.jl") -include("caches/boundarympscacheutils.jl") -include("caches/boundarympscache.jl") +#include("caches/boundarympscacheutils.jl") +#include("caches/boundarympscache.jl") include("contraction_tree_to_graph.jl") include("gauging.jl") include("utils.jl") diff --git a/src/caches/abstractbeliefpropagationcache.jl b/src/caches/abstractbeliefpropagationcache.jl new file mode 100644 index 00000000..3646faa5 --- /dev/null +++ b/src/caches/abstractbeliefpropagationcache.jl @@ -0,0 +1,236 @@ +using Graphs: IsDirected +using SplitApplyCombine: group +using LinearAlgebra: diag, dot +using ITensors: dir +using ITensorMPS: ITensorMPS +using NamedGraphs.PartitionedGraphs: + PartitionedGraphs, + PartitionedGraph, + PartitionVertex, + boundary_partitionedges, + partitionvertices, + partitionedges, + unpartitioned_graph +using SimpleTraits: SimpleTraits, Not, @traitfn +using NDTensors: NDTensors + +abstract type AbstractBeliefPropagationCache end + +function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...) + sequence = optimal_contraction_sequence(contract_list) + updated_messages = contract(contract_list; sequence, kwargs...) + message_norm = norm(updated_messages) + if normalize && !iszero(message_norm) + updated_messages /= message_norm + end + return ITensor[updated_messages] +end + +#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages +function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) + lhs, rhs = contract(message_a), contract(message_b) + f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs))) + return 1 - f +end + +default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e] +default_messages(ptn::PartitionedGraph) = Dictionary() +@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing +@traitfn function default_bp_maxiter(g::::IsDirected) + return default_bp_maxiter(undirected_graph(underlying_graph(g))) +end +default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) +function default_partitioned_vertices(f::AbstractFormNetwork) + return group(v -> original_state_vertex(f, v), vertices(f)) +end +default_cache_update_kwargs(cache) = (; maxiter=25, tol=1e-8) + +partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented() +messages(bpc::AbstractBeliefPropagationCache) = not_implemented() +default_message(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...) = not_implemented() +Base.copy(bpc::AbstractBeliefPropagationCache) = not_implemented() +default_bp_maxiter(bpc::AbstractBeliefPropagationCache) = not_implemented() +default_edge_sequence(bpc::AbstractBeliefPropagationCache) = not_implemented() +environment(bpc::AbstractBeliefPropagationCache, partition_vertices::Vector{<:PartitionVertex}; kwargs...) = not_implemented() +region_scalar(bpc::AbstractBeliefPropagationCache, pv::PartitionVertex; kwargs...) = not_implemented() +region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; kwargs...) = not_implemented() + +function factors(bpc::AbstractBeliefPropagationCache, verts::Vector) + return ITensor[tensornetwork(bpc)[v] for v in verts] +end + +function factor(bpc::AbstractBeliefPropagationCache, vertex::PartitionVertex) + return factors(bpc, vertices(bpc, vertex)) +end + +function vertex_scalars( + bpc::AbstractBeliefPropagationCache, + pvs=partitionvertices(partitioned_tensornetwork(bpc)); + kwargs..., + ) + return map(pv -> region_scalar(bpc, pv; kwargs...), pvs) +end + +function edge_scalars( + bpc::AbstractBeliefPropagationCache, + pes=partitionedges(partitioned_tensornetwork(bpc)); + kwargs..., + ) + return map(pe -> region_scalar(bpc, pe; kwargs...), pes) +end + +function scalar_factors_quotient(bpc::AbstractBeliefPropagationCache) + return vertex_scalars(bpc), edge_scalars(bpc) +end + +function environment( + bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex; kwargs... + ) + return environment(bpc, [partition_vertex]; kwargs...) +end + +function environment(bpc::AbstractBeliefPropagationCache, verts::Vector) + partition_verts = partitionvertices(bpc, verts) + messages = environment(bpc, partition_verts) + central_tensors = factors(bpc, setdiff(vertices(bpc, partition_verts), verts)) + return vcat(messages, central_tensors) +end + +function tensornetwork(bpc::AbstractBeliefPropagationCache) + return unpartitioned_graph(partitioned_tensornetwork(bpc)) +end + +#Forward from partitioned graph +for f in [ + :(PartitionedGraphs.partitioned_graph), + :(PartitionedGraphs.partitionedge), + :(PartitionedGraphs.partitionvertices), + :(PartitionedGraphs.vertices), + :(PartitionedGraphs.boundary_partitionedges), + :(ITensorMPS.linkinds), + ] + @eval begin + function $f(bpc::AbstractBeliefPropagationCache, args...; kwargs...) + return $f(partitioned_tensornetwork(bpc), args...; kwargs...) + end + end +end + +NDTensors.scalartype(bpc::AbstractBeliefPropagationCache) = scalartype(tensornetwork(bpc)) + +""" +Update the tensornetwork inside the cache +""" +function update_factors(bpc::AbstractBeliefPropagationCache, factors) + bpc = copy(bpc) + tn = tensornetwork(bpc) + for vertex in eachindex(factors) + # TODO: Add a check that this preserves the graph structure. + setindex_preserve_graph!(tn, factors[vertex], vertex) + end + return bpc +end + +function update_factor(bpc, vertex, factor) + return update_factors(bpc, Dictionary([vertex], [factor])) +end + +function message(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...) + mts = messages(bpc) + return get(() -> default_message(bpc, edge; kwargs...), mts, edge) +end +function messages(bpc::AbstractBeliefPropagationCache, edges; kwargs...) + return map(edge -> message(bpc, edge; kwargs...), edges) +end + +""" +Compute message tensor as product of incoming mts and local state +""" +function update_message( + bpc::AbstractBeliefPropagationCache, + edge::PartitionEdge; + message_update=default_message_update, + message_update_kwargs=(;), +) + vertex = src(edge) + messages = environment(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)]) + state = factor(bpc, vertex) + + return message_update(ITensor[messages; state]; message_update_kwargs...) +end + +""" +Do a sequential update of the message tensors on `edges` +""" +function update( + bpc::AbstractBeliefPropagationCache, + edges::Vector{<:PartitionEdge}; + (update_diff!)=nothing, + kwargs..., +) + bpc_updated = copy(bpc) + mts = messages(bpc_updated) + for e in edges + set!(mts, e, update_message(bpc, e; kwargs...)) + if !isnothing(update_diff!) + update_diff![] += message_diff(message(bpc, e), mts[e]) + end + end + return bpc_updated +end + + +""" +Update the message tensor on a single edge +""" +function update(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...) + return update(bpc, [edge]; kwargs...) +end + +""" +Do parallel updates between groups of edges of all message tensors +Currently we send the full message tensor data struct to update for each edge_group. But really we only need the +mts relevant to that group. +""" +function update( + bpc::AbstractBeliefPropagationCache, + edge_groups::Vector{<:Vector{<:PartitionEdge}}; + kwargs..., +) + new_mts = copy(messages(bpc)) + for edges in edge_groups + bpc_t = update(bpc, edges; kwargs...) + for e in edges + new_mts[e] = message(bpc_t, e) + end + end + return set_messages(bpc, new_mts) +end + +""" +More generic interface for update, with default params +""" +function update( + bpc::AbstractBeliefPropagationCache; + edges=default_edge_sequence(bpc), + maxiter=default_bp_maxiter(bpc), + tol=nothing, + verbose=false, + kwargs..., +) + compute_error = !isnothing(tol) + if isnothing(maxiter) + error("You need to specify a number of iterations for BP!") + end + for i in 1:maxiter + diff = compute_error ? Ref(0.0) : nothing + bpc = update(bpc, edges; (update_diff!)=diff, kwargs...) + if compute_error && (diff.x / length(edges)) <= tol + if verbose + println("BP converged to desired precision after $i iterations.") + end + break + end + end + return bpc +end \ No newline at end of file diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index d5886ec7..22ca093f 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -14,57 +14,29 @@ using NamedGraphs.PartitionedGraphs: using SimpleTraits: SimpleTraits, Not, @traitfn using NDTensors: NDTensors -default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e] -default_messages(ptn::PartitionedGraph) = Dictionary() -function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...) - sequence = optimal_contraction_sequence(contract_list) - updated_messages = contract(contract_list; sequence, kwargs...) - message_norm = norm(updated_messages) - if normalize && !iszero(message_norm) - updated_messages /= message_norm - end - return ITensor[updated_messages] -end -@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing -@traitfn function default_bp_maxiter(g::::IsDirected) - return default_bp_maxiter(undirected_graph(underlying_graph(g))) -end -default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) -function default_partitioned_vertices(f::AbstractFormNetwork) - return group(v -> original_state_vertex(f, v), vertices(f)) -end -default_cache_update_kwargs(cache) = (; maxiter=25, tol=1e-8) function default_cache_construction_kwargs(alg::Algorithm"bp", ψ::AbstractITensorNetwork) return (; partitioned_vertices=default_partitioned_vertices(ψ)) end -#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages -function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) - lhs, rhs = contract(message_a), contract(message_b) - f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs))) - return 1 - f -end - -struct BeliefPropagationCache{PTN,MTS,DM} +struct BeliefPropagationCache{PTN,MTS} <: AbstractBeliefPropagationCache partitioned_tensornetwork::PTN messages::MTS - default_message::DM end #Constructors... function BeliefPropagationCache( - ptn::PartitionedGraph; messages=default_messages(ptn), default_message=default_message + ptn::PartitionedGraph; messages=default_messages(ptn) ) - return BeliefPropagationCache(ptn, messages, default_message) + return BeliefPropagationCache(ptn, messages) end -function BeliefPropagationCache(tn, partitioned_vertices; kwargs...) +function BeliefPropagationCache(tn::AbstractITensorNetwork, partitioned_vertices; kwargs...) ptn = PartitionedGraph(tn, partitioned_vertices) return BeliefPropagationCache(ptn; kwargs...) end function BeliefPropagationCache( - tn; partitioned_vertices=default_partitioned_vertices(tn), kwargs... + tn::AbstractITensorNetwork; partitioned_vertices=default_partitioned_vertices(tn), kwargs... ) return BeliefPropagationCache(tn, partitioned_vertices; kwargs...) end @@ -76,47 +48,20 @@ end function partitioned_tensornetwork(bp_cache::BeliefPropagationCache) return bp_cache.partitioned_tensornetwork end + messages(bp_cache::BeliefPropagationCache) = bp_cache.messages -default_message(bp_cache::BeliefPropagationCache) = bp_cache.default_message function tensornetwork(bp_cache::BeliefPropagationCache) return unpartitioned_graph(partitioned_tensornetwork(bp_cache)) end -#Forward from partitioned graph -for f in [ - :(PartitionedGraphs.partitioned_graph), - :(PartitionedGraphs.partitionedge), - :(PartitionedGraphs.partitionvertices), - :(PartitionedGraphs.vertices), - :(PartitionedGraphs.boundary_partitionedges), - :(ITensorMPS.linkinds), -] - @eval begin - function $f(bp_cache::BeliefPropagationCache, args...; kwargs...) - return $f(partitioned_tensornetwork(bp_cache), args...; kwargs...) - end - end -end - -NDTensors.scalartype(bp_cache) = scalartype(tensornetwork(bp_cache)) - function default_message(bp_cache::BeliefPropagationCache, edge::PartitionEdge) - return default_message(bp_cache)(scalartype(bp_cache), linkinds(bp_cache, edge)) -end - -function message(bp_cache::BeliefPropagationCache, edge::PartitionEdge) - mts = messages(bp_cache) - return get(() -> default_message(bp_cache, edge), mts, edge) -end -function messages(bp_cache::BeliefPropagationCache, edges; kwargs...) - return map(edge -> message(bp_cache, edge; kwargs...), edges) + return default_message(scalartype(bp_cache), linkinds(bp_cache, edge)) end function Base.copy(bp_cache::BeliefPropagationCache) return BeliefPropagationCache( copy(partitioned_tensornetwork(bp_cache)), copy(messages(bp_cache)), - default_message(bp_cache), ) end @@ -129,7 +74,7 @@ end function set_messages(cache::BeliefPropagationCache, messages) return BeliefPropagationCache( - partitioned_tensornetwork(cache), messages, default_message(cache) + partitioned_tensornetwork(cache), messages ) end @@ -143,135 +88,6 @@ function environment( return reduce(vcat, ms; init=ITensor[]) end -function environment( - bp_cache::BeliefPropagationCache, partition_vertex::PartitionVertex; kwargs... -) - return environment(bp_cache, [partition_vertex]; kwargs...) -end - -function environment(bp_cache::BeliefPropagationCache, verts::Vector) - partition_verts = partitionvertices(bp_cache, verts) - messages = environment(bp_cache, partition_verts) - central_tensors = factors(bp_cache, setdiff(vertices(bp_cache, partition_verts), verts)) - return vcat(messages, central_tensors) -end - -function factors(bp_cache::BeliefPropagationCache, verts::Vector) - return ITensor[tensornetwork(bp_cache)[v] for v in verts] -end - -function factor(bp_cache::BeliefPropagationCache, vertex::PartitionVertex) - return factors(bp_cache, vertices(bp_cache, vertex)) -end - -""" -Compute message tensor as product of incoming mts and local state -""" -function update_message( - bp_cache::BeliefPropagationCache, - edge::PartitionEdge; - message_update=default_message_update, - message_update_kwargs=(;), -) - vertex = src(edge) - messages = environment(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)]) - state = factor(bp_cache, vertex) - - return message_update(ITensor[messages; state]; message_update_kwargs...) -end - -""" -Do a sequential update of the message tensors on `edges` -""" -function update( - bp_cache::BeliefPropagationCache, - edges::Vector{<:PartitionEdge}; - (update_diff!)=nothing, - kwargs..., -) - bp_cache_updated = copy(bp_cache) - mts = messages(bp_cache_updated) - for e in edges - set!(mts, e, update_message(bp_cache_updated, e; kwargs...)) - if !isnothing(update_diff!) - update_diff![] += message_diff(message(bp_cache, e), mts[e]) - end - end - return bp_cache_updated -end - -""" -Update the message tensor on a single edge -""" -function update(bp_cache::BeliefPropagationCache, edge::PartitionEdge; kwargs...) - return update(bp_cache, [edge]; kwargs...) -end - -""" -Do parallel updates between groups of edges of all message tensors -Currently we send the full message tensor data struct to update for each edge_group. But really we only need the -mts relevant to that group. -""" -function update( - bp_cache::BeliefPropagationCache, - edge_groups::Vector{<:Vector{<:PartitionEdge}}; - kwargs..., -) - new_mts = copy(messages(bp_cache)) - for edges in edge_groups - bp_cache_t = update(bp_cache, edges; kwargs...) - for e in edges - new_mts[e] = message(bp_cache_t, e) - end - end - return set_messages(bp_cache, new_mts) -end - -""" -More generic interface for update, with default params -""" -function update( - bp_cache::BeliefPropagationCache; - edges=default_edge_sequence(bp_cache), - maxiter=default_bp_maxiter(bp_cache), - tol=nothing, - verbose=false, - kwargs..., -) - compute_error = !isnothing(tol) - if isnothing(maxiter) - error("You need to specify a number of iterations for BP!") - end - for i in 1:maxiter - diff = compute_error ? Ref(0.0) : nothing - bp_cache = update(bp_cache, edges; (update_diff!)=diff, kwargs...) - if compute_error && (diff.x / length(edges)) <= tol - if verbose - println("BP converged to desired precision after $i iterations.") - end - break - end - end - return bp_cache -end - -""" -Update the tensornetwork inside the cache -""" -function update_factors(bp_cache::BeliefPropagationCache, factors) - bp_cache = copy(bp_cache) - tn = tensornetwork(bp_cache) - for vertex in eachindex(factors) - # TODO: Add a check that this preserves the graph structure. - setindex_preserve_graph!(tn, factors[vertex], vertex) - end - return bp_cache -end - -function update_factor(bp_cache, vertex, factor) - return update_factors(bp_cache, Dictionary([vertex], [factor])) -end - function region_scalar( bp_cache::BeliefPropagationCache, pv::PartitionVertex; @@ -291,23 +107,3 @@ function region_scalar( vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))); contract_kwargs... )[] end - -function vertex_scalars( - bp_cache::BeliefPropagationCache, - pvs=partitionvertices(partitioned_tensornetwork(bp_cache)); - kwargs..., -) - return map(pv -> region_scalar(bp_cache, pv; kwargs...), pvs) -end - -function edge_scalars( - bp_cache::BeliefPropagationCache, - pes=partitionedges(partitioned_tensornetwork(bp_cache)); - kwargs..., -) - return map(pe -> region_scalar(bp_cache, pe; kwargs...), pes) -end - -function scalar_factors_quotient(bp_cache::BeliefPropagationCache) - return vertex_scalars(bp_cache), edge_scalars(bp_cache) -end