Skip to content

Commit

Permalink
add MatrixProductState.from_fill_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 3, 2023
1 parent 60a0fb6 commit 4f2f8c1
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 40 deletions.
4 changes: 3 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Release notes for `quimb`.

**Enhancements:**

- {class}`~quimb.tensor.Circuit` : allow any gate to be controlled by any
- [`Circuit`](quimb.tensor.Circuit) : allow any gate to be controlled by any
number of qubits.
- add [`is_cyclic_x`](quimb.tensor.TensorNetwork2D.is_cyclic_x),
[`is_cyclic_y`](quimb.tensor.TensorNetwork2D.is_cyclic_y) and
Expand All @@ -24,6 +24,8 @@ Release notes for `quimb`.
- add [TensorNetwork.compress_all_1d](quimb.tensor.TensorNetwork.compress_all_1d)
for compressing generic tensor networks that you promise have a 1D topology,
without casting as a [TensorNetwork1D](quimb.tensor.TensorNetwork1D).
- add [MatrixProductState.from_fill_fn](quimb.tensor.tensor_1d.MatrixProductState.from_fill_fn)
for constructing MPS from a function that fills the tensors.

(whats-new-1-6-0)=

Expand Down
44 changes: 42 additions & 2 deletions quimb/tensor/tensor_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import operator
import functools
import itertools
import collections
from math import log2
from numbers import Integral
Expand Down Expand Up @@ -1125,7 +1126,8 @@ def compress_site(
self.right_compress_site(i + 1, bra=bra, **compress_opts)

def bond(self, i, j):
"""Get the name of the index defining the bond between sites i and j."""
"""Get the name of the index defining the bond between sites i and j.
"""
(bond,) = self[i].bonds(self[j])
return bond

Expand Down Expand Up @@ -1438,9 +1440,47 @@ def from_fill_fn(
site_tag_id="I{}",
tags=None,
):
"""Create a random MPS by supplying a function to generate the data
for each site.
Parameters
----------
fill_fn : callable
A function with signature
``fill_fn(shape : tuple[int]) -> array_like``.
L : int
The number of sites.
bond_dim : int
The bond dimension.
phys_dim : int or Sequence[int], optional
The physical dimension(s) of each site, if a sequence it will be
cycled over.
cyclic : bool, optional
Whether the MPS should be cyclic (periodic).
shape : str, optional
What specific order to layout the indices in, should be a sequence
of ``'l'``, ``'r'``, and ``'p'``, corresponding to left, right, and
physical indices respectively.
site_ind_id : str, optional
How to label the physical site indices.
site_tag_id : str, optional
How to tag the physical sites.
tags : str or sequence of str, optional
Global tags to attach to all tensors.
Returns
-------
MatrixProductState
"""
if set(shape) - set("lrp"):
raise ValueError("Invalid shape string: {}".format(shape))

# check for site varying physical dimensions
if isinstance(phys_dim, Integral):
phys_dims = itertools.repeat(phys_dim)
else:
phys_dims = itertools.cycle(phys_dim)

mps = TensorNetwork()
global_tags = tags_to_oset(tags)
bonds = collections.defaultdict(rand_uuid)
Expand All @@ -1459,7 +1499,7 @@ def from_fill_fn(
data_shape.append(bond_dim)
else: # c == 'p':
inds.append(site_ind_id.format(i))
data_shape.append(phys_dim)
data_shape.append(next(phys_dims))
data = fill_fn(data_shape)
tags = global_tags | oset((site_tag_id.format(i),))
mps |= Tensor(data, inds=inds, tags=tags)
Expand Down
75 changes: 38 additions & 37 deletions quimb/tensor/tensor_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3017,45 +3017,44 @@ def MPS_rand_state(
mps_opts
Supplied to :class:`~quimb.tensor.tensor_1d.MatrixProductState`.
"""
if trans_invar and not cyclic:
raise ValueError(
"State cannot be translationally invariant with open "
"boundary conditions."
if trans_invar:
if not cyclic:
raise ValueError(
"State cannot be translationally invariant with open "
"boundary conditions."
)
array = sensibly_scale(
randn(shape=(bond_dim, bond_dim, phys_dim), dtype=dtype)
)

# check for site varying physical dimensions
if isinstance(phys_dim, Integral):
phys_dims = itertools.repeat(phys_dim)
else:
phys_dims = itertools.cycle(phys_dim)

cyc_dim = (bond_dim,) if cyclic else ()

def gen_shapes():
yield (*cyc_dim, bond_dim, next(phys_dims))
for _ in range(L - 2):
yield (bond_dim, bond_dim, next(phys_dims))
yield (bond_dim, *cyc_dim, next(phys_dims))
def fill_fn(shape):
return array

def gen_data(shape):
return randn(shape, dtype=dtype)

if trans_invar:
array = sensibly_scale(gen_data(next(gen_shapes())))
arrays = (array for _ in range(L))
else:
arrays = map(sensibly_scale, map(gen_data, gen_shapes()))
def fill_fn(shape):
return sensibly_scale(randn(shape, dtype=dtype))

rmps = MatrixProductState(arrays, **mps_opts)
mps = MatrixProductState.from_fill_fn(
fill_fn,
L=L,
bond_dim=bond_dim,
phys_dim=phys_dim,
cyclic=cyclic,
**mps_opts,
)

if normalize == "left":
rmps.left_canonize(normalize=True)
if cyclic:
raise ValueError("Cannot left normalize cyclic MPS.")
mps.left_canonize(normalize=True)
elif normalize == "right":
rmps.left_canonize(normalize=True)
if cyclic:
raise ValueError("Cannot right normalize cyclic MPS.")
mps.left_canonize(normalize=True)
elif normalize:
rmps /= (rmps.H @ rmps) ** 0.5
mps.normalize()

return rmps
return mps


def MPS_product_state(arrays, cyclic=False, **mps_opts):
Expand Down Expand Up @@ -3238,15 +3237,17 @@ def MPS_zero_state(
mps_opts
Supplied to :class:`~quimb.tensor.tensor_1d.MatrixProductState`.
"""
cyc_dim = (bond_dim,) if cyclic else ()

def gen_arrays():
yield np.zeros((*cyc_dim, bond_dim, phys_dim), dtype=dtype)
for _ in range(L - 2):
yield np.zeros((bond_dim, bond_dim, phys_dim), dtype=dtype)
yield np.zeros((bond_dim, *cyc_dim, phys_dim), dtype=dtype)
def fill_fn(shape):
return np.zeros(shape, dtype=dtype)

return MatrixProductState(gen_arrays(), **mps_opts)
return MatrixProductState.from_fill_fn(
fill_fn,
L=L,
bond_dim=bond_dim,
phys_dim=phys_dim,
cyclic=cyclic,
**mps_opts
)


def MPS_sampler(L, dtype=complex, squeeze=True, **mps_opts):
Expand Down

0 comments on commit 4f2f8c1

Please sign in to comment.