Network pruning is one of popular approaches of network compression, which removes the least important parameters in the network to achieve compact architectures with minimal accuracy drop.
- Unstructured Pruning
Unstructured pruning means finding and removing the less salient connection in the model where the nonzero patterns are irregular and could be anywhere in the matrix.
- Structured Pruning
Structured pruning means finding parameters in groups, deleting entire blocks, filters, or channels according to some pruning criterions.
Pruning Type | Pruning Granularity | Pruning Algorithm | Framework |
---|---|---|---|
Unstructured Pruning | Element-wise | Magnitude | PyTorch, TensorFlow |
Pattern Lock | PyTorch | ||
Structured Pruning | Filter/Channel-wise | Gradient Sensitivity | PyTorch |
Block-wise | Group Lasso | PyTorch | |
Element-wise | Pattern Lock | PyTorch |
-
Magnitude
- The algorithm prunes the weight by the lowest absolute value at each layer with given sparsity target.
-
Gradient sensitivity
- The algorithm prunes the head, intermediate layers, and hidden states in NLP model according to importance score calculated by following the paper FastFormers.
-
Group Lasso
- The algorithm uses Group lasso regularization to prune entire rows, columns or blocks of parameters that result in a smaller dense network.
-
Pattern Lock
- The algorithm locks the sparsity pattern in fine tune phase by freezing those zero values of weight tensor during weight update of training.
Neural Compressor pruning API is defined under neural_compressor.experimental.Pruning
, which takes a user defined yaml file as input. The user defined yaml defines training, pruning and evaluation behaviors.
API Readme.
Below is the launcher code if training behavior is defined in user-defined yaml.
from neural_compressor.experimental import Pruning
prune = Pruning('/path/to/user/pruning/yaml')
prune.model = model
model = prune.fit()
The user-defined yaml follows below syntax, note train
section is optional if user implements pruning_func
and sets to pruning_func
attribute of pruning instance.
User could refer to the yaml template file to know field meanings.
The train
section defines the training behavior, including what training hyper-parameter would be used and which dataloader is used during training.
The approach
section defines which pruning algorithm is used and how to apply it during training process.
-
weight compression
: pruning target, currently onlyweight compression
is supported.weight compression
means zeroing the weight matrix. The parameters forweight compression
is divided into global parameters and local parameters in differentpruners
. Global parameters may containstart_epoch
,end_epoch
,initial_sparsity
,target_sparsity
andfrequency
.start_epoch
: on which epoch pruning beginsend_epoch
: on which epoch pruning endsinitial_sparsity
: initial sparsity goal, default 0.target_sparsity
: target sparsity goalfrequency
: frequency to updating sparsity
-
Pruner
:-
prune_type
: pruning algorithm, currentlybasic_magnitude
,gradient_sensitivity
andgroup_lasso
are supported. -
names
: weight name to be pruned. If no weight is specified, all weights of the model will be pruned. -
parameters
: Additional parameters is requiredgradient_sensitivity
prune_type, which is defined inparameters
field. Those parameters determined how a weight is pruned, including the pruning target and the calculation of weight's importance. It contains:target
: the pruning target for weight, will override global configtarget_sparsity
if set.stride
: each stride of the pruned weight.transpose
: whether to transpose weight before prune.normalize
: whether to normalize the calculated importance.index
: the index of calculated importance.importance_inputs
: inputs of the importance calculation for weight.importance_metric
: the metric used in importance calculation, currentlyabs_gradient
andweighted_gradient
are supported.
Take above as an example, if we assume the 'bert.encoder.layer.0.attention.output.dense.weight' is the shape of [N, 12*64]. The target 8 and stride 64 is used to control the pruned weight shape to be [N, 8*64].
Transpose
set to True indicates the weight is pruned at dim 1 and should be transposed to [12*64, N] before pruning.importance_input
andimportance_metric
specify the actual input and metric to calculate importance matrix.
-
In this case, the launcher code is like the following:
from neural_compressor.experimental import Pruning, common
prune = Pruning(args.config)
prune.model = model
prune.pruning_func = pruning_func
model = prune.fit()
User can pass the customized training/evaluation functions to Pruning
for flexible scenarios. In this case, pruning process can be done by pre-defined hooks in Neural Compressor. User needs to put those hooks inside the training function.
Neural Compressor defines several hooks for user use:
on_epoch_begin(epoch) : Hook executed at each epoch beginning
on_batch_begin(batch) : Hook executed at each batch beginning
on_batch_end() : Hook executed at each batch end
on_epoch_end() : Hook executed at each epoch end
on_post_grad() : Hook executed after gradients calculated and before backward
Following section shows how to use hooks in user pass-in training function which is part of example from BERT training:
def pruning_func(model):
for epoch in range(int(args.num_train_epochs)):
pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')
model.train()
prune.on_epoch_begin(epoch)
for step, batch in enumerate(train_dataloader):
prune.on_batch_begin(step)
batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'labels': batch[3]}
#inputs['token_type_ids'] = batch[2]
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
if (step + 1) % args.gradient_accumulation_steps == 0:
optimizer.step()
scheduler.step() # Update learning rate schedule
model.zero_grad()
prune.on_batch_end()
...
For related examples, please refer to Pruning examples.