Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support a custom HTTPX client in Client and AsyncClient #380

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 126 additions & 52 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,48 +74,54 @@
T = TypeVar('T')


class BaseClient:
class Client:
@overload
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These overload aren't necessary since there's no overlap between the client and non-client versions

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are to maintain proper typing, the point is you can't pass client and follow_redirects or timeout together. These overloads mean doing so will give a typing error.

def __init__(
self,
client,
host: Optional[str] = None,
follow_redirects: bool = True,
*,
follow_redirects: Optional[bool] = None,
timeout: Any = None,
headers: Optional[Mapping[str, str]] = None,
**kwargs,
) -> None:
"""
Creates a httpx client. Default parameters are the same as those defined in httpx
except for the following:
- `follow_redirects`: True
- `timeout`: None
`kwargs` are passed to the httpx client.
"""
**httpx_kwargs: Any,
) -> None: ...

self._client = client(
base_url=_parse_host(host or os.getenv('OLLAMA_HOST')),
follow_redirects=follow_redirects,
timeout=timeout,
# Lowercase all headers to ensure override
headers={
k.lower(): v
for k, v in {
**(headers or {}),
'Content-Type': 'application/json',
'Accept': 'application/json',
'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}',
}.items()
},
**kwargs,
)


class Client(BaseClient):
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
super().__init__(httpx.Client, host, **kwargs)
@overload
def __init__(
self,
host: Optional[str] = None,
*,
client: httpx.Client,
headers: Optional[Mapping[str, str]] = None,
) -> None: ...

def _request_raw(self, *args, **kwargs):
r = self._client.request(*args, **kwargs)
def __init__(
self,
host: Optional[str] = None,
*,
client: Optional[httpx.Client] = None,
follow_redirects: Optional[bool] = None,
timeout: Any = None,
headers: Optional[Mapping[str, str]] = None,
**httpx_kwargs: Any,
) -> None:
self._host = _parse_host(host or os.getenv('OLLAMA_HOST'))
self._request_headers = _get_headers(headers)
if client:
assert follow_redirects is None, 'Cannot provide both `client` and `follow_redirects`'
assert timeout is None, 'Cannot provide both `client` and `timeout`'
assert not httpx_kwargs, 'Cannot provide both `client` and `httpx_kwargs`'
self._client = client
else:
self._client = httpx.Client(
follow_redirects=True if follow_redirects is None else follow_redirects,
timeout=timeout,
**httpx_kwargs,
)

def _request_raw(self, method: str, path: str, **kwargs):
assert path.startswith('/'), 'path must start with "/"'
r = self._client.request(method, self._host + path, headers=self._request_headers, **kwargs)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
Expand All @@ -126,7 +132,9 @@ def _request_raw(self, *args, **kwargs):
def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
*,
stream: Literal[False] = False,
**kwargs,
) -> T: ...
Expand All @@ -135,7 +143,9 @@ def _request(
def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
*,
stream: Literal[True] = True,
**kwargs,
) -> Iterator[T]: ...
Expand All @@ -144,22 +154,26 @@ def _request(
def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
*,
stream: bool = False,
**kwargs,
) -> Union[T, Iterator[T]]: ...

def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved
stream: bool = False,
**kwargs,
) -> Union[T, Iterator[T]]:
if stream:

def inner():
with self._client.stream(*args, **kwargs) as r:
assert path.startswith('/'), 'path must start with "/"'
with self._client.stream(method, self._host + path, headers=self._request_headers, **kwargs) as r:
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
Expand All @@ -174,7 +188,7 @@ def inner():

return inner()

return cls(**self._request_raw(*args, **kwargs).json())
return cls(**self._request_raw(method, path, **kwargs).json())

@overload
def generate(
Expand Down Expand Up @@ -612,12 +626,54 @@ def ps(self) -> ProcessResponse:
)


class AsyncClient(BaseClient):
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
super().__init__(httpx.AsyncClient, host, **kwargs)
class AsyncClient:
@overload
def __init__(
self,
host: Optional[str] = None,
*,
follow_redirects: Optional[bool] = None,
timeout: Any = None,
headers: Optional[Mapping[str, str]] = None,
**httpx_kwargs: Any,
) -> None: ...

async def _request_raw(self, *args, **kwargs):
r = await self._client.request(*args, **kwargs)
@overload
def __init__(
self,
host: Optional[str] = None,
*,
client: httpx.AsyncClient,
headers: Optional[Mapping[str, str]] = None,
) -> None: ...

def __init__(
self,
host: Optional[str] = None,
*,
client: Optional[httpx.AsyncClient] = None,
follow_redirects: Optional[bool] = None,
timeout: Any = None,
headers: Optional[Mapping[str, str]] = None,
**httpx_kwargs: Any,
) -> None:
self._host = _parse_host(host or os.getenv('OLLAMA_HOST'))
self._request_headers = _get_headers(headers)
if client:
assert follow_redirects is None, 'Cannot provide both `client` and `follow_redirects`'
assert timeout is None, 'Cannot provide both `client` and `timeout`'
assert not httpx_kwargs, 'Cannot provide both `client` and `httpx_kwargs`'
self._client = client
else:
self._client = httpx.AsyncClient(
follow_redirects=True if follow_redirects is None else follow_redirects,
timeout=timeout,
**httpx_kwargs,
)

async def _request_raw(self, method: str, path: str, **kwargs):
assert path.startswith('/'), 'path must start with "/"'
r = await self._client.request(method, self._host + path, headers=self._request_headers, **kwargs)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
Expand All @@ -628,7 +684,8 @@ async def _request_raw(self, *args, **kwargs):
async def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
stream: Literal[False] = False,
**kwargs,
) -> T: ...
Expand All @@ -637,7 +694,8 @@ async def _request(
async def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
stream: Literal[True] = True,
**kwargs,
) -> AsyncIterator[T]: ...
Expand All @@ -646,22 +704,25 @@ async def _request(
async def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
stream: bool = False,
**kwargs,
) -> Union[T, AsyncIterator[T]]: ...

async def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
stream: bool = False,
**kwargs,
) -> Union[T, AsyncIterator[T]]:
if stream:

async def inner():
async with self._client.stream(*args, **kwargs) as r:
assert path.startswith('/'), 'path must start with "/"'
async with self._client.stream(method, self._host + path, headers=self._request_headers, **kwargs) as r:
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
Expand All @@ -676,7 +737,7 @@ async def inner():

return inner()

return cls(**(await self._request_raw(*args, **kwargs)).json())
return cls(**(await self._request_raw(method, path, **kwargs)).json())

@overload
async def generate(
Expand Down Expand Up @@ -1231,3 +1292,16 @@ def _parse_host(host: Optional[str]) -> str:
return f'{scheme}://{host}:{port}/{path}'

return f'{scheme}://{host}:{port}'


def _get_headers(extra_headers: Optional[Mapping[str, str]] = None) -> Mapping[str, str]:
# Lowercase all headers to ensure override
return {
k.lower(): v
for k, v in {
**(extra_headers or {}),
'Content-Type': 'application/json',
'Accept': 'application/json',
'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}',
}.items()
}
73 changes: 65 additions & 8 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import io
import json
from typing import Optional

from pydantic import ValidationError, BaseModel
import pytest
import tempfile
Expand Down Expand Up @@ -1193,20 +1195,75 @@ async def test_async_client_copy(httpserver: HTTPServer):
assert response['status'] == 'success'


def test_headers():
client = Client()
assert client._client.headers['content-type'] == 'application/json'
assert client._client.headers['accept'] == 'application/json'
assert client._client.headers['user-agent'].startswith('ollama-python/')
def custom_header_matcher(header_name: str, actual: Optional[str], expected: str) -> bool:
if header_name == 'User-Agent':
return actual.startswith(expected)
else:
return actual == expected


def test_headers(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False,
},
header_value_matcher=custom_header_matcher,
headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'User-Agent': 'ollama-python/'},
).respond_with_json(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': "I don't know.",
},
}
)

client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == "I don't know."


def test_custom_headers(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False,
},
header_value_matcher=custom_header_matcher,
headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'User-Agent': 'ollama-python/', 'X-Custom': 'value'},
).respond_with_json(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': "I don't know.",
},
}
)

client = Client(
httpserver.url_for('/'),
headers={
'X-Custom': 'value',
'Content-Type': 'text/plain',
}
},
)
assert client._client.headers['x-custom'] == 'value'
assert client._client.headers['content-type'] == 'application/json'
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == "I don't know."


def test_copy_tools():
Expand Down
Loading