Skip to content

Commit

Permalink
Update sparsity_config.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz authored Jun 5, 2024
1 parent ef0232e commit 535b30d
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions src/sparseml/transformers/compression/sparsity_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from compressed_tensors import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.quantization.utils import is_model_quantized
from sparseml.pytorch.utils import ModuleSparsificationInfo
from sparseml.transformers.compression.helpers import (
infer_sparsity_structure_from_model,
infer_sparsity_structure_from_stage_modifiers,
)


class SparsityConfigMetadata:
Expand All @@ -47,26 +51,34 @@ def infer_global_sparsity(
return global_sparsity

@staticmethod
def infer_sparsity_structure() -> str:
def infer_sparsity_structure(model: Optional[Module] = None) -> str:
"""
Determines what sparsity structure, if any, was applied in the currently active
sparse session
Determines what sparsity structure, if any, was applied.
First, there is an attempt to dedue the sparsity structure
from the currently active sparse session.
If that fails, the sparsity structure is inferred from the
model (if provided)
Finally, if both fail, the sparsity structure is set to
"unstructured"
:return: sparsity structure as a string
"""
sparsity_structure = None

current_session = sparseml.active_session()
stage_modifiers = current_session.lifecycle.modifiers
sparsity_structure = "unstructured"
if stage_modifiers:
sparsity_structure = infer_sparsity_structure_from_stage_modifiers(
stage_modifiers
)

# check for applied pruning modifiers
for stage in stage_modifiers:
if stage.applied:
for modifier in stage.modifiers:
if hasattr(modifier, "mask_structure"):
sparsity_structure = modifier.mask_structure
break
if model and sparsity_structure is None:
sparsity_structure = infer_sparsity_structure_from_model(model)

return sparsity_structure
return sparsity_structure or "unstructured"

@staticmethod
def from_pretrained(
Expand All @@ -91,7 +103,9 @@ def from_pretrained(
if global_sparsity < 0.05:
return None

sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure()
sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure(
model=model
)
if is_model_quantized(model):
# compressing a sparse quantized model is not supported yet
format = CompressionFormat.dense.value
Expand Down

0 comments on commit 535b30d

Please sign in to comment.