-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE] Enables offline /score for embedding models #12021
base: main
Are you sure you want to change the base?
Changes from 7 commits
9b0ce65
74cd2dd
590ab4d
f219024
47211e5
1a89033
db7919b
41b11be
b57e01a
4956eae
f8b8d8c
cd835de
0e95ead
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -5,6 +5,7 @@ | |||||
Tuple, Type, Union, cast, overload) | ||||||
|
||||||
import cloudpickle | ||||||
import torch | ||||||
import torch.nn as nn | ||||||
from tqdm import tqdm | ||||||
from typing_extensions import TypeVar, deprecated | ||||||
|
@@ -996,6 +997,93 @@ def classify( | |||||
|
||||||
return [ClassificationRequestOutput.from_base(item) for item in items] | ||||||
|
||||||
def _embedding_score(self, tokenizer, truncate_prompt_tokens, | ||||||
text_1: List[str | TextPrompt | TokensPrompt], | ||||||
text_2: List[str | TextPrompt | TokensPrompt], | ||||||
use_tqdm, lora_request, | ||||||
prompt_adapter_request) -> List[ScoringRequestOutput]: | ||||||
|
||||||
DarkLight1337 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
encoded_output = self.encode(text_1 + text_2) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's pass |
||||||
encoded_output_1 = encoded_output[0:len(text_1)] | ||||||
encoded_output_2 = encoded_output[len(text_1):] | ||||||
|
||||||
if len(encoded_output_1) == 1: | ||||||
encoded_output_1 = encoded_output_1 * len(encoded_output_2) | ||||||
|
||||||
output_pairs = [(t1, t2) | ||||||
for t1, t2 in zip(encoded_output_1, encoded_output_2)] | ||||||
|
||||||
scores = [] | ||||||
scorer = torch.nn.CosineSimilarity(0) | ||||||
|
||||||
for embed_1, embed_2 in output_pairs: | ||||||
pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data) | ||||||
|
||||||
if getattr(tokenizer, "pad_token", None) is None: | ||||||
tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids | ||||||
else: | ||||||
tokens = embed_1.prompt_token_ids + [ | ||||||
tokenizer.pad_token_type_id | ||||||
] + embed_2.prompt_token_ids | ||||||
|
||||||
scores.append( | ||||||
PoolingRequestOutput( | ||||||
request_id=f"{embed_1.request_id}_{embed_2.request_id}", | ||||||
outputs=pair_score, | ||||||
prompt_token_ids=tokens, | ||||||
finished=True)) | ||||||
|
||||||
items = self.engine_class.validate_outputs(scores, | ||||||
PoolingRequestOutput) | ||||||
return [ScoringRequestOutput.from_base(item) for item in items] | ||||||
|
||||||
def _cross_encoding_score( | ||||||
self, tokenizer, truncate_prompt_tokens, | ||||||
text_1: List[str | TextPrompt | TokensPrompt], | ||||||
text_2: List[str | TextPrompt | TokensPrompt], use_tqdm, | ||||||
lora_request, | ||||||
prompt_adapter_request) -> List[ScoringRequestOutput]: | ||||||
|
||||||
if isinstance(tokenizer, MistralTokenizer): | ||||||
raise ValueError( | ||||||
"MistralTokenizer not supported for cross-encoding") | ||||||
|
||||||
if len(text_1) == 1: | ||||||
text_1 = text_1 * len(text_2) | ||||||
|
||||||
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] | ||||||
|
||||||
pooling_params = PoolingParams() | ||||||
|
||||||
tokenization_kwargs: Dict[str, Any] = {} | ||||||
if truncate_prompt_tokens is not None: | ||||||
tokenization_kwargs["truncation"] = True | ||||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens | ||||||
|
||||||
parsed_prompts = [] | ||||||
|
||||||
for q, t in input_pairs: | ||||||
prompt_inputs = tokenizer(text=q, | ||||||
text_pair=t, | ||||||
**tokenization_kwargs) | ||||||
engine_prompt = TokensPrompt( | ||||||
prompt_token_ids=prompt_inputs["input_ids"], | ||||||
token_type_ids=prompt_inputs.get("token_type_ids")) | ||||||
parsed_prompts.append(engine_prompt) | ||||||
|
||||||
self._validate_and_add_requests( | ||||||
prompts=parsed_prompts, | ||||||
params=pooling_params, | ||||||
lora_request=lora_request, | ||||||
prompt_adapter_request=prompt_adapter_request, | ||||||
) | ||||||
|
||||||
outputs = self._run_engine(use_tqdm=use_tqdm) | ||||||
items = self.engine_class.validate_outputs(outputs, | ||||||
PoolingRequestOutput) | ||||||
|
||||||
return [ScoringRequestOutput.from_base(item) for item in items] | ||||||
|
||||||
def score( | ||||||
self, | ||||||
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]], | ||||||
|
@@ -1032,6 +1120,7 @@ def score( | |||||
A list of ``ScoringRequestOutput`` objects containing the | ||||||
generated scores in the same order as the input prompts. | ||||||
""" | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Avoid unnecessary line changes |
||||||
runner_type = self.llm_engine.model_config.runner_type | ||||||
if runner_type != "pooling": | ||||||
messages = ["LLM.score() is only supported for pooling models."] | ||||||
|
@@ -1047,25 +1136,20 @@ def score( | |||||
|
||||||
raise ValueError(" ".join(messages)) | ||||||
|
||||||
if not self.llm_engine.model_config.is_cross_encoder: | ||||||
raise ValueError("Your model does not support cross encoding") | ||||||
if self.llm_engine.model_config.task != "score": | ||||||
raise ValueError("Score API is only enabled for `--task score`") | ||||||
|
||||||
tokenizer = self.llm_engine.get_tokenizer() | ||||||
|
||||||
if isinstance(tokenizer, MistralTokenizer): | ||||||
if self.llm_engine.model_config.task not in ("embed", "score"): | ||||||
raise ValueError( | ||||||
"MistralTokenizer not supported for cross-encoding") | ||||||
"Score API is only enabled for `--task embed or score`") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
# the tokenizer for models such as | ||||||
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing | ||||||
# lists of tokens to the `text` and `text_pair` kwargs | ||||||
tokenizer = self.llm_engine.get_tokenizer() | ||||||
|
||||||
def ensure_str(prompt: SingletonPrompt): | ||||||
if isinstance(prompt, dict): | ||||||
if "multi_modal_data" in prompt: | ||||||
raise ValueError("Multi-modal prompt is not " | ||||||
"supported for cross encoding") | ||||||
"supported for scoring") | ||||||
elif "prompt_token_ids" in prompt: | ||||||
prompt = tokenizer.decode( | ||||||
cast(TokensPrompt, prompt)["prompt_token_ids"]) | ||||||
|
@@ -1091,40 +1175,15 @@ def ensure_str(prompt: SingletonPrompt): | |||||
if len(text_2) == 0: | ||||||
raise ValueError("At least one text_pair element must be given") | ||||||
|
||||||
if len(text_1) == 1: | ||||||
text_1 = text_1 * len(text_2) | ||||||
|
||||||
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] | ||||||
pooling_params = PoolingParams() | ||||||
|
||||||
tokenization_kwargs: Dict[str, Any] = {} | ||||||
if truncate_prompt_tokens is not None: | ||||||
tokenization_kwargs["truncation"] = True | ||||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens | ||||||
|
||||||
parsed_prompts = [] | ||||||
|
||||||
for q, t in input_pairs: | ||||||
prompt_inputs = tokenizer(text=q, | ||||||
text_pair=t, | ||||||
**tokenization_kwargs) | ||||||
engine_prompt = TokensPrompt( | ||||||
prompt_token_ids=prompt_inputs["input_ids"], | ||||||
token_type_ids=prompt_inputs.get("token_type_ids")) | ||||||
parsed_prompts.append(engine_prompt) | ||||||
|
||||||
self._validate_and_add_requests( | ||||||
prompts=parsed_prompts, | ||||||
params=pooling_params, | ||||||
lora_request=lora_request, | ||||||
prompt_adapter_request=prompt_adapter_request, | ||||||
) | ||||||
|
||||||
outputs = self._run_engine(use_tqdm=use_tqdm) | ||||||
items = self.engine_class.validate_outputs(outputs, | ||||||
PoolingRequestOutput) | ||||||
|
||||||
return [ScoringRequestOutput.from_base(item) for item in items] | ||||||
if self.llm_engine.model_config.is_cross_encoder: | ||||||
return self._cross_encoding_score(tokenizer, | ||||||
truncate_prompt_tokens, text_1, | ||||||
text_2, use_tqdm, lora_request, | ||||||
prompt_adapter_request) | ||||||
else: | ||||||
return self._embedding_score(tokenizer, truncate_prompt_tokens, | ||||||
text_1, text_2, use_tqdm, | ||||||
lora_request, prompt_adapter_request) | ||||||
|
||||||
def start_profile(self) -> None: | ||||||
self.llm_engine.start_profile() | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add type annotations for the parameters.