Skip to content

Commit

Permalink
add HD1BP amd HV1BP impls and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 19, 2023
1 parent 270c810 commit fc072ae
Show file tree
Hide file tree
Showing 6 changed files with 1,507 additions and 6 deletions.
130 changes: 125 additions & 5 deletions quimb/experimental/belief_propagation/bp_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,7 @@ def prod(xs):
return functools.reduce(operator.mul, xs)


def initialize_hyper_messages(
tn,
fill_fn=None,
smudge_factor=1e-12
):
def initialize_hyper_messages(tn, fill_fn=None, smudge_factor=1e-12):
"""Initialize messages for belief propagation, this is equivalent to doing
a single round of belief propagation with uniform messages.
Expand Down Expand Up @@ -184,6 +180,130 @@ def combine_local_contractions(
return mantissa * 10**exponent


def contract_hyper_messages(
tn,
messages,
strip_exponent=False,
backend=None,
):
"""Estimate the contraction of ``tn`` given ``messages``, via the
exponential of the Bethe free entropy.
"""
tvals = []
mvals = []

for tid, t in tn.tensor_map.items():
if backend is None:
backend = ar.infer_backend(t.data)

arrays = [t.data]
inputs = [range(t.ndim)]
for i, ix in enumerate(t.inds):
arrays.append(messages[ix, tid])
inputs.append((i,))

# local message overlap correction
mvals.append(
qtn.array_contract(
(messages[tid, ix], messages[ix, tid]),
inputs=((0,), (0,)),
output=(),
)
)

# local factor free entropy
tvals.append(qtn.array_contract(arrays, inputs, output=()))

for ix, tids in tn.ind_map.items():
arrays = tuple(messages[tid, ix] for tid in tids)
inputs = tuple((0,) for _ in tids)
# local variable free entropy
tvals.append(qtn.array_contract(arrays, inputs, output=()))

return combine_local_contractions(
tvals, mvals, backend, strip_exponent=strip_exponent
)


def compute_index_marginal(tn, ind, messages):
"""Compute the marginal for a single index given ``messages``.
Parameters
----------
tn : TensorNetwork
The tensor network to compute the marginal for.
ind : int
The index to compute the marginal for.
messages : dict
The messages to use, which should match ``tn``.
Returns
-------
marginal : array_like
The marginal probability distribution for the index ``ind``.
"""
m = prod(messages[tid, ind] for tid in tn.ind_map[ind])
return m / ar.do("sum", m)


def compute_tensor_marginal(tn, tid, messages):
"""Compute the marginal for the region surrounding a single tensor/factor
given ``messages``.
Parameters
----------
tn : TensorNetwork
The tensor network to compute the marginal for.
tid : int
The tensor id to compute the marginal for.
messages : dict
The messages to use, which should match ``tn``.
Returns
-------
marginal : array_like
The marginal probability distribution for the tensor/factor ``tid``.
"""
t = tn.tensor_map[tid]

output = tuple(range(t.ndim))
inputs = [output]
arrays = [t.data]

for i, ix in enumerate(t.inds):
mix = prod(
messages[otid, ix] for otid in tn.ind_map[ix] if otid != tid
)
inputs.append((i,))
arrays.append(mix)

m = qtn.array_contract(
arrays=arrays,
inputs=inputs,
output=output,
)

return m / ar.do("sum", m)


def compute_all_index_marginals_from_messages(tn, messages):
"""Compute all index marginals from belief propagation messages.
Parameters
----------
tn : TensorNetwork
The tensor network to compute marginals for.
messages : dict
The belief propagation messages.
Returns
-------
marginals : dict
The marginals for each index.
"""
return {ix: compute_index_marginal(tn, ix, messages) for ix in tn.ind_map}


def maybe_get_thread_pool(thread_pool):
"""Get a thread pool if requested."""
if thread_pool is False:
Expand Down
Loading

0 comments on commit fc072ae

Please sign in to comment.