Skip to content

Commit

Permalink
TN: add various checking methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jul 27, 2023
1 parent 6852f2d commit 155ba5c
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 6 deletions.
2 changes: 1 addition & 1 deletion quimb/tensor/tensor_arbgeom.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def gate(
is essentially a wrapper around
:meth:`~quimb.tensor.tensor_core.TensorNetwork.gate_inds` apart from
``where`` can be specified as a list of sites, and tags can be
optionally, intelligently propgated to the new gate tensor.
optionally, intelligently propagated to the new gate tensor.
.. math::
Expand Down
145 changes: 140 additions & 5 deletions quimb/tensor/tensor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1652,6 +1652,19 @@ def left_inds(self):
def left_inds(self, left_inds):
self._left_inds = tuple(left_inds) if left_inds is not None else None

def check(self):
"""Do some basic diagnostics on this tensor, raising errors if
something is wrong."""
if ndim(self.data) != len(self.inds):
raise ValueError(
f"Wrong number of inds, {self.inds}, supplied for array"
f" of shape {self.data.shape}."
)
if not do("all", do("isfinite", self.data)):
raise ValueError(
f"Tensor data contains non-finite values: {self.data}."
)

@property
def owners(self):
return self._owners
Expand All @@ -1665,10 +1678,7 @@ def add_owner(self, tn, tid):
def remove_owner(self, tn):
"""Remove TensorNetwork ``tn`` as an owner of this Tensor.
"""
try:
del self._owners[hash(tn)]
except KeyError:
pass
self._owners.pop(hash(tn), None)

def check_owners(self):
"""Check if this tensor is 'owned' by any alive TensorNetworks. Also
Expand Down Expand Up @@ -2148,7 +2158,7 @@ def trace(

old_inds, new_inds = tuple(old_inds), tuple(new_inds)

eq = _inds_to_eq((old_inds,), new_inds)
eq = inds_to_eq((old_inds,), new_inds)
t.modify(apply=lambda x: do('einsum', eq, x, like=x),
inds=new_inds, left_inds=None)

Expand Down Expand Up @@ -3888,6 +3898,60 @@ def delete(self, tags, which='all'):
for tid in tuple(tids):
self.pop_tensor(tid)

def check(self):
"""Check some basic diagnostics of the tensor network.
"""
for tid, t in self.tensor_map.items():
t.check()

if not t.check_owners():
raise ValueError(
f"Tensor {tid} doesn't have any owners, but should have "
"this tensor network as one."
)
if not any(
(tid == ref_tid and (ref() is self))
for ref, ref_tid in t._owners.values()
):
raise ValueError(
f"Tensor {tid} does not have this tensor network as an "
"owner."
)

# check indices correctly registered
for ix in t.inds:
ix_tids = self.ind_map.get(ix, None)
if ix_tids is None:
raise ValueError(
f"Index {ix} of tensor {tid} not in index map."
)
if tid not in ix_tids:
raise ValueError(
f"Tensor {tid} not registered under index {ix}."
)

# check tags correctly registered
for tag in t.tags:
tag_tids = self.tag_map.get(tag, None)
if tag_tids is None:
raise ValueError(
f"Tag {tag} of tensor {tid} not in tag map."
)
if tid not in tag_tids:
raise ValueError(
f"Tensor {tid} not registered under tag {tag}."
)

# check that all index dimensions match across incident tensors
for ix, tids in self.ind_map.items():
ts = tuple(self._tids_get(*tids))
dims = {t.ind_size(ix) for t in ts}
if len(dims) != 1:
raise ValueError(
"Mismatched index dimension for index "
f"'{ix}' in tensors {ts}."
)

def add_tag(self, tag, where=None, which='all'):
"""Add tag to every tensor in this network, or if ``where`` is
specified, the tensors matching those tags -- i.e. adds the tag to
Expand Down Expand Up @@ -5827,6 +5891,64 @@ def _get_neighbor_tids(self, tids, exclude_inds=()):

return neighbors

def _get_subgraph_tids(self, tids):
"""Get the tids of tensors connected, by any distance, to the tensor or
region of tensors ``tids``.
"""
region = tags_to_oset(tids)
queue = list(self._get_neighbor_tids(region))
while queue:
tid = queue.pop()
if tid not in region:
region.add(tid)
queue.extend(self._get_neighbor_tids([tid]))
return region

def _ind_to_subgraph_tids(self, ind):
"""Get the tids of tensors connected, by any distance, to the index
``ind``.
"""
return self._get_subgraph_tids(self._get_tids_from_inds(ind))

def istree(self):
"""Check if this tensor network has a tree structure, (treating
multibonds as a single edge).
Examples
--------
>>> MPS_rand_state(10, 7).istree()
True
>>> MPS_rand_state(10, 7, cyclic=True).istree()
False
"""
tid0 = next(iter(self.tensor_map))
region = [(tid0, None)]
seen = {tid0}
while region:
tid, ptid = region.pop()
for ntid in self._get_neighbor_tids(tid):
if ntid == ptid:
# ignore the previous tid we just came from
continue
if ntid in seen:
# found a loop
return False
# expand the queue
region.append((ntid, tid))
seen.add(ntid)
return True

def isconnected(self):
"""Check whether this tensor network is connected, i.e. whether
there is a path between any two tensors, (including size 1 indices).
"""
tid0 = next(iter(self.tensor_map))
region = self._get_subgraph_tids([tid0])
return len(region) == len(self.tensor_map)

def subgraphs(self, virtual=False):
"""Split this tensor network into disconneceted subgraphs.
Expand Down Expand Up @@ -9049,6 +9171,19 @@ def gen_loops(self, max_loop_length=None):
hg = get_hypergraph(inputs, accel='auto')
return hg.compute_loops(max_loop_length)

def _get_string_between_tids(self, tida, tidb):
strings = [(tida,)]
while strings:
string = strings.pop(0)
tid_current = string[-1]
for tid_next in self._get_neighbor_tids(tid_current):
if tid_next == tidb:
# finished!
return string + (tidb,)
if tid_next not in string:
# continue onwards!
strings.append(string + (tid_next,))

def tids_are_connected(self, tids):
"""Check whether nodes ``tids`` are connected.
Expand Down
1 change: 1 addition & 0 deletions tests/test_tensor/test_tensor_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_matrix_product_state(self):
[np.random.rand(5, 5, 2)
for _ in range(3)] + [np.random.rand(5, 2)])
mps = MatrixProductState(tensors)
mps.check()
assert len(mps.tensors) == 5
nmps = mps.reindex_sites('foo{}', inplace=False, where=slice(0, 3))
assert nmps.site_ind_id == "k{}"
Expand Down
33 changes: 33 additions & 0 deletions tests/test_tensor/test_tensor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_tensor_construct(self):
def test_tensor_copy(self):
a = Tensor(np.random.randn(2, 3, 4), inds=[0, 1, 2], tags="blue")
b = a.copy()
b.check()
b.add_tag("foo")
assert "foo" not in a.tags
b.data[:] = b.data / 2
Expand Down Expand Up @@ -144,6 +145,7 @@ def test_contract_some(self):
assert a.shared_bond_size(b) == 12

c = a @ b
c.check()

assert isinstance(c, Tensor)
assert c.shape == (2, 5)
Expand All @@ -160,6 +162,7 @@ def test_contract_None(self):
a = Tensor(np.random.randn(2, 3, 4), inds=[0, 1, 2])
b = Tensor(np.random.randn(3, 4, 5), inds=[3, 4, 5])
c = a @ b
c.check()
assert c.shape == (2, 3, 4, 3, 4, 5)
assert c.inds == (0, 1, 2, 3, 4, 5)

Expand All @@ -181,6 +184,7 @@ def test_multi_contract(self):
b = Tensor(np.random.randn(3, 4, 5), inds=[1, 2, 3], tags="blue")
c = Tensor(np.random.randn(5, 2, 6), inds=[3, 0, 4], tags="blue")
d = tensor_contract(a, b, c)
d.check()
assert isinstance(d, Tensor)
assert d.shape == (6,)
assert d.inds == (4,)
Expand Down Expand Up @@ -1071,8 +1075,10 @@ def test_reindex(self):
c = Tensor(np.random.randn(5, 2, 6), inds=[3, 0, 4], tags="green")

a_b_c = a & b & c
a_b_c.check()

d = a_b_c.reindex({4: "foo", 2: "bar"})
d.check()

assert a_b_c.outer_inds() == (4,)
assert d.outer_inds() == ("foo",)
Expand Down Expand Up @@ -1211,6 +1217,7 @@ def test_replace_with_identity(self):
)

tn = A & B & C & D
tn.check()

with pytest.raises(ValueError):
tn.replace_with_identity(("I1", "I2"), inplace=True)
Expand All @@ -1219,6 +1226,7 @@ def test_replace_with_identity(self):
tn["I3"] = rand_tensor((4,), "f", tags=["I3"])

tn1 = tn.replace_with_identity(("I1", "I2"))
tn1.check()
assert len(tn1.tensors) == 2
x = tn1 ^ ...
assert set(x.inds) == {"a", "b"}
Expand Down Expand Up @@ -1385,6 +1393,7 @@ def test_tn_split_tensor(self):
mps = MPS_rand_state(4, 3)
right_inds = bonds(mps[1], mps[2])
mps.split_tensor(1, left_inds=None, right_inds=right_inds, rtags="X")
mps.check()
assert mps.num_tensors == 5
assert mps["X"].shape == (3, 3)
assert mps.H @ mps == pytest.approx(1.0)
Expand Down Expand Up @@ -1561,6 +1570,30 @@ def test_hyperind_simplification_with_outputs(self):
p2 = stn2.contract(output_inds=output_inds).data
assert_allclose(pex, p2)

def test_istree(self):
assert Tensor().as_network().istree()
tn = rand_tensor([2] * 1, ['x']).as_network()
assert tn.istree()
tn |= rand_tensor([2] * 3, ['x', 'y', 'z'])
assert tn.istree()
tn |= rand_tensor([2] * 2, ['y', 'z'])
assert tn.istree()
tn |= rand_tensor([2] * 2, ['x', 'z'])
assert not tn.istree()

def test_isconnected(self):
assert Tensor().as_network().isconnected()
tn = rand_tensor([2] * 1, ['x']).as_network()
assert tn.isconnected()
tn |= rand_tensor([2] * 3, ['x', 'y', 'z'])
assert tn.isconnected()
tn |= rand_tensor([2] * 2, ['w', 'u'])
assert not tn.isconnected()
assert not (Tensor() | Tensor()).isconnected()

def test_get_string_between_tids(self):
tn = MPS_rand_state(5, 3)
assert tn._get_string_between_tids(0, 4) == (0, 1, 2, 3, 4)

class TestTensorNetworkSimplifications:
def test_rank_simplify(self):
Expand Down

0 comments on commit 155ba5c

Please sign in to comment.