Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use absolute imports #85

Merged
merged 5 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions linear_operator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
from . import beta_features, operators, settings, utils
from .functions import (
from linear_operator import beta_features, operators, settings, utils
from linear_operator.functions import (
add_diagonal,
add_jitter,
diagonalization,
Expand All @@ -13,7 +13,7 @@
solve,
sqrt_inv_matmul,
)
from .operators import LinearOperator, to_dense, to_linear_operator
from linear_operator.operators import LinearOperator, to_dense, to_linear_operator

# Read version number as written by setuptools_scm
try:
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/beta_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings

from .settings import _feature_flag
from linear_operator.settings import _feature_flag


class _moved_beta_feature(object):
Expand Down
20 changes: 10 additions & 10 deletions linear_operator/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from ._dsmm import DSMM
from linear_operator.functions._dsmm import DSMM

LinearOperatorType = Any # Want this to be "LinearOperator" but runtime type checker can't handle

Expand All @@ -23,7 +23,7 @@ def add_diagonal(input: Anysor, diag: torch.Tensor) -> LinearOperatorType:
:return: :math:`\mathbf A + \text{diag}(\mathbf d)`, where :math:`\mathbf A` is the linear operator
and :math:`\mathbf d` is the diagonal component
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).add_diagonal(diag)

Expand Down Expand Up @@ -61,7 +61,7 @@ def diagonalization(
based on size if not specified.
:return: eigenvalues and eigenvectors representing the diagonalization.
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).diagonalization(method=method)

Expand Down Expand Up @@ -105,7 +105,7 @@ def inv_quad(input: Anysor, inv_quad_rhs: torch.Tensor, reduce_inv_quad: bool =
:returns: The inverse quadratic term.
If `reduce_inv_quad=True`, the inverse quadratic term is of shape (...). Otherwise, it is (... x M).
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).inv_quad(inv_quad_rhs, reduce_inv_quad=reduce_inv_quad)

Expand All @@ -127,7 +127,7 @@ def inv_quad_logdet(
:returns: The inverse quadratic term (or None), and the logdet term (or None).
If `reduce_inv_quad=True`, the inverse quadratic term is of shape (...). Otherwise, it is (... x M).
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).inv_quad_logdet(inv_quad_rhs, logdet, reduce_inv_quad=reduce_inv_quad)

Expand Down Expand Up @@ -156,7 +156,7 @@ def pivoted_cholesky(
.. _Harbrecht et al., 2012:
https://www.sciencedirect.com/science/article/pii/S0168927411001814
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).pivoted_cholesky(rank=rank, error_tol=error_tol, return_pivots=return_pivots)

Expand All @@ -173,7 +173,7 @@ def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOpe
"cholesky", "lanczos", "symeig", "pivoted_cholesky", or "svd".
:return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A`.
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).root_decomposition(method=method)

Expand All @@ -199,7 +199,7 @@ def root_inv_decomposition(
:param method: Root decomposition method to use (symeig, diagonalization, lanczos, or cholesky).
:return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A^{-1}`.
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).root_inv_decomposition(
initial_vectors=initial_vectors, test_vectors=test_vectors, method=method
Expand Down Expand Up @@ -235,7 +235,7 @@ def solve(input: Anysor, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None)
:param lhs: :math:`\mathbf L` - the left hand side
:return: :math:`\mathbf A^{-1} \mathbf R` or :math:`\mathbf L \mathbf A^{-1} \mathbf R`.
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).solve(right_tensor=rhs, left_tensor=lhs)

Expand Down Expand Up @@ -268,7 +268,7 @@ def sqrt_inv_matmul(
:param lhs: :math:`\mathbf L` - the left hand side
:return: :math:`\mathbf A^{-1/2} \mathbf R` or :math:`\mathbf L \mathbf A^{-1/2} \mathbf R`.
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).sqrt_inv_matmul(rhs=rhs, lhs=lhs)

Expand Down
4 changes: 2 additions & 2 deletions linear_operator/functions/_diagonalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from torch.autograd import Function

from .. import settings
from ..utils import lanczos
from linear_operator import settings
from linear_operator.utils import lanczos


class Diagonalization(Function):
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/functions/_dsmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch.autograd import Function

from ..utils.sparse import bdsmm
from linear_operator.utils.sparse import bdsmm


class DSMM(Function):
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/functions/_inv_quad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.autograd import Function

from .. import settings
from linear_operator import settings


def _solve(linear_op, rhs):
Expand Down
6 changes: 3 additions & 3 deletions linear_operator/functions/_inv_quad_logdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch
from torch.autograd import Function

from .. import settings
from ..utils.lanczos import lanczos_tridiag_to_diag
from ..utils.stochastic_lq import StochasticLQ
from linear_operator import settings
from linear_operator.utils.lanczos import lanczos_tridiag_to_diag
from linear_operator.utils.stochastic_lq import StochasticLQ


class InvQuadLogdet(Function):
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/functions/_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch.autograd import Function

from .. import settings
from linear_operator import settings


class Matmul(Function):
Expand Down
6 changes: 3 additions & 3 deletions linear_operator/functions/_pivoted_cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch
from torch.autograd import Function

from .. import settings
from ..utils.cholesky import psd_safe_cholesky
from ..utils.permutation import apply_permutation, inverse_permutation
from linear_operator import settings
from linear_operator.utils.cholesky import psd_safe_cholesky
from linear_operator.utils.permutation import apply_permutation, inverse_permutation


class PivotedCholesky(Function):
Expand Down
6 changes: 3 additions & 3 deletions linear_operator/functions/_root_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from torch.autograd import Function

from .. import settings
from ..utils import lanczos
from linear_operator import settings
from linear_operator.utils import lanczos


class RootDecomposition(Function):
Expand All @@ -29,7 +29,7 @@ def forward(
:return: :attr:`R`, such that :math:`R R^T \approx A`, and :attr:`R_inv`, such that
:math:`R_{inv} R_{inv}^T \approx A^{-1}` (will only be populated if self.inverse = True)
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

ctx.representation_tree = representation_tree
ctx.device = device
Expand Down
4 changes: 2 additions & 2 deletions linear_operator/functions/_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch
from torch.autograd import Function

from .. import settings
from linear_operator import settings


def _solve(linear_op, rhs):
from ..operators import CholLinearOperator, TriangularLinearOperator
from linear_operator.operators import CholLinearOperator, TriangularLinearOperator

if isinstance(linear_op, (CholLinearOperator, TriangularLinearOperator)):
# May want to do this for some KroneckerProductLinearOperators and possibly
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/functions/_sqrt_inv_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.autograd import Function

from .. import settings, utils
from linear_operator import settings, utils


class SqrtInvMatmul(Function):
Expand Down
67 changes: 36 additions & 31 deletions linear_operator/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,45 @@
#!/usr/bin/env python3

from ._linear_operator import LinearOperator, to_dense
from .added_diag_linear_operator import AddedDiagLinearOperator
from .batch_repeat_linear_operator import BatchRepeatLinearOperator
from .block_diag_linear_operator import BlockDiagLinearOperator
from .block_interleaved_linear_operator import BlockInterleavedLinearOperator
from .block_linear_operator import BlockLinearOperator
from .cat_linear_operator import cat, CatLinearOperator
from .chol_linear_operator import CholLinearOperator
from .constant_mul_linear_operator import ConstantMulLinearOperator
from .dense_linear_operator import DenseLinearOperator, to_linear_operator
from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator
from .identity_linear_operator import IdentityLinearOperator
from .interpolated_linear_operator import InterpolatedLinearOperator
from .keops_linear_operator import KeOpsLinearOperator
from .kernel_linear_operator import KernelLinearOperator
from .kronecker_product_added_diag_linear_operator import KroneckerProductAddedDiagLinearOperator
from .kronecker_product_linear_operator import (
from linear_operator.operators._linear_operator import LinearOperator, to_dense
from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator
from linear_operator.operators.batch_repeat_linear_operator import BatchRepeatLinearOperator
from linear_operator.operators.block_diag_linear_operator import BlockDiagLinearOperator
from linear_operator.operators.block_interleaved_linear_operator import BlockInterleavedLinearOperator
from linear_operator.operators.block_linear_operator import BlockLinearOperator
from linear_operator.operators.cat_linear_operator import cat, CatLinearOperator
from linear_operator.operators.chol_linear_operator import CholLinearOperator
from linear_operator.operators.constant_mul_linear_operator import ConstantMulLinearOperator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator
from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator
from linear_operator.operators.identity_linear_operator import IdentityLinearOperator
from linear_operator.operators.interpolated_linear_operator import InterpolatedLinearOperator
from linear_operator.operators.keops_linear_operator import KeOpsLinearOperator
from linear_operator.operators.kernel_linear_operator import KernelLinearOperator
from linear_operator.operators.kronecker_product_added_diag_linear_operator import (
KroneckerProductAddedDiagLinearOperator,
)
from linear_operator.operators.kronecker_product_linear_operator import (
KroneckerProductDiagLinearOperator,
KroneckerProductLinearOperator,
KroneckerProductTriangularLinearOperator,
)
from .low_rank_root_added_diag_linear_operator import LowRankRootAddedDiagLinearOperator
from .low_rank_root_linear_operator import LowRankRootLinearOperator
from .masked_linear_operator import MaskedLinearOperator
from .matmul_linear_operator import MatmulLinearOperator
from .mul_linear_operator import MulLinearOperator
from .permutation_linear_operator import PermutationLinearOperator, TransposePermutationLinearOperator
from .psd_sum_linear_operator import PsdSumLinearOperator
from .root_linear_operator import RootLinearOperator
from .sum_batch_linear_operator import SumBatchLinearOperator
from .sum_kronecker_linear_operator import SumKroneckerLinearOperator
from .sum_linear_operator import SumLinearOperator
from .toeplitz_linear_operator import ToeplitzLinearOperator
from .triangular_linear_operator import TriangularLinearOperator
from .zero_linear_operator import ZeroLinearOperator
from linear_operator.operators.low_rank_root_added_diag_linear_operator import LowRankRootAddedDiagLinearOperator
from linear_operator.operators.low_rank_root_linear_operator import LowRankRootLinearOperator
from linear_operator.operators.masked_linear_operator import MaskedLinearOperator
from linear_operator.operators.matmul_linear_operator import MatmulLinearOperator
from linear_operator.operators.mul_linear_operator import MulLinearOperator
from linear_operator.operators.permutation_linear_operator import (
PermutationLinearOperator,
TransposePermutationLinearOperator,
)
from linear_operator.operators.psd_sum_linear_operator import PsdSumLinearOperator
from linear_operator.operators.root_linear_operator import RootLinearOperator
from linear_operator.operators.sum_batch_linear_operator import SumBatchLinearOperator
from linear_operator.operators.sum_kronecker_linear_operator import SumKroneckerLinearOperator
from linear_operator.operators.sum_linear_operator import SumLinearOperator
from linear_operator.operators.toeplitz_linear_operator import ToeplitzLinearOperator
from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator
from linear_operator.operators.zero_linear_operator import ZeroLinearOperator

__all__ = [
"to_dense",
Expand Down
Loading
Loading