Skip to content

Commit

Permalink
Implements score testing for embedding models.
Browse files Browse the repository at this point in the history
Signed-off-by: Gabriel Marinho <[email protected]>
  • Loading branch information
gmarinho2 committed Jan 21, 2025
1 parent f5864a6 commit b8267c1
Showing 1 changed file with 104 additions and 0 deletions.
104 changes: 104 additions & 0 deletions tests/models/embedding/language/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
import math

import pytest
import torch
import torch.nn.functional as F

MODELS = [
"cross-encoder/ms-marco-MiniLM-L-6-v2", # Bert
"BAAI/bge-reranker-v2-m3", # Roberta
]

EMBEDDING_MODELS = [
"sentence-transformers/all-MiniLM-L12-v2",
]

TEXTS_1 = [
"What is the capital of France?",
"What is the capital of Germany?",
Expand Down Expand Up @@ -87,3 +93,101 @@ def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str):

assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)


@pytest.fixture(scope="module", params=EMBEDDING_MODELS)
def emb_model_name(request):
yield request.param


@pytest.mark.parametrize("dtype", ["half"])
def test_llm_1_to_1_embedding(vllm_runner, hf_runner, emb_model_name,
dtype: str):

text_pair = [TEXTS_1[0], TEXTS_2[0]]

with hf_runner(emb_model_name, dtype=dtype,
is_sentence_transformer=True) as hf_model:
hf_embeddings = hf_model.encode(text_pair)
hf_outputs = [
F.cosine_similarity(torch.tensor(hf_embeddings[0]),
torch.tensor(hf_embeddings[1]),
dim=0)
]

with vllm_runner(emb_model_name,
task="embed",
dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(text_pair[0], text_pair[1])

assert len(vllm_outputs) == 1
assert len(hf_outputs) == 1

assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)


@pytest.mark.parametrize("dtype", ["half"])
def test_llm_1_to_N_embedding(vllm_runner, hf_runner, emb_model_name,
dtype: str):

text_pairs = [
[TEXTS_1[0], TEXTS_2[0]],
[TEXTS_1[0], TEXTS_2[1]],
]

with hf_runner(emb_model_name, dtype=dtype,
is_sentence_transformer=True) as hf_model:
hf_embeddings = [
hf_model.encode(text_pair) for text_pair in text_pairs
]
hf_outputs = [
F.cosine_similarity(torch.tensor(encoded_pair[0]),
torch.tensor(encoded_pair[1]),
dim=0) for encoded_pair in hf_embeddings
]

with vllm_runner(emb_model_name,
task="embed",
dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2)

assert len(vllm_outputs) == 2
assert len(hf_outputs) == 2

assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)


@pytest.mark.parametrize("dtype", ["half"])
def test_llm_N_to_N_embedding(vllm_runner, hf_runner, emb_model_name,
dtype: str):

text_pairs = [
[TEXTS_1[0], TEXTS_2[0]],
[TEXTS_1[1], TEXTS_2[1]],
]

with hf_runner(emb_model_name, dtype=dtype,
is_sentence_transformer=True) as hf_model:
hf_embeddings = [
hf_model.encode(text_pair) for text_pair in text_pairs
]
hf_outputs = [
F.cosine_similarity(torch.tensor(encoded_pair[0]),
torch.tensor(encoded_pair[1]),
dim=0) for encoded_pair in hf_embeddings
]

with vllm_runner(emb_model_name,
task="embed",
dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2)

assert len(vllm_outputs) == 2
assert len(hf_outputs) == 2

assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)

0 comments on commit b8267c1

Please sign in to comment.