diff --git a/linear_operator/__init__.py b/linear_operator/__init__.py index bdfe087a..627b2c66 100644 --- a/linear_operator/__init__.py +++ b/linear_operator/__init__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from . import beta_features, operators, settings, utils -from .functions import ( +from linear_operator import beta_features, operators, settings, utils +from linear_operator.functions import ( add_diagonal, add_jitter, diagonalization, @@ -13,7 +13,7 @@ solve, sqrt_inv_matmul, ) -from .operators import LinearOperator, to_dense, to_linear_operator +from linear_operator.operators import LinearOperator, to_dense, to_linear_operator # Read version number as written by setuptools_scm try: diff --git a/linear_operator/beta_features.py b/linear_operator/beta_features.py index f7560676..1bb84d6b 100644 --- a/linear_operator/beta_features.py +++ b/linear_operator/beta_features.py @@ -2,7 +2,7 @@ import warnings -from .settings import _feature_flag +from linear_operator.settings import _feature_flag class _moved_beta_feature(object): diff --git a/linear_operator/functions/__init__.py b/linear_operator/functions/__init__.py index bb2c7c1f..8a2940a9 100644 --- a/linear_operator/functions/__init__.py +++ b/linear_operator/functions/__init__.py @@ -6,7 +6,7 @@ import torch -from ._dsmm import DSMM +from linear_operator.functions._dsmm import DSMM LinearOperatorType = Any # Want this to be "LinearOperator" but runtime type checker can't handle @@ -23,7 +23,7 @@ def add_diagonal(input: Anysor, diag: torch.Tensor) -> LinearOperatorType: :return: :math:`\mathbf A + \text{diag}(\mathbf d)`, where :math:`\mathbf A` is the linear operator and :math:`\mathbf d` is the diagonal component """ - from ..operators import to_linear_operator + from linear_operator.operators import to_linear_operator return to_linear_operator(input).add_diagonal(diag) @@ -61,7 +61,7 @@ def diagonalization( based on size if not specified. :return: eigenvalues and eigenvectors representing the diagonalization. """ - from ..operators import to_linear_operator + from linear_operator.operators import to_linear_operator return to_linear_operator(input).diagonalization(method=method) @@ -105,7 +105,7 @@ def inv_quad(input: Anysor, inv_quad_rhs: torch.Tensor, reduce_inv_quad: bool = :returns: The inverse quadratic term. If `reduce_inv_quad=True`, the inverse quadratic term is of shape (...). Otherwise, it is (... x M). """ - from ..operators import to_linear_operator + from linear_operator.operators import to_linear_operator return to_linear_operator(input).inv_quad(inv_quad_rhs, reduce_inv_quad=reduce_inv_quad) @@ -127,7 +127,7 @@ def inv_quad_logdet( :returns: The inverse quadratic term (or None), and the logdet term (or None). If `reduce_inv_quad=True`, the inverse quadratic term is of shape (...). Otherwise, it is (... x M). """ - from ..operators import to_linear_operator + from linear_operator.operators import to_linear_operator return to_linear_operator(input).inv_quad_logdet(inv_quad_rhs, logdet, reduce_inv_quad=reduce_inv_quad) @@ -156,7 +156,7 @@ def pivoted_cholesky( .. _Harbrecht et al., 2012: https://www.sciencedirect.com/science/article/pii/S0168927411001814 """ - from ..operators import to_linear_operator + from linear_operator.operators import to_linear_operator return to_linear_operator(input).pivoted_cholesky(rank=rank, error_tol=error_tol, return_pivots=return_pivots) @@ -173,7 +173,7 @@ def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOpe "cholesky", "lanczos", "symeig", "pivoted_cholesky", or "svd". :return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A`. """ - from ..operators import to_linear_operator + from linear_operator.operators import to_linear_operator return to_linear_operator(input).root_decomposition(method=method) @@ -199,7 +199,7 @@ def root_inv_decomposition( :param method: Root decomposition method to use (symeig, diagonalization, lanczos, or cholesky). :return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A^{-1}`. """ - from ..operators import to_linear_operator + from linear_operator.operators import to_linear_operator return to_linear_operator(input).root_inv_decomposition( initial_vectors=initial_vectors, test_vectors=test_vectors, method=method @@ -235,7 +235,7 @@ def solve(input: Anysor, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None) :param lhs: :math:`\mathbf L` - the left hand side :return: :math:`\mathbf A^{-1} \mathbf R` or :math:`\mathbf L \mathbf A^{-1} \mathbf R`. """ - from ..operators import to_linear_operator + from linear_operator.operators import to_linear_operator return to_linear_operator(input).solve(right_tensor=rhs, left_tensor=lhs) @@ -268,7 +268,7 @@ def sqrt_inv_matmul( :param lhs: :math:`\mathbf L` - the left hand side :return: :math:`\mathbf A^{-1/2} \mathbf R` or :math:`\mathbf L \mathbf A^{-1/2} \mathbf R`. """ - from ..operators import to_linear_operator + from linear_operator.operators import to_linear_operator return to_linear_operator(input).sqrt_inv_matmul(rhs=rhs, lhs=lhs) diff --git a/linear_operator/functions/_diagonalization.py b/linear_operator/functions/_diagonalization.py index 81e05a6b..8b7241ab 100644 --- a/linear_operator/functions/_diagonalization.py +++ b/linear_operator/functions/_diagonalization.py @@ -3,8 +3,8 @@ import torch from torch.autograd import Function -from .. import settings -from ..utils import lanczos +from linear_operator import settings +from linear_operator.utils import lanczos class Diagonalization(Function): diff --git a/linear_operator/functions/_dsmm.py b/linear_operator/functions/_dsmm.py index a25c27a5..8c316e07 100644 --- a/linear_operator/functions/_dsmm.py +++ b/linear_operator/functions/_dsmm.py @@ -2,7 +2,7 @@ from torch.autograd import Function -from ..utils.sparse import bdsmm +from linear_operator.utils.sparse import bdsmm class DSMM(Function): diff --git a/linear_operator/functions/_inv_quad.py b/linear_operator/functions/_inv_quad.py index 488b41ab..6d7a611b 100644 --- a/linear_operator/functions/_inv_quad.py +++ b/linear_operator/functions/_inv_quad.py @@ -3,7 +3,7 @@ import torch from torch.autograd import Function -from .. import settings +from linear_operator import settings def _solve(linear_op, rhs): diff --git a/linear_operator/functions/_inv_quad_logdet.py b/linear_operator/functions/_inv_quad_logdet.py index 304bf7f5..ec3e7896 100644 --- a/linear_operator/functions/_inv_quad_logdet.py +++ b/linear_operator/functions/_inv_quad_logdet.py @@ -5,9 +5,9 @@ import torch from torch.autograd import Function -from .. import settings -from ..utils.lanczos import lanczos_tridiag_to_diag -from ..utils.stochastic_lq import StochasticLQ +from linear_operator import settings +from linear_operator.utils.lanczos import lanczos_tridiag_to_diag +from linear_operator.utils.stochastic_lq import StochasticLQ class InvQuadLogdet(Function): diff --git a/linear_operator/functions/_matmul.py b/linear_operator/functions/_matmul.py index 35b571b9..13cf96ee 100644 --- a/linear_operator/functions/_matmul.py +++ b/linear_operator/functions/_matmul.py @@ -2,7 +2,7 @@ from torch.autograd import Function -from .. import settings +from linear_operator import settings class Matmul(Function): diff --git a/linear_operator/functions/_pivoted_cholesky.py b/linear_operator/functions/_pivoted_cholesky.py index 695f8af8..023eb2ae 100644 --- a/linear_operator/functions/_pivoted_cholesky.py +++ b/linear_operator/functions/_pivoted_cholesky.py @@ -3,9 +3,9 @@ import torch from torch.autograd import Function -from .. import settings -from ..utils.cholesky import psd_safe_cholesky -from ..utils.permutation import apply_permutation, inverse_permutation +from linear_operator import settings +from linear_operator.utils.cholesky import psd_safe_cholesky +from linear_operator.utils.permutation import apply_permutation, inverse_permutation class PivotedCholesky(Function): diff --git a/linear_operator/functions/_root_decomposition.py b/linear_operator/functions/_root_decomposition.py index a8b7b196..5b3fbe58 100644 --- a/linear_operator/functions/_root_decomposition.py +++ b/linear_operator/functions/_root_decomposition.py @@ -3,8 +3,8 @@ import torch from torch.autograd import Function -from .. import settings -from ..utils import lanczos +from linear_operator import settings +from linear_operator.utils import lanczos class RootDecomposition(Function): @@ -29,7 +29,7 @@ def forward( :return: :attr:`R`, such that :math:`R R^T \approx A`, and :attr:`R_inv`, such that :math:`R_{inv} R_{inv}^T \approx A^{-1}` (will only be populated if self.inverse = True) """ - from ..operators import to_linear_operator + from linear_operator.operators import to_linear_operator ctx.representation_tree = representation_tree ctx.device = device diff --git a/linear_operator/functions/_solve.py b/linear_operator/functions/_solve.py index 659df944..8783b68b 100644 --- a/linear_operator/functions/_solve.py +++ b/linear_operator/functions/_solve.py @@ -3,11 +3,11 @@ import torch from torch.autograd import Function -from .. import settings +from linear_operator import settings def _solve(linear_op, rhs): - from ..operators import CholLinearOperator, TriangularLinearOperator + from linear_operator.operators import CholLinearOperator, TriangularLinearOperator if isinstance(linear_op, (CholLinearOperator, TriangularLinearOperator)): # May want to do this for some KroneckerProductLinearOperators and possibly diff --git a/linear_operator/functions/_sqrt_inv_matmul.py b/linear_operator/functions/_sqrt_inv_matmul.py index 4aa5ba05..c137c762 100644 --- a/linear_operator/functions/_sqrt_inv_matmul.py +++ b/linear_operator/functions/_sqrt_inv_matmul.py @@ -3,7 +3,7 @@ import torch from torch.autograd import Function -from .. import settings, utils +from linear_operator import settings, utils class SqrtInvMatmul(Function): diff --git a/linear_operator/operators/__init__.py b/linear_operator/operators/__init__.py index 06fbc79f..41d093cf 100644 --- a/linear_operator/operators/__init__.py +++ b/linear_operator/operators/__init__.py @@ -1,40 +1,45 @@ #!/usr/bin/env python3 -from ._linear_operator import LinearOperator, to_dense -from .added_diag_linear_operator import AddedDiagLinearOperator -from .batch_repeat_linear_operator import BatchRepeatLinearOperator -from .block_diag_linear_operator import BlockDiagLinearOperator -from .block_interleaved_linear_operator import BlockInterleavedLinearOperator -from .block_linear_operator import BlockLinearOperator -from .cat_linear_operator import cat, CatLinearOperator -from .chol_linear_operator import CholLinearOperator -from .constant_mul_linear_operator import ConstantMulLinearOperator -from .dense_linear_operator import DenseLinearOperator, to_linear_operator -from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator -from .identity_linear_operator import IdentityLinearOperator -from .interpolated_linear_operator import InterpolatedLinearOperator -from .keops_linear_operator import KeOpsLinearOperator -from .kernel_linear_operator import KernelLinearOperator -from .kronecker_product_added_diag_linear_operator import KroneckerProductAddedDiagLinearOperator -from .kronecker_product_linear_operator import ( +from linear_operator.operators._linear_operator import LinearOperator, to_dense +from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator +from linear_operator.operators.batch_repeat_linear_operator import BatchRepeatLinearOperator +from linear_operator.operators.block_diag_linear_operator import BlockDiagLinearOperator +from linear_operator.operators.block_interleaved_linear_operator import BlockInterleavedLinearOperator +from linear_operator.operators.block_linear_operator import BlockLinearOperator +from linear_operator.operators.cat_linear_operator import cat, CatLinearOperator +from linear_operator.operators.chol_linear_operator import CholLinearOperator +from linear_operator.operators.constant_mul_linear_operator import ConstantMulLinearOperator +from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator +from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator +from linear_operator.operators.identity_linear_operator import IdentityLinearOperator +from linear_operator.operators.interpolated_linear_operator import InterpolatedLinearOperator +from linear_operator.operators.keops_linear_operator import KeOpsLinearOperator +from linear_operator.operators.kernel_linear_operator import KernelLinearOperator +from linear_operator.operators.kronecker_product_added_diag_linear_operator import ( + KroneckerProductAddedDiagLinearOperator, +) +from linear_operator.operators.kronecker_product_linear_operator import ( KroneckerProductDiagLinearOperator, KroneckerProductLinearOperator, KroneckerProductTriangularLinearOperator, ) -from .low_rank_root_added_diag_linear_operator import LowRankRootAddedDiagLinearOperator -from .low_rank_root_linear_operator import LowRankRootLinearOperator -from .masked_linear_operator import MaskedLinearOperator -from .matmul_linear_operator import MatmulLinearOperator -from .mul_linear_operator import MulLinearOperator -from .permutation_linear_operator import PermutationLinearOperator, TransposePermutationLinearOperator -from .psd_sum_linear_operator import PsdSumLinearOperator -from .root_linear_operator import RootLinearOperator -from .sum_batch_linear_operator import SumBatchLinearOperator -from .sum_kronecker_linear_operator import SumKroneckerLinearOperator -from .sum_linear_operator import SumLinearOperator -from .toeplitz_linear_operator import ToeplitzLinearOperator -from .triangular_linear_operator import TriangularLinearOperator -from .zero_linear_operator import ZeroLinearOperator +from linear_operator.operators.low_rank_root_added_diag_linear_operator import LowRankRootAddedDiagLinearOperator +from linear_operator.operators.low_rank_root_linear_operator import LowRankRootLinearOperator +from linear_operator.operators.masked_linear_operator import MaskedLinearOperator +from linear_operator.operators.matmul_linear_operator import MatmulLinearOperator +from linear_operator.operators.mul_linear_operator import MulLinearOperator +from linear_operator.operators.permutation_linear_operator import ( + PermutationLinearOperator, + TransposePermutationLinearOperator, +) +from linear_operator.operators.psd_sum_linear_operator import PsdSumLinearOperator +from linear_operator.operators.root_linear_operator import RootLinearOperator +from linear_operator.operators.sum_batch_linear_operator import SumBatchLinearOperator +from linear_operator.operators.sum_kronecker_linear_operator import SumKroneckerLinearOperator +from linear_operator.operators.sum_linear_operator import SumLinearOperator +from linear_operator.operators.toeplitz_linear_operator import ToeplitzLinearOperator +from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator +from linear_operator.operators.zero_linear_operator import ZeroLinearOperator __all__ = [ "to_dense", diff --git a/linear_operator/operators/_linear_operator.py b/linear_operator/operators/_linear_operator.py index eb31dcc4..b3243571 100644 --- a/linear_operator/operators/_linear_operator.py +++ b/linear_operator/operators/_linear_operator.py @@ -24,21 +24,22 @@ import linear_operator -from .. import settings, utils -from ..functions._diagonalization import Diagonalization -from ..functions._inv_quad import InvQuad -from ..functions._inv_quad_logdet import InvQuadLogdet -from ..functions._matmul import Matmul -from ..functions._pivoted_cholesky import PivotedCholesky -from ..functions._root_decomposition import RootDecomposition -from ..functions._solve import Solve -from ..functions._sqrt_inv_matmul import SqrtInvMatmul -from ..utils.broadcasting import _matmul_broadcast_shape -from ..utils.cholesky import psd_safe_cholesky -from ..utils.deprecation import _deprecate_renamed_methods -from ..utils.errors import CachingError -from ..utils.generic import _to_helper -from ..utils.getitem import ( +from linear_operator import settings, utils +from linear_operator.functions._diagonalization import Diagonalization +from linear_operator.functions._inv_quad import InvQuad +from linear_operator.functions._inv_quad_logdet import InvQuadLogdet +from linear_operator.functions._matmul import Matmul +from linear_operator.functions._pivoted_cholesky import PivotedCholesky +from linear_operator.functions._root_decomposition import RootDecomposition +from linear_operator.functions._solve import Solve +from linear_operator.functions._sqrt_inv_matmul import SqrtInvMatmul +from linear_operator.operators.linear_operator_representation_tree import LinearOperatorRepresentationTree +from linear_operator.utils.broadcasting import _matmul_broadcast_shape +from linear_operator.utils.cholesky import psd_safe_cholesky +from linear_operator.utils.deprecation import _deprecate_renamed_methods +from linear_operator.utils.errors import CachingError +from linear_operator.utils.generic import _to_helper +from linear_operator.utils.getitem import ( _compute_getitem_size, _convert_indices_to_tensors, _is_noop_index, @@ -46,11 +47,16 @@ _noop_index, IndexType, ) -from ..utils.lanczos import _postprocess_lanczos_root_inv_decomp -from ..utils.memoize import _is_in_cache_ignore_all_args, _is_in_cache_ignore_args, add_to_cache, cached, pop_from_cache -from ..utils.pinverse import stable_pinverse -from ..utils.warnings import NumericalWarning, PerformanceWarning -from .linear_operator_representation_tree import LinearOperatorRepresentationTree +from linear_operator.utils.lanczos import _postprocess_lanczos_root_inv_decomp +from linear_operator.utils.memoize import ( + _is_in_cache_ignore_all_args, + _is_in_cache_ignore_args, + add_to_cache, + cached, + pop_from_cache, +) +from linear_operator.utils.pinverse import stable_pinverse +from linear_operator.utils.warnings import NumericalWarning, PerformanceWarning _HANDLED_FUNCTIONS = {} _HANDLED_SECOND_ARG_FUNCTIONS = {} @@ -298,7 +304,7 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I col_interp_values = torch.tensor(1.0, dtype=self.dtype, device=self.device).expand_as(col_interp_indices) # Construct interpolated LinearOperator - from . import InterpolatedLinearOperator + from linear_operator.operators import InterpolatedLinearOperator res = InterpolatedLinearOperator( self, @@ -441,7 +447,7 @@ def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indice col_interp_values = torch.tensor(1.0, dtype=self.dtype, device=self.device).expand_as(col_interp_indices) # Construct interpolated LinearOperator - from . import InterpolatedLinearOperator + from linear_operator.operators import InterpolatedLinearOperator res = ( InterpolatedLinearOperator( @@ -504,8 +510,8 @@ def _cholesky( :param upper: Upper triangular or lower triangular factor (default: False). :return: Cholesky factor (lower or upper triangular) """ - from .keops_linear_operator import KeOpsLinearOperator - from .triangular_linear_operator import TriangularLinearOperator + from linear_operator.operators.keops_linear_operator import KeOpsLinearOperator + from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator evaluated_kern_mat = self.evaluate_kernel() @@ -581,7 +587,7 @@ def _mul_constant( :param other: The constant (or batch of constants) """ - from .constant_mul_linear_operator import ConstantMulLinearOperator + from linear_operator.operators.constant_mul_linear_operator import ConstantMulLinearOperator return ConstantMulLinearOperator(self, other) @@ -598,8 +604,8 @@ def _mul_matrix( :param other: The other linear operator to multiply against. """ - from .dense_linear_operator import DenseLinearOperator - from .mul_linear_operator import MulLinearOperator + from linear_operator.operators.dense_linear_operator import DenseLinearOperator + from linear_operator.operators.mul_linear_operator import MulLinearOperator self = self.evaluate_kernel() other = other.evaluate_kernel() @@ -635,8 +641,8 @@ def _prod_batch(self, dim: int) -> LinearOperator: :param dim: The (positive valued) dimension to multiply """ - from .mul_linear_operator import MulLinearOperator - from .root_linear_operator import RootLinearOperator + from linear_operator.operators.mul_linear_operator import MulLinearOperator + from linear_operator.operators.root_linear_operator import RootLinearOperator if self.size(dim) == 1: return self.squeeze(dim) @@ -731,7 +737,7 @@ def _root_inv_decomposition( :param test_vectors: Vectors used to test the accuracy of the decomposition. :return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A^{-1}`. """ - from .root_linear_operator import RootLinearOperator + from linear_operator.operators.root_linear_operator import RootLinearOperator roots, inv_roots = RootDecomposition.apply( self.representation_tree(), @@ -850,7 +856,7 @@ def _sum_batch(self, dim: int) -> LinearOperator: :param dim: The (positive valued) dimension to sum """ - from .sum_batch_linear_operator import SumBatchLinearOperator + from linear_operator.operators.sum_batch_linear_operator import SumBatchLinearOperator return SumBatchLinearOperator(self, block_dim=dim) @@ -954,8 +960,8 @@ def add_diagonal( :return: :math:`\mathbf A + \text{diag}(\mathbf d)`, where :math:`\mathbf A` is the linear operator and :math:`\mathbf d` is the diagonal component """ - from .added_diag_linear_operator import AddedDiagLinearOperator - from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator + from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator + from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator if not self.is_square: raise RuntimeError("add_diagonal only defined for square matrices") @@ -1049,10 +1055,10 @@ def add_low_rank( .. _Kernel Interpolation for Scalable Online Gaussian Processes: https://arxiv.org/abs/2103.01454. """ - from . import to_linear_operator - from .root_linear_operator import RootLinearOperator - from .sum_linear_operator import SumLinearOperator - from .triangular_linear_operator import TriangularLinearOperator + from linear_operator.operators import to_linear_operator + from linear_operator.operators.root_linear_operator import RootLinearOperator + from linear_operator.operators.sum_linear_operator import SumLinearOperator + from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator if not isinstance(self, SumLinearOperator): new_linear_op = self + to_linear_operator(low_rank_mat.matmul(low_rank_mat.mT)) @@ -1219,10 +1225,10 @@ def cat_rows( .. _Efficient Nonmyopic Bayesian Optimization via One-Shot Multistep Trees: https://arxiv.org/abs/2006.15779 """ - from . import to_linear_operator - from .cat_linear_operator import CatLinearOperator - from .root_linear_operator import RootLinearOperator - from .triangular_linear_operator import TriangularLinearOperator + from linear_operator.operators import to_linear_operator + from linear_operator.operators.cat_linear_operator import CatLinearOperator + from linear_operator.operators.root_linear_operator import RootLinearOperator + from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator if not generate_roots and generate_inv_roots: warnings.warn( @@ -1436,7 +1442,7 @@ def diagonalization( method = "lanczos" if method == "lanczos": - from ..operators import to_linear_operator + from linear_operator.operators import to_linear_operator evals, evecs = Diagonalization.apply( self.representation_tree(), @@ -1471,7 +1477,7 @@ def div(self, other: Union[float, torch.Tensor]) -> LinearOperator: :param other: Object to divide against :return: Result of division. """ - from .zero_linear_operator import ZeroLinearOperator + from linear_operator.operators.zero_linear_operator import ZeroLinearOperator if isinstance(other, ZeroLinearOperator): raise RuntimeError("Attempted to divide by a ZeroLinearOperator (divison by zero)") @@ -1686,8 +1692,8 @@ def inv_quad_logdet( """ # Special case: use Cholesky to compute these terms if settings.fast_computations.log_prob.off() or (self.size(-1) <= settings.max_cholesky_size.value()): - from .chol_linear_operator import CholLinearOperator - from .triangular_linear_operator import TriangularLinearOperator + from linear_operator.operators.chol_linear_operator import CholLinearOperator + from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator # if the root decomposition has already been computed and is triangular we can use it instead # of computing the cholesky. @@ -1747,7 +1753,7 @@ def inv_quad_logdet( preconditioner, precond_lt, logdet_p = self._preconditioner() if precond_lt is None: - from ..operators.identity_linear_operator import IdentityLinearOperator + from linear_operator.operators.identity_linear_operator import IdentityLinearOperator precond_lt = IdentityLinearOperator( diag_shape=self.size(-1), @@ -1824,7 +1830,7 @@ def matmul( _matmul_broadcast_shape(self.shape, other.shape) if isinstance(other, LinearOperator): - from .matmul_linear_operator import MatmulLinearOperator + from linear_operator.operators.matmul_linear_operator import MatmulLinearOperator return MatmulLinearOperator(self, other) @@ -1855,8 +1861,8 @@ def mul( :obj:`~linear_operator.operators.ConstantMulLinearOperator`. If :obj:`other` was a matrix or LinearOperator, this will likely be a :obj:`MulLinearOperator`. """ - from .dense_linear_operator import to_linear_operator - from .zero_linear_operator import ZeroLinearOperator + from linear_operator.operators.dense_linear_operator import to_linear_operator + from linear_operator.operators.zero_linear_operator import ZeroLinearOperator if isinstance(other, ZeroLinearOperator): return other @@ -2021,7 +2027,7 @@ def repeat(self, *sizes: Union[int, Tuple[int, ...]]) -> LinearOperator: :param sizes: The number of times to repeat this tensor along each dimension. :return: A LinearOperator with repeated dimensions. """ - from .batch_repeat_linear_operator import BatchRepeatLinearOperator + from linear_operator.operators.batch_repeat_linear_operator import BatchRepeatLinearOperator # Short path if no repetition is necessary if all(x == 1 for x in sizes) and len(sizes) == self.dim(): @@ -2130,9 +2136,9 @@ def root_decomposition( "cholesky", "lanczos", "symeig", "pivoted_cholesky", or "svd". :return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A`. """ - from . import to_linear_operator - from .chol_linear_operator import CholLinearOperator - from .root_linear_operator import RootLinearOperator + from linear_operator.operators import to_linear_operator + from linear_operator.operators.chol_linear_operator import CholLinearOperator + from linear_operator.operators.root_linear_operator import RootLinearOperator if not self.is_square: raise RuntimeError( @@ -2201,8 +2207,8 @@ def root_inv_decomposition( :param method: Root decomposition method to use (symeig, diagonalization, lanczos, or cholesky). :return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A^{-1}`. """ - from .dense_linear_operator import to_linear_operator - from .root_linear_operator import RootLinearOperator + from linear_operator.operators.dense_linear_operator import to_linear_operator + from linear_operator.operators.root_linear_operator import RootLinearOperator if not self.is_square: raise RuntimeError( @@ -2705,7 +2711,7 @@ def zero_mean_mvn_samples( :param num_samples: Number of samples to draw. :return: Samples from MVN :math:`\mathcal N( \mathbf 0, \mathbf A)`. """ - from ..utils.contour_integral_quad import contour_integral_quad + from linear_operator.utils.contour_integral_quad import contour_integral_quad if settings.ciq_samples.on(): base_samples = torch.randn( @@ -2753,12 +2759,12 @@ def __add__( self: Float[LinearOperator, "... #M #N"], other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: - from .added_diag_linear_operator import AddedDiagLinearOperator - from .dense_linear_operator import to_linear_operator - from .diag_linear_operator import DiagLinearOperator - from .root_linear_operator import RootLinearOperator - from .sum_linear_operator import SumLinearOperator - from .zero_linear_operator import ZeroLinearOperator + from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator + from linear_operator.operators.dense_linear_operator import to_linear_operator + from linear_operator.operators.diag_linear_operator import DiagLinearOperator + from linear_operator.operators.root_linear_operator import RootLinearOperator + from linear_operator.operators.sum_linear_operator import SumLinearOperator + from linear_operator.operators.zero_linear_operator import ZeroLinearOperator if isinstance(other, ZeroLinearOperator): return self diff --git a/linear_operator/operators/added_diag_linear_operator.py b/linear_operator/operators/added_diag_linear_operator.py index 546206ad..71a82e46 100644 --- a/linear_operator/operators/added_diag_linear_operator.py +++ b/linear_operator/operators/added_diag_linear_operator.py @@ -9,14 +9,14 @@ from jaxtyping import Float from torch import Tensor -from .. import settings -from ..utils.memoize import cached -from ..utils.warnings import NumericalWarning -from ._linear_operator import LinearOperator -from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator -from .psd_sum_linear_operator import PsdSumLinearOperator -from .root_linear_operator import RootLinearOperator -from .sum_linear_operator import SumLinearOperator +from linear_operator import settings +from linear_operator.operators._linear_operator import LinearOperator +from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator +from linear_operator.operators.psd_sum_linear_operator import PsdSumLinearOperator +from linear_operator.operators.root_linear_operator import RootLinearOperator +from linear_operator.operators.sum_linear_operator import SumLinearOperator +from linear_operator.utils.memoize import cached +from linear_operator.utils.warnings import NumericalWarning class AddedDiagLinearOperator(SumLinearOperator): @@ -86,7 +86,7 @@ def __add__( self: Float[LinearOperator, "... #M #N"], other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: - from .diag_linear_operator import DiagLinearOperator + from linear_operator.operators.diag_linear_operator import DiagLinearOperator if isinstance(other, DiagLinearOperator): return self.__class__(self._linear_op, self._diag_tensor + other) diff --git a/linear_operator/operators/batch_repeat_linear_operator.py b/linear_operator/operators/batch_repeat_linear_operator.py index 2ab11db1..04ee478e 100644 --- a/linear_operator/operators/batch_repeat_linear_operator.py +++ b/linear_operator/operators/batch_repeat_linear_operator.py @@ -8,10 +8,10 @@ from jaxtyping import Float from torch import Tensor -from .. import settings -from ..utils.broadcasting import _matmul_broadcast_shape -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator +from linear_operator import settings +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.utils.broadcasting import _matmul_broadcast_shape +from linear_operator.utils.memoize import cached class BatchRepeatLinearOperator(LinearOperator): @@ -42,7 +42,7 @@ def __init__(self, base_linear_op, batch_repeat=torch.Size((1,))): def _cholesky( self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False ) -> Float[LinearOperator, "*batch N N"]: - from .triangular_linear_operator import TriangularLinearOperator + from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator res = self.base_linear_op.cholesky(upper=upper)._tensor res = res.repeat(*self.batch_repeat, 1, 1) diff --git a/linear_operator/operators/block_diag_linear_operator.py b/linear_operator/operators/block_diag_linear_operator.py index de3a9dd6..971f5de2 100644 --- a/linear_operator/operators/block_diag_linear_operator.py +++ b/linear_operator/operators/block_diag_linear_operator.py @@ -7,9 +7,10 @@ from jaxtyping import Float from torch import Tensor -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator -from .block_linear_operator import BlockLinearOperator +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.block_linear_operator import BlockLinearOperator + +from linear_operator.utils.memoize import cached # metaclass of BlockDiagLinearOperator, overwrites behavior of constructor call @@ -17,7 +18,7 @@ # if base_linear_op is a DiagLinearOperator itself class _MetaBlockDiagLinearOperator(ABCMeta): def __call__(cls, base_linear_op: Union[LinearOperator, Tensor], block_dim: int = -3): - from .diag_linear_operator import DiagLinearOperator + from linear_operator.operators.diag_linear_operator import DiagLinearOperator if cls is BlockDiagLinearOperator and isinstance(base_linear_op, DiagLinearOperator): if block_dim != -3: @@ -74,7 +75,7 @@ def _add_batch_dim( def _cholesky( self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False ) -> Float[LinearOperator, "*batch N N"]: - from .triangular_linear_operator import TriangularLinearOperator + from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator chol = self.__class__(self.base_linear_op.cholesky(upper=upper)) return TriangularLinearOperator(chol, upper=upper) @@ -183,7 +184,7 @@ def matmul( self: Float[LinearOperator, "*batch M N"], other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], ) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: - from .diag_linear_operator import DiagLinearOperator + from linear_operator.operators.diag_linear_operator import DiagLinearOperator # this is trivial if we multiply two BlockDiagLinearOperator with matching block sizes if isinstance(other, BlockDiagLinearOperator) and self.base_linear_op.shape == other.base_linear_op.shape: diff --git a/linear_operator/operators/block_interleaved_linear_operator.py b/linear_operator/operators/block_interleaved_linear_operator.py index 1823d219..ccfd6bb1 100644 --- a/linear_operator/operators/block_interleaved_linear_operator.py +++ b/linear_operator/operators/block_interleaved_linear_operator.py @@ -5,9 +5,10 @@ from jaxtyping import Float from torch import Tensor -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator -from .block_linear_operator import BlockLinearOperator +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.block_linear_operator import BlockLinearOperator + +from linear_operator.utils.memoize import cached class BlockInterleavedLinearOperator(BlockLinearOperator): @@ -42,7 +43,7 @@ def _add_batch_dim(self, other): def _cholesky( self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False ) -> Float[LinearOperator, "*batch N N"]: - from .triangular_linear_operator import TriangularLinearOperator + from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator chol = self.__class__(self.base_linear_op.cholesky(upper=upper)) return TriangularLinearOperator(chol, upper=upper) diff --git a/linear_operator/operators/block_linear_operator.py b/linear_operator/operators/block_linear_operator.py index 7de2fbd0..8ad44763 100644 --- a/linear_operator/operators/block_linear_operator.py +++ b/linear_operator/operators/block_linear_operator.py @@ -7,9 +7,10 @@ from jaxtyping import Float from torch import Tensor -from ..utils.getitem import _is_noop_index, _noop_index -from ._linear_operator import IndexType, LinearOperator -from .dense_linear_operator import to_linear_operator +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.dense_linear_operator import to_linear_operator + +from linear_operator.utils.getitem import _is_noop_index, _noop_index class BlockLinearOperator(LinearOperator): @@ -154,7 +155,7 @@ def _mul_constant( ) -> Float[LinearOperator, "*batch M N"]: # We're using a custom method here - the constant mul is applied to the base_lazy tensor # This preserves the block structure - from .constant_mul_linear_operator import ConstantMulLinearOperator + from linear_operator.operators.constant_mul_linear_operator import ConstantMulLinearOperator return self.__class__(ConstantMulLinearOperator(self.base_linear_op, other)) diff --git a/linear_operator/operators/cat_linear_operator.py b/linear_operator/operators/cat_linear_operator.py index 1adc300e..521d8e02 100644 --- a/linear_operator/operators/cat_linear_operator.py +++ b/linear_operator/operators/cat_linear_operator.py @@ -8,12 +8,13 @@ from jaxtyping import Float from torch import Tensor -from ..utils.broadcasting import _matmul_broadcast_shape -from ..utils.deprecation import bool_compat -from ..utils.generic import _to_helper -from ..utils.getitem import _noop_index -from ._linear_operator import IndexType, LinearOperator, to_dense -from .dense_linear_operator import DenseLinearOperator, to_linear_operator +from linear_operator.operators._linear_operator import IndexType, LinearOperator, to_dense +from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator + +from linear_operator.utils.broadcasting import _matmul_broadcast_shape +from linear_operator.utils.deprecation import bool_compat +from linear_operator.utils.generic import _to_helper +from linear_operator.utils.getitem import _noop_index def cat(inputs, dim=0, output_device=None): diff --git a/linear_operator/operators/chol_linear_operator.py b/linear_operator/operators/chol_linear_operator.py index dfd0706a..6b502d91 100644 --- a/linear_operator/operators/chol_linear_operator.py +++ b/linear_operator/operators/chol_linear_operator.py @@ -9,10 +9,11 @@ from jaxtyping import Float from torch import Tensor -from ..utils.memoize import cached -from ._linear_operator import LinearOperator -from .root_linear_operator import RootLinearOperator -from .triangular_linear_operator import _TriangularLinearOperatorBase, TriangularLinearOperator +from linear_operator.operators._linear_operator import LinearOperator +from linear_operator.operators.root_linear_operator import RootLinearOperator +from linear_operator.operators.triangular_linear_operator import _TriangularLinearOperatorBase, TriangularLinearOperator + +from linear_operator.utils.memoize import cached class CholLinearOperator(RootLinearOperator): diff --git a/linear_operator/operators/constant_mul_linear_operator.py b/linear_operator/operators/constant_mul_linear_operator.py index e6c6de60..44215365 100644 --- a/linear_operator/operators/constant_mul_linear_operator.py +++ b/linear_operator/operators/constant_mul_linear_operator.py @@ -8,9 +8,10 @@ from jaxtyping import Float from torch import Tensor -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator -from .root_linear_operator import RootLinearOperator +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.root_linear_operator import RootLinearOperator + +from linear_operator.utils.memoize import cached class ConstantMulLinearOperator(LinearOperator): diff --git a/linear_operator/operators/dense_linear_operator.py b/linear_operator/operators/dense_linear_operator.py index fb5e3eea..a55db844 100644 --- a/linear_operator/operators/dense_linear_operator.py +++ b/linear_operator/operators/dense_linear_operator.py @@ -8,7 +8,7 @@ from jaxtyping import Float from torch import Tensor -from ._linear_operator import IndexType, LinearOperator, to_dense +from linear_operator.operators._linear_operator import IndexType, LinearOperator, to_dense class DenseLinearOperator(LinearOperator): diff --git a/linear_operator/operators/diag_linear_operator.py b/linear_operator/operators/diag_linear_operator.py index bff03a9d..a5553bd0 100644 --- a/linear_operator/operators/diag_linear_operator.py +++ b/linear_operator/operators/diag_linear_operator.py @@ -8,12 +8,12 @@ from jaxtyping import Float from torch import Tensor -from .. import settings -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator -from .block_diag_linear_operator import BlockDiagLinearOperator -from .dense_linear_operator import DenseLinearOperator -from .triangular_linear_operator import TriangularLinearOperator +from linear_operator import settings +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.block_diag_linear_operator import BlockDiagLinearOperator +from linear_operator.operators.dense_linear_operator import DenseLinearOperator +from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator +from linear_operator.utils.memoize import cached class DiagLinearOperator(TriangularLinearOperator): @@ -33,7 +33,7 @@ def __add__( ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: if isinstance(other, DiagLinearOperator): return self.add_diagonal(other._diag) - from .added_diag_linear_operator import AddedDiagLinearOperator + from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator return AddedDiagLinearOperator(other, self) diff --git a/linear_operator/operators/identity_linear_operator.py b/linear_operator/operators/identity_linear_operator.py index 3f8444c9..435e7cf0 100644 --- a/linear_operator/operators/identity_linear_operator.py +++ b/linear_operator/operators/identity_linear_operator.py @@ -8,13 +8,13 @@ from jaxtyping import Float from torch import Tensor -from ..utils.generic import _to_helper +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator +from linear_operator.operators.zero_linear_operator import ZeroLinearOperator -from ..utils.getitem import _compute_getitem_size, _is_noop_index -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator -from .diag_linear_operator import ConstantDiagLinearOperator -from .zero_linear_operator import ZeroLinearOperator +from linear_operator.utils.generic import _to_helper +from linear_operator.utils.getitem import _compute_getitem_size, _is_noop_index +from linear_operator.utils.memoize import cached class IdentityLinearOperator(ConstantDiagLinearOperator): diff --git a/linear_operator/operators/interpolated_linear_operator.py b/linear_operator/operators/interpolated_linear_operator.py index 49e8a19a..d088eb25 100644 --- a/linear_operator/operators/interpolated_linear_operator.py +++ b/linear_operator/operators/interpolated_linear_operator.py @@ -8,15 +8,16 @@ from jaxtyping import Float from torch import Tensor -from ..utils import sparse -from ..utils.broadcasting import _pad_with_singletons -from ..utils.generic import _to_helper -from ..utils.getitem import _noop_index -from ..utils.interpolation import left_interp, left_t_interp -from ._linear_operator import IndexType, LinearOperator -from .dense_linear_operator import DenseLinearOperator, to_linear_operator -from .diag_linear_operator import DiagLinearOperator -from .root_linear_operator import RootLinearOperator +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator +from linear_operator.operators.diag_linear_operator import DiagLinearOperator +from linear_operator.operators.root_linear_operator import RootLinearOperator + +from linear_operator.utils import sparse +from linear_operator.utils.broadcasting import _pad_with_singletons +from linear_operator.utils.generic import _to_helper +from linear_operator.utils.getitem import _noop_index +from linear_operator.utils.interpolation import left_interp, left_t_interp class InterpolatedLinearOperator(LinearOperator): @@ -397,7 +398,7 @@ def _sum_batch(self, dim: int) -> LinearOperator: right_interp_values = right_interp_values.permute(permute_order).reshape(right_shape) # Make the base_lazy tensor block diagonal - from .block_diag_linear_operator import BlockDiagLinearOperator + from linear_operator.operators.block_diag_linear_operator import BlockDiagLinearOperator block_diag = BlockDiagLinearOperator(self.base_linear_op, block_dim=dim) diff --git a/linear_operator/operators/keops_linear_operator.py b/linear_operator/operators/keops_linear_operator.py index 6990b1f2..9fd0cd83 100644 --- a/linear_operator/operators/keops_linear_operator.py +++ b/linear_operator/operators/keops_linear_operator.py @@ -8,9 +8,10 @@ from jaxtyping import Float from torch import Tensor -from ..utils.getitem import _noop_index -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator +from linear_operator.operators._linear_operator import IndexType, LinearOperator + +from linear_operator.utils.getitem import _noop_index +from linear_operator.utils.memoize import cached class KeOpsLinearOperator(LinearOperator): diff --git a/linear_operator/operators/kernel_linear_operator.py b/linear_operator/operators/kernel_linear_operator.py index 1df031bc..0d1c009d 100644 --- a/linear_operator/operators/kernel_linear_operator.py +++ b/linear_operator/operators/kernel_linear_operator.py @@ -6,10 +6,11 @@ from jaxtyping import Float from torch import Tensor -from ..utils.broadcasting import _pad_with_singletons -from ..utils.getitem import _noop_index, IndexType -from ..utils.memoize import cached -from ._linear_operator import LinearOperator, to_dense +from linear_operator.operators._linear_operator import LinearOperator, to_dense + +from linear_operator.utils.broadcasting import _pad_with_singletons +from linear_operator.utils.getitem import _noop_index, IndexType +from linear_operator.utils.memoize import cached def _x_getitem(x, batch_indices, data_index): diff --git a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py index bfa97e77..d4a166b2 100644 --- a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py +++ b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py @@ -6,12 +6,15 @@ from jaxtyping import Float from torch import Tensor -from .. import settings -from ._linear_operator import LinearOperator -from .added_diag_linear_operator import AddedDiagLinearOperator -from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator -from .kronecker_product_linear_operator import KroneckerProductDiagLinearOperator, KroneckerProductLinearOperator -from .matmul_linear_operator import MatmulLinearOperator +from linear_operator import settings +from linear_operator.operators._linear_operator import LinearOperator +from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator +from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator +from linear_operator.operators.kronecker_product_linear_operator import ( + KroneckerProductDiagLinearOperator, + KroneckerProductLinearOperator, +) +from linear_operator.operators.matmul_linear_operator import MatmulLinearOperator def _constant_kpadlt_constructor(lt, dlt): diff --git a/linear_operator/operators/kronecker_product_linear_operator.py b/linear_operator/operators/kronecker_product_linear_operator.py index bd889310..1cece8b0 100644 --- a/linear_operator/operators/kronecker_product_linear_operator.py +++ b/linear_operator/operators/kronecker_product_linear_operator.py @@ -8,13 +8,13 @@ from jaxtyping import Float from torch import Tensor -from .. import settings -from ..utils.broadcasting import _matmul_broadcast_shape -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator -from .dense_linear_operator import to_linear_operator -from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator -from .triangular_linear_operator import _TriangularLinearOperatorBase, TriangularLinearOperator +from linear_operator import settings +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.dense_linear_operator import to_linear_operator +from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator +from linear_operator.operators.triangular_linear_operator import _TriangularLinearOperatorBase, TriangularLinearOperator +from linear_operator.utils.broadcasting import _matmul_broadcast_shape +from linear_operator.utils.memoize import cached def _kron_diag(*lts) -> Tensor: @@ -100,11 +100,13 @@ def __add__( other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: if isinstance(other, (KroneckerProductDiagLinearOperator, ConstantDiagLinearOperator)): - from .kronecker_product_added_diag_linear_operator import KroneckerProductAddedDiagLinearOperator + from linear_operator.operators.kronecker_product_added_diag_linear_operator import ( + KroneckerProductAddedDiagLinearOperator, + ) return KroneckerProductAddedDiagLinearOperator(self, other) if isinstance(other, KroneckerProductLinearOperator): - from .sum_kronecker_linear_operator import SumKroneckerLinearOperator + from linear_operator.operators.sum_kronecker_linear_operator import SumKroneckerLinearOperator return SumKroneckerLinearOperator(self, other) if isinstance(other, DiagLinearOperator): @@ -115,7 +117,9 @@ def add_diagonal( self: Float[LinearOperator, "*batch N N"], diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]], ) -> Float[LinearOperator, "*batch N N"]: - from .kronecker_product_added_diag_linear_operator import KroneckerProductAddedDiagLinearOperator + from linear_operator.operators.kronecker_product_added_diag_linear_operator import ( + KroneckerProductAddedDiagLinearOperator, + ) if not self.is_square: raise RuntimeError("add_diag only defined for square matrices") diff --git a/linear_operator/operators/low_rank_root_added_diag_linear_operator.py b/linear_operator/operators/low_rank_root_added_diag_linear_operator.py index f6db058f..6e7f1aba 100644 --- a/linear_operator/operators/low_rank_root_added_diag_linear_operator.py +++ b/linear_operator/operators/low_rank_root_added_diag_linear_operator.py @@ -5,14 +5,15 @@ from jaxtyping import Float from torch import Tensor -from ..utils.cholesky import psd_safe_cholesky -from ..utils.memoize import cached -from . import to_dense -from ._linear_operator import LinearOperator -from .added_diag_linear_operator import AddedDiagLinearOperator -from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator -from .low_rank_root_linear_operator import LowRankRootLinearOperator -from .sum_batch_linear_operator import SumBatchLinearOperator +from linear_operator.operators import to_dense +from linear_operator.operators._linear_operator import LinearOperator +from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator +from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator +from linear_operator.operators.low_rank_root_linear_operator import LowRankRootLinearOperator +from linear_operator.operators.sum_batch_linear_operator import SumBatchLinearOperator + +from linear_operator.utils.cholesky import psd_safe_cholesky +from linear_operator.utils.memoize import cached class LowRankRootAddedDiagLinearOperator(AddedDiagLinearOperator): @@ -100,7 +101,7 @@ def __add__( self: Float[LinearOperator, "... #M #N"], other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: - from .diag_linear_operator import DiagLinearOperator + from linear_operator.operators.diag_linear_operator import DiagLinearOperator if isinstance(other, DiagLinearOperator): return self.__class__(self._linear_op, self._diag_tensor + other) diff --git a/linear_operator/operators/low_rank_root_linear_operator.py b/linear_operator/operators/low_rank_root_linear_operator.py index dc8e8588..3cd39e2c 100644 --- a/linear_operator/operators/low_rank_root_linear_operator.py +++ b/linear_operator/operators/low_rank_root_linear_operator.py @@ -5,8 +5,8 @@ from jaxtyping import Float from torch import Tensor -from ._linear_operator import LinearOperator -from .root_linear_operator import RootLinearOperator +from linear_operator.operators._linear_operator import LinearOperator +from linear_operator.operators.root_linear_operator import RootLinearOperator class LowRankRootLinearOperator(RootLinearOperator): @@ -22,8 +22,10 @@ def add_diagonal( self: Float[LinearOperator, "*batch N N"], diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]], ) -> Float[LinearOperator, "*batch N N"]: - from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator - from .low_rank_root_added_diag_linear_operator import LowRankRootAddedDiagLinearOperator + from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator + from linear_operator.operators.low_rank_root_added_diag_linear_operator import ( + LowRankRootAddedDiagLinearOperator, + ) if not self.is_square: raise RuntimeError("add_diag only defined for square matrices") @@ -52,8 +54,10 @@ def __add__( self: Float[LinearOperator, "... #M #N"], other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: - from .diag_linear_operator import DiagLinearOperator - from .low_rank_root_added_diag_linear_operator import LowRankRootAddedDiagLinearOperator + from linear_operator.operators.diag_linear_operator import DiagLinearOperator + from linear_operator.operators.low_rank_root_added_diag_linear_operator import ( + LowRankRootAddedDiagLinearOperator, + ) if isinstance(other, DiagLinearOperator): return LowRankRootAddedDiagLinearOperator(self, other) diff --git a/linear_operator/operators/masked_linear_operator.py b/linear_operator/operators/masked_linear_operator.py index 1783f93d..b9ffd148 100644 --- a/linear_operator/operators/masked_linear_operator.py +++ b/linear_operator/operators/masked_linear_operator.py @@ -4,9 +4,9 @@ from jaxtyping import Bool, Float from torch import Tensor -from ..utils.generic import _to_helper +from linear_operator.operators._linear_operator import _is_noop_index, IndexType, LinearOperator -from ._linear_operator import _is_noop_index, IndexType, LinearOperator +from linear_operator.utils.generic import _to_helper class MaskedLinearOperator(LinearOperator): diff --git a/linear_operator/operators/matmul_linear_operator.py b/linear_operator/operators/matmul_linear_operator.py index 05cb0fc0..0bd93f21 100644 --- a/linear_operator/operators/matmul_linear_operator.py +++ b/linear_operator/operators/matmul_linear_operator.py @@ -6,12 +6,13 @@ from jaxtyping import Float from torch import Tensor -from ..utils.broadcasting import _matmul_broadcast_shape, _pad_with_singletons -from ..utils.getitem import _noop_index -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator -from .dense_linear_operator import DenseLinearOperator, to_linear_operator -from .diag_linear_operator import DiagLinearOperator +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator +from linear_operator.operators.diag_linear_operator import DiagLinearOperator + +from linear_operator.utils.broadcasting import _matmul_broadcast_shape, _pad_with_singletons +from linear_operator.utils.getitem import _noop_index +from linear_operator.utils.memoize import cached def _inner_repeat(tensor, amt): diff --git a/linear_operator/operators/mul_linear_operator.py b/linear_operator/operators/mul_linear_operator.py index 27025861..9d3dfd60 100644 --- a/linear_operator/operators/mul_linear_operator.py +++ b/linear_operator/operators/mul_linear_operator.py @@ -5,11 +5,12 @@ from jaxtyping import Float from torch import Tensor -from ..utils.broadcasting import _matmul_broadcast_shape -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator -from .linear_operator_representation_tree import LinearOperatorRepresentationTree -from .root_linear_operator import RootLinearOperator +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.linear_operator_representation_tree import LinearOperatorRepresentationTree +from linear_operator.operators.root_linear_operator import RootLinearOperator + +from linear_operator.utils.broadcasting import _matmul_broadcast_shape +from linear_operator.utils.memoize import cached class MulLinearOperator(LinearOperator): diff --git a/linear_operator/operators/permutation_linear_operator.py b/linear_operator/operators/permutation_linear_operator.py index 25733a61..b0bdc3b7 100644 --- a/linear_operator/operators/permutation_linear_operator.py +++ b/linear_operator/operators/permutation_linear_operator.py @@ -4,7 +4,7 @@ from jaxtyping import Float from torch import Tensor -from ._linear_operator import LinearOperator +from linear_operator.operators._linear_operator import LinearOperator class AbstractPermutationLinearOperator(LinearOperator): diff --git a/linear_operator/operators/psd_sum_linear_operator.py b/linear_operator/operators/psd_sum_linear_operator.py index 5af0a6e9..c35ac090 100644 --- a/linear_operator/operators/psd_sum_linear_operator.py +++ b/linear_operator/operators/psd_sum_linear_operator.py @@ -2,8 +2,8 @@ from jaxtyping import Float from torch import Tensor -from ._linear_operator import LinearOperator -from .sum_linear_operator import SumLinearOperator +from linear_operator.operators._linear_operator import LinearOperator +from linear_operator.operators.sum_linear_operator import SumLinearOperator class PsdSumLinearOperator(SumLinearOperator): diff --git a/linear_operator/operators/root_linear_operator.py b/linear_operator/operators/root_linear_operator.py index 5784c227..50a257ef 100644 --- a/linear_operator/operators/root_linear_operator.py +++ b/linear_operator/operators/root_linear_operator.py @@ -5,12 +5,13 @@ from jaxtyping import Float from torch import Tensor -from ..utils.broadcasting import _pad_with_singletons -from ..utils.getitem import _equal_indices, _noop_index -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator -from .dense_linear_operator import DenseLinearOperator, to_linear_operator -from .matmul_linear_operator import MatmulLinearOperator +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator +from linear_operator.operators.matmul_linear_operator import MatmulLinearOperator + +from linear_operator.utils.broadcasting import _pad_with_singletons +from linear_operator.utils.getitem import _equal_indices, _noop_index +from linear_operator.utils.memoize import cached class RootLinearOperator(LinearOperator): diff --git a/linear_operator/operators/sum_batch_linear_operator.py b/linear_operator/operators/sum_batch_linear_operator.py index 1e6fe3b2..50042a58 100644 --- a/linear_operator/operators/sum_batch_linear_operator.py +++ b/linear_operator/operators/sum_batch_linear_operator.py @@ -4,10 +4,11 @@ from jaxtyping import Float from torch import Tensor -from ..utils.broadcasting import _pad_with_singletons -from ..utils.getitem import _noop_index -from ._linear_operator import IndexType, LinearOperator -from .block_linear_operator import BlockLinearOperator +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.block_linear_operator import BlockLinearOperator + +from linear_operator.utils.broadcasting import _pad_with_singletons +from linear_operator.utils.getitem import _noop_index class SumBatchLinearOperator(BlockLinearOperator): diff --git a/linear_operator/operators/sum_kronecker_linear_operator.py b/linear_operator/operators/sum_kronecker_linear_operator.py index 56722264..c2930ffe 100644 --- a/linear_operator/operators/sum_kronecker_linear_operator.py +++ b/linear_operator/operators/sum_kronecker_linear_operator.py @@ -5,9 +5,9 @@ from jaxtyping import Float from torch import Tensor -from ._linear_operator import LinearOperator -from .kronecker_product_linear_operator import KroneckerProductLinearOperator -from .sum_linear_operator import SumLinearOperator +from linear_operator.operators._linear_operator import LinearOperator +from linear_operator.operators.kronecker_product_linear_operator import KroneckerProductLinearOperator +from linear_operator.operators.sum_linear_operator import SumLinearOperator class SumKroneckerLinearOperator(SumLinearOperator): diff --git a/linear_operator/operators/sum_linear_operator.py b/linear_operator/operators/sum_linear_operator.py index 018088a8..3fb511b1 100644 --- a/linear_operator/operators/sum_linear_operator.py +++ b/linear_operator/operators/sum_linear_operator.py @@ -5,12 +5,13 @@ from jaxtyping import Float from torch import Tensor -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator -from .dense_linear_operator import to_linear_operator -from .zero_linear_operator import ZeroLinearOperator +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.dense_linear_operator import to_linear_operator +from linear_operator.operators.zero_linear_operator import ZeroLinearOperator -# from .broadcasted_linear_operator import BroadcastedLinearOperator +from linear_operator.utils.memoize import cached + +# from linear_operator.operators.broadcasted_linear_operator import BroadcastedLinearOperator class SumLinearOperator(LinearOperator): @@ -83,8 +84,8 @@ def __add__( self: Float[LinearOperator, "... #M #N"], other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: - from .added_diag_linear_operator import AddedDiagLinearOperator - from .diag_linear_operator import DiagLinearOperator + from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator + from linear_operator.operators.diag_linear_operator import DiagLinearOperator if isinstance(other, ZeroLinearOperator): return self diff --git a/linear_operator/operators/toeplitz_linear_operator.py b/linear_operator/operators/toeplitz_linear_operator.py index b54daca1..90057e2e 100644 --- a/linear_operator/operators/toeplitz_linear_operator.py +++ b/linear_operator/operators/toeplitz_linear_operator.py @@ -5,8 +5,9 @@ from jaxtyping import Float from torch import Tensor -from ..utils.toeplitz import sym_toeplitz_derivative_quadratic_form, sym_toeplitz_matmul -from ._linear_operator import IndexType, LinearOperator +from linear_operator.operators._linear_operator import IndexType, LinearOperator + +from linear_operator.utils.toeplitz import sym_toeplitz_derivative_quadratic_form, sym_toeplitz_matmul class ToeplitzLinearOperator(LinearOperator): diff --git a/linear_operator/operators/triangular_linear_operator.py b/linear_operator/operators/triangular_linear_operator.py index f5b45147..3dbeb883 100644 --- a/linear_operator/operators/triangular_linear_operator.py +++ b/linear_operator/operators/triangular_linear_operator.py @@ -6,11 +6,12 @@ from jaxtyping import Float from torch import Tensor -from ..utils.errors import NotPSDError -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator -from .batch_repeat_linear_operator import BatchRepeatLinearOperator -from .dense_linear_operator import DenseLinearOperator +from linear_operator.operators._linear_operator import IndexType, LinearOperator +from linear_operator.operators.batch_repeat_linear_operator import BatchRepeatLinearOperator +from linear_operator.operators.dense_linear_operator import DenseLinearOperator + +from linear_operator.utils.errors import NotPSDError +from linear_operator.utils.memoize import cached Allsor = Union[Tensor, LinearOperator] @@ -53,10 +54,10 @@ def __add__( self: Float[LinearOperator, "... #M #N"], other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: - from .diag_linear_operator import DiagLinearOperator + from linear_operator.operators.diag_linear_operator import DiagLinearOperator if isinstance(other, DiagLinearOperator): - from .added_diag_linear_operator import AddedDiagLinearOperator + from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator return self.__class__(AddedDiagLinearOperator(self._tensor, other), upper=self.upper) if isinstance(other, TriangularLinearOperator) and not self.upper ^ other.upper: diff --git a/linear_operator/operators/zero_linear_operator.py b/linear_operator/operators/zero_linear_operator.py index 16e540c3..8c6d5867 100644 --- a/linear_operator/operators/zero_linear_operator.py +++ b/linear_operator/operators/zero_linear_operator.py @@ -8,9 +8,10 @@ from jaxtyping import Float from torch import Tensor -from ..utils.getitem import _compute_getitem_size -from ..utils.memoize import cached -from ._linear_operator import IndexType, LinearOperator +from linear_operator.operators._linear_operator import IndexType, LinearOperator + +from linear_operator.utils.getitem import _compute_getitem_size +from linear_operator.utils.memoize import cached class ZeroLinearOperator(LinearOperator): @@ -131,7 +132,7 @@ def add_diagonal( self: Float[LinearOperator, "*batch N N"], diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]], ) -> Float[LinearOperator, "*batch N N"]: - from .diag_linear_operator import DiagLinearOperator + from linear_operator.operators.diag_linear_operator import DiagLinearOperator if self.size(-1) != self.size(-2): raise RuntimeError("add_diag only defined for square matrices") diff --git a/linear_operator/test/linear_operator_test_case.py b/linear_operator/test/linear_operator_test_case.py index abbfa19a..f41cca66 100644 --- a/linear_operator/test/linear_operator_test_case.py +++ b/linear_operator/test/linear_operator_test_case.py @@ -12,11 +12,10 @@ import linear_operator from linear_operator.operators import DenseLinearOperator, DiagLinearOperator, to_dense from linear_operator.settings import linalg_dtypes +from linear_operator.test.base_test_case import BaseTestCase from linear_operator.utils.errors import CachingError from linear_operator.utils.memoize import get_from_cache - -from ..utils.warnings import PerformanceWarning -from .base_test_case import BaseTestCase +from linear_operator.utils.warnings import PerformanceWarning class RectangularLinearOperatorTestCase(BaseTestCase): diff --git a/linear_operator/utils/__init__.py b/linear_operator/utils/__init__.py index 9bfcc9d2..fcdcfb44 100644 --- a/linear_operator/utils/__init__.py +++ b/linear_operator/utils/__init__.py @@ -1,13 +1,23 @@ #!/usr/bin/env python3 -from . import broadcasting, cholesky, errors, getitem, interpolation, lanczos, permutation, sparse, warnings -from .contour_integral_quad import contour_integral_quad -from .linear_cg import linear_cg -from .memoize import cached -from .minres import minres -from .pinverse import stable_pinverse -from .qr import stable_qr -from .stochastic_lq import StochasticLQ +from linear_operator.utils import ( + broadcasting, + cholesky, + errors, + getitem, + interpolation, + lanczos, + permutation, + sparse, + warnings, +) +from linear_operator.utils.contour_integral_quad import contour_integral_quad +from linear_operator.utils.linear_cg import linear_cg +from linear_operator.utils.memoize import cached +from linear_operator.utils.minres import minres +from linear_operator.utils.pinverse import stable_pinverse +from linear_operator.utils.qr import stable_qr +from linear_operator.utils.stochastic_lq import StochasticLQ __all__ = [ "broadcasting", diff --git a/linear_operator/utils/cholesky.py b/linear_operator/utils/cholesky.py index bcc6d51f..4d779371 100644 --- a/linear_operator/utils/cholesky.py +++ b/linear_operator/utils/cholesky.py @@ -4,9 +4,9 @@ import torch -from .. import settings -from .errors import NanError, NotPSDError -from .warnings import NumericalWarning +from linear_operator import settings +from linear_operator.utils.errors import NanError, NotPSDError +from linear_operator.utils.warnings import NumericalWarning def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=None): diff --git a/linear_operator/utils/contour_integral_quad.py b/linear_operator/utils/contour_integral_quad.py index 0bbc0382..4141a352 100644 --- a/linear_operator/utils/contour_integral_quad.py +++ b/linear_operator/utils/contour_integral_quad.py @@ -3,10 +3,10 @@ import torch -from .. import settings -from .linear_cg import linear_cg -from .minres import minres -from .warnings import NumericalWarning +from linear_operator import settings +from linear_operator.utils.linear_cg import linear_cg +from linear_operator.utils.minres import minres +from linear_operator.utils.warnings import NumericalWarning def contour_integral_quad( diff --git a/linear_operator/utils/getitem.py b/linear_operator/utils/getitem.py index 031f0822..17bf94ee 100644 --- a/linear_operator/utils/getitem.py +++ b/linear_operator/utils/getitem.py @@ -6,8 +6,8 @@ import torch -from .. import settings -from .broadcasting import _pad_with_singletons +from linear_operator import settings +from linear_operator.utils.broadcasting import _pad_with_singletons # EllipsisType is only available in Python 3.10+ IndexType = Union[type(Ellipsis), slice, Iterable[int], torch.LongTensor, int] diff --git a/linear_operator/utils/interpolation.py b/linear_operator/utils/interpolation.py index eb756a9e..73faee18 100644 --- a/linear_operator/utils/interpolation.py +++ b/linear_operator/utils/interpolation.py @@ -3,7 +3,7 @@ import torch -from .broadcasting import _matmul_broadcast_shape +from linear_operator.utils.broadcasting import _matmul_broadcast_shape def left_interp(interp_indices, interp_values, rhs): @@ -32,7 +32,7 @@ def left_interp(interp_indices, interp_values, rhs): def left_t_interp(interp_indices, interp_values, rhs, output_dim): """ """ - from .. import dsmm + from linear_operator import dsmm is_vector = rhs.ndimension() == 1 if is_vector: diff --git a/linear_operator/utils/lanczos.py b/linear_operator/utils/lanczos.py index 3def45a2..dc34a1f1 100644 --- a/linear_operator/utils/lanczos.py +++ b/linear_operator/utils/lanczos.py @@ -2,7 +2,7 @@ import torch -from .. import settings +from linear_operator import settings def lanczos_tridiag( diff --git a/linear_operator/utils/linear_cg.py b/linear_operator/utils/linear_cg.py index c7bd54fc..d9b97a2b 100644 --- a/linear_operator/utils/linear_cg.py +++ b/linear_operator/utils/linear_cg.py @@ -4,9 +4,9 @@ import torch -from .. import settings -from .deprecation import bool_compat -from .warnings import NumericalWarning +from linear_operator import settings +from linear_operator.utils.deprecation import bool_compat +from linear_operator.utils.warnings import NumericalWarning def _default_preconditioner(x): diff --git a/linear_operator/utils/memoize.py b/linear_operator/utils/memoize.py index 17ca63f7..4adbe303 100644 --- a/linear_operator/utils/memoize.py +++ b/linear_operator/utils/memoize.py @@ -3,7 +3,7 @@ import functools import pickle -from .errors import CachingError +from linear_operator.utils.errors import CachingError def cached(method=None, name=None, ignore_args=False): diff --git a/linear_operator/utils/minres.py b/linear_operator/utils/minres.py index c4736507..18c755a8 100644 --- a/linear_operator/utils/minres.py +++ b/linear_operator/utils/minres.py @@ -2,8 +2,8 @@ import torch -from .. import settings -from .broadcasting import _pad_with_singletons +from linear_operator import settings +from linear_operator.utils.broadcasting import _pad_with_singletons def minres( diff --git a/linear_operator/utils/permutation.py b/linear_operator/utils/permutation.py index 13ba52e7..8905dc2a 100644 --- a/linear_operator/utils/permutation.py +++ b/linear_operator/utils/permutation.py @@ -53,7 +53,7 @@ def apply_permutation( >>> ]) # Partial permutation: 2 x 3 x 3 >>> apply_permutation(matrix, left_permutation, right_permutation) # 2 x 3 x 5 x 3 """ - from ..operators import to_dense + from linear_operator.operators import to_dense if left_permutation is None and right_permutation is None: return to_dense(matrix) diff --git a/linear_operator/utils/pinverse.py b/linear_operator/utils/pinverse.py index 89a25c5f..4b34e9c2 100644 --- a/linear_operator/utils/pinverse.py +++ b/linear_operator/utils/pinverse.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from .qr import stable_qr +from linear_operator.utils.qr import stable_qr def stable_pinverse(A: Tensor) -> Tensor: diff --git a/linear_operator/utils/sparse.py b/linear_operator/utils/sparse.py index dfa4460e..b55ebb4b 100644 --- a/linear_operator/utils/sparse.py +++ b/linear_operator/utils/sparse.py @@ -2,7 +2,7 @@ import torch -from .broadcasting import _matmul_broadcast_shape +from linear_operator.utils.broadcasting import _matmul_broadcast_shape def make_sparse_from_indices_and_values(interp_indices, interp_values, num_rows): diff --git a/linear_operator/utils/stochastic_lq.py b/linear_operator/utils/stochastic_lq.py index 6650658b..762bec63 100644 --- a/linear_operator/utils/stochastic_lq.py +++ b/linear_operator/utils/stochastic_lq.py @@ -2,7 +2,7 @@ import torch -from .lanczos import lanczos_tridiag +from linear_operator.utils.lanczos import lanczos_tridiag class StochasticLQ(object): diff --git a/linear_operator/utils/toeplitz.py b/linear_operator/utils/toeplitz.py index aa6b1751..7064307c 100644 --- a/linear_operator/utils/toeplitz.py +++ b/linear_operator/utils/toeplitz.py @@ -3,7 +3,7 @@ import torch from torch.fft import fft, ifft -from ..utils import broadcasting +from linear_operator.utils import broadcasting def toeplitz(toeplitz_column, toeplitz_row): diff --git a/test/operators/test_diag_linear_operator.py b/test/operators/test_diag_linear_operator.py index 8357a9bc..a93a4263 100644 --- a/test/operators/test_diag_linear_operator.py +++ b/test/operators/test_diag_linear_operator.py @@ -4,10 +4,20 @@ import torch -from linear_operator.operators import DiagLinearOperator +from linear_operator.operators import DiagLinearOperator, KroneckerProductDiagLinearOperator from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase +def kron_diag(*lts): + """Compute diagonal of a KroneckerProductLinearOperator from the diagonals of the constituiting tensors""" + lead_diag = lts[0].diagonal(dim1=-1, dim2=-2) + if len(lts) == 1: # base case: + return lead_diag + trail_diag = kron_diag(*lts[1:]) + diag = lead_diag.unsqueeze(-2) * trail_diag.unsqueeze(-1) + return diag.mT.reshape(*diag.shape[:-2], -1) + + class TestDiagLinearOperator(LinearOperatorTestCase, unittest.TestCase): seed = 0 should_test_sample = True @@ -124,5 +134,31 @@ def evaluate_linear_op(self, linear_op): return torch.diag_embed(diag) +class TestKroneckerProductDiagLinearOperator(TestDiagLinearOperator): + should_call_lanczos_diagonalization = False + + def create_linear_op(self): + a = torch.tensor([4.0, 1.0, 2.0], dtype=torch.float) + b = torch.tensor([3.0, 1.3], dtype=torch.float) + c = torch.tensor([1.75, 1.95, 1.05, 0.25], dtype=torch.float) + a.requires_grad_(True) + b.requires_grad_(True) + c.requires_grad_(True) + kp_linear_op = KroneckerProductDiagLinearOperator( + DiagLinearOperator(a), DiagLinearOperator(b), DiagLinearOperator(c) + ) + return kp_linear_op + + def evaluate_linear_op(self, linear_op): + res_diag = kron_diag(*linear_op.linear_ops) + return torch.diag_embed(res_diag) + + def test_exp(self): + pass + + def test_log(self): + pass + + if __name__ == "__main__": unittest.main() diff --git a/test/operators/test_kronecker_product_linear_operator.py b/test/operators/test_kronecker_product_linear_operator.py index 12c018c4..ea22cd2d 100644 --- a/test/operators/test_kronecker_product_linear_operator.py +++ b/test/operators/test_kronecker_product_linear_operator.py @@ -4,16 +4,9 @@ import torch -from linear_operator.operators import ( - DenseLinearOperator, - DiagLinearOperator, - KroneckerProductDiagLinearOperator, - KroneckerProductLinearOperator, -) +from linear_operator.operators import DenseLinearOperator, KroneckerProductLinearOperator from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase, RectangularLinearOperatorTestCase -from .test_diag_linear_operator import TestDiagLinearOperator - def kron(a, b): res = [] @@ -25,16 +18,6 @@ def kron(a, b): return torch.cat(res, -2) -def kron_diag(*lts): - """Compute diagonal of a KroneckerProductLinearOperator from the diagonals of the constituiting tensors""" - lead_diag = lts[0].diagonal(dim1=-1, dim2=-2) - if len(lts) == 1: # base case: - return lead_diag - trail_diag = kron_diag(*lts[1:]) - diag = lead_diag.unsqueeze(-2) * trail_diag.unsqueeze(-1) - return diag.mT.reshape(*diag.shape[:-2], -1) - - class TestKroneckerProductLinearOperator(LinearOperatorTestCase, unittest.TestCase): seed = 0 should_call_lanczos = True @@ -61,32 +44,6 @@ def evaluate_linear_op(self, linear_op): return res -class TestKroneckerProductDiagLinearOperator(TestDiagLinearOperator): - should_call_lanczos_diagonalization = False - - def create_linear_op(self): - a = torch.tensor([4.0, 1.0, 2.0], dtype=torch.float) - b = torch.tensor([3.0, 1.3], dtype=torch.float) - c = torch.tensor([1.75, 1.95, 1.05, 0.25], dtype=torch.float) - a.requires_grad_(True) - b.requires_grad_(True) - c.requires_grad_(True) - kp_linear_op = KroneckerProductDiagLinearOperator( - DiagLinearOperator(a), DiagLinearOperator(b), DiagLinearOperator(c) - ) - return kp_linear_op - - def evaluate_linear_op(self, linear_op): - res_diag = kron_diag(*linear_op.linear_ops) - return torch.diag_embed(res_diag) - - def test_exp(self): - pass - - def test_log(self): - pass - - class TestKroneckerProductLinearOperatorBatch(TestKroneckerProductLinearOperator): seed = 0 should_call_lanczos = True