From b9d6a9cc95f5b318dc6b11d90b453ae577948c5d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 14 Dec 2024 11:45:23 +0000 Subject: [PATCH 1/2] support a custom httpx client in Client and AsyncClient --- ollama/_client.py | 175 ++++++++++++++++++++++++++++++------------- tests/test_client.py | 73 ++++++++++++++++-- 2 files changed, 188 insertions(+), 60 deletions(-) diff --git a/ollama/_client.py b/ollama/_client.py index 87fa881..dde82ac 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -74,48 +74,54 @@ T = TypeVar('T') -class BaseClient: +class Client: + @overload def __init__( self, - client, host: Optional[str] = None, - follow_redirects: bool = True, + *, + follow_redirects: bool | None = None, timeout: Any = None, headers: Optional[Mapping[str, str]] = None, - **kwargs, - ) -> None: - """ - Creates a httpx client. Default parameters are the same as those defined in httpx - except for the following: - - `follow_redirects`: True - - `timeout`: None - `kwargs` are passed to the httpx client. - """ + **httpx_kwargs: Any, + ) -> None: ... - self._client = client( - base_url=_parse_host(host or os.getenv('OLLAMA_HOST')), - follow_redirects=follow_redirects, - timeout=timeout, - # Lowercase all headers to ensure override - headers={ - k.lower(): v - for k, v in { - **(headers or {}), - 'Content-Type': 'application/json', - 'Accept': 'application/json', - 'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}', - }.items() - }, - **kwargs, - ) - - -class Client(BaseClient): - def __init__(self, host: Optional[str] = None, **kwargs) -> None: - super().__init__(httpx.Client, host, **kwargs) + @overload + def __init__( + self, + host: Optional[str] = None, + *, + client: httpx.Client, + headers: Optional[Mapping[str, str]] = None, + ) -> None: ... - def _request_raw(self, *args, **kwargs): - r = self._client.request(*args, **kwargs) + def __init__( + self, + host: Optional[str] = None, + *, + client: httpx.Client | None = None, + follow_redirects: bool | None = None, + timeout: Any = None, + headers: Optional[Mapping[str, str]] = None, + **httpx_kwargs: Any, + ) -> None: + self._host = _parse_host(host or os.getenv('OLLAMA_HOST')) + self._request_headers = _get_headers(headers) + if client: + assert follow_redirects is None, 'Cannot provide both `client` and `follow_redirects`' + assert timeout is None, 'Cannot provide both `client` and `timeout`' + assert not httpx_kwargs, 'Cannot provide both `client` and `httpx_kwargs`' + self._client = client + else: + self._client = httpx.Client( + follow_redirects=True if follow_redirects is None else follow_redirects, + timeout=timeout, + **httpx_kwargs, + ) + + def _request_raw(self, method: str, path: str, **kwargs): + assert path.startswith('/'), 'path must start with "/"' + r = self._client.request(method, self._host + path, headers=self._request_headers, **kwargs) try: r.raise_for_status() except httpx.HTTPStatusError as e: @@ -126,7 +132,8 @@ def _request_raw(self, *args, **kwargs): def _request( self, cls: Type[T], - *args, + method: str, + path: str, stream: Literal[False] = False, **kwargs, ) -> T: ... @@ -135,7 +142,8 @@ def _request( def _request( self, cls: Type[T], - *args, + method: str, + path: str, stream: Literal[True] = True, **kwargs, ) -> Iterator[T]: ... @@ -144,7 +152,8 @@ def _request( def _request( self, cls: Type[T], - *args, + method: str, + path: str, stream: bool = False, **kwargs, ) -> Union[T, Iterator[T]]: ... @@ -152,14 +161,16 @@ def _request( def _request( self, cls: Type[T], - *args, + method: str, + path: str, stream: bool = False, **kwargs, ) -> Union[T, Iterator[T]]: if stream: def inner(): - with self._client.stream(*args, **kwargs) as r: + assert path.startswith('/'), 'path must start with "/"' + with self._client.stream(method, self._host + path, headers=self._request_headers, **kwargs) as r: try: r.raise_for_status() except httpx.HTTPStatusError as e: @@ -174,7 +185,7 @@ def inner(): return inner() - return cls(**self._request_raw(*args, **kwargs).json()) + return cls(**self._request_raw(method, path, **kwargs).json()) @overload def generate( @@ -612,12 +623,54 @@ def ps(self) -> ProcessResponse: ) -class AsyncClient(BaseClient): - def __init__(self, host: Optional[str] = None, **kwargs) -> None: - super().__init__(httpx.AsyncClient, host, **kwargs) +class AsyncClient: + @overload + def __init__( + self, + host: Optional[str] = None, + *, + follow_redirects: bool | None = None, + timeout: Any = None, + headers: Optional[Mapping[str, str]] = None, + **httpx_kwargs: Any, + ) -> None: ... - async def _request_raw(self, *args, **kwargs): - r = await self._client.request(*args, **kwargs) + @overload + def __init__( + self, + host: Optional[str] = None, + *, + client: httpx.AsyncClient, + headers: Optional[Mapping[str, str]] = None, + ) -> None: ... + + def __init__( + self, + host: Optional[str] = None, + *, + client: httpx.AsyncClient | None = None, + follow_redirects: bool | None = None, + timeout: Any = None, + headers: Optional[Mapping[str, str]] = None, + **httpx_kwargs: Any, + ) -> None: + self._host = _parse_host(host or os.getenv('OLLAMA_HOST')) + self._request_headers = _get_headers(headers) + if client: + assert follow_redirects is None, 'Cannot provide both `client` and `follow_redirects`' + assert timeout is None, 'Cannot provide both `client` and `timeout`' + assert not httpx_kwargs, 'Cannot provide both `client` and `httpx_kwargs`' + self._client = client + else: + self._client = httpx.AsyncClient( + follow_redirects=True if follow_redirects is None else follow_redirects, + timeout=timeout, + **httpx_kwargs, + ) + + async def _request_raw(self, method: str, path: str, **kwargs): + assert path.startswith('/'), 'path must start with "/"' + r = await self._client.request(method, self._host + path, headers=self._request_headers, **kwargs) try: r.raise_for_status() except httpx.HTTPStatusError as e: @@ -628,7 +681,8 @@ async def _request_raw(self, *args, **kwargs): async def _request( self, cls: Type[T], - *args, + method: str, + path: str, stream: Literal[False] = False, **kwargs, ) -> T: ... @@ -637,7 +691,8 @@ async def _request( async def _request( self, cls: Type[T], - *args, + method: str, + path: str, stream: Literal[True] = True, **kwargs, ) -> AsyncIterator[T]: ... @@ -646,7 +701,8 @@ async def _request( async def _request( self, cls: Type[T], - *args, + method: str, + path: str, stream: bool = False, **kwargs, ) -> Union[T, AsyncIterator[T]]: ... @@ -654,14 +710,16 @@ async def _request( async def _request( self, cls: Type[T], - *args, + method: str, + path: str, stream: bool = False, **kwargs, ) -> Union[T, AsyncIterator[T]]: if stream: async def inner(): - async with self._client.stream(*args, **kwargs) as r: + assert path.startswith('/'), 'path must start with "/"' + async with self._client.stream(method, self._host + path, headers=self._request_headers, **kwargs) as r: try: r.raise_for_status() except httpx.HTTPStatusError as e: @@ -676,7 +734,7 @@ async def inner(): return inner() - return cls(**(await self._request_raw(*args, **kwargs)).json()) + return cls(**(await self._request_raw(method, path, **kwargs)).json()) @overload async def generate( @@ -1231,3 +1289,16 @@ def _parse_host(host: Optional[str]) -> str: return f'{scheme}://{host}:{port}/{path}' return f'{scheme}://{host}:{port}' + + +def _get_headers(extra_headers: Optional[Mapping[str, str]] = None) -> Mapping[str, str]: + # Lowercase all headers to ensure override + return { + k.lower(): v + for k, v in { + **(extra_headers or {}), + 'Content-Type': 'application/json', + 'Accept': 'application/json', + 'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}', + }.items() + } diff --git a/tests/test_client.py b/tests/test_client.py index aab2f2e..7c7cda3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,8 @@ import os import io import json +from typing import Optional + from pydantic import ValidationError, BaseModel import pytest import tempfile @@ -1193,20 +1195,75 @@ async def test_async_client_copy(httpserver: HTTPServer): assert response['status'] == 'success' -def test_headers(): - client = Client() - assert client._client.headers['content-type'] == 'application/json' - assert client._client.headers['accept'] == 'application/json' - assert client._client.headers['user-agent'].startswith('ollama-python/') +def custom_header_matcher(header_name: str, actual: Optional[str], expected: str) -> bool: + if header_name == 'User-Agent': + return actual.startswith(expected) + else: + return actual == expected + + +def test_headers(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'tools': [], + 'stream': False, + }, + header_value_matcher=custom_header_matcher, + headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'User-Agent': 'ollama-python/'}, + ).respond_with_json( + { + 'model': 'dummy', + 'message': { + 'role': 'assistant', + 'content': "I don't know.", + }, + } + ) + + client = Client(httpserver.url_for('/')) + response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}]) + assert response['model'] == 'dummy' + assert response['message']['role'] == 'assistant' + assert response['message']['content'] == "I don't know." + + +def test_custom_headers(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'tools': [], + 'stream': False, + }, + header_value_matcher=custom_header_matcher, + headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'User-Agent': 'ollama-python/', 'X-Custom': 'value'}, + ).respond_with_json( + { + 'model': 'dummy', + 'message': { + 'role': 'assistant', + 'content': "I don't know.", + }, + } + ) client = Client( + httpserver.url_for('/'), headers={ 'X-Custom': 'value', 'Content-Type': 'text/plain', - } + }, ) - assert client._client.headers['x-custom'] == 'value' - assert client._client.headers['content-type'] == 'application/json' + response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}]) + assert response['model'] == 'dummy' + assert response['message']['role'] == 'assistant' + assert response['message']['content'] == "I don't know." def test_copy_tools(): From 3572fc73594ea04302a75c16216c40831876e9c0 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 21 Dec 2024 15:31:53 +0000 Subject: [PATCH 2/2] fix suggestions --- ollama/_client.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/ollama/_client.py b/ollama/_client.py index dde82ac..a9d65aa 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -80,7 +80,7 @@ def __init__( self, host: Optional[str] = None, *, - follow_redirects: bool | None = None, + follow_redirects: Optional[bool] = None, timeout: Any = None, headers: Optional[Mapping[str, str]] = None, **httpx_kwargs: Any, @@ -99,8 +99,8 @@ def __init__( self, host: Optional[str] = None, *, - client: httpx.Client | None = None, - follow_redirects: bool | None = None, + client: Optional[httpx.Client] = None, + follow_redirects: Optional[bool] = None, timeout: Any = None, headers: Optional[Mapping[str, str]] = None, **httpx_kwargs: Any, @@ -134,6 +134,7 @@ def _request( cls: Type[T], method: str, path: str, + *, stream: Literal[False] = False, **kwargs, ) -> T: ... @@ -144,6 +145,7 @@ def _request( cls: Type[T], method: str, path: str, + *, stream: Literal[True] = True, **kwargs, ) -> Iterator[T]: ... @@ -154,6 +156,7 @@ def _request( cls: Type[T], method: str, path: str, + *, stream: bool = False, **kwargs, ) -> Union[T, Iterator[T]]: ... @@ -629,7 +632,7 @@ def __init__( self, host: Optional[str] = None, *, - follow_redirects: bool | None = None, + follow_redirects: Optional[bool] = None, timeout: Any = None, headers: Optional[Mapping[str, str]] = None, **httpx_kwargs: Any, @@ -648,8 +651,8 @@ def __init__( self, host: Optional[str] = None, *, - client: httpx.AsyncClient | None = None, - follow_redirects: bool | None = None, + client: Optional[httpx.AsyncClient] = None, + follow_redirects: Optional[bool] = None, timeout: Any = None, headers: Optional[Mapping[str, str]] = None, **httpx_kwargs: Any,