-
Was wondering if there is an easy and minimal way to reproduce this behaviour with our own config parameters import jax
from functools import partial
from jax import numpy as jnp
@partial(jax.jit , static_argnums=(0,))
def f(shape):
print(f"jitted for shape {shape}")
return jnp.ones(shape)
jax.config.update("jax_enable_x64", False)
shape = (2 , 2 , 2)
f(shape)
# Out:
# jitted for shape (2, 2, 2)
# Array([[[1., 1.],
# [1., 1.]],
#
# [[1., 1.],
# [1., 1.]]], dtype=float32)
jax.config.update("jax_enable_x64", True)
f(shape)
# Out:
# jitted for shape (2, 2, 2)
# Array([[[1., 1.],
# [1., 1.]],
#
# [[1., 1.],
# [1., 1.]]], dtype=float64) Currently, if I change a global config parameter the function is not rejitted Thank you |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Oct 21, 2024
Replies: 1 comment 2 replies
-
Hi - thanks for the question! The way that this works for configurations like Lines 194 to 201 in f833891 This is an internal utility, and unfortunately there is no public API for customizing the trace context. |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
ASKabalan
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi - thanks for the question! The way that this works for configurations like
enable_x64
is via this function:jax/jax/_src/config.py
Lines 194 to 201 in f833891
This is an internal utility, and unfortunately there is no public API for customizing the trace context.