Skip to content

Commit

Permalink
add damping to hyper 1-BP algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 19, 2023
1 parent fc072ae commit 7bd8848
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 19 deletions.
24 changes: 15 additions & 9 deletions quimb/experimental/belief_propagation/hd1bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def compute_all_tensor_messages_shortcuts(x, ms, ndim):


def compute_all_tensor_messages_prod(
x, ms, backend=None, smudge_factor=1e-12,
x,
ms,
backend=None,
smudge_factor=1e-12,
):
"""Given set of messages ``ms`` incident to tensor with data ``x``, compute
the corresponding next output messages, using the 'prod' implementation.
Expand Down Expand Up @@ -236,7 +239,8 @@ def _normalize_and_insert(k, m, max_dm):
for tid, t in tn.tensor_map.items():
inds = t.inds
ms = compute_all_tensor_messages_tree(
t.data, [messages[ix, tid] for ix in inds],
t.data,
[messages[ix, tid] for ix in inds],
)
for ix, m in zip(inds, ms):
max_dm = _normalize_and_insert((tid, ix), m, max_dm)
Expand Down Expand Up @@ -298,15 +302,13 @@ def contract(self, strip_exponent=False):
)




def contract_hd1bp(
tn,
messages=None,
max_iterations=1000,
tol=5e-6,
smudge_factor=1e-12,
damping=0.0,
smudge_factor=1e-12,
strip_exponent=False,
progbar=False,
):
Expand All @@ -323,11 +325,11 @@ def contract_hd1bp(
The maximum number of iterations to perform.
tol : float, optional
The convergence tolerance for messages.
damping : float, optional
The damping factor to use, 0.0 means no damping.
smudge_factor : float, optional
A small number to add to the denominator of messages to avoid division
by zero. Note when this happens the numerator will also be zero.
damping : float, optional
The damping factor to use, 0.0 means no damping.
strip_exponent : bool, optional
Whether to strip the exponent from the final result. If ``True``
then the returned result is ``(mantissa, exponent)``.
Expand All @@ -352,12 +354,12 @@ def contract_hd1bp(
return bp.contract(strip_exponent=strip_exponent)



def run_belief_propagation_hd1bp(
tn,
messages=None,
max_iterations=1000,
tol=5e-6,
damping=0.0,
smudge_factor=1e-12,
progbar=False,
):
Expand Down Expand Up @@ -389,7 +391,9 @@ def run_belief_propagation_hd1bp(
converged : bool
Whether the algorithm converged.
"""
bp = HD1BP(tn, messages=messages, smudge_factor=smudge_factor)
bp = HD1BP(
tn, messages=messages, damping=damping, smudge_factor=smudge_factor
)
bp.run(max_iterations=max_iterations, tol=tol, progbar=progbar)
return bp.messages, bp.converged

Expand All @@ -400,6 +404,7 @@ def sample_hd1bp(
output_inds=None,
max_iterations=1000,
tol=1e-2,
damping=0.0,
smudge_factor=1e-12,
bias=False,
seed=None,
Expand Down Expand Up @@ -485,6 +490,7 @@ def sample_hd1bp(
messages,
max_iterations=max_iterations,
tol=tol,
damping=damping,
smudge_factor=smudge_factor,
progbar=True,
)
Expand Down
13 changes: 9 additions & 4 deletions quimb/experimental/belief_propagation/hv1bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,7 @@ def iterate(self, **kwargs):
return None, None, max_dm

def get_messages(self):
"""Get messages in individual form from the batched stacks.
"""
"""Get messages in individual form from the batched stacks."""
return _extract_messages_from_inputs_batched(
self.batched_inputs_m,
self.batched_inputs_t,
Expand Down Expand Up @@ -587,6 +586,7 @@ def run_belief_propagation_hv1bp(
messages=None,
max_iterations=1000,
tol=5e-6,
damping=0.0,
smudge_factor=1e-12,
progbar=False,
):
Expand All @@ -604,6 +604,8 @@ def run_belief_propagation_hv1bp(
The maximum number of iterations to run for.
tol : float, optional
The convergence tolerance.
damping : float, optional
The damping factor to use, 0.0 means no damping.
smudge_factor : float, optional
A small number to add to the denominator of messages to avoid division
by zero. Note when this happens the numerator will also be zero.
Expand All @@ -617,7 +619,9 @@ def run_belief_propagation_hv1bp(
converged : bool
Whether the algorithm converged.
"""
bp = HV1BP(tn, messages=messages, smudge_factor=smudge_factor)
bp = HV1BP(
tn, messages=messages, damping=damping, smudge_factor=smudge_factor
)
bp.run(max_iterations=max_iterations, tol=tol, progbar=progbar)
return bp.get_messages(), bp.converged

Expand All @@ -628,6 +632,7 @@ def sample_hv1bp(
output_inds=None,
max_iterations=1000,
tol=1e-2,
damping=0.0,
smudge_factor=1e-12,
bias=False,
seed=None,
Expand Down Expand Up @@ -713,8 +718,8 @@ def sample_hv1bp(
messages,
max_iterations=max_iterations,
tol=tol,
damping=damping,
smudge_factor=smudge_factor,
progbar=True,
)

marginals = compute_all_index_marginals_from_messages(
Expand Down
9 changes: 5 additions & 4 deletions tests/test_tensor/test_belief_propagation/test_hd1bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)


@pytest.mark.parametrize("damping", [0.0, 0.1, 0.5])
@pytest.mark.parametrize("damping", [0.0, 0.1])
def test_contract_hyper(damping):
htn = qtn.HTN_random_ksat(3, 50, alpha=2.0, seed=42, mode="dense")
num_solutions = contract_hd1bp(htn, damping=damping)
Expand All @@ -22,18 +22,19 @@ def test_contract_tree_exact():
assert Z == pytest.approx(Z_bp, rel=1e-12)


@pytest.mark.parametrize("damping", [0.0, 0.1, 0.5])
@pytest.mark.parametrize("damping", [0.0, 0.1])
def test_contract_normal(damping):
tn = qtn.TN2D_from_fill_fn(lambda s: qu.randn(s, dist="uniform"), 6, 6, 2)
Z = tn.contract()
Z_bp = contract_hd1bp(tn, damping=damping)
assert Z == pytest.approx(Z_bp, rel=1e-1)


def test_sample():
@pytest.mark.parametrize("damping", [0.0, 0.1])
def test_sample(damping):
nvars = 20
htn = qtn.HTN_random_ksat(3, nvars, alpha=2.0, seed=42, mode="dense")
config, tn_config, omega = sample_hd1bp(htn, progbar=True)
config, tn_config, omega = sample_hd1bp(htn, damping=damping)
assert len(config) == nvars
assert tn_config.num_indices == 0
assert tn_config.contract() == pytest.approx(1.0)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_tensor/test_belief_propagation/test_hv1bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ def test_contract_normal(damping):
assert Z == pytest.approx(Z_bp, rel=1e-1)


def test_sample():
@pytest.mark.parametrize("damping", [0.0, 0.1])
def test_sample(damping):
nvars = 20
htn = qtn.HTN_random_ksat(3, nvars, alpha=2.0, seed=42, mode="dense")
config, tn_config, omega = sample_hv1bp(htn, progbar=True)
config, tn_config, omega = sample_hv1bp(htn, damping=damping)
assert len(config) == nvars
assert tn_config.num_indices == 0
assert tn_config.contract() == pytest.approx(1.0)
Expand Down

0 comments on commit 7bd8848

Please sign in to comment.