Skip to content

Commit

Permalink
Merge pull request #352 from ami-iit/sprint/gpu-transfert
Browse files Browse the repository at this point in the history
[Sprint] Fix device transfer and exceptions handling
  • Loading branch information
flferretti authored Jan 24, 2025
2 parents 15e402e + 7aff901 commit ccb7f27
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/guide/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ The logging and exceptions configurations is controlled by the following environ

*Default:* ``DEBUG`` for development, ``WARNING`` for production.

- ``JAXSIM_DISABLE_EXCEPTIONS``: Disables the runtime checks and exceptions.
- ``JAXSIM_ENABLE_EXCEPTIONS``: Enables the runtime checks and exceptions. Note that enabling exceptions might lead to device-to-host transfer of data, increasing the computational time required.

*Default:* ``False``.

Expand Down
14 changes: 7 additions & 7 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class JaxSimModel(JaxsimDataclass):

model_name: Static[str]

time_step: jtp.FloatLike = dataclasses.field(
default_factory=lambda: jnp.array(0.001, dtype=float),
time_step: float = dataclasses.field(
default=0.001,
)

terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
Expand Down Expand Up @@ -91,7 +91,7 @@ def __hash__(self) -> int:
return hash(
(
hash(self.model_name),
hash(float(self.time_step)),
hash(self.time_step),
hash(self.kin_dyn_parameters),
hash(self.contact_model),
)
Expand Down Expand Up @@ -222,7 +222,7 @@ def build(
time_step = (
time_step
if time_step is not None
else JaxSimModel.__dataclass_fields__["time_step"].default_factory()
else JaxSimModel.__dataclass_fields__["time_step"].default
)

# Create the default contact model.
Expand Down Expand Up @@ -317,7 +317,7 @@ def floating_base(self) -> bool:
True if the model is floating-base, False otherwise.
"""

return bool(self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6)
return self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6

def base_link(self) -> str:
"""
Expand Down Expand Up @@ -348,7 +348,7 @@ def dofs(self) -> int:
the number of joints. In the future, this could be different.
"""

return int(sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:]))
return sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:])

def joint_names(self) -> tuple[str, ...]:
"""
Expand Down Expand Up @@ -431,7 +431,7 @@ def reduce(
for joint_name in set(model.joint_names()) - set(considered_joints):
j = intermediate_description.joints_dict[joint_name]
with j.mutable_context():
j.initial_position = float(locked_joint_positions.get(joint_name, 0.0))
j.initial_position = locked_joint_positions.get(joint_name, 0.0)

# Reduce the model description.
# If `considered_joints` contains joints not existing in the model,
Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def raise_if(

# Disable host callback if running on unsupported hardware or if the user
# explicitly disabled it.
if jax.devices()[0].platform in {"tpu", "METAL"} or os.environ.get(
"JAXSIM_DISABLE_EXCEPTIONS", 0
if jax.devices()[0].platform in {"tpu", "METAL"} or not os.environ.get(
"JAXSIM_ENABLE_EXCEPTIONS", 0
):
return

Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os

os.environ["JAXSIM_ENABLE_EXCEPTIONS"] = "1"

import pathlib
import subprocess

Expand Down
2 changes: 1 addition & 1 deletion tests/test_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_model_creation_and_reduction(
locked_joint_positions=dict(
zip(
model_full.joint_names(),
data_full.joint_positions,
data_full.joint_positions.tolist(),
strict=True,
)
),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def step(
model=model,
data=data_x0,
joint_force_references=τ,
link_forces=W_f_L,
link_forces_inertial=W_f_L,
)

xf_W_p_B = data_xf.base_position
Expand Down
4 changes: 2 additions & 2 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_box_with_external_forces(
data = js.model.step(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
link_forces_inertial=references._link_forces,
)

# Check that the box didn't move.
Expand Down Expand Up @@ -148,7 +148,7 @@ def test_box_with_zero_gravity(
data = js.model.step(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
link_forces_inertial=references.link_forces(model=model, data=data),
)

# Check that the box moved as expected.
Expand Down

0 comments on commit ccb7f27

Please sign in to comment.