Skip to content

Commit

Permalink
Fix #118
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Dec 15, 2024
1 parent 42cf2d8 commit fc40e5b
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions tests/test_vmap_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import functools as ft

import equinox as eqx
import equinox.internal as eqxi
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
Expand Down Expand Up @@ -123,10 +122,10 @@ def linear_solve(operator, vector):
eqx.filter_vmap(
lambda x: x.as_matrix(),
in_axes=vmap1_op,
out_axes=eqxi.if_mapped(0),
out_axes=None if vmap1_op is None else 0,
),
in_axes=vmap2_op,
out_axes=eqxi.if_mapped(0),
out_axes=None if vmap2_op is None else 0,
)(operator)

vmap1_axes = (vmap1_op, vmap1_vec)
Expand Down

0 comments on commit fc40e5b

Please sign in to comment.