From 48df1ac04647343b4cd0e25a1d8292709c9e4246 Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Wed, 22 Jan 2025 16:56:01 -0800 Subject: [PATCH] Use a separate `_is_saving_in_progress` bool to track whether a save is ongoing, as this prevents conflicts when concurrent `wait_until_finished` calls cause the finalize thread lock to block. PiperOrigin-RevId: 718583104 --- checkpoint/orbax/checkpoint/checkpoint_manager.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index c8f16e7b..3e7bb05d 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -780,6 +780,10 @@ def __init__( with self._finalize_thread_lock: self._finalize_thread = None + self._is_saving_in_progress_lock = threading.Lock() + with self._is_saving_in_progress_lock: + self._is_saving_in_progress = False + self._checkpoint_deleter: deleter.CheckpointDeleter = ( deleter.create_checkpoint_deleter( self._multiprocessing_options.primary_host, @@ -1317,6 +1321,8 @@ def save( assert self._finalize_thread is None if is_async_checkpointer(self._checkpointer): + with self._is_saving_in_progress_lock: + self._is_saving_in_progress = True with self._finalize_thread_lock: finalize_thread_name = 'save_finalize' logging.info( @@ -1822,10 +1828,8 @@ def wait_until_finished(self): def is_saving_in_progress(self) -> bool: """Returns whether a checkpoint save is in progress.""" - with self._finalize_thread_lock: - return ( - self._finalize_thread is not None and self._finalize_thread.is_alive() - ) + with self._is_saving_in_progress_lock: + return self._is_saving_in_progress def check_for_errors(self): """See superclass documentation.""" @@ -1905,6 +1909,8 @@ def _finalize(self, step: int, steps_to_remove: List[int]): threading.current_thread().name, step, ) + with self._is_saving_in_progress_lock: + self._is_saving_in_progress = False def close(self): """See superclass documentation."""