Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add OpenAPI support, OpenAPITool component #8

Merged
merged 42 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
472b2de
Initial openapi impl
vblagoje May 28, 2024
3fd99ac
Refactoring step 1
vblagoje Jun 5, 2024
20be1e5
Refactoring step 2
vblagoje Jun 5, 2024
adb96c5
Refactor step 3
vblagoje Jun 5, 2024
221df1b
Refactoring step 4
vblagoje Jun 5, 2024
e2ecd8d
Add OpenAPITool initial impl
vblagoje Jun 5, 2024
992ef6e
Add headers
vblagoje Jun 5, 2024
910a7b2
Refactoring step 5 - move things around
vblagoje Jun 5, 2024
14d38e9
Fix linting
vblagoje Jun 5, 2024
406aa1b
Refactoring step 6 - simplify generator factory
vblagoje Jun 5, 2024
e5db2d0
Cosmetics
vblagoje Jun 5, 2024
4b35fdf
Add model_kwargs to OpenAPITool init
vblagoje Jun 5, 2024
7bca18f
Fix double ClientConfiguration creation
vblagoje Jun 5, 2024
72922a0
Update internal pydoc
vblagoje Jun 6, 2024
d996cc5
PR feedback
vblagoje Jun 6, 2024
2f01180
Remove lazy imports
vblagoje Jun 6, 2024
4d90cee
Typing fixes
vblagoje Jun 6, 2024
e57ac8a
Add lazy imports
vblagoje Jun 6, 2024
1b87f82
Expose LLMProvider
vblagoje Jun 6, 2024
7e09480
Avoid circular deps
vblagoje Jun 6, 2024
a7fcadc
Add header for types.py
vblagoje Jun 6, 2024
190f113
Improve pydoc
vblagoje Jun 7, 2024
d21bcad
Add back in http bearer auth
vblagoje Jun 17, 2024
e264ec2
Add firecrawl openapi conversion tests
vblagoje Jun 18, 2024
57bb344
PR feedback
vblagoje Jun 18, 2024
df71f3a
Fix header
vblagoje Jun 18, 2024
bfdcf94
PR feedback - details
vblagoje Jun 18, 2024
c6cca91
Lift up OpenAPISpecification
vblagoje Jun 18, 2024
05e38f2
Update OpenAPITool
vblagoje Jun 18, 2024
5de7003
Final touches
vblagoje Jun 18, 2024
32d8ca0
Final touches - pydoc
vblagoje Jun 18, 2024
4f55c96
Minor detail around ClientConfiguration LLMProvider setting
vblagoje Jun 18, 2024
dfcae5c
Make use of OpenAPISpecification explicit
vblagoje Jun 18, 2024
503e71f
First batch of OpenAPITool unit and integration tests
vblagoje Jun 18, 2024
d371dba
Add serde and unit tests
vblagoje Jun 18, 2024
a87d446
Merge branch 'main' into openapi
vblagoje Jun 18, 2024
e5c76c7
Skip github test
vblagoje Jun 18, 2024
ecdc37d
Skip github tests
vblagoje Jun 18, 2024
a5db73b
Increase default request timeout to 30 sec
vblagoje Jun 19, 2024
37e6e54
PR review
vblagoje Jun 19, 2024
7eb140f
Add to Experiments catalog
vblagoje Jun 24, 2024
d714a6f
Merge branch 'main' into openapi
vblagoje Jun 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions haystack_experimental/components/tools/openapi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from haystack_experimental.components.tools.openapi.openapi_tool import OpenAPITool
from haystack_experimental.components.tools.openapi.types import LLMProvider

__all__ = ["LLMProvider", "OpenAPITool"]
341 changes: 341 additions & 0 deletions haystack_experimental/components/tools/openapi/_openapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import logging
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional

import requests

from haystack_experimental.components.tools.openapi._payload_extraction import (
create_function_payload_extractor,
)
from haystack_experimental.components.tools.openapi._schema_conversion import (
anthropic_converter,
cohere_converter,
openai_converter,
)
from haystack_experimental.components.tools.openapi.types import LLMProvider, OpenAPISpecification, Operation

MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3
logger = logging.getLogger(__name__)


def send_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""
Send an HTTP request and return the response.

:param request: The request to send.
:returns: The response from the server.
"""
url = request["url"]
headers = {**request.get("headers", {})}
try:
response = requests.request(
request["method"],
url,
headers=headers,
params=request.get("params", {}),
json=request.get("json"),
auth=request.get("auth"),
timeout=30,
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
logger.warning("HTTP error occurred: %s while sending request to %s", e, url)
raise HttpClientError(f"HTTP error occurred: {e}") from e
except requests.exceptions.RequestException as e:
logger.warning("Request error occurred: %s while sending request to %s", e, url)
raise HttpClientError(f"HTTP error occurred: {e}") from e
except Exception as e:
logger.warning("An error occurred: %s while sending request to %s", e, url)
raise HttpClientError(f"An error occurred: {e}") from e


# Authentication strategies
def create_api_key_auth_function(api_key: str) -> Callable[[Dict[str, Any], Dict[str, Any]], None]:
"""
Create a function that applies the API key authentication strategy to a given request.

:param api_key: the API key to use for authentication.
:returns: a function that applies the API key authentication to a request
at the schema specified location.
"""

def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]) -> None:
"""
Apply the API key authentication strategy to the given request.

:param security_scheme: the security scheme from the OpenAPI spec.
:param request: the request to apply the authentication to.
"""
if security_scheme["in"] == "header":
request.setdefault("headers", {})[security_scheme["name"]] = api_key
elif security_scheme["in"] == "query":
request.setdefault("params", {})[security_scheme["name"]] = api_key
elif security_scheme["in"] == "cookie":
request.setdefault("cookies", {})[security_scheme["name"]] = api_key
else:
raise ValueError(
f"Unsupported apiKey authentication location: {security_scheme['in']}, "
f"must be one of 'header', 'query', or 'cookie'"
)

return apply_auth


def create_http_auth_function(token: str) -> Callable[[Dict[str, Any], Dict[str, Any]], None]:
"""
Create a function that applies the http authentication strategy to a given request.

:param token: the authentication token to use.
:returns: a function that applies the API key authentication to a request
at the schema specified location.
"""

def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]) -> None:
"""
Apply the HTTP authentication strategy to the given request.

:param security_scheme: the security scheme from the OpenAPI spec.
:param request: the request to apply the authentication to.
"""
if security_scheme["type"] == "http":
# support bearer http auth, no basic support yet
if security_scheme["scheme"].lower() == "bearer":
if not token:
raise ValueError("Token must be provided for Bearer Auth.")
request.setdefault("headers", {})[
"Authorization"
] = f"Bearer {token}"
else:
raise ValueError(
f"Unsupported HTTP authentication scheme: {security_scheme['scheme']}"
)
else:
raise ValueError(
"HTTPAuthentication strategy received a non-HTTP security scheme."
)

return apply_auth


class HttpClientError(Exception):
"""Exception raised for errors in the HTTP client."""


class ClientConfiguration:
"""Configuration for the OpenAPI client."""

def __init__(
self,
openapi_spec: OpenAPISpecification,
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
credentials: Optional[str] = None,
request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
llm_provider: LLMProvider = LLMProvider.OPENAI,
): # noqa: PLR0913
"""
Initialize a ClientConfiguration instance.

:param openapi_spec: The OpenAPI specification to use for the client.
:param credentials: The credentials to use for authentication.
:param request_sender: The function to use for sending requests.
:param llm_provider: The LLM provider to use for generating tools definitions.
:raises ValueError: If the OpenAPI specification format is invalid.
"""
self.openapi_spec = openapi_spec
self.credentials = credentials
self.request_sender = request_sender or send_request
self.llm_provider: LLMProvider = llm_provider

def get_auth_function(self) -> Callable[[Dict[str, Any], Dict[str, Any]], Any]:
"""
Get the authentication function that sets a schema specified authentication to the request.

The function takes a security scheme and a request as arguments:
`security_scheme: Dict[str, Any] - The security scheme from the OpenAPI spec.`
`request: Dict[str, Any] - The request to apply the authentication to.`
:returns: The authentication function.
:raises ValueError: If the credentials type is not supported.
"""
security_schemes = self.openapi_spec.get_security_schemes()
if not self.credentials:
return lambda security_scheme, request: None # No-op function
if isinstance(self.credentials, str):
return self._create_authentication_from_string(
self.credentials, security_schemes
)
raise ValueError(f"Unsupported credentials type: {type(self.credentials)}")

def get_tools_definitions(self) -> List[Dict[str, Any]]:
"""
Get the tools definitions used as tools LLM parameter.

:returns: The tools definitions passed to the LLM as tools parameter.
"""
provider_to_converter = defaultdict(
lambda: openai_converter,
{
LLMProvider.ANTHROPIC: anthropic_converter,
LLMProvider.COHERE: cohere_converter,
}
)
converter = provider_to_converter[self.llm_provider]
return converter(self.openapi_spec)

def get_payload_extractor(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
"""
Get the payload extractor for the LLM provider.

This function knows how to extract the exact function payload from the LLM generated function calling payload.
:returns: The payload extractor function.
"""
provider_to_arguments_field_name = defaultdict(
lambda: "arguments",
{
LLMProvider.ANTHROPIC: "input",
LLMProvider.COHERE: "parameters",
}
)
arguments_field_name = provider_to_arguments_field_name[self.llm_provider]
return create_function_payload_extractor(arguments_field_name)

def _create_authentication_from_string(
self, credentials: str, security_schemes: Dict[str, Any]
) -> Callable[[Dict[str, Any], Dict[str, Any]], Any]:
for scheme in security_schemes.values():
if scheme["type"] == "apiKey":
return create_api_key_auth_function(api_key=credentials)
if scheme["type"] == "http":
return create_http_auth_function(token=credentials)
raise ValueError(
f"Unsupported authentication type '{scheme['type']}' provided."
)
raise ValueError(
f"Unable to create authentication from provided credentials: {credentials}"
)


def build_request(operation: Operation, **kwargs) -> Dict[str, Any]:
"""
Build an HTTP request for the operation.

:param operation: The operation to build the request for.
:param kwargs: The arguments to use for building the request.
:returns: The HTTP request as a dictionary.
:raises ValueError: If a required parameter is missing.
:raises NotImplementedError: If the request body content type is not supported. We only support JSON payloads.
"""
path = operation.path
for parameter in operation.get_parameters("path"):
param_value = kwargs.get(parameter["name"], None)
if param_value:
path = path.replace(f"{{{parameter['name']}}}", str(param_value))
elif parameter.get("required", False):
raise ValueError(f"Missing required path parameter: {parameter['name']}")
url = operation.get_server() + path
# method
method = operation.method.lower()
# headers
headers = {}
for parameter in operation.get_parameters("header"):
param_value = kwargs.get(parameter["name"], None)
if param_value:
headers[parameter["name"]] = str(param_value)
elif parameter.get("required", False):
raise ValueError(f"Missing required header parameter: {parameter['name']}")
# query params
query_params = {}
for parameter in operation.get_parameters("query"):
param_value = kwargs.get(parameter["name"], None)
if param_value:
query_params[parameter["name"]] = param_value
elif parameter.get("required", False):
raise ValueError(f"Missing required query parameter: {parameter['name']}")

json_payload = None
request_body = operation.request_body
if request_body:
content = request_body.get("content", {})
if "application/json" in content:
json_payload = {**kwargs}
else:
raise NotImplementedError("Request body content type not supported")
return {
"url": url,
"method": method,
"headers": headers,
"params": query_params,
"json": json_payload,
}


def apply_authentication(
auth_strategy: Callable[[Dict[str, Any], Dict[str, Any]], Any],
operation: Operation,
request: Dict[str, Any],
):
"""
Apply the authentication strategy to the given request.

:param auth_strategy: The authentication strategy to apply.
This is a function that takes a security scheme and a request as arguments (at runtime)
and applies the authentication
:param operation: The operation to apply the authentication to.
:param request: The request to apply the authentication to.
"""
security_requirements = operation.security_requirements
security_schemes = operation.spec_dict.get("components", {}).get(
"securitySchemes", {}
)
if security_requirements:
for requirement in security_requirements:
for scheme_name in requirement:
if scheme_name in security_schemes:
security_scheme = security_schemes[scheme_name]
auth_strategy(security_scheme, request)
break


class OpenAPIServiceClient:
"""
A client for invoking operations on REST services defined by OpenAPI specifications.
"""

def __init__(self, client_config: ClientConfiguration):
self.client_config = client_config

def invoke(self, function_payload: Any) -> Any:
"""
Invokes a function specified in the function payload.

:param function_payload: The function payload containing the details of the function to be invoked.
:returns: The response from the service after invoking the function.
:raises OpenAPIClientError: If the function invocation payload cannot be extracted from the function payload.
:raises HttpClientError: If an error occurs while sending the request and receiving the response.
"""
fn_invocation_payload = {}
try:
fn_extractor = self.client_config.get_payload_extractor()
fn_invocation_payload = fn_extractor(function_payload)
except Exception as e:
raise OpenAPIClientError(
f"Error extracting function invocation payload: {str(e)}"
) from e

if "name" not in fn_invocation_payload or "arguments" not in fn_invocation_payload:
raise OpenAPIClientError(
f"Function invocation payload does not contain 'name' or 'arguments' keys: {fn_invocation_payload}, "
f"the payload extraction function may be incorrect."
)
# fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on
operation = self.client_config.openapi_spec.find_operation_by_id(fn_invocation_payload["name"])
request = build_request(operation, **fn_invocation_payload["arguments"])
apply_authentication(self.client_config.get_auth_function(), operation, request)
return self.client_config.request_sender(request)


class OpenAPIClientError(Exception):
"""Exception raised for errors in the OpenAPI client."""
Loading
Loading