Skip to content

Commit

Permalink
Merge with main branch
Browse files Browse the repository at this point in the history
  • Loading branch information
alle-pawols committed Nov 5, 2024
2 parents 79acecf + 4e3c893 commit 37ea284
Show file tree
Hide file tree
Showing 17 changed files with 369 additions and 182 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ ___

## Supported Models

| LLM Family | Hosting | Supported LLMs |
|-------------|---------------------|-----------------------------------------|
| GPT(s) | OpenAI endpoint | `gpt-3.5-turbo`, `gpt-4`, `gpt-4-turbo` |
| Google LLMs | VertexAI deployment | `text-bison@001`, `gemini-pro` |
| Llama2 | Azure deployment | `llama2-7b`, `llama2-13b`, `llama2-70b` |
| Mistral | Azure deployment | `Mistral-7b`, `Mixtral-7bx8` |
| Gemma | GCP deployment | `gemma` |
| LLM Family | Hosting | Supported LLMs |
|-------------|---------------------|------------------------------------------------------------------|
| GPT(s) | OpenAI endpoint | `gpt-3.5-turbo`, `gpt-4`, `gpt-4-turbo`, `gpt4-o`, `gpt4-o mini` |
| Google LLMs | VertexAI deployment | `text-bison@001`, `gemini-pro`, `gemini-flash` |
| Llama2 | Azure deployment | `llama2-7b`, `llama2-13b`, `llama2-70b` |
| Mistral | Azure deployment | `Mistral-7b`, `Mixtral-7bx8` |
| Gemma | GCP deployment | `gemma` |

* Do you already have a subscription to a Cloud Provider for any the models above? Configure
the model using your credentials and start querying!
Expand Down
2 changes: 1 addition & 1 deletion allms/defaults/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PalmModelDefaults:


class GeminiModelDefaults:
GCP_MODEL_NAME = "gemini-pro"
GCP_MODEL_NAME = "gemini-1.5-flash-001"
MODEL_TOTAL_MAX_TOKENS = 30720
MAX_OUTPUT_TOKENS = 2048
TEMPERATURE = 0.0
Expand Down
8 changes: 6 additions & 2 deletions allms/domain/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass
from typing import Optional
from typing import Dict, Optional

from langchain_google_vertexai import HarmBlockThreshold, HarmCategory

from allms.defaults.vertex_ai import GeminiModelDefaults, PalmModelDefaults

Expand All @@ -10,7 +12,8 @@ class AzureOpenAIConfiguration:
deployment: str
model_name: str
api_version: str
api_key: str
api_key: Optional[str] = None
azure_ad_token: Optional[str] = None


@dataclass
Expand All @@ -26,6 +29,7 @@ class VertexAIConfiguration:
cloud_location: str
palm_model_name: Optional[str] = PalmModelDefaults.GCP_MODEL_NAME
gemini_model_name: Optional[str] = GeminiModelDefaults.GCP_MODEL_NAME
gemini_safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None


class VertexAIModelGardenConfiguration(VertexAIConfiguration):
Expand Down
7 changes: 5 additions & 2 deletions allms/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Type
from typing import Dict, Type

from allms.domain.configuration import HarmBlockThreshold, HarmCategory
from allms.domain.enumerables import AvailableModels
from allms.models.abstract import AbstractModel
from allms.models.azure_llama2 import AzureLlama2Model
Expand All @@ -16,11 +17,13 @@
"VertexAIPalmModel",
"VertexAIGeminiModel",
"VertexAIGemmaModel",
"HarmCategory",
"HarmBlockThreshold",
"get_available_models"
]


def get_available_models() -> dict[str, Type[AbstractModel]]:
def get_available_models() -> Dict[str, Type[AbstractModel]]:
return {
AvailableModels.AZURE_OPENAI_MODEL: AzureOpenAIModel,
AvailableModels.AZURE_LLAMA2_MODEL: AzureLlama2Model,
Expand Down
7 changes: 4 additions & 3 deletions allms/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from allms.domain.input_data import InputData
from allms.domain.prompt_dto import SummaryOutputClass, KeywordsOutputClass
from allms.domain.response import ResponseData
from allms.models.vertexai_base import GCPInvalidRequestError
from allms.utils.long_text_processing_utils import get_max_allowed_number_of_tokens
from allms.utils.response_parsing_utils import ResponseParser

Expand Down Expand Up @@ -168,7 +169,7 @@ async def _build_chat_prompts(
self,
prompt_template_args: dict,
system_prompt: SystemMessagePromptTemplate
) -> list[SystemMessagePromptTemplate | HumanMessagePromptTemplate]:
) -> typing.List[typing.Union[SystemMessagePromptTemplate, HumanMessagePromptTemplate]]:
human_message = HumanMessagePromptTemplate(prompt=PromptTemplate(**prompt_template_args))
if not system_prompt:
return [human_message]
Expand Down Expand Up @@ -261,7 +262,7 @@ async def _predict_example(
model_response = None
error_message = f"{IODataConstants.ERROR_MESSAGE_STR}: {invalid_request_error}"

except (InvalidArgument, ValueError, TimeoutError, openai.APIError) as other_error:
except (InvalidArgument, ValueError, TimeoutError, openai.APIError, GCPInvalidRequestError) as other_error:
model_response = None
logger.info(f"Error for id {input_data.id} has occurred. Message: {other_error} ")
error_message = f"{type(other_error).__name__}: {other_error}"
Expand Down Expand Up @@ -331,7 +332,7 @@ def _validate_system_prompt(self, system_prompt: typing.Optional[str] = None) ->
raise ValueError(input_exception_message.get_system_prompt_contains_input_variables())

@staticmethod
def _extract_input_variables_from_prompt(prompt: str) -> set[str]:
def _extract_input_variables_from_prompt(prompt: str) -> typing.Set[str]:
input_variables_pattern = r'(?<!\{)\{([^{}]+)\}(?!\})'
input_variables_set = set(re.findall(input_variables_pattern, prompt))
return input_variables_set
Expand Down
1 change: 1 addition & 0 deletions allms/models/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _create_llm(self) -> AzureChatOpenAI:
model_name=self._config.model_name,
base_url=self._config.base_url,
api_key=self._config.api_key,
azure_ad_token=self._config.azure_ad_token,
temperature=self._temperature,
max_tokens=self._max_output_tokens,
request_timeout=self._request_timeout_s
Expand Down
7 changes: 7 additions & 0 deletions allms/models/vertexai_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from allms.constants.vertex_ai import VertexModelConstants


class GCPInvalidRequestError(Exception):
pass


class CustomVertexAI(VertexAI):
async def _agenerate(
self,
Expand All @@ -31,6 +35,9 @@ def was_response_blocked(generation: Generation) -> bool:
**kwargs
)

if not all(result.generations):
raise GCPInvalidRequestError("The response is empty. It may have been blocked due to content filtering.")

return LLMResult(
generations=(
chain(result.generations)
Expand Down
38 changes: 35 additions & 3 deletions allms/models/vertexai_gemini.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import typing
from asyncio import AbstractEventLoop
from langchain_community.llms.vertexai import VertexAI
from typing import Optional

from langchain_core.prompts import ChatPromptTemplate
from langchain_google_vertexai import VertexAI
from vertexai.preview import tokenization
from vertexai.tokenization._tokenizers import Tokenizer

from allms.defaults.general_defaults import GeneralDefaults
from allms.defaults.vertex_ai import GeminiModelDefaults
from allms.domain.configuration import VertexAIConfiguration
from allms.models.vertexai_base import CustomVertexAI
from allms.domain.input_data import InputData
from allms.models.abstract import AbstractModel
from allms.models.vertexai_base import CustomVertexAI

BASE_GEMINI_MODEL_NAMES = ["gemini-1.0-pro", "gemini-1.5-pro", "gemini-1.5-flash"]


class VertexAIGeminiModel(AbstractModel):
Expand All @@ -28,6 +36,8 @@ def __init__(
self._verbose = verbose
self._config = config

self._gcp_tokenizer = self._get_gcp_tokenizer(self._config.gemini_model_name)

super().__init__(
temperature=temperature,
model_total_max_tokens=model_total_max_tokens,
Expand All @@ -44,7 +54,29 @@ def _create_llm(self) -> CustomVertexAI:
temperature=self._temperature,
top_p=self._top_p,
top_k=self._top_k,
safety_settings=self._config.gemini_safety_settings,
verbose=self._verbose,
project=self._config.cloud_project,
location=self._config.cloud_location
)
)

def _get_prompt_tokens_number(self, prompt: ChatPromptTemplate, input_data: InputData) -> int:
return self._gcp_tokenizer.count_tokens(
prompt.format_prompt(**input_data.input_mappings).to_string()
).total_tokens

def _get_model_response_tokens_number(self, model_response: typing.Optional[str]) -> int:
if model_response:
return self._gcp_tokenizer.count_tokens(model_response).total_tokens
return 0

@staticmethod
def _get_gcp_tokenizer(model_name) -> Tokenizer:
try:
return tokenization.get_tokenizer_for_model(model_name)
except ValueError:
for base_model_name in BASE_GEMINI_MODEL_NAMES:
if model_name.startswith(base_model_name):
return tokenization.get_tokenizer_for_model(base_model_name)
raise

2 changes: 1 addition & 1 deletion allms/models/vertexai_gemma.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from asyncio import AbstractEventLoop

from langchain_community.llms.vertexai import VertexAIModelGarden
from langchain_google_vertexai import VertexAIModelGarden
from typing import Optional

from allms.defaults.general_defaults import GeneralDefaults
Expand Down
2 changes: 1 addition & 1 deletion allms/models/vertexai_palm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from asyncio import AbstractEventLoop
from langchain_community.llms.vertexai import VertexAI
from langchain_google_vertexai import VertexAI
from typing import Optional

from allms.defaults.general_defaults import GeneralDefaults
Expand Down
Loading

0 comments on commit 37ea284

Please sign in to comment.