Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

module 'jax.api_util' has no attribute 'debug_info' #160

Open
amrzv opened this issue Mar 3, 2025 · 1 comment
Open

module 'jax.api_util' has no attribute 'debug_info' #160

amrzv opened this issue Mar 3, 2025 · 1 comment

Comments

@amrzv
Copy link

amrzv commented Mar 3, 2025

Hello.
When running SigLIP2_demo.ipynb notebook in colab on line zimg, _, out = model.apply({'params': params}, imgs, None) the error is raised

AttributeError                            Traceback (most recent call last)
[<ipython-input-5-76fa53b787a5>](https://localhost:8080/#) in <cell line: 0>()
     46   print('imgs', imgs.shape)
     47 
---> 48 zimg, _, out = model.apply({'params': params}, imgs, None)
     49 
     50 print('zimg', zimg.shape)

    [... skipping hidden 6 frame]

3 frames
[/content/big_vision/big_vision/models/proj/image_text/two_towers.py](https://localhost:8080/#) in __call__(self, image, text, **kw)
     66       ).Model(**{"num_classes": out_dims[0], **(self.image or {})}, name="img")  # pylint: disable=not-a-mapping
     67 
---> 68       zimg, out_img = image_model(image, **kw)
     69       for k, v in out_img.items():
     70         out[f"img/{k}"] = v

    [... skipping hidden 2 frame]

[/content/big_vision/big_vision/models/vit.py](https://localhost:8080/#) in __call__(self, image, train)
    228     x = nn.Dropout(rate=self.dropout)(x, not train)
    229 
--> 230     x, out["encoder"] = Encoder(
    231         depth=self.depth,
    232         mlp_dim=self.mlp_dim,

    [... skipping hidden 2 frame]

[/content/big_vision/big_vision/models/vit.py](https://localhost:8080/#) in __call__(self, x, deterministic)
    134           policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
    135           )
--> 136       x, scan_out = nn.scan(
    137           block,
    138           variable_axes={"params": 0},

    [... skipping hidden 3 frame]

[/usr/local/lib/python3.11/dist-packages/flax/core/axes_scan.py](https://localhost:8080/#) in scan_fn(broadcast_in, init, *args)
    157 
    158     in_avals, in_tree = jax.tree_util.tree_flatten(input_avals)
--> 159     debug_info = jax.api_util.debug_info("flax scan", broadcast_body,
    160                                          (in_tree,), {})
    161     f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(

AttributeError: module 'jax.api_util' has no attribute 'debug_info'

Image

@mitscha
Copy link
Collaborator

mitscha commented Mar 4, 2025

This seems unrelated to big_vision and might be due to a flax update in the colab environment. You should be able work around it by adding !pip3 install flax==0.8.5 after !pip3 -q install --no-cache-dir -U crcmod in the cell "Environment setup".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants