From 94115ecf0888213df3574c5c7d6c15d68b9a0946 Mon Sep 17 00:00:00 2001 From: Alessio Quaglino Date: Mon, 10 Feb 2025 09:08:24 -0800 Subject: [PATCH] Fix qvel bind in MJX when using multiple joints. Reported by #2402. PiperOrigin-RevId: 725237347 Change-Id: I279c73e5cf2b959086292b739972afe5f70ceffb --- mjx/mujoco/mjx/_src/support.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/mjx/mujoco/mjx/_src/support.py b/mjx/mujoco/mjx/_src/support.py index 9556c930c7..40bec2a93e 100644 --- a/mjx/mujoco/mjx/_src/support.py +++ b/mjx/mujoco/mjx/_src/support.py @@ -438,16 +438,16 @@ def __getname(self, name: str): return name else: raise AttributeError('ctrl is not available for this type') - if name == 'qpos': + if name == 'qpos' or name == 'qvel': if self.prefix == 'jnt_': return name else: - raise AttributeError('qpos is not available for this type') + raise AttributeError('qpos and qvel are not available for this type') else: return self.prefix + name def __getattr__(self, name: str): - if name == 'sensordata' or name == 'qpos': + if name in ('sensordata', 'qpos', 'qvel'): adr = num = 0 if name == 'sensordata': adr = self.model.sensor_adr[self.id] @@ -455,12 +455,11 @@ def __getattr__(self, name: str): elif name == 'qpos': adr = self.model.jnt_qposadr[self.id] typ = self.model.jnt_type[self.id] - num = ( - (typ == JointType.FREE) * JointType.FREE.qpos_width() - + (typ == JointType.BALL) * JointType.BALL.qpos_width() - + (typ == JointType.HINGE) * JointType.HINGE.qpos_width() - + (typ == JointType.SLIDE) * JointType.SLIDE.qpos_width() - ) + num = sum((typ == jt) * jt.qpos_width() for jt in JointType) + elif name == 'qvel': + adr = self.model.jnt_dofadr[self.id] + typ = self.model.jnt_type[self.id] + num = sum((typ == jt) * jt.dof_width() for jt in JointType) if isinstance(self.id, list): idx = [] for a, n in zip(adr, num): @@ -484,12 +483,11 @@ def set(self, name: str, value: jax.Array) -> Data: if name == 'qpos': adr = self.model.jnt_qposadr[self.id] typ = self.model.jnt_type[self.id] - num = ( - (typ == JointType.FREE) * JointType.FREE.qpos_width() - + (typ == JointType.BALL) * JointType.BALL.qpos_width() - + (typ == JointType.HINGE) * JointType.HINGE.qpos_width() - + (typ == JointType.SLIDE) * JointType.SLIDE.qpos_width() - ) + num = sum((typ == jt) * jt.qpos_width() for jt in JointType) + elif name == 'qvel': + adr = self.model.jnt_dofadr[self.id] + typ = self.model.jnt_type[self.id] + num = sum((typ == jt) * jt.dof_width() for jt in JointType) elif isinstance(self.id, list): adr = self.id num = [1 for _ in range(len(self.id))]