Skip to content

Commit

Permalink
Adding version guard for jax.core.DebugInfo, which is the new name fo…
Browse files Browse the repository at this point in the history
…r jax.core.JaxprDebugInfo from JAX 0.5.1.

PiperOrigin-RevId: 730573843
  • Loading branch information
james-martens authored and KfacJaxDev committed Feb 24, 2025
1 parent a05444d commit 479d949
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@
from kfac_jax._src import utils
import numpy as np

jax_version = (
jax.__version_info__ if hasattr(jax, "__version_info__")
else tuple(map(int, jax.__version__.split("."))))

if jax_version >= (0, 5, 1):
DebugInfo = jax.core.DebugInfo
else:
DebugInfo = jax.core.JaxprDebugInfo # pytype: disable=module-attr


HIGHER_ORDER_NAMES = ("cond", "while", "scan", "pjit", "xla_call", "xla_pmap")
ITERATIVE_HIGHER_ORDER_NAMES = ("while", "scan")

Expand Down Expand Up @@ -385,7 +395,7 @@ def make_jax_graph(

debug_info = closed_jaxpr.jaxpr.debug_info
if debug_info is not None:
debug_info = jax.core.DebugInfo(
debug_info = DebugInfo(
debug_info.traced_for,
debug_info.func_src_info,
debug_info.arg_names,
Expand Down

0 comments on commit 479d949

Please sign in to comment.