Skip to content

Commit

Permalink
improvements to TN.gauge_all_simple
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Nov 16, 2024
1 parent 00dae54 commit be4ddf7
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Release notes for `quimb`.
- specialize [`CircuitMPS.local_expectation`](quimb.tensor.circuit.CircuitMPS.local_expectation) to make use of the MPS form.
- add [`PEPS.product_state`](quimb.tensor.tensor_2d.PEPS.product_state) for constructing a PEPS representing a product state.
- add [`PEPS.vacuum`](quimb.tensor.tensor_2d.PEPS.vacuum) for constructing a PEPS representing the vacuum state $|000\ldots0\rangle$.
- [tn.gauge_all_simple](quimb.tensor.tensor_core.TensorNetwork.gauge_all_simple): improve scheduling and add `damping` and `touched_tids` options.

---

Expand Down
127 changes: 109 additions & 18 deletions quimb/tensor/tensor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4330,6 +4330,10 @@ def make_norm(
The tags to identify the top and bottom.
return_all : bool, optional
Return the norm, the ket and the bra.
Returns
-------
norm : TensorNetwork
"""
ket = self.copy()
ket.add_tag(layer_tags[0])
Expand Down Expand Up @@ -6833,25 +6837,57 @@ def gauge_all_simple(
tol=0.0,
smudge=1e-12,
power=1.0,
damping=0.0,
gauges=None,
equalize_norms=False,
touched_tids=None,
progbar=False,
inplace=False,
):
"""Iterative gauge all the bonds in this tensor network with a 'simple
update' like strategy.
Parameters
----------
max_iterations : int, optional
The maximum number of iterations to perform.
tol : float, optional
The convergence tolerance for the singular values.
smudge : float, optional
The smudge factor to add to the singular values when gauging.
power : float, optional
The power to raise the singular values to when gauging.
damping : float, optional
The damping factor to apply to the gauging updates.
gauges : dict, optional
Supply the initial gauges to use.
equalize_norms : bool, optional
Whether to equalize the norms of the tensors after each update.
touched_tids : sequence of int, optional
The tensor identifiers to start the gauge sweep from.
progbar : bool, optional
Whether to show a progress bar.
inplace : bool, optional
Whether to perform the gauging inplace.
Returns
-------
TensorNetwork
"""
tn = self if inplace else self.copy()

# every index in the TN
inds = list(tn.ind_map)

# the vector 'gauges' that will live on the bonds
gauges_supplied = gauges is not None
if not gauges_supplied:
gauges = {}

_sval_mapper = {
# we store the actual ("conditioned") vectors treated as the
# environments separately from the 'exact' gauges
gauges_conditioned = {}
# if damping we need to mark if we have updated gauge specifically
have_conditioned = set()

_sval_conditioner = {
(True, True): lambda s: s,
(True, False): lambda s: s + smudge,
(False, True): lambda s: s**power,
Expand All @@ -6871,14 +6907,31 @@ def gauge_all_simple(
else:
pbar = None

it = 0
not_converged = True
while not_converged and it < max_iterations:
# keep track of which indices are available to be updated
if touched_tids is not None:
# use indices adjacent to the given tensors
next_touched = oset(
ix
for tid in touched_tids
for ix in tn.tensor_map[tid].inds
)
else:
# use all indices
next_touched = oset(tn._inner_inds)

it = 0
unconverged = True
while unconverged and it < max_iterations:
# can only converge if tol > 0.0
max_sdiff = -1.0

for ind in inds:
touched, next_touched = next_touched, oset()
# add an arbitrary index to start the sweep
queue = oset([touched.popleft()])

while queue:
ind = queue.popleft()

try:
tid1, tid2 = tn.ind_map[ind]
except (KeyError, ValueError):
Expand All @@ -6893,10 +6946,30 @@ def gauge_all_simple(
inv_gauges = []
for t, ixs in ((t1, lix), (t2, rix)):
for ix in ixs:
try:
s = _sval_mapper(gauges[ix])
except KeyError:
if ix not in gauges:
continue

if ix not in have_conditioned:
if ix not in gauges_conditioned:
# first iteration
s = _sval_conditioner(gauges[ix])
gauges_conditioned[ix] = s
else:
snew = _sval_conditioner(gauges[ix])
if damping == 0.0:
s = snew
else:
# damped update, combine old and new
sold = gauges_conditioned[ix]
s = damping * sold + (1 - damping) * snew
gauges_conditioned[ix] = s

# mark as computed
have_conditioned.add(ix)
else:
# have already computed s this sweep
s = gauges_conditioned[ix]

t.multiply_index_diagonal_(ix, s)
# keep track of how to invert gauge
inv_gauges.append((t, ix, 1 / s))
Expand All @@ -6911,17 +6984,17 @@ def gauge_all_simple(
)

s = info["singular_values"]
smax = do("linalg.norm", s)
new_gauge = s / smax
nfact = do("log10", smax) + nfact
snorm = do("linalg.norm", s)
new_gauge = s / snorm
nfact = do("log10", snorm) + nfact

if (tol > 0.0) or (pbar is not None):
# check convergence
old_gauge = gauges.get(bond, 1.0)

if size(old_gauge) != size(new_gauge):
# the bond has changed size, so we can't compare
# the singular values directly
# the bond has changed size, so we can't
# compare the singular values directly
old_gauge = 1.0

sdiff = do("linalg.norm", old_gauge - new_gauge)
Expand All @@ -6939,13 +7012,31 @@ def gauge_all_simple(
tn.strip_exponent(tid1)
tn.strip_exponent(tid2)

# mark conditioned version as out-of-date
have_conditioned.discard(bond)
has_changed = (tol == 0.0) or (sdiff > tol)

if has_changed:
# mark index and neighbors as touched for next sweep
next_touched.add(bond)

for neighbor_ind in tn._get_neighbor_inds(bond):
if neighbor_ind in tn._inner_inds:
if neighbor_ind in touched:
# move into queue
touched.remove(neighbor_ind)
queue.add(neighbor_ind)
if has_changed:
# mark as touched for next sweep
next_touched.add(neighbor_ind)

if pbar is not None:
pbar.update()
pbar.set_description(
f"max|dS|={max_sdiff:.2e}, " f"nfact={nfact:.2f}"
)

not_converged = (tol == 0.0) or (max_sdiff > tol)
unconverged = (tol == 0.0) or (max_sdiff > tol)
it += 1

if equalize_norms:
Expand Down Expand Up @@ -10564,7 +10655,7 @@ def max_bond(self):

@property
def shape(self):
"""Actual, i.e. exterior, shape of this TensorNetwork."""
"""Effective, i.e. outer, shape of this TensorNetwork."""
return tuple(di[0] for di in self.outer_dims_inds())

@property
Expand Down

0 comments on commit be4ddf7

Please sign in to comment.