Skip to content

Commit

Permalink
Merge pull request #106 from ami-iit/functional_tests
Browse files Browse the repository at this point in the history
Add new test suite of functional APIs
  • Loading branch information
diegoferigo authored Mar 12, 2024
2 parents 4fd2032 + a54bfae commit 795df2b
Show file tree
Hide file tree
Showing 29 changed files with 1,728 additions and 1,080 deletions.
15 changes: 13 additions & 2 deletions .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,22 @@ jobs:
with:
fetch-depth: 0

- name: Install Gazebo Classic
# - name: Install Gazebo Classic
# if: contains(matrix.os, 'ubuntu')
# run: |
# sudo apt-get update
# sudo apt-get install gazebo

# https://gazebosim.org/docs/harmonic/install_ubuntu
- name: Install Gazebo Sim
if: contains(matrix.os, 'ubuntu')
run: |
sudo apt-get update
sudo apt-get install gazebo
sudo apt-get install lsb-release wget gnupg
sudo wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg
echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null
sudo apt-get update
sudo apt-get install gz-harmonic
- name: Run the Python tests
if: contains(matrix.os, 'ubuntu')
Expand Down
3 changes: 1 addition & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- jaxlie >= 1.3.0
- jax-dataclasses >= 1.4.0
- pptree
- rod
- rod >= 0.2.0
- typing_extensions # python<3.12
# Optional dependencies from setup.cfg
# [style]
Expand All @@ -19,7 +19,6 @@ dependencies:
# [testing]
- idyntree
- pytest
- pytest-forked
- pytest-icdiff
- robot_descriptions
# [viz]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ multi_line_output = 3

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-rsxX -v --strict-markers --forked"
addopts = "-rsxX -v --strict-markers"
testpaths = [
"tests",
]
9 changes: 4 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ package_dir =
python_requires = >=3.11
install_requires =
coloredlogs
jax >= 0.4.13,< 0.4.25
jaxlib >= 0.4.13,< 0.4.25
jax >= 0.4.13
jaxlib >= 0.4.13
jaxlie >= 1.3.0
jax_dataclasses >= 1.4.0
pptree
rod
rod >= 0.2.0
typing_extensions ; python_version < '3.12'

[options.packages.find]
Expand All @@ -71,8 +71,7 @@ style =
pre-commit
testing =
idyntree
pytest >= 6.0
pytest-forked
pytest >=6.0
pytest-icdiff
robot-descriptions
viz =
Expand Down
3 changes: 2 additions & 1 deletion src/jaxsim/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import contact, data, joint, link, model, ode
from . import model, data # isort:skip
from . import common, contact, joint, link, ode, references
26 changes: 14 additions & 12 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,20 +761,22 @@ def random_model_data(
*jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi)
).as_quaternion_xyzw()[np.array([3, 0, 1, 2])]

physics_model_state.joint_positions = jaxsim.api.joint.random_joint_positions(
model=model, key=k3
)
if model.number_of_joints() > 0:
physics_model_state.joint_positions = (
jaxsim.api.joint.random_joint_positions(model=model, key=k3)
)

physics_model_state.base_linear_velocity = jax.random.uniform(
key=k4, shape=(3,), minval=v_min, maxval=v_max
)
physics_model_state.joint_velocities = jax.random.uniform(
key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
)

physics_model_state.base_angular_velocity = jax.random.uniform(
key=k5, shape=(3,), minval=ω_min, maxval=ω_max
)
if model.floating_base():
physics_model_state.base_linear_velocity = jax.random.uniform(
key=k5, shape=(3,), minval=v_min, maxval=v_max
)

physics_model_state.joint_velocities = jax.random.uniform(
key=k6, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
)
physics_model_state.base_angular_velocity = jax.random.uniform(
key=k6, shape=(3,), minval=ω_min, maxval=ω_max
)

return random_data
11 changes: 7 additions & 4 deletions src/jaxsim/api/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def name_to_idx(model: Model.JaxSimModel, *, joint_name: str) -> jtp.Int:
"""

return jnp.array(
model.physics_model.description.joints_dict[joint_name].index, dtype=int
model.physics_model.description.joints_dict[joint_name].index - 1, dtype=int
)


Expand Down Expand Up @@ -103,10 +103,13 @@ def position_limit(
) -> tuple[jtp.Float, jtp.Float]:
""""""

min = model.physics_model._joint_position_limits_min[joint_index]
max = model.physics_model._joint_position_limits_max[joint_index]
if model.physics_model.NB <= 1:
return jnp.array([]), jnp.array([])

return min.astype(float), max.astype(float)
s_min = model.physics_model._joint_position_limits_min[joint_index]
s_max = model.physics_model._joint_position_limits_max[joint_index]

return s_min.astype(float), s_max.astype(float)


@functools.partial(jax.jit, static_argnames=["joint_names"])
Expand Down
36 changes: 26 additions & 10 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,14 +549,22 @@ def forward_dynamics_aba(
else jnp.zeros((model.number_of_links(), 6))
)

references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=τ,
link_forces=f_ext,
data=data,
velocity_representation=data.velocity_representation,
)

# Compute ABA
W_v̇_WB, = jaxsim.physics.algos.aba.aba(
model=model.physics_model,
xfb=data.state.physics_model.xfb(),
q=data.state.physics_model.joint_positions,
qd=data.state.physics_model.joint_velocities,
tau=τ,
f_ext=f_ext,
tau=references.input.physics_model.tau,
f_ext=references.input.physics_model.f_ext,
)

def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC):
Expand Down Expand Up @@ -602,6 +610,12 @@ def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC):
W_vl_WC=W_vl_WC,
)

# The ABA algorithm already returns a zero base 6D acceleration for
# fixed-based models. However, the to_active function introduces an
# additional acceleration component in Mixed representation.
# Here below we make sure that the base acceleration is zero.
C_v̇_WB = C_v̇_WB if model.floating_base() else jnp.zeros(6).astype(float)

# Adjust shape
= jnp.atleast_1d(.squeeze())

Expand Down Expand Up @@ -948,18 +962,20 @@ def free_floating_bias_forces(
data.state.physics_model.joint_positions
)

data_rnea.state.physics_model.base_linear_velocity = (
data.state.physics_model.base_linear_velocity
)

data_rnea.state.physics_model.base_angular_velocity = (
data.state.physics_model.base_angular_velocity
)

data_rnea.state.physics_model.joint_velocities = (
data.state.physics_model.joint_velocities
)

# Make sure that base velocity is zero for fixed-base model.
if model.floating_base():
data_rnea.state.physics_model.base_linear_velocity = (
data.state.physics_model.base_linear_velocity
)

data_rnea.state.physics_model.base_angular_velocity = (
data.state.physics_model.base_angular_velocity
)

return jnp.hstack(
inverse_dynamics(
model=model,
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def system_velocity_dynamics(
lambda nc: (
jnp.vstack(
jnp.equal(
np.array(model.physics_model.gc.body, dtype=int), nc
jnp.array(model.physics_model.gc.body, dtype=int), nc
).astype(int)
)
* W_f_Ci
Expand Down
12 changes: 1 addition & 11 deletions src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,20 +209,10 @@ def build(
The integrator object.
"""

# Adjust the shape of the tableau coefficients.
c = jnp.atleast_1d(cls.c.squeeze())
b = jnp.atleast_2d(jnp.vstack(cls.b.squeeze()))
A = jnp.atleast_2d(cls.A.squeeze())

# Check validity of the Butcher tableau.
if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
if not ExplicitRungeKutta.butcher_tableau_is_valid(A=cls.A, b=cls.b, c=cls.c):
raise ValueError("The Butcher tableau of this class is not valid.")

# Store the adjusted shapes of the tableau coefficients.
cls.c = c
cls.b = b
cls.A = A

# Check that b.T has enough rows based on the configured index of the solution.
if cls.row_index_of_solution >= cls.b.T.shape[0]:
msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."
Expand Down
26 changes: 6 additions & 20 deletions src/jaxsim/integrators/fixed_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,18 @@
@jax_dataclasses.pytree_dataclass
class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):

A: ClassVar[jax.typing.ArrayLike] = jnp.array(
[
[0],
]
).astype(float)
A: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(0).astype(float)

b: ClassVar[jax.typing.ArrayLike] = (
jnp.array(
[
[1],
]
)
.astype(float)
.transpose()
)
b: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(1).astype(float).transpose()

c: ClassVar[jax.typing.ArrayLike] = jnp.array(
[0],
).astype(float)
c: ClassVar[jax.typing.ArrayLike] = jnp.atleast_1d(0).astype(float)

row_index_of_solution: ClassVar[int] = 0
order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,)


@jax_dataclasses.pytree_dataclass
class Heun(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):

A: ClassVar[jax.typing.ArrayLike] = jnp.array(
[
Expand Down Expand Up @@ -144,12 +130,12 @@ def post_process_state(


@jax_dataclasses.pytree_dataclass
class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]):
class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[ODEState]):
pass


@jax_dataclasses.pytree_dataclass
class HeunSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]):
class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[ODEState]):
pass


Expand Down
42 changes: 27 additions & 15 deletions src/jaxsim/physics/algos/aba.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)

# Propagate link velocity
vJ = S[i] * qd[ii] if qd.size != 0 else S[i] * 0
vJ = S[i] * qd[ii]

v_i = i_X_λi[i] @ v[λ[i]] + vJ
v = v.at[i].set(v_i)
Expand All @@ -134,10 +134,14 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:

return (i_X_λi, v, c, MA, pA, i_X_0), None

(i_X_λi, v, c, MA, pA, i_X_0), _ = jax.lax.scan(
f=loop_body_pass1,
init=pass_1_carry,
xs=np.arange(start=1, stop=model.NB),
(i_X_λi, v, c, MA, pA, i_X_0), _ = (
jax.lax.scan(
f=loop_body_pass1,
init=pass_1_carry,
xs=np.arange(start=1, stop=model.NB),
)
if model.NB > 1
else [(i_X_λi, v, c, MA, pA, i_X_0), None]
)

U = jnp.zeros_like(S)
Expand Down Expand Up @@ -166,7 +170,7 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
d_i = S[i].T @ U[i]
d = d.at[i].set(d_i.squeeze())

u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i]
u_i = tau[ii] - S[i].T @ pA[i]
u = u.at[i].set(u_i.squeeze())

# Compute the articulated-body inertia and bias forces of this link
Expand Down Expand Up @@ -196,10 +200,14 @@ def propagate(

return (U, d, u, MA, pA), None

(U, d, u, MA, pA), _ = jax.lax.scan(
f=loop_body_pass2,
init=pass_2_carry,
xs=np.flip(np.arange(start=1, stop=model.NB)),
(U, d, u, MA, pA), _ = (
jax.lax.scan(
f=loop_body_pass2,
init=pass_2_carry,
xs=np.flip(np.arange(start=1, stop=model.NB)),
)
if model.NB > 1
else [(U, d, u, MA, pA), None]
)

if model.is_floating_base:
Expand All @@ -226,15 +234,19 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]:
qdd_ii = (u[i] - U[i].T @ a_i) / d[i]
qdd = qdd.at[i - 1].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd

a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i
a_i = a_i + S[i] * qdd[ii]
a = a.at[i].set(a_i)

return (a, qdd), None

(a, qdd), _ = jax.lax.scan(
f=loop_body_pass3,
init=pass_3_carry,
xs=np.arange(1, model.NB),
(a, qdd), _ = (
jax.lax.scan(
f=loop_body_pass3,
init=pass_3_carry,
xs=np.arange(1, model.NB),
)
if model.NB > 1
else [(a, qdd), None]
)

# Handle 1 DoF models
Expand Down
Loading

0 comments on commit 795df2b

Please sign in to comment.