Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor generation benchmark to compare with AWQ and HQQ #128

Merged
merged 3 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions bench/generation/evaluate_configurations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import argparse
import json

import torch
from evaluate_model import evaluate
from gen_barchart import gen_barchart

from quanto import qtype


def evaluate_model_configurations(model_id: str, metric: str, device: torch.device, batch_size: int = 32):
weights = [
"int4",
"int8",
"float8",
]

activations = [
"none",
"int8",
"float8",
]

def short_name(qtype: qtype):
return {
"none": "f16",
"int4": "i4",
"int8": "i8",
"float8": "f8",
}[qtype]

results = {}

# Evaluate float16 model
print(f"{model_id}[Wf16Af16]:")
results["Wf16Af16"] = evaluate(model_id, metric, "quanto", "none", "none", batch_size, device)
# Evaluate quantized models
for w in weights:
for a in activations:
config_name = f"W{short_name(w)}A{short_name(a)}"
print(f"{model_id}[{config_name}]:")
results[config_name] = evaluate(model_id, metric, "quanto", w, a, batch_size, device)

return results


def main():
parser = argparse.ArgumentParser(description="Evaluate quantized model predictions on Lambada Dataset")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--model",
type=str,
default="facebook/opt-350m",
help="The name of the trained Model.",
)
parser.add_argument("--device", type=str, default=None, help="The device to use for generation.")
parser.add_argument("--metric", type=str, default="prediction", choices=["latency", "prediction", "perplexity"])
parser.add_argument("--batch_size", type=int, default=32, help="The batch size during evaluation.")
parser.add_argument("--json", action="store_true", help="Dump the results to a json file.")
parser.add_argument("--png", action="store_true", help="Generate a PNG.")
args = parser.parse_args()

torch.manual_seed(args.seed)

if args.device is None:
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
else:
device = torch.device(args.device)

results = evaluate_model_configurations(args.model, args.metric, device, batch_size=args.batch_size)
if args.json:
model_name = args.model.split("/")[-1]
json_path = f"{model_name}-{args.metric}.json"
with open(json_path, "w") as fp:
json.dump({model_name: results}, fp, indent=4)
if args.png:
if args.metric == "latency":
title = f"{args.model}: Mean latency per token"
label = "Latency (ms)"
elif args.metric == "prediction":
title = f"{args.model}: Prediction accuracy on Lambada dataset"
label = "Accuracy"
elif args.metric == "perplexity":
title = f"{args.model}: Perplexity evaluated on WikiText dataset"
label = "Perplexity"
gen_barchart(args.model, title, label, results)


if __name__ == "__main__":
main()
12 changes: 6 additions & 6 deletions bench/generation/evaluate_many_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ bigger_models=(
)

for m in ${small_models[@]}; do
python ${SCRIPT_PATH}/evaluate_model.py --model $m --metric prediction --png
python ${SCRIPT_PATH}/evaluate_model.py --model $m --metric perplexity --png
python ${SCRIPT_PATH}/evaluate_model.py --model $m --metric latency --png
python ${SCRIPT_PATH}/evaluate_configurations.py --model $m --metric prediction --png
python ${SCRIPT_PATH}/evaluate_configurations.py --model $m --metric perplexity --png
python ${SCRIPT_PATH}/evaluate_configurations.py --model $m --metric latency --png
done

for m in ${bigger_models[@]}; do
python ${SCRIPT_PATH}/evaluate_model.py --model $m --metric prediction --png --json --batch_size 16
python ${SCRIPT_PATH}/evaluate_model.py --model $m --metric perplexity --png --json --batch_size 16
python ${SCRIPT_PATH}/evaluate_model.py --model $m --metric latency --png --json --batch_size 16
python ${SCRIPT_PATH}/evaluate_configurations.py --model $m --metric prediction --png --json --batch_size 16
python ${SCRIPT_PATH}/evaluate_configurations.py --model $m --metric perplexity --png --json --batch_size 16
python ${SCRIPT_PATH}/evaluate_configurations.py --model $m --metric latency --png --json --batch_size 16
done
135 changes: 60 additions & 75 deletions bench/generation/evaluate_model.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,56 @@
import argparse
import json

import torch
from gen_barchart import gen_barchart
from latency import latency
from perplexity import perplexity
from prediction import prediction_accuracy

from quanto import qfloat8, qint4, qint8, qtype


def evaluate_model_configurations(
model_id: str, metric: str, device: torch.device, batch_size: int = 32, seed: int = 1
from datasets import load_dataset
from metrics.latency import latency
from metrics.perplexity import perplexity
from metrics.prediction import prediction_accuracy

from setup.awq import setup as awq_setup
from setup.bnb import setup as bnb_setup
from setup.hqq import setup as hqq_setup
from setup.quanto import setup as quanto_setup


@torch.no_grad()
def calibrate(model, tokenizer, batch_size, batches):
samples = batch_size * batches
cal_dataset = load_dataset("lambada", split=["validation"])[0]
model.eval()
total = 0
for batch in cal_dataset.iter(batch_size=batch_size):
inputs = tokenizer(batch["text"], return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
model(input_ids, attention_mask=attention_mask)
total += input_ids.size(0)
if total >= samples:
break


def evaluate(
model_id: str, metric: str, quantizer: str, weights: str, activations: str, batch_size: int, device: torch.device
):
weights = [
qint4,
qint8,
qfloat8,
]

activations = [
None,
qint8,
qfloat8,
]

def short_name(qtype: qtype):
return {
None: "f16",
qint4: "i4",
qint8: "i8",
qfloat8: "f8",
}[qtype]

results = {}

def get_results(model_id: str, w: qtype, a: qtype, device: torch.device, seed: int = 1):
if metric == "latency":
return latency(model_id, w, a, device, batch_size=batch_size, seed=seed)
elif metric == "prediction":
return prediction_accuracy(model_id, w, a, device, batch_size=batch_size, seed=seed)
elif metric == "perplexity":
return perplexity(model_id, w, a, device, batch_size=batch_size, seed=seed)

# Evaluate float16 model
print(f"{model_id}[Wf16Af16]:")
results["Wf16Af16"] = get_results(model_id, None, None, device, seed=seed)
# Evaluate quantized models
for w in weights:
for a in activations:
config_name = f"W{short_name(w)}A{short_name(a)}"
print(f"{model_id}[{config_name}]:")
results[config_name] = get_results(model_id, w, a, device, seed=seed)

return results
if quantizer == "quanto":
model, tokenizer = quanto_setup(model_id, weights, activations, batch_size, device)
elif quantizer == "awq":
model, tokenizer = awq_setup(model_id, weights, activations)
elif quantizer == "bnb":
model, tokenizer = bnb_setup(model_id, weights, activations, device)
elif quantizer == "hqq":
model, tokenizer = hqq_setup(model_id, weights, activations, device)
else:
raise ValueError(f"Unsupported quantizer {quantizer}")
if metric == "latency":
return latency(model, tokenizer, device, batch_size=1, prompt_length=512, nb_tokens=512, iterations=5)
elif metric == "prediction":
return prediction_accuracy(model, tokenizer, batch_size)
elif metric == "perplexity":
return perplexity(model, tokenizer)


def main():
parser = argparse.ArgumentParser(description="Evaluate quantized model predictions on Lambada Dataset")
parser = argparse.ArgumentParser(description="Evaluate quantized model metrics")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--model",
Expand All @@ -67,9 +60,20 @@ def main():
)
parser.add_argument("--device", type=str, default=None, help="The device to use for generation.")
parser.add_argument("--metric", type=str, default="prediction", choices=["latency", "prediction", "perplexity"])
parser.add_argument("--quantizer", type=str, default="quanto", choices=["quanto", "awq", "bnb", "hqq"])
parser.add_argument(
"--weights",
type=str,
default="none",
choices=["none", "int4", "int8", "float8"],
)
parser.add_argument(
"--activations",
type=str,
default="none",
choices=["none", "int8", "float8"],
)
parser.add_argument("--batch_size", type=int, default=32, help="The batch size during evaluation.")
parser.add_argument("--json", action="store_true", help="Dump the results to a json file.")
parser.add_argument("--png", action="store_true", help="Generate a PNG.")
args = parser.parse_args()

torch.manual_seed(args.seed)
Expand All @@ -83,26 +87,7 @@ def main():
device = torch.device("cpu")
else:
device = torch.device(args.device)

results = evaluate_model_configurations(
args.model, args.metric, device, batch_size=args.batch_size, seed=args.seed
)
if args.json:
model_name = args.model.split("/")[-1]
json_path = f"{model_name}-{args.metric}.json"
with open(json_path, "w") as fp:
json.dump({model_name: results}, fp, indent=4)
if args.png:
if args.metric == "latency":
title = f"{args.model}: Mean latency per token"
label = "Latency (ms)"
elif args.metric == "prediction":
title = f"{args.model}: Prediction accuracy on Lambada dataset"
label = "Accuracy"
elif args.metric == "perplexity":
title = f"{args.model}: Perplexity evaluated on WikiText dataset"
label = "Perplexity"
gen_barchart(args.model, title, label, results)
evaluate(args.model, args.metric, args.quantizer, args.weights, args.activations, args.batch_size, device)


if __name__ == "__main__":
Expand Down
Loading
Loading