diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 84a50d6c5..4cac7cb34 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -297,7 +297,7 @@ def test_ad_integration( _, subkey = jax.random.split(prng_key, num=2) data, references = get_random_data_and_references( - model=model, velocity_representation=VelRepr.Inertial, key=subkey + model=model, velocity_representation=VelRepr.Mixed, key=subkey ) # State in VelRepr.Inertial representation. @@ -308,7 +308,7 @@ def test_ad_integration( ṡ = data.joint_velocities # Inputs. - W_f_L = references.link_forces(model=model) + LW_f_L = references.link_forces(model=model, data=data) τ = references.joint_force_references(model=model) # ==== @@ -323,7 +323,7 @@ def step( W_v_WB: jtp.Vector, ṡ: jtp.Vector, τ: jtp.Vector, - W_f_L: jtp.Matrix, + LW_f_L: jtp.Matrix, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: # When JAX tests against finite differences, the injected ε will make the @@ -344,7 +344,7 @@ def step( model=model, data=data_x0, joint_force_references=τ, - link_forces=W_f_L, + link_forces_mixed=LW_f_L, ) xf_W_p_B = data_xf.base_position @@ -360,7 +360,7 @@ def step( # current implementation of `optax` optimizers in the relaxed rigid contact model. check_grads( f=step, - args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L), + args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, LW_f_L), order=AD_ORDER, modes=["fwd"], eps=ε, diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 86e3696e3..93f19ccff 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -69,13 +69,14 @@ def test_box_with_external_forces( data = data0.copy() # ... and step the simulation. - for _ in T_ns: + with references.switch_velocity_representation(VelRepr.Mixed): + for _ in T_ns: - data = js.model.step( - model=model, - data=data, - link_forces=references.link_forces(model=model, data=data), - ) + data = js.model.step( + model=model, + data=data, + link_forces_mixed=references.link_forces(model=model, data=data), + ) # Check that the box didn't move. assert data.base_position == pytest.approx(data0.base_position) @@ -143,13 +144,14 @@ def test_box_with_zero_gravity( data = data0.copy() # ... and step the simulation. - for _ in T: + with references.switch_velocity_representation(jaxsim.VelRepr.Mixed): + for _ in T: - data = js.model.step( - model=model, - data=data, - link_forces=references.link_forces(model=model, data=data), - ) + data = js.model.step( + model=model, + data=data, + link_forces_mixed=references.link_forces(model=model, data=data), + ) # Check that the box moved as expected. assert data.base_position == pytest.approx(