Skip to content

Commit

Permalink
perf(generation): add dtype parameter (quanto only)
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Jul 9, 2024
1 parent 8a13cf9 commit 0e8a0ad
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
28 changes: 25 additions & 3 deletions bench/generation/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from setup.bnb import setup as bnb_setup
from setup.hqq import setup as hqq_setup
from setup.quanto import setup as quanto_setup
from transformers import AutoConfig


@torch.no_grad()
Expand All @@ -42,10 +43,20 @@ def calibrate(model, tokenizer, batch_size, batches):


def evaluate(
model_id: str, metric: str, quantizer: str, weights: str, activations: str, batch_size: int, device: torch.device
model_id: str,
metric: str,
quantizer: str,
weights: str,
activations: str,
batch_size: int,
device: torch.device,
dtype: torch.dtype = None,
):
if quantizer == "quanto":
model, tokenizer = quanto_setup(model_id, weights, activations, batch_size, device)
if dtype is None:
config = AutoConfig.from_pretrained(model_id)
dtype = getattr(config, "torch_dtype", torch.float16)
model, tokenizer = quanto_setup(model_id, weights, activations, batch_size, device, dtype)
elif quantizer == "awq":
model, tokenizer = awq_setup(model_id, weights, activations, group_size=128)
elif quantizer == "bnb":
Expand All @@ -54,6 +65,10 @@ def evaluate(
model, tokenizer = hqq_setup(model_id, weights, activations, device)
else:
raise ValueError(f"Unsupported quantizer {quantizer}")
dtype = next(model.parameters()).dtype
weights = dtype if weights == "none" else weights
activations = dtype if activations == "none" else activations
print(f"Evaluating {model_id} {metric} with {weights} weights and {activations} activations.")
if metric == "latency":
return latency(model, tokenizer, device, batch_size=1, prompt_length=512, nb_tokens=512, iterations=5)
elif metric == "prediction":
Expand Down Expand Up @@ -87,6 +102,12 @@ def main():
choices=["none", "int8", "float8"],
)
parser.add_argument("--batch_size", type=int, default=32, help="The batch size during evaluation.")
parser.add_argument(
"--dtype",
type=str,
default="none",
choices=["none", "fp16", "bf16"],
)
args = parser.parse_args()

torch.manual_seed(args.seed)
Expand All @@ -100,7 +121,8 @@ def main():
device = torch.device("cpu")
else:
device = torch.device(args.device)
evaluate(args.model, args.metric, args.quantizer, args.weights, args.activations, args.batch_size, device)
dtype = {"none": None, "fp16": torch.float16, "bf16": torch.bfloat16}[args.dtype]
evaluate(args.model, args.metric, args.quantizer, args.weights, args.activations, args.batch_size, device, dtype)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion bench/generation/setup/quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def setup(
activations: str,
batch_size: int,
device: torch.device,
dtype: torch.dtype,
):
weights = keyword_to_qtype(weights)
activations = keyword_to_qtype(activations)
dtype = torch.float32 if device.type == "cpu" else torch.float16
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
Expand Down

0 comments on commit 0e8a0ad

Please sign in to comment.