diff --git a/quimb/tensor/tensor_arbgeom.py b/quimb/tensor/tensor_arbgeom.py index 002ee492..81963834 100644 --- a/quimb/tensor/tensor_arbgeom.py +++ b/quimb/tensor/tensor_arbgeom.py @@ -595,7 +595,7 @@ def gate( is essentially a wrapper around :meth:`~quimb.tensor.tensor_core.TensorNetwork.gate_inds` apart from ``where`` can be specified as a list of sites, and tags can be - optionally, intelligently propgated to the new gate tensor. + optionally, intelligently propagated to the new gate tensor. .. math:: diff --git a/quimb/tensor/tensor_core.py b/quimb/tensor/tensor_core.py index f499a086..cb87ed42 100644 --- a/quimb/tensor/tensor_core.py +++ b/quimb/tensor/tensor_core.py @@ -1652,6 +1652,19 @@ def left_inds(self): def left_inds(self, left_inds): self._left_inds = tuple(left_inds) if left_inds is not None else None + def check(self): + """Do some basic diagnostics on this tensor, raising errors if + something is wrong.""" + if ndim(self.data) != len(self.inds): + raise ValueError( + f"Wrong number of inds, {self.inds}, supplied for array" + f" of shape {self.data.shape}." + ) + if not do("all", do("isfinite", self.data)): + raise ValueError( + f"Tensor data contains non-finite values: {self.data}." + ) + @property def owners(self): return self._owners @@ -1665,10 +1678,7 @@ def add_owner(self, tn, tid): def remove_owner(self, tn): """Remove TensorNetwork ``tn`` as an owner of this Tensor. """ - try: - del self._owners[hash(tn)] - except KeyError: - pass + self._owners.pop(hash(tn), None) def check_owners(self): """Check if this tensor is 'owned' by any alive TensorNetworks. Also @@ -2148,7 +2158,7 @@ def trace( old_inds, new_inds = tuple(old_inds), tuple(new_inds) - eq = _inds_to_eq((old_inds,), new_inds) + eq = inds_to_eq((old_inds,), new_inds) t.modify(apply=lambda x: do('einsum', eq, x, like=x), inds=new_inds, left_inds=None) @@ -3888,6 +3898,60 @@ def delete(self, tags, which='all'): for tid in tuple(tids): self.pop_tensor(tid) + def check(self): + """Check some basic diagnostics of the tensor network. + """ + for tid, t in self.tensor_map.items(): + t.check() + + if not t.check_owners(): + raise ValueError( + f"Tensor {tid} doesn't have any owners, but should have " + "this tensor network as one." + ) + if not any( + (tid == ref_tid and (ref() is self)) + for ref, ref_tid in t._owners.values() + ): + raise ValueError( + f"Tensor {tid} does not have this tensor network as an " + "owner." + ) + + # check indices correctly registered + for ix in t.inds: + ix_tids = self.ind_map.get(ix, None) + if ix_tids is None: + raise ValueError( + f"Index {ix} of tensor {tid} not in index map." + ) + if tid not in ix_tids: + raise ValueError( + f"Tensor {tid} not registered under index {ix}." + ) + + # check tags correctly registered + for tag in t.tags: + tag_tids = self.tag_map.get(tag, None) + if tag_tids is None: + raise ValueError( + f"Tag {tag} of tensor {tid} not in tag map." + ) + if tid not in tag_tids: + raise ValueError( + f"Tensor {tid} not registered under tag {tag}." + ) + + # check that all index dimensions match across incident tensors + for ix, tids in self.ind_map.items(): + ts = tuple(self._tids_get(*tids)) + dims = {t.ind_size(ix) for t in ts} + if len(dims) != 1: + raise ValueError( + "Mismatched index dimension for index " + f"'{ix}' in tensors {ts}." + ) + def add_tag(self, tag, where=None, which='all'): """Add tag to every tensor in this network, or if ``where`` is specified, the tensors matching those tags -- i.e. adds the tag to @@ -5827,6 +5891,64 @@ def _get_neighbor_tids(self, tids, exclude_inds=()): return neighbors + def _get_subgraph_tids(self, tids): + """Get the tids of tensors connected, by any distance, to the tensor or + region of tensors ``tids``. + """ + region = tags_to_oset(tids) + queue = list(self._get_neighbor_tids(region)) + while queue: + tid = queue.pop() + if tid not in region: + region.add(tid) + queue.extend(self._get_neighbor_tids([tid])) + return region + + def _ind_to_subgraph_tids(self, ind): + """Get the tids of tensors connected, by any distance, to the index + ``ind``. + """ + return self._get_subgraph_tids(self._get_tids_from_inds(ind)) + + def istree(self): + """Check if this tensor network has a tree structure, (treating + multibonds as a single edge). + + Examples + -------- + + >>> MPS_rand_state(10, 7).istree() + True + + >>> MPS_rand_state(10, 7, cyclic=True).istree() + False + + """ + tid0 = next(iter(self.tensor_map)) + region = [(tid0, None)] + seen = {tid0} + while region: + tid, ptid = region.pop() + for ntid in self._get_neighbor_tids(tid): + if ntid == ptid: + # ignore the previous tid we just came from + continue + if ntid in seen: + # found a loop + return False + # expand the queue + region.append((ntid, tid)) + seen.add(ntid) + return True + + def isconnected(self): + """Check whether this tensor network is connected, i.e. whether + there is a path between any two tensors, (including size 1 indices). + """ + tid0 = next(iter(self.tensor_map)) + region = self._get_subgraph_tids([tid0]) + return len(region) == len(self.tensor_map) + def subgraphs(self, virtual=False): """Split this tensor network into disconneceted subgraphs. @@ -9049,6 +9171,19 @@ def gen_loops(self, max_loop_length=None): hg = get_hypergraph(inputs, accel='auto') return hg.compute_loops(max_loop_length) + def _get_string_between_tids(self, tida, tidb): + strings = [(tida,)] + while strings: + string = strings.pop(0) + tid_current = string[-1] + for tid_next in self._get_neighbor_tids(tid_current): + if tid_next == tidb: + # finished! + return string + (tidb,) + if tid_next not in string: + # continue onwards! + strings.append(string + (tid_next,)) + def tids_are_connected(self, tids): """Check whether nodes ``tids`` are connected. diff --git a/tests/test_tensor/test_tensor_1d.py b/tests/test_tensor/test_tensor_1d.py index dea0b008..a4812609 100644 --- a/tests/test_tensor/test_tensor_1d.py +++ b/tests/test_tensor/test_tensor_1d.py @@ -21,6 +21,7 @@ def test_matrix_product_state(self): [np.random.rand(5, 5, 2) for _ in range(3)] + [np.random.rand(5, 2)]) mps = MatrixProductState(tensors) + mps.check() assert len(mps.tensors) == 5 nmps = mps.reindex_sites('foo{}', inplace=False, where=slice(0, 3)) assert nmps.site_ind_id == "k{}" diff --git a/tests/test_tensor/test_tensor_core.py b/tests/test_tensor/test_tensor_core.py index e5f906a8..52dbcf73 100644 --- a/tests/test_tensor/test_tensor_core.py +++ b/tests/test_tensor/test_tensor_core.py @@ -63,6 +63,7 @@ def test_tensor_construct(self): def test_tensor_copy(self): a = Tensor(np.random.randn(2, 3, 4), inds=[0, 1, 2], tags="blue") b = a.copy() + b.check() b.add_tag("foo") assert "foo" not in a.tags b.data[:] = b.data / 2 @@ -144,6 +145,7 @@ def test_contract_some(self): assert a.shared_bond_size(b) == 12 c = a @ b + c.check() assert isinstance(c, Tensor) assert c.shape == (2, 5) @@ -160,6 +162,7 @@ def test_contract_None(self): a = Tensor(np.random.randn(2, 3, 4), inds=[0, 1, 2]) b = Tensor(np.random.randn(3, 4, 5), inds=[3, 4, 5]) c = a @ b + c.check() assert c.shape == (2, 3, 4, 3, 4, 5) assert c.inds == (0, 1, 2, 3, 4, 5) @@ -181,6 +184,7 @@ def test_multi_contract(self): b = Tensor(np.random.randn(3, 4, 5), inds=[1, 2, 3], tags="blue") c = Tensor(np.random.randn(5, 2, 6), inds=[3, 0, 4], tags="blue") d = tensor_contract(a, b, c) + d.check() assert isinstance(d, Tensor) assert d.shape == (6,) assert d.inds == (4,) @@ -1071,8 +1075,10 @@ def test_reindex(self): c = Tensor(np.random.randn(5, 2, 6), inds=[3, 0, 4], tags="green") a_b_c = a & b & c + a_b_c.check() d = a_b_c.reindex({4: "foo", 2: "bar"}) + d.check() assert a_b_c.outer_inds() == (4,) assert d.outer_inds() == ("foo",) @@ -1211,6 +1217,7 @@ def test_replace_with_identity(self): ) tn = A & B & C & D + tn.check() with pytest.raises(ValueError): tn.replace_with_identity(("I1", "I2"), inplace=True) @@ -1219,6 +1226,7 @@ def test_replace_with_identity(self): tn["I3"] = rand_tensor((4,), "f", tags=["I3"]) tn1 = tn.replace_with_identity(("I1", "I2")) + tn1.check() assert len(tn1.tensors) == 2 x = tn1 ^ ... assert set(x.inds) == {"a", "b"} @@ -1385,6 +1393,7 @@ def test_tn_split_tensor(self): mps = MPS_rand_state(4, 3) right_inds = bonds(mps[1], mps[2]) mps.split_tensor(1, left_inds=None, right_inds=right_inds, rtags="X") + mps.check() assert mps.num_tensors == 5 assert mps["X"].shape == (3, 3) assert mps.H @ mps == pytest.approx(1.0) @@ -1561,6 +1570,30 @@ def test_hyperind_simplification_with_outputs(self): p2 = stn2.contract(output_inds=output_inds).data assert_allclose(pex, p2) + def test_istree(self): + assert Tensor().as_network().istree() + tn = rand_tensor([2] * 1, ['x']).as_network() + assert tn.istree() + tn |= rand_tensor([2] * 3, ['x', 'y', 'z']) + assert tn.istree() + tn |= rand_tensor([2] * 2, ['y', 'z']) + assert tn.istree() + tn |= rand_tensor([2] * 2, ['x', 'z']) + assert not tn.istree() + + def test_isconnected(self): + assert Tensor().as_network().isconnected() + tn = rand_tensor([2] * 1, ['x']).as_network() + assert tn.isconnected() + tn |= rand_tensor([2] * 3, ['x', 'y', 'z']) + assert tn.isconnected() + tn |= rand_tensor([2] * 2, ['w', 'u']) + assert not tn.isconnected() + assert not (Tensor() | Tensor()).isconnected() + + def test_get_string_between_tids(self): + tn = MPS_rand_state(5, 3) + assert tn._get_string_between_tids(0, 4) == (0, 1, 2, 3, 4) class TestTensorNetworkSimplifications: def test_rank_simplify(self):