Skip to content

Commit

Permalink
Change type annotation to allow complex types
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl authored and patrick-kidger committed Nov 7, 2023
1 parent 9a57352 commit 3f34189
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@
import jax.tree_util as jtu
import numpy as np
from equinox.internal import ω
from jaxtyping import Array, ArrayLike, Float, PyTree, Scalar, Shaped # pyright: ignore
from jaxtyping import ( # pyright: ignore
Array,
ArrayLike,
Inexact,
PyTree,
Scalar,
Shaped,
)

from ._custom_types import sentinel
from ._misc import (
Expand Down Expand Up @@ -95,7 +102,9 @@ def __check_init__(self):
)

@abc.abstractmethod
def mv(self, vector: PyTree[Float[Array, " _b"]]) -> PyTree[Float[Array, " _a"]]:
def mv(
self, vector: PyTree[Inexact[Array, " _b"]]
) -> PyTree[Inexact[Array, " _a"]]:
"""Computes a matrix-vector product between this operator and a `vector`.
**Arguments:**
Expand All @@ -110,7 +119,7 @@ def mv(self, vector: PyTree[Float[Array, " _b"]]) -> PyTree[Float[Array, " _a"]]
"""

@abc.abstractmethod
def as_matrix(self) -> Float[Array, "a b"]:
def as_matrix(self) -> Inexact[Array, "a b"]:
"""Materialises this linear operator as a matrix.
Note that this can be a computationally (time and/or memory) expensive
Expand Down Expand Up @@ -230,7 +239,7 @@ class MatrixLinearOperator(AbstractLinearOperator):
shape `(a,)` and returns a vector of shape `(b,)`.
"""

matrix: Float[Array, "a b"]
matrix: Inexact[Array, "a b"]
tags: frozenset[object] = eqx.field(static=True)

def __init__(
Expand Down Expand Up @@ -353,7 +362,7 @@ class PyTreeLinearOperator(AbstractLinearOperator):
```
"""

pytree: PyTree[Float[Array, "..."]]
pytree: PyTree[Inexact[Array, "..."]]
output_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
tags: frozenset[object] = eqx.field(static=True)
input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
Expand Down Expand Up @@ -513,9 +522,9 @@ class JacobianLinearOperator(AbstractLinearOperator):
"""

fn: Callable[
[PyTree[Float[Array, "..."]], PyTree[Any]], PyTree[Float[Array, "..."]]
[PyTree[Inexact[Array, "..."]], PyTree[Any]], PyTree[Inexact[Array, "..."]]
]
x: PyTree[Float[Array, "..."]]
x: PyTree[Inexact[Array, "..."]]
args: PyTree[Any]
tags: frozenset[object] = eqx.field(static=True)

Expand Down Expand Up @@ -592,13 +601,13 @@ class FunctionLinearOperator(AbstractLinearOperator):
in memory. (Similar to `.as_matrix()`.)
"""

fn: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]]
fn: Callable[[PyTree[Inexact[Array, "..."]]], PyTree[Inexact[Array, "..."]]]
input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
tags: frozenset[object] = eqx.field(static=True)

def __init__(
self,
fn: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]],
fn: Callable[[PyTree[Inexact[Array, "..."]]], PyTree[Inexact[Array, "..."]]],
input_structure: PyTree[jax.ShapeDtypeStruct],
tags: Union[object, Iterable[object]] = (),
):
Expand Down Expand Up @@ -734,7 +743,7 @@ class DiagonalLinearOperator(AbstractLinearOperator):
`matrix @ vector` (for speed).
"""

diagonal: Float[Array, " size"]
diagonal: Inexact[Array, " size"]

def __init__(self, diagonal: Shaped[Array, " size"]):
"""**Arguments:**
Expand Down Expand Up @@ -767,15 +776,15 @@ class TridiagonalLinearOperator(AbstractLinearOperator):
matrix.
"""

diagonal: Float[Array, " size"]
lower_diagonal: Float[Array, " size-1"]
upper_diagonal: Float[Array, " size-1"]
diagonal: Inexact[Array, " size"]
lower_diagonal: Inexact[Array, " size-1"]
upper_diagonal: Inexact[Array, " size-1"]

def __init__(
self,
diagonal: Float[Array, " size"],
lower_diagonal: Float[Array, " size-1"],
upper_diagonal: Float[Array, " size-1"],
diagonal: Inexact[Array, " size"],
lower_diagonal: Inexact[Array, " size-1"],
upper_diagonal: Inexact[Array, " size-1"],
):
"""**Arguments:**
Expand Down

0 comments on commit 3f34189

Please sign in to comment.