Skip to content

Commit

Permalink
Support config filtering in ondemand benchmark flow
Browse files Browse the repository at this point in the history
  • Loading branch information
Github Executorch authored and Guang Yang committed Jan 14, 2025
1 parent 09279e6 commit e7f89f5
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 37 deletions.
104 changes: 69 additions & 35 deletions .ci/scripts/gather_benchmark_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@
}


def extract_all_configs(data, os_name=None):
if isinstance(data, dict):
# If os_name is specified, include "xplat" and the specified branch
include_branches = {"xplat", os_name} if os_name else data.keys()
return [
v
for key, value in data.items()
if key in include_branches
for v in extract_all_configs(value, os_name)
]
elif isinstance(data, list):
return [v for item in data for v in extract_all_configs(item, os_name)]
else:
return [data]


def parse_args() -> Any:
"""
Parse command-line arguments.
Expand Down Expand Up @@ -82,6 +98,11 @@ def comma_separated(value: str):
type=comma_separated, # Use the custom parser for comma-separated values
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
)
parser.add_argument(
"--configs",
type=comma_separated, # Use the custom parser for comma-separated values
help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}",
)

return parser.parse_args()

Expand Down Expand Up @@ -123,7 +144,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
return bool(re.match(pattern, model_name))


def get_benchmark_configs() -> Dict[str, Dict]:
def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901
"""
Gather benchmark configurations for a given set of models on the target operating system and devices.
Expand Down Expand Up @@ -153,48 +174,61 @@ def get_benchmark_configs() -> Dict[str, Dict]:
}
"""
args = parse_args()
target_os = args.os
devices = args.devices
models = args.models
configs = args.configs
target_os = args.os

benchmark_configs = {"include": []}

for model_name in models:
configs = []
if is_valid_huggingface_model_id(model_name):
if model_name.startswith("meta-llama/"):
# LLaMA models
repo_name = model_name.split("meta-llama/")[1]
if "qlora" in repo_name.lower():
configs.append("llama3_qlora")
elif "spinquant" in repo_name.lower():
configs.append("llama3_spinquant")
else:
configs.append("llama3_fb16")
configs.extend(
[
config
for config in BENCHMARK_CONFIGS.get(target_os, [])
if config.startswith("llama")
]

if len(configs) > 0:
supported_configs = extract_all_configs(BENCHMARK_CONFIGS, target_os)
for config in configs:
if config not in supported_configs:
raise Exception(
f"Unsupported config '{config}'. Skipping. Available configs: {extract_all_configs(BENCHMARK_CONFIGS, target_os)}"
)
else:
# Non-LLaMA models
configs.append("hf_xnnpack_fp32")
elif model_name in MODEL_NAME_TO_MODEL:
# ExecuTorch in-tree non-GenAI models
configs.append("xnnpack_q8")
configs.extend(
[
config
for config in BENCHMARK_CONFIGS.get(target_os, [])
if not config.startswith("llama")
]
)
print(f"Using provided configs {configs} for model '{model_name}'")
else:
# Skip unknown models with a warning
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
continue
print(f"Discover all compatible configs for model '{model_name}'")
if is_valid_huggingface_model_id(model_name):
if model_name.startswith("meta-llama/"):
# LLaMA models
repo_name = model_name.split("meta-llama/")[1]
if "qlora" in repo_name.lower():
configs.append("llama3_qlora")
elif "spinquant" in repo_name.lower():
configs.append("llama3_spinquant")
else:
configs.append("llama3_fb16")
configs.extend(
[
config
for config in BENCHMARK_CONFIGS.get(target_os, [])
if config.startswith("llama")
]
)
else:
# Non-LLaMA models
configs.append("hf_xnnpack_fp32")
elif model_name in MODEL_NAME_TO_MODEL:
# ExecuTorch in-tree non-GenAI models
configs.append("xnnpack_q8")
configs.extend(
[
config
for config in BENCHMARK_CONFIGS.get(target_os, [])
if not config.startswith("llama")
]
)
else:
# Skip unknown models with a warning
logging.warning(
f"Unknown or invalid model name '{model_name}'. Skipping."
)
continue

# Add configurations for each valid device
for device in devices:
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/android-perf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ jobs:
if [ -z "$DEVICES" ]; then
DEVICES="$CRON_DEFAULT_DEVICES"
fi
BENCHMARK_CONFIGS="${{ inputs.benchmark_configs }}"
PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py \
--os "android" \
--models $MODELS \
--devices $DEVICES
--devices $DEVICES \
--configs $BENCHMARK_CONFIGS
prepare-test-specs:
runs-on: linux.2xlarge
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/apple-perf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ jobs:
if [ -z "$DEVICES" ]; then
DEVICES="$CRON_DEFAULT_DEVICES"
fi
BENCHMARK_CONFIGS="${{ inputs.benchmark_configs }}"
PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py \
--os "ios" \
--models $MODELS \
--devices $DEVICES
--devices $DEVICES \
--configs $BENCHMARK_CONFIGS
echo "benchmark_configs is: ${{ steps.set-parameters.outputs.benchmark_configs }}"
Expand Down

0 comments on commit e7f89f5

Please sign in to comment.