-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add OCDBT usage documentation to the RTD page.
PiperOrigin-RevId: 546994248
- Loading branch information
1 parent
670c7dd
commit 1c56e81
Showing
8 changed files
with
320 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "tKGlDfvNJM8R" | ||
}, | ||
"source": [ | ||
"# Optimized Checkpointing with Tensorstore" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "AIgEkzUjJRUt" | ||
}, | ||
"source": [ | ||
"Orbax relies on [Tensorstore](https://google.github.io/tensorstore/) to store\n", | ||
"individual arrays in a checkpoint. Tensorstore provides efficient, scalable library for reading and writing arrays.\n", | ||
"\n", | ||
"Until recently, however, our use of Tensorstore came with a few drawbacks. Chief among them was the fact that every parameter in a training state would be saved as a separate directory. This approach can be quite performant, even for models with hundreds of billions of parameters, *provided that model layers are stacked*. Otherwise, hundreds or thousands of directories may be created in the checkpoint.\n", | ||
"\n", | ||
"This fact can lead to very slow restore times, which is undesirable in and of itself, but is particularly painful for jobs that may be preempted frequently and need to restart, for example.\n", | ||
"\n", | ||
"While it is slightly less of a concern at save time, since writes to disk can happen asynchronously, the synchronous portion of the save can still be slow as many directories are created.\n", | ||
"\n", | ||
"Additionally, if individual parameters are small, storage may be wasted on filesystems with minimum file sizes." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "0zou_vidLNMd" | ||
}, | ||
"source": [ | ||
"## Introducing OCDBT\n", | ||
"\n", | ||
"The new, optimized checkpoint format provided by Orbax is backed by Tensorstore's [OCDBT](https://google.github.io/tensorstore/kvstore/ocdbt/index.html) driver (optionally-cooperative distributed B-tree).\n", | ||
"\n", | ||
"For practical purposes, this means that we will no longer store one parameter per directory, but will aggregate many parameters into a smaller set of large files.\n", | ||
"\n", | ||
"Empirically, we have observed substantial speed-ups in both save and restore when using OCDBT." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "uA_OTs-GEmty" | ||
}, | ||
"source": [ | ||
"### Save Performance (sec)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "MUs2AbpcaG8U" | ||
}, | ||
"source": [ | ||
"\u003cimg src=https://orbax.readthedocs.io/en/latest/img/checkpoint/benchmarks/save_ocdbt.png\u003e" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "FhyHhgwqEgmR" | ||
}, | ||
"source": [ | ||
"### Restore Performance (sec)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "5Yrd8RQ3cZer" | ||
}, | ||
"source": [ | ||
"\u003cimg src=https://orbax.readthedocs.io/en/latest/img/checkpoint/benchmarks/restore_ocdbt.png\u003e" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "kfO6-6ZENhEG" | ||
}, | ||
"source": [ | ||
"## Checkpoint Format\n", | ||
"\n", | ||
"Concretely, what does the new checkpoint format look like in comparison to the old?" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "j7veWqRzQ7Jb" | ||
}, | ||
"source": [ | ||
"### Old Format" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "QOHzh4hAPUJQ" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"f = \"\"\"\n", | ||
"path/to/my/checkpoint/dir/\n", | ||
" 0/\n", | ||
" state/\n", | ||
" layer0.param0/\n", | ||
" .zarray\n", | ||
" 0.0\n", | ||
" 0.1\n", | ||
" 1.0\n", | ||
" 1.1\n", | ||
" layer1.param0/\n", | ||
" .zarray\n", | ||
" 0.0\n", | ||
" ...\n", | ||
" \u003canother_item\u003e/\n", | ||
" ...\n", | ||
" 1/\n", | ||
" ...\n", | ||
" 2/\n", | ||
" ...\n", | ||
"\n", | ||
"Note: in this case, `0.0`, `0.1`, etc. provides an indication of how the array\n", | ||
"was sharded when originally saved.\n", | ||
"\"\"\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "-0J3DhoFQ-x1" | ||
}, | ||
"source": [ | ||
"### New Format" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "K5WqzSTpRBp3" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"f = \"\"\"\n", | ||
"path/to/my/checkpoint/dir/\n", | ||
" 0/\n", | ||
" state/\n", | ||
" checkpoint # legacy msgpack file, stores tree structure\n", | ||
" tree_metadata # (maybe) new proto file, stores tree structure\n", | ||
" d/ # array data stored here\n", | ||
" 012b2c6e5c9d2a16c240a59d5f0f35c0\n", | ||
" 056e0816bdc5496a86251e58a0ec202b\n", | ||
" ...\n", | ||
" manifest.0000000000000001\n", | ||
" ...\n", | ||
" manifest.ocdbt\n", | ||
" \u003canother_item\u003e/\n", | ||
" ...\n", | ||
" 1/\n", | ||
" ...\n", | ||
" 2/\n", | ||
" ...\n", | ||
"\"\"\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "hOf2vscWRF5u" | ||
}, | ||
"source": [ | ||
"## Enabling OCDBT" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "ADRsxIkFRPZR" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import orbax.checkpoint as ock\n", | ||
"\n", | ||
"# Ensure that the coordinator_server is kept alive for the duration of the\n", | ||
"# program (if not None).\n", | ||
"# The server will only be non-None on a single process.\n", | ||
"ocdbt_context, coordinator_server = (\n", | ||
" ock.type_handlers.create_coordinator_server_and_context()\n", | ||
")\n", | ||
"ock.type_handlers.register_standard_handlers_with_options(\n", | ||
" use_ocdbt=True, ts_context=ocdbt_context\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "gXXzqbco_UgX" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Later, make sure PyTreeCheckpointHandler is initialized with `use_ocdbt=True`.\n", | ||
"# Depending on when you read this, the option may already default to True.\n", | ||
"ckptr = ock.Checkpointer(ock.PyTreeCheckpointHandler(use_ocdbt=True))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "HxLoTiZnAOvw" | ||
}, | ||
"source": [ | ||
"## Additional Notes\n", | ||
"\n", | ||
"All checkpoints previously produced by Orbax in the old format will still be\n", | ||
"readable when OCDBT is enabled. However, if a checkpoint is produced in the OCDBT format, it cannot be read if the OCDBT feature is disabled." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"last_runtime": { | ||
"build_target": "//learning/grp/tools/ml_python:ml_notebook", | ||
"kind": "private" | ||
}, | ||
"private_outputs": true, | ||
"provenance": [ | ||
{ | ||
"file_id": "1bRC6p0AstPPAAW0AUoxHaOFEWpaW_GjI", | ||
"timestamp": 1688077923387 | ||
} | ||
], | ||
"toc_visible": true | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters