Skip to content

Commit

Permalink
Kylesayrs/update readme (#252)
Browse files Browse the repository at this point in the history
* WIP

Signed-off-by: Kyle Sayers <[email protected]>

* Fix errors in scripts and notebooks in `examples/` and drop `sparseml` dependence (#247)

* first pass, awaiting team feedback

* drop hf-transfer nonsense

* remaining example files

* black/isort

* Apply suggestions from code review

Co-authored-by: Kyle Sayers <[email protected]>

* notebook example cleanup

Signed-off-by: Brian Dellabetta <[email protected]>

* update compressed-tensors examples for QDQ and actual compression using torch hooks

Signed-off-by: Brian Dellabetta <[email protected]>

* Update examples/llama_1.1b/ex_config_quantization.py

Co-authored-by: Rahul Tuli <[email protected]>

* f string typo

Signed-off-by: Brian Dellabetta <[email protected]>

* updates from codereview

Signed-off-by: Brian Dellabetta <[email protected]>

---------

Signed-off-by: Brian Dellabetta <[email protected]>
Co-authored-by: Kyle Sayers <[email protected]>
Co-authored-by: Rahul Tuli <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Co-authored-by: Kyle Sayers <[email protected]>
Co-authored-by: Rahul Tuli <[email protected]>
  • Loading branch information
3 people authored Feb 6, 2025
1 parent 22c09f3 commit 5f24384
Show file tree
Hide file tree
Showing 9 changed files with 323 additions and 165 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

examples/**/*.safetensors
79 changes: 43 additions & 36 deletions examples/bit_packing/ex_quantize_and_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from tqdm import tqdm
from torch.utils.data import RandomSampler
####
#
# The following example shows how to run QDQ inside `compressed-tensors`
# QDQ (quantize & de-quantize) is a way to evaluate quantized model
# accuracy but will not lead to a runtime speedup.
# See `../llama_1.1b/ex_config_quantization.py` to go beyond QDQ
# and quantize models that will run more performantly.
#
####

from pathlib import Path

import torch
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.quantization import (
apply_quantization_config,
freeze_module_quantization,
QuantizationConfig,
QuantizationStatus,
apply_quantization_config,
)
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from sparseml.transformers.finetune.data.base import TextGenerationDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator
from datasets import load_dataset
from torch.utils.data import DataLoader
from sparseml.pytorch.utils import tensors_to_device
import torch
from compressed_tensors.compressors import ModelCompressor
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

config_file = "int4_config.json"

config_file = Path(__file__).parent / "int4_config.json"
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
dataset_name = "open_platypus"
dataset_name = "garage-bAInd/Open-Platypus"
split = "train"
num_calibration_samples = 128
max_seq_length = 512
pad_to_max_length = False
output_dir = "./llama1.1b_new_quant_out_test_packing"
device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype="auto")
model = AutoModelForCausalLM.from_pretrained(
model_name, device_map=device, torch_dtype="auto"
)
model.eval() # no grad or updates needed for base model
config = QuantizationConfig.parse_file(config_file)
config = QuantizationConfig.model_validate_json(config_file.read_text())

# set status to calibration
config.quantization_status = QuantizationStatus.CALIBRATION
Expand All @@ -49,39 +60,35 @@
apply_quantization_config(model, config)

# create dataset
dataset = load_dataset(dataset_name, split=f"train[:{num_calibration_samples}]")
tokenizer = AutoTokenizer.from_pretrained(model_name)
data_args = DataTrainingArguments(
dataset=dataset_name,
max_seq_length=max_seq_length,
pad_to_max_length=pad_to_max_length,
)
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split=split,
tokenizer=tokenizer,
)
calib_dataset = dataset_manager.tokenize_and_process(
dataset_manager.get_raw_dataset()
)


def tokenize_function(examples):
return tokenizer(
examples["output"], padding=False, truncation=True, max_length=1024
)


tokenized_dataset = dataset.map(tokenize_function, batched=True)

data_loader = DataLoader(
calib_dataset, batch_size=1, collate_fn=DefaultDataCollator(), sampler=RandomSampler(calib_dataset)
tokenized_dataset,
batch_size=1,
)

# run calibration
with torch.no_grad():
for idx, sample in tqdm(enumerate(data_loader), desc="Running calibration"):
sample = tensors_to_device(sample, "cuda:0")
sample = {k: v.to(model.device) for k, v in sample.items()}
_ = model(**sample)

if idx >= num_calibration_samples:
break

# freeze params after calibration
model.apply(freeze_module_quantization)

# apply compression
# convert model to QDQ model
compressor = ModelCompressor(quantization_config=config)
compressed_state_dict = compressor.compress(model)

# save QDQ model
model.save_pretrained(output_dir, state_dict=compressed_state_dict)
compressor.update_config(output_dir)
compressor.update_config(output_dir)
3 changes: 1 addition & 2 deletions examples/bit_packing/int4_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,5 @@
},
"targets": ["Linear"]
}
},
"ignore": ["lm_head"]
}
}
40 changes: 20 additions & 20 deletions examples/bitmask_compression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -29,7 +29,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand All @@ -40,30 +40,30 @@
" (embed_tokens): Embedding(32000, 768)\n",
" (layers): ModuleList(\n",
" (0-11): 12 x LlamaDecoderLayer(\n",
" (self_attn): LlamaSdpaAttention(\n",
" (self_attn): LlamaAttention(\n",
" (q_proj): Linear(in_features=768, out_features=768, bias=False)\n",
" (k_proj): Linear(in_features=768, out_features=768, bias=False)\n",
" (v_proj): Linear(in_features=768, out_features=768, bias=False)\n",
" (o_proj): Linear(in_features=768, out_features=768, bias=False)\n",
" (rotary_emb): LlamaRotaryEmbedding()\n",
" )\n",
" (mlp): LlamaMLP(\n",
" (gate_proj): Linear(in_features=768, out_features=2048, bias=False)\n",
" (up_proj): Linear(in_features=768, out_features=2048, bias=False)\n",
" (down_proj): Linear(in_features=2048, out_features=768, bias=False)\n",
" (act_fn): SiLU()\n",
" )\n",
" (input_layernorm): LlamaRMSNorm()\n",
" (post_attention_layernorm): LlamaRMSNorm()\n",
" (input_layernorm): LlamaRMSNorm((768,), eps=1e-05)\n",
" (post_attention_layernorm): LlamaRMSNorm((768,), eps=1e-05)\n",
" )\n",
" )\n",
" (norm): LlamaRMSNorm()\n",
" (norm): LlamaRMSNorm((768,), eps=1e-05)\n",
" (rotary_emb): LlamaRotaryEmbedding()\n",
" )\n",
" (lm_head): Linear(in_features=768, out_features=32000, bias=False)\n",
")"
]
},
"execution_count": 9,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -77,14 +77,14 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The example layer model.layers.0.self_attn.q_proj.weight has sparsity 0.50%\n"
"The example layer model.layers.0.self_attn.q_proj.weight has sparsity 50%\n"
]
}
],
Expand All @@ -93,42 +93,42 @@
"state_dict = model.state_dict()\n",
"state_dict.keys()\n",
"example_layer = \"model.layers.0.self_attn.q_proj.weight\"\n",
"print(f\"The example layer {example_layer} has sparsity {torch.sum(state_dict[example_layer] == 0).item() / state_dict[example_layer].numel():.2f}%\")"
"print(f\"The example layer {example_layer} has sparsity {100 * state_dict[example_layer].eq(0).sum().item() / state_dict[example_layer].numel():.0f}%\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model is 31.67% sparse overall\n"
"The model is 32% sparse overall\n"
]
}
],
"source": [
"# we can inspect to total sparisity of the state_dict\n",
"# we can inspect to total sparsity of the state_dict\n",
"total_num_parameters = 0\n",
"total_num_zero_parameters = 0\n",
"for key in state_dict:\n",
" total_num_parameters += state_dict[key].numel()\n",
" total_num_zero_parameters += state_dict[key].eq(0).sum().item()\n",
"print(f\"The model is {total_num_zero_parameters/total_num_parameters*100:.2f}% sparse overall\")"
"print(f\"The model is {total_num_zero_parameters/total_num_parameters*100:.0f}% sparse overall\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Compressing model: 100%|██████████| 111/111 [00:06<00:00, 17.92it/s]\n"
"Compressing model: 100%|██████████| 111/111 [00:00<00:00, 313.39it/s]\n"
]
},
{
Expand Down Expand Up @@ -168,7 +168,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -185,8 +185,8 @@
"## load the uncompressed safetensors to memory ##\n",
"state_dict_1 = {}\n",
"with safe_open('model.safetensors', framework=\"pt\") as f:\n",
" for key in f.keys():\n",
" state_dict_1[key] = f.get_tensor(key)\n",
" for key in f.keys():\n",
" state_dict_1[key] = f.get_tensor(key)\n",
"\n",
"## load the compressed-tensors to memory ##\n",
"config = BitmaskConfig() # we need to specify the method for decompression\n",
Expand Down
Loading

0 comments on commit 5f24384

Please sign in to comment.