Skip to content

Commit

Permalink
Add OCDBT usage documentation to the RTD page.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 546994248
  • Loading branch information
cpgaffney1 authored and copybara-github committed Jul 10, 2023
1 parent 670c7dd commit 1c56e81
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 7 deletions.
16 changes: 14 additions & 2 deletions checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@

def _get_coordinator_address_without_port(
coordinator_address: Optional[str],
) -> str:
) -> Optional[str]:
"""Returns JAX coordinator address stripped of port number."""
if not coordinator_address:
raise ValueError('Coordinator address not set.')
logging.warning('JAX coordinator address not set.')
return None
return coordinator_address.split(':')[0]


Expand Down Expand Up @@ -76,6 +77,17 @@ def create_coordinator_server_and_context() -> (
ocdbt_address = _get_coordinator_address_without_port(
jax_global_state.coordinator_address
)
if ocdbt_address is None:
return (
ts.Context(
{
# Provide cache pool for B-tree nodes to avoid repeated reads.
'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
},
parent=serialization.TS_CONTEXT,
),
None,
)

coordinator_server = None
if jax_global_state.process_id == 0:
Expand Down
2 changes: 0 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
'sphinxcontrib.katex',
'sphinx_autodoc_typehints',
'sphinx_book_theme',
# 'coverage_check',
'myst_nb', # This is used for the .ipynb notebooks
'sphinx.ext.autosectionlabel',
'sphinx.ext.mathjax',
Expand Down Expand Up @@ -97,7 +96,6 @@
html_theme = 'sphinx_book_theme'

html_theme_options = {
# 'logo_only': True,
'show_toc_level': 2,
}

Expand Down
Binary file added docs/img/checkpoint/benchmarks/restore_ocdbt.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/checkpoint/benchmarks/save_ocdbt.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 46 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,44 @@ users. It includes multiple distinct but interrelated libraries.
Installation
---------------

For more information how to install orbax, see the project README.
There is no single `orbax` package, but rather a separate package for each
functionality provided by the Orbax namespace.

The latest release of `orbax-checkpoint` can be installed from
`PyPI <https://pypi.org/project/orbax-checkpoint/>`_ using

``pip install orbax-checkpoint``

You may also install directly from GitHub, using the following command. This
can be used to obtain the most recent version of Optax.

``pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint'``

Similarly, `orbax-export` can be installed from
`PyPI <https://pypi.org/project/orbax-export/>`_ using

``pip install orbax-export``

Install from GitHub using the following.

``pip install 'git+https://github.com/google/orbax/#subdirectory=export'``


.. For TOC
.. toctree::
:hidden:
:maxdepth: 1
:caption: Checkpointing

orbax_checkpoint_101
api_reference/checkpoint

.. toctree::
:hidden:
:maxdepth: 1
:caption: Exporting

api_reference/export


.. For TOC
Expand All @@ -48,6 +85,7 @@ For more information how to install orbax, see the project README.
:caption: Checkpointing

orbax_checkpoint_101
optimized_checkpointing
api_reference/checkpoint

.. toctree::
Expand All @@ -70,6 +108,13 @@ Checkpointing
:class-card: sd-text-black sd-bg-light
:link: orbax_checkpoint_101.html

.. grid-item::
:columns: 6 6 6 4

.. card:: Optimized Checkpointing
:class-card: sd-text-black sd-bg-light
:link: optimized_checkpointing.html

.. grid-item::
:columns: 6 6 6 4

Expand Down
255 changes: 255 additions & 0 deletions docs/optimized_checkpointing.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "tKGlDfvNJM8R"
},
"source": [
"# Optimized Checkpointing with Tensorstore"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AIgEkzUjJRUt"
},
"source": [
"Orbax relies on [Tensorstore](https://google.github.io/tensorstore/) to store\n",
"individual arrays in a checkpoint. Tensorstore provides efficient, scalable library for reading and writing arrays.\n",
"\n",
"Until recently, however, our use of Tensorstore came with a few drawbacks. Chief among them was the fact that every parameter in a training state would be saved as a separate directory. This approach can be quite performant, even for models with hundreds of billions of parameters, *provided that model layers are stacked*. Otherwise, hundreds or thousands of directories may be created in the checkpoint.\n",
"\n",
"This fact can lead to very slow restore times, which is undesirable in and of itself, but is particularly painful for jobs that may be preempted frequently and need to restart, for example.\n",
"\n",
"While it is slightly less of a concern at save time, since writes to disk can happen asynchronously, the synchronous portion of the save can still be slow as many directories are created.\n",
"\n",
"Additionally, if individual parameters are small, storage may be wasted on filesystems with minimum file sizes."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0zou_vidLNMd"
},
"source": [
"## Introducing OCDBT\n",
"\n",
"The new, optimized checkpoint format provided by Orbax is backed by Tensorstore's [OCDBT](https://google.github.io/tensorstore/kvstore/ocdbt/index.html) driver (optionally-cooperative distributed B-tree).\n",
"\n",
"For practical purposes, this means that we will no longer store one parameter per directory, but will aggregate many parameters into a smaller set of large files.\n",
"\n",
"Empirically, we have observed substantial speed-ups in both save and restore when using OCDBT."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uA_OTs-GEmty"
},
"source": [
"### Save Performance (sec)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MUs2AbpcaG8U"
},
"source": [
"\u003cimg src=https://orbax.readthedocs.io/en/latest/img/checkpoint/benchmarks/save_ocdbt.png\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FhyHhgwqEgmR"
},
"source": [
"### Restore Performance (sec)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5Yrd8RQ3cZer"
},
"source": [
"\u003cimg src=https://orbax.readthedocs.io/en/latest/img/checkpoint/benchmarks/restore_ocdbt.png\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kfO6-6ZENhEG"
},
"source": [
"## Checkpoint Format\n",
"\n",
"Concretely, what does the new checkpoint format look like in comparison to the old?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j7veWqRzQ7Jb"
},
"source": [
"### Old Format"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QOHzh4hAPUJQ"
},
"outputs": [],
"source": [
"f = \"\"\"\n",
"path/to/my/checkpoint/dir/\n",
" 0/\n",
" state/\n",
" layer0.param0/\n",
" .zarray\n",
" 0.0\n",
" 0.1\n",
" 1.0\n",
" 1.1\n",
" layer1.param0/\n",
" .zarray\n",
" 0.0\n",
" ...\n",
" \u003canother_item\u003e/\n",
" ...\n",
" 1/\n",
" ...\n",
" 2/\n",
" ...\n",
"\n",
"Note: in this case, `0.0`, `0.1`, etc. provides an indication of how the array\n",
"was sharded when originally saved.\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-0J3DhoFQ-x1"
},
"source": [
"### New Format"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "K5WqzSTpRBp3"
},
"outputs": [],
"source": [
"f = \"\"\"\n",
"path/to/my/checkpoint/dir/\n",
" 0/\n",
" state/\n",
" checkpoint # legacy msgpack file, stores tree structure\n",
" tree_metadata # (maybe) new proto file, stores tree structure\n",
" d/ # array data stored here\n",
" 012b2c6e5c9d2a16c240a59d5f0f35c0\n",
" 056e0816bdc5496a86251e58a0ec202b\n",
" ...\n",
" manifest.0000000000000001\n",
" ...\n",
" manifest.ocdbt\n",
" \u003canother_item\u003e/\n",
" ...\n",
" 1/\n",
" ...\n",
" 2/\n",
" ...\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hOf2vscWRF5u"
},
"source": [
"## Enabling OCDBT"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ADRsxIkFRPZR"
},
"outputs": [],
"source": [
"import orbax.checkpoint as ock\n",
"\n",
"# Ensure that the coordinator_server is kept alive for the duration of the\n",
"# program (if not None).\n",
"# The server will only be non-None on a single process.\n",
"ocdbt_context, coordinator_server = (\n",
" ock.type_handlers.create_coordinator_server_and_context()\n",
")\n",
"ock.type_handlers.register_standard_handlers_with_options(\n",
" use_ocdbt=True, ts_context=ocdbt_context\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gXXzqbco_UgX"
},
"outputs": [],
"source": [
"# Later, make sure PyTreeCheckpointHandler is initialized with `use_ocdbt=True`.\n",
"# Depending on when you read this, the option may already default to True.\n",
"ckptr = ock.Checkpointer(ock.PyTreeCheckpointHandler(use_ocdbt=True))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HxLoTiZnAOvw"
},
"source": [
"## Additional Notes\n",
"\n",
"All checkpoints previously produced by Orbax in the old format will still be\n",
"readable when OCDBT is enabled. However, if a checkpoint is produced in the OCDBT format, it cannot be read if the OCDBT feature is disabled."
]
}
],
"metadata": {
"colab": {
"last_runtime": {
"build_target": "//learning/grp/tools/ml_python:ml_notebook",
"kind": "private"
},
"private_outputs": true,
"provenance": [
{
"file_id": "1bRC6p0AstPPAAW0AUoxHaOFEWpaW_GjI",
"timestamp": 1688077923387
}
],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
4 changes: 3 additions & 1 deletion docs/orbax_checkpoint_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@
},
"outputs": [],
"source": [
"path = epath.Path('/tmp/checkpoint')\n",
"path = epath.Path('/tmp/my-checkpoints/')\n",
"if path.exists():\n",
" path.rmtree()\n",
"path.mkdir()"
]
},
Expand Down
3 changes: 2 additions & 1 deletion readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ python:
- requirements: ./docs/requirements/requirements-docs.txt
- method: pip
path: ./checkpoint
# TODO(cpgaffney) support export.
- method: pip
path: ./export

0 comments on commit 1c56e81

Please sign in to comment.