-
Notifications
You must be signed in to change notification settings - Fork 894
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ZhipuAI provider integration to the python package. * Add `aisuite/providers/zhipuai_provider.py` implementing the `ZhipuaiProvider` class using the `zhipuai` library. * Update `ProviderFactory` in `aisuite/provider.py` to include ZhipuAI as a supported provider. * Add test cases for `ZhipuaiProvider` in `tests/client/test_client.py` and mock the `ZhipuaiProvider` in the test cases. * Update `README.md` to mention ZhipuAI as a supported provider and provide an example of using the ZhipuAI provider. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/andrewyng/aisuite?shareId=XXXX-XXXX-XXXX-XXXX).
- Loading branch information
Showing
3 changed files
with
61 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import zhipuai | ||
from aisuite.provider import Provider | ||
from aisuite.framework import ChatCompletionResponse | ||
|
||
|
||
class ZhipuaiProvider(Provider): | ||
def __init__(self, **config): | ||
""" | ||
Initialize the ZhipuAI provider with the given configuration. | ||
Pass the entire configuration dictionary to the ZhipuAI client constructor. | ||
""" | ||
self.client = zhipuai.ZhipuAI(**config) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
return self.normalize_response( | ||
self.client.chat.completions.create(model=model, messages=messages, **kwargs) | ||
) | ||
|
||
def normalize_response(self, response): | ||
"""Normalize the response from the ZhipuAI API to match OpenAI's response format.""" | ||
normalized_response = ChatCompletionResponse() | ||
normalized_response.choices[0].message.content = response.choices[0].message | ||
return normalized_response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters