Skip to content

Commit

Permalink
Preserve sparsity SPARSEGPT (#2282)
Browse files Browse the repository at this point in the history
* test

* Preserve weight sparsity if greater than threshold

* Add argument to preserve sparsity mask in SPARSEGPT

* fix case when mask is none

---------

Co-authored-by: Sara Adkins <[email protected]>
  • Loading branch information
rahul-tuli and Sara Adkins authored May 17, 2024
1 parent 14a1b08 commit 446555f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class SparseGPTModifier(Modifier):
:param block_size: Used to determine number of columns to compress in one pass
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param preserve_sparsity_mask: Whether or not to preserve the sparsity mask
during when applying sparsegpt, this becomes useful when starting from a
previously pruned model, defaults to False.
"""

sparsity: Union[float, List[float]] = 0.0
Expand All @@ -68,6 +71,7 @@ class SparseGPTModifier(Modifier):
prunem_: Optional[int] = None
block_size: int = 128
dampening_frac: Optional[float] = 0.01
preserve_sparsity_mask: bool = False

def on_initialize_structure(self, state: State, **kwargs):
"""
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def _compression_arguments(self, sparsity):
"prunem": self.prunem_,
"blocksize": self.block_size,
"percdamp": self.dampening_frac,
"preserve_sparsity_mask": self.preserve_sparsity_mask,
}

def _compression_class(self):
Expand Down
42 changes: 40 additions & 2 deletions src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def fasterprune(
prunem: int = 0,
blocksize: int = 128,
percdamp: float = 0.01,
preserve_sparsity_mask: bool = False,
):
"""
Run pruning and quantization(if applicable) on the layer up to the target
Expand All @@ -95,6 +96,7 @@ def fasterprune(
:param blocksize: Number of columns to compress in one pass
:param percdamp: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param preserve_sparsity_mask: Extend or ignore the base sparsity mask
"""
final_shape = self.layer.weight.shape
final_dtype = self.layer.weight.dtype
Expand Down Expand Up @@ -123,6 +125,13 @@ def fasterprune(
Hinv = self.H

mask = None
if preserve_sparsity_mask:
# compute existing sparsity mask
mask = torch.where(
W == 0,
torch.tensor(1, dtype=torch.bool),
torch.tensor(0, dtype=torch.bool),
)

# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, self.columns, blocksize):
Expand All @@ -138,12 +147,32 @@ def fasterprune(
if prunen == 0:
if mask is not None:
mask1 = mask[:, i1:i2]
if int(W1.numel() * sparsity) > mask1.sum():
# target sparsity is higher than base sparsity, extend mask1
tmp = (
(~mask[:, i1:i2])
* W1**2
/ (torch.diag(Hinv1).reshape((1, -1))) ** 2
)
thresh = torch.sort(tmp.flatten())[0][
int(tmp.numel() * sparsity)
]
mask1 = tmp <= thresh
else:
raise ValueError(
"The target sparsity is lower than the sparsity "
"of the base model. Please retry "
"after turning preserve_sparsity_mask=False"
)
else:
tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
mask1 = tmp <= thresh
else:
mask1 = torch.zeros_like(W1) == 1
if mask is not None:
mask1 = mask[:, i1:i2]
else:
mask1 = torch.zeros_like(W1) == 1

for i in range(count):
w = W1[:, i]
Expand All @@ -154,6 +183,10 @@ def fasterprune(
W1[:, i : (i + prunem)] ** 2
/ (torch.diag(Hinv1)[i : (i + prunem)].reshape((1, -1))) ** 2
)

if mask is not None:
tmp = tmp * (~mask[:, i : (i + prunem)])

mask1.scatter_(
1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True
)
Expand All @@ -174,7 +207,12 @@ def fasterprune(
W[:, i1:i2] = Q1
Losses += torch.sum(Losses1, 1) / 2

W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
if preserve_sparsity_mask:
# respect the sparsity of other groups
# really not needed, but kept for explicitness
W[:, i2:] -= (~mask[:, i2:]) * Err1.matmul(Hinv[i1:i2, i2:])
else:
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

_LOGGER.info("time %.2f" % (time.time() - tick))
_LOGGER.info("error %.2f" % torch.sum(Losses).item())
Expand Down

0 comments on commit 446555f

Please sign in to comment.