Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError when attempting to AOT lower or remat functions with jax_getattr. #24314

Open
markblee opened this issue Oct 15, 2024 · 0 comments
Open
Labels
bug Something isn't working

Comments

@markblee
Copy link

Description

A simple repro:

import jax.ad_checkpoint
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies
from jax.experimental.attrs import jax_getattr, jax_setattr

class A:
    ...

a = A()
jax_setattr(a, 'x', 0)

def fn(c):
    jax_getattr(a, 'x')  # comment out to run successfully
    return c

# ValueError: safe_zip() argument 2 is longer than argument 1 
jax.jit(fn).lower(0.0)

# ValueError: too many values to unpack (expected 0)
fn = jax.ad_checkpoint.remat(fn, policy=jax_remat_policies.everything_saveable)
fn(0.0)

System info (python version, jaxlib version, accelerator, etc.)

python: 3.10.14
jax: 0.4.34
jaxlib: 0.4.34
accelerator: cpu
@markblee markblee added the bug Something isn't working label Oct 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant