Skip to content

Commit

Permalink
TN1D: utilize standard canonize/compress funcs, test new compression …
Browse files Browse the repository at this point in the history
…methods
  • Loading branch information
jcmgray committed Apr 26, 2024
1 parent 605868f commit d38b707
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 179 deletions.
86 changes: 41 additions & 45 deletions quimb/tensor/tensor_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
)
from .tensor_core import (
Tensor,
TensorNetwork,
bonds,
new_bond,
oset,
rand_uuid,
tags_to_oset,
tensor_canonize_bond,
tensor_compress_bond,
)

align_TN_1D = deprecated(
Expand Down Expand Up @@ -681,42 +682,6 @@ class TensorNetwork1DFlat(TensorNetwork1D):

_EXTRA_PROPS = ("_site_tag_id", "_L")

def _left_decomp_site(self, i, bra=None, **split_opts):
T1, T2 = self[i], self[i + 1]
rix, lix = T1.filter_bonds(T2)

set_default_compress_mode(split_opts, self.cyclic)
Q, R = T1.split(lix, get="tensors", right_inds=rix, **split_opts)
R = R @ T2

Q.transpose_like_(T1)
R.transpose_like_(T2)

self[i].modify(data=Q.data)
self[i + 1].modify(data=R.data)

if bra is not None:
bra[i].modify(data=Q.data.conj())
bra[i + 1].modify(data=R.data.conj())

def _right_decomp_site(self, i, bra=None, **split_opts):
T1, T2 = self[i], self[i - 1]
lix, rix = T1.filter_bonds(T2)

set_default_compress_mode(split_opts, self.cyclic)
L, Q = T1.split(lix, get="tensors", right_inds=rix, **split_opts)
L = T2 @ L

L.transpose_like_(T2)
Q.transpose_like_(T1)

self[i - 1].modify(data=L.data)
self[i].modify(data=Q.data)

if bra is not None:
bra[i - 1].modify(data=L.data.conj())
bra[i].modify(data=Q.data.conj())

def left_canonize_site(self, i, bra=None):
r"""Left canonize this TN's ith site, inplace::
Expand All @@ -732,7 +697,11 @@ def left_canonize_site(self, i, bra=None):
bra : None or matching TensorNetwork to self, optional
If set, also update this TN's data with the conjugate canonization.
"""
self._left_decomp_site(i, bra=bra, method="qr")
tl, tr = self[i], self[i + 1]
tensor_canonize_bond(tl, tr)
if bra is not None:
bra[i].modify(data=conj(tl.data))
bra[i + 1].modify(data=conj(tr.data))

def right_canonize_site(self, i, bra=None):
r"""Right canonize this TN's ith site, inplace::
Expand All @@ -749,7 +718,11 @@ def right_canonize_site(self, i, bra=None):
bra : None or matching TensorNetwork to self, optional
If set, also update this TN's data with the conjugate canonization.
"""
self._right_decomp_site(i, bra=bra, method="lq")
tl, tr = self[i - 1], self[i]
tensor_canonize_bond(tr, tl)
if bra is not None:
bra[i].modify(data=conj(tr.data))
bra[i - 1].modify(data=conj(tl.data))

def left_canonize(self, stop=None, start=None, normalize=False, bra=None):
r"""Left canonize all or a portion of this TN. If this is a MPS,
Expand Down Expand Up @@ -981,8 +954,17 @@ def left_compress_site(self, i, bra=None, **compress_opts):
compress_opts
Supplied to :meth:`Tensor.split`.
"""
set_default_compress_mode(compress_opts, self.cyclic)
compress_opts.setdefault("absorb", "right")
self._left_decomp_site(i, bra=bra, **compress_opts)
compress_opts.setdefault("reduced", "left")

tl, tr = self[i], self[i + 1]
tensor_compress_bond(tl, tr, **compress_opts)

if bra is not None:
bra[i].modify(data=conj(tl.data))
bra[i + 1].modify(data=conj(tr.data))


def right_compress_site(self, i, bra=None, **compress_opts):
"""Right compress this 1D TN's ith site, such that the site is then
Expand All @@ -997,8 +979,16 @@ def right_compress_site(self, i, bra=None, **compress_opts):
compress_opts
Supplied to :meth:`Tensor.split`.
"""
set_default_compress_mode(compress_opts, self.cyclic)
compress_opts.setdefault("absorb", "left")
self._right_decomp_site(i, bra=bra, **compress_opts)
compress_opts.setdefault("reduced", "right")

tl, tr = self[i - 1], self[i]
tensor_compress_bond(tl, tr, **compress_opts)

if bra is not None:
bra[i].modify(data=conj(tr.data))
bra[i - 1].modify(data=conj(tl.data))

def left_compress(self, start=None, stop=None, bra=None, **compress_opts):
"""Compress this 1D TN, from left to right, such that it becomes
Expand Down Expand Up @@ -1066,13 +1056,19 @@ def compress(self, form=None, **compress_opts):
form = "right"

if isinstance(form, Integral):
self.right_canonize()
self.left_compress(**compress_opts)
self.right_canonize(stop=form)
if form < self.L // 2:
self.left_canonize()
self.right_compress(**compress_opts)
self.left_canonize(stop=form)
else:
self.right_canonize()
self.left_compress(**compress_opts)
self.right_canonize(stop=form)

elif form == "left":
self.right_canonize(bra=compress_opts.get("bra", None))
self.left_compress(**compress_opts)

elif form == "right":
self.left_canonize(bra=compress_opts.get("bra", None))
self.right_compress(**compress_opts)
Expand Down Expand Up @@ -1841,7 +1837,7 @@ def gate_with_auto_swap(
cur_orthog = (i + 1, i + 2)

# make sure sites are orthog center, then apply and split
mps.canonize((i, i + 1), cur_orthog)
mps.canonize((i, i + 1), cur_orthog=cur_orthog)
mps.gate_split_(
G, where=(i + 1, i) if need2flip else (i, i + 1), **compress_opts
)
Expand Down
Loading

0 comments on commit d38b707

Please sign in to comment.