Skip to content

Commit

Permalink
Adding PR comment updates
Browse files Browse the repository at this point in the history
  • Loading branch information
geomin12 committed Feb 20, 2025
1 parent 3724c10 commit 1f2169a
Show file tree
Hide file tree
Showing 18 changed files with 367 additions and 219 deletions.
51 changes: 32 additions & 19 deletions sharktank_models/clip/test_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,51 @@

THIS_DIR = pathlib.Path(__file__).parent


def load_tensor_from_irpa(path: PathLike) -> np.ndarray:
index = iree.runtime.ParameterIndex()
index.load(str(path))
index_entry: iree.runtime.ParameterIndexEntry = index.items()[0][1]
return iree.runtime.parameter_index_entry_as_numpy_ndarray(index_entry)


@pytest.fixture(
params=[
pytest.param("local-task", marks=pytest.mark.target_cpu),
pytest.param("hip", marks=pytest.mark.target_hip),
]
params=[
pytest.param("local-task", marks=pytest.mark.target_cpu),
pytest.param("hip", marks=pytest.mark.target_hip),
]
)
def device_id(request: pytest.FixtureRequest) -> str:
return request.param


@pytest.fixture(
params=["bf16", "f32"]
)
@pytest.fixture(params=["bf16", "f32"])
def model_variant(request: pytest.FixtureRequest) -> str:
return request.param


mlir_path = {
"bf16": THIS_DIR / "assets/text_model/toy/bf16.mlir",
"f32": THIS_DIR / "assets/text_model/toy/f32.mlir"
"f32": THIS_DIR / "assets/text_model/toy/f32.mlir",
}

parameters_path = {
"bf16": THIS_DIR / "assets/text_model/toy/bf16_parameters.irpa",
"f32": THIS_DIR / "assets/text_model/toy/f32_parameters.irpa"
"f32": THIS_DIR / "assets/text_model/toy/f32_parameters.irpa",
}

function_arg0_path = THIS_DIR / "assets/text_model/toy/forward_bs4_arg0_input_ids.irpa"
function_expected_result0 = THIS_DIR / "assets/text_model/toy/forward_bs4_expected_result0_last_hidden_state_f32.irpa"
function_expected_result0 = (
THIS_DIR
/ "assets/text_model/toy/forward_bs4_expected_result0_last_hidden_state_f32.irpa"
)

absolute_tolerance = {
"bf16": 1e-3,
"f32" : 1e-5,
"f32": 1e-5,
}


def compiler_args(device_id: str) -> list[str]:
if device_id == "local-task":
return ["--iree-hal-target-device=llvm-cpu", "--iree-llvmcpu-target-cpu=host"]
Expand All @@ -70,16 +74,21 @@ def compiler_args(device_id: str) -> list[str]:

raise KeyError(f"Compiler args for {device_id} not found")

def compile_and_run(mlir_path: str, compiler_args: list[str], function: str, args: list[np.ndarray]) -> list[np.ndarray]:

def compile_and_run(
mlir_path: str, compiler_args: list[str], function: str, args: list[np.ndarray]
) -> list[np.ndarray]:
iree.compiler.compile_file(
mlir_path,
extra_args=compiler_args,
)


@pytest.fixture(scope="session")
def iree_module(model_variant, device_id) -> iree.runtime.VmModule:
compiler_arguments = compiler_args(device_id)


def device_array_to_host(device_array: iree.runtime.DeviceArray) -> np.ndarray:
def reinterpret_hal_buffer_view_element_type(
buffer_view: iree.runtime.HalBufferView,
Expand Down Expand Up @@ -157,11 +166,12 @@ def assert_text_encoder_state_close(
rtol=0,
)


def test_results_close(model_variant, device_id):
module_buffer = iree.compiler.compile_file(
str(mlir_path[model_variant]),
extra_args=compiler_args(device_id),
)
str(mlir_path[model_variant]),
extra_args=compiler_args(device_id),
)

vm_instance = iree.runtime.VmInstance()
paramIndex = iree.runtime.ParameterIndex()
Expand All @@ -173,13 +183,16 @@ def test_results_close(model_variant, device_id):
device = iree.runtime.get_device(device_id)
hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=[device])
vm_module = iree.runtime.VmModule.from_buffer(vm_instance, module_buffer)
config=iree.runtime.Config(device=device)
bound_modules = iree.runtime.load_vm_modules(hal_module, parameters_module, vm_module,
config=config)
config = iree.runtime.Config(device=device)
bound_modules = iree.runtime.load_vm_modules(
hal_module, parameters_module, vm_module, config=config
)
module = bound_modules[-1]
result = module.forward_bs4(load_tensor_from_irpa(function_arg0_path))[0]

expected_result = load_tensor_from_irpa(function_expected_result0)
result = device_array_to_host(result).astype(dtype=expected_result.dtype)

assert_text_encoder_state_close(result, expected_result, absolute_tolerance[model_variant])
assert_text_encoder_state_close(
result, expected_result, absolute_tolerance[model_variant]
)
2 changes: 1 addition & 1 deletion sharktank_models/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ml_dtypes
numpy
pytest
pytest-check
pytest-depends
pytest-dependency
pytest-html
pytest-reportlog
pytest-retry
Expand Down
1 change: 1 addition & 0 deletions sharktank_models/test_suite/benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Argument options for the script
| function_run | required | string | The function that the `iree-benchmark-module` will run adnd benchmark |
| benchmark_repetitions | required | float | The number of times the benchmark tests will repeat |
| benchmark_min_warmup_time | required | float | The minimum warm up time for the benchmark test |
| device | required | string | The device that the benchmark tests are running |
| golden_time_tolerance_multiplier | optional | object | An object of tolerance multipliers, where the key is the sku and the value is the multiplier, (ex: `{"mi250": 1.3}`) |
| golden_time_ms | optional | object | An object of golden times, where the key is the sku and the value is the golden time in ms, (ex: `{"mi250": 100}`) |
| golden_dispatch | optional | object | An object of golden dispatches, where the key is the sku and the value is the golden dispatch count, (ex: `{"mi250": 1602}`) |
Expand Down
61 changes: 34 additions & 27 deletions sharktank_models/test_suite/benchmarks/run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,40 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import subprocess
import os
import os
from pathlib import Path
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="sdxl")
parser.add_argument("--filename", type=str, default="*")
parser.add_argument("--sku", type=str, default="mi300")
parser.add_argument("--rocm-chip", type=str, default="gfx942")
args = parser.parse_args()
model = args.model
filename = args.filename
sku = args.sku
rocm_chip = args.rocm_chip

os.environ['BENCHMARK_MODEL'] = model
os.environ['BENCHMARK_FILE_NAME'] = filename
os.environ['SKU'] = sku
os.environ['ROCM_CHIP'] = rocm_chip

THIS_DIR = Path(__file__).parent

command = [
"pytest",
THIS_DIR / "test_model_benchmark.py",
"--log-cli-level=info",
"--timeout=600",
"--retries=7"
]
subprocess.run(command)

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="sdxl")
parser.add_argument("--filename", type=str, default="*")
parser.add_argument("--sku", type=str, default="mi300")
parser.add_argument("--rocm-chip", type=str, default="gfx942")
args = parser.parse_args()
model = args.model
filename = args.filename
sku = args.sku
rocm_chip = args.rocm_chip

os.environ["BENCHMARK_MODEL"] = model
os.environ["BENCHMARK_FILE_NAME"] = filename
os.environ["SKU"] = sku
os.environ["ROCM_CHIP"] = rocm_chip

THIS_DIR = Path(__file__).parent

command = [
"pytest",
THIS_DIR / "test_model_benchmark.py",
"--log-cli-level=info",
"--timeout=600",
"--retries=7",
]
subprocess.run(command)
return 0


if __name__ == "__main__":
sys.exit(main())
8 changes: 6 additions & 2 deletions sharktank_models/test_suite/benchmarks/sdxl/clip_rocm.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
"1x64xi64"
],
"function_run": "encode_prompts",
"benchmark_repetitions": 10,
"benchmark_min_warmup_time": 3.0,
"benchmark_flags": [
"--benchmark_repetitions=10",
"--benchmark_min_warmup_time=3.0",
"--device_allocator=caching"
],
"device": "hip",
"golden_time_tolerance_multiplier": {
"mi250": 1.3,
"mi300": 1.1,
Expand Down
11 changes: 8 additions & 3 deletions sharktank_models/test_suite/benchmarks/sdxl/e2e_rocm.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"--iree-codegen-gpu-native-math-precision=true",
"--iree-hip-waves-per-eu=2",
"--iree-opt-outer-dim-concat=true",
"--iree-llvmgpu-enable-prefetch"
"--iree-llvmgpu-enable-prefetch",
"--iree-hal-target-backends=rocm"
],
"mlir_file_path": "external_test_files/sdxl_pipeline_bench_f16.mlir",
"modules": [
Expand All @@ -24,8 +25,12 @@
"sdxl_vae"
],
"function_run": "tokens_to_image",
"benchmark_repetitions": 10,
"benchmark_min_warmup_time": 3.0,
"benchmark_flags": [
"--benchmark_repetitions=10",
"--benchmark_min_warmup_time=3.0",
"--device_allocator=caching"
],
"device": "hip",
"golden_time_tolerance_multiplier": {
"mi250": 1.3,
"mi300": 1.1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
"1xf16"
],
"function_run": "main",
"benchmark_repetitions": 10,
"benchmark_min_warmup_time": 3.0,
"benchmark_flags": [
"--benchmark_repetitions=10",
"--benchmark_min_warmup_time=3.0",
"--device_allocator=caching"
],
"device": "hip",
"golden_time_tolerance_multiplier": {
"mi300": 1.1,
"mi308": 1.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
"1xf16"
],
"function_run": "main",
"benchmark_repetitions": 10,
"benchmark_min_warmup_time": 3.0,
"benchmark_flags": [
"--benchmark_repetitions=10",
"--benchmark_min_warmup_time=3.0",
"--device_allocator=caching"
],
"device": "hip",
"golden_time_tolerance_multiplier": {
"mi300": 1.1,
"mi308": 1.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
"1xi64"
],
"function_run": "run_forward",
"benchmark_repetitions": 10,
"benchmark_min_warmup_time": 3.0,
"benchmark_flags": [
"--benchmark_repetitions=10",
"--benchmark_min_warmup_time=3.0",
"--device_allocator=caching"
],
"device": "hip",
"golden_time_tolerance_multiplier": {
"mi250": 1.3,
"mi300": 1.1,
Expand Down
8 changes: 6 additions & 2 deletions sharktank_models/test_suite/benchmarks/sdxl/vae_rocm.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
"1x4x128x128xf16"
],
"function_run": "main",
"benchmark_repetitions": 10,
"benchmark_min_warmup_time": 3.0,
"benchmark_flags": [
"--benchmark_repetitions=10",
"--benchmark_min_warmup_time=3.0",
"--device_allocator=caching"
],
"golden_time_tolerance_multiplier": {
"mi250": 1.3,
"mi300": 1.1,
"mi308": 1.1
},
"device": "hip",
"golden_time_ms": {
"mi250": 310,
"mi300": 75,
Expand Down
Loading

0 comments on commit 1f2169a

Please sign in to comment.