From 179fd90ce314952474ba8dcecc03b085aa8cab67 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Wed, 3 Jul 2024 18:17:48 +0200 Subject: [PATCH] [Fix] Allow to create `SparseAutoModelForCausalLM` with `trust_remote_code=True` (#2349) * initial commit * better comments --------- Co-authored-by: bogunowicz@arrival.com --- .../transformers/sparsification/sparse_model.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/sparseml/transformers/sparsification/sparse_model.py b/src/sparseml/transformers/sparsification/sparse_model.py index c5e17764874..31feeeef88c 100644 --- a/src/sparseml/transformers/sparsification/sparse_model.py +++ b/src/sparseml/transformers/sparsification/sparse_model.py @@ -101,16 +101,29 @@ def skip(*args, **kwargs): ) # instantiate compressor from model config - compressor = ModelCompressor.from_pretrained(pretrained_model_name_or_path) + compressor = ModelCompressor.from_pretrained( + pretrained_model_name_or_path, **kwargs + ) # temporarily set the log level to error, to ignore printing out long missing # and unexpected key error messages (these are EXPECTED for quantized models) logger = logging.getLogger("transformers.modeling_utils") restore_log_level = logger.getEffectiveLevel() logger.setLevel(level=logging.ERROR) + + if kwargs.get("trust_remote_code"): + # By artifically aliasing + # class name SparseAutoModelForCausallLM to + # AutoModelForCausalLM we can "trick" the + # `from_pretrained` method into properly + # resolving the logic when + # (has_remote_code and trust_remote_code) == True + cls.__name__ = AutoModelForCausalLM.__name__ + model = super(AutoModelForCausalLM, cls).from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs ) + if model.dtype != model.config.torch_dtype: _LOGGER.warning( f"The dtype of the loaded model: {model.dtype} is different "