Skip to content

Commit

Permalink
Add to some tests a direct call to JAXOp
Browse files Browse the repository at this point in the history
  • Loading branch information
jdehning committed Feb 6, 2025
1 parent 48fbf0a commit b8c4523
Showing 1 changed file with 114 additions and 17 deletions.
131 changes: 114 additions & 17 deletions tests/link/jax/test_as_jax_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import pytensor.tensor as pt
from pytensor import as_jax_op, config, grad
from pytensor.graph.fg import FunctionGraph
from pytensor.link.jax.ops import JAXOp
from pytensor.scalar import all_types
from pytensor.tensor import tensor
from pytensor.tensor import TensorType, tensor
from tests.link.jax.test_basic import compare_jax_and_py


Expand All @@ -19,18 +20,29 @@ def test_two_inputs_single_output():
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]

@as_jax_op
def f(x, y):
return jax.nn.sigmoid(x + y)

out = f(x, y)
# Test with as_jax_op decorator
out = as_jax_op(f)(x, y)
grad_out = grad(pt.sum(out), [x, y])

fg = FunctionGraph([x, y], [out, *grad_out])
fn, _ = compare_jax_and_py(fg, test_values)
with jax.disable_jit():
fn, _ = compare_jax_and_py(fg, test_values)

# Test direct JAXOp usage
jax_op = JAXOp(
[x.type, y.type],
[TensorType(config.floatX, shape=(2,))],
f,
)
out = jax_op(x, y)
grad_out = grad(pt.sum(out), [x, y])
fg = FunctionGraph([x, y], [out, *grad_out])
fn, _ = compare_jax_and_py(fg, test_values)


def test_two_inputs_tuple_output():
rng = np.random.default_rng(2)
Expand All @@ -40,11 +52,11 @@ def test_two_inputs_tuple_output():
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]

@as_jax_op
def f(x, y):
return jax.nn.sigmoid(x + y), y * 2

out1, out2 = f(x, y)
# Test with as_jax_op decorator
out1, out2 = as_jax_op(f)(x, y)
grad_out = grad(pt.sum(out1 + out2), [x, y])

fg = FunctionGraph([x, y], [out1, out2, *grad_out])
Expand All @@ -54,6 +66,17 @@ def f(x, y):
# inputs are not automatically transformed to jax.Array anymore
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)

# Test direct JAXOp usage
jax_op = JAXOp(
[x.type, y.type],
[TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))],
f,
)
out1, out2 = jax_op(x, y)
grad_out = grad(pt.sum(out1 + out2), [x, y])
fg = FunctionGraph([x, y], [out1, out2, *grad_out])
fn, _ = compare_jax_and_py(fg, test_values)


def test_two_inputs_list_output_one_unused_output():
# One output is unused, to test whether the wrapper can handle DisconnectedType
Expand All @@ -64,72 +87,119 @@ def test_two_inputs_list_output_one_unused_output():
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]

@as_jax_op
def f(x, y):
return [jax.nn.sigmoid(x + y), y * 2]

out, _ = f(x, y)
# Test with as_jax_op decorator
out, _ = as_jax_op(f)(x, y)
grad_out = grad(pt.sum(out), [x, y])

fg = FunctionGraph([x, y], [out, *grad_out])
fn, _ = compare_jax_and_py(fg, test_values)
with jax.disable_jit():
fn, _ = compare_jax_and_py(fg, test_values)

# Test direct JAXOp usage
jax_op = JAXOp(
[x.type, y.type],
[TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))],
f,
)
out, _ = jax_op(x, y)
grad_out = grad(pt.sum(out), [x, y])
fg = FunctionGraph([x, y], [out, *grad_out])
fn, _ = compare_jax_and_py(fg, test_values)


def test_single_input_tuple_output():
rng = np.random.default_rng(4)
x = tensor("x", shape=(2,))
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]

@as_jax_op
def f(x):
return jax.nn.sigmoid(x), x * 2

out1, out2 = f(x)
# Test with as_jax_op decorator
out1, out2 = as_jax_op(f)(x)
grad_out = grad(pt.sum(out1), [x])

fg = FunctionGraph([x], [out1, out2, *grad_out])
fn, _ = compare_jax_and_py(fg, test_values)
with jax.disable_jit():
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)

# Test direct JAXOp usage
jax_op = JAXOp(
[x.type],
[TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))],
f,
)
out1, out2 = jax_op(x)
grad_out = grad(pt.sum(out1), [x])
fg = FunctionGraph([x], [out1, out2, *grad_out])
fn, _ = compare_jax_and_py(fg, test_values)


def test_scalar_input_tuple_output():
rng = np.random.default_rng(5)
x = tensor("x", shape=())
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]

@as_jax_op
def f(x):
return jax.nn.sigmoid(x), x

out1, out2 = f(x)
# Test with as_jax_op decorator
out1, out2 = as_jax_op(f)(x)
grad_out = grad(pt.sum(out1), [x])

fg = FunctionGraph([x], [out1, out2, *grad_out])
fn, _ = compare_jax_and_py(fg, test_values)
with jax.disable_jit():
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)

# Test direct JAXOp usage
jax_op = JAXOp(
[x.type],
[TensorType(config.floatX, shape=()), TensorType(config.floatX, shape=())],
f,
)
out1, out2 = jax_op(x)
grad_out = grad(pt.sum(out1), [x])
fg = FunctionGraph([x], [out1, out2, *grad_out])
fn, _ = compare_jax_and_py(fg, test_values)


def test_single_input_list_output():
rng = np.random.default_rng(6)
x = tensor("x", shape=(2,))
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]

@as_jax_op
def f(x):
return [jax.nn.sigmoid(x), 2 * x]

out1, out2 = f(x)
# Test with as_jax_op decorator
out1, out2 = as_jax_op(f)(x)
grad_out = grad(pt.sum(out1), [x])

fg = FunctionGraph([x], [out1, out2, *grad_out])
fn, _ = compare_jax_and_py(fg, test_values)
with jax.disable_jit():
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)

# Test direct JAXOp usage, with unspecified output shapes
jax_op = JAXOp(
[x.type],
[
TensorType(config.floatX, shape=(None,)),
TensorType(config.floatX, shape=(None,)),
],
f,
)
out1, out2 = jax_op(x)
grad_out = grad(pt.sum(out1), [x])
fg = FunctionGraph([x], [out1, out2, *grad_out])
fn, _ = compare_jax_and_py(fg, test_values)


def test_pytree_input_tuple_output():
rng = np.random.default_rng(7)
Expand All @@ -144,6 +214,7 @@ def test_pytree_input_tuple_output():
def f(x, y):
return jax.nn.sigmoid(x), 2 * x + y["y"] + y["y2"][0]

# Test with as_jax_op decorator
out = f(x, y_tmp)
grad_out = grad(pt.sum(out[1]), [x, y])

Expand All @@ -167,6 +238,7 @@ def test_pytree_input_pytree_output():
def f(x, y):
return x, jax.tree_util.tree_map(lambda x: jnp.exp(x), y)

# Test with as_jax_op decorator
out = f(x, y_tmp)
grad_out = grad(pt.sum(out[1]["b"][0]), [x, y])

Expand Down Expand Up @@ -198,6 +270,7 @@ def f(x, y, depth, which_variable):
var = jax.nn.sigmoid(var)
return var

# Test with as_jax_op decorator
# arguments depth and which_variable are not part of the graph
out = f(x, y_tmp, depth=3, which_variable="x")
grad_out = grad(pt.sum(out), [x])
Expand Down Expand Up @@ -228,11 +301,11 @@ def test_unused_matrix_product():
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]

@as_jax_op
def f(x, y):
return x[:, None] @ y[None], jnp.exp(x)

out = f(x, y)
# Test with as_jax_op decorator
out = as_jax_op(f)(x, y)
grad_out = grad(pt.sum(out[1]), [x])

fg = FunctionGraph([x, y], [out[1], *grad_out])
Expand All @@ -241,6 +314,20 @@ def f(x, y):
with jax.disable_jit():
fn, _ = compare_jax_and_py(fg, test_values)

# Test direct JAXOp usage
jax_op = JAXOp(
[x.type, y.type],
[
TensorType(config.floatX, shape=(3, 3)),
TensorType(config.floatX, shape=(3,)),
],
f,
)
out = jax_op(x, y)
grad_out = grad(pt.sum(out[1]), [x])
fg = FunctionGraph([x, y], [out[1], *grad_out])
fn, _ = compare_jax_and_py(fg, test_values)


def test_unknown_static_shape():
rng = np.random.default_rng(11)
Expand All @@ -252,11 +339,10 @@ def test_unknown_static_shape():

x_cumsum = pt.cumsum(x) # Now x_cumsum has an unknown shape

@as_jax_op
def f(x, y):
return x * jnp.ones(3)

out = f(x_cumsum, y)
out = as_jax_op(f)(x_cumsum, y)
grad_out = grad(pt.sum(out), [x])

fg = FunctionGraph([x, y], [out, *grad_out])
Expand All @@ -265,6 +351,17 @@ def f(x, y):
with jax.disable_jit():
fn, _ = compare_jax_and_py(fg, test_values)

# Test direct JAXOp usage
jax_op = JAXOp(
[x.type, y.type],
[TensorType(config.floatX, shape=(None,))],
f,
)
out = jax_op(x_cumsum, y)
grad_out = grad(pt.sum(out), [x])
fg = FunctionGraph([x, y], [out, *grad_out])
fn, _ = compare_jax_and_py(fg, test_values)


def test_nested_functions():
rng = np.random.default_rng(13)
Expand Down

0 comments on commit b8c4523

Please sign in to comment.