diff --git a/jax_cfd/base/initial_conditions.py b/jax_cfd/base/initial_conditions.py index c1d4135..c2bc579 100644 --- a/jax_cfd/base/initial_conditions.py +++ b/jax_cfd/base/initial_conditions.py @@ -65,7 +65,7 @@ def _log_normal_pdf(x, mode, variance=.25): def _max_speed(v): - return jnp.linalg.norm([u.data for u in v], axis=0).max() + return jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max() def filtered_velocity_field( diff --git a/jax_cfd/base/initial_conditions_test.py b/jax_cfd/base/initial_conditions_test.py index 76259ca..e92f3fc 100644 --- a/jax_cfd/base/initial_conditions_test.py +++ b/jax_cfd/base/initial_conditions_test.py @@ -50,7 +50,7 @@ def test_filtered_velocity_field( self, seed, grid, maximum_velocity, peak_wavenumber): v = ic.filtered_velocity_field( jax.random.PRNGKey(seed), grid, maximum_velocity, peak_wavenumber) - actual_maximum_velocity = jnp.linalg.norm([u.data for u in v], axis=0).max() + actual_maximum_velocity = jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max() max_divergence = fd.divergence(v).data.max() # Assert that initial velocity is divergence free diff --git a/jax_cfd/collocated/initial_conditions.py b/jax_cfd/collocated/initial_conditions.py index cb1f4ff..0918be9 100644 --- a/jax_cfd/collocated/initial_conditions.py +++ b/jax_cfd/collocated/initial_conditions.py @@ -44,7 +44,7 @@ def _log_normal_pdf(x, mode, variance=.25): def _max_speed(v): - return jnp.linalg.norm([u.data for u in v], axis=0).max() + return jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max() def filtered_velocity_field( diff --git a/jax_cfd/collocated/initial_conditions_test.py b/jax_cfd/collocated/initial_conditions_test.py index 006a4af..f5bfe20 100644 --- a/jax_cfd/collocated/initial_conditions_test.py +++ b/jax_cfd/collocated/initial_conditions_test.py @@ -43,7 +43,7 @@ def test_filtered_velocity_field( self, seed, grid, maximum_velocity, peak_wavenumber): v = initial_conditions.filtered_velocity_field( jax.random.PRNGKey(seed), grid, maximum_velocity, peak_wavenumber) - actual_maximum_velocity = jnp.linalg.norm([u.data for u in v], axis=0).max() + actual_maximum_velocity = jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max() max_divergence = fd.centered_divergence(v).data.max() # Assert that initial velocity is divergence free