Conversation
20b91ca to
9f878f1
Compare
6e718c6 to
2da826b
Compare
There was a problem hiding this comment.
This should probably go into the mujoco warp repo. It could also use a bit more thorough review as I'm not super fluent with the FFI semantics.
493b601 to
7f0f35f
Compare
|
This should be ready to be reviewed. I see some tests failing locally like this: Details$ pytest
============================= test session starts =============================
platform darwin -- Python 3.12.10, pytest-9.0.2, pluggy-1.6.0
rootdir: /mujoco/main/mjx
configfile: pyproject.toml
collected 718 items
mujoco/mjx/_src/collision_driver_test.py .............................. [ 4%]
....... [ 5%]
mujoco/mjx/_src/constraint_test.py .......... [ 6%]
mujoco/mjx/_src/dataclasses_test.py . [ 6%]
mujoco/mjx/_src/forward_test.py ......... [ 7%]
mujoco/mjx/_src/inverse_test.py ....... [ 8%]
mujoco/mjx/_src/io_test.py ............................................ [ 15%]
........................................................ [ 22%]
mujoco/mjx/_src/math_test.py .......................................... [ 28%]
....................................................................... [ 38%]
........................ [ 41%]
mujoco/mjx/_src/mesh_test.py ... [ 42%]
mujoco/mjx/_src/passive_test.py . [ 42%]
mujoco/mjx/_src/ray_test.py ........... [ 44%]
mujoco/mjx/_src/scan_test.py .... [ 44%]
mujoco/mjx/_src/sensor_test.py .......... [ 45%]
mujoco/mjx/_src/smooth_test.py ............................. [ 50%]
mujoco/mjx/_src/solver_test.py ............. [ 51%]
mujoco/mjx/_src/support_test.py .............. [ 53%]
mujoco/mjx/integration_test/collision_driver_test.py .................. [ 56%]
....................................................................... [ 66%]
....................................................................... [ 76%]
....................................................................... [ 85%]
......................... [ 89%]
mujoco/mjx/integration_test/forward_test.py ........................... [ 93%]
... [ 93%]
mujoco/mjx/integration_test/smooth_test.py ............................ [ 97%]
.. [ 97%]
mujoco/mjx/warp/collision_driver_test.py . [ 97%]
mujoco/mjx/warp/forward_test.py .F..FF..... [ 99%]
mujoco/mjx/warp/smooth_test.py .... [100%]
================================== FAILURES ===================================
__________________________ ForwardTest.test_forward1 __________________________
self = <mjx.warp.forward_test.ForwardTest testMethod=test_forward1>
xml = 'humanoid/humanoid.xml', batch_size = 7
@parameterized.product(
xml=(
'humanoid/humanoid.xml',
'pendula.xml',
),
batch_size=(1, 7),
)
def test_forward(self, xml: str, batch_size: int):
if not _FORCE_TEST:
if not mjxw.WARP_INSTALLED:
self.skipTest('Warp not installed.')
m = test_util.load_test_file(xml)
m.opt.iterations = 10
m.opt.ls_iterations = 10
m.opt.jacobian = mujoco.mjtJacobian.mjJAC_DENSE
mx = mjx.put_model(m, impl='warp')
d = mujoco.MjData(m)
worldids = jp.arange(batch_size)
dx_batch = jax.vmap(functools.partial(tu.make_data, m))(worldids)
dx_batch = jax.jit(jax.vmap(forward.forward, in_axes=(None, 0)))(
mx, dx_batch
)
for i in range(batch_size):
dx = dx_batch[i]
d.qpos[:] = dx.qpos
d.qvel[:] = dx.qvel
d.ctrl[:] = dx.ctrl
d.mocap_pos[:] = dx.mocap_pos
d.mocap_quat[:] = dx.mocap_quat
mujoco.mj_forward(m, d)
# fwd_position
tu.assert_attr_eq(dx, d, 'xpos')
tu.assert_attr_eq(dx, d, 'xquat')
tu.assert_attr_eq(dx, d, 'xipos')
tu.assert_eq(d.ximat.reshape((-1, 3, 3)), dx.ximat, 'ximat')
tu.assert_attr_eq(dx, d, 'xanchor')
tu.assert_attr_eq(dx, d, 'xaxis')
tu.assert_attr_eq(dx, d, 'geom_xpos')
tu.assert_eq(dx.geom_xmat, d.geom_xmat.reshape((-1, 3, 3)), 'geom_xmat')
if m.nsite:
tu.assert_attr_eq(dx, d, 'site_xpos')
tu.assert_eq(dx.site_xmat, d.site_xmat.reshape((-1, 3, 3)), 'site_xmat')
tu.assert_attr_eq(dx, d, 'cdof')
tu.assert_attr_eq(dx._impl, d, 'cinert')
tu.assert_attr_eq(dx, d, 'subtree_com')
if m.nlight:
tu.assert_attr_eq(dx._impl, d, 'light_xpos')
tu.assert_attr_eq(dx._impl, d, 'light_xdir')
if m.ncam:
tu.assert_attr_eq(dx, d, 'cam_xpos')
tu.assert_eq(dx.cam_xmat, d.cam_xmat.reshape((-1, 3, 3)), 'cam_xmat')
tu.assert_attr_eq(dx, d, 'ten_length')
tu.assert_attr_eq(dx._impl, d, 'ten_J')
tu.assert_attr_eq(dx._impl, d, 'ten_wrapadr')
tu.assert_attr_eq(dx._impl, d, 'ten_wrapnum')
tu.assert_attr_eq(dx._impl, d, 'wrap_xpos')
tu.assert_attr_eq(dx._impl, d, 'wrap_obj')
tu.assert_attr_eq(dx._impl, d, 'crb')
qm = np.zeros((m.nv, m.nv))
mujoco.mj_fullM(m, qm, d.qM)
# mjwarp adds padding to qM
tu.assert_eq(qm, dx._impl.qM[: m.nv, : m.nv], 'qM')
# qLD is fused in a cholesky factorize and solve, and not written to.
tu.assert_contact_eq(d, dx, worldid=i)
tu.assert_attr_eq(dx, d, 'actuator_length')
actuator_moment = np.zeros((m.nu, m.nv))
mujoco.mju_sparse2dense(
actuator_moment,
d.actuator_moment,
d.moment_rownnz,
d.moment_rowadr,
d.moment_colind,
)
tu.assert_eq(dx._impl.actuator_moment, actuator_moment, 'actuator_moment')
# fwd_velocity
tu.assert_attr_eq(dx._impl, d, 'actuator_velocity')
tu.assert_attr_eq(dx, d, 'cvel')
tu.assert_attr_eq(dx, d, 'cdof_dot')
tu.assert_attr_eq(dx._impl, d, 'qfrc_spring')
tu.assert_attr_eq(dx._impl, d, 'qfrc_damper')
tu.assert_attr_eq(dx, d, 'qfrc_gravcomp')
tu.assert_attr_eq(dx, d, 'qfrc_fluid')
tu.assert_attr_eq(dx, d, 'qfrc_passive')
tu.assert_attr_eq(dx, d, 'qfrc_bias')
> tu.assert_efc_eq(d, dx, worldid=i)
mujoco/mjx/warp/forward_test.py:179:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
mujoco/mjx/warp/test_util.py:199: in assert_efc_eq
assert_eq(jp_, j, 'J')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Array([[-0.9875637 , 0.15721938, 1. , -0.00917677, 1.335037 ,
-0.10656909, -0.10577767, 1.1439778 ... -0.03383656, 0. , 0. , 0. , 0. ,
0. , 0. ]], dtype=float32)
b = array([[-0.9875637 , 0.15721938, 1. , -0.00917675, 1.33503689,
-0.10656907, -0.10577765, 1.14397789....07388643,
-0.0338367 , 0. , 0. , 0. , 0. ,
0. , 0. ]])
name = 'J'
def assert_eq(a, b, name):
tol = _TOLERANCE * 10 # avoid test noise
err_msg = f'mismatch: {name}'
> np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)
E AssertionError:
E Not equal to tolerance rtol=0.0005, atol=0.0005
E mismatch: J
E Mismatched elements: 160 / 432 (37%)
E First 5 mismatches are at indices:
E [4, 0]: -0.9871430993080139 (ACTUAL), -0.999479684061378 (DESIRED)
E [4, 1]: 0.15983904898166656 (ACTUAL), -0.03225462987801541 (DESIRED)
E [4, 3]: 0.16618438065052032 (ACTUAL), -0.22991350005486522 (DESIRED)
E [4, 4]: 1.2810968160629272 (ACTUAL), 1.3359474985903166 (DESIRED)
E [4, 5]: 0.044975683093070984 (ACTUAL), -0.0601160972922109 (DESIRED)
E Max absolute difference among violations: 0.87907806
E Max relative difference among violations: 12.74774404
E ACTUAL: array([[-0.987564, 0.157219, 1. , -0.009177, 1.335037, -0.106569,
E -0.105778, 1.143978, -0.053232, 0.039397, -0.004723, 0.880073,
E -0.497201, 0.173519, -0.011014, 0. , 0. , 0. ,...
E DESIRED: array([[-0.987564, 0.157219, 1. , -0.009177, 1.335037, -0.106569,
E -0.105778, 1.143978, -0.053232, 0.039397, -0.004723, 0.880073,
E -0.497201, 0.173519, -0.011014, 0. , 0. , 0. ,...
mujoco/mjx/warp/test_util.py:36: AssertionError
________________________ ForwardTest.test_jit_caching0 ________________________
self = <mjx.warp.forward_test.ForwardTest testMethod=test_jit_caching0>
xml = 'pendula.xml'
@parameterized.parameters(
'pendula.xml',
'humanoid/humanoid.xml',
)
def test_jit_caching(self, xml):
"""Tests jit caching on the full step function."""
if not _FORCE_TEST:
if not mjxw.WARP_INSTALLED:
self.skipTest('Warp not installed.')
batch_size = 7
m = test_util.load_test_file(xml)
mx = mjx.put_model(m, impl='warp')
keys = jp.arange(batch_size)
dx_batch = jax.vmap(functools.partial(tu.make_data, m))(keys)
step_fn = jax.jit(jax.vmap(forward.step, in_axes=(None, 0)))
dx_batch1 = step_fn(mx, dx_batch)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), dx_batch1)
> self.assertEqual(step_fn._cache_size(), 1)
E AssertionError: 0 != 1
mujoco/mjx/warp/forward_test.py:76: AssertionError
---------------------------- Captured stdout call -----------------------------
Module mul_m_sparse_diag__locals___mul_m_sparse_diag_1de23634 4949128 load on device 'cpu' took 195.10 ms (compiled)
Module mul_m_sparse_ij__locals___mul_m_sparse_ij_b1bb5fb3 8b21b7f load on device 'cpu' took 204.99 ms (compiled)
Module update_gradient_JTDAJ_sparse_tiled__locals__kernel_8f59ead1 6963331 load on device 'cpu' took 269.85 ms (compiled)
________________________ ForwardTest.test_jit_caching1 ________________________
self = <mjx.warp.forward_test.ForwardTest testMethod=test_jit_caching1>
xml = 'humanoid/humanoid.xml'
@parameterized.parameters(
'pendula.xml',
'humanoid/humanoid.xml',
)
def test_jit_caching(self, xml):
"""Tests jit caching on the full step function."""
if not _FORCE_TEST:
if not mjxw.WARP_INSTALLED:
self.skipTest('Warp not installed.')
batch_size = 7
m = test_util.load_test_file(xml)
mx = mjx.put_model(m, impl='warp')
keys = jp.arange(batch_size)
dx_batch = jax.vmap(functools.partial(tu.make_data, m))(keys)
step_fn = jax.jit(jax.vmap(forward.step, in_axes=(None, 0)))
dx_batch1 = step_fn(mx, dx_batch)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), dx_batch1)
> self.assertEqual(step_fn._cache_size(), 1)
E AssertionError: 0 != 1
mujoco/mjx/warp/forward_test.py:76: AssertionError
============================== warnings summary ===============================
mujoco/mjx/_src/collision_driver_test.py::ConvexTest::test_box_box
mujoco/mjx/_src/collision_driver_test.py::ConvexTest::test_box_box_edge
mujoco/mjx/_src/collision_driver_test.py::ConvexTest::test_convex_convex
mujoco/mjx/_src/collision_driver_test.py::ConvexTest::test_convex_convex_edge
mujoco/mjx/_src/collision_driver_test.py::HFieldTest::test_hfield_deep
mujoco/mjx/_src/support_test.py::SupportTest::test_bind
mujoco/mjx/integration_test/smooth_test.py::TransmissionIntegrationTest::test_transmission14
/mujoco/main/.venv/lib/python3.12/site-packages/jax/_src/abstract_arrays.py:135: RuntimeWarning: overflow encountered in cast
return literals.TypedNdArray(np.asarray(x, dtype), weak_type=False)
mujoco/mjx/_src/forward_test.py::ActuatorTest::test_actuator2
mujoco/mjx/_src/io_test.py::DataIOTest::test_qm_mapm2m0
mujoco/mjx/_src/passive_test.py::PassiveTest::test_passive
mujoco/mjx/integration_test/collision_driver_test.py::CollisionDriverIntegrationTest::test_collision_driver0
mujoco/mjx/integration_test/collision_driver_test.py::CollisionDriverIntegrationTest::test_collision_driver115
mujoco/mjx/integration_test/collision_driver_test.py::CollisionDriverIntegrationTest::test_collision_driver125
/mujoco/main/.venv/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py:2412: DeprecationWarning: Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated. Please use 'arr', 'min' or 'max' respectively instead.
ans_pytree = fun(*args, **kwargs)
mujoco/mjx/_src/io_test.py::DataIOTest::test_make_data_warp
/mujoco/main/mjx/mujoco/mjx/_src/io_test.py:468: DeprecationWarning: nconmax will be deprecated in mujoco-mjx>=3.5. Use naconmax instead.
d = mjx.make_data(m, impl='warp', nconmax=9, njmax=23)
mujoco/mjx/integration_test/collision_driver_test.py: 490 warnings
/mujoco/main/mjx/mujoco/mjx/integration_test/collision_driver_test.py:80: DeprecationWarning: Accessing `contact` directly from `Data` is deprecated. Access it via `data._impl.contact` instead.
idx_mjx = list(zip(dx.contact.geom1, dx.contact.geom2))
mujoco/mjx/integration_test/collision_driver_test.py: 245 warnings
/mujoco/main/mjx/mujoco/mjx/integration_test/collision_driver_test.py:86: DeprecationWarning: Accessing `contact` directly from `Data` is deprecated. Access it via `data._impl.contact` instead.
lambda x: x.take(np.array(idx), axis=0), dx.contact
mujoco/mjx/integration_test/collision_driver_test.py: 11 warnings
/mujoco/main/mjx/mujoco/mjx/integration_test/collision_driver_test.py:75: DeprecationWarning: Accessing `contact` directly from `Data` is deprecated. Access it via `data._impl.contact` instead.
self.assertTrue((dx.contact.dist > 0).all())
mujoco/mjx/integration_test/smooth_test.py: 30 warnings
/mujoco/main/mjx/mujoco/mjx/integration_test/smooth_test.py:86: DeprecationWarning: Accessing `actuator_moment` directly from `Data` is deprecated. Access it via `data._impl.actuator_moment` instead.
dx.actuator_moment,
mujoco/mjx/warp/forward_test.py: 11 warnings
mujoco/mjx/warp/smooth_test.py: 3 warnings
/mujoco/main/mjx/mujoco/mjx/warp/test_util.py:47: DeprecationWarning: nconmax will be deprecated in mujoco-mjx>=3.5. Use naconmax instead.
dx = mjx.make_data(m, impl='warp', nconmax=nconmax, njmax=njmax)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ===========================
FAILED mujoco/mjx/warp/forward_test.py::ForwardTest::test_forward1 - AssertionError:
FAILED mujoco/mjx/warp/forward_test.py::ForwardTest::test_jit_caching0 - AssertionError: 0 != 1
FAILED mujoco/mjx/warp/forward_test.py::ForwardTest::test_jit_caching1 - AssertionError: 0 != 1
=========== 3 failed, 715 passed, 804 warnings in 357.33s (0:05:57) ===========
WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.0000.
WARNING: Nan, Inf or huge value in QACC at DOF 9. The simulation is unstable. Time = 0.0680.Weirdly enough, the two cache tests don't fail if I run the Detailspytest ./mjx/mujoco/mjx/warp/forward_test.py
================================ test session starts ================================
platform darwin -- Python 3.12.10, pytest-9.0.2, pluggy-1.6.0
rootdir: /mujoco/mjx
configfile: pyproject.toml
collected 11 items
mjx/mujoco/mjx/warp/forward_test.py .F......... [100%]
===================================== FAILURES ======================================
_____________________________ ForwardTest.test_forward1 _____________________________
self = <mjx.warp.forward_test.ForwardTest testMethod=test_forward1>
xml = 'humanoid/humanoid.xml', batch_size = 7
@parameterized.product(
xml=(
'humanoid/humanoid.xml',
'pendula.xml',
),
batch_size=(1, 7),
)
def test_forward(self, xml: str, batch_size: int):
if not _FORCE_TEST:
if not mjxw.WARP_INSTALLED:
self.skipTest('Warp not installed.')
m = test_util.load_test_file(xml)
m.opt.iterations = 10
m.opt.ls_iterations = 10
m.opt.jacobian = mujoco.mjtJacobian.mjJAC_DENSE
mx = mjx.put_model(m, impl='warp')
d = mujoco.MjData(m)
worldids = jp.arange(batch_size)
dx_batch = jax.vmap(functools.partial(tu.make_data, m))(worldids)
dx_batch = jax.jit(jax.vmap(forward.forward, in_axes=(None, 0)))(
mx, dx_batch
)
for i in range(batch_size):
dx = dx_batch[i]
d.qpos[:] = dx.qpos
d.qvel[:] = dx.qvel
d.ctrl[:] = dx.ctrl
d.mocap_pos[:] = dx.mocap_pos
d.mocap_quat[:] = dx.mocap_quat
mujoco.mj_forward(m, d)
# fwd_position
tu.assert_attr_eq(dx, d, 'xpos')
tu.assert_attr_eq(dx, d, 'xquat')
tu.assert_attr_eq(dx, d, 'xipos')
tu.assert_eq(d.ximat.reshape((-1, 3, 3)), dx.ximat, 'ximat')
tu.assert_attr_eq(dx, d, 'xanchor')
tu.assert_attr_eq(dx, d, 'xaxis')
tu.assert_attr_eq(dx, d, 'geom_xpos')
tu.assert_eq(dx.geom_xmat, d.geom_xmat.reshape((-1, 3, 3)), 'geom_xmat')
if m.nsite:
tu.assert_attr_eq(dx, d, 'site_xpos')
tu.assert_eq(dx.site_xmat, d.site_xmat.reshape((-1, 3, 3)), 'site_xmat')
tu.assert_attr_eq(dx, d, 'cdof')
tu.assert_attr_eq(dx._impl, d, 'cinert')
tu.assert_attr_eq(dx, d, 'subtree_com')
if m.nlight:
tu.assert_attr_eq(dx._impl, d, 'light_xpos')
tu.assert_attr_eq(dx._impl, d, 'light_xdir')
if m.ncam:
tu.assert_attr_eq(dx, d, 'cam_xpos')
tu.assert_eq(dx.cam_xmat, d.cam_xmat.reshape((-1, 3, 3)), 'cam_xmat')
tu.assert_attr_eq(dx, d, 'ten_length')
tu.assert_attr_eq(dx._impl, d, 'ten_J')
tu.assert_attr_eq(dx._impl, d, 'ten_wrapadr')
tu.assert_attr_eq(dx._impl, d, 'ten_wrapnum')
tu.assert_attr_eq(dx._impl, d, 'wrap_xpos')
tu.assert_attr_eq(dx._impl, d, 'wrap_obj')
tu.assert_attr_eq(dx._impl, d, 'crb')
qm = np.zeros((m.nv, m.nv))
mujoco.mj_fullM(m, qm, d.qM)
# mjwarp adds padding to qM
tu.assert_eq(qm, dx._impl.qM[: m.nv, : m.nv], 'qM')
# qLD is fused in a cholesky factorize and solve, and not written to.
tu.assert_contact_eq(d, dx, worldid=i)
tu.assert_attr_eq(dx, d, 'actuator_length')
actuator_moment = np.zeros((m.nu, m.nv))
mujoco.mju_sparse2dense(
actuator_moment,
d.actuator_moment,
d.moment_rownnz,
d.moment_rowadr,
d.moment_colind,
)
tu.assert_eq(dx._impl.actuator_moment, actuator_moment, 'actuator_moment')
# fwd_velocity
tu.assert_attr_eq(dx._impl, d, 'actuator_velocity')
tu.assert_attr_eq(dx, d, 'cvel')
tu.assert_attr_eq(dx, d, 'cdof_dot')
tu.assert_attr_eq(dx._impl, d, 'qfrc_spring')
tu.assert_attr_eq(dx._impl, d, 'qfrc_damper')
tu.assert_attr_eq(dx, d, 'qfrc_gravcomp')
tu.assert_attr_eq(dx, d, 'qfrc_fluid')
tu.assert_attr_eq(dx, d, 'qfrc_passive')
tu.assert_attr_eq(dx, d, 'qfrc_bias')
# NOTE(user): This fails due to some weird sorting of keys.
> tu.assert_efc_eq(d, dx, worldid=i)
mjx/mujoco/mjx/warp/forward_test.py:179:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
mjx/mujoco/mjx/warp/test_util.py:199: in assert_efc_eq
assert_eq(jp_, j, 'J')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Array([[-0.9875637 , 0.15721938, 1. , -0.00917677, 1.335037 ,
-0.10656909, -0.10577767, 1.1439778 ... -0.03383656, 0. , 0. , 0. , 0. ,
0. , 0. ]], dtype=float32)
b = array([[-0.9875637 , 0.15721938, 1. , -0.00917675, 1.33503689,
-0.10656907, -0.10577765, 1.14397789....07388643,
-0.0338367 , 0. , 0. , 0. , 0. ,
0. , 0. ]])
name = 'J'
def assert_eq(a, b, name):
tol = _TOLERANCE * 10 # avoid test noise
err_msg = f'mismatch: {name}'
> np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)
E AssertionError:
E Not equal to tolerance rtol=0.0005, atol=0.0005
E mismatch: J
E Mismatched elements: 160 / 432 (37%)
E First 5 mismatches are at indices:
E [4, 0]: -0.9871430993080139 (ACTUAL), -0.999479684061378 (DESIRED)
E [4, 1]: 0.15983904898166656 (ACTUAL), -0.03225462987801541 (DESIRED)
E [4, 3]: 0.16618438065052032 (ACTUAL), -0.22991350005486522 (DESIRED)
E [4, 4]: 1.2810968160629272 (ACTUAL), 1.3359474985903166 (DESIRED)
E [4, 5]: 0.044975683093070984 (ACTUAL), -0.0601160972922109 (DESIRED)
E Max absolute difference among violations: 0.87907806
E Max relative difference among violations: 12.74774404
E ACTUAL: array([[-0.987564, 0.157219, 1. , -0.009177, 1.335037, -0.106569,
E -0.105778, 1.143978, -0.053232, 0.039397, -0.004723, 0.880073,
E -0.497201, 0.173519, -0.011014, 0. , 0. , 0. ,...
E DESIRED: array([[-0.987564, 0.157219, 1. , -0.009177, 1.335037, -0.106569,
E -0.105778, 1.143978, -0.053232, 0.039397, -0.004723, 0.880073,
E -0.497201, 0.173519, -0.011014, 0. , 0. , 0. ,...
mjx/mujoco/mjx/warp/test_util.py:36: AssertionError
================================= warnings summary ==================================
mujoco/mjx/warp/forward_test.py: 13 warnings
/mujoco/mjx/mujoco/mjx/warp/test_util.py:47: DeprecationWarning: nconmax will be deprecated in mujoco-mjx>=3.5. Use naconmax instead.
dx = mjx.make_data(m, impl='warp', nconmax=nconmax, njmax=njmax)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================== short test summary info ==============================
FAILED mjx/mujoco/mjx/warp/forward_test.py::ForwardTest::test_forward1 - AssertionError:
==================== 1 failed, 10 passed, 13 warnings in 28.00s =====================I tracked the numerical errors down to mujoco/mjx/mujoco/mjx/warp/test_util.py Line 156 in 07d7bc9 Changing this line to the following fixes the issue: keys_sorted = np.lexsort((-np.round(efc_pos, 4), efc_type, np.round(efc_d, 4)))I'm not sure if the |
7f0f35f to
965bef3
Compare
9fd58de to
732877b
Compare
| efc_type = select(dx._impl.efc__type)[:nefc] | ||
| efc_d = select(dx._impl.efc__D)[:nefc] | ||
| keys_sorted = np.lexsort((-efc_pos, efc_type, efc_d)) | ||
| keys_sorted = np.lexsort((np.round(-efc_pos, 12), efc_type, np.round(efc_d, 12))) |
There was a problem hiding this comment.
Note these rounding changes. Let me know if you want these to be in a separate PR.
877f79a to
f88be6e
Compare
ac6152b to
0b3f62b
Compare
0b3f62b to
05c7453
Compare
05c7453 to
237cb19
Compare
This is not ready yet but just demonstrates that it should be possible to enable warp backend on MacOS. ```console $ uv pip install -U --extra-index-url="https://py.mujoco.org" "mujoco>=3.7.0.dev0,<3.8.0" && \ uv pip install -U warp-lang && \ uv pip install -U -e ./mjx /Users/google-deepmind/mujoco/main/mjx/mujoco/mjx/_src/io_test.py:472: DeprecationWarning: nconmax will be deprecated in mujoco-mjx>=3.5. Use naconmax instead. d = mjx.make_data(m, impl='warp', nconmax=9, njmax=23) Warp 1.12.0 initialized: CUDA not enabled in this build Devices: "cpu" : "arm" Kernel cache: /private/var/folders/ml/rrlg98ln26l7xvxgq_yqfn4c0000gn/T/tmpoklv_7xg Warp DeprecationWarning: The symbol `warp.types.warp_type_to_np_dtype` will soon be removed from the public API. It can still be accessed from `warp._src.types.warp_type_to_np_dtype` but might be changed or removed without notice. ./opt/homebrew/Cellar/python@3.12/3.12.10/Frameworks/Python.framework/Versions/3.12/lib/python3.12/tempfile.py:940: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/var/folders/ml/rrlg98ln26l7xvxgq_yqfn4c0000gn/T/tmpoklv_7xg'> _warnings.warn(warn_message, ResourceWarning) ---------------------------------------------------------------------- Ran 1 test in 1.105s OK ```
When using warp on CPU, there are small numerical differences that cause the `np.lexsort` to return inconsistently-ordered `efc-pos` and `efc_D`. This changes the sorting to round the values, fixing the flakiness.
237cb19 to
3b9919c
Compare
See #2947.
With the existing changes and
warp-lang>=1.11.0, we can now enable warp backend for mujoco on MacOS. More generally, mujoco warp now also works on CPU and not just cuda.