Skip to content

Commit

Permalink
Add flash attention v2 and INT4 CUDA for LLaMA E2E benchmarking (#20149)
Browse files Browse the repository at this point in the history
### Description
This PR adds flash attention v2 and support for INT4 CUDA benchmarking
in PyTorch.

### Motivation and Context
The [flash attention v2](https://github.com/Dao-AILab/flash-attention)
algorithm helps improve model performance in PyTorch. Support for INT4
CUDA in PyTorch is done through the
[`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) package.
  • Loading branch information
kunal-vaishnavi authored Mar 30, 2024
1 parent 00244ea commit a0ebd5f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
51 changes: 41 additions & 10 deletions onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
# 4) Install the latest ONNX Runtime version
#
# $ pip install onnxruntime-gpu
#
# 5) Install flash attention v2
#
# $ pip install flash-attn --no-build-isolation
#
# 6) Install bitsandbytes
#
# $ pip install bitsandbytes

from __future__ import annotations

Expand All @@ -38,22 +46,44 @@
import torch
from benchmark_helper import setup_logger
from llama_inputs import add_io_bindings_as_tensors, get_initial_inputs_and_outputs
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

import onnxruntime as ort

logger = logging.getLogger(__name__)


def get_model(args):
def get_model(args: argparse.Namespace):
if args.benchmark_type in {"pt-eager", "pt-compile"}:
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
use_cache=True,
).to(args.target_device)
model = None
if args.onnx_precision == "int4" and args.device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
use_cache=True,
attn_implementation="flash_attention_2",
quantization_config=bnb_config,
max_memory={args.device_id: "80GB"},
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
use_cache=True,
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
).to(args.target_device)

model.eval()

if args.benchmark_type == "pt-compile":
Expand Down Expand Up @@ -223,7 +253,7 @@ def get_args():
parser.add_argument(
"-s",
"--prompt-lengths",
default="32 64 128 256 512",
default="16 64 256 1024",
)

parser.add_argument(
Expand Down Expand Up @@ -277,6 +307,7 @@ def get_args():
args.prompt_lengths = args.prompt_lengths.split(" ")

# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
setattr(args, "onnx_precision", args.precision) # noqa: B010
args.precision = (
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_args(argv: list[str]):
parser.add_argument(
"-m",
"--model_name",
required=True,
required=False,
help="Model name in Hugging Face",
)

Expand Down

0 comments on commit a0ebd5f

Please sign in to comment.