Skip to content

Commit

Permalink
Generic update interface for boundary mps and simple bp
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Jan 8, 2025
1 parent b11c2f8 commit 99b4929
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 89 deletions.
83 changes: 60 additions & 23 deletions src/caches/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,15 @@ function default_message(
)
return not_implemented()
end
default_message_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented()
Base.copy(bpc::AbstractBeliefPropagationCache) = not_implemented()
default_bp_maxiter(bpc::AbstractBeliefPropagationCache) = not_implemented()
default_edge_sequence(bpc::AbstractBeliefPropagationCache) = not_implemented()
default_bp_maxiter(alg::Algorithm, bpc::AbstractBeliefPropagationCache) = not_implemented()
function default_edge_sequence(alg::Algorithm, bpc::AbstractBeliefPropagationCache)
return not_implemented()
end
function default_message_update_kwargs(alg::Algorithm, bpc::AbstractBeliefPropagationCache)
return not_implemented()
end
function environment(bpc::AbstractBeliefPropagationCache, verts::Vector; kwargs...)
return not_implemented()
end
Expand All @@ -65,6 +71,22 @@ function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; k
return not_implemented()
end

function default_edge_sequence(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
)
return default_edge_sequence(Algorithm(alg), bpc)
end
function default_bp_maxiter(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
)
return default_bp_maxiter(Algorithm(alg), bpc)
end
function default_message_update_kwargs(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
)
return default_message_update_kwargs(Algorithm(alg), bpc)
end

function tensornetwork(bpc::AbstractBeliefPropagationCache)
return unpartitioned_graph(partitioned_tensornetwork(bpc))
end
Expand Down Expand Up @@ -165,44 +187,49 @@ end
"""
Compute message tensor as product of incoming mts and local state
"""
function update_message(
function updated_message(
bpc::AbstractBeliefPropagationCache,
edge::PartitionEdge;
message_update=default_message_update,
message_update_kwargs=(;),
)
vertex = src(edge)
messages = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)])
incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)])
state = factor(bpc, vertex)

return message_update(ITensor[messages; state]; message_update_kwargs...)
return message_update(ITensor[incoming_ms; state]; message_update_kwargs...)
end

function update(
alg::Algorithm"SimpleBP",
bpc::AbstractBeliefPropagationCache,
edge::PartitionEdge;
kwargs...,
)
new_m = updated_message(bpc, edge; kwargs...)
bpc = set_message(bpc, edge, new_m)
return bpc
end

"""
Do a sequential update of the message tensors on `edges`
"""
function update(
alg::Algorithm,
bpc::AbstractBeliefPropagationCache,
edges::Vector{<:PartitionEdge};
edges::Vector;
(update_diff!)=nothing,
kwargs...,
)
bpc_updated = copy(bpc)
mts = messages(bpc_updated)
bpc = copy(bpc)
for e in edges
set!(mts, e, update_message(bpc_updated, e; kwargs...))
prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing
bpc = update(alg, bpc, e; kwargs...)
if !isnothing(update_diff!)
update_diff![] += message_diff(message(bpc, e), mts[e])
update_diff![] += message_diff(message(bpc, e), prev_message)
end
end
return bpc_updated
end

"""
Update the message tensor on a single edge
"""
function update(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...)
return update(bpc, [edge]; kwargs...)
return bpc
end

"""
Expand All @@ -211,13 +238,14 @@ Currently we send the full message tensor data struct to update for each edge_gr
mts relevant to that group.
"""
function update(
alg::Algorithm,
bpc::AbstractBeliefPropagationCache,
edge_groups::Vector{<:Vector{<:PartitionEdge}};
kwargs...,
)
new_mts = copy(messages(bpc))
for edges in edge_groups
bpc_t = update(bpc, edges; kwargs...)
bpc_t = update(alg, bpc, edges; kwargs...)
for e in edges
new_mts[e] = message(bpc_t, e)
end
Expand All @@ -229,20 +257,21 @@ end
More generic interface for update, with default params
"""
function update(
alg::Algorithm,
bpc::AbstractBeliefPropagationCache;
edges=default_edge_sequence(bpc),
maxiter=default_bp_maxiter(bpc),
edges=default_edge_sequence(alg, bpc),
maxiter=default_bp_maxiter(alg, bpc),
message_update_kwargs=default_message_update_kwargs(alg, bpc),
tol=nothing,
verbose=false,
kwargs...,
)
compute_error = !isnothing(tol)
if isnothing(maxiter)
error("You need to specify a number of iterations for BP!")
end
for i in 1:maxiter
diff = compute_error ? Ref(0.0) : nothing
bpc = update(bpc, edges; (update_diff!)=diff, kwargs...)
bpc = update(alg, bpc, edges; (update_diff!)=diff, message_update_kwargs...)
if compute_error && (diff.x / length(edges)) <= tol
if verbose
println("BP converged to desired precision after $i iterations.")
Expand All @@ -252,3 +281,11 @@ function update(
end
return bpc
end

function update(
bpc::AbstractBeliefPropagationCache;
alg::String=default_message_update_alg(bpc),
kwargs...,
)
return update(Algorithm(alg), bpc; kwargs...)
end
11 changes: 9 additions & 2 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,19 @@ function Base.copy(bp_cache::BeliefPropagationCache)
)
end

function default_bp_maxiter(bp_cache::BeliefPropagationCache)
default_message_update_alg(bp_cache::BeliefPropagationCache) = "SimpleBP"

function default_bp_maxiter(alg::Algorithm"SimpleBP", bp_cache::BeliefPropagationCache)
return default_bp_maxiter(partitioned_graph(bp_cache))
end
function default_edge_sequence(bp_cache::BeliefPropagationCache)
function default_edge_sequence(alg::Algorithm"SimpleBP", bp_cache::BeliefPropagationCache)
return default_edge_sequence(partitioned_tensornetwork(bp_cache))
end
function default_message_update_kwargs(
alg::Algorithm"SimpleBP", bpc::AbstractBeliefPropagationCache
)
return (;)
end

function set_messages(cache::BeliefPropagationCache, messages)
return BeliefPropagationCache(partitioned_tensornetwork(cache), messages)
Expand Down
90 changes: 33 additions & 57 deletions src/caches/boundarympscache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,19 @@ function partitioned_tensornetwork(bmpsc::BoundaryMPSCache)
end
messages(bmpsc::BoundaryMPSCache) = messages(bp_cache(bmpsc))

default_edge_sequence(bmpsc::BoundaryMPSCache) = pair.(default_edge_sequence(ppg(bmpsc)))
default_message_update_alg(bmpsc::BoundaryMPSCache) = "orthogonal"

function default_bp_maxiter(alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache)
return default_bp_maxiter(partitioned_graph(ppg(bmpsc)))
end
default_bp_maxiter(alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache) = 50
function default_mps_fit_kwargs(alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache)
function default_edge_sequence(alg::Algorithm, bmpsc::BoundaryMPSCache)
return pair.(default_edge_sequence(ppg(bmpsc)))
end
function default_message_update_kwargs(alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache)
return (; niters=50, tolerance=1e-10)
end
default_mps_fit_kwargs(alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache) = (;)
default_message_update_kwargs(alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache) = (;)

function Base.copy(bmpsc::BoundaryMPSCache)
return BoundaryMPSCache(
Expand Down Expand Up @@ -221,23 +225,33 @@ function switch_messages(bmpsc::BoundaryMPSCache, partitionpair::Pair)
end

#Update all messages tensors within a partition
function partition_update(bmpsc::BoundaryMPSCache, partition::Int64; kwargs...)
function partition_update(bmpsc::BoundaryMPSCache, partition::Int64)
vs = sort(planargraph_vertices(bmpsc, partition))
bmpsc = partition_update(bmpsc, first(vs); kwargs...)
bmpsc = partition_update(bmpsc, last(vs); kwargs...)
bmpsc = partition_update(bmpsc, first(vs))
bmpsc = partition_update(bmpsc, last(vs))
return bmpsc
end

#Update all messages within a partition along the path from from v1 to v2
function partition_update(bmpsc::BoundaryMPSCache, v1, v2; kwargs...)
return update(bmpsc, PartitionEdge.(a_star(ppg(bmpsc), v1, v2)); kwargs...)
function partition_update(bmpsc::BoundaryMPSCache, v1, v2)
return update(
Algorithm("SimpleBP"),
bmpsc,
PartitionEdge.(a_star(ppg(bmpsc), v1, v2));
message_update_kwargs=(; normalize=false),
)
end

#Update all message tensors within a partition pointing towards v
function partition_update(bmpsc::BoundaryMPSCache, v; kwargs...)
function partition_update(bmpsc::BoundaryMPSCache, v)
pv = planargraph_partition(bmpsc, v)
g = subgraph(unpartitioned_graph(ppg(bmpsc)), planargraph_vertices(bmpsc, pv))
return update(bmpsc, PartitionEdge.(post_order_dfs_edges(g, v)); kwargs...)
return update(
Algorithm("SimpleBP"),
bmpsc,
PartitionEdge.(post_order_dfs_edges(g, v));
message_update_kwargs=(; normalize=false),
)
end

#Move the orthogonality centre one step on an interpartition from the message tensor on pe1 to that on pe2
Expand Down Expand Up @@ -327,7 +341,8 @@ function biorthogonalize(bmpsc::BoundaryMPSCache, args...; kwargs...)
end

#Update all the message tensors on an interpartition via an orthogonal fitting procedure
function mps_update(
#TODO: Unify this to one function and make two-site possible
function update(
alg::Algorithm"orthogonal",
bmpsc::BoundaryMPSCache,
partitionpair::Pair;
Expand All @@ -350,18 +365,11 @@ function mps_update(
orthogonalize(bmpsc, reverse(update_pe))
end
bmpsc = if !isnothing(prev_v)
partition_update(
bmpsc,
prev_v,
cur_v;
message_update=ms -> default_message_update(ms; normalize=false),
)
partition_update(bmpsc, prev_v, cur_v)
else
partition_update(
bmpsc, cur_v; message_update=ms -> default_message_update(ms; normalize=false)
)
partition_update(bmpsc, cur_v)
end
me = update_message(
me = updated_message(
bmpsc, update_pe; message_update=ms -> default_message_update(ms; normalize)
)
costfunction += region_scalar(bp_cache(bmpsc), src(update_pe)) / norm(me)
Expand All @@ -379,7 +387,7 @@ function mps_update(
end

#Update all the message tensors on an interpartition via a biorthogonal fitting procedure
function mps_update(
function update(
alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache, partitionpair::Pair; normalize=true
)
prev_v, prev_pe = nothing, nothing
Expand All @@ -391,21 +399,14 @@ function mps_update(
biorthogonalize(bmpsc, update_pe)
end
bmpsc = if !isnothing(prev_v)
partition_update(
bmpsc,
prev_v,
cur_v;
message_update=ms -> default_message_update(ms; normalize=false),
)
partition_update(bmpsc, prev_v, cur_v)
else
partition_update(
bmpsc, cur_v; message_update=ms -> default_message_update(ms; normalize=false)
)
partition_update(bmpsc, cur_v)
end

me_prev = only(message(bmpsc, update_pe))
me = only(
update_message(
updated_message(
bmpsc, update_pe; message_update=ms -> default_message_update(ms; normalize)
),
)
Expand All @@ -432,31 +433,6 @@ function mps_update(
return bmpsc
end

"""
More generic interface for update, with default params
"""
function update(
alg::Algorithm,
bmpsc::BoundaryMPSCache;
partitionpairs=default_edge_sequence(bmpsc),
maxiter=default_bp_maxiter(alg, bmpsc),
mps_fit_kwargs=default_mps_fit_kwargs(alg, bmpsc),
)
if isnothing(maxiter)
error("You need to specify a number of iterations for Boundary MPS!")
end
for i in 1:maxiter
for partitionpair in partitionpairs
bmpsc = mps_update(alg, bmpsc, partitionpair; mps_fit_kwargs...)
end
end
return bmpsc
end

function update(bmpsc::BoundaryMPSCache; alg::String="orthogonal", kwargs...)
return update(Algorithm(alg), bmpsc; kwargs...)
end

#Assume all vertices live in the same partition for now
function ITensorNetworks.environment(bmpsc::BoundaryMPSCache, verts::Vector; kwargs...)
vs = parent.((partitionvertices(bp_cache(bmpsc), verts)))
Expand Down
4 changes: 2 additions & 2 deletions test/test_belief_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using ITensorNetworks:
tensornetwork,
update,
update_factor,
update_message,
updated_message,
message_diff
using ITensors: ITensors, ITensor, combiner, dag, inds, inner, op, prime, random_itensor
using ITensorNetworks.ModelNetworks: ModelNetworks
Expand Down Expand Up @@ -51,7 +51,7 @@ using Test: @test, @testset
bpc = update(bpc; maxiter=25, tol=eps(real(elt)))
#Test messages are converged
for pe in partitionedges(partitioned_tensornetwork(bpc))
@test message_diff(update_message(bpc, pe), message(bpc, pe)) < 10 * eps(real(elt))
@test message_diff(updated_message(bpc, pe), message(bpc, pe)) < 10 * eps(real(elt))
@test eltype(only(message(bpc, pe))) == elt
end
#Test updating the underlying tensornetwork in the cache
Expand Down
Loading

0 comments on commit 99b4929

Please sign in to comment.