Skip to content

Commit

Permalink
Adding the ability to create an Index with an IndexConfig set
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaxter-cohere committed Jan 13, 2025
1 parent f082124 commit b6fe262
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 3 deletions.
11 changes: 9 additions & 2 deletions cohere/compass/clients/compass.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
SearchInput,
UploadDocumentsInput,
)
from cohere.compass.models.config import IndexConfig
from cohere.compass.models.datasources import PaginatedList
from cohere.compass.models.documents import DocumentAttributes, PutDocumentsResponse

Expand Down Expand Up @@ -150,18 +151,22 @@ def __init__(
"list_datasources_objects_states": "/api/v1/datasources/{datasource_id}/documents?skip={skip}&limit={limit}", # noqa: E501
}

def create_index(self, *, index_name: str):
def create_index(
self, *, index_name: str, index_config: Optional[IndexConfig] = None
):
"""
Create an index in Compass.
:param index_name: the name of the index
:param index_config: the optional configuration for the index
:returns: the response from the Compass API
"""
return self._send_request(
api_name="create_index",
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
index_name=index_name,
data=index_config,
)

def refresh_index(self, *, index_name: str):
Expand Down Expand Up @@ -798,7 +803,9 @@ def _send_request_with_retry():
nonlocal error

try:
data_dict = data.model_dump(mode="json") if data else None
data_dict = (
data.model_dump(mode="json", exclude_none=True) if data else None
)

headers = None
auth = None
Expand Down
13 changes: 13 additions & 0 deletions cohere/compass/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,16 @@ class MetadataConfig(ValidatedModel):
keyword_search_attributes: list[str] = METADATA_HEURISTICS_ATTRIBUTES
keyword_search_separator: str = "."
ignore_errors: bool = True


class IndexConfig(BaseModel):
"""
A model class for specifying configuration related to a search index.
:param number_of_shards: the total number of shards to split the index into
:param number_of_replicas: the number of replicas for each shard. Number of shards
will be multiplied by this number to determine the total number of shards used.
"""

number_of_shards: Optional[int] = None
number_of_replicas: Optional[int] = None
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "compass-sdk"
version = "0.11.2"
version = "0.12.0"
authors = []
description = "Compass SDK"
readme = "README.md"
Expand Down
14 changes: 14 additions & 0 deletions tests/test_compass_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from cohere.compass.clients import CompassClient
from cohere.compass.exceptions import CompassClientError
from cohere.compass.models import CompassDocument
from cohere.compass.models.config import IndexConfig
from cohere.compass.models.documents import DocumentAttributes


Expand All @@ -27,6 +28,19 @@ def test_create_index_formatted_with_index(requests_mock: Mocker):
assert requests_mock.request_history[0].method == "PUT"


def test_create_index_with_index_config(requests_mock: Mocker):
compass = CompassClient(index_url="http://test.com")
compass.create_index(
index_name="test_index", index_config=IndexConfig(number_of_shards=5)
)
assert (
requests_mock.request_history[0].url
== "http://test.com/api/v1/indexes/test_index"
)
assert requests_mock.request_history[0].method == "PUT"
assert requests_mock.request_history[0].json() == {"number_of_shards": 5}


def test_put_documents_payload_and_url_exist(requests_mock: Mocker):
compass = CompassClient(index_url="http://test.com")
compass.insert_docs(index_name="test_index", docs=iter([CompassDocument()]))
Expand Down

0 comments on commit b6fe262

Please sign in to comment.