Skip to content

Commit

Permalink
chore(backend): fix issues from rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
ezawadski committed Jan 15, 2025
1 parent 74ed897 commit d341c24
Show file tree
Hide file tree
Showing 12 changed files with 129 additions and 116 deletions.
20 changes: 10 additions & 10 deletions src/backend/model_deployments/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ def __init__(self, **kwargs: Any):
base_url=self.chat_endpoint_url, api_key=self.api_key
)

@classmethod
def name(cls) -> str:
@staticmethod
def name() -> str:
return "Azure"

@classmethod
def env_vars(cls) -> List[str]:
@staticmethod
def env_vars() -> list[str]:
return [AZURE_API_KEY_ENV_VAR, AZURE_CHAT_URL_ENV_VAR]

@classmethod
def rerank_enabled(cls) -> bool:
@staticmethod
def rerank_enabled() -> bool:
return False

@classmethod
Expand All @@ -62,14 +62,14 @@ def list_models(cls) -> list[str]:

return cls.DEFAULT_MODELS

@classmethod
def is_available(cls) -> bool:
@staticmethod
def is_available() -> bool:
return (
AzureDeployment.default_api_key is not None
and AzureDeployment.default_chat_endpoint_url is not None
)

async def invoke_chat(self, chat_request: CohereChatRequest) -> Any:
async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs) -> Any:
response = self.client.chat(
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
)
Expand All @@ -86,6 +86,6 @@ async def invoke_chat_stream(
yield to_dict(event)

async def invoke_rerank(
self, query: str, documents: list[str], ctx: Context
self, query: str, documents: list[str], ctx: Context, **kwargs
) -> Any:
return None
23 changes: 12 additions & 11 deletions src/backend/model_deployments/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator
from typing import Any

from backend.config.settings import Settings
from backend.schemas.cohere_chat import CohereChatRequest
Expand All @@ -25,31 +25,32 @@ def __init__(self, db_id=None, **kwargs: Any):
def id(cls) -> str:
return cls.db_id if cls.db_id else cls.name().replace(" ", "_").lower()

@classmethod
@staticmethod
@abstractmethod
def name(cls) -> str: ...
def name() -> str: ...

@classmethod
@staticmethod
@abstractmethod
def env_vars(cls) -> List[str]: ...
def env_vars() -> list[str]: ...

@classmethod
@staticmethod
@abstractmethod
def rerank_enabled(cls) -> bool: ...
def rerank_enabled() -> bool: ...

@classmethod
@abstractmethod
def list_models(cls) -> list[str]: ...

@classmethod
@staticmethod
@abstractmethod
def is_available(cls) -> bool: ...
def is_available() -> bool: ...

@classmethod
def is_community(cls) -> bool:
return False

def config(cls) -> Dict[str, Any]:
@classmethod
def config(cls) -> dict[str, Any]:
config = Settings().get(f"deployments.{cls.id()}")
config_dict = {} if not config else dict(config)
for key, value in config_dict.items():
Expand Down Expand Up @@ -78,7 +79,7 @@ async def invoke_chat(
@abstractmethod
async def invoke_chat_stream(
self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any
) -> AsyncGenerator[Any, Any]: ...
) -> Any: ...

@abstractmethod
async def invoke_rerank(
Expand Down
20 changes: 10 additions & 10 deletions src/backend/model_deployments/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ def __init__(self, **kwargs: Any):
),
)

@classmethod
def name(cls) -> str:
@staticmethod
def name() -> str:
return "Bedrock"

@classmethod
def env_vars(cls) -> List[str]:
@staticmethod
def env_vars() -> list[str]:
return [
BEDROCK_ACCESS_KEY_ENV_VAR,
BEDROCK_SECRET_KEY_ENV_VAR,
BEDROCK_SESSION_TOKEN_ENV_VAR,
BEDROCK_REGION_NAME_ENV_VAR,
]

@classmethod
def rerank_enabled(cls) -> bool:
@staticmethod
def rerank_enabled() -> bool:
return False

@classmethod
Expand All @@ -66,16 +66,16 @@ def list_models(cls) -> list[str]:

return cls.DEFAULT_MODELS

@classmethod
def is_available(cls) -> bool:
@staticmethod
def is_available() -> bool:
return (
BedrockDeployment.access_key is not None
and BedrockDeployment.secret_access_key is not None
and BedrockDeployment.session_token is not None
and BedrockDeployment.region_name is not None
)

async def invoke_chat(self, chat_request: CohereChatRequest) -> Any:
async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
# bedrock accepts a subset of the chat request fields
bedrock_chat_req = chat_request.model_dump(
exclude={"tools", "conversation_id", "model", "stream"}, exclude_none=True
Expand All @@ -101,6 +101,6 @@ async def invoke_chat_stream(
yield to_dict(event)

async def invoke_rerank(
self, query: str, documents: list[str], ctx: Context
self, query: str, documents: list[str], ctx: Context, **kwargs: Any
) -> Any:
return None
18 changes: 9 additions & 9 deletions src/backend/model_deployments/cohere_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ def __init__(self, **kwargs: Any):
)
self.client = cohere.Client(api_key, client_name=self.client_name)

@classmethod
def name(cls) -> str:
@staticmethod
def name() -> str:
return "Cohere Platform"

@classmethod
def env_vars(cls) -> List[str]:
@staticmethod
def env_vars() -> list[str]:
return [COHERE_API_KEY_ENV_VAR]

@classmethod
def rerank_enabled(cls) -> bool:
@staticmethod
def rerank_enabled() -> bool:
return True

@classmethod
Expand All @@ -64,12 +64,12 @@ def list_models(cls) -> list[str]:
models = response.json()["models"]
return [model["name"] for model in models if model.get("endpoints") and "chat" in model["endpoints"]]

@classmethod
def is_available(cls) -> bool:
@staticmethod
def is_available() -> bool:
return CohereDeployment.api_key is not None

async def invoke_chat(
self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any
self, chat_request: CohereChatRequest, **kwargs: Any
) -> Any:
response = self.client.chat(
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
Expand Down
18 changes: 9 additions & 9 deletions src/backend/model_deployments/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def __init__(self, **kwargs: Any):
"ContentType": "application/json",
}

@classmethod
def name(cls) -> str:
@staticmethod
def name() -> str:
return "SageMaker"

@classmethod
def env_vars(cls) -> List[str]:
@staticmethod
def env_vars() -> list[str]:
return [
SAGE_MAKER_ACCESS_KEY_ENV_VAR,
SAGE_MAKER_SECRET_KEY_ENV_VAR,
Expand All @@ -79,8 +79,8 @@ def env_vars(cls) -> List[str]:
SAGE_MAKER_ENDPOINT_NAME_ENV_VAR,
]

@classmethod
def rerank_enabled(cls) -> bool:
@staticmethod
def rerank_enabled() -> bool:
return False

@classmethod
Expand All @@ -90,8 +90,8 @@ def list_models(cls) -> list[str]:

return cls.DEFAULT_MODELS

@classmethod
def is_available(cls) -> bool:
@staticmethod
def is_available() -> bool:
return (
SageMakerDeployment.region_name is not None
and SageMakerDeployment.aws_access_key_id is not None
Expand Down Expand Up @@ -121,7 +121,7 @@ async def invoke_chat_stream(
yield stream_event

async def invoke_rerank(
self, query: str, documents: list[str], ctx: Context
self, query: str, documents: list[str], ctx: Context, **kwargs
) -> Any:
return None

Expand Down
20 changes: 10 additions & 10 deletions src/backend/model_deployments/single_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ def __init__(self, **kwargs: Any):
base_url=self.url, client_name=self.client_name, api_key="none"
)

@classmethod
def name(cls) -> str:
@staticmethod
def name() -> str:
return "Single Container"

@classmethod
def env_vars(cls) -> List[str]:
@staticmethod
def env_vars() -> list[str]:
return [SC_URL_ENV_VAR, SC_MODEL_ENV_VAR]

@classmethod
def rerank_enabled(cls) -> bool:
@staticmethod
def rerank_enabled() -> bool:
return SingleContainerDeployment.default_model.startswith("rerank")

@classmethod
Expand All @@ -52,14 +52,14 @@ def list_models(cls) -> list[str]:

return [SingleContainerDeployment.default_model]

@classmethod
def is_available(cls) -> bool:
@staticmethod
def is_available() -> bool:
return (
SingleContainerDeployment.default_model is not None
and SingleContainerDeployment.default_url is not None
)

async def invoke_chat(self, chat_request: CohereChatRequest) -> Any:
async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs) -> Any:
response = self.client.chat(
**chat_request.model_dump(
exclude={"stream", "file_ids", "model", "agent_id"}
Expand All @@ -80,7 +80,7 @@ async def invoke_chat_stream(
yield to_dict(event)

async def invoke_rerank(
self, query: str, documents: list[str], ctx: Context
self, query: str, documents: list[str], ctx: Context, **kwargs
) -> Any:
return self.client.rerank(
query=query, documents=documents, model=DEFAULT_RERANK_MODEL
Expand Down
Loading

0 comments on commit d341c24

Please sign in to comment.