Skip to content

Commit

Permalink
Refine several API docs and add a test to prevent colab docs from bre…
Browse files Browse the repository at this point in the history
…aking.

PiperOrigin-RevId: 682082999
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Oct 3, 2024
1 parent 8bf99e3 commit 410ca56
Show file tree
Hide file tree
Showing 4 changed files with 521 additions and 528 deletions.
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# pylint: disable=g-importing-member, unused-import

from orbax.checkpoint._src.handlers import handler_registration
from orbax.checkpoint._src.handlers.array_checkpoint_handler import ArrayCheckpointHandler
from orbax.checkpoint._src.handlers.async_checkpoint_handler import AsyncCheckpointHandler
from orbax.checkpoint._src.handlers.base_pytree_checkpoint_handler import BasePyTreeCheckpointHandler
Expand Down
144 changes: 76 additions & 68 deletions docs/guides/checkpoint/api_refactor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"id": "0jj2MOXcL9Eh"
},
"source": [
"# Using the Refactored CheckpointManager API"
"# Using the Refactored CheckpointManager API"
]
},
{
Expand All @@ -24,7 +24,7 @@
"id": "wZKtrmVojffN"
},
"source": [
"**The legacy APIs is deprecated and will stop working after May 1st, 2024. Please ensure you are using the new style by then.**"
"**The legacy APIs is deprecated and will stop working soon. Please ensure you are using the new style ASAP.**"
]
},
{
Expand Down Expand Up @@ -54,7 +54,8 @@
},
"outputs": [],
"source": [
"import orbax.checkpoint as ocp"
"import orbax.checkpoint as ocp\n",
"from etils import epath"
]
},
{
Expand Down Expand Up @@ -142,19 +143,11 @@
" ocp.test_utils.erase_and_create_empty('/tmp/ckpt2/'),\n",
" options=options,\n",
") as mngr:\n",
"\n",
" mngr.save(0, args=ocp.args.StandardSave(pytree))\n",
" mngr.wait_until_finished()\n",
"\n",
" # After providing `args` during an initial `save` or `restore` call, the\n",
" # `CheckpointManager` instance records the type so that you do not need to\n",
" # specify it again. If the `CheckpointManager` instance is not provided with a\n",
" # `ocp.args.CheckpointArgs` instance for a particular item on a previous\n",
" # occasion it cannot be restored without specifying the argument at restore\n",
" # time.\n",
"\n",
" # In many cases, you can restore exactly as saved without specifying additional\n",
" # arguments.\n",
" # The `CheckpointManager` already knows that the object is saved and restored\n",
" # using \"standard\" pytree logic. In many cases, you can restore exactly as\n",
" # saved without specifying additional arguments.\n",
" mngr.restore(0)\n",
" # If customization of properties like sharding or dtype is desired, just provide\n",
" # the abstract target PyTree, the properties of which will be used to set\n",
Expand All @@ -181,9 +174,7 @@
"id": "ebL-zbpVaH-4"
},
"source": [
"Let's explore scenarios when `restore()` and `item_metadata()` calls raise errors due to unspecified CheckpointHandlers for item names.\n",
"\n",
"`CheckpointManager(..., item_handlers=...)` can be used to resolve these scenarios."
"Let's explore scenarios when `restore()` and `item_metadata()` calls raise errors due to unspecified CheckpointHandlers for item names."
]
},
{
Expand All @@ -196,7 +187,19 @@
"source": [
"# Unmapped CheckpointHandlers on a new CheckpointManager instance.\n",
"new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)\n",
"new_mngr.restore(0) # Raises error due to unmapped CheckpointHandler"
"try:\n",
" new_mngr.restore(0) # Raises error due to unmapped CheckpointHandler\n",
"except BaseException as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o6royiq9WNIg"
},
"source": [
"To fix this, use one of the following options:"
]
},
{
Expand All @@ -207,18 +210,41 @@
},
"outputs": [],
"source": [
"new_mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))"
"new_mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))\n",
"new_mngr.close()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2wZzpye8WYdi"
},
"source": [
"We can also configure the `CheckpointManager` to know how to restore the object in advance."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tGXkK0BIXDaz"
"id": "3Bdto6tPAfQy"
},
"outputs": [],
"source": [
"new_mngr.close()"
"# The item name is \"default\".\n",
"list(epath.Path('/tmp/ckpt2/0').iterdir())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tATQ5rmCAqey"
},
"outputs": [],
"source": [
"registry = ocp.handlers.DefaultCheckpointHandlerRegistry()\n",
"registry.add('default', ocp.args.StandardRestore, ocp.StandardCheckpointHandler)"
]
},
{
Expand All @@ -233,7 +259,7 @@
"with ocp.CheckpointManager(\n",
" '/tmp/ckpt2/',\n",
" options=options,\n",
" item_handlers=ocp.StandardCheckpointHandler()\n",
" handler_registry=registry,\n",
") as new_mngr:\n",
" print(new_mngr.restore(0))"
]
Expand All @@ -247,7 +273,7 @@
"**NOTE:**\n",
"`CheckpointManager.item_metadata(step)` doesn't support any input like `args` in `restore(..., args=...)`.\n",
"\n",
"So, `item_handlers` is the only option available with `item_metadata(step)` calls."
"So, `handler_registry` is currently required when calling `item_metadata(step)` before calling restore or save."
]
},
{
Expand All @@ -260,7 +286,10 @@
"source": [
"# item_handlers becomes even more critical with item_metadata() calls.\n",
"new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)\n",
"new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandler"
"try:\n",
" new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandler\n",
"except BaseException as e:\n",
" print(e)"
]
},
{
Expand All @@ -271,23 +300,12 @@
},
"outputs": [],
"source": [
"new_mngr = ocp.CheckpointManager(\n",
"with ocp.CheckpointManager(\n",
" '/tmp/ckpt2/',\n",
" options=options,\n",
" item_handlers=ocp.StandardCheckpointHandler(),\n",
")\n",
"new_mngr.item_metadata(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mTZK3rB-LFg8"
},
"outputs": [],
"source": [
"new_mngr.close()"
" handler_registry=registry,\n",
") as new_mngr:\n",
" new_mngr.item_metadata(0)"
]
},
{
Expand Down Expand Up @@ -362,7 +380,6 @@
" ocp.test_utils.erase_and_create_empty('/tmp/ckpt4/'),\n",
" # `item_names` defines an up-front contract about what items the\n",
" # CheckpointManager will be dealing with.\n",
" item_names=('state', 'extra_metadata'),\n",
" options=options,\n",
")\n",
"\n",
Expand Down Expand Up @@ -415,9 +432,7 @@
"id": "mZSZzWkhLSvz"
},
"source": [
"Just like single item use case described above, let's explore scenarios when `restore()` and `item_metadata()` calls raise errors due to unspecified CheckpointHandlers for item names.\n",
"\n",
"`CheckpointManager(..., item_handlers=...)` can be used to resolve these scenarios."
"Just like single item use case described above, let's explore scenarios when `restore()` and `item_metadata()` calls raise errors due to unspecified CheckpointHandlers for item names."
]
},
{
Expand All @@ -434,7 +449,10 @@
" options=options,\n",
" item_names=('state', 'extra_metadata'),\n",
")\n",
"new_mngr.restore(0) # Raises error due to unmapped CheckpointHandlers"
"try:\n",
" new_mngr.restore(0) # Raises error due to unmapped CheckpointHandlers\n",
"except BaseException as e:\n",
" print(e)"
]
},
{
Expand All @@ -451,18 +469,21 @@
" state=ocp.args.StandardRestore(abstract_pytree),\n",
" extra_metadata=ocp.args.JsonRestore(),\n",
" ),\n",
")"
")\n",
"new_mngr.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wVtGDZS1XQKy"
"id": "nG12gM6l_y6C"
},
"outputs": [],
"source": [
"new_mngr.close()"
"registry = ocp.handlers.DefaultCheckpointHandlerRegistry()\n",
"registry.add('state', ocp.args.StandardRestore, ocp.StandardCheckpointHandler)\n",
"registry.add('extra_metadata', ocp.args.JsonRestore, ocp.JsonCheckpointHandler)"
]
},
{
Expand All @@ -477,10 +498,7 @@
"with ocp.CheckpointManager(\n",
" '/tmp/ckpt4/',\n",
" options=options,\n",
" item_handlers={\n",
" 'state': ocp.StandardCheckpointHandler(),\n",
" 'extra_metadata': ocp.JsonCheckpointHandler(),\n",
" },\n",
" handler_registry=registry,\n",
") as new_mngr:\n",
" print(new_mngr.restore(0))"
]
Expand All @@ -494,7 +512,7 @@
"**NOTE:**\n",
"`CheckpointManager.item_metadata(step)` doesn't support any input like `args` in `restore(..., args=...)`.\n",
"\n",
"So, `item_handlers` is the only option available with `item_metadata(step)` calls."
"So, `handler_registry` is currently required with `item_metadata(step)` calls."
]
},
{
Expand All @@ -511,7 +529,10 @@
" options=options,\n",
" item_names=('state', 'extra_metadata'),\n",
") as new_mngr:\n",
" new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandlers"
" try:\n",
" new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandlers\n",
" except BaseException as e:\n",
" print(e)"
]
},
{
Expand All @@ -525,28 +546,15 @@
"with ocp.CheckpointManager(\n",
" '/tmp/ckpt4/',\n",
" options=options,\n",
" item_handlers={\n",
" 'state': ocp.StandardCheckpointHandler(),\n",
" 'extra_metadata': ocp.JsonCheckpointHandler(),\n",
" },\n",
" handler_registry=registry,\n",
") as new_mngr:\n",
" print(new_mngr.item_metadata(0))"
]
}
],
"metadata": {
"colab": {
"last_runtime": {
"build_target": "//learning/grp/tools/ml_python:ml_notebook",
"kind": "private"
},
"private_outputs": true,
"provenance": [
{
"file_id": "17zeb2jhSE6p1x3u7r_s15AuKrOPtQ7y3",
"timestamp": 1704302873675
}
]
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
Expand Down
Loading

0 comments on commit 410ca56

Please sign in to comment.