Skip to content

Commit

Permalink
Update document name in models
Browse files Browse the repository at this point in the history
  • Loading branch information
benrules3 committed Dec 6, 2024
1 parent ec32e5b commit 4917353
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 30 deletions.
36 changes: 18 additions & 18 deletions cohere/compass/clients/compass.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def __init__(
"create_index": "/api/v1/indexes/{index_name}",
"list_indexes": "/api/v1/indexes",
"delete_index": "/api/v1/indexes/{index_name}",
"delete_document": "/api/v1/indexes/{index_name}/documents/{doc_id}",
"get_document": "/api/v1/indexes/{index_name}/documents/{doc_id}",
"delete_document": "/api/v1/indexes/{index_name}/documents/{document_id}",
"get_document": "/api/v1/indexes/{index_name}/documents/{document_id}",
"put_documents": "/api/v1/indexes/{index_name}/documents",
"search_documents": "/api/v1/indexes/{index_name}/documents/_search",
"search_chunks": "/api/v1/indexes/{index_name}/documents/_search_chunks",
"add_context": "/api/v1/indexes/{index_name}/documents/add_context/{doc_id}",
"add_context": "/api/v1/indexes/{index_name}/documents/add_context/{document_id}",
"refresh": "/api/v1/indexes/{index_name}/_refresh",
"upload_documents": "/api/v1/indexes/{index_name}/documents/_upload",
"edit_group_authorization": "/api/v1/indexes/{index_name}/group_authorization",
Expand Down Expand Up @@ -185,31 +185,31 @@ def delete_index(self, *, index_name: str):
index_name=index_name,
)

def delete_document(self, *, index_name: str, doc_id: str):
def delete_document(self, *, index_name: str, document_id: str):
"""
Delete a document from Compass
:param index_name: the name of the index
:doc_id: the id of the document
:document_id: the id of the document
:return: the response from the Compass API
"""
return self._send_request(
api_name="delete_document",
doc_id=doc_id,
document_id=document_id,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
index_name=index_name,
)

def get_document(self, *, index_name: str, doc_id: str):
def get_document(self, *, index_name: str, document_id: str):
"""
Get a document from Compass
:param index_name: the name of the index
:doc_id: the id of the document
:document_id: the id of the document
:return: the response from the Compass API
"""
return self._send_request(
api_name="get_document",
doc_id=doc_id,
document_id=document_id,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
index_name=index_name,
Expand All @@ -231,7 +231,7 @@ def add_context(
self,
*,
index_name: str,
doc_id: str,
document_id: str,
context: dict[str, Any],
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
Expand All @@ -240,15 +240,15 @@ def add_context(
Update the content field of an existing document with additional context
:param index_name: the name of the index
:param doc_id: the document to modify
:param document_id: the document to modify
:param context: A dictionary of key:value pairs to insert into the content field of a document
:param max_retries: the maximum number of times to retry a doc insertion
:param sleep_retry_seconds: number of seconds to go to sleep before retrying a doc insertion
"""

return self._send_request(
api_name="add_context",
doc_id=doc_id,
document_id=document_id,
data=context,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
Expand Down Expand Up @@ -373,7 +373,7 @@ def put_request(
compass_doc for compass_doc, _ in request_data
]
put_docs_input = PutDocumentsInput(
docs=[input_doc for _, input_doc in request_data],
documents=[input_doc for _, input_doc in request_data],
authorized_groups=authorized_groups,
merge_groups_on_conflict=merge_groups_on_conflict,
)
Expand Down Expand Up @@ -401,7 +401,7 @@ def put_request(
)
errors.append(
{
doc.metadata.doc_id: f"{doc.metadata.filename}: {results.error}"
doc.metadata.document_id: f"{doc.metadata.filename}: {results.error}"
}
)
else:
Expand Down Expand Up @@ -570,9 +570,9 @@ def _get_request_blocks(
num_chunks = 0
for _, doc in enumerate(docs, 1):
if doc.status != CompassDocumentStatus.Success:
logger.error(f"Document {doc.metadata.doc_id} has errors: {doc.errors}")
logger.error(f"Document {doc.metadata.document_id} has errors: {doc.errors}")
for error in doc.errors:
errors.append({doc.metadata.doc_id: list(error.values())[0]})
errors.append({doc.metadata.document_id: list(error.values())[0]})
else:
num_chunks += (
len(doc.chunks)
Expand All @@ -588,8 +588,8 @@ def _get_request_blocks(
(
doc,
Document(
doc_id=doc.metadata.doc_id,
parent_doc_id=doc.metadata.parent_doc_id,
document_id=doc.metadata.document_id,
parent_document_id=doc.metadata.parent_document_id,
path=doc.metadata.filename,
content=doc.content,
chunks=[Chunk(**c.model_dump()) for c in doc.chunks],
Expand Down
37 changes: 36 additions & 1 deletion cohere/compass/clients/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def process_file(
docs: list[CompassDocument] = []
for doc in res.json()["docs"]:
if not doc.get("errors", []):
compass_doc = CompassDocument(**doc)
compass_doc = self._adapt_doc_id_compass_doc(doc)
additional_metadata = CompassParserClient._get_metadata(
doc=compass_doc, custom_context=custom_context
)
Expand All @@ -268,3 +268,38 @@ def process_file(
logger.error(f"Error processing file: {res.text}")

return docs

@staticmethod
def _adapt_doc_id_compass_doc(doc: Dict[Any,Any]) -> CompassDocument:
"""
Adapt the doc_id to document_id
"""

metadata = doc['metadata']
if not 'document_id' in metadata:
metadata['document_id'] = metadata.pop('doc_id')
metadata['parent_document_id'] = metadata.pop('parent_doc_id')

chunks = doc['chunks']
for chunk in chunks:
if not 'parent_document_id' in chunk:
chunk['parent_document_id'] = chunk.pop('parent_doc_id')
if not 'document_id' in chunk:
chunk['document_id'] = chunk.pop('doc_id')
if not 'path' in chunk:
chunk['path'] = doc['metadata']['filename']

res = CompassDocument(
filebytes=doc['filebytes'],
metadata=metadata,
content=doc['content'],
content_type=doc['content_type'],
elements=doc['elements'],
chunks=chunks,
index_fields=doc['index_fields'],
errors=doc['errors'],
ignore_metadata_errors=doc['ignore_metadata_errors'],
markdown=doc['markdown']
)

return res
27 changes: 16 additions & 11 deletions cohere/compass/models/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class CompassDocumentMetadata(ValidatedModel):
Compass document metadata
"""

doc_id: str = ""
document_id: str = ""
filename: str = ""
meta: list[Any] = field(default_factory=list)
parent_doc_id: str = ""
parent_document_id: str = ""


class CompassDocumentChunkAsset(BaseModel):
Expand All @@ -30,14 +30,15 @@ class CompassDocumentChunkAsset(BaseModel):
class CompassDocumentChunk(BaseModel):
chunk_id: str
sort_id: str
doc_id: str
parent_doc_id: str
document_id: str
parent_document_id: str
content: Dict[str, Any]
origin: Optional[Dict[str, Any]] = None
assets: Optional[list[CompassDocumentChunkAsset]] = None
path: Optional[str] = ""

def parent_doc_is_split(self):
return self.doc_id != self.parent_doc_id
return self.document_id != self.parent_document_id


class CompassDocumentStatus(str, Enum):
Expand Down Expand Up @@ -140,23 +141,27 @@ class DocumentChunkAsset(BaseModel):
class Chunk(BaseModel):
chunk_id: str
sort_id: int
parent_document_id: str
path: str = ""
content: Dict[str, Any]
origin: Optional[Dict[str, Any]] = None
assets: Optional[list[DocumentChunkAsset]] = None
parent_doc_id: str
assets: Optional[List[DocumentChunkAsset]] = None
asset_ids: Optional[List[str]] = None



class Document(BaseModel):
"""
A document that can be indexed in Compass (i.e., a list of indexable chunks)
"""

doc_id: str
document_id: str
path: str
parent_doc_id: str
parent_document_id: str
content: Dict[str, Any]
chunks: List[Chunk]
index_fields: List[str] = field(default_factory=list)
index_fields: Optional[List[str]] = None
authorized_groups: Optional[List[str]] = None


class ParseableDocument(BaseModel):
Expand All @@ -183,6 +188,6 @@ class PutDocumentsInput(BaseModel):
A Compass request to put a list of Document
"""

docs: List[Document]
documents: List[Document]
authorized_groups: Optional[List[str]] = None
merge_groups_on_conflict: bool = False

0 comments on commit 4917353

Please sign in to comment.