From bcbe111ae5ee5434b59485f6f84559a60f4936de Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Thu, 30 May 2024 16:55:54 +0200 Subject: [PATCH] Update sparse_model.py --- src/sparseml/transformers/sparsification/sparse_model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/sparseml/transformers/sparsification/sparse_model.py b/src/sparseml/transformers/sparsification/sparse_model.py index 3132411d332..c5e17764874 100644 --- a/src/sparseml/transformers/sparsification/sparse_model.py +++ b/src/sparseml/transformers/sparsification/sparse_model.py @@ -111,6 +111,14 @@ def skip(*args, **kwargs): model = super(AutoModelForCausalLM, cls).from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs ) + if model.dtype != model.config.torch_dtype: + _LOGGER.warning( + f"The dtype of the loaded model: {model.dtype} is different " + "from from the dtype specified in the model config: " + f"{model.config.torch_dtype}." + "To load the model in the format that it was previously saved in, " + "set torch_dtype=`auto` in the SparseAutoModel creation call." + ) logger.setLevel(level=restore_log_level) # override the PreTrainedModel instance with compression save function modify_save_pretrained(model)