-
From what I understand, only environments implemented in JAX can be JIT-compiled together with a JAX policy. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi @uduse xla here is implemented via custom call, any c++ code (envpool code) can be put under a custom call and jitted into jax computation graph. The GPU device support is faked via an extra call to copy the memory to GPU device. https://dfm.io/posts/extending-jax/ this post is our main reference for the implementation. |
Beta Was this translation helpful? Give feedback.
Hi @uduse xla here is implemented via custom call, any c++ code (envpool code) can be put under a custom call and jitted into jax computation graph.
The GPU device support is faked via an extra call to copy the memory to GPU device.
https://dfm.io/posts/extending-jax/ this post is our main reference for the implementation.