Skip to content

Commit

Permalink
Update tests to use mixed velocity representation for link forces in …
Browse files Browse the repository at this point in the history
…`step`
  • Loading branch information
xela-95 committed Jan 24, 2025
1 parent f6dcd5b commit 7f83dad
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
10 changes: 5 additions & 5 deletions tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_ad_integration(

_, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=subkey
model=model, velocity_representation=VelRepr.Mixed, key=subkey
)

# State in VelRepr.Inertial representation.
Expand All @@ -308,7 +308,7 @@ def test_ad_integration(
= data.joint_velocities

# Inputs.
W_f_L = references.link_forces(model=model)
LW_f_L = references.link_forces(model=model, data=data)
τ = references.joint_force_references(model=model)

# ====
Expand All @@ -323,7 +323,7 @@ def step(
W_v_WB: jtp.Vector,
: jtp.Vector,
τ: jtp.Vector,
W_f_L: jtp.Matrix,
LW_f_L: jtp.Matrix,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:

# When JAX tests against finite differences, the injected ε will make the
Expand All @@ -344,7 +344,7 @@ def step(
model=model,
data=data_x0,
joint_force_references=τ,
link_forces=W_f_L,
link_forces_mixed=LW_f_L,
)

xf_W_p_B = data_xf.base_position
Expand All @@ -360,7 +360,7 @@ def step(
# current implementation of `optax` optimizers in the relaxed rigid contact model.
check_grads(
f=step,
args=(W_p_B, W_Q_B, s, W_v_WB, , τ, W_f_L),
args=(W_p_B, W_Q_B, s, W_v_WB, , τ, LW_f_L),
order=AD_ORDER,
modes=["fwd"],
eps=ε,
Expand Down
26 changes: 14 additions & 12 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,14 @@ def test_box_with_external_forces(
data = data0.copy()

# ... and step the simulation.
for _ in T_ns:
with references.switch_velocity_representation(VelRepr.Mixed):
for _ in T_ns:

data = js.model.step(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
)
data = js.model.step(
model=model,
data=data,
link_forces_mixed=references.link_forces(model=model, data=data),
)

# Check that the box didn't move.
assert data.base_position == pytest.approx(data0.base_position)
Expand Down Expand Up @@ -143,13 +144,14 @@ def test_box_with_zero_gravity(
data = data0.copy()

# ... and step the simulation.
for _ in T:
with references.switch_velocity_representation(jaxsim.VelRepr.Mixed):
for _ in T:

data = js.model.step(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
)
data = js.model.step(
model=model,
data=data,
link_forces_mixed=references.link_forces(model=model, data=data),
)

# Check that the box moved as expected.
assert data.base_position == pytest.approx(
Expand Down

0 comments on commit 7f83dad

Please sign in to comment.