Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

filter_pmap traces function twice #932

Open
garymm opened this issue Jan 9, 2025 · 4 comments
Open

filter_pmap traces function twice #932

garymm opened this issue Jan 9, 2025 · 4 comments

Comments

@garymm
Copy link
Contributor

garymm commented Jan 9, 2025

import equinox as eqx
import jax


def f(x):
    return x + 1


f_amt = eqx.debug.assert_max_traces(f, max_traces=1)

f_amt_pmap = eqx.filter_pmap(f_amt, axis_name="device", devices=jax.local_devices())

f_amt_pmap(jax.numpy.ones((1, 1)))

RuntimeError: <function f at 0x107a1a840> can only be traced 1 times. However, it is has now been traced 2 times. Could not determine argument was responsible for re-tracing.

Not sure if this is WAI, but I was adding pmap to some code and was surprised to see it and it doesn't seem to be documented. Happy to contribute a fix or update docs if you can give me any pointers.

@patrick-kidger
Copy link
Owner

So this is expected, as filter_pmap actually uses an additional vmap call under the hood:

fun_abstract = jax.vmap(

in order to resolve the output axes correctly.

I can see the interaction with eqx.debug.assert_max_traces is annoying though. I don't think I have a good solution to that. FWIW I think eqx.filter_pmap should anyway be considered semi-deprecated, much like jax.pmap -- it's largely been superseded by using JIT with sharding.

@garymm
Copy link
Contributor Author

garymm commented Jan 10, 2025

Yeah I realize pmap is the old way, but I have a use case where I'm not convinced jit with sharding is able to do what I want (or at least I haven't been able to figure out how to do it yet).

@garymm
Copy link
Contributor Author

garymm commented Jan 10, 2025

Would it be worth updating the docs of filter_pmap to note this is expected?

@patrick-kidger
Copy link
Owner

It's pretty edge-case, unless this comes up a few times then I think just the existence of this issue is probably enough :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants