Skip to content

Commit

Permalink
Materialize functional C->R operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl committed Nov 18, 2024
1 parent 0fd9b1f commit 4179194
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 7 deletions.
44 changes: 37 additions & 7 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import jax.tree_util as jtu
import numpy as np
from equinox.internal import ω
from jax import ShapeDtypeStruct
from jaxtyping import (
Array,
ArrayLike,
Expand All @@ -39,6 +40,7 @@

from ._custom_types import sentinel
from ._misc import (
complex_to_real_dtype,
default_floating_dtype,
inexact_asarray,
jacobian,
Expand Down Expand Up @@ -1322,16 +1324,44 @@ def _(operator):

@materialise.register(FunctionLinearOperator)
def _(operator):
flat, unravel = strip_weak_dtype(
eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())
complex_input = jnp.isdtype(
jnp.result_type(*(jax.tree.flatten(operator.in_structure())[0])),
"complex floating",
)
if jnp.result_type(operator.out_structure()) != jnp.result_type(
operator.in_structure()
):
real_output = not jnp.isdtype(
jnp.result_type(*(jax.tree.flatten(operator.out_structure())[0])),
"complex floating",
)
if complex_input and real_output:
# We'll use R^2->R representation for C->R function.
pass
in_structure = jtu.tree_map(
lambda x: ShapeDtypeStruct(
tuple(x.shape) + (2,), complex_to_real_dtype(x.dtype)
)
if jnp.isdtype(x.dtype, "complex floating")
else x,
operator.in_structure(),
)

def map_to_original(x):
with jax.numpy_dtype_promotion("standard"):
return jtu.tree_map(
lambda x, struct: x[..., 0] + 1.0j * x[..., 1]
if jnp.isdtype(struct.dtype, "complex floating")
else x,
x,
operator.in_structure(),
)
else:
map_to_original = lambda x: x
in_structure = operator.in_structure()
flat, unravel = strip_weak_dtype(
eqx.filter_eval_shape(jfu.ravel_pytree, in_structure)
)
fn = lambda x: operator.fn(map_to_original(unravel(x)))
eye = jnp.eye(flat.size, dtype=flat.dtype)
jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye)

jac = jax.vmap(fn, out_axes=-1)(eye)

def batch_unravel(x):
assert x.ndim > 0
Expand Down
23 changes: 23 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax.random as jr
import lineax as lx
import pytest
from lineax._misc import complex_to_real_dtype

from .helpers import (
make_diagonal_operator,
Expand Down Expand Up @@ -321,6 +322,28 @@ def test_materialise_function_linear_operator(dtype, getkey):
assert jax.eval_shape(lambda: materialised_operator.pytree) == expected_struct


def test_materialise_function_real_linear_operator(getkey):
dtype = jnp.complex128
x = (
jr.normal(getkey(), (5, 9), dtype=dtype),
jr.normal(getkey(), (3,), dtype=dtype),
)
input_structure = jax.eval_shape(lambda: x)
fn = lambda x: {"a": jnp.broadcast_to(jnp.sum(x[0]).real, (1, 2))}
output_structure = jax.eval_shape(fn, input_structure)
operator = lx.FunctionLinearOperator(fn, input_structure)
materialised_operator = lx.materialise(operator)
assert materialised_operator.out_structure() == output_structure
assert isinstance(materialised_operator, lx.PyTreeLinearOperator)
expected_struct = {
"a": (
jax.ShapeDtypeStruct((1, 2, 5, 9, 2), complex_to_real_dtype(dtype)),
jax.ShapeDtypeStruct((1, 2, 3, 2), complex_to_real_dtype(dtype)),
)
}
assert jax.eval_shape(lambda: materialised_operator.pytree) == expected_struct


@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_pytree_transpose(dtype, getkey):
out_struct = jax.eval_shape(
Expand Down

0 comments on commit 4179194

Please sign in to comment.