-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
247 additions
and
214 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
using Graphs: IsDirected | ||
using SplitApplyCombine: group | ||
using LinearAlgebra: diag, dot | ||
using ITensors: dir | ||
using ITensorMPS: ITensorMPS | ||
using NamedGraphs.PartitionedGraphs: | ||
PartitionedGraphs, | ||
PartitionedGraph, | ||
PartitionVertex, | ||
boundary_partitionedges, | ||
partitionvertices, | ||
partitionedges, | ||
unpartitioned_graph | ||
using SimpleTraits: SimpleTraits, Not, @traitfn | ||
using NDTensors: NDTensors | ||
|
||
abstract type AbstractBeliefPropagationCache end | ||
|
||
function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...) | ||
sequence = optimal_contraction_sequence(contract_list) | ||
updated_messages = contract(contract_list; sequence, kwargs...) | ||
message_norm = norm(updated_messages) | ||
if normalize && !iszero(message_norm) | ||
updated_messages /= message_norm | ||
end | ||
return ITensor[updated_messages] | ||
end | ||
|
||
#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages | ||
function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) | ||
lhs, rhs = contract(message_a), contract(message_b) | ||
f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs))) | ||
return 1 - f | ||
end | ||
|
||
default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e] | ||
default_messages(ptn::PartitionedGraph) = Dictionary() | ||
@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing | ||
@traitfn function default_bp_maxiter(g::::IsDirected) | ||
return default_bp_maxiter(undirected_graph(underlying_graph(g))) | ||
end | ||
default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) | ||
function default_partitioned_vertices(f::AbstractFormNetwork) | ||
return group(v -> original_state_vertex(f, v), vertices(f)) | ||
end | ||
default_cache_update_kwargs(cache) = (; maxiter=25, tol=1e-8) | ||
|
||
partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented() | ||
messages(bpc::AbstractBeliefPropagationCache) = not_implemented() | ||
default_message(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...) = not_implemented() | ||
Base.copy(bpc::AbstractBeliefPropagationCache) = not_implemented() | ||
default_bp_maxiter(bpc::AbstractBeliefPropagationCache) = not_implemented() | ||
default_edge_sequence(bpc::AbstractBeliefPropagationCache) = not_implemented() | ||
environment(bpc::AbstractBeliefPropagationCache, partition_vertices::Vector{<:PartitionVertex}; kwargs...) = not_implemented() | ||
region_scalar(bpc::AbstractBeliefPropagationCache, pv::PartitionVertex; kwargs...) = not_implemented() | ||
region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; kwargs...) = not_implemented() | ||
|
||
function factors(bpc::AbstractBeliefPropagationCache, verts::Vector) | ||
return ITensor[tensornetwork(bpc)[v] for v in verts] | ||
end | ||
|
||
function factor(bpc::AbstractBeliefPropagationCache, vertex::PartitionVertex) | ||
return factors(bpc, vertices(bpc, vertex)) | ||
end | ||
|
||
function vertex_scalars( | ||
bpc::AbstractBeliefPropagationCache, | ||
pvs=partitionvertices(partitioned_tensornetwork(bpc)); | ||
kwargs..., | ||
) | ||
return map(pv -> region_scalar(bpc, pv; kwargs...), pvs) | ||
end | ||
|
||
function edge_scalars( | ||
bpc::AbstractBeliefPropagationCache, | ||
pes=partitionedges(partitioned_tensornetwork(bpc)); | ||
kwargs..., | ||
) | ||
return map(pe -> region_scalar(bpc, pe; kwargs...), pes) | ||
end | ||
|
||
function scalar_factors_quotient(bpc::AbstractBeliefPropagationCache) | ||
return vertex_scalars(bpc), edge_scalars(bpc) | ||
end | ||
|
||
function environment( | ||
bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex; kwargs... | ||
) | ||
return environment(bpc, [partition_vertex]; kwargs...) | ||
end | ||
|
||
function environment(bpc::AbstractBeliefPropagationCache, verts::Vector) | ||
partition_verts = partitionvertices(bpc, verts) | ||
messages = environment(bpc, partition_verts) | ||
central_tensors = factors(bpc, setdiff(vertices(bpc, partition_verts), verts)) | ||
return vcat(messages, central_tensors) | ||
end | ||
|
||
function tensornetwork(bpc::AbstractBeliefPropagationCache) | ||
return unpartitioned_graph(partitioned_tensornetwork(bpc)) | ||
end | ||
|
||
#Forward from partitioned graph | ||
for f in [ | ||
:(PartitionedGraphs.partitioned_graph), | ||
:(PartitionedGraphs.partitionedge), | ||
:(PartitionedGraphs.partitionvertices), | ||
:(PartitionedGraphs.vertices), | ||
:(PartitionedGraphs.boundary_partitionedges), | ||
:(ITensorMPS.linkinds), | ||
] | ||
@eval begin | ||
function $f(bpc::AbstractBeliefPropagationCache, args...; kwargs...) | ||
return $f(partitioned_tensornetwork(bpc), args...; kwargs...) | ||
end | ||
end | ||
end | ||
|
||
NDTensors.scalartype(bpc::AbstractBeliefPropagationCache) = scalartype(tensornetwork(bpc)) | ||
|
||
""" | ||
Update the tensornetwork inside the cache | ||
""" | ||
function update_factors(bpc::AbstractBeliefPropagationCache, factors) | ||
bpc = copy(bpc) | ||
tn = tensornetwork(bpc) | ||
for vertex in eachindex(factors) | ||
# TODO: Add a check that this preserves the graph structure. | ||
setindex_preserve_graph!(tn, factors[vertex], vertex) | ||
end | ||
return bpc | ||
end | ||
|
||
function update_factor(bpc, vertex, factor) | ||
return update_factors(bpc, Dictionary([vertex], [factor])) | ||
end | ||
|
||
function message(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...) | ||
mts = messages(bpc) | ||
return get(() -> default_message(bpc, edge; kwargs...), mts, edge) | ||
end | ||
function messages(bpc::AbstractBeliefPropagationCache, edges; kwargs...) | ||
return map(edge -> message(bpc, edge; kwargs...), edges) | ||
end | ||
|
||
""" | ||
Compute message tensor as product of incoming mts and local state | ||
""" | ||
function update_message( | ||
bpc::AbstractBeliefPropagationCache, | ||
edge::PartitionEdge; | ||
message_update=default_message_update, | ||
message_update_kwargs=(;), | ||
) | ||
vertex = src(edge) | ||
messages = environment(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)]) | ||
state = factor(bpc, vertex) | ||
|
||
return message_update(ITensor[messages; state]; message_update_kwargs...) | ||
end | ||
|
||
""" | ||
Do a sequential update of the message tensors on `edges` | ||
""" | ||
function update( | ||
bpc::AbstractBeliefPropagationCache, | ||
edges::Vector{<:PartitionEdge}; | ||
(update_diff!)=nothing, | ||
kwargs..., | ||
) | ||
bpc_updated = copy(bpc) | ||
mts = messages(bpc_updated) | ||
for e in edges | ||
set!(mts, e, update_message(bpc, e; kwargs...)) | ||
if !isnothing(update_diff!) | ||
update_diff![] += message_diff(message(bpc, e), mts[e]) | ||
end | ||
end | ||
return bpc_updated | ||
end | ||
|
||
|
||
""" | ||
Update the message tensor on a single edge | ||
""" | ||
function update(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...) | ||
return update(bpc, [edge]; kwargs...) | ||
end | ||
|
||
""" | ||
Do parallel updates between groups of edges of all message tensors | ||
Currently we send the full message tensor data struct to update for each edge_group. But really we only need the | ||
mts relevant to that group. | ||
""" | ||
function update( | ||
bpc::AbstractBeliefPropagationCache, | ||
edge_groups::Vector{<:Vector{<:PartitionEdge}}; | ||
kwargs..., | ||
) | ||
new_mts = copy(messages(bpc)) | ||
for edges in edge_groups | ||
bpc_t = update(bpc, edges; kwargs...) | ||
for e in edges | ||
new_mts[e] = message(bpc_t, e) | ||
end | ||
end | ||
return set_messages(bpc, new_mts) | ||
end | ||
|
||
""" | ||
More generic interface for update, with default params | ||
""" | ||
function update( | ||
bpc::AbstractBeliefPropagationCache; | ||
edges=default_edge_sequence(bpc), | ||
maxiter=default_bp_maxiter(bpc), | ||
tol=nothing, | ||
verbose=false, | ||
kwargs..., | ||
) | ||
compute_error = !isnothing(tol) | ||
if isnothing(maxiter) | ||
error("You need to specify a number of iterations for BP!") | ||
end | ||
for i in 1:maxiter | ||
diff = compute_error ? Ref(0.0) : nothing | ||
bpc = update(bpc, edges; (update_diff!)=diff, kwargs...) | ||
if compute_error && (diff.x / length(edges)) <= tol | ||
if verbose | ||
println("BP converged to desired precision after $i iterations.") | ||
end | ||
break | ||
end | ||
end | ||
return bpc | ||
end |
Oops, something went wrong.