Skip to content

Commit

Permalink
Merge pull request #353 from ami-iit/sprint/fix-xela95
Browse files Browse the repository at this point in the history
[Sprint] Fix missing change of data representation
  • Loading branch information
CarlottaSartore authored Jan 24, 2025
2 parents 976ff4a + 563f33a commit 15e402e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 44 deletions.
4 changes: 1 addition & 3 deletions src/jaxsim/api/contact_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def link_contact_forces(
*,
link_forces: jtp.MatrixLike | None = None,
joint_torques: jtp.VectorLike | None = None,
**kwargs,
) -> jtp.Matrix:
"""
Compute the 6D contact forces of all links of the model in inertial representation.
Expand All @@ -26,8 +25,7 @@ def link_contact_forces(
link_forces:
The 6D external forces to apply to the links expressed in inertial representation
joint_torques:
The joint torques applied to the joints.
kwargs: Additional keyword arguments to pass to the active contact model..
The joint torques acting on the joints.
Returns:
A `(nL, 6)` array containing the stacked 6D contact forces of the links,
Expand Down
81 changes: 41 additions & 40 deletions src/jaxsim/api/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,57 +15,58 @@ def semi_implicit_euler_integration(
) -> JaxSimModelData:
"""Integrate the system state using the semi-implicit Euler method."""
# Step the dynamics forward.
with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):

dt = model.time_step
W_v̇_WB = base_acceleration_inertial
= joint_accelerations
dt = model.time_step
W_v̇_WB = base_acceleration_inertial
= joint_accelerations

B_H_W = Transform.inverse(data.base_transform).at[:3, :3].set(jnp.eye(3))
BW_X_W = Adjoint.from_transform(B_H_W)
B_H_W = Transform.inverse(data.base_transform).at[:3, :3].set(jnp.eye(3))
BW_X_W = Adjoint.from_transform(B_H_W)

new_generalized_acceleration = jnp.hstack([W_v̇_WB, ])
new_generalized_acceleration = jnp.hstack([W_v̇_WB, ])

new_generalized_velocity = (
data.generalized_velocity() + dt * new_generalized_acceleration
)
new_generalized_velocity = (
data.generalized_velocity() + dt * new_generalized_acceleration
)

new_base_velocity_inertial = new_generalized_velocity[0:6]
new_joint_velocities = new_generalized_velocity[6:]
new_base_velocity_inertial = new_generalized_velocity[0:6]
new_joint_velocities = new_generalized_velocity[6:]

base_lin_velocity_inertial = new_base_velocity_inertial[0:3]
base_lin_velocity_inertial = new_base_velocity_inertial[0:3]

new_base_velocity_mixed = BW_X_W @ new_generalized_velocity[0:6]
base_lin_velocity_mixed = new_base_velocity_mixed[0:3]
base_ang_velocity_mixed = new_base_velocity_mixed[3:6]
new_base_velocity_mixed = BW_X_W @ new_generalized_velocity[0:6]
base_lin_velocity_mixed = new_base_velocity_mixed[0:3]
base_ang_velocity_mixed = new_base_velocity_mixed[3:6]

base_quaternion_derivative = jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation(),
omega=base_ang_velocity_mixed,
omega_in_body_fixed=False,
).squeeze()
base_quaternion_derivative = jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation(),
omega=base_ang_velocity_mixed,
omega_in_body_fixed=False,
).squeeze()

new_base_position = data.base_position + dt * base_lin_velocity_mixed
new_base_quaternion = data.base_orientation() + dt * base_quaternion_derivative
new_base_position = data.base_position + dt * base_lin_velocity_mixed
new_base_quaternion = data.base_orientation() + dt * base_quaternion_derivative

base_quaternion_norm = jaxsim.math.safe_norm(new_base_quaternion)
base_quaternion_norm = jaxsim.math.safe_norm(new_base_quaternion)

new_base_quaternion = new_base_quaternion / jnp.where(
base_quaternion_norm == 0, 1.0, base_quaternion_norm
)
new_base_quaternion = new_base_quaternion / jnp.where(
base_quaternion_norm == 0, 1.0, base_quaternion_norm
)

new_joint_position = data.joint_positions + dt * new_joint_velocities
new_joint_position = data.joint_positions + dt * new_joint_velocities

data = data.replace(
validate=True,
base_quaternion=new_base_quaternion,
base_position=new_base_position,
joint_positions=new_joint_position,
joint_velocities=new_joint_velocities,
base_linear_velocity=base_lin_velocity_inertial,
# Here we use the base angular velocity in mixed representation since
# it's equivalent to the one in inertial representation
# See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9
base_angular_velocity=base_ang_velocity_mixed,
)
data = data.replace(
validate=True,
base_quaternion=new_base_quaternion,
base_position=new_base_position,
joint_positions=new_joint_position,
joint_velocities=new_joint_velocities,
base_linear_velocity=base_lin_velocity_inertial,
# Here we use the base angular velocity in mixed representation since
# it's equivalent to the one in inertial representation
# See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9
base_angular_velocity=base_ang_velocity_mixed,
)

return data
return data
2 changes: 1 addition & 1 deletion src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2049,7 +2049,7 @@ def step(
model=model,
data=data,
link_forces=W_f_L_external,
joint_force_references=τ_total,
joint_torques=τ_total,
)

# ==============================
Expand Down

0 comments on commit 15e402e

Please sign in to comment.