Skip to content

Commit

Permalink
Update TfGrainCheckpointHandler.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686958575
  • Loading branch information
jan-matthis authored and copybara-github committed Oct 17, 2024
1 parent ce73f91 commit 64b309c
Showing 1 changed file with 1 addition and 9 deletions.
10 changes: 1 addition & 9 deletions connectomics/jax/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,7 @@ def restore_checkpoint(
args=ocp.args.Composite(**restore_args_dict))


class TfGrainCheckpointHandler(tfgrain.OrbaxCheckpointHandler):

def save(self, directory: epath.Path, args: 'TfGrainCheckpointArgs') -> None:
return super().save(directory, args.item)

def restore(
self, directory: epath.Path, args: 'TfGrainCheckpointArgs'
) -> tfgrain.TfGrainDatasetIterator:
return super().restore(directory, args.item)
TfGrainCheckpointHandler = tfgrain.OrbaxCheckpointHandler


@ocp.args.register_with_handler( # pytype:disable=wrong-arg-types
Expand Down

0 comments on commit 64b309c

Please sign in to comment.