diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index c1534618302..74920d0d697 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -54,6 +54,9 @@ class SparseGPTModifier(Modifier): :param block_size: Used to determine number of columns to compress in one pass :param dampening_frac: Amount of dampening to apply to H, as a fraction of the diagonal norm + :param preserve_sparsity_mask: Whether or not to preserve the sparsity mask + during when applying sparsegpt, this becomes useful when starting from a + previously pruned model, defaults to False. """ sparsity: Union[float, List[float]] = 0.0 @@ -68,6 +71,7 @@ class SparseGPTModifier(Modifier): prunem_: Optional[int] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 + preserve_sparsity_mask: bool = False def on_initialize_structure(self, state: State, **kwargs): """ diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index 4825eed1a92..ec9dfd90d23 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -203,6 +203,7 @@ def _compression_arguments(self, sparsity): "prunem": self.prunem_, "blocksize": self.block_size, "percdamp": self.dampening_frac, + "preserve_sparsity_mask": self.preserve_sparsity_mask, } def _compression_class(self): diff --git a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py index d8a95f18853..0079071bd0e 100644 --- a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py +++ b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py @@ -84,6 +84,7 @@ def fasterprune( prunem: int = 0, blocksize: int = 128, percdamp: float = 0.01, + preserve_sparsity_mask: bool = False, ): """ Run pruning and quantization(if applicable) on the layer up to the target @@ -95,6 +96,7 @@ def fasterprune( :param blocksize: Number of columns to compress in one pass :param percdamp: Amount of dampening to apply to H, as a fraction of the diagonal norm + :param preserve_sparsity_mask: Extend or ignore the base sparsity mask """ final_shape = self.layer.weight.shape final_dtype = self.layer.weight.dtype @@ -123,6 +125,13 @@ def fasterprune( Hinv = self.H mask = None + if preserve_sparsity_mask: + # compute existing sparsity mask + mask = torch.where( + W == 0, + torch.tensor(1, dtype=torch.bool), + torch.tensor(0, dtype=torch.bool), + ) # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): @@ -138,12 +147,32 @@ def fasterprune( if prunen == 0: if mask is not None: mask1 = mask[:, i1:i2] + if int(W1.numel() * sparsity) > mask1.sum(): + # target sparsity is higher than base sparsity, extend mask1 + tmp = ( + (~mask[:, i1:i2]) + * W1**2 + / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + ) + thresh = torch.sort(tmp.flatten())[0][ + int(tmp.numel() * sparsity) + ] + mask1 = tmp <= thresh + else: + raise ValueError( + "The target sparsity is lower than the sparsity " + "of the base model. Please retry " + "after turning preserve_sparsity_mask=False" + ) else: tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] mask1 = tmp <= thresh else: - mask1 = torch.zeros_like(W1) == 1 + if mask is not None: + mask1 = mask[:, i1:i2] + else: + mask1 = torch.zeros_like(W1) == 1 for i in range(count): w = W1[:, i] @@ -154,6 +183,10 @@ def fasterprune( W1[:, i : (i + prunem)] ** 2 / (torch.diag(Hinv1)[i : (i + prunem)].reshape((1, -1))) ** 2 ) + + if mask is not None: + tmp = tmp * (~mask[:, i : (i + prunem)]) + mask1.scatter_( 1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True ) @@ -174,7 +207,12 @@ def fasterprune( W[:, i1:i2] = Q1 Losses += torch.sum(Losses1, 1) / 2 - W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_sparsity_mask: + # respect the sparsity of other groups + # really not needed, but kept for explicitness + W[:, i2:] -= (~mask[:, i2:]) * Err1.matmul(Hinv[i1:i2, i2:]) + else: + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) _LOGGER.info("time %.2f" % (time.time() - tick)) _LOGGER.info("error %.2f" % torch.sum(Losses).item())