Skip to content

Commit

Permalink
adding option to use safetensors for checkpoint manager
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Jun 20, 2024
1 parent 9b737af commit ea0ace3
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 5 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "fjformer"
authors = [{ name = "Erfan Zare Chavoshi", email = "[email protected]" }]
requires-python = ">=3.9"
readme = "README.md"
version = "0.0.68"
version = "0.0.69"

dependencies = [
"jax>=0.4.28",
Expand All @@ -20,6 +20,7 @@ dependencies = [
"einops==0.8.0",
"ml-collections==0.1.1",
"plum-dispatch==2.3.2",
"safetensors",
"termcolor"
]

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ chex>=0.1.7
einops==0.8.0
ml-collections==0.1.1
plum-dispatch==2.3.2
termcolor
termcolor
safetensors
2 changes: 1 addition & 1 deletion src/fjformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from fjformer import optimizers as optimizers
from fjformer import linen as linen

__version__ = "0.0.68"
__version__ = "0.0.69"

__all__ = (
"JaxRNG",
Expand Down
120 changes: 119 additions & 1 deletion src/fjformer/checkpoint/streamer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import jax
import flax
import safetensors.flax
import tqdm
from flax.serialization import (
from_bytes, to_bytes, to_state_dict, from_state_dict
Expand All @@ -10,7 +11,30 @@
from jax import numpy as jnp

from flax import struct
from typing import Callable, Literal, Union
from typing import Callable, Literal, Union, Optional, Tuple


def load_file(filename) -> Tuple[dict, dict]:
result = {}
with safetensors.safe_open(filename, framework="flax") as f:
metadata = f.metadata()
for k in f.keys():
result[k] = f.get_tensor(k)
return result, metadata


def is_flatten(pytree: Union[dict, struct.PyTreeNode]):
"""The is_flatten function checks if the pytree is flattened.
If it is, then the first key in the dictionary will be a tuple of (mpl, mpl_id).
Otherwise, it will be an integer representing mpl_id.
Args:
pytree: dict: Pass the pytree to the function
Returns:
True if the pytree is a flattened tree, and false otherwise
"""
return True if isinstance([k for k in pytree.keys()][0], tuple) else False


def get_dtype(tensor, dtype):
Expand Down Expand Up @@ -70,6 +94,100 @@ def save_checkpoint(
state, path, gather_fns, self.float_dtype, mismatch_allowed=mismatch_allowed
)

@staticmethod
def load_checkpoint_safe(
path: Union[str, os.PathLike],
target=None,
shard_fns: dict[Callable] = None,
verbose: bool = False,
mismatch_allowed: bool = True
):
shard_functions_mismatch = 0
state, metadata = load_file(path)
state = flax.traverse_util.unflatten_dict(state, sep=".")
state = flax.traverse_util.flatten_dict(state)
if shard_fns is not None:

pbar_sharding = tqdm.tqdm(list(state.keys()), desc="Sharding State", disable=not verbose)
if not is_flatten(shard_fns):
shard_fns = flatten_dict(shard_fns)
for key in list(state.keys()):
try:
callable_func = shard_fns[key]
if callable_func is None and not mismatch_allowed:
raise KeyError(f"Shard Function {key} is None and NoneType OBJ is not callable.")

if callable_func is None:
shard_functions_mismatch += 1
else:
state[key] = callable_func(state[key])
except KeyError as k_err:
if mismatch_allowed:
shard_functions_mismatch += 1
else:
raise KeyError(k_err)
pbar_sharding.set_postfix(sharding_mismatch=shard_functions_mismatch)
pbar_sharding.update(1)
if target is not None: # noqa
flattened_target = flatten_dict(
to_state_dict(target), keep_empty_nodes=True
)
for key, value in flattened_target.items():
if key not in state and value == empty_node:
state[key] = value

state = unflatten_dict(state)
if target is None:
return state, metadata

return from_state_dict(target, state), metadata

@staticmethod
def save_checkpoint_safe(
state: struct.PyTreeNode,
path: Union[str, os.PathLike],
gather_fns: dict[Callable] = None,
float_dtype=None,
verbose: bool = True,
mismatch_allowed: bool = True,
metadata: Optional[dict[str, str]] = None
):
state = to_state_dict(state)
gather_functions_mismatch = 0
if is_flatten(state):
state = unflatten_dict(state)
if gather_fns is not None:
if not is_flatten(gather_fns):
gather_fns = flatten_dict(gather_fns)
state = flatten_dict(state)
pbar_gather = tqdm.tqdm(list(state.keys()), desc="Gathering State", disable=not verbose)
for key in pbar_gather:
try:
callable_func = gather_fns[key]
if callable_func is None and not mismatch_allowed:
raise KeyError(f"Gather Function {key} is None and NoneType OBJ is not callable.")
if callable_func is None:
gather_functions_mismatch += 1
else:
state[key] = callable_func(state[key])
except KeyError as e:
if mismatch_allowed:
pbar_gather.set_postfix(gather_mismatch=gather_functions_mismatch)
else:
raise KeyError(e)
pbar_gather.update(1)
state = flax.traverse_util.flatten_dict(state, sep=".")
for key in list(state.keys()):
if not isinstance(state[key], jax.Array):
state[key] = jnp.array(state[key])
state[key] = get_dtype(state[key], float_dtype)

safetensors.flax.save_file(
tensors=state,
filename=path,
metadata=metadata
)

@staticmethod
def save_state_to_file(
state: struct.PyTreeNode,
Expand Down
2 changes: 1 addition & 1 deletion test/import_check_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def main():
start_time = time.time()

from src import fjformer
from src.fjformer import checkpoint
from src.fjformer import functions
from src.fjformer import linen
Expand Down

0 comments on commit ea0ace3

Please sign in to comment.