From af27fe82a848f170ce800125c86acd813f452069 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
 <142633134+stainless-app[bot]@users.noreply.github.com>
Date: Thu, 8 Aug 2024 13:17:06 +0000
Subject: [PATCH] chore(internal): codegen related update (#48)

---
 requirements-dev.lock        |  2 +-
 src/writerai/_base_client.py |  8 +++++
 src/writerai/_response.py    |  5 +++
 tests/test_client.py         | 61 ++++++++++++++++++++++++++++++++++++
 4 files changed, 75 insertions(+), 1 deletion(-)

diff --git a/requirements-dev.lock b/requirements-dev.lock
index d9332cf..5f5a3d5 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -70,7 +70,7 @@ pydantic-core==2.18.2
     # via pydantic
 pygments==2.18.0
     # via rich
-pyright==1.1.364
+pyright==1.1.374
 pytest==7.1.1
     # via pytest-asyncio
 pytest-asyncio==0.21.1
diff --git a/src/writerai/_base_client.py b/src/writerai/_base_client.py
index 11f7e74..fb189bb 100644
--- a/src/writerai/_base_client.py
+++ b/src/writerai/_base_client.py
@@ -1049,6 +1049,7 @@ def _request(
             response=response,
             stream=stream,
             stream_cls=stream_cls,
+            retries_taken=options.get_max_retries(self.max_retries) - retries,
         )
 
     def _retry_request(
@@ -1090,6 +1091,7 @@ def _process_response(
         response: httpx.Response,
         stream: bool,
         stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
+        retries_taken: int = 0,
     ) -> ResponseT:
         origin = get_origin(cast_to) or cast_to
 
@@ -1107,6 +1109,7 @@ def _process_response(
                     stream=stream,
                     stream_cls=stream_cls,
                     options=options,
+                    retries_taken=retries_taken,
                 ),
             )
 
@@ -1120,6 +1123,7 @@ def _process_response(
             stream=stream,
             stream_cls=stream_cls,
             options=options,
+            retries_taken=retries_taken,
         )
         if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):
             return cast(ResponseT, api_response)
@@ -1610,6 +1614,7 @@ async def _request(
             response=response,
             stream=stream,
             stream_cls=stream_cls,
+            retries_taken=options.get_max_retries(self.max_retries) - retries,
         )
 
     async def _retry_request(
@@ -1649,6 +1654,7 @@ async def _process_response(
         response: httpx.Response,
         stream: bool,
         stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
+        retries_taken: int = 0,
     ) -> ResponseT:
         origin = get_origin(cast_to) or cast_to
 
@@ -1666,6 +1672,7 @@ async def _process_response(
                     stream=stream,
                     stream_cls=stream_cls,
                     options=options,
+                    retries_taken=retries_taken,
                 ),
             )
 
@@ -1679,6 +1686,7 @@ async def _process_response(
             stream=stream,
             stream_cls=stream_cls,
             options=options,
+            retries_taken=retries_taken,
         )
         if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):
             return cast(ResponseT, api_response)
diff --git a/src/writerai/_response.py b/src/writerai/_response.py
index f587c71..b7b9bfb 100644
--- a/src/writerai/_response.py
+++ b/src/writerai/_response.py
@@ -55,6 +55,9 @@ class BaseAPIResponse(Generic[R]):
 
     http_response: httpx.Response
 
+    retries_taken: int
+    """The number of retries made. If no retries happened this will be `0`"""
+
     def __init__(
         self,
         *,
@@ -64,6 +67,7 @@ def __init__(
         stream: bool,
         stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
         options: FinalRequestOptions,
+        retries_taken: int = 0,
     ) -> None:
         self._cast_to = cast_to
         self._client = client
@@ -72,6 +76,7 @@ def __init__(
         self._stream_cls = stream_cls
         self._options = options
         self.http_response = raw
+        self.retries_taken = retries_taken
 
     @property
     def headers(self) -> httpx.Headers:
diff --git a/tests/test_client.py b/tests/test_client.py
index c609809..d190307 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -763,6 +763,35 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non
 
         assert _get_open_connections(self.client) == 0
 
+    @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
+    @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
+    @pytest.mark.respx(base_url=base_url)
+    def test_retries_taken(self, client: Writer, failures_before_success: int, respx_mock: MockRouter) -> None:
+        client = client.with_options(max_retries=4)
+
+        nb_retries = 0
+
+        def retry_handler(_request: httpx.Request) -> httpx.Response:
+            nonlocal nb_retries
+            if nb_retries < failures_before_success:
+                nb_retries += 1
+                return httpx.Response(500)
+            return httpx.Response(200)
+
+        respx_mock.post("/v1/chat").mock(side_effect=retry_handler)
+
+        response = client.chat.with_raw_response.chat(
+            messages=[
+                {
+                    "content": "content",
+                    "role": "user",
+                }
+            ],
+            model="model",
+        )
+
+        assert response.retries_taken == failures_before_success
+
 
 class TestAsyncWriter:
     client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=True)
@@ -1493,3 +1522,35 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter)
             )
 
         assert _get_open_connections(self.client) == 0
+
+    @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
+    @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
+    @pytest.mark.respx(base_url=base_url)
+    @pytest.mark.asyncio
+    async def test_retries_taken(
+        self, async_client: AsyncWriter, failures_before_success: int, respx_mock: MockRouter
+    ) -> None:
+        client = async_client.with_options(max_retries=4)
+
+        nb_retries = 0
+
+        def retry_handler(_request: httpx.Request) -> httpx.Response:
+            nonlocal nb_retries
+            if nb_retries < failures_before_success:
+                nb_retries += 1
+                return httpx.Response(500)
+            return httpx.Response(200)
+
+        respx_mock.post("/v1/chat").mock(side_effect=retry_handler)
+
+        response = await client.chat.with_raw_response.chat(
+            messages=[
+                {
+                    "content": "content",
+                    "role": "user",
+                }
+            ],
+            model="model",
+        )
+
+        assert response.retries_taken == failures_before_success