Skip to content

Commit

Permalink
Merge branch 'main' into add-nightly-weekly-gha
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka authored May 23, 2024
2 parents ac860a5 + c24e97f commit cfe0620
Show file tree
Hide file tree
Showing 43 changed files with 1,836 additions and 235 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.PHONY: build docs test

BUILDDIR := $(PWD)
CHECKDIRS := integrations src tests utils status setup.py
CHECKDIRS := integrations src tests utils status examples setup.py
CHECKGLOBS := 'integrations/**/*.py' 'src/**/*.py' 'tests/**/*.py' 'utils/**/*.py' 'status/**/*.py' setup.py
DOCDIR := docs
MDCHECKGLOBS := 'docs/**/*.md' 'docs/**/*.rst' 'integrations/**/*.md'
Expand Down
38 changes: 38 additions & 0 deletions examples/llama7b_sparse_quantized/2:4_w4a16_recipe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
sparsity_stage:
run_type: oneshot
sparsity_modifiers:
SparseGPTModifier:
sparsity: 0.5
mask_structure: "2:4"
sequential_update: false
finetuning_stage:
run_type: train
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
quantization_stage:
run_type: oneshot
quantization_modifiers:
vLLMQuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 4
type: "int"
symmetric: true
strategy: "channel"
targets: ["Linear"]
SparseGPTModifier:
sparsity: 0.0
quantize: True
sequential_update: false
47 changes: 47 additions & 0 deletions examples/llama7b_sparse_quantized/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Creating a Sparse Quantized Llama7b Model

The example in this folder runs in multiple stages to create a Llama 7b model with
a 2:4 sparsity pattern and W4A16 post training quantization (PTW). The model is
calibrated and trained with the ultachat200k dataset. At least 75GB of GPU memory is
required to run this example.

## Recipe Summary

The recipe used for this flow is located in [2:4_w4a16_recipe.yaml](./2:4_w4a16_recipe.yaml). It contains 3 stages that are outlined below.


### Stage 1: Sparsification

Runs the SparseGPT one-shot algorithm to prune the model to 50% sparsity with a 2:4
sparsity pattern. This means that 2 weights out of every group of 4 weights are masked to 0.

### Stage 2: Finetuning Recovery

This stage runs a single epoch of training on the ultrachat200k dataset while maintaining
the sparsity mask from stage 1. The purpose of this stage is to recover any accuracy lost
during the sparsification process.

### Stage 3: Quantization

Finally, we run the GPTQ one-shot algorithm to quantize all linear weights to 4 bit
channelwise.

## How to Run

We can run the entire staged recipe with one call to SparseML's `apply` pathway. This
will save a checkpoint of the model after each stage.

```python examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py```

### Compression

The resulting model will be uncompressed. To save a final compressed copy of the model
run the following:

```
import torch
from sparseml import SparseAutoModelForCausalLM
model = SparseAutoModelForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16)
model.save_pretrained(compressed_output_dir, save_compressed=True)
```
54 changes: 54 additions & 0 deletions examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

from sparseml.transformers import SparseAutoModelForCausalLM, apply


# define a recipe to handle sparsity, finetuning and quantization
recipe = "2:4_w4a16_recipe.yaml"

# load the model in as bfloat16 to save on memory and compute
model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base"
model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)

# uses SparseML's built-in preprocessing for ultra chat
dataset = "ultrachat-200k"

# save location of quantized model
output_dir = "output_llama7b_2:4_w4a16_channel"

# set dataset config parameters
splits = {"calibration": "train_gen[:5%]", "train": "train_gen"}
max_seq_length = 512
num_calibration_samples = 512

# set training parameters for finetuning
num_train_epochs = 1
logging_steps = 500
save_steps = 5000
gradient_checkpointing = True # saves memory during training
learning_rate = 0.0001
bf16 = True # using bfloat16 for training
lr_scheduler_type = "cosine"
warmup_ratio = 0.1

# this will run the recipe stage by stage:
# oneshot sparsification -> finetuning -> oneshot quantization
apply(
model=model,
dataset=dataset,
recipe=recipe,
bf16=bf16,
output_dir=output_dir,
splits=splits,
max_seq_length=max_seq_length,
num_calibration_samples=num_calibration_samples,
num_train_epochs=num_train_epochs,
logging_steps=logging_steps,
save_steps=save_steps,
gradient_checkpointing=gradient_checkpointing,
learning_rate=learning_rate,
lr_scheduler_type=lr_scheduler_type,
warmup_ratio=warmup_ratio,
)
185 changes: 185 additions & 0 deletions examples/llama7b_w4a16_quantization.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Quantizing Llama 7B to W4A16 Using SparseML's OneShot Pathway\n",
"\n",
"This example notebook walks through how to quantize Llama 7B using SparseML. We apply int4 channel-wise quantization all Linear layers, using UltraChat 200k as a calibration dataset.\n",
"\n",
"This example requires at least 45GB of GPU memory to run. The memory requirement can be reduced to 32GB by setting `sequential_update: true` in the recipe definition, but this will increase the runtime significantly."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from sparseml.transformers import SparseAutoModelForCausalLM, oneshot"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"SparseML uses recipes to define configurations for different oneshot algorithms. Recipes can be defined as a string or a yaml file. Below we create a sample recipe for GPTQ quantization. The recipe is made up of two different algorithms, called modifiers.\n",
"\n",
"1. **vLLMQuantizationModifier**: calibrates the model for quantization by calculating scale and zero points from a small amount of calibration data\n",
"2. **SparseGPTModifier**: applies the GPTQ algorithm, using the result of the vLLMQuantizationModifier to determine the best quantization bin to place each linear weight into"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"recipe=\"\"\"\n",
"quant_stage:\n",
" quant_modifiers:\n",
" vLLMQuantizationModifier:\n",
" ignore: [\"lm_head\"]\n",
" config_groups:\n",
" group_0:\n",
" weights:\n",
" num_bits: 4\n",
" type: \"int\"\n",
" symmetric: true\n",
" strategy: \"channel\"\n",
" targets: [\"Linear\"]\n",
" SparseGPTModifier:\n",
" sparsity: 0.0\n",
" quantize: True\n",
" sequential_update: false\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we need to initialize the model we wish to quantize, and define a dataset for calibration. We will use a llama2 7b model that has been pretrained on the ultrachat 200k dataset. We will use the same dataset the model has been pretrained on for our one shot calibration. \n",
"\n",
"SparseML supports several datasets, such as ultrachat-200k, out of the box. You can also pass in a tokenized `datasets.Dataset` object for custom dataset support."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# by setting the device_map to auto, we can spread the model evenly across all available GPUs\n",
"# load the model in as bfloat16 to save on memory and compute\n",
"model_stub = \"zoo:llama2-7b-ultrachat200k_llama2_pretrain-base\"\n",
"model = SparseAutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.bfloat16, device_map=\"auto\")\n",
"\n",
"# uses SparseML's built-in preprocessing for ultra chat\n",
"dataset = \"ultrachat-200k\"\n",
"\n",
"# save location of quantized model\n",
"output_dir = \"./output_llama7b_W4A16_channel_compressed\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we will configure our calibration dataset. To save on load time, we load only a small subset of ultrachat200k's `train_gen` split and label it as calibration data. For oneshot we do not need to pad the input, so we set `pad_to_max_length` to false. We also truncate each sample to a maximum of 512 tokens and select 512 samples for calibration. \n",
"\n",
"Using more calibration samples can improve model performance but will take longer to run. Generally 256-2048 calibration samples is recommended."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# set dataset config parameters\n",
"splits = {\"calibration\": \"train_gen[:5%]\"}\n",
"max_seq_length = 512\n",
"pad_to_max_length = False\n",
"num_calibration_samples = 512"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we can launch our quantization recipe using the `oneshot` function. This function call will apply the algorithms defined in `recipe` to the input `model`, using `num_calibration_samples` from `dataset` as calibration data. We will save the quantized model to `output_dir`.\n",
"\n",
"By setting `save_compressed` to True, the model will be saved by packing every 8 int4 weights into a single int32. This will enable the model to be loaded by vLLM. Once a model has been saved in this way, you can no longer recover the original unquantized weights. To save the model in a \"fake quantized\" state instead so that the original weights are preserved, set `save_compressed` to False."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"oneshot(\n",
" model=model,\n",
" dataset=dataset,\n",
" recipe=recipe,\n",
" output_dir=output_dir,\n",
" splits=splits,\n",
" max_seq_length=max_seq_length,\n",
" pad_to_max_length=pad_to_max_length,\n",
" num_calibration_samples=num_calibration_samples,\n",
" save_compressed=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The quantized model should now be stored in the defined `output_dir`. Its `config.json` will contain a new `compression_config` field that describes how the model has been quantized. This config will be used to load the model into vLLM."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output_dir"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.save_pretrained(\"/network/sadkins/llama1.1b_W4A16_channel_packed\", save_compressed=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
56 changes: 56 additions & 0 deletions examples/llama7b_w4a16_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch

from sparseml.transformers import SparseAutoModelForCausalLM, oneshot


# define a sparseml recipe for GPTQ W8A8 quantization
recipe = """
quant_stage:
quant_modifiers:
vLLMQuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 4
type: "int"
symmetric: true
strategy: "channel"
targets: ["Linear"]
SparseGPTModifier:
sparsity: 0.0
quantize: true
sequential_update: false
"""

# setting device_map to auto to spread the model evenly across all available GPUs
# load the model in as bfloat16 to save on memory and compute
model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base"
model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)

# uses SparseML's built-in preprocessing for ultra chat
dataset = "ultrachat-200k"

# save location of quantized model out
output_dir = "./output_llama7b_w4a16_channel_compressed"

# set dataset config parameters
splits = {"calibration": "train_gen[:5%]"}
max_seq_length = 512
pad_to_max_length = False
num_calibration_samples = 512

# apply recipe to the model and save quantized output in an int4 packed format
oneshot(
model=model,
dataset=dataset,
recipe=recipe,
output_dir=output_dir,
splits=splits,
max_seq_length=max_seq_length,
pad_to_max_length=pad_to_max_length,
num_calibration_samples=num_calibration_samples,
save_compressed=True,
)
Loading

0 comments on commit cfe0620

Please sign in to comment.