Skip to content

Commit

Permalink
Allow skipping tmp checkpoint saving during eval.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 729482348
  • Loading branch information
The kauldron Authors committed Feb 21, 2025
1 parent 746dbc8 commit 173a5b9
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions kauldron/evals/eval_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def _preemptable_iter_new_checkpoints(
) -> Iterator[train_step.TrainState]:
"""Yields the new checkpoints."""
# Skip the `iter_new_checkpoints` for eval-only jobs.
if 'save_tmp_ckpt' in trainer.aux:
save_tmp_ckpt = trainer.aux['save_tmp_ckpt']
else:
save_tmp_ckpt = True

if trainer.setup.eval_only:
return

Expand Down Expand Up @@ -206,11 +211,13 @@ def _preemptable_iter_new_checkpoints(
assert int(state.step) == step
# Temporarily copy the state to the eval checkpoint, to ensure that
# it won't be deleted by the train job until the current eval is done.
eval_ckpt.save(state, step=step)
if save_tmp_ckpt:
eval_ckpt.save(state, step=step)
yield state
# state might have been donated, we should not access it after this point.
# Eval is done, remove the duplicated checkpoint
eval_ckpt.delete(step)
if save_tmp_ckpt:
eval_ckpt.delete(step)


def _restore_checkpoint(
Expand Down

0 comments on commit 173a5b9

Please sign in to comment.