Skip to content

Commit

Permalink
Bug - [InferenceClient] - use proxy set in var env (#2421)
Browse files Browse the repository at this point in the history
* set trust_env in aiohttp at True

* handle trust_env and set trust_env parameter in aiohttp client for Async and request client for Sync

* handle trust_env and set trust_env parameter in aiohttp client for Async and request client for Sync

* handle trust_env and set trust_env parameter in aiohttp client for Async and request client for Sync

* Update src/huggingface_hub/inference/_generated/_async_client.py

Co-authored-by: Benjamin BERNARD <[email protected]>

* do not modify InferenceClient

* Add trust_env parameter only in AsyncInferenceClient + respect proxy

* document proxies and trust_env parameters

* remove newlines

---------

Co-authored-by: Benjamin BERNARD <[email protected]>
Co-authored-by: Lucain Pouget <[email protected]>
  • Loading branch information
3 people authored Aug 13, 2024
1 parent 9a9b8c1 commit c9c39b8
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 24 deletions.
2 changes: 2 additions & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class InferenceClient:
Values in this dictionary will override the default values.
cookies (`Dict[str, str]`, `optional`):
Additional cookies to send to the server.
proxies (`Any`, `optional`):
Proxies to use for the request.
base_url (`str`, `optional`):
Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
Expand Down
43 changes: 31 additions & 12 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@

if TYPE_CHECKING:
import numpy as np
from aiohttp import ClientSession
from PIL.Image import Image

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -133,6 +134,10 @@ class AsyncInferenceClient:
Values in this dictionary will override the default values.
cookies (`Dict[str, str]`, `optional`):
Additional cookies to send to the server.
trust_env ('bool', 'optional'):
Trust environment settings for proxy configuration if the parameter is `True` (`False` by default).
proxies (`Any`, `optional`):
Proxies to use for the request.
base_url (`str`, `optional`):
Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
Expand All @@ -150,6 +155,7 @@ def __init__(
timeout: Optional[float] = None,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
trust_env: bool = False,
proxies: Optional[Any] = None,
# OpenAI compatibility
base_url: Optional[str] = None,
Expand All @@ -176,6 +182,7 @@ def __init__(
self.headers.update(headers)
self.cookies = cookies
self.timeout = timeout
self.trust_env = trust_env
self.proxies = proxies

# OpenAI compatibility
Expand Down Expand Up @@ -265,7 +272,7 @@ async def post(
warnings.warn("Ignoring `json` as `data` is passed as binary.")

# Set Accept header if relevant
headers = self.headers.copy()
headers = dict()
if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
headers["Accept"] = "image/png"

Expand All @@ -275,9 +282,7 @@ async def post(
with _open_as_binary(data) as data_as_binary:
# Do not use context manager as we don't want to close the connection immediately when returning
# a stream
client = aiohttp.ClientSession(
headers=headers, cookies=self.cookies, timeout=aiohttp.ClientTimeout(self.timeout)
)
client = self._get_client_session(headers=headers)

try:
response = await client.post(url, json=json, data=data_as_binary, proxy=self.proxies)
Expand Down Expand Up @@ -1299,8 +1304,8 @@ def _unpack_response(framework: str, items: List[Dict]) -> None:
models_by_task.setdefault(model["task"], []).append(model["model_id"])

async def _fetch_framework(framework: str) -> None:
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}")
async with self._get_client_session() as client:
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}", proxy=self.proxies)
response.raise_for_status()
_unpack_response(framework, await response.json())

Expand Down Expand Up @@ -2581,6 +2586,20 @@ async def zero_shot_image_classification(
)
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)

def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
aiohttp = _import_aiohttp()
client_headers = self.headers.copy()
if headers is not None:
client_headers.update(headers)

# Return a new aiohttp ClientSession with correct settings.
return aiohttp.ClientSession(
headers=client_headers,
cookies=self.cookies,
timeout=aiohttp.ClientTimeout(self.timeout),
trust_env=self.trust_env,
)

def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
model = model or self.model or self.base_url

Expand Down Expand Up @@ -2687,8 +2706,8 @@ async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, A
else:
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"

async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(url)
async with self._get_client_session() as client:
response = await client.get(url, proxy=self.proxies)
response.raise_for_status()
return await response.json()

Expand Down Expand Up @@ -2724,8 +2743,8 @@ async def health_check(self, model: Optional[str] = None) -> bool:
)
url = model.rstrip("/") + "/health"

async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(url)
async with self._get_client_session() as client:
response = await client.get(url, proxy=self.proxies)
return response.status == 200

async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
Expand Down Expand Up @@ -2766,8 +2785,8 @@ async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
raise NotImplementedError("Model status is only available for Inference API endpoints.")
url = f"{INFERENCE_ENDPOINT}/status/{model}"

async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(url)
async with self._get_client_session() as client:
response = await client.get(url, proxy=self.proxies)
response.raise_for_status()
response_data = await response.json()

Expand Down
68 changes: 56 additions & 12 deletions utils/generate_async_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def generate_async_client_code(code: str) -> str:
# Adapt /info and /health endpoints
code = _adapt_info_and_health_endpoints(code)

# Add _get_client_session
code = _add_get_client_session(code)

# Adapt the proxy client (for client.chat.completions.create)
code = _adapt_proxy_client(code)

Expand Down Expand Up @@ -186,7 +189,7 @@ def _rename_to_AsyncInferenceClient(code: str) -> str:
warnings.warn("Ignoring `json` as `data` is passed as binary.")
# Set Accept header if relevant
headers = self.headers.copy()
headers = dict()
if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
headers["Accept"] = "image/png"
Expand All @@ -196,9 +199,7 @@ def _rename_to_AsyncInferenceClient(code: str) -> str:
with _open_as_binary(data) as data_as_binary:
# Do not use context manager as we don't want to close the connection immediately when returning
# a stream
client = aiohttp.ClientSession(
headers=headers, cookies=self.cookies, timeout=aiohttp.ClientTimeout(self.timeout)
)
client = self._get_client_session(headers=headers)
try:
response = await client.post(url, json=json, data=data_as_binary, proxy=self.proxies)
Expand Down Expand Up @@ -420,8 +421,8 @@ def _adapt_get_model_status(code: str) -> str:
response_data = response.json()"""

async_snippet = """
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(url)
async with self._get_client_session() as client:
response = await client.get(url, proxy=self.proxies)
response.raise_for_status()
response_data = await response.json()"""

Expand All @@ -437,8 +438,8 @@ def _adapt_list_deployed_models(code: str) -> str:

async_snippet = """
async def _fetch_framework(framework: str) -> None:
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}")
async with self._get_client_session() as client:
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}", proxy=self.proxies)
response.raise_for_status()
_unpack_response(framework, await response.json())
Expand All @@ -456,8 +457,8 @@ def _adapt_info_and_health_endpoints(code: str) -> str:
return response.json()"""

info_async_snippet = """
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(url)
async with self._get_client_session() as client:
response = await client.get(url, proxy=self.proxies)
response.raise_for_status()
return await response.json()"""

Expand All @@ -468,20 +469,63 @@ def _adapt_info_and_health_endpoints(code: str) -> str:
return response.status_code == 200"""

health_async_snippet = """
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(url)
async with self._get_client_session() as client:
response = await client.get(url, proxy=self.proxies)
return response.status == 200"""

return code.replace(health_sync_snippet, health_async_snippet)


def _add_get_client_session(code: str) -> str:
# Add trust_env as parameter
code = _add_before(code, "proxies: Optional[Any] = None,", "trust_env: bool = False,")
code = _add_before(code, "\n self.proxies = proxies\n", "\n self.trust_env = trust_env")

# Document `trust_env` parameter
code = _add_before(
code,
"\n proxies (`Any`, `optional`):",
"""
trust_env ('bool', 'optional'):
Trust environment settings for proxy configuration if the parameter is `True` (`False` by default).""",
)

# insert `_get_client_session` before `_resolve_url` method
client_session_code = """
def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
aiohttp = _import_aiohttp()
client_headers = self.headers.copy()
if headers is not None:
client_headers.update(headers)
# Return a new aiohttp ClientSession with correct settings.
return aiohttp.ClientSession(
headers=client_headers,
cookies=self.cookies,
timeout=aiohttp.ClientTimeout(self.timeout),
trust_env=self.trust_env,
)
"""
code = _add_before(code, "\n def _resolve_url(", client_session_code)

return code


def _adapt_proxy_client(code: str) -> str:
return code.replace(
"def __init__(self, client: InferenceClient):",
"def __init__(self, client: AsyncInferenceClient):",
)


def _add_before(code: str, pattern: str, addition: str) -> str:
index = code.find(pattern)
assert index != -1, f"Pattern '{pattern}' not found in code."
return code[:index] + addition + code[index:]


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down

0 comments on commit c9c39b8

Please sign in to comment.