Skip to content

Commit

Permalink
feat(api)!: define chat completion models (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-app[bot] committed Jan 9, 2025
1 parent 3bb05ac commit 2a1d32a
Show file tree
Hide file tree
Showing 41 changed files with 691 additions and 664 deletions.
2 changes: 1 addition & 1 deletion .stats.yml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
configured_endpoints: 21
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/writerai%2Fwriter-fcd4d82943d0aeefc300520f0ee4684456ef647140f1d6ba9ffcb86278d83d3a.yml
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/writerai%2Fwriter-efae0fba75d52fb4c68e8f0332de2486bea6777516ef5cc90163a7c504d95194.yml
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ client = Writer(
api_key=os.environ.get("WRITER_API_KEY"), # This is the default and can be omitted
)

chat = client.chat.chat(
chat_completion = client.chat.chat(
messages=[{"role": "user"}],
model="palmyra-x-004",
)
print(chat.id)
print(chat_completion.id)
```

While you can provide an `api_key` keyword argument,
Expand All @@ -59,11 +59,11 @@ client = AsyncWriter(


async def main() -> None:
chat = await client.chat.chat(
chat_completion = await client.chat.chat(
messages=[{"role": "user"}],
model="palmyra-x-004",
)
print(chat.id)
print(chat_completion.id)


asyncio.run(main())
Expand Down
33 changes: 30 additions & 3 deletions api.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
# Shared Types

```python
from writerai.types import (
ErrorMessage,
ErrorObject,
FunctionDefinition,
FunctionParams,
GraphData,
Logprobs,
LogprobsToken,
Source,
ToolCall,
ToolCallStreaming,
ToolChoiceJsonObject,
ToolChoiceString,
ToolParam,
)
```

# Applications

Types:
Expand All @@ -15,19 +35,26 @@ Methods:
Types:

```python
from writerai.types import Chat, ChatCompletionChunk
from writerai.types import (
ChatCompletion,
ChatCompletionChoice,
ChatCompletionChunk,
ChatCompletionMessage,
ChatCompletionParams,
ChatCompletionUsage,
)
```

Methods:

- <code title="post /v1/chat">client.chat.<a href="./src/writerai/resources/chat.py">chat</a>(\*\*<a href="src/writerai/types/chat_chat_params.py">params</a>) -> <a href="./src/writerai/types/chat.py">Chat</a></code>
- <code title="post /v1/chat">client.chat.<a href="./src/writerai/resources/chat.py">chat</a>(\*\*<a href="src/writerai/types/chat_chat_params.py">params</a>) -> <a href="./src/writerai/types/chat_completion.py">ChatCompletion</a></code>

# Completions

Types:

```python
from writerai.types import Completion, StreamingData
from writerai.types import Completion, CompletionChunk, CompletionParams
```

Methods:
Expand Down
6 changes: 6 additions & 0 deletions src/writerai/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def __stream__(self) -> Iterator[_T]:
iterator = self._iter_events()

for sse in iterator:
if sse.data.startswith("[DONE]"):
break

if sse.event is None:
yield process_data(data=sse.json(), cast_to=cast_to, response=response)

Expand Down Expand Up @@ -135,6 +138,9 @@ async def __stream__(self) -> AsyncIterator[_T]:
iterator = self._iter_events()

async for sse in iterator:
if sse.data.startswith("[DONE]"):
break

if sse.event is None:
yield process_data(data=sse.json(), cast_to=cast_to, response=response)

Expand Down
35 changes: 18 additions & 17 deletions src/writerai/resources/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
async_to_streamed_response_wrapper,
)
from .._streaming import Stream, AsyncStream
from ..types.chat import Chat
from .._base_client import make_request_options
from ..types.chat_completion import ChatCompletion
from ..types.chat_completion_chunk import ChatCompletionChunk
from ..types.shared_params.tool_param import ToolParam

__all__ = ["ChatResource", "AsyncChatResource"]

Expand Down Expand Up @@ -64,15 +65,15 @@ def chat(
stream_options: chat_chat_params.StreamOptions | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: chat_chat_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[chat_chat_params.Tool] | NotGiven = NOT_GIVEN,
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Chat:
) -> ChatCompletion:
"""Generate a chat completion based on the provided messages.
The response shown
Expand Down Expand Up @@ -147,7 +148,7 @@ def chat(
stream_options: chat_chat_params.StreamOptions | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: chat_chat_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[chat_chat_params.Tool] | NotGiven = NOT_GIVEN,
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down Expand Up @@ -230,15 +231,15 @@ def chat(
stream_options: chat_chat_params.StreamOptions | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: chat_chat_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[chat_chat_params.Tool] | NotGiven = NOT_GIVEN,
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Chat | Stream[ChatCompletionChunk]:
) -> ChatCompletion | Stream[ChatCompletionChunk]:
"""Generate a chat completion based on the provided messages.
The response shown
Expand Down Expand Up @@ -313,15 +314,15 @@ def chat(
stream_options: chat_chat_params.StreamOptions | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: chat_chat_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[chat_chat_params.Tool] | NotGiven = NOT_GIVEN,
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Chat | Stream[ChatCompletionChunk]:
) -> ChatCompletion | Stream[ChatCompletionChunk]:
return self._post(
"/v1/chat",
body=maybe_transform(
Expand All @@ -344,7 +345,7 @@ def chat(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=Chat,
cast_to=ChatCompletion,
stream=stream or False,
stream_cls=Stream[ChatCompletionChunk],
)
Expand Down Expand Up @@ -384,15 +385,15 @@ async def chat(
stream_options: chat_chat_params.StreamOptions | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: chat_chat_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[chat_chat_params.Tool] | NotGiven = NOT_GIVEN,
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Chat:
) -> ChatCompletion:
"""Generate a chat completion based on the provided messages.
The response shown
Expand Down Expand Up @@ -467,7 +468,7 @@ async def chat(
stream_options: chat_chat_params.StreamOptions | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: chat_chat_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[chat_chat_params.Tool] | NotGiven = NOT_GIVEN,
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down Expand Up @@ -550,15 +551,15 @@ async def chat(
stream_options: chat_chat_params.StreamOptions | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: chat_chat_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[chat_chat_params.Tool] | NotGiven = NOT_GIVEN,
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Chat | AsyncStream[ChatCompletionChunk]:
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
"""Generate a chat completion based on the provided messages.
The response shown
Expand Down Expand Up @@ -633,15 +634,15 @@ async def chat(
stream_options: chat_chat_params.StreamOptions | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: chat_chat_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[chat_chat_params.Tool] | NotGiven = NOT_GIVEN,
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Chat | AsyncStream[ChatCompletionChunk]:
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
return await self._post(
"/v1/chat",
body=await async_maybe_transform(
Expand All @@ -664,7 +665,7 @@ async def chat(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=Chat,
cast_to=ChatCompletion,
stream=stream or False,
stream_cls=AsyncStream[ChatCompletionChunk],
)
Expand Down
18 changes: 9 additions & 9 deletions src/writerai/resources/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .._streaming import Stream, AsyncStream
from .._base_client import make_request_options
from ..types.completion import Completion
from ..types.streaming_data import StreamingData
from ..types.completion_chunk import CompletionChunk

__all__ = ["CompletionsResource", "AsyncCompletionsResource"]

Expand Down Expand Up @@ -130,7 +130,7 @@ def create(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Stream[StreamingData]:
) -> Stream[CompletionChunk]:
"""
Text generation
Expand Down Expand Up @@ -191,7 +191,7 @@ def create(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Completion | Stream[StreamingData]:
) -> Completion | Stream[CompletionChunk]:
"""
Text generation
Expand Down Expand Up @@ -252,7 +252,7 @@ def create(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Completion | Stream[StreamingData]:
) -> Completion | Stream[CompletionChunk]:
return self._post(
"/v1/completions",
body=maybe_transform(
Expand All @@ -274,7 +274,7 @@ def create(
),
cast_to=Completion,
stream=stream or False,
stream_cls=Stream[StreamingData],
stream_cls=Stream[CompletionChunk],
)


Expand Down Expand Up @@ -378,7 +378,7 @@ async def create(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AsyncStream[StreamingData]:
) -> AsyncStream[CompletionChunk]:
"""
Text generation
Expand Down Expand Up @@ -439,7 +439,7 @@ async def create(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Completion | AsyncStream[StreamingData]:
) -> Completion | AsyncStream[CompletionChunk]:
"""
Text generation
Expand Down Expand Up @@ -500,7 +500,7 @@ async def create(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Completion | AsyncStream[StreamingData]:
) -> Completion | AsyncStream[CompletionChunk]:
return await self._post(
"/v1/completions",
body=await async_maybe_transform(
Expand All @@ -522,7 +522,7 @@ async def create(
),
cast_to=Completion,
stream=stream or False,
stream_cls=AsyncStream[StreamingData],
stream_cls=AsyncStream[CompletionChunk],
)


Expand Down
22 changes: 20 additions & 2 deletions src/writerai/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,28 @@

from __future__ import annotations

from .chat import Chat as Chat
from .file import File as File
from .graph import Graph as Graph
from .shared import (
Source as Source,
Logprobs as Logprobs,
ToolCall as ToolCall,
GraphData as GraphData,
ToolParam as ToolParam,
ErrorObject as ErrorObject,
ErrorMessage as ErrorMessage,
LogprobsToken as LogprobsToken,
FunctionParams as FunctionParams,
ToolChoiceString as ToolChoiceString,
ToolCallStreaming as ToolCallStreaming,
FunctionDefinition as FunctionDefinition,
ToolChoiceJsonObject as ToolChoiceJsonObject,
)
from .question import Question as Question
from .completion import Completion as Completion
from .streaming_data import StreamingData as StreamingData
from .chat_completion import ChatCompletion as ChatCompletion
from .chat_chat_params import ChatChatParams as ChatChatParams
from .completion_chunk import CompletionChunk as CompletionChunk
from .file_list_params import FileListParams as FileListParams
from .file_retry_params import FileRetryParams as FileRetryParams
from .graph_list_params import GraphListParams as GraphListParams
Expand All @@ -19,11 +34,14 @@
from .model_list_response import ModelListResponse as ModelListResponse
from .file_delete_response import FileDeleteResponse as FileDeleteResponse
from .chat_completion_chunk import ChatCompletionChunk as ChatCompletionChunk
from .chat_completion_usage import ChatCompletionUsage as ChatCompletionUsage
from .graph_create_response import GraphCreateResponse as GraphCreateResponse
from .graph_delete_response import GraphDeleteResponse as GraphDeleteResponse
from .graph_question_params import GraphQuestionParams as GraphQuestionParams
from .graph_update_response import GraphUpdateResponse as GraphUpdateResponse
from .tool_parse_pdf_params import ToolParsePdfParams as ToolParsePdfParams
from .chat_completion_choice import ChatCompletionChoice as ChatCompletionChoice
from .chat_completion_message import ChatCompletionMessage as ChatCompletionMessage
from .question_response_chunk import QuestionResponseChunk as QuestionResponseChunk
from .tool_parse_pdf_response import ToolParsePdfResponse as ToolParsePdfResponse
from .completion_create_params import CompletionCreateParams as CompletionCreateParams
Expand Down
Loading

0 comments on commit 2a1d32a

Please sign in to comment.