Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatibility with Haiku #528

Open
chriscarmona opened this issue Oct 2, 2023 · 4 comments
Open

Incompatibility with Haiku #528

chriscarmona opened this issue Oct 2, 2023 · 4 comments

Comments

@chriscarmona
Copy link

chriscarmona commented Oct 2, 2023

Reopening an issue regarding incompatibility with Haiku naming conventions (similar to previous issue). This is not problematic in v0.3.5

Sample code:

from jax import numpy as jnp
import orbax.checkpoint as ocp
import haiku as hk


@hk.transform
def forward_fn(inputs):
  # net = hk.Linear(output_size=2) # This works
  net = hk.nets.MLP(
      output_sizes=[2, 2], activate_final=True)  # This doesn't work
  return net(inputs)


prng_seq = hk.PRNGSequence(0)
params = forward_fn.init(next(prng_seq), jnp.ones((1, 5)))

ckpt_dir = '/tmp/my-checkpoints/'
orbax_mngr = ocp.CheckpointManager(
    ckpt_dir,
    {'state': ocp.PyTreeCheckpointer()},
    options=ocp.CheckpointManagerOptions(max_to_keep=1),
)
orbax_mngr.save(step=0, items={'state': params})

The error:

Traceback (most recent call last):
  File "/workspaces/modularbayes/examples/bar.py", line 23, in <module>
    orbax_mngr.save(step=0, items={'state': params})
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 496, in save
    self._checkpointers[k].save(item_dir, item, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 79, in save
    self._handler.save(tmpdir, *args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 818, in save
    asyncio.run(async_save(directory, item, *args, **kwargs))
  File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 811, in async_save
    commit_futures = await self.async_save(*args, **kwargs)  # pytype: disable=bad-return-type
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 786, in async_save
    commit_futures = await asyncio.gather(*serialize_ops)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 893, in serialize
    open_future = ts.open(
ValueError: Error parsing object member "json_pointer": JSON Pointer requires '~' to be followed by '0' or '1': "/mlp/~/linear_0.b" [source locations='tensorstore/internal/json_binding/json_binding.h:861\ntensorstore/internal/json_binding/json_binding.h:825']
sys:1: RuntimeWarning: coroutine 'async_serialize' was never awaited
@chriscarmona chriscarmona changed the title Incompatibility with Haiku (again) Incompatibility with Haiku Oct 3, 2023
@liangyaning33
Copy link
Collaborator

Hi Chris,

Thanks for raising the issue. We have submitted a fix and this issue should have been resolved now. Can you please verify that this is no longer erroring out for you? Thanks!!

Best,
Yaning

@Carbon225
Copy link

@liangyaning33

I installed with

pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint'

but I get the error

ValueError: Error parsing object member "json_pointer": JSON Pointer requires '~' to be followed by '0' or '1': "/state.disc.batch_norm/~/mean_ema.average" [source locations='tensorstore/internal/json_binding/json_binding.h:861\ntensorstore/internal/json_binding/json_binding.h:825']

Version 0.3.5 works.

@liangyaning33
Copy link
Collaborator

Hi Chris,

Do you mind checking what version of the orbax checkpointing you are using? I just checked this in notebook with v0.4.1 and it works.
Screenshot 2023-10-24 at 10 20 09 AM

@chriscarmona
Copy link
Author

Hi @liangyaning33,

Apologies for the delayed response. I observe the same behaviour as @Carbon225, The error still appears in v0.4.1 (as currently installed by pip install -U orbax-checkpoint).
It works with v0.3.5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants