Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
saitcakmak committed Nov 20, 2023
1 parent fe711c0 commit edae3ac
Show file tree
Hide file tree
Showing 28 changed files with 65 additions and 55 deletions.
2 changes: 1 addition & 1 deletion linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from linear_operator.functions._root_decomposition import RootDecomposition
from linear_operator.functions._solve import Solve
from linear_operator.functions._sqrt_inv_matmul import SqrtInvMatmul
from linear_operator.operators.linear_operator_representation_tree import LinearOperatorRepresentationTree
from linear_operator.utils.broadcasting import _matmul_broadcast_shape
from linear_operator.utils.cholesky import psd_safe_cholesky
from linear_operator.utils.deprecation import _deprecate_renamed_methods
Expand All @@ -56,7 +57,6 @@
)
from linear_operator.utils.pinverse import stable_pinverse
from linear_operator.utils.warnings import NumericalWarning, PerformanceWarning
from linear_operator.operators.linear_operator_representation_tree import LinearOperatorRepresentationTree

_HANDLED_FUNCTIONS = {}
_HANDLED_SECOND_ARG_FUNCTIONS = {}
Expand Down
4 changes: 2 additions & 2 deletions linear_operator/operators/added_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from torch import Tensor

from linear_operator import settings
from linear_operator.utils.memoize import cached
from linear_operator.utils.warnings import NumericalWarning
from linear_operator.operators._linear_operator import LinearOperator
from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator
from linear_operator.operators.psd_sum_linear_operator import PsdSumLinearOperator
from linear_operator.operators.root_linear_operator import RootLinearOperator
from linear_operator.operators.sum_linear_operator import SumLinearOperator
from linear_operator.utils.memoize import cached
from linear_operator.utils.warnings import NumericalWarning


class AddedDiagLinearOperator(SumLinearOperator):
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/operators/batch_repeat_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from torch import Tensor

from linear_operator import settings
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.utils.broadcasting import _matmul_broadcast_shape
from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator


class BatchRepeatLinearOperator(LinearOperator):
Expand Down
3 changes: 2 additions & 1 deletion linear_operator/operators/block_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.block_linear_operator import BlockLinearOperator

from linear_operator.utils.memoize import cached


# metaclass of BlockDiagLinearOperator, overwrites behavior of constructor call
# _MetaBlockDiagLinearOperator(base_linear_op, block_dim=-3) to return a DiagLinearOperator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.block_linear_operator import BlockLinearOperator

from linear_operator.utils.memoize import cached


class BlockInterleavedLinearOperator(BlockLinearOperator):
"""
Expand Down
3 changes: 2 additions & 1 deletion linear_operator/operators/block_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.getitem import _is_noop_index, _noop_index
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.dense_linear_operator import to_linear_operator

from linear_operator.utils.getitem import _is_noop_index, _noop_index


class BlockLinearOperator(LinearOperator):
"""
Expand Down
5 changes: 3 additions & 2 deletions linear_operator/operators/cat_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.operators._linear_operator import IndexType, LinearOperator, to_dense
from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator

from linear_operator.utils.broadcasting import _matmul_broadcast_shape
from linear_operator.utils.deprecation import bool_compat
from linear_operator.utils.generic import _to_helper
from linear_operator.utils.getitem import _noop_index
from linear_operator.operators._linear_operator import IndexType, LinearOperator, to_dense
from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator


def cat(inputs, dim=0, output_device=None):
Expand Down
3 changes: 2 additions & 1 deletion linear_operator/operators/chol_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import LinearOperator
from linear_operator.operators.root_linear_operator import RootLinearOperator
from linear_operator.operators.triangular_linear_operator import _TriangularLinearOperatorBase, TriangularLinearOperator

from linear_operator.utils.memoize import cached


class CholLinearOperator(RootLinearOperator):
r"""
Expand Down
3 changes: 2 additions & 1 deletion linear_operator/operators/constant_mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.root_linear_operator import RootLinearOperator

from linear_operator.utils.memoize import cached


class ConstantMulLinearOperator(LinearOperator):
"""
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/operators/diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from torch import Tensor

from linear_operator import settings
from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.block_diag_linear_operator import BlockDiagLinearOperator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator
from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator
from linear_operator.utils.memoize import cached


class DiagLinearOperator(TriangularLinearOperator):
Expand Down
7 changes: 4 additions & 3 deletions linear_operator/operators/identity_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.generic import _to_helper
from linear_operator.utils.getitem import _compute_getitem_size, _is_noop_index
from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator
from linear_operator.operators.zero_linear_operator import ZeroLinearOperator

from linear_operator.utils.generic import _to_helper
from linear_operator.utils.getitem import _compute_getitem_size, _is_noop_index
from linear_operator.utils.memoize import cached


class IdentityLinearOperator(ConstantDiagLinearOperator):
"""
Expand Down
9 changes: 5 additions & 4 deletions linear_operator/operators/interpolated_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator
from linear_operator.operators.diag_linear_operator import DiagLinearOperator
from linear_operator.operators.root_linear_operator import RootLinearOperator

from linear_operator.utils import sparse
from linear_operator.utils.broadcasting import _pad_with_singletons
from linear_operator.utils.generic import _to_helper
from linear_operator.utils.getitem import _noop_index
from linear_operator.utils.interpolation import left_interp, left_t_interp
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator
from linear_operator.operators.diag_linear_operator import DiagLinearOperator
from linear_operator.operators.root_linear_operator import RootLinearOperator


class InterpolatedLinearOperator(LinearOperator):
Expand Down
3 changes: 2 additions & 1 deletion linear_operator/operators/keops_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.operators._linear_operator import IndexType, LinearOperator

from linear_operator.utils.getitem import _noop_index
from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator


class KeOpsLinearOperator(LinearOperator):
Expand Down
3 changes: 2 additions & 1 deletion linear_operator/operators/kernel_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.operators._linear_operator import LinearOperator, to_dense

from linear_operator.utils.broadcasting import _pad_with_singletons
from linear_operator.utils.getitem import _noop_index, IndexType
from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import LinearOperator, to_dense


def _x_getitem(x, batch_indices, data_index):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,12 @@
from torch import Tensor

from linear_operator import settings
from linear_operator.utils.broadcasting import _matmul_broadcast_shape
from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.dense_linear_operator import to_linear_operator
from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator
from linear_operator.operators.triangular_linear_operator import (
_TriangularLinearOperatorBase,
TriangularLinearOperator,
)
from linear_operator.operators.triangular_linear_operator import _TriangularLinearOperatorBase, TriangularLinearOperator
from linear_operator.utils.broadcasting import _matmul_broadcast_shape
from linear_operator.utils.memoize import cached


def _kron_diag(*lts) -> Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.cholesky import psd_safe_cholesky
from linear_operator.utils.memoize import cached
from linear_operator.operators import to_dense
from linear_operator.operators._linear_operator import LinearOperator
from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator
from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator
from linear_operator.operators.low_rank_root_linear_operator import LowRankRootLinearOperator
from linear_operator.operators.sum_batch_linear_operator import SumBatchLinearOperator

from linear_operator.utils.cholesky import psd_safe_cholesky
from linear_operator.utils.memoize import cached


class LowRankRootAddedDiagLinearOperator(AddedDiagLinearOperator):
def __init__(self, *linear_ops, preconditioner_override=None):
Expand Down
4 changes: 2 additions & 2 deletions linear_operator/operators/masked_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from jaxtyping import Bool, Float
from torch import Tensor

from linear_operator.utils.generic import _to_helper

from linear_operator.operators._linear_operator import _is_noop_index, IndexType, LinearOperator

from linear_operator.utils.generic import _to_helper


class MaskedLinearOperator(LinearOperator):
r"""
Expand Down
7 changes: 4 additions & 3 deletions linear_operator/operators/matmul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.broadcasting import _matmul_broadcast_shape, _pad_with_singletons
from linear_operator.utils.getitem import _noop_index
from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator
from linear_operator.operators.diag_linear_operator import DiagLinearOperator

from linear_operator.utils.broadcasting import _matmul_broadcast_shape, _pad_with_singletons
from linear_operator.utils.getitem import _noop_index
from linear_operator.utils.memoize import cached


def _inner_repeat(tensor, amt):
return tensor.unsqueeze(-1).repeat(amt, 1).squeeze(-1)
Expand Down
5 changes: 3 additions & 2 deletions linear_operator/operators/mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.broadcasting import _matmul_broadcast_shape
from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.linear_operator_representation_tree import LinearOperatorRepresentationTree
from linear_operator.operators.root_linear_operator import RootLinearOperator

from linear_operator.utils.broadcasting import _matmul_broadcast_shape
from linear_operator.utils.memoize import cached


class MulLinearOperator(LinearOperator):
def _check_args(self, left_linear_op, right_linear_op):
Expand Down
7 changes: 4 additions & 3 deletions linear_operator/operators/root_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.broadcasting import _pad_with_singletons
from linear_operator.utils.getitem import _equal_indices, _noop_index
from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator
from linear_operator.operators.matmul_linear_operator import MatmulLinearOperator

from linear_operator.utils.broadcasting import _pad_with_singletons
from linear_operator.utils.getitem import _equal_indices, _noop_index
from linear_operator.utils.memoize import cached


class RootLinearOperator(LinearOperator):
def __init__(self, root):
Expand Down
5 changes: 3 additions & 2 deletions linear_operator/operators/sum_batch_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.broadcasting import _pad_with_singletons
from linear_operator.utils.getitem import _noop_index
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.block_linear_operator import BlockLinearOperator

from linear_operator.utils.broadcasting import _pad_with_singletons
from linear_operator.utils.getitem import _noop_index


class SumBatchLinearOperator(BlockLinearOperator):
"""
Expand Down
3 changes: 2 additions & 1 deletion linear_operator/operators/sum_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.dense_linear_operator import to_linear_operator
from linear_operator.operators.zero_linear_operator import ZeroLinearOperator

from linear_operator.utils.memoize import cached

# from linear_operator.operators.broadcasted_linear_operator import BroadcastedLinearOperator


Expand Down
3 changes: 2 additions & 1 deletion linear_operator/operators/toeplitz_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.toeplitz import sym_toeplitz_derivative_quadratic_form, sym_toeplitz_matmul
from linear_operator.operators._linear_operator import IndexType, LinearOperator

from linear_operator.utils.toeplitz import sym_toeplitz_derivative_quadratic_form, sym_toeplitz_matmul


class ToeplitzLinearOperator(LinearOperator):
def __init__(self, column):
Expand Down
5 changes: 3 additions & 2 deletions linear_operator/operators/triangular_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.utils.errors import NotPSDError
from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.batch_repeat_linear_operator import BatchRepeatLinearOperator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator

from linear_operator.utils.errors import NotPSDError
from linear_operator.utils.memoize import cached

Allsor = Union[Tensor, LinearOperator]


Expand Down
3 changes: 2 additions & 1 deletion linear_operator/operators/zero_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from jaxtyping import Float
from torch import Tensor

from linear_operator.operators._linear_operator import IndexType, LinearOperator

from linear_operator.utils.getitem import _compute_getitem_size
from linear_operator.utils.memoize import cached
from linear_operator.operators._linear_operator import IndexType, LinearOperator


class ZeroLinearOperator(LinearOperator):
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/test/linear_operator_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import linear_operator
from linear_operator.operators import DenseLinearOperator, DiagLinearOperator, to_dense
from linear_operator.settings import linalg_dtypes
from linear_operator.test.base_test_case import BaseTestCase
from linear_operator.utils.errors import CachingError
from linear_operator.utils.memoize import get_from_cache
from linear_operator.utils.warnings import PerformanceWarning
from linear_operator.test.base_test_case import BaseTestCase


class RectangularLinearOperatorTestCase(BaseTestCase):
Expand Down
5 changes: 1 addition & 4 deletions test/operators/test_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@

import torch

from linear_operator.operators import (
DiagLinearOperator,
KroneckerProductDiagLinearOperator
)
from linear_operator.operators import DiagLinearOperator, KroneckerProductDiagLinearOperator
from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase


Expand Down
Loading

0 comments on commit edae3ac

Please sign in to comment.