Skip to content

Commit

Permalink
Unify update function
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Jan 9, 2025
1 parent 99b4929 commit 19630aa
Showing 1 changed file with 138 additions and 94 deletions.
232 changes: 138 additions & 94 deletions src/caches/boundarympscache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,31 +232,28 @@ function partition_update(bmpsc::BoundaryMPSCache, partition::Int64)
return bmpsc
end

#Update all messages within a partition along the path from from v1 to v2
function partition_update(bmpsc::BoundaryMPSCache, v1, v2)
return update(
Algorithm("SimpleBP"),
bmpsc,
PartitionEdge.(a_star(ppg(bmpsc), v1, v2));
message_update_kwargs=(; normalize=false),
)
function partition_update_sequence(bmpsc::BoundaryMPSCache, v1, v2)
return PartitionEdge.(a_star(ppg(bmpsc), v1, v2))
end

#Update all message tensors within a partition pointing towards v
function partition_update(bmpsc::BoundaryMPSCache, v)
function partition_update_sequence(bmpsc::BoundaryMPSCache, v)
pv = planargraph_partition(bmpsc, v)
g = subgraph(unpartitioned_graph(ppg(bmpsc)), planargraph_vertices(bmpsc, pv))
return PartitionEdge.(post_order_dfs_edges(g, v))
end

#Update all messages within a partition along the path from from v1 to v2
function partition_update(bmpsc::BoundaryMPSCache, args...)
return update(
Algorithm("SimpleBP"),
bmpsc,
PartitionEdge.(post_order_dfs_edges(g, v));
partition_update_sequence(bmpsc, args...);
message_update_kwargs=(; normalize=false),
)
end

#Move the orthogonality centre one step on an interpartition from the message tensor on pe1 to that on pe2
function gauge_step(
alg::Algorithm"orthogonalize",
alg::Algorithm"orthogonal",
bmpsc::BoundaryMPSCache,
pe1::PartitionEdge,
pe2::PartitionEdge;
Expand All @@ -276,7 +273,7 @@ end

#Move the biorthogonality centre one step on an interpartition from the partition edge pe1 (and its reverse) to that on pe2
function gauge_step(
alg::Algorithm"biorthogonalize",
alg::Algorithm"biorthogonal",
bmpsc::BoundaryMPSCache,
pe1::PartitionEdge,
pe2::PartitionEdge,
Expand Down Expand Up @@ -323,114 +320,161 @@ function gauge_walk(alg::Algorithm, bmpsc::BoundaryMPSCache, seq::Vector; kwargs
return bmpsc
end

function gauge(alg::Algorithm, bmpsc::BoundaryMPSCache, args...; kwargs...)
return gauge_walk(alg, bmpsc, mps_gauge_update_sequence(bmpsc, args...); kwargs...)
end

#Move the orthogonality centre on an interpartition to the message tensor on pe or between two pes
function ITensorMPS.orthogonalize(bmpsc::BoundaryMPSCache, args...; kwargs...)
return gauge_walk(
Algorithm("orthogonalize"), bmpsc, mps_gauge_update_sequence(bmpsc, args...); kwargs...
)
return gauge(Algorithm("orthogonal"), bmpsc, args...; kwargs...)
end

#Move the biorthogonality centre on an interpartition to the message tensor or between two pes
function biorthogonalize(bmpsc::BoundaryMPSCache, args...; kwargs...)
return gauge_walk(
Algorithm("biorthogonalize"),
bmpsc,
mps_gauge_update_sequence(bmpsc, args...);
kwargs...,
)
return gauge(Algorithm("biorthogonal"), bmpsc, args...; kwargs...)
end

#Update all the message tensors on an interpartition via an orthogonal fitting procedure
#TODO: Unify this to one function and make two-site possible
function update(
function default_inserter(
alg::Algorithm"orthogonal",
bmpsc::BoundaryMPSCache,
pe::PartitionEdge,
me::Vector{ITensor},
)
return set_message(bmpsc, reverse(pe), dag.(me))
end

function default_inserter(
alg::Algorithm"biorthogonal",
bmpsc::BoundaryMPSCache,
pe::PartitionEdge,
me::Vector{ITensor},
)
p_above, p_below = partitionedge_above(bmpsc, pe), partitionedge_below(bmpsc, pe)
me = only(me)
me_prev = only(message(bmpsc, pe))
if !isnothing(p_above)
me = replaceind(
me,
commonind(me, only(message(bmpsc, reverse(p_above)))),
commonind(me_prev, only(message(bmpsc, p_above))),
)
end
if !isnothing(p_below)
me = replaceind(
me,
commonind(me, only(message(bmpsc, reverse(p_below)))),
commonind(me_prev, only(message(bmpsc, p_below))),
)
end
return set_message(bmpsc, pe, ITensor[me])
end

function default_updater(
alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache, prev_pe, update_pe, prev_v, cur_v
)
bmpsc = if !isnothing(prev_pe)
gauge(alg, bmpsc, reverse(prev_pe), reverse(update_pe))
else
gauge(alg, bmpsc, reverse(update_pe))
end
bmpsc = if !isnothing(prev_v)
partition_update(bmpsc, prev_v, cur_v)
else
partition_update(bmpsc, cur_v)
end
return bmpsc
end

function default_updater(
alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache, prev_pe, update_pe, prev_v, cur_v
)
bmpsc = if !isnothing(prev_pe)
gauge(alg, bmpsc, prev_pe, update_pe)
else
gauge(alg, bmpsc, update_pe)
end
bmpsc = if !isnothing(prev_v)
partition_update(bmpsc, prev_v, cur_v)
else
partition_update(bmpsc, cur_v)
end
return bmpsc
end

function default_cache_prep_function(
alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache, partitionpair
)
return bmpsc
end
function default_cache_prep_function(
alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache, partitionpair
)
return switch_messages(bmpsc, partitionpair)
end

default_niters(alg::Algorithm"orthogonal") = 25
default_niters(alg::Algorithm"biorthogonal") = 3
default_tolerance(alg::Algorithm"orthogonal") = 1e-10
default_tolerance(alg::Algorithm"biorthogonal") = nothing

function default_costfunction(
alg::Algorithm"orthogonal",
bmpsc::BoundaryMPSCache,
pe::PartitionEdge,
me::Vector{ITensor},
)
return region_scalar(bp_cache(bmpsc), src(pe)) / norm(only(me))
end

function default_costfunction(
alg::Algorithm"biorthogonal",
bmpsc::BoundaryMPSCache,
pe::PartitionEdge,
me::Vector{ITensor},
)
return region_scalar(bp_cache(bmpsc), src(pe)) /
dot(only(me), only(message(bmpsc, reverse(pe))))
end

#Update all the message tensors on an interpartition via a specified fitting procedure
#TODO: Make two-site possible
function update(
alg::Algorithm,
bmpsc::BoundaryMPSCache,
partitionpair::Pair;
niters::Int64=25,
tolerance=1e-10,
inserter=default_inserter,
costfunction=default_costfunction,
updater=default_updater,
cache_prep_function=default_cache_prep_function,
niters::Int64=default_niters(alg),
tolerance=default_tolerance(alg),
normalize=true,
)
bmpsc = switch_messages(bmpsc, partitionpair)
bmpsc = cache_prep_function(alg, bmpsc, partitionpair)
pes = planargraph_partitionpair_partitionedges(bmpsc, partitionpair)
update_seq = vcat(pes, reverse(pes)[2:length(pes)])
prev_v, prev_pe = nothing, nothing
prev_costfunction = 0
prev_cf = 0
for i in 1:niters
costfunction = 0
cf = 0
for update_pe in update_seq
cur_v = parent(src(update_pe))
bmpsc = if !isnothing(prev_pe)
orthogonalize(bmpsc, reverse(prev_pe), reverse(update_pe))
else
orthogonalize(bmpsc, reverse(update_pe))
end
bmpsc = if !isnothing(prev_v)
partition_update(bmpsc, prev_v, cur_v)
else
partition_update(bmpsc, cur_v)
end
bmpsc = updater(alg, bmpsc, prev_pe, update_pe, prev_v, cur_v)
me = updated_message(
bmpsc, update_pe; message_update=ms -> default_message_update(ms; normalize)
)
costfunction += region_scalar(bp_cache(bmpsc), src(update_pe)) / norm(me)
bmpsc = set_message(bmpsc, reverse(update_pe), dag.(me))
cf += costfunction(alg, bmpsc, update_pe, me)
bmpsc = inserter(alg, bmpsc, update_pe, me)
prev_v, prev_pe = cur_v, update_pe
end
epsilon = abs(costfunction - prev_costfunction) / length(update_seq)
epsilon = abs(cf - prev_cf) / length(update_seq)
if !isnothing(tolerance) && epsilon < tolerance
return switch_messages(bmpsc, partitionpair)
return cache_prep_function(alg, bmpsc, partitionpair)
else
prev_costfunction = costfunction
prev_cf = cf
end
end
return switch_messages(bmpsc, partitionpair)
end

#Update all the message tensors on an interpartition via a biorthogonal fitting procedure
function update(
alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache, partitionpair::Pair; normalize=true
)
prev_v, prev_pe = nothing, nothing
for update_pe in planargraph_partitionpair_partitionedges(bmpsc, partitionpair)
cur_v = parent(src(update_pe))
bmpsc = if !isnothing(prev_pe)
biorthogonalize(bmpsc, prev_pe, update_pe)
else
biorthogonalize(bmpsc, update_pe)
end
bmpsc = if !isnothing(prev_v)
partition_update(bmpsc, prev_v, cur_v)
else
partition_update(bmpsc, cur_v)
end

me_prev = only(message(bmpsc, update_pe))
me = only(
updated_message(
bmpsc, update_pe; message_update=ms -> default_message_update(ms; normalize)
),
)
p_above, p_below = partitionedge_above(bmpsc, update_pe),
partitionedge_below(bmpsc, update_pe)
if !isnothing(p_above)
me = replaceind(
me,
commonind(me, only(message(bmpsc, reverse(p_above)))),
commonind(me_prev, only(message(bmpsc, p_above))),
)
end
if !isnothing(p_below)
me = replaceind(
me,
commonind(me, only(message(bmpsc, reverse(p_below)))),
commonind(me_prev, only(message(bmpsc, p_below))),
)
end
bmpsc = set_message(bmpsc, update_pe, ITensor[me])
prev_v, prev_pe = cur_v, update_pe
end

return bmpsc
return cache_prep_function(alg, bmpsc, partitionpair)
end

#Assume all vertices live in the same partition for now
Expand Down

0 comments on commit 19630aa

Please sign in to comment.