Skip to content

Commit

Permalink
[Models] Vectorized edges length and force computation to avoid nan
Browse files Browse the repository at this point in the history
  • Loading branch information
arpastrana committed Feb 9, 2025
1 parent 7089b11 commit f81156c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/jax_cem/equilibrium/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def edges_length(
"""
vectors = edges_vector(xyz, structure.connectivity)

return vector_length(vectors)
return vmap(vector_length)(vectors)

# ------------------------------------------------------------------------------
# Edge forces
Expand Down Expand Up @@ -440,12 +440,13 @@ def trails_force(
The force in the trail edges of a structure.
"""
residuals = jnp.concatenate(residuals)
forces = trail_force(residuals)
forces = vmap(trail_force)(residuals)

lengths = jnp.reshape(jnp.concatenate(lengths), (-1, 1))

return jnp.copysign(forces, lengths)


# ------------------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------------------
Expand Down
8 changes: 7 additions & 1 deletion src/jax_cem/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ def vector_length(v: jax.Array, keepdims: bool = True) -> jax.Array:
"""
Calculate the length of a vector over its last dimension.
"""
return jnp.linalg.norm(v, axis=-1, keepdims=keepdims)
v = jnp.nan_to_num(v)
is_zero_vector = jnp.allclose(v, 0.0)
d = jnp.where(is_zero_vector, jnp.ones_like(v), v) # replace d with ones if is_zero

length = jnp.where(is_zero_vector, 0.0, jnp.linalg.norm(d, axis=-1, keepdims=keepdims))

return length


def vector_normalized(u: jax.Array) -> jax.Array:
Expand Down

0 comments on commit f81156c

Please sign in to comment.