Skip to content

Commit

Permalink
Fix qvel bind in MJX when using multiple joints.
Browse files Browse the repository at this point in the history
Reported by #2402.

PiperOrigin-RevId: 725237347
Change-Id: I279c73e5cf2b959086292b739972afe5f70ceffb
  • Loading branch information
quagla authored and copybara-github committed Feb 10, 2025
1 parent 4156350 commit 94115ec
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions mjx/mujoco/mjx/_src/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,29 +438,28 @@ def __getname(self, name: str):
return name
else:
raise AttributeError('ctrl is not available for this type')
if name == 'qpos':
if name == 'qpos' or name == 'qvel':
if self.prefix == 'jnt_':
return name
else:
raise AttributeError('qpos is not available for this type')
raise AttributeError('qpos and qvel are not available for this type')
else:
return self.prefix + name

def __getattr__(self, name: str):
if name == 'sensordata' or name == 'qpos':
if name in ('sensordata', 'qpos', 'qvel'):
adr = num = 0
if name == 'sensordata':
adr = self.model.sensor_adr[self.id]
num = self.model.sensor_dim[self.id]
elif name == 'qpos':
adr = self.model.jnt_qposadr[self.id]
typ = self.model.jnt_type[self.id]
num = (
(typ == JointType.FREE) * JointType.FREE.qpos_width()
+ (typ == JointType.BALL) * JointType.BALL.qpos_width()
+ (typ == JointType.HINGE) * JointType.HINGE.qpos_width()
+ (typ == JointType.SLIDE) * JointType.SLIDE.qpos_width()
)
num = sum((typ == jt) * jt.qpos_width() for jt in JointType)
elif name == 'qvel':
adr = self.model.jnt_dofadr[self.id]
typ = self.model.jnt_type[self.id]
num = sum((typ == jt) * jt.dof_width() for jt in JointType)
if isinstance(self.id, list):
idx = []
for a, n in zip(adr, num):
Expand All @@ -484,12 +483,11 @@ def set(self, name: str, value: jax.Array) -> Data:
if name == 'qpos':
adr = self.model.jnt_qposadr[self.id]
typ = self.model.jnt_type[self.id]
num = (
(typ == JointType.FREE) * JointType.FREE.qpos_width()
+ (typ == JointType.BALL) * JointType.BALL.qpos_width()
+ (typ == JointType.HINGE) * JointType.HINGE.qpos_width()
+ (typ == JointType.SLIDE) * JointType.SLIDE.qpos_width()
)
num = sum((typ == jt) * jt.qpos_width() for jt in JointType)
elif name == 'qvel':
adr = self.model.jnt_dofadr[self.id]
typ = self.model.jnt_type[self.id]
num = sum((typ == jt) * jt.dof_width() for jt in JointType)
elif isinstance(self.id, list):
adr = self.id
num = [1 for _ in range(len(self.id))]
Expand Down

0 comments on commit 94115ec

Please sign in to comment.