Skip to content

Commit

Permalink
various direct contraction functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jul 25, 2023
1 parent b50ba52 commit 6852f2d
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 59 deletions.
6 changes: 6 additions & 0 deletions quimb/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
set_contract_strategy,
set_tensor_linop_backend,
tensor_linop_backend,
get_symbol,
inds_to_eq,
array_contract,
)
from .tensor_core import (
bonds_size,
Expand Down Expand Up @@ -327,6 +330,9 @@
"tensor_direct_product",
"tensor_fuse_squeeze",
"tensor_linop_backend",
"get_symbol",
"inds_to_eq",
"array_contract",
"tensor_network_align",
"tensor_network_apply_op_vec",
"tensor_network_distance",
Expand Down
104 changes: 102 additions & 2 deletions quimb/tensor/contraction.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Functions relating to tensor network contraction.
"""

import functools
import itertools
import threading
import contextlib
import collections

import opt_einsum as oe
from opt_einsum.contract import parse_backend
from autoray import infer_backend
from autoray import infer_backend, shape

from ..utils import concat

Expand Down Expand Up @@ -297,6 +297,106 @@ def get_contractor(
return info


@functools.lru_cache(2**12)
def get_symbol(i):
"""Get the 'ith' symbol.
"""
return oe.get_symbol(i)


def empty_symbol_map():
"""Get a default dictionary that will populate with symbol entries as they
are accessed.
"""
return collections.defaultdict(map(get_symbol, itertools.count()).__next__)


def inds_to_symbols(inputs):
"""Map a sequence of inputs terms, containing any hashable indices, to
single unicode letters, appropriate for einsum.
Parameters
----------
inputs : sequence of sequence of hashable
The input indices per tensor.
Returns
-------
symbols : dict[hashable, str]
The mapping from index to symbol.
"""
symbols = empty_symbol_map()
return {
ix: symbols[ix]
for term in inputs
for ix in term
}


@functools.lru_cache(2**12)
def inds_to_eq(inputs, output=None):
"""Turn input and output indices of any sort into a single 'equation'
string where each index is a single 'symbol' (unicode character).
Parameters
----------
inputs : sequence of sequence of hashable
The input indices per tensor.
output : sequence of hashable
The output indices.
Returns
-------
eq : str
The string to feed to einsum/contract.
"""
symbols = empty_symbol_map()
in_str = ("".join(symbols[ix] for ix in inds) for inds in inputs)
in_str = ",".join(in_str)
if output is None:
out_str = "".join(
ix for ix in symbols.values() if in_str.count(ix) == 1
)
else:
out_str = "".join(symbols[ix] for ix in output)
return f"{in_str}->{out_str}"


def array_contract(
arrays,
inputs,
output=None,
optimize=None,
backend=None,
**contract_opts
):
"""Contraction interface for raw arrays with arbitrary hashable indices.
Parameters
----------
arrays : sequence of array_like
The arrays to contract.
inputs : sequence of sequence of hashable
The input indices per tensor.
output : sequence of hashable, optional
The output indices, will be computed as every index that appears only
once in the inputs if not given.
optimize : None, str, path_like, PathOptimizer, optional
How to compute the contraction path.
backend : None, str, optional
Which backend to use for the contraction.
contract_opts
Supplied to ``contract_expression``.
Returns
-------
array_like
"""
eq = inds_to_eq(inputs, output)
shapes = tuple(shape(a) for a in arrays)
f = get_contractor(eq, *shapes, optimize=optimize, **contract_opts)
return f(*arrays, backend=backend)

try:
from opt_einsum.contract import infer_backend as _oe_infer_backend
del _oe_infer_backend
Expand Down
2 changes: 0 additions & 2 deletions quimb/tensor/tensor_arbgeom_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,6 @@ def draw(
Whether to show the legend of which terms are in which group.
ax : None or matplotlib.Axes, optional
Add to a existing set of axes.
return_fig : bool, optional
Whether to return any newly created figure.
"""
import networkx as nx
import matplotlib.pyplot as plt
Expand Down
69 changes: 15 additions & 54 deletions quimb/tensor/tensor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import opt_einsum as oe
import scipy.sparse.linalg as spla
from autoray import (
do, conj, astype, infer_backend, get_dtype_name, dag, shape
do, conj, astype, infer_backend, get_dtype_name, dag, shape, size
)
try:
from autoray import get_common_dtype
Expand All @@ -39,8 +39,16 @@
get_contract_strategy,
get_tensor_linop_backend,
contract_strategy,
get_symbol,
inds_to_symbols,
inds_to_eq,
array_contract,
)

_inds_to_eq = deprecated(inds_to_eq, '_inds_to_eq', 'inds_to_eq')
get_symbol = deprecated(
get_symbol, 'tensor_core.get_symbol', 'contraction.get_symbol'
)

# --------------------------------------------------------------------------- #
# Tensor Funcs #
Expand Down Expand Up @@ -89,43 +97,6 @@ def _gen_output_inds(all_inds):
yield ind


@functools.lru_cache(2**12)
def get_symbol(i):
"""Get the 'ith' symbol.
"""
return oe.get_symbol(i)


def empty_symbol_map():
"""Get a default dictionary that will populate with symbol entries as they
are accessed.
"""
return collections.defaultdict(map(get_symbol, itertools.count()).__next__)


@functools.lru_cache(2**12)
def _inds_to_eq(inputs, output):
"""Turn input and output indices of any sort into a single 'equation'
string where each index is a single 'symbol' (unicode character).
Parameters
----------
inputs : sequence of sequence of str
The input indices per tensor.
output : sequence of str
The output indices.
Returns
-------
eq : str
The string to feed to einsum/contract.
"""
symbol_get = empty_symbol_map().__getitem__
in_str = ("".join(map(symbol_get, inds)) for inds in inputs)
out_str = "".join(map(symbol_get, output))
return ",".join(in_str) + f"->{out_str}"


_VALID_CONTRACT_GET = {None, 'expression', 'path', 'path-info', 'symbol-map'}


Expand Down Expand Up @@ -222,26 +193,20 @@ def tensor_contract(
inds_out = tuple(output_inds)

# possibly map indices into the range needed by opt-einsum
eq = _inds_to_eq(inds, inds_out)
eq = inds_to_eq(inds, inds_out)

if get is not None:
check_opt('get', get, _VALID_CONTRACT_GET)

if get == 'symbol-map':
return {
get_symbol(i): ix
for i, ix in enumerate(unique(concat(inds)))
}
return inds_to_symbols(inds)

if get == 'path':
return get_contractor(eq, *shapes, get='path', **contract_opts)

if get == 'path-info':
pathinfo = get_contractor(eq, *shapes, get='info', **contract_opts)
pathinfo.quimb_symbol_map = {
get_symbol(i): ix
for i, ix in enumerate(unique(concat(inds)))
}
pathinfo.quimb_symbol_map = inds_to_symbols(inds)
return pathinfo

if get == 'expression':
Expand Down Expand Up @@ -2228,7 +2193,7 @@ def collapse_repeated(self, inplace=False):
if len(old_inds) == len(new_inds):
return t

eq = _inds_to_eq((old_inds,), new_inds)
eq = inds_to_eq((old_inds,), new_inds)
t.modify(apply=lambda x: do('einsum', eq, x, like=x),
inds=new_inds, left_inds=None)

Expand Down Expand Up @@ -4220,11 +4185,7 @@ def get_symbol_map(self):
--------
get_equation, get_inputs_output_size_dict
"""
symbol_map = empty_symbol_map()
for t in self:
for ix in t.inds:
symbol_map[ix]
return symbol_map
return inds_to_symbols(t.inds for t in self)

def get_equation(self, output_inds=None):
"""Get the 'equation' describing this tensor network, in ``einsum``
Expand Down Expand Up @@ -4254,7 +4215,7 @@ def get_equation(self, output_inds=None):
if output_inds is None:
output_inds = self.outer_inds()
inputs_inds = tuple(t.inds for t in self)
return _inds_to_eq(inputs_inds, output_inds)
return inds_to_eq(inputs_inds, output_inds)

def get_inputs_output_size_dict(self, output_inds=None):
"""Get a tuple of ``inputs``, ``output`` and ``size_dict`` suitable for
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tensor/test_tensor_2d_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_construct(self, Lx, Ly, H2_type, H1_type):
assert len({id(x) for x in ham.terms.values()}) == 1

print(ham)
fig = ham.draw(return_fig=True)
fig, ax = ham.draw()
plt.close(fig)

@pytest.mark.parametrize('Lx', [4, 5])
Expand Down

0 comments on commit 6852f2d

Please sign in to comment.