diff --git a/application/llm/openai.py b/application/llm/openai.py index 34d568549..a132399aa 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -5,40 +5,38 @@ class OpenAILLM(BaseLLM): def __init__(self, api_key): global openai - import openai - openai.api_key = api_key - self.api_key = api_key # Save the API key to be used later + from openai import OpenAI + + self.client = OpenAI( + api_key=api_key, + ) + self.api_key = api_key def _get_openai(self): # Import openai when needed import openai - # Set the API key every time you import openai - openai.api_key = self.api_key + return openai def gen(self, model, engine, messages, stream=False, **kwargs): - response = openai.ChatCompletion.create( - model=model, - engine=engine, + response = self.client.chat.completions.create(model=model, messages=messages, stream=stream, - **kwargs - ) + **kwargs) - return response["choices"][0]["message"]["content"] + return response.choices[0].message.content def gen_stream(self, model, engine, messages, stream=True, **kwargs): - response = openai.ChatCompletion.create( - model=model, - engine=engine, + response = self.client.chat.completions.create(model=model, messages=messages, stream=stream, - **kwargs - ) + **kwargs) for line in response: - if "content" in line["choices"][0]["delta"]: - yield line["choices"][0]["delta"]["content"] + # import sys + # print(line.choices[0].delta.content, file=sys.stderr) + if line.choices[0].delta.content is not None: + yield line.choices[0].delta.content class AzureOpenAILLM(OpenAILLM): @@ -48,10 +46,15 @@ def __init__(self, openai_api_key, openai_api_base, openai_api_version, deployme self.api_base = settings.OPENAI_API_BASE, self.api_version = settings.OPENAI_API_VERSION, self.deployment_name = settings.AZURE_DEPLOYMENT_NAME, + from openai import AzureOpenAI + self.client = AzureOpenAI( + api_key=openai_api_key, + api_version=settings.OPENAI_API_VERSION, + api_base=settings.OPENAI_API_BASE, + deployment_name=settings.AZURE_DEPLOYMENT_NAME, + ) def _get_openai(self): openai = super()._get_openai() - openai.api_base = self.api_base - openai.api_version = self.api_version - openai.api_type = "azure" + return openai