Skip to content

Commit

Permalink
add lazy 2-norm BP, various tree funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 20, 2023
1 parent 7bd8848 commit 11d1c3a
Show file tree
Hide file tree
Showing 9 changed files with 971 additions and 315 deletions.
17 changes: 10 additions & 7 deletions quimb/experimental/belief_propagation/bp_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import quimb.tensor as qtn


def prod(xs):
"""Product of all elements in ``xs``."""
return functools.reduce(operator.mul, xs)


class RollingDiffMean:
"""Tracker for the absolute rolling mean of diffs between values, to
assess effective convergence of BP above actual message tolerance.
Expand Down Expand Up @@ -59,11 +64,14 @@ def run(self, max_iterations=1000, tol=5e-6, progbar=False):
rdm = RollingDiffMean()
self.converged = False
while not self.converged and it < max_iterations:
# can only converge if tol > 0.0
# perform a single iteration of BP
# we supply tol here for use with local convergence
nconv, ncheck, max_mdiff = self.iterate(tol=tol)
it += 1

# check rolling mean convergence
rdm.update(max_mdiff)
self.converged = (max_mdiff < tol) or (rdm.absmeandiff() < tol)
it += 1

if pbar is not None:
pbar.set_description(
Expand All @@ -85,11 +93,6 @@ def run(self, max_iterations=1000, tol=5e-6, progbar=False):
)


def prod(xs):
"""Product of all elements in ``xs``."""
return functools.reduce(operator.mul, xs)


def initialize_hyper_messages(tn, fill_fn=None, smudge_factor=1e-12):
"""Initialize messages for belief propagation, this is equivalent to doing
a single round of belief propagation with uniform messages.
Expand Down
4 changes: 2 additions & 2 deletions quimb/experimental/belief_propagation/d2bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,11 @@ def compress_d2bp(
Computed automatically if not specified.
optimize : str or PathOptimizer, optional
The path optimizer to use when contracting the messages.
damping : float, optional
The damping parameter to use, defaults to no damping.
local_convergence : bool, optional
Whether to allow messages to locally converge - i.e. if all their
input messages have converged then stop updating them.
damping : float, optional
The damping parameter to use, defaults to no damping.
inplace : bool, optional
Whether to perform the compression inplace.
progbar : bool, optional
Expand Down
Loading

0 comments on commit 11d1c3a

Please sign in to comment.