Skip to content

Commit

Permalink
Adding override for invoke_llm to ChatOpenAI to catch param conversion
Browse files Browse the repository at this point in the history
* If we get a bad request error that is about max_completion_tokens,
  then we will monkey patch ChatOpenAI to prevent the automatic
conversion of max_tokens to max_completion_tokens

Signed-off-by: Shawn Hurley <[email protected]>
  • Loading branch information
shawn-hurley committed Feb 28, 2025
1 parent 74b4682 commit 5db5080
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions kai/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from openai import BadRequestError
from opentelemetry import trace
from pydantic.v1.utils import deep_update

Expand Down Expand Up @@ -255,6 +256,9 @@ def prepare_model_args(


class ModelProviderChatOpenAI(ModelProvider):

is_monkey_patched: bool = False

def __init__(self, config: KaiConfigModels, demo_mode: bool, cache: Cache | None):
super().__init__(
config=config,
Expand All @@ -281,6 +285,55 @@ def prepare_model_args(

return model_args, model_id

@override
def invoke_llm(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
configurable_fields: dict[str, Any] | None = None,
stop: list[str] | None = None,
do_continuation: bool = True,
**kwargs: Any,
) -> BaseMessage:
try:
return super().invoke_llm(
input, config, configurable_fields, stop, do_continuation, **kwargs
)
except BadRequestError as e:
# if we already tried to monkey patch then some other config is broken.
if self.is_monkey_patched:
raise e

@property # type: ignore[misc]
def _default_params(self: ChatOpenAI) -> dict[str, Any]:
return super(ChatOpenAI, self)._default_params

def _get_request_payload(
self: ChatOpenAI,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict: # type: ignore[type-arg]
return super(ChatOpenAI, self)._get_request_payload(
input_, stop=stop, **kwargs
)

if "max_completion_tokens" in e.message:
LOG.debug(
f"got error: {e} - attempting to monkey patch to prevent conversion of max_tokens"
)
ChatOpenAI._default_params = _default_params # type: ignore[method-assign]
ChatOpenAI._get_request_payload = _get_request_payload # type: ignore[method-assign]
self.is_monkey_patched = True
return super().invoke_llm(
input, config, configurable_fields, stop, do_continuation, **kwargs
)
else:
raise e
except Exception as e:
raise e


class ModelProviderChatBedrock(ModelProvider):
def __init__(self, config: KaiConfigModels, demo_mode: bool, cache: Cache | None):
Expand Down

0 comments on commit 5db5080

Please sign in to comment.