Skip to content

Commit

Permalink
Update code to read in state dict layer by layer
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Jun 13, 2024
1 parent 6b1226e commit e95e766
Showing 1 changed file with 43 additions and 11 deletions.
54 changes: 43 additions & 11 deletions src/compressed_tensors/utils/converters/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Callable, Dict, Iterable, Union
from typing import Callable, Dict, Iterable, Iterator, Tuple, Union

import torch
from compressed_tensors.registry.registry import RegistryMixin
Expand Down Expand Up @@ -77,22 +77,34 @@ def convert_from_safetensors(cls, filepath: str, save_dir: str = None) -> str:
save_dir_.mkdir(exist_ok=True, parents=True)

metadata = {"format": "pt", "source": "Created by SparseML"}

# transform and save the state_dict
if filepath_.is_dir():
for file in filepath_.glob("*.safetensors"):
_LOGGER.info(f"Loading file: {file}")
state_dict: StateDictType = load_safetensors_state_dict(file)
new_state_dict = cls.translate(state_dict=state_dict)
save_file(
new_state_dict, filename=save_dir_ / file.name, metadata=metadata
new_state_dict = {}
state_dict: Iterable[StateDictType] = load_safetensors_state_dict(
file, by_layers=True
)
for layer_state_dict in state_dict:
new_state_dict.update(cls.translate(state_dict=layer_state_dict))

if new_state_dict:
save_file(
new_state_dict,
filename=save_dir_ / file.name,
metadata=metadata,
)
_copy_non_safetensor_files_(filepath_, save_dir_)
_update_quantization_config(filepath_, save_dir_)

elif filepath_.is_file():
state_dict: StateDictType = load_safetensors_state_dict(filepath)
new_state_dict = cls.translate(state_dict=state_dict)
new_state_dict = {}
state_dict: Iterable[StateDictType] = load_safetensors_state_dict(
file, by_layers=True
)
for layer_state_dict in state_dict:
new_state_dict.update(cls.translate(state_dict=layer_state_dict))

save_file(
new_state_dict, save_path=save_dir_ / filepath_.name, metadata=metadata
)
Expand Down Expand Up @@ -177,12 +189,32 @@ def _update_quantization_config(source_dir: Path, dest_dir: Path):
config.save_pretrained(dest_dir)


def load_safetensors_state_dict(file_path: str) -> Dict[str, torch.Tensor]:
def load_safetensors_state_dict(
file_path: str, by_layers: bool = True
) -> Iterator[Tuple[str, Dict[str, torch.Tensor]]]:
"""
Load a safetensors file from disk
:param file_path: path to the safetensors file
:return: dictionary of safetensors data
:param by_layers: if True, return a iterator with dictionary of safetensors
data by layers
:return: Iterator of dictionary of safetensors data or iterator of
dictionaries by layers
"""
with safe_open(file_path, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
if by_layers:
current_layer = None
layer_data = {}
for key in sorted(f.keys()):
layer_name, param_name = key.split(".", 1)
if current_layer is None:
current_layer = layer_name
elif layer_name != current_layer:
yield current_layer, layer_data
current_layer = layer_name
layer_data = {}
layer_data[key] = f.get_tensor(key)
if layer_data:
yield layer_data
else:
yield {key: f.get_tensor(key) for key in f.keys()}

0 comments on commit e95e766

Please sign in to comment.