diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py index 47b7f35cbdd7c..b69bd229745c6 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py @@ -20,6 +20,14 @@ # 4) Install the latest ONNX Runtime version # # $ pip install onnxruntime-gpu +# +# 5) Install flash attention v2 +# +# $ pip install flash-attn --no-build-isolation +# +# 6) Install bitsandbytes +# +# $ pip install bitsandbytes from __future__ import annotations @@ -38,22 +46,44 @@ import torch from benchmark_helper import setup_logger from llama_inputs import add_io_bindings_as_tensors, get_initial_inputs_and_outputs -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import onnxruntime as ort logger = logging.getLogger(__name__) -def get_model(args): +def get_model(args: argparse.Namespace): if args.benchmark_type in {"pt-eager", "pt-compile"}: - model = AutoModelForCausalLM.from_pretrained( - args.hf_dir_path if args.hf_dir_path != "" else args.model_name, - cache_dir=args.cache_dir, - torch_dtype=args.torch_dtype, - use_auth_token=args.auth, - use_cache=True, - ).to(args.target_device) + model = None + if args.onnx_precision == "int4" and args.device == "cuda": + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + + model = AutoModelForCausalLM.from_pretrained( + args.hf_dir_path if args.hf_dir_path != "" else args.model_name, + cache_dir=args.cache_dir, + torch_dtype=args.torch_dtype, + use_auth_token=args.auth, + use_cache=True, + attn_implementation="flash_attention_2", + quantization_config=bnb_config, + max_memory={args.device_id: "80GB"}, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + args.hf_dir_path if args.hf_dir_path != "" else args.model_name, + cache_dir=args.cache_dir, + torch_dtype=args.torch_dtype, + use_auth_token=args.auth, + use_cache=True, + attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"), + ).to(args.target_device) + model.eval() if args.benchmark_type == "pt-compile": @@ -223,7 +253,7 @@ def get_args(): parser.add_argument( "-s", "--prompt-lengths", - default="32 64 128 256 512", + default="16 64 256 1024", ) parser.add_argument( @@ -277,6 +307,7 @@ def get_args(): args.prompt_lengths = args.prompt_lengths.split(" ") # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models + setattr(args, "onnx_precision", args.precision) # noqa: B010 args.precision = ( "fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16" ) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 9cbc9af7fe9b5..7b186eec2f5a9 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -170,7 +170,7 @@ def get_args(argv: list[str]): parser.add_argument( "-m", "--model_name", - required=True, + required=False, help="Model name in Hugging Face", )