From 4e57490e24390114614c2381456c6cd3da5fa965 Mon Sep 17 00:00:00 2001 From: Daniel Baxter Date: Mon, 13 Jan 2025 13:17:00 -0500 Subject: [PATCH] Adding the ability to create an Index with an IndexConfig set --- cohere/compass/clients/compass.py | 9 +++++++-- cohere/compass/models/config.py | 12 ++++++++++++ pyproject.toml | 2 +- tests/test_compass_client.py | 15 +++++++++++++++ 4 files changed, 35 insertions(+), 3 deletions(-) diff --git a/cohere/compass/clients/compass.py b/cohere/compass/clients/compass.py index 7bc51e7..6ab5975 100644 --- a/cohere/compass/clients/compass.py +++ b/cohere/compass/clients/compass.py @@ -59,6 +59,7 @@ SearchInput, UploadDocumentsInput, ) +from cohere.compass.models.config import IndexConfig from cohere.compass.models.datasources import PaginatedList from cohere.compass.models.documents import DocumentAttributes, PutDocumentsResponse @@ -150,11 +151,14 @@ def __init__( "list_datasources_objects_states": "/api/v1/datasources/{datasource_id}/documents?skip={skip}&limit={limit}", # noqa: E501 } - def create_index(self, *, index_name: str): + def create_index( + self, *, index_name: str, index_config: Optional[IndexConfig] = None + ): """ Create an index in Compass. :param index_name: the name of the index + :param index_config: the optional configuration for the index :returns: the response from the Compass API """ return self._send_request( @@ -162,6 +166,7 @@ def create_index(self, *, index_name: str): max_retries=DEFAULT_MAX_RETRIES, sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS, index_name=index_name, + data=index_config, ) def refresh_index(self, *, index_name: str): @@ -798,7 +803,7 @@ def _send_request_with_retry(): nonlocal error try: - data_dict = data.model_dump(mode="json") if data else None + data_dict = data.model_dump(mode="json", exclude_none=True) if data else None headers = None auth = None diff --git a/cohere/compass/models/config.py b/cohere/compass/models/config.py index 747d593..62c92da 100644 --- a/cohere/compass/models/config.py +++ b/cohere/compass/models/config.py @@ -192,3 +192,15 @@ class MetadataConfig(ValidatedModel): keyword_search_attributes: list[str] = METADATA_HEURISTICS_ATTRIBUTES keyword_search_separator: str = "." ignore_errors: bool = True + +class IndexConfig(BaseModel): + """ + A model class for specifying configuration related to a search index. + + :param number_of_shards: the total number of shards to split the index into + :param number_of_replicas: the number of replicas for each shard. Number of shards + will be multiplied by this number to determine the total number of shards used. + """ + + number_of_shards: Optional[int] = None + number_of_replicas: Optional[int] = None diff --git a/pyproject.toml b/pyproject.toml index 7f2dbde..adf791d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "compass-sdk" -version = "0.11.2" +version = "0.11.3" authors = [] description = "Compass SDK" readme = "README.md" diff --git a/tests/test_compass_client.py b/tests/test_compass_client.py index cabcfcc..2e15c3e 100644 --- a/tests/test_compass_client.py +++ b/tests/test_compass_client.py @@ -4,6 +4,7 @@ from cohere.compass.clients import CompassClient from cohere.compass.exceptions import CompassClientError from cohere.compass.models import CompassDocument +from cohere.compass.models.config import IndexConfig from cohere.compass.models.documents import DocumentAttributes @@ -27,6 +28,20 @@ def test_create_index_formatted_with_index(requests_mock: Mocker): assert requests_mock.request_history[0].method == "PUT" +def test_create_index_with_index_config(requests_mock: Mocker): + compass = CompassClient(index_url="http://test.com") + compass.create_index( + index_name="test_index", + index_config=IndexConfig(number_of_shards=5) + ) + assert ( + requests_mock.request_history[0].url + == "http://test.com/api/v1/indexes/test_index" + ) + assert requests_mock.request_history[0].method == "PUT" + assert requests_mock.request_history[0].json() == {"number_of_shards": 5} + + def test_put_documents_payload_and_url_exist(requests_mock: Mocker): compass = CompassClient(index_url="http://test.com") compass.insert_docs(index_name="test_index", docs=iter([CompassDocument()]))