From 97b8e15a8bd7bea65595b9d3f7bf36ed298ec551 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 13 Jun 2024 15:08:03 +0000 Subject: [PATCH] allow passing in kwargs --- src/compressed_tensors/utils/converters/converters.py | 8 ++++++-- src/compressed_tensors/utils/converters/main.py | 7 +++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/utils/converters/converters.py b/src/compressed_tensors/utils/converters/converters.py index 57898e88..ebe106a5 100644 --- a/src/compressed_tensors/utils/converters/converters.py +++ b/src/compressed_tensors/utils/converters/converters.py @@ -56,7 +56,9 @@ def translate(cls, state_dict: StateDictType, **kwargs) -> StateDictType: return new_state_dict @classmethod - def convert_from_safetensors(cls, filepath: str, save_dir: str = None) -> str: + def convert_from_safetensors( + cls, filepath: str, save_dir: str = None, **kwargs + ) -> str: """ Convert a .safetensors file or directory of .safetensors files, applying transformations to the state_dict and saving the new state_dict to a new @@ -86,7 +88,9 @@ def convert_from_safetensors(cls, filepath: str, save_dir: str = None) -> str: file, by_layers=True ) for layer_state_dict in state_dict: - new_state_dict.update(cls.translate(state_dict=layer_state_dict)) + new_state_dict.update( + cls.translate(state_dict=layer_state_dict, **kwargs) + ) if new_state_dict: save_file( diff --git a/src/compressed_tensors/utils/converters/main.py b/src/compressed_tensors/utils/converters/main.py index 122400a3..3089849c 100644 --- a/src/compressed_tensors/utils/converters/main.py +++ b/src/compressed_tensors/utils/converters/main.py @@ -19,19 +19,22 @@ __all__ = ["convert_autogptq_checkpoint"] -def convert_autogptq_checkpoint(old_checkpoint_path, new_checkpoint_path) -> str: +def convert_autogptq_checkpoint( + old_checkpoint_path, new_checkpoint_path, **kwargs +) -> str: """ Convert an autogptq checkpoint to a compressed tensor checkpoint :param old_checkpoint_path: the path to the autogptq checkpoint :param new_checkpoint_path: the path to save the converted compressed tensor checkpoint + :param kwargs: additional arguments to pass to the transformations :return: the path to the new checkpoint """ converter: BaseConverter = BaseConverter.load_from_registry( ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR ) checkpoint_path = converter.convert_from_safetensors( - old_checkpoint_path, new_checkpoint_path + old_checkpoint_path, new_checkpoint_path, **kwargs ) return checkpoint_path