Skip to content

Commit

Permalink
Save pytorch profiler output for latency benchmark (vllm-project#1871)
Browse files Browse the repository at this point in the history
* Save profiler output

* Apply feedback from code review
  • Loading branch information
Yard1 authored Dec 6, 2023
1 parent 1d9b737 commit 05ff90b
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import time
from pathlib import Path
from typing import Optional

import numpy as np
import torch
Expand Down Expand Up @@ -34,12 +36,15 @@ def main(args: argparse.Namespace):
print(sampling_params)
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size

def run_to_completion(profile: bool = False):
if profile:
with torch.profiler.profile(activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]) as p:
def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
Expand All @@ -54,11 +59,14 @@ def run_to_completion(profile: bool = False):
return latency

print("Warming up...")
run_to_completion(profile=False)
run_to_completion(profile_dir=None)

if args.profile:
print("Profiling...")
run_to_completion(profile=True)
profile_dir = args.profile_result_dir
if not profile_dir:
profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=args.profile_result_dir)
return

# Benchmark.
Expand Down Expand Up @@ -107,5 +115,13 @@ def run_to_completion(profile: bool = False):
'--profile',
action='store_true',
help='profile the generation process of a single batch')
parser.add_argument(
'--profile-result-dir',
type=str,
default=None,
help=(
'path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'
))
args = parser.parse_args()
main(args)

0 comments on commit 05ff90b

Please sign in to comment.