Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code duplication around tensorstore_spec logic in orbax-checkpoint #1241

Open
minotru opened this issue Oct 13, 2024 · 2 comments
Open

Code duplication around tensorstore_spec logic in orbax-checkpoint #1241

minotru opened this issue Oct 13, 2024 · 2 comments

Comments

@minotru
Copy link

minotru commented Oct 13, 2024

Hi Orbax team,

I was looking at Orbax code at the latest version==0.7.0 and found that pieces of code with quite heavy logic around tensorstore_spec creation seem to contain duplicates.

I'd like to know if this code duplication intended by design or I am welcome to submit a PR.

Here get_tensorstore_spec is a part of public API, and I can't find any usage of get_tensorstore_spec by orbax-checkpoint itself
https://github.com/google/orbax/blob/8b4e90d573082a5c7caa5f99c51db376f62a6995/checkpoint/orbax/checkpoint/serialization.py#L97C5-L124

And here is a very similar piece of code in build_kvstore_tspec in _internal package, and build_kvstore_tspec is used heavily by type_handlers.py

def build_kvstore_tspec(
directory: str,
name: str | None = None,
*,
use_ocdbt: bool = True,
process_id: int | str | None = None,
) -> JsonSpec:
"""Constructs a spec for a Tensorstore KvStore.
Args:
directory: Base path (key prefix) of the KvStore, used by the underlying
file driver.
name: Name (filename) of the parameter.
use_ocdbt: Whether to use OCDBT driver.
process_id: [only used with OCDBT driver] If provided,
`{directory}/ocdbt.process_{process_id}` path is used as the base path.
If a string, must conform to [A-Za-z0-9]+ pattern.
Returns:
A Tensorstore KvStore spec in dictionary form.
"""
default_driver = DEFAULT_DRIVER
# Normalize path to exclude trailing '/'. In GCS path case, we will need to
# fix the path prefix to add back the stripped '/'.
directory = os.path.normpath(directory).replace('gs:/', 'gs://')
is_gcs_path = directory.startswith('gs://')
kv_spec = {}
if use_ocdbt:
if not is_gcs_path and not os.path.isabs(directory):
raise ValueError(f'Checkpoint path should be absolute. Got {directory}')
if process_id is not None:
process_id = str(process_id)
if re.fullmatch(_OCDBT_PROCESS_ID_RE, process_id) is None:
raise ValueError(
f'process_id must conform to {_OCDBT_PROCESS_ID_RE} pattern'
f', got {process_id}'
)
directory = os.path.join(
directory, f'{PROCESS_SUBDIR_PREFIX}{process_id}'
)
base_driver_spec = (
directory
if is_gcs_path
else {'driver': default_driver, 'path': str(directory)}
)
kv_spec.update({
'driver': 'ocdbt',
'base': base_driver_spec,
})
if name is not None:
kv_spec['path'] = name
kv_spec.update({ # pytype: disable=attribute-error
# Enable read coalescing. This feature merges adjacent read_ops into
# one, which could reduce I/O ops by a factor of 10. This is especially
# beneficial for unstacked models.
'experimental_read_coalescing_threshold_bytes': 1000000,
'experimental_read_coalescing_merged_bytes': 500000000000,
'experimental_read_coalescing_interval': '1ms',
# References the cache specified in ts.Context.
'cache_pool': 'cache_pool#ocdbt',
})
else:
if name is None:
path = directory
else:
path = os.path.join(directory, name)
if is_gcs_path:
kv_spec = _get_kvstore_for_gcs(path)
else:
kv_spec = {'driver': default_driver, 'path': path}
return kv_spec

Would you consider get_tensorstore_spec to reuse build_kvstore_tspec under the hood?


Also, there seems to be a bit of obscurity with default ts_context value.

  • In orbax/checkpoint/serialization.py, there is TS_CONTEXT in public serialization.py that is used as a default value of context in async_serialize (actually, orbax does not use async_serialize anywhere and recommends using async_serialize_shards) , async_serialize_shards, async_deserialize and by StringHandler.
  • At the same time, intype_handlers.py, there is get_ts_context() (it references _DEFAULT_OCDBT_TS_CONTEXT), and get_ts_context is used by all other handler implementations.

So, TS_CONTEXT from serialization.py seems to be never used by common checkpoint IO code.

Should we somehow leave only 1 source of truth for default ts_context values?

@cpgaffney1
Copy link
Collaborator

This code is currently in the process of a rework by @dicentra13. get_tensorstore_spec in serialization.py should be either removed or should reuse build_kvstore_tspec. I think it is only used in a couple places in internal code - just due slower progress in refactoring rather than any inherent need.

Again, async_serialize is used in one place internally that I'm working on eliminating, although probably will leave that function in place as a wrapper.

Agreed that ts_context should use one source of truth, that could use fixing.

@minotru
Copy link
Author

minotru commented Oct 16, 2024

Got it @cpgaffney1 , thank for the reply! Waiting for the refactoring by @dicentra13 :)

Let's close this issue once refactoring is completed and merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants