From 535b30dca4818319afce40fc24cbcd5acb2908e2 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Wed, 5 Jun 2024 15:08:41 +0200 Subject: [PATCH] Update sparsity_config.py --- .../compression/sparsity_config.py | 40 +++++++++++++------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/src/sparseml/transformers/compression/sparsity_config.py b/src/sparseml/transformers/compression/sparsity_config.py index b5f69cb83e1..f8bc477366c 100644 --- a/src/sparseml/transformers/compression/sparsity_config.py +++ b/src/sparseml/transformers/compression/sparsity_config.py @@ -21,6 +21,10 @@ from compressed_tensors import CompressionFormat, SparsityCompressionConfig from compressed_tensors.quantization.utils import is_model_quantized from sparseml.pytorch.utils import ModuleSparsificationInfo +from sparseml.transformers.compression.helpers import ( + infer_sparsity_structure_from_model, + infer_sparsity_structure_from_stage_modifiers, +) class SparsityConfigMetadata: @@ -47,26 +51,34 @@ def infer_global_sparsity( return global_sparsity @staticmethod - def infer_sparsity_structure() -> str: + def infer_sparsity_structure(model: Optional[Module] = None) -> str: """ - Determines what sparsity structure, if any, was applied in the currently active - sparse session + Determines what sparsity structure, if any, was applied. + + First, there is an attempt to dedue the sparsity structure + from the currently active sparse session. + + If that fails, the sparsity structure is inferred from the + model (if provided) + + Finally, if both fail, the sparsity structure is set to + "unstructured" :return: sparsity structure as a string """ + sparsity_structure = None + current_session = sparseml.active_session() stage_modifiers = current_session.lifecycle.modifiers - sparsity_structure = "unstructured" + if stage_modifiers: + sparsity_structure = infer_sparsity_structure_from_stage_modifiers( + stage_modifiers + ) - # check for applied pruning modifiers - for stage in stage_modifiers: - if stage.applied: - for modifier in stage.modifiers: - if hasattr(modifier, "mask_structure"): - sparsity_structure = modifier.mask_structure - break + if model and sparsity_structure is None: + sparsity_structure = infer_sparsity_structure_from_model(model) - return sparsity_structure + return sparsity_structure or "unstructured" @staticmethod def from_pretrained( @@ -91,7 +103,9 @@ def from_pretrained( if global_sparsity < 0.05: return None - sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure() + sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure( + model=model + ) if is_model_quantized(model): # compressing a sparse quantized model is not supported yet format = CompressionFormat.dense.value