Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Select benchmarks CLI option #1261

Merged
merged 8 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions mteb/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,16 @@ def run(args: argparse.Namespace) -> None:

model = mteb.get_model(args.model, args.model_revision, device=device)

tasks = mteb.get_tasks(
categories=args.categories,
task_types=args.task_types,
languages=args.languages,
tasks=args.tasks,
)
if args.benchmarks:
tasks = mteb.get_benchmarks(names=args.benchmarks)
else:
tasks = mteb.get_tasks(
categories=args.categories,
task_types=args.task_types,
languages=args.languages,
tasks=args.tasks,
)

eval = mteb.MTEB(tasks=tasks)

encode_kwargs = {}
Expand All @@ -153,7 +157,7 @@ def run(args: argparse.Namespace) -> None:


def available_benchmarks(args: argparse.Namespace) -> None:
benchmarks = mteb.get_benchmarks()
benchmarks = mteb.get_benchmarks(names=args.benchmarks)
eval = mteb.MTEB(tasks=benchmarks)
eval.mteb_benchmarks()

Expand All @@ -169,6 +173,18 @@ def available_tasks(args: argparse.Namespace) -> None:
eval.mteb_tasks()


def add_benchmark_selection_args(parser: argparse.ArgumentParser) -> None:
"""Adds arguments to the parser for filtering benchmarks by name."""
parser.add_argument(
"-b",
"--benchmarks",
nargs="+",
type=str,
default=None,
help="List of benchmark to be evaluated.",
)


def add_task_selection_args(parser: argparse.ArgumentParser) -> None:
"""Adds arguments to the parser for filtering tasks by type, category, language, and task name."""
parser.add_argument(
Expand Down Expand Up @@ -216,7 +232,7 @@ def add_available_benchmarks_parser(subparsers) -> None:
parser = subparsers.add_parser(
"available_benchmarks", help="List the available benchmarks within MTEB"
)
add_task_selection_args(parser)
add_benchmark_selection_args(parser)

parser.set_defaults(func=available_benchmarks)

Expand All @@ -232,6 +248,7 @@ def add_run_parser(subparsers) -> None:
)

add_task_selection_args(parser)
add_benchmark_selection_args(parser)

parser.add_argument(
"--device", type=int, default=None, help="Device to use for computation"
Expand Down
8 changes: 7 additions & 1 deletion mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import traceback
from copy import copy
from datetime import datetime
from itertools import chain
from pathlib import Path
from time import time
from typing import Any, Iterable
Expand Down Expand Up @@ -52,12 +53,17 @@ def __init__(
err_logs_path: Path to save error logs.
kwargs: Additional arguments to be passed to the tasks
"""
from mteb.benchmarks import Benchmark

self.deprecation_warning(
task_types, task_categories, task_langs, tasks, version
)

if tasks is not None:
self._tasks = tasks
if isinstance(tasks[0], Benchmark):
self.benchmarks = tasks
self._tasks = list(chain.from_iterable(tasks))
assert (
task_types is None and task_categories is None
), "Cannot specify both `tasks` and `task_types`/`task_categories`"
Expand Down Expand Up @@ -170,7 +176,7 @@ def _display_tasks(self, task_list, name=None):

def mteb_benchmarks(self):
"""Get all benchmarks available in the MTEB."""
for benchmark in self._tasks:
for benchmark in self.benchmarks:
name = benchmark.name
self._display_tasks(benchmark.tasks, name=name)

Expand Down
13 changes: 13 additions & 0 deletions tests/test_benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,19 @@ def test_run_using_benchmark(model: mteb.Encoder):
) # we just want to test that it runs


@pytest.mark.parametrize("model", [MockNumpyEncoder()])
def test_run_using_list_of_benchmark(model: mteb.Encoder):
"""Test that a list of benchmark objects can be run using the MTEB class."""
bench = [
Benchmark(name="test_bench", tasks=mteb.get_tasks(tasks=["STS12", "SummEval"]))
]

eval = mteb.MTEB(tasks=bench)
eval.run(
model, output_folder="tests/results", overwrite_results=True
) # we just want to test that it runs


def test_benchmark_names_must_be_unique():
import mteb.benchmarks.benchmarks as benchmark_module

Expand Down
3 changes: 2 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_available_benchmarks():
assert result.returncode == 0, "Command failed"
assert (
"MTEB(eng)" in result.stdout
), "Sample benchmark MTEB(eng) task not found in available bencmarks"
), "Sample benchmark MTEB(eng) task not found in available benchmarks"


run_task_fixures = [
Expand Down Expand Up @@ -65,6 +65,7 @@ def test_run_task(
co2_tracker=None,
overwrite=True,
eval_splits=None,
benchmarks=None,
)

run(args)
Expand Down
Loading