From d5a83c68b0746fe16c8916b7058967f68e8c4c69 Mon Sep 17 00:00:00 2001 From: mikelarg Date: Wed, 29 Jan 2025 17:25:31 +0300 Subject: [PATCH 1/2] feat: added x_headers metadata to response --- src/gigachat/api/assistants/get_assistants.py | 17 +++------------ .../api/assistants/post_assistant_delete.py | 17 +++------------ .../assistants/post_assistant_files_delete.py | 17 +++------------ .../api/assistants/post_assistant_modify.py | 17 +++------------ .../api/assistants/post_assistants.py | 17 +++------------ src/gigachat/api/get_balance.py | 17 +++------------ src/gigachat/api/get_image.py | 5 +++-- src/gigachat/api/get_model.py | 17 +++------------ src/gigachat/api/get_models.py | 17 +++------------ src/gigachat/api/post_auth.py | 17 +++------------ src/gigachat/api/post_chat.py | 17 +++------------ src/gigachat/api/post_embeddings.py | 17 +++------------ src/gigachat/api/post_files.py | 17 +++------------ src/gigachat/api/post_functions_convert.py | 17 +++------------ src/gigachat/api/post_token.py | 17 +++------------ src/gigachat/api/stream_chat.py | 6 +++++- src/gigachat/api/threads/get_threads.py | 17 +++------------ .../api/threads/get_threads_messages.py | 17 +++------------ src/gigachat/api/threads/get_threads_run.py | 17 +++------------ .../api/threads/post_thread_messages_rerun.py | 17 +++------------ .../post_thread_messages_rerun_stream.py | 6 +++++- .../api/threads/post_thread_messages_run.py | 17 +++------------ .../post_thread_messages_run_stream.py | 6 +++++- .../api/threads/post_threads_messages.py | 17 +++------------ .../api/threads/post_threads_retrieve.py | 17 +++------------ src/gigachat/api/threads/post_threads_run.py | 17 +++------------ src/gigachat/api/utils.py | 21 +++++++++++++++++++ src/gigachat/client.py | 2 +- src/gigachat/models/__init__.py | 2 ++ src/gigachat/models/access_token.py | 4 ++-- .../models/assistants/assistant_delete.py | 4 ++-- .../assistants/assistant_file_delete.py | 4 ++-- src/gigachat/models/assistants/assistants.py | 4 ++-- .../models/assistants/create_assistant.py | 4 ++-- src/gigachat/models/balance.py | 3 ++- src/gigachat/models/chat_completion.py | 5 +++-- src/gigachat/models/chat_completion_chunk.py | 5 +++-- src/gigachat/models/embeddings.py | 5 +++-- src/gigachat/models/image.py | 4 ++-- src/gigachat/models/model.py | 5 +++-- src/gigachat/models/models.py | 5 +++-- src/gigachat/models/open_api_functions.py | 6 +++--- .../models/threads/thread_completion.py | 7 ++++--- .../models/threads/thread_completion_chunk.py | 7 ++++--- .../models/threads/thread_messages.py | 4 ++-- .../threads/thread_messages_response.py | 4 ++-- .../models/threads/thread_run_response.py | 4 ++-- .../models/threads/thread_run_result.py | 4 ++-- src/gigachat/models/threads/threads.py | 4 ++-- src/gigachat/models/token.py | 4 ++-- src/gigachat/models/tokens_count.py | 5 +++-- src/gigachat/models/uploaded_file.py | 5 +++-- src/gigachat/models/with_x_headers.py | 8 +++++++ tests/unit_tests/gigachat/test_client.py | 14 +++++++++---- 54 files changed, 184 insertions(+), 366 deletions(-) create mode 100644 src/gigachat/models/with_x_headers.py diff --git a/src/gigachat/api/assistants/get_assistants.py b/src/gigachat/api/assistants/get_assistants.py index 6faf355..584ab2d 100644 --- a/src/gigachat/api/assistants/get_assistants.py +++ b/src/gigachat/api/assistants/get_assistants.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models.assistants import Assistants @@ -24,15 +22,6 @@ def _get_kwargs( return params -def _build_response(response: httpx.Response) -> Assistants: - if response.status_code == HTTPStatus.OK: - return Assistants(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -42,7 +31,7 @@ def sync( """Возвращает массив объектов с данными доступных ассистентов""" kwargs = _get_kwargs(assistant_id=assistant_id, access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, Assistants) async def asyncio( @@ -54,4 +43,4 @@ async def asyncio( """Возвращает массив объектов с данными доступных ассистентов""" kwargs = _get_kwargs(assistant_id=assistant_id, access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, Assistants) diff --git a/src/gigachat/api/assistants/post_assistant_delete.py b/src/gigachat/api/assistants/post_assistant_delete.py index 4871d1f..39a70b6 100644 --- a/src/gigachat/api/assistants/post_assistant_delete.py +++ b/src/gigachat/api/assistants/post_assistant_delete.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models.assistants import AssistantDelete @@ -25,15 +23,6 @@ def _get_kwargs( } -def _build_response(response: httpx.Response) -> AssistantDelete: - if response.status_code == HTTPStatus.OK: - return AssistantDelete(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -42,7 +31,7 @@ def sync( ) -> AssistantDelete: kwargs = _get_kwargs(assistant_id=assistant_id, access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, AssistantDelete) async def asyncio( @@ -53,4 +42,4 @@ async def asyncio( ) -> AssistantDelete: kwargs = _get_kwargs(assistant_id=assistant_id, access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, AssistantDelete) diff --git a/src/gigachat/api/assistants/post_assistant_files_delete.py b/src/gigachat/api/assistants/post_assistant_files_delete.py index 75ae742..e4478d8 100644 --- a/src/gigachat/api/assistants/post_assistant_files_delete.py +++ b/src/gigachat/api/assistants/post_assistant_files_delete.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models.assistants import AssistantFileDelete @@ -27,15 +25,6 @@ def _get_kwargs( } -def _build_response(response: httpx.Response) -> AssistantFileDelete: - if response.status_code == HTTPStatus.OK: - return AssistantFileDelete(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -45,7 +34,7 @@ def sync( ) -> AssistantFileDelete: kwargs = _get_kwargs(assistant_id=assistant_id, file_id=file_id, access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, AssistantFileDelete) async def asyncio( @@ -57,4 +46,4 @@ async def asyncio( ) -> AssistantFileDelete: kwargs = _get_kwargs(assistant_id=assistant_id, file_id=file_id, access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, AssistantFileDelete) diff --git a/src/gigachat/api/assistants/post_assistant_modify.py b/src/gigachat/api/assistants/post_assistant_modify.py index 2fc93de..9aab500 100644 --- a/src/gigachat/api/assistants/post_assistant_modify.py +++ b/src/gigachat/api/assistants/post_assistant_modify.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, List, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models import Function from gigachat.models.assistants import Assistant @@ -41,15 +39,6 @@ def _get_kwargs( } -def _build_response(response: httpx.Response) -> Assistant: - if response.status_code == HTTPStatus.OK: - return Assistant(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -73,7 +62,7 @@ def sync( access_token=access_token, ) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, Assistant) async def asyncio( @@ -99,4 +88,4 @@ async def asyncio( access_token=access_token, ) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, Assistant) diff --git a/src/gigachat/api/assistants/post_assistants.py b/src/gigachat/api/assistants/post_assistants.py index 0d5936d..1571a61 100644 --- a/src/gigachat/api/assistants/post_assistants.py +++ b/src/gigachat/api/assistants/post_assistants.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, List, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models import Function from gigachat.models.assistants import CreateAssistant @@ -40,15 +38,6 @@ def _get_kwargs( } -def _build_response(response: httpx.Response) -> CreateAssistant: - if response.status_code == HTTPStatus.OK: - return CreateAssistant(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -73,7 +62,7 @@ def sync( access_token=access_token, ) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, CreateAssistant) async def asyncio( @@ -100,4 +89,4 @@ async def asyncio( access_token=access_token, ) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, CreateAssistant) diff --git a/src/gigachat/api/get_balance.py b/src/gigachat/api/get_balance.py index cfaf38d..cf0600f 100644 --- a/src/gigachat/api/get_balance.py +++ b/src/gigachat/api/get_balance.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models.balance import Balance @@ -21,15 +19,6 @@ def _get_kwargs( } -def _build_response(response: httpx.Response) -> Balance: - if response.status_code == HTTPStatus.OK: - return Balance(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -39,7 +28,7 @@ def sync( Только для клиентов с предоплатой иначе http 403""" kwargs = _get_kwargs(access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, Balance) async def asyncio( @@ -51,4 +40,4 @@ async def asyncio( Только для клиентов с предоплатой иначе http 403""" kwargs = _get_kwargs(access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, Balance) diff --git a/src/gigachat/api/get_image.py b/src/gigachat/api/get_image.py index cf99816..749cfad 100644 --- a/src/gigachat/api/get_image.py +++ b/src/gigachat/api/get_image.py @@ -4,7 +4,7 @@ import httpx -from gigachat.api.utils import build_headers +from gigachat.api.utils import build_headers, build_x_headers from gigachat.exceptions import AuthenticationError, ResponseError from gigachat.models import Image @@ -25,7 +25,8 @@ def _get_kwargs( def _build_response(response: httpx.Response) -> Image: if response.status_code == HTTPStatus.OK: - return Image(content=base64.b64encode(response.content).decode()) + x_headers = build_x_headers(response) + return Image(x_headers=x_headers, content=base64.b64encode(response.content).decode()) elif response.status_code == HTTPStatus.UNAUTHORIZED: raise AuthenticationError(response.url, response.status_code, response.content, response.headers) else: diff --git a/src/gigachat/api/get_model.py b/src/gigachat/api/get_model.py index 469ff27..30d4bef 100644 --- a/src/gigachat/api/get_model.py +++ b/src/gigachat/api/get_model.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models import Model @@ -22,15 +20,6 @@ def _get_kwargs( } -def _build_response(response: httpx.Response) -> Model: - if response.status_code == HTTPStatus.OK: - return Model(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -40,7 +29,7 @@ def sync( """Возвращает объект с описанием указанной модели""" kwargs = _get_kwargs(model=model, access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, Model) async def asyncio( @@ -52,4 +41,4 @@ async def asyncio( """Возвращает объект с описанием указанной модели""" kwargs = _get_kwargs(model=model, access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, Model) diff --git a/src/gigachat/api/get_models.py b/src/gigachat/api/get_models.py index 50aaa15..6eeb8fc 100644 --- a/src/gigachat/api/get_models.py +++ b/src/gigachat/api/get_models.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models import Models @@ -21,15 +19,6 @@ def _get_kwargs( } -def _build_response(response: httpx.Response) -> Models: - if response.status_code == HTTPStatus.OK: - return Models(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -38,7 +27,7 @@ def sync( """Возвращает массив объектов с данными доступных моделей""" kwargs = _get_kwargs(access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, Models) async def asyncio( @@ -49,4 +38,4 @@ async def asyncio( """Возвращает массив объектов с данными доступных моделей""" kwargs = _get_kwargs(access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, Models) diff --git a/src/gigachat/api/post_auth.py b/src/gigachat/api/post_auth.py index 15b3124..97df430 100644 --- a/src/gigachat/api/post_auth.py +++ b/src/gigachat/api/post_auth.py @@ -2,13 +2,11 @@ import binascii import logging import uuid -from http import HTTPStatus from typing import Any, Dict import httpx -from gigachat.api.utils import USER_AGENT -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import USER_AGENT, build_response from gigachat.models import AccessToken _logger = logging.getLogger(__name__) @@ -28,15 +26,6 @@ def _get_kwargs(*, url: str, credentials: str, scope: str) -> Dict[str, Any]: } -def _build_response(response: httpx.Response) -> AccessToken: - if response.status_code == HTTPStatus.OK: - return AccessToken(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def _validate_credentials(credentials: str) -> None: try: base64.b64decode(credentials, validate=True) @@ -50,11 +39,11 @@ def sync(client: httpx.Client, *, url: str, credentials: str, scope: str) -> Acc _validate_credentials(credentials) kwargs = _get_kwargs(url=url, credentials=credentials, scope=scope) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, AccessToken) async def asyncio(client: httpx.AsyncClient, *, url: str, credentials: str, scope: str) -> AccessToken: _validate_credentials(credentials) kwargs = _get_kwargs(url=url, credentials=credentials, scope=scope) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, AccessToken) diff --git a/src/gigachat/api/post_chat.py b/src/gigachat/api/post_chat.py index cb340e7..6be6371 100644 --- a/src/gigachat/api/post_chat.py +++ b/src/gigachat/api/post_chat.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models import Chat, ChatCompletion @@ -23,15 +21,6 @@ def _get_kwargs( } -def _build_response(response: httpx.Response) -> ChatCompletion: - if response.status_code == HTTPStatus.OK: - return ChatCompletion(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -40,7 +29,7 @@ def sync( ) -> ChatCompletion: kwargs = _get_kwargs(chat=chat, access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, ChatCompletion) async def asyncio( @@ -51,4 +40,4 @@ async def asyncio( ) -> ChatCompletion: kwargs = _get_kwargs(chat=chat, access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, ChatCompletion) diff --git a/src/gigachat/api/post_embeddings.py b/src/gigachat/api/post_embeddings.py index 2e0df75..779d761 100644 --- a/src/gigachat/api/post_embeddings.py +++ b/src/gigachat/api/post_embeddings.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, List, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models import Embeddings @@ -24,15 +22,6 @@ def _get_kwargs( } -def _build_response(response: httpx.Response) -> Embeddings: - if response.status_code == HTTPStatus.OK: - return Embeddings(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -42,7 +31,7 @@ def sync( ) -> Embeddings: kwargs = _get_kwargs(input_=input_, model=model, access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, Embeddings) async def asyncio( @@ -54,4 +43,4 @@ async def asyncio( ) -> Embeddings: kwargs = _get_kwargs(input_=input_, model=model, access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, Embeddings) diff --git a/src/gigachat/api/post_files.py b/src/gigachat/api/post_files.py index 07580fe..261252a 100644 --- a/src/gigachat/api/post_files.py +++ b/src/gigachat/api/post_files.py @@ -1,11 +1,9 @@ -from http import HTTPStatus from typing import Any, Dict, Literal, Optional import httpx from gigachat._types import FileTypes -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models import UploadedFile @@ -26,15 +24,6 @@ def _get_kwargs( } -def _build_response(response: httpx.Response) -> UploadedFile: - if response.status_code == HTTPStatus.OK: - return UploadedFile(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -44,7 +33,7 @@ def sync( ) -> UploadedFile: kwargs = _get_kwargs(file=file, purpose=purpose, access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, UploadedFile) async def asyncio( @@ -56,4 +45,4 @@ async def asyncio( ) -> UploadedFile: kwargs = _get_kwargs(file=file, purpose=purpose, access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, UploadedFile) diff --git a/src/gigachat/api/post_functions_convert.py b/src/gigachat/api/post_functions_convert.py index e4924df..1fe8eef 100644 --- a/src/gigachat/api/post_functions_convert.py +++ b/src/gigachat/api/post_functions_convert.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models.open_api_functions import OpenApiFunctions @@ -23,15 +21,6 @@ def _get_kwargs( } -def _build_response(response: httpx.Response) -> OpenApiFunctions: - if response.status_code == HTTPStatus.OK: - return OpenApiFunctions(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -41,7 +30,7 @@ def sync( """Конвертация описание функции в формате OpenAPI в gigachat функцию""" kwargs = _get_kwargs(openapi_function=openapi_function, access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, OpenApiFunctions) async def asyncio( @@ -53,4 +42,4 @@ async def asyncio( """Конвертация описание функции в формате OpenAPI в gigachat функцию""" kwargs = _get_kwargs(openapi_function=openapi_function, access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, OpenApiFunctions) diff --git a/src/gigachat/api/post_token.py b/src/gigachat/api/post_token.py index 6d0ba48..2d827ab 100644 --- a/src/gigachat/api/post_token.py +++ b/src/gigachat/api/post_token.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models import Token @@ -23,15 +21,6 @@ def _get_kwargs( } -def _build_response(response: httpx.Response) -> Token: - if response.status_code == HTTPStatus.OK: - return Token(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -40,7 +29,7 @@ def sync( ) -> Token: kwargs = _get_kwargs(user=user, password=password) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, Token) async def asyncio( @@ -51,4 +40,4 @@ async def asyncio( ) -> Token: kwargs = _get_kwargs(user=user, password=password) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, Token) diff --git a/src/gigachat/api/stream_chat.py b/src/gigachat/api/stream_chat.py index 7732fa4..b251b1e 100644 --- a/src/gigachat/api/stream_chat.py +++ b/src/gigachat/api/stream_chat.py @@ -3,7 +3,7 @@ import httpx -from gigachat.api.utils import build_headers, parse_chunk +from gigachat.api.utils import build_headers, build_x_headers, parse_chunk from gigachat.exceptions import AuthenticationError, ResponseError from gigachat.models import Chat, ChatCompletionChunk @@ -60,8 +60,10 @@ def sync( kwargs = _get_kwargs(chat=chat, access_token=access_token) with client.stream(**kwargs) as response: _check_response(response) + x_headers = build_x_headers(response) for line in response.iter_lines(): if chunk := parse_chunk(line, ChatCompletionChunk): + chunk.x_headers = x_headers yield chunk @@ -74,6 +76,8 @@ async def asyncio( kwargs = _get_kwargs(chat=chat, access_token=access_token) async with client.stream(**kwargs) as response: await _acheck_response(response) + x_headers = build_x_headers(response) async for line in response.aiter_lines(): if chunk := parse_chunk(line, ChatCompletionChunk): + chunk.x_headers = x_headers yield chunk diff --git a/src/gigachat/api/threads/get_threads.py b/src/gigachat/api/threads/get_threads.py index c161969..427619d 100644 --- a/src/gigachat/api/threads/get_threads.py +++ b/src/gigachat/api/threads/get_threads.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, List, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models.threads import Threads @@ -32,15 +30,6 @@ def _get_kwargs( return params -def _build_response(response: httpx.Response) -> Threads: - if response.status_code == HTTPStatus.OK: - return Threads(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -57,7 +46,7 @@ def sync( access_token=access_token, ) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, Threads) async def asyncio( @@ -76,4 +65,4 @@ async def asyncio( access_token=access_token, ) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, Threads) diff --git a/src/gigachat/api/threads/get_threads_messages.py b/src/gigachat/api/threads/get_threads_messages.py index f388c8c..a011bf3 100644 --- a/src/gigachat/api/threads/get_threads_messages.py +++ b/src/gigachat/api/threads/get_threads_messages.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models.threads import ThreadMessages @@ -30,15 +28,6 @@ def _get_kwargs( return params -def _build_response(response: httpx.Response) -> ThreadMessages: - if response.status_code == HTTPStatus.OK: - return ThreadMessages(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -50,7 +39,7 @@ def sync( """Получение сообщений треда""" kwargs = _get_kwargs(thread_id=thread_id, limit=limit, before=before, access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, ThreadMessages) async def asyncio( @@ -64,4 +53,4 @@ async def asyncio( """Получение сообщений треда""" kwargs = _get_kwargs(thread_id=thread_id, limit=limit, before=before, access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, ThreadMessages) diff --git a/src/gigachat/api/threads/get_threads_run.py b/src/gigachat/api/threads/get_threads_run.py index a260096..17d1c32 100644 --- a/src/gigachat/api/threads/get_threads_run.py +++ b/src/gigachat/api/threads/get_threads_run.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models.threads import ThreadRunResult @@ -23,15 +21,6 @@ def _get_kwargs( return params -def _build_response(response: httpx.Response) -> ThreadRunResult: - if response.status_code == HTTPStatus.OK: - return ThreadRunResult(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -41,7 +30,7 @@ def sync( """Получить результат run треда""" kwargs = _get_kwargs(thread_id=thread_id, access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, ThreadRunResult) async def asyncio( @@ -53,4 +42,4 @@ async def asyncio( """Получить результат run треда""" kwargs = _get_kwargs(thread_id=thread_id, access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, ThreadRunResult) diff --git a/src/gigachat/api/threads/post_thread_messages_rerun.py b/src/gigachat/api/threads/post_thread_messages_rerun.py index 40da1c0..1b27969 100644 --- a/src/gigachat/api/threads/post_thread_messages_rerun.py +++ b/src/gigachat/api/threads/post_thread_messages_rerun.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models.threads import ThreadCompletion, ThreadRunOptions @@ -32,15 +30,6 @@ def _get_kwargs( return params -def _build_response(response: httpx.Response) -> ThreadCompletion: - if response.status_code == HTTPStatus.OK: - return ThreadCompletion(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -55,7 +44,7 @@ def sync( access_token=access_token, ) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, ThreadCompletion) async def asyncio( @@ -72,4 +61,4 @@ async def asyncio( access_token=access_token, ) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, ThreadCompletion) diff --git a/src/gigachat/api/threads/post_thread_messages_rerun_stream.py b/src/gigachat/api/threads/post_thread_messages_rerun_stream.py index df8090e..3d5899c 100644 --- a/src/gigachat/api/threads/post_thread_messages_rerun_stream.py +++ b/src/gigachat/api/threads/post_thread_messages_rerun_stream.py @@ -3,7 +3,7 @@ import httpx -from gigachat.api.utils import build_headers, parse_chunk +from gigachat.api.utils import build_headers, build_x_headers, parse_chunk from gigachat.exceptions import AuthenticationError, ResponseError from gigachat.models.threads import ThreadCompletionChunk, ThreadRunOptions @@ -78,8 +78,10 @@ def sync( ) with client.stream(**kwargs) as response: _check_response(response) + x_headers = build_x_headers(response) for line in response.iter_lines(): if chunk := parse_chunk(line, ThreadCompletionChunk): + chunk.x_headers = x_headers yield chunk @@ -100,6 +102,8 @@ async def asyncio( ) async with client.stream(**kwargs) as response: await _acheck_response(response) + x_headers = build_x_headers(response) async for line in response.aiter_lines(): if chunk := parse_chunk(line, ThreadCompletionChunk): + chunk.x_headers = x_headers yield chunk diff --git a/src/gigachat/api/threads/post_thread_messages_run.py b/src/gigachat/api/threads/post_thread_messages_run.py index 81c032a..061dc5a 100644 --- a/src/gigachat/api/threads/post_thread_messages_run.py +++ b/src/gigachat/api/threads/post_thread_messages_run.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, List, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models import Messages from gigachat.models.threads import ThreadCompletion, ThreadRunOptions @@ -41,15 +39,6 @@ def _get_kwargs( return params -def _build_response(response: httpx.Response) -> ThreadCompletion: - if response.status_code == HTTPStatus.OK: - return ThreadCompletion(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -70,7 +59,7 @@ def sync( access_token=access_token, ) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, ThreadCompletion) async def asyncio( @@ -93,4 +82,4 @@ async def asyncio( access_token=access_token, ) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, ThreadCompletion) diff --git a/src/gigachat/api/threads/post_thread_messages_run_stream.py b/src/gigachat/api/threads/post_thread_messages_run_stream.py index d5d8ead..8d1f462 100644 --- a/src/gigachat/api/threads/post_thread_messages_run_stream.py +++ b/src/gigachat/api/threads/post_thread_messages_run_stream.py @@ -3,7 +3,7 @@ import httpx -from gigachat.api.utils import build_headers, parse_chunk +from gigachat.api.utils import build_headers, build_x_headers, parse_chunk from gigachat.exceptions import AuthenticationError, ResponseError from gigachat.models import Messages from gigachat.models.threads import ThreadCompletionChunk, ThreadRunOptions @@ -92,8 +92,10 @@ def sync( ) with client.stream(**kwargs) as response: _check_response(response) + x_headers = build_x_headers(response) for line in response.iter_lines(): if chunk := parse_chunk(line, ThreadCompletionChunk): + chunk.x_headers = x_headers yield chunk @@ -119,6 +121,8 @@ async def asyncio( ) async with client.stream(**kwargs) as response: await _acheck_response(response) + x_headers = build_x_headers(response) async for line in response.aiter_lines(): if chunk := parse_chunk(line, ThreadCompletionChunk): + chunk.x_headers = x_headers yield chunk diff --git a/src/gigachat/api/threads/post_threads_messages.py b/src/gigachat/api/threads/post_threads_messages.py index 646801a..24c9679 100644 --- a/src/gigachat/api/threads/post_threads_messages.py +++ b/src/gigachat/api/threads/post_threads_messages.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, List, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models import Messages from gigachat.models.threads import ThreadMessagesResponse @@ -34,15 +32,6 @@ def _get_kwargs( return params -def _build_response(response: httpx.Response) -> ThreadMessagesResponse: - if response.status_code == HTTPStatus.OK: - return ThreadMessagesResponse(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -61,7 +50,7 @@ def sync( access_token=access_token, ) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, ThreadMessagesResponse) async def asyncio( @@ -82,4 +71,4 @@ async def asyncio( access_token=access_token, ) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, ThreadMessagesResponse) diff --git a/src/gigachat/api/threads/post_threads_retrieve.py b/src/gigachat/api/threads/post_threads_retrieve.py index 5421e5d..32a3fde 100644 --- a/src/gigachat/api/threads/post_threads_retrieve.py +++ b/src/gigachat/api/threads/post_threads_retrieve.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, List, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models.threads import Threads @@ -25,15 +23,6 @@ def _get_kwargs( return params -def _build_response(response: httpx.Response) -> Threads: - if response.status_code == HTTPStatus.OK: - return Threads(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -43,7 +32,7 @@ def sync( """Получение перечня тредов по идентификаторам""" kwargs = _get_kwargs(threads_ids=threads_ids, access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, Threads) async def asyncio( @@ -55,4 +44,4 @@ async def asyncio( """Получение перечня тредов по идентификаторам""" kwargs = _get_kwargs(threads_ids=threads_ids, access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, Threads) diff --git a/src/gigachat/api/threads/post_threads_run.py b/src/gigachat/api/threads/post_threads_run.py index 44bf1cc..1f74852 100644 --- a/src/gigachat/api/threads/post_threads_run.py +++ b/src/gigachat/api/threads/post_threads_run.py @@ -1,10 +1,8 @@ -from http import HTTPStatus from typing import Any, Dict, Optional import httpx -from gigachat.api.utils import build_headers -from gigachat.exceptions import AuthenticationError, ResponseError +from gigachat.api.utils import build_headers, build_response from gigachat.models.threads import ThreadRunOptions, ThreadRunResponse @@ -27,15 +25,6 @@ def _get_kwargs( return params -def _build_response(response: httpx.Response) -> ThreadRunResponse: - if response.status_code == HTTPStatus.OK: - return ThreadRunResponse(**response.json()) - elif response.status_code == HTTPStatus.UNAUTHORIZED: - raise AuthenticationError(response.url, response.status_code, response.content, response.headers) - else: - raise ResponseError(response.url, response.status_code, response.content, response.headers) - - def sync( client: httpx.Client, *, @@ -46,7 +35,7 @@ def sync( """Получить результат run треда""" kwargs = _get_kwargs(thread_id=thread_id, thread_options=thread_options, access_token=access_token) response = client.request(**kwargs) - return _build_response(response) + return build_response(response, ThreadRunResponse) async def asyncio( @@ -59,4 +48,4 @@ async def asyncio( """Получить результат run треда""" kwargs = _get_kwargs(thread_id=thread_id, thread_options=thread_options, access_token=access_token) response = await client.request(**kwargs) - return _build_response(response) + return build_response(response, ThreadRunResponse) diff --git a/src/gigachat/api/utils.py b/src/gigachat/api/utils.py index f3b5ef3..3a583f4 100644 --- a/src/gigachat/api/utils.py +++ b/src/gigachat/api/utils.py @@ -1,6 +1,9 @@ import logging +from http import HTTPStatus from typing import Dict, Optional, Type, TypeVar +import httpx + from gigachat.context import ( authorization_cvar, client_id_cvar, @@ -9,6 +12,7 @@ service_id_cvar, session_id_cvar, ) +from gigachat.exceptions import AuthenticationError, ResponseError from gigachat.pydantic_v1 import BaseModel _logger = logging.getLogger(__name__) @@ -62,3 +66,20 @@ def parse_chunk(line: str, model_class: Type[T]) -> Optional[T]: raise e else: return None + + +def build_x_headers(response: httpx.Response) -> Dict[str, Optional[str]]: + return { + "x-request-id": response.headers.get("x-request-id"), + "x-session-id": response.headers.get("x-session-id"), + "x-client-id": response.headers.get("x-client-id"), + } + + +def build_response(response: httpx.Response, model_class: Type[T]) -> T: + if response.status_code == HTTPStatus.OK: + return model_class(x_headers=build_x_headers(response), **response.json()) + elif response.status_code == HTTPStatus.UNAUTHORIZED: + raise AuthenticationError(response.url, response.status_code, response.content, response.headers) + else: + raise ResponseError(response.url, response.status_code, response.content, response.headers) diff --git a/src/gigachat/client.py b/src/gigachat/client.py index b15fb55..85d707f 100644 --- a/src/gigachat/client.py +++ b/src/gigachat/client.py @@ -111,7 +111,7 @@ def _parse_chat(payload: Union[Chat, Dict[str, Any], str], settings: Settings) - def _build_access_token(token: Token) -> AccessToken: - return AccessToken(access_token=token.tok, expires_at=token.exp) + return AccessToken(access_token=token.tok, expires_at=token.exp, x_headers=token.x_headers) class _BaseClient: diff --git a/src/gigachat/models/__init__.py b/src/gigachat/models/__init__.py index b2a8974..7551629 100644 --- a/src/gigachat/models/__init__.py +++ b/src/gigachat/models/__init__.py @@ -22,6 +22,7 @@ from gigachat.models.tokens_count import TokensCount from gigachat.models.uploaded_file import UploadedFile from gigachat.models.usage import Usage +from gigachat.models.with_x_headers import WithXHeaders __all__ = ( "AccessToken", @@ -50,4 +51,5 @@ "Image", "threads", "assistants", + "WithXHeaders", ) diff --git a/src/gigachat/models/access_token.py b/src/gigachat/models/access_token.py index 4a2a125..fee7531 100644 --- a/src/gigachat/models/access_token.py +++ b/src/gigachat/models/access_token.py @@ -1,7 +1,7 @@ -from gigachat.pydantic_v1 import BaseModel +from gigachat.models.with_x_headers import WithXHeaders -class AccessToken(BaseModel): +class AccessToken(WithXHeaders): """Токен доступа""" access_token: str diff --git a/src/gigachat/models/assistants/assistant_delete.py b/src/gigachat/models/assistants/assistant_delete.py index eac781f..e9822fb 100644 --- a/src/gigachat/models/assistants/assistant_delete.py +++ b/src/gigachat/models/assistants/assistant_delete.py @@ -1,7 +1,7 @@ -from gigachat.pydantic_v1 import BaseModel +from gigachat.models.with_x_headers import WithXHeaders -class AssistantDelete(BaseModel): +class AssistantDelete(WithXHeaders): """Информация об удаленном ассистенте""" assistant_id: str diff --git a/src/gigachat/models/assistants/assistant_file_delete.py b/src/gigachat/models/assistants/assistant_file_delete.py index 056f411..a01ca5b 100644 --- a/src/gigachat/models/assistants/assistant_file_delete.py +++ b/src/gigachat/models/assistants/assistant_file_delete.py @@ -1,7 +1,7 @@ -from gigachat.pydantic_v1 import BaseModel +from gigachat.models.with_x_headers import WithXHeaders -class AssistantFileDelete(BaseModel): +class AssistantFileDelete(WithXHeaders): """Информация об удаленном файле""" file_id: str diff --git a/src/gigachat/models/assistants/assistants.py b/src/gigachat/models/assistants/assistants.py index ef468e4..85742c0 100644 --- a/src/gigachat/models/assistants/assistants.py +++ b/src/gigachat/models/assistants/assistants.py @@ -1,10 +1,10 @@ from typing import List from gigachat.models.assistants.assistant import Assistant -from gigachat.pydantic_v1 import BaseModel +from gigachat.models.with_x_headers import WithXHeaders -class Assistants(BaseModel): +class Assistants(WithXHeaders): """Доступные ассистенты""" data: List[Assistant] diff --git a/src/gigachat/models/assistants/create_assistant.py b/src/gigachat/models/assistants/create_assistant.py index bcc32f0..2a54206 100644 --- a/src/gigachat/models/assistants/create_assistant.py +++ b/src/gigachat/models/assistants/create_assistant.py @@ -1,7 +1,7 @@ -from gigachat.pydantic_v1 import BaseModel +from gigachat.models.with_x_headers import WithXHeaders -class CreateAssistant(BaseModel): +class CreateAssistant(WithXHeaders): """Информация о созданном ассистенте""" assistant_id: str diff --git a/src/gigachat/models/balance.py b/src/gigachat/models/balance.py index df1124a..0fba60d 100644 --- a/src/gigachat/models/balance.py +++ b/src/gigachat/models/balance.py @@ -1,5 +1,6 @@ from typing import List +from gigachat.models.with_x_headers import WithXHeaders from gigachat.pydantic_v1 import BaseModel @@ -12,7 +13,7 @@ class BalanceValue(BaseModel): """Количество доступных токенов""" -class Balance(BaseModel): +class Balance(WithXHeaders): """Текущий баланс""" balance: List[BalanceValue] diff --git a/src/gigachat/models/chat_completion.py b/src/gigachat/models/chat_completion.py index 24e56f2..702a439 100644 --- a/src/gigachat/models/chat_completion.py +++ b/src/gigachat/models/chat_completion.py @@ -2,10 +2,11 @@ from gigachat.models.choices import Choices from gigachat.models.usage import Usage -from gigachat.pydantic_v1 import BaseModel, Field +from gigachat.models.with_x_headers import WithXHeaders +from gigachat.pydantic_v1 import Field -class ChatCompletion(BaseModel): +class ChatCompletion(WithXHeaders): """Ответ модели""" choices: List[Choices] diff --git a/src/gigachat/models/chat_completion_chunk.py b/src/gigachat/models/chat_completion_chunk.py index c149648..2d3577d 100644 --- a/src/gigachat/models/chat_completion_chunk.py +++ b/src/gigachat/models/chat_completion_chunk.py @@ -1,10 +1,11 @@ from typing import List from gigachat.models.choices_chunk import ChoicesChunk -from gigachat.pydantic_v1 import BaseModel, Field +from gigachat.models.with_x_headers import WithXHeaders +from gigachat.pydantic_v1 import Field -class ChatCompletionChunk(BaseModel): +class ChatCompletionChunk(WithXHeaders): """Ответ модели в потоке""" choices: List[ChoicesChunk] diff --git a/src/gigachat/models/embeddings.py b/src/gigachat/models/embeddings.py index ee707b8..3be234c 100644 --- a/src/gigachat/models/embeddings.py +++ b/src/gigachat/models/embeddings.py @@ -1,10 +1,11 @@ from typing import List, Optional from gigachat.models.embedding import Embedding -from gigachat.pydantic_v1 import BaseModel, Field +from gigachat.models.with_x_headers import WithXHeaders +from gigachat.pydantic_v1 import Field -class Embeddings(BaseModel): +class Embeddings(WithXHeaders): """Ответ модели""" data: List[Embedding] diff --git a/src/gigachat/models/image.py b/src/gigachat/models/image.py index 5c86b38..42a00c0 100644 --- a/src/gigachat/models/image.py +++ b/src/gigachat/models/image.py @@ -1,7 +1,7 @@ -from gigachat.pydantic_v1 import BaseModel +from gigachat.models.with_x_headers import WithXHeaders -class Image(BaseModel): +class Image(WithXHeaders): """Изображение""" content: str diff --git a/src/gigachat/models/model.py b/src/gigachat/models/model.py index dca3337..51aa415 100644 --- a/src/gigachat/models/model.py +++ b/src/gigachat/models/model.py @@ -1,7 +1,8 @@ -from gigachat.pydantic_v1 import BaseModel, Field +from gigachat.models.with_x_headers import WithXHeaders +from gigachat.pydantic_v1 import Field -class Model(BaseModel): +class Model(WithXHeaders): """Описание модели""" id_: str = Field(alias="id") diff --git a/src/gigachat/models/models.py b/src/gigachat/models/models.py index 1aab066..833d5fb 100644 --- a/src/gigachat/models/models.py +++ b/src/gigachat/models/models.py @@ -1,10 +1,11 @@ from typing import List from gigachat.models.model import Model -from gigachat.pydantic_v1 import BaseModel, Field +from gigachat.models.with_x_headers import WithXHeaders +from gigachat.pydantic_v1 import Field -class Models(BaseModel): +class Models(WithXHeaders): """Доступные модели""" data: List[Model] diff --git a/src/gigachat/models/open_api_functions.py b/src/gigachat/models/open_api_functions.py index 64e5bde..0838530 100644 --- a/src/gigachat/models/open_api_functions.py +++ b/src/gigachat/models/open_api_functions.py @@ -1,10 +1,10 @@ from typing import List -from gigachat.models import Function -from gigachat.pydantic_v1 import BaseModel +from gigachat.models.function import Function +from gigachat.models.with_x_headers import WithXHeaders -class OpenApiFunctions(BaseModel): +class OpenApiFunctions(WithXHeaders): """Функции конвертированные из OpenAPI в GigaFunctions""" functions: List[Function] diff --git a/src/gigachat/models/threads/thread_completion.py b/src/gigachat/models/threads/thread_completion.py index de34585..4e52d3a 100644 --- a/src/gigachat/models/threads/thread_completion.py +++ b/src/gigachat/models/threads/thread_completion.py @@ -1,9 +1,10 @@ -from gigachat.models import Messages +from gigachat.models.messages import Messages from gigachat.models.usage import Usage -from gigachat.pydantic_v1 import BaseModel, Field +from gigachat.models.with_x_headers import WithXHeaders +from gigachat.pydantic_v1 import Field -class ThreadCompletion(BaseModel): +class ThreadCompletion(WithXHeaders): """Ответ модели""" object_: str = Field(alias="object") diff --git a/src/gigachat/models/threads/thread_completion_chunk.py b/src/gigachat/models/threads/thread_completion_chunk.py index d9f9c27..27b1a10 100644 --- a/src/gigachat/models/threads/thread_completion_chunk.py +++ b/src/gigachat/models/threads/thread_completion_chunk.py @@ -1,11 +1,12 @@ from typing import List -from gigachat.models import ChoicesChunk +from gigachat.models.choices_chunk import ChoicesChunk from gigachat.models.usage import Usage -from gigachat.pydantic_v1 import BaseModel, Field +from gigachat.models.with_x_headers import WithXHeaders +from gigachat.pydantic_v1 import Field -class ThreadCompletionChunk(BaseModel): +class ThreadCompletionChunk(WithXHeaders): """Ответ модели""" object_: str = Field(alias="object") diff --git a/src/gigachat/models/threads/thread_messages.py b/src/gigachat/models/threads/thread_messages.py index 0403443..9d81971 100644 --- a/src/gigachat/models/threads/thread_messages.py +++ b/src/gigachat/models/threads/thread_messages.py @@ -1,10 +1,10 @@ from typing import List from gigachat.models.threads.thread_message import ThreadMessage -from gigachat.pydantic_v1 import BaseModel +from gigachat.models.with_x_headers import WithXHeaders -class ThreadMessages(BaseModel): +class ThreadMessages(WithXHeaders): """Сообщения треда""" thread_id: str diff --git a/src/gigachat/models/threads/thread_messages_response.py b/src/gigachat/models/threads/thread_messages_response.py index fd7fd6a..cd6a302 100644 --- a/src/gigachat/models/threads/thread_messages_response.py +++ b/src/gigachat/models/threads/thread_messages_response.py @@ -1,10 +1,10 @@ from typing import List from gigachat.models.threads.thread_message_response import ThreadMessageResponse -from gigachat.pydantic_v1 import BaseModel +from gigachat.models.with_x_headers import WithXHeaders -class ThreadMessagesResponse(BaseModel): +class ThreadMessagesResponse(WithXHeaders): thread_id: str """Идентификатор треда""" messages: List[ThreadMessageResponse] diff --git a/src/gigachat/models/threads/thread_run_response.py b/src/gigachat/models/threads/thread_run_response.py index 6c33bd1..d8b937a 100644 --- a/src/gigachat/models/threads/thread_run_response.py +++ b/src/gigachat/models/threads/thread_run_response.py @@ -1,8 +1,8 @@ from gigachat.models.threads.thread_status import ThreadStatus -from gigachat.pydantic_v1 import BaseModel +from gigachat.models.with_x_headers import WithXHeaders -class ThreadRunResponse(BaseModel): +class ThreadRunResponse(WithXHeaders): status: ThreadStatus """Статус запуска""" thread_id: str diff --git a/src/gigachat/models/threads/thread_run_result.py b/src/gigachat/models/threads/thread_run_result.py index e3beb30..32d58db 100644 --- a/src/gigachat/models/threads/thread_run_result.py +++ b/src/gigachat/models/threads/thread_run_result.py @@ -2,10 +2,10 @@ from gigachat.models.threads.thread_message import ThreadMessage from gigachat.models.threads.thread_status import ThreadStatus -from gigachat.pydantic_v1 import BaseModel +from gigachat.models.with_x_headers import WithXHeaders -class ThreadRunResult(BaseModel): +class ThreadRunResult(WithXHeaders): """Run треда""" status: ThreadStatus diff --git a/src/gigachat/models/threads/threads.py b/src/gigachat/models/threads/threads.py index 05cabc2..9062a96 100644 --- a/src/gigachat/models/threads/threads.py +++ b/src/gigachat/models/threads/threads.py @@ -1,10 +1,10 @@ from typing import List from gigachat.models.threads.thread import Thread -from gigachat.pydantic_v1 import BaseModel +from gigachat.models.with_x_headers import WithXHeaders -class Threads(BaseModel): +class Threads(WithXHeaders): """Треды""" threads: List[Thread] diff --git a/src/gigachat/models/token.py b/src/gigachat/models/token.py index cc4db6c..57cf034 100644 --- a/src/gigachat/models/token.py +++ b/src/gigachat/models/token.py @@ -1,7 +1,7 @@ -from gigachat.pydantic_v1 import BaseModel +from gigachat.models.with_x_headers import WithXHeaders -class Token(BaseModel): +class Token(WithXHeaders): """Токен доступа""" tok: str diff --git a/src/gigachat/models/tokens_count.py b/src/gigachat/models/tokens_count.py index a2a67d3..47e9a9a 100644 --- a/src/gigachat/models/tokens_count.py +++ b/src/gigachat/models/tokens_count.py @@ -1,7 +1,8 @@ -from gigachat.pydantic_v1 import BaseModel, Field +from gigachat.models.with_x_headers import WithXHeaders +from gigachat.pydantic_v1 import Field -class TokensCount(BaseModel): +class TokensCount(WithXHeaders): """Информация о количестве токенов""" tokens: int diff --git a/src/gigachat/models/uploaded_file.py b/src/gigachat/models/uploaded_file.py index 144d897..aae7c8c 100644 --- a/src/gigachat/models/uploaded_file.py +++ b/src/gigachat/models/uploaded_file.py @@ -1,7 +1,8 @@ -from gigachat.pydantic_v1 import BaseModel, Field +from gigachat.models.with_x_headers import WithXHeaders +from gigachat.pydantic_v1 import Field -class UploadedFile(BaseModel): +class UploadedFile(WithXHeaders): """Информация о загруженном файле""" id_: str = Field(alias="id") diff --git a/src/gigachat/models/with_x_headers.py b/src/gigachat/models/with_x_headers.py new file mode 100644 index 0000000..14dac75 --- /dev/null +++ b/src/gigachat/models/with_x_headers.py @@ -0,0 +1,8 @@ +from typing import Dict, Optional + +from gigachat.pydantic_v1 import BaseModel, Field + + +class WithXHeaders(BaseModel): + x_headers: Optional[Dict[str, Optional[str]]] = Field(default=None) + """Служебная информация о запросе (x-request-id, x-session-id, x-client-id)""" diff --git a/tests/unit_tests/gigachat/test_client.py b/tests/unit_tests/gigachat/test_client.py index bdae0a2..3fab0d8 100644 --- a/tests/unit_tests/gigachat/test_client.py +++ b/tests/unit_tests/gigachat/test_client.py @@ -401,8 +401,11 @@ def test_get_token_credentials(httpx_mock: HTTPXMock) -> None: ) access_token = model.get_token() - assert model._access_token == ACCESS_TOKEN - assert access_token == ACCESS_TOKEN + assert model._access_token is not None + assert model._access_token.access_token == ACCESS_TOKEN["access_token"] + assert model._access_token.expires_at == ACCESS_TOKEN["expires_at"] + assert access_token.access_token == ACCESS_TOKEN["access_token"] + assert access_token.expires_at == ACCESS_TOKEN["expires_at"] def test_balance(httpx_mock: HTTPXMock) -> None: @@ -675,5 +678,8 @@ async def test_aget_token_credentials(httpx_mock: HTTPXMock) -> None: ) access_token = await model.aget_token() - assert model._access_token == ACCESS_TOKEN - assert access_token == ACCESS_TOKEN + assert model._access_token is not None + assert model._access_token.access_token == ACCESS_TOKEN["access_token"] + assert model._access_token.expires_at == ACCESS_TOKEN["expires_at"] + assert access_token.access_token == ACCESS_TOKEN["access_token"] + assert access_token.expires_at == ACCESS_TOKEN["expires_at"] From eaff10985c9b2f06199d5b6a60b67c729c048bcb Mon Sep 17 00:00:00 2001 From: mikelarg Date: Wed, 29 Jan 2025 17:27:59 +0300 Subject: [PATCH 2/2] chore: version up --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6595796..25fbf48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gigachat" -version = "0.1.37post1" +version = "0.1.38" description = "GigaChat. Python-library for GigaChain and LangChain" authors = ["Konstantin Krestnikov ", "Sergey Malyshev "] license = "MIT"