Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cohere v2/rerank support #8421

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/my-website/docs/providers/cohere.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ response = embedding(

### Usage


LiteLLM supports the v1 and v2 clients for Cohere rerank. By default, the `rerank` endpoint uses the v2 client, but you can specify the v1 client by explicitly calling `v1/rerank`

<Tabs>
<TabItem value="sdk" label="LiteLLM SDK Usage">
Expand Down
2 changes: 1 addition & 1 deletion docs/my-website/docs/rerank.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ curl http://0.0.0.0:4000/rerank \

| Provider | Link to Usage |
|-------------|--------------------|
| Cohere | [Usage](#quick-start) |
| Cohere (v1 + v2 clients) | [Usage](#quick-start) |
| Together AI| [Usage](../docs/providers/togetherai) |
| Azure AI| [Usage](../docs/providers/azure_ai) |
| Jina AI| [Usage](../docs/providers/jina_ai) |
Expand Down
4 changes: 4 additions & 0 deletions litellm/llms/azure_ai/rerank/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class AzureAIRerankConfig(CohereRerankConfig):
Azure AI Rerank - Follows the same Spec as Cohere Rerank
"""

# Azure does not support v2/rerank for cohere yet
def __init__(self):
super().__init__(True)

def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base is None:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions litellm/llms/base_llm/rerank/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def map_cohere_rerank_params(
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
pass

Expand Down
56 changes: 46 additions & 10 deletions litellm/llms/cohere/rerank/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,36 @@ class CohereRerankConfig(BaseRerankConfig):
Reference: https://docs.cohere.com/v2/reference/rerank
"""

def __init__(self) -> None:
pass

def get_complete_url(self, api_base: Optional[str], model: str) -> str:
def __init__(self, api_base: Optional[str], present_version_params: list[str]) -> None:
# Default to the v2 client unless the user specifically uses the v1/rerank endpoint or uses v1-specific params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

request, instead of adding complexity to /rerank. Create rerankv2/transformation.py. It should inherit from CohereRerankConfig

This is an established pattern we follow. See

uses_v1_params = ("max_chunks_per_doc" in present_version_params) and ('max_tokens_per_doc' not in present_version_params)
if api_base:
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/v1/rerank"):
api_base = f"{api_base}/v1/rerank"
if api_base.endswith("/v1/rerank") or (uses_v1_params and not api_base.endswith("/v2/rerank")):
self.uses_v1_client = True
return

self.uses_v1_client = False

def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if not api_base:
return "https://api.cohere.ai/v2/rerank"

api_base = api_base.rstrip("/")

# Use /v1/rerank if user intentionally uses the deprecated cohere endpoint
if self.uses_v1_client:
if api_base.endswith("/v1/rerank"):
return api_base
else:
return f"{api_base}/v1/rerank"

# By default use the v2 endpoint
if api_base.endswith("/v2/rerank"):
return api_base
return "https://api.cohere.ai/v1/rerank"

return f"{api_base}/v2/rerank"

def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
Expand All @@ -38,6 +57,7 @@ def get_supported_cohere_rerank_params(self, model: str) -> list:
"max_chunks_per_doc",
"rank_fields",
"return_documents",
"max_tokens_per_doc"
]

def map_cohere_rerank_params(
Expand All @@ -52,21 +72,29 @@ def map_cohere_rerank_params(
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
"""
Map Cohere rerank params

No mapping required - returns all supported params
"""

unique_version_params = (
{"max_chunks_per_doc": max_chunks_per_doc}
if self.uses_v1_client
else {"max_tokens_per_doc": max_tokens_per_doc}
)

return OptionalRerankParams(
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
**unique_version_params
)

def validate_environment(
self,
headers: dict,
Expand Down Expand Up @@ -108,15 +136,23 @@ def transform_rerank_request(
raise ValueError("query is required for Cohere rerank")
if "documents" not in optional_rerank_params:
raise ValueError("documents is required for Cohere rerank")

unique_version_params = (
{"max_chunks_per_doc": optional_rerank_params.get("max_chunks_per_doc", None)}
if self.uses_v1_client
else {"max_tokens_per_doc": optional_rerank_params.get("max_tokens_per_doc", None)}
)

rerank_request = RerankRequest(
model=model,
query=optional_rerank_params["query"],
documents=optional_rerank_params["documents"],
top_n=optional_rerank_params.get("top_n", None),
rank_fields=optional_rerank_params.get("rank_fields", None),
return_documents=optional_rerank_params.get("return_documents", None),
max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None),
**unique_version_params
)

return rerank_request.model_dump(exclude_none=True)

def transform_rerank_response(
Expand Down
6 changes: 2 additions & 4 deletions litellm/llms/custom_httpx/llm_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,7 @@ def rerank(
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
provider_config: BaseRerankConfig,
optional_rerank_params: OptionalRerankParams,
timeout: Optional[Union[float, httpx.Timeout]],
model_response: RerankResponse,
Expand All @@ -717,10 +718,7 @@ def rerank(
api_base: Optional[str] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:

provider_config = ProviderConfigManager.get_provider_rerank_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)

# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
Expand Down
5 changes: 5 additions & 0 deletions litellm/llms/infinity/rerank/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@


class InfinityRerankConfig(CohereRerankConfig):
# This is set to the v1/rerank endpoint to not break any existing integrations
# This should be changed once it is confirmed that infinity supports the v2 endpoint
def __init__(self):
super().__init__(True)

def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base is None:
raise ValueError("api_base is required for Infinity rerank")
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class LiteLLMRoutes(enum.Enum):
# rerank
"/rerank",
"/v1/rerank",
"/v2/rerank"
# realtime
"/realtime",
"/v1/realtime",
Expand Down
7 changes: 6 additions & 1 deletion litellm/proxy/rerank_endpoints/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
router = APIRouter()
import asyncio


@router.post(
"/v2/rerank",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["rerank"],
)
@router.post(
"/v1/rerank",
dependencies=[Depends(user_api_key_auth)],
Expand Down
10 changes: 10 additions & 0 deletions litellm/rerank_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def rerank( # noqa: PLR0915
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
**kwargs,
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
"""
Expand All @@ -99,6 +100,9 @@ def rerank( # noqa: PLR0915
try:
_is_async = kwargs.pop("arerank", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
# Params that are unique to specific versions of the client for the rerank call
unique_version_params = {"max_chunks_per_doc": max_chunks_per_doc, "max_tokens_per_doc": max_tokens_per_doc}
present_version_params = [k for k, v in unique_version_params.items() if v is not None]

model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
litellm.get_llm_provider(
Expand All @@ -113,6 +117,8 @@ def rerank( # noqa: PLR0915
ProviderConfigManager.get_provider_rerank_config(
model=model,
provider=litellm.LlmProviders(_custom_llm_provider),
api_base=optional_params.api_base,
present_version_params=present_version_params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we avoid adding these two new params to get_provider_rerank_config. present_version seems specific to cohere and should be handled with CohereRerankConfig() and CohereRerankv2Config()

)
)

Expand All @@ -127,6 +133,7 @@ def rerank( # noqa: PLR0915
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
max_tokens_per_doc=max_tokens_per_doc,
non_default_params=kwargs,
)

Expand Down Expand Up @@ -173,6 +180,7 @@ def rerank( # noqa: PLR0915
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
provider_config=rerank_provider_config,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
Expand All @@ -194,6 +202,7 @@ def rerank( # noqa: PLR0915
model=model,
custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params,
provider_config=rerank_provider_config,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
Expand Down Expand Up @@ -222,6 +231,7 @@ def rerank( # noqa: PLR0915
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
provider_config=rerank_provider_config,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
Expand Down
2 changes: 2 additions & 0 deletions litellm/rerank_api/rerank_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def get_optional_rerank_params(
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
non_default_params: Optional[dict] = None,
) -> OptionalRerankParams:
return rerank_provider_config.map_cohere_rerank_params(
Expand All @@ -27,5 +28,6 @@ def get_optional_rerank_params(
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
max_tokens_per_doc=max_tokens_per_doc,
non_default_params=non_default_params,
)
3 changes: 3 additions & 0 deletions litellm/types/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class RerankRequest(BaseModel):
rank_fields: Optional[List[str]] = None
return_documents: Optional[bool] = None
max_chunks_per_doc: Optional[int] = None
max_tokens_per_doc: Optional[int] = None



class OptionalRerankParams(TypedDict, total=False):
Expand All @@ -27,6 +29,7 @@ class OptionalRerankParams(TypedDict, total=False):
rank_fields: Optional[List[str]]
return_documents: Optional[bool]
max_chunks_per_doc: Optional[int]
max_tokens_per_doc: Optional[int]


class RerankBilledUnits(TypedDict, total=False):
Expand Down
6 changes: 4 additions & 2 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6119,14 +6119,16 @@ def get_provider_embedding_config(
def get_provider_rerank_config(
model: str,
provider: LlmProviders,
api_base: Optional[str],
present_version_params: list[str],
) -> BaseRerankConfig:
if litellm.LlmProviders.COHERE == provider:
return litellm.CohereRerankConfig()
return litellm.CohereRerankConfig(api_base, present_version_params)
elif litellm.LlmProviders.AZURE_AI == provider:
return litellm.AzureAIRerankConfig()
elif litellm.LlmProviders.INFINITY == provider:
return litellm.InfinityRerankConfig()
return litellm.CohereRerankConfig()
return litellm.CohereRerankConfig(api_base, present_version_params)

@staticmethod
def get_provider_audio_transcription_config(
Expand Down
Loading