Skip to content

Commit

Permalink
allow passing in kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Jun 13, 2024
1 parent e95e766 commit 97b8e15
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
8 changes: 6 additions & 2 deletions src/compressed_tensors/utils/converters/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def translate(cls, state_dict: StateDictType, **kwargs) -> StateDictType:
return new_state_dict

@classmethod
def convert_from_safetensors(cls, filepath: str, save_dir: str = None) -> str:
def convert_from_safetensors(
cls, filepath: str, save_dir: str = None, **kwargs
) -> str:
"""
Convert a .safetensors file or directory of .safetensors files, applying
transformations to the state_dict and saving the new state_dict to a new
Expand Down Expand Up @@ -86,7 +88,9 @@ def convert_from_safetensors(cls, filepath: str, save_dir: str = None) -> str:
file, by_layers=True
)
for layer_state_dict in state_dict:
new_state_dict.update(cls.translate(state_dict=layer_state_dict))
new_state_dict.update(
cls.translate(state_dict=layer_state_dict, **kwargs)
)

if new_state_dict:
save_file(
Expand Down
7 changes: 5 additions & 2 deletions src/compressed_tensors/utils/converters/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,22 @@
__all__ = ["convert_autogptq_checkpoint"]


def convert_autogptq_checkpoint(old_checkpoint_path, new_checkpoint_path) -> str:
def convert_autogptq_checkpoint(
old_checkpoint_path, new_checkpoint_path, **kwargs
) -> str:
"""
Convert an autogptq checkpoint to a compressed tensor checkpoint
:param old_checkpoint_path: the path to the autogptq checkpoint
:param new_checkpoint_path: the path to save the converted compressed
tensor checkpoint
:param kwargs: additional arguments to pass to the transformations
:return: the path to the new checkpoint
"""
converter: BaseConverter = BaseConverter.load_from_registry(
ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR
)
checkpoint_path = converter.convert_from_safetensors(
old_checkpoint_path, new_checkpoint_path
old_checkpoint_path, new_checkpoint_path, **kwargs
)
return checkpoint_path

0 comments on commit 97b8e15

Please sign in to comment.