From 4d2ea7cfc19b6c015e2fd9f518761a4a51415b67 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 23 Oct 2024 22:24:35 -0700 Subject: [PATCH 1/3] check in --- py/compose.full.yaml | 3 +- py/core/base/api/models/__init__.py | 2 + py/core/main/api/ingestion_router.py | 71 ++++++++++++++++- .../hatchet/ingestion_workflow.py | 79 +++++++++++++++++++ .../simple/ingestion_workflow.py | 54 +++++++++++++ py/core/main/services/ingestion_service.py | 53 +++++++++++++ py/core/providers/database/vector.py | 1 + 7 files changed, 259 insertions(+), 4 deletions(-) diff --git a/py/compose.full.yaml b/py/compose.full.yaml index 8c7ea4a19..50793aa1b 100644 --- a/py/compose.full.yaml +++ b/py/compose.full.yaml @@ -270,7 +270,8 @@ services: retries: 5 r2r: - image: ${R2R_IMAGE:-ragtoriches/prod:latest} + # image: ${R2R_IMAGE:-ragtoriches/prod:latest} + image: r2r/test build: context: . args: diff --git a/py/core/base/api/models/__init__.py b/py/core/base/api/models/__init__.py index 5fc4ea41d..41a541a87 100644 --- a/py/core/base/api/models/__init__.py +++ b/py/core/base/api/models/__init__.py @@ -9,6 +9,7 @@ from shared.api.models.ingestion.responses import ( CreateVectorIndexResponse, IngestionResponse, + UpdateResponse, WrappedCreateVectorIndexResponse, WrappedDeleteVectorIndexResponse, WrappedIngestionResponse, @@ -87,6 +88,7 @@ "WrappedListVectorIndicesResponse", "WrappedDeleteVectorIndexResponse", "WrappedSelectVectorIndexResponse", + "UpdateResponse", # Knowledge Graph Responses "KGCreationResponse", "WrappedKGCreationResponse", diff --git a/py/core/main/api/ingestion_router.py b/py/core/main/api/ingestion_router.py index 2b2f546a3..b1532cbbf 100644 --- a/py/core/main/api/ingestion_router.py +++ b/py/core/main/api/ingestion_router.py @@ -1,17 +1,18 @@ import base64 import logging from io import BytesIO -from pathlib import Path +from pathlib import Path as pathlib_Path from typing import Optional, Union from uuid import UUID import yaml -from fastapi import Body, Depends, File, Form, Query, UploadFile +from fastapi import Body, Depends, File, Form, Query, UploadFile, Path from pydantic import Json from core.base import R2RException, RawChunk, generate_document_id from core.base.api.models import ( CreateVectorIndexResponse, + UpdateResponse, WrappedCreateVectorIndexResponse, WrappedDeleteVectorIndexResponse, WrappedIngestionResponse, @@ -64,6 +65,11 @@ def _register_workflows(self): if self.orchestration_provider.config.provider != "simple" else "Update task queued successfully." ), + "update-chunk": ( + "Update chunk task queued successfully." + if self.orchestration_provider.config.provider != "simple" + else "Chunk update completed successfully." + ), "create-vector-index": ( "Vector index creation task queued successfully." if self.orchestration_provider.config.provider != "simple" @@ -84,7 +90,9 @@ def _register_workflows(self): def _load_openapi_extras(self): yaml_path = ( - Path(__file__).parent / "data" / "ingestion_router_openapi.yml" + pathlib_Path(__file__).parent + / "data" + / "ingestion_router_openapi.yml" ) with open(yaml_path, "r") as yaml_file: yaml_content = yaml.safe_load(yaml_file) @@ -406,6 +414,63 @@ async def ingest_chunks_app( "task_id": None, } + @self.router.put( + "/update_chunk/{document_id}/{extraction_id}", + ) + @self.base_endpoint + async def update_chunk_app( + document_id: UUID = Path( + ..., description="The document ID of the chunk to update" + ), + extraction_id: UUID = Path( + ..., description="The extraction ID of the chunk to update" + ), + text: str = Body( + ..., description="The new text content for the chunk" + ), + metadata: Optional[dict] = Body( + None, description="Optional updated metadata" + ), + run_with_orchestration: Optional[bool] = Body(True), + auth_user=Depends(self.service.providers.auth.auth_wrapper), + ) -> WrappedUpdateResponse: + try: + workflow_input = { + "document_id": str(document_id), + "extraction_id": str(extraction_id), + "text": text, + "metadata": metadata, + "user": auth_user.model_dump_json(), + } + + if run_with_orchestration: + raw_message: dict[str, Union[str, None]] = await self.orchestration_provider.run_workflow( # type: ignore + "update-chunk", {"request": workflow_input}, {} + ) + raw_message["message"] = "Update task queued successfully." + raw_message["document_ids"] = [document_id] # type: ignore + + return raw_message # type: ignore + else: + logger.info("Running chunk update without orchestration.") + from core.main.orchestration import ( + simple_ingestion_factory, + ) + + simple_ingestor = simple_ingestion_factory(self.service) + await simple_ingestor["update-chunk"](workflow_input) + + return { # type: ignore + "message": "Update task completed successfully.", + "document_ids": workflow_input["document_ids"], + "task_id": None, + } + + except Exception as e: + raise R2RException( + status_code=500, message=f"Error updating chunk: {str(e)}" + ) + create_vector_index_extras = self.openapi_extras.get( "create_vector_index", {} ) diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index 4798a4178..19d2b9338 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -498,6 +498,83 @@ async def on_failure(self, context: Context) -> None: f"Failed to update document status for {document_id}: {e}" ) + @orchestration_provider.workflow( + name="update-chunk", + timeout="60m", + ) + class HatchetUpdateChunkWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="60m") + async def update_chunk(self, context: Context) -> dict: + try: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_update_chunk_input( + input_data + ) + + document_uuid = ( + parsed_data["document_id"] + if isinstance(parsed_data["document_id"], str) + else parsed_data["document_id"] + ) + extraction_uuid = ( + parsed_data["extraction_id"] + if isinstance(parsed_data["extraction_id"], str) + else parsed_data["extraction_id"] + ) + + document_info = ( + await self.ingestion_service.update_chunk_ingress( + **{ + **parsed_data, + "document_id": document_uuid, + "extraction_id": extraction_uuid, + } + ) + ) + + extraction = DocumentExtraction( + id=extraction_uuid, + document_id=document_uuid, + collection_ids=parsed_data.get("collection_ids", []), + user_id=document_info.user_id, + data=parsed_data["text"], + metadata=parsed_data["metadata"], + ).to_dict() + + embedding_generator = ( + await self.ingestion_service.embed_document([extraction]) + ) + embeddings = [ + embedding.to_dict() + async for embedding in embedding_generator + ] + + storage_generator = ( + await self.ingestion_service.store_embeddings(embeddings) + ) + async for _ in storage_generator: + pass + + return { + "message": "Chunk update completed successfully.", + "task_id": context.workflow_run_id(), # or None if not applicable + "document_ids": [str(document_uuid)], + } + + except Exception as e: + raise R2RException( + status_code=500, + message=f"Error during chunk update: {str(e)}", + ) + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + # Handle failure case if necessary + pass + @orchestration_provider.workflow( name="create-vector-index", timeout="360m" ) @@ -545,6 +622,7 @@ async def delete_vector_index(self, context: Context) -> dict: ingest_files_workflow = HatchetIngestFilesWorkflow(service) update_files_workflow = HatchetUpdateFilesWorkflow(service) ingest_chunks_workflow = HatchetIngestChunksWorkflow(service) + update_chunks_workflow = HatchetUpdateChunkWorkflow(service) create_vector_index_workflow = HatchetCreateVectorIndexWorkflow(service) delete_vector_index_workflow = HatchetDeleteVectorIndexWorkflow(service) @@ -552,6 +630,7 @@ async def delete_vector_index(self, context: Context) -> dict: "ingest_files": ingest_files_workflow, "update_files": update_files_workflow, "ingest_chunks": ingest_chunks_workflow, + "update_chunk": update_chunks_workflow, "create_vector_index": create_vector_index_workflow, "delete_vector_index": delete_vector_index_workflow, } diff --git a/py/core/main/orchestration/simple/ingestion_workflow.py b/py/core/main/orchestration/simple/ingestion_workflow.py index 47d04f803..08df9b599 100644 --- a/py/core/main/orchestration/simple/ingestion_workflow.py +++ b/py/core/main/orchestration/simple/ingestion_workflow.py @@ -1,5 +1,6 @@ import asyncio import logging +from uuid import UUID from litellm import AuthenticationError @@ -255,6 +256,58 @@ async def ingest_chunks(input_data): message=f"Error during chunk ingestion: {str(e)}", ) + async def update_chunk(input_data): + try: + from core.main import IngestionServiceAdapter + + parsed_data = IngestionServiceAdapter.parse_update_chunk_input( + input_data + ) + + document_uuid = ( + UUID(parsed_data["document_id"]) + if isinstance(parsed_data["document_id"], str) + else parsed_data["document_id"] + ) + extraction_uuid = ( + UUID(parsed_data["extraction_id"]) + if isinstance(parsed_data["extraction_id"], str) + else parsed_data["extraction_id"] + ) + + document_info = await service.update_chunk_ingress( + **{ + **parsed_data, + "document_id": document_uuid, + "extraction_id": extraction_uuid, + } + ) + + extraction = DocumentExtraction( + id=extraction_uuid, + document_id=document_uuid, + collection_ids=parsed_data.get("collection_ids", []), + user_id=document_info.user_id, + data=parsed_data["text"], + metadata=parsed_data["metadata"], + ).model_dump() + + embedding_generator = await service.embed_document([extraction]) + embeddings = [ + embedding.model_dump() + async for embedding in embedding_generator + ] + + storage_generator = await service.store_embeddings(embeddings) + async for _ in storage_generator: + pass + + except Exception as e: + raise R2RException( + status_code=500, + message=f"Error during chunk update: {str(e)}", + ) + async def create_vector_index(input_data): try: @@ -298,6 +351,7 @@ async def delete_vector_index(input_data): "ingest-files": ingest_files, "update-files": update_files, "ingest-chunks": ingest_chunks, + "update-chunk": update_chunk, "create-vector-index": create_vector_index, "delete-vector-index": delete_vector_index, } diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index 478361c85..5838dd9e1 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -352,6 +352,48 @@ async def ingest_chunks_ingress( return document_info + @telemetry_event("UpdateChunk") + async def update_chunk_ingress( + self, + document_id: UUID, + extraction_id: UUID, + text: str, + user: UserResponse, + metadata: Optional[dict] = None, + *args: Any, + **kwargs: Any, + ) -> DocumentInfo: + # Verify chunk exists and user has access + existing_chunks = await self.providers.database.get_document_chunks( + document_id=document_id, limit=1 + ) + + if not existing_chunks["results"]: + raise R2RException( + status_code=404, + message=f"Chunk with extraction_id {extraction_id} not found.", + ) + + existing_chunk = existing_chunks["results"][0] + + if ( + str(existing_chunk["user_id"]) != str(user.id) + and not user.is_superuser + ): + raise R2RException( + status_code=403, + message="You don't have permission to modify this chunk.", + ) + + # Get document info for return + documents_overview = ( + await self.providers.database.get_documents_overview( + filter_document_ids=[document_id], + ) + )["results"] + + return documents_overview[0] + async def _get_enriched_chunk_text( self, chunk_idx: int, @@ -577,6 +619,17 @@ def parse_ingest_chunks_input(data: dict) -> dict: "chunks": [RawChunk.from_dict(chunk) for chunk in data["chunks"]], } + @staticmethod + def parse_update_chunk_input(data: dict) -> dict: + return { + "user": IngestionServiceAdapter._parse_user_data(data["user"]), + "document_id": UUID(data["document_id"]), + "extraction_id": UUID(data["extraction_id"]), + "text": data["text"], + "metadata": data.get("metadata"), + "collection_ids": data.get("collection_ids", []), + } + @staticmethod def parse_update_files_input(data: dict) -> dict: return { diff --git a/py/core/providers/database/vector.py b/py/core/providers/database/vector.py index 354da033c..d5e84a730 100644 --- a/py/core/providers/database/vector.py +++ b/py/core/providers/database/vector.py @@ -479,6 +479,7 @@ async def get_document_chunks( SELECT extraction_id, document_id, user_id, collection_ids, text, metadata{vector_select}, COUNT(*) OVER() AS total FROM {self._get_table_name(PostgresVectorHandler.TABLE_NAME)} WHERE document_id = $1 + ORDER BY (metadata->>'chunk_order')::integer OFFSET $2 {limit_clause}; """ From 578971a3d97099b9c3a280caa3bdcf7fb6515e84 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Thu, 24 Oct 2024 16:27:23 -0700 Subject: [PATCH 2/3] Finish ingest chunks, delete, and testing --- .../run-sdk-collections-tests/action.yml | 15 +++ .../r2rClientIntegrationSuperUser.test.ts | 34 ++++- .../r2rClientIntegrationUser.test.ts | 1 + js/sdk/src/r2rClient.ts | 39 ++++++ py/compose.full.yaml | 3 +- py/core/base/providers/database.py | 7 + .../hatchet/ingestion_workflow.py | 43 ++---- .../simple/ingestion_workflow.py | 37 ++---- py/core/main/services/ingestion_service.py | 48 +++++-- py/core/main/services/management_service.py | 81 +++++++----- py/core/providers/database/vector.py | 22 ++++ py/sdk/mixins/ingestion.py | 33 +++++ py/tests/integration/runner_sdk.py | 124 ++++++++++++++++++ 13 files changed, 378 insertions(+), 109 deletions(-) diff --git a/.github/actions/run-sdk-collections-tests/action.yml b/.github/actions/run-sdk-collections-tests/action.yml index 7f8878bed..704f8e487 100644 --- a/.github/actions/run-sdk-collections-tests/action.yml +++ b/.github/actions/run-sdk-collections-tests/action.yml @@ -82,3 +82,18 @@ runs: working-directory: ./py shell: bash run: poetry run python tests/integration/runner_sdk.py test_user_permissions + + - name: Ingest chunks + working-directory: ./py + shell: bash + run: poetry run python tests/integration/runner_sdk.py test_ingest_chunks + + - name: Update chunks + working-directory: ./py + shell: bash + run: poetry run python tests/integration/runner_sdk.py test_update_chunks + + - name: Delete chunks + working-directory: ./py + shell: bash + run: poetry run python tests/integration/runner_sdk.py test_delete_chunks diff --git a/js/sdk/__tests__/r2rClientIntegrationSuperUser.test.ts b/js/sdk/__tests__/r2rClientIntegrationSuperUser.test.ts index 8fa29bd76..c52b7e7ea 100644 --- a/js/sdk/__tests__/r2rClientIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/r2rClientIntegrationSuperUser.test.ts @@ -33,6 +33,7 @@ let newCollectionId: string; * - ingestFiles * - updateFiles * - ingestChunks + * - updateChunks * Management: * - serverStats * X updatePrompt @@ -254,8 +255,37 @@ describe("r2rClient Integration Tests", () => { ).resolves.not.toThrow(); }); - test("Collections overview", async () => { - await expect(client.collectionsOverview()).resolves.not.toThrow(); + test("Update chunk", async () => { + const response = await client.updateChunk( + "bd2cbead-66e0-57bc-acea-2c34711a39b5", + "c043aa2c-80e8-59ed-a390-54f1947ea32b", + "updated text", + ); + }); + + test("Ensure that updated chunk has updated text", async () => { + const response = await client.documentChunks( + "bd2cbead-66e0-57bc-acea-2c34711a39b5", + ); + + const targetId = "c043aa2c-80e8-59ed-a390-54f1947ea32b"; + const updatedChunk = response.results.find( + (chunk: { extraction_id: string; text: string }) => + String(chunk.extraction_id) === targetId, + ); + + expect(updatedChunk).toBeDefined(); + expect(updatedChunk?.text).toBe("updated text"); + }); + + test("Delete the updated chunk", async () => { + await expect( + client.delete({ + extraction_id: { + $eq: "c043aa2c-80e8-59ed-a390-54f1947ea32b", + }, + }), + ).resolves.toBe(""); }); test("Create collection", async () => { diff --git a/js/sdk/__tests__/r2rClientIntegrationUser.test.ts b/js/sdk/__tests__/r2rClientIntegrationUser.test.ts index 4d3733342..2afa7f5dc 100644 --- a/js/sdk/__tests__/r2rClientIntegrationUser.test.ts +++ b/js/sdk/__tests__/r2rClientIntegrationUser.test.ts @@ -29,6 +29,7 @@ const baseUrl = "http://localhost:7272"; * - ingestFiles * - updateFiles * X ingestChunks + * X updateChunks * Management: * - serverStats * X updatePrompt diff --git a/js/sdk/src/r2rClient.ts b/js/sdk/src/r2rClient.ts index 0b86dc345..5a8551ce2 100644 --- a/js/sdk/src/r2rClient.ts +++ b/js/sdk/src/r2rClient.ts @@ -678,6 +678,45 @@ export class r2rClient { }); } + @feature("updateChunk") + async updateChunk( + documentId: string, + extractionId: string, + text: string, + metadata?: Record, + runWithOrchestration?: boolean, + ): Promise> { + /** + * Update the content of an existing chunk. + * + * @param documentId - The ID of the document containing the chunk. + * @param extractionId - The ID of the chunk to update. + * @param text - The new text content of the chunk. + * @param metadata - Optional metadata dictionary for the chunk. + * @param runWithOrchestration - Whether to run the update through orchestration. + * @returns Update results containing processed, failed, and skipped documents. + */ + this._ensureAuthenticated(); + + const data: Record = { + text, + metadata, + run_with_orchestration: runWithOrchestration, + }; + + Object.keys(data).forEach( + (key) => data[key] === undefined && delete data[key], + ); + + return await this._makeRequest( + "PUT", + `update_chunk/${documentId}/${extractionId}`, + { + data, + }, + ); + } + // ----------------------------------------------------------------------------- // // Management diff --git a/py/compose.full.yaml b/py/compose.full.yaml index 50793aa1b..8c7ea4a19 100644 --- a/py/compose.full.yaml +++ b/py/compose.full.yaml @@ -270,8 +270,7 @@ services: retries: 5 r2r: - # image: ${R2R_IMAGE:-ragtoriches/prod:latest} - image: r2r/test + image: ${R2R_IMAGE:-ragtoriches/prod:latest} build: context: . args: diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index ecf00e5d9..e1c5f0295 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -492,6 +492,10 @@ async def get_document_chunks( ) -> dict[str, Any]: pass + @abstractmethod + async def get_chunk(self, extraction_id: UUID) -> Optional[dict[str, Any]]: + pass + @abstractmethod async def create_index( self, @@ -902,6 +906,9 @@ async def get_document_chunks( document_id, offset, limit, include_vectors ) + async def get_chunk(self, extraction_id: UUID) -> Optional[dict[str, Any]]: + return await self.vector_handler.get_chunk(extraction_id) + async def create_index( self, table_name: Optional[VectorTableName] = None, diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index 19d2b9338..5ee8853b0 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -1,6 +1,7 @@ import asyncio import logging import uuid +from uuid import UUID from typing import TYPE_CHECKING from hatchet_sdk import ConcurrencyLimitStrategy, Context @@ -515,52 +516,28 @@ async def update_chunk(self, context: Context) -> dict: ) document_uuid = ( - parsed_data["document_id"] + UUID(parsed_data["document_id"]) if isinstance(parsed_data["document_id"], str) else parsed_data["document_id"] ) extraction_uuid = ( - parsed_data["extraction_id"] + UUID(parsed_data["extraction_id"]) if isinstance(parsed_data["extraction_id"], str) else parsed_data["extraction_id"] ) - document_info = ( - await self.ingestion_service.update_chunk_ingress( - **{ - **parsed_data, - "document_id": document_uuid, - "extraction_id": extraction_uuid, - } - ) - ) - - extraction = DocumentExtraction( - id=extraction_uuid, + await self.ingestion_service.update_chunk_ingress( document_id=document_uuid, - collection_ids=parsed_data.get("collection_ids", []), - user_id=document_info.user_id, - data=parsed_data["text"], - metadata=parsed_data["metadata"], - ).to_dict() - - embedding_generator = ( - await self.ingestion_service.embed_document([extraction]) - ) - embeddings = [ - embedding.to_dict() - async for embedding in embedding_generator - ] - - storage_generator = ( - await self.ingestion_service.store_embeddings(embeddings) + extraction_id=extraction_uuid, + text=parsed_data.get("text"), + user=parsed_data["user"], + metadata=parsed_data.get("metadata"), + collection_ids=parsed_data.get("collection_ids"), ) - async for _ in storage_generator: - pass return { "message": "Chunk update completed successfully.", - "task_id": context.workflow_run_id(), # or None if not applicable + "task_id": context.workflow_run_id(), "document_ids": [str(document_uuid)], } diff --git a/py/core/main/orchestration/simple/ingestion_workflow.py b/py/core/main/orchestration/simple/ingestion_workflow.py index 08df9b599..9cf71d1cd 100644 --- a/py/core/main/orchestration/simple/ingestion_workflow.py +++ b/py/core/main/orchestration/simple/ingestion_workflow.py @@ -257,13 +257,12 @@ async def ingest_chunks(input_data): ) async def update_chunk(input_data): - try: - from core.main import IngestionServiceAdapter + from core.main import IngestionServiceAdapter + try: parsed_data = IngestionServiceAdapter.parse_update_chunk_input( input_data ) - document_uuid = ( UUID(parsed_data["document_id"]) if isinstance(parsed_data["document_id"], str) @@ -275,32 +274,14 @@ async def update_chunk(input_data): else parsed_data["extraction_id"] ) - document_info = await service.update_chunk_ingress( - **{ - **parsed_data, - "document_id": document_uuid, - "extraction_id": extraction_uuid, - } - ) - - extraction = DocumentExtraction( - id=extraction_uuid, + await service.update_chunk_ingress( document_id=document_uuid, - collection_ids=parsed_data.get("collection_ids", []), - user_id=document_info.user_id, - data=parsed_data["text"], - metadata=parsed_data["metadata"], - ).model_dump() - - embedding_generator = await service.embed_document([extraction]) - embeddings = [ - embedding.model_dump() - async for embedding in embedding_generator - ] - - storage_generator = await service.store_embeddings(embeddings) - async for _ in storage_generator: - pass + extraction_id=extraction_uuid, + text=parsed_data.get("text"), + user=parsed_data["user"], + metadata=parsed_data.get("metadata"), + collection_ids=parsed_data.get("collection_ids"), + ) except Exception as e: raise R2RException( diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index 5838dd9e1..34a89d7c2 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -362,7 +362,7 @@ async def update_chunk_ingress( metadata: Optional[dict] = None, *args: Any, **kwargs: Any, - ) -> DocumentInfo: + ) -> dict: # Verify chunk exists and user has access existing_chunks = await self.providers.database.get_document_chunks( document_id=document_id, limit=1 @@ -374,7 +374,12 @@ async def update_chunk_ingress( message=f"Chunk with extraction_id {extraction_id} not found.", ) - existing_chunk = existing_chunks["results"][0] + existing_chunk = await self.providers.database.get_chunk(extraction_id) + if not existing_chunk: + raise R2RException( + status_code=404, + message=f"Chunk with id {extraction_id} not found", + ) if ( str(existing_chunk["user_id"]) != str(user.id) @@ -385,14 +390,39 @@ async def update_chunk_ingress( message="You don't have permission to modify this chunk.", ) - # Get document info for return - documents_overview = ( - await self.providers.database.get_documents_overview( - filter_document_ids=[document_id], - ) - )["results"] + # Handle metadata merging + if metadata is not None: + merged_metadata = { + **existing_chunk["metadata"], + **metadata, + } + else: + merged_metadata = existing_chunk["metadata"] + + # Create updated extraction + extraction_data = { + "id": extraction_id, + "document_id": document_id, + "collection_ids": kwargs.get( + "collection_ids", existing_chunk["collection_ids"] + ), + "user_id": existing_chunk["user_id"], + "data": text or existing_chunk["text"], + "metadata": merged_metadata, + } + + extraction = DocumentExtraction(**extraction_data).model_dump() + + embedding_generator = await self.embed_document([extraction]) + embeddings = [ + embedding.model_dump() async for embedding in embedding_generator + ] + + storage_generator = await self.store_embeddings(embeddings) + async for _ in storage_generator: + pass - return documents_overview[0] + return extraction async def _get_enriched_chunk_text( self, diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index c4297cdff..4a7746efc 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -230,7 +230,12 @@ async def delete( """ def validate_filters(filters: dict[str, Any]) -> None: - ALLOWED_FILTERS = {"document_id", "user_id", "collection_ids"} + ALLOWED_FILTERS = { + "document_id", + "user_id", + "collection_ids", + "extraction_id", + } if not filters: raise R2RException( @@ -244,7 +249,7 @@ def validate_filters(filters: dict[str, Any]) -> None: message=f"Invalid filter field: {field}", ) - for field in ["document_id", "user_id"]: + for field in ["document_id", "user_id", "extraction_id"]: if field in filters: op = next(iter(filters[field].keys())) try: @@ -280,12 +285,9 @@ def validate_filters(filters: dict[str, Any]) -> None: document_ids_to_purge: set[UUID] = set() if vector_delete_results: document_ids_to_purge.update( - UUID(doc_id) - for doc_id in ( - result.get("document_id") - for result in vector_delete_results.values() - ) - if doc_id + UUID(result.get("document_id")) + for result in vector_delete_results.values() + if result.get("document_id") ) relational_filters = {} @@ -300,38 +302,47 @@ def validate_filters(filters: dict[str, Any]) -> None: filters["collection_ids"]["$in"] ) - try: - documents_overview = ( - await self.providers.database.get_documents_overview( - **relational_filters # type: ignore + if relational_filters: + try: + documents_overview = ( + await self.providers.database.get_documents_overview( + **relational_filters # type: ignore + ) + )["results"] + except Exception as e: + logger.error( + f"Error fetching documents from relational database: {e}" ) - )["results"] - except Exception as e: - logger.error( - f"Error fetching documents from relational database: {e}" - ) - documents_overview = [] - - if documents_overview: - document_ids_to_purge.update(doc.id for doc in documents_overview) - - if not document_ids_to_purge: - raise R2RException( - status_code=404, message="No entries found for deletion." - ) + documents_overview = [] - for document_id in document_ids_to_purge: - try: - await self.providers.database.delete_from_documents_overview( - document_id + if documents_overview: + document_ids_to_purge.update( + doc.id for doc in documents_overview ) - logger.info( - f"Deleted document ID {document_id} from documents_overview." + + if not document_ids_to_purge: + raise R2RException( + status_code=404, message="No entries found for deletion." ) - except Exception as e: - logger.error( - f"Error deleting document ID {document_id} from documents_overview: {e}" + + for document_id in document_ids_to_purge: + remaining_chunks = ( + await self.providers.database.get_document_chunks( + document_id + ) ) + if remaining_chunks["total_entries"] == 0: + try: + await self.providers.database.delete_from_documents_overview( + document_id + ) + logger.info( + f"Deleted document ID {document_id} from documents_overview." + ) + except Exception as e: + logger.error( + f"Error deleting document ID {document_id} from documents_overview: {e}" + ) return None diff --git a/py/core/providers/database/vector.py b/py/core/providers/database/vector.py index d5e84a730..c3bd2ec76 100644 --- a/py/core/providers/database/vector.py +++ b/py/core/providers/database/vector.py @@ -509,6 +509,28 @@ async def get_document_chunks( return {"results": chunks, "total_entries": total} + async def get_chunk(self, extraction_id: UUID) -> Optional[dict[str, Any]]: + query = f""" + SELECT extraction_id, document_id, user_id, collection_ids, text, metadata + FROM {self._get_table_name(PostgresVectorHandler.TABLE_NAME)} + WHERE extraction_id = $1; + """ + + result = await self.connection_manager.fetchrow_query( + query, (extraction_id,) + ) + + if result: + return { + "extraction_id": result["extraction_id"], + "document_id": result["document_id"], + "user_id": result["user_id"], + "collection_ids": result["collection_ids"], + "text": result["text"], + "metadata": json.loads(result["metadata"]), + } + return None + async def create_index( self, table_name: Optional[VectorTableName] = None, diff --git a/py/sdk/mixins/ingestion.py b/py/sdk/mixins/ingestion.py index 10c4cb9d3..4b7212e5a 100644 --- a/py/sdk/mixins/ingestion.py +++ b/py/sdk/mixins/ingestion.py @@ -169,6 +169,39 @@ async def ingest_chunks( data["run_with_orchestration"] = str(run_with_orchestration) # type: ignore return await self._make_request("POST", "ingest_chunks", json=data) # type: ignore + async def update_chunks( + self, + document_id: UUID, + extraction_id: UUID, + text: str, + metadata: Optional[dict] = None, + run_with_orchestration: Optional[bool] = None, + ) -> dict: + """ + Update the content of an existing chunk. + + Args: + document_id (UUID): The ID of the document containing the chunk. + extraction_id (UUID): The ID of the chunk to update. + text (str): The new text content of the chunk. + metadata (Optional[dict]): Metadata dictionary for the chunk. + run_with_orchestration (Optional[bool]): Whether to run the update through orchestration. + + Returns: + dict: Update results containing processed, failed, and skipped documents. + """ + + data = { + "text": text, + "metadata": metadata, + "run_with_orchestration": run_with_orchestration, + } + + # Remove None values from payload + data = {k: v for k, v in data.items() if v is not None} + + return await self._make_request("PUT", f"update_chunk/{document_id}/{extraction_id}", json=data) # type: ignore + async def create_vector_index( self, table_name: VectorTableName = VectorTableName.VECTORS, diff --git a/py/tests/integration/runner_sdk.py b/py/tests/integration/runner_sdk.py index dc50a2ba7..4ae6c6177 100644 --- a/py/tests/integration/runner_sdk.py +++ b/py/tests/integration/runner_sdk.py @@ -1795,6 +1795,130 @@ def test_conversation_history_sdk(): print("~" * 100) +def test_ingest_chunks(): + print("Testing: Ingest chunks") + + client.ingest_chunks( + chunks=[ + { + # extraction_id should be 21acd7c0-fe60-572e-89b1-3ae71861bbb3 + "text": "Hello, world!", + }, + { + # extraction_id should be 7c1871cd-0f6a-52c1-84d8-3365c29251b3 + "text": "Hallo, Welt!", + }, + { + # extraction_id should be bccdb72f-ac9f-5708-81eb-b4d781ed9fe2 + "text": "Szia, világ!", + }, + { + # extraction_id should be 0d3d0fdd-5a13-55a7-8f42-8443f3ad7fbc + "text": "Dzień dobry, świecie!", + }, + ], + document_id="82346fd6-7479-4a49-a16a-88b5f91a3672", + metadata={ + "Language 1": "English", + "Language 2": "German", + "Language 3": "Hungarian", + "Language 4": "Polish", + }, + ) + + ingest_chunks_response = client.document_chunks( + document_id="82346fd6-7479-4a49-a16a-88b5f91a3672" + ) + + # Assert that the extraction_id is correct + assert ( + ingest_chunks_response["results"][0]["extraction_id"] + == "21acd7c0-fe60-572e-89b1-3ae71861bbb3" + ) + assert ( + ingest_chunks_response["results"][1]["extraction_id"] + == "7c1871cd-0f6a-52c1-84d8-3365c29251b3" + ) + assert ( + ingest_chunks_response["results"][2]["extraction_id"] + == "bccdb72f-ac9f-5708-81eb-b4d781ed9fe2" + ) + assert ( + ingest_chunks_response["results"][3]["extraction_id"] + == "0d3d0fdd-5a13-55a7-8f42-8443f3ad7fbc" + ) + + +def test_update_chunks(): + print("Testing: Update chunk") + + client.update_chunks( + document_id="82346fd6-7479-4a49-a16a-88b5f91a3672", + extraction_id="21acd7c0-fe60-572e-89b1-3ae71861bbb3", + text="Goodbye, world!", + ) + + client.update_chunks( + document_id="82346fd6-7479-4a49-a16a-88b5f91a3672", + extraction_id="7c1871cd-0f6a-52c1-84d8-3365c29251b3", + text="Auf Wiedersehen, Welt!", + ) + + client.update_chunks( + document_id="82346fd6-7479-4a49-a16a-88b5f91a3672", + extraction_id="bccdb72f-ac9f-5708-81eb-b4d781ed9fe2", + text="Viszlát, világ!", + ) + + client.update_chunks( + document_id="82346fd6-7479-4a49-a16a-88b5f91a3672", + extraction_id="0d3d0fdd-5a13-55a7-8f42-8443f3ad7fbc", + text="Dobranoc, świecie!", + ) + + ingest_chunks_response = client.document_chunks( + document_id="82346fd6-7479-4a49-a16a-88b5f91a3672" + ) + + # Assert that the text has been updated + assert ingest_chunks_response["results"][0]["text"] == "Goodbye, world!" + assert ( + ingest_chunks_response["results"][1]["text"] + == "Auf Wiedersehen, Welt!" + ) + assert ingest_chunks_response["results"][2]["text"] == "Viszlát, világ!" + assert ingest_chunks_response["results"][3]["text"] == "Dobranoc, świecie!" + + # Assert that the metadata has been maintained + assert ( + ingest_chunks_response["results"][0]["metadata"]["Language 1"] + == "English" + ) + + +def test_delete_chunks(): + print("Testing: Delete chunks") + + client.delete( + {"extraction_id": {"$eq": "21acd7c0-fe60-572e-89b1-3ae71861bbb3"}} + ) + client.delete( + {"extraction_id": {"$eq": "7c1871cd-0f6a-52c1-84d8-3365c29251b3"}} + ) + client.delete( + {"extraction_id": {"$eq": "bccdb72f-ac9f-5708-81eb-b4d781ed9fe2"}} + ) + client.delete( + {"extraction_id": {"$eq": "0d3d0fdd-5a13-55a7-8f42-8443f3ad7fbc"}} + ) + try: + client.document_chunks( + document_id="82346fd6-7479-4a49-a16a-88b5f91a3672" + ) + except R2RException as e: + assert e.status_code == 404 + + def create_client(base_url): return R2RClient(base_url) From e0b15b7fae70e394516a7a49ffb7ab6226d15414 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Thu, 24 Oct 2024 16:45:39 -0700 Subject: [PATCH 3/3] Docs and JS tests --- docs/cookbooks/ingestion.mdx | 52 +- docs/documentation/js-sdk/ingestion.mdx | 57 +++ docs/documentation/python-sdk/ingestion.mdx | 52 ++ js/sdk/__tests__/r2rClient.test.ts | 529 ++++++++++++++++++++ 4 files changed, 687 insertions(+), 3 deletions(-) diff --git a/docs/cookbooks/ingestion.mdx b/docs/cookbooks/ingestion.mdx index ab679d805..5d024979f 100644 --- a/docs/cookbooks/ingestion.mdx +++ b/docs/cookbooks/ingestion.mdx @@ -99,19 +99,65 @@ The `update_files` method accepts the following parameters: - `document_ids` (required): A list of document IDs corresponding to the files being updated. - `metadatas` (optional): A list of metadata dictionaries to update for each document. -## Deleting Documents -To delete documents from your R2R system, you can use the `delete` method: +## Updating Chunks + +To update specific chunks within existing documents in your R2R deployment, you can use the `update_chunks` method: + +```python +document_id = "9fbe403b-c11c-5aae-8ade-ef22980c3ad1" +extraction_id = "aeba6400-1bd0-5ee9-8925-04732d675434" + +update_response = client.update_chunks( + document_id=document_id, + extraction_id=extraction_id, + text="Updated chunk content with new information...", + metadata={ + "source": "manual_edit", + "edited_at": "2024-10-24", + "editor": "John Doe" + } +) +``` + +The `update_chunks` method accepts the following parameters: + +- `document_id` (required): The ID of the document containing the chunk you want to update. +- `extraction_id` (required): The ID of the specific chunk you want to update. +- `text` (required): The new text content that will replace the existing chunk text. +- `metadata` (optional): A metadata dictionary that will replace the existing chunk metadata. +- `run_with_orchestration` (optional): Whether to run the update through orchestration (default: true). + +This method is particularly useful when you need to: +- Correct errors in specific chunks +- Update outdated information +- Add or modify metadata for individual chunks +- Make targeted changes without reprocessing entire documents + +Note that updating chunks will trigger a re-vectorization of the modified content, ensuring that your vector search capabilities remain accurate with the updated information. + + +## Deleting Documents and Chunks + +To delete documents or chunks from your R2R deployment, you can use the `delete` method: ```python +# For documents delete_response = client.delete( { "document_id": {"$eq": "document1_id"} } ) + +# For chunks +delete_response = client.delete( + { + "extraction_id": {"$eq": "extraction1_id"} + } +) ``` -The `delete` method accepts a dictionary specifying the filters to identify the documents to delete. In this example, it deletes the document with the ID "document1_id". +The `delete` method accepts a dictionary specifying the filters to identify the documents to delete. In this example, it deletes the document with the ID "document1_id" and the chunk with the ID "extraction1_id." ## Conclusion diff --git a/docs/documentation/js-sdk/ingestion.mdx b/docs/documentation/js-sdk/ingestion.mdx index 1131e334e..4c9aaedeb 100644 --- a/docs/documentation/js-sdk/ingestion.mdx +++ b/docs/documentation/js-sdk/ingestion.mdx @@ -329,6 +329,63 @@ const updateResponse = await client.updateFiles(files, { +### Update Chunks + +Update the content of an existing chunk in your R2R system: + +```javascript +const documentId = "9fbe403b-c11c-5aae-8ade-ef22980c3ad1"; +const extractionId = "aeba6400-1bd0-5ee9-8925-04732d675434"; + +const updateResponse = await client.updateChunks({ + document_id: documentId, + extraction_id: extractionId, + text: "Updated chunk content...", + metadata: { + source: "manual_edit", + edited_at: "2024-10-24" + } +}); +``` + + + + + The response from the R2R system after updating the chunk. + ```bash + { + 'message': 'Update chunk task queued successfully.', + 'task_id': '7e27dfca-606d-422d-b73f-2d9e138661b4', + 'document_id': '9fbe403b-c11c-5aae-8ade-ef22980c3ad1' + } + ``` + + + + + + + The ID of the document containing the chunk to update. + + + + The ID of the specific chunk to update. + + + + The new text content to replace the existing chunk text. + + + + An optional metadata object for the updated chunk. If provided, this will replace the existing chunk metadata. + + + + Whether or not the update runs with orchestration, default is `true`. When set to `false`, the update process will run synchronous and directly return the result. + + + + ### Documents Overview Retrieve high-level document information, restricted to user files, except when called by a superuser where it will then return results from over all users: diff --git a/docs/documentation/python-sdk/ingestion.mdx b/docs/documentation/python-sdk/ingestion.mdx index 56504846f..df186ba32 100644 --- a/docs/documentation/python-sdk/ingestion.mdx +++ b/docs/documentation/python-sdk/ingestion.mdx @@ -452,6 +452,58 @@ The ingestion configuration can be customized analogously to the ingest files en +### Update Chunks + +Update the content of an existing chunk in your R2R system: + +```python +document_id = "9fbe403b-c11c-5aae-8ade-ef22980c3ad1" +extraction_id = "aeba6400-1bd0-5ee9-8925-04732d675434" + +update_response = client.update_chunks( + document_id=document_id, + extraction_id=extraction_id, + text="Updated chunk content...", + metadata={"source": "manual_edit", "edited_at": "2024-10-24"} +) +``` + + + + + The response from the R2R system after updating the chunk. + ```bash + { + 'message': 'Update chunk task queued successfully.', + 'task_id': '7e27dfca-606d-422d-b73f-2d9e138661b4', + 'document_id': '9fbe403b-c11c-5aae-8ade-ef22980c3ad1' + } + ``` + + + + + + The ID of the document containing the chunk to update. + + + + The ID of the specific chunk to update. + + + + The new text content to replace the existing chunk text. + + + + An optional metadata dictionary for the updated chunk. If provided, this will replace the existing chunk metadata. + + + + Whether or not the update runs with orchestration, default is `True`. When set to `False`, the update process will run synchronous and directly return the result. + + + ### Documents Overview Retrieve high-level document information. Results are restricted to the current user's files, unless the request is made by a superuser, in which case results from all users are returned: diff --git a/js/sdk/__tests__/r2rClient.test.ts b/js/sdk/__tests__/r2rClient.test.ts index 44bb0556a..e35511962 100644 --- a/js/sdk/__tests__/r2rClient.test.ts +++ b/js/sdk/__tests__/r2rClient.test.ts @@ -41,4 +41,533 @@ describe("R2RClient", () => { }); }); }); + + describe("Authentication Methods", () => { + test("register should send POST request to /register with correct data", async () => { + const mockResponse = { success: true }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + const email = "test@example.com"; + const password = "password123"; + const result = await client.register(email, password); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "POST", + url: "register", + data: JSON.stringify({ email, password }), + headers: { + "Content-Type": "application/json", + }, + responseType: "json", + }); + }); + + test("login should send POST request to /login with correct data and set tokens", async () => { + const mockResponse = { + results: { + access_token: { token: "access-token", token_type: "access_token" }, + refresh_token: { + token: "refresh-token", + token_type: "refresh_token", + }, + }, + }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + const email = "test@example.com"; + const password = "password123"; + const result = await client.login(email, password); + + expect(result).toEqual(mockResponse.results); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "POST", + url: "login", + data: "username=test%40example.com&password=password123", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + responseType: "json", + }); + // Check that tokens are set + expect((client as any).accessToken).toBe("access-token"); + expect((client as any).refreshToken).toBe("refresh-token"); + }); + + test("verifyEmail should send POST request to /verify_email with correct data", async () => { + const mockResponse = { success: true }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + const verification_code = "123456"; + const result = await client.verifyEmail(verification_code); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "POST", + url: "verify_email", + data: JSON.stringify({ verification_code }), + headers: { + "Content-Type": "application/json", + }, + responseType: "json", + }); + }); + + test("logout should send POST request to /logout and clear tokens", async () => { + mockAxiosInstance.request.mockResolvedValue({ data: {} }); + + // Set tokens first + (client as any).accessToken = "access-token"; + (client as any).refreshToken = "refresh-token"; + + const result = await client.logout(); + + expect(result).toEqual({}); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "POST", + url: "logout", + headers: { + Authorization: "Bearer access-token", + }, + responseType: "json", + }); + expect((client as any).accessToken).toBeNull(); + expect((client as any).refreshToken).toBeNull(); + }); + + test("user should send GET request to /user and return data", async () => { + const mockResponse = { id: "user-id", email: "test@example.com" }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + // Set accessToken + (client as any).accessToken = "access-token"; + + const result = await client.user(); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "GET", + url: "user", + headers: { + Authorization: "Bearer access-token", + }, + responseType: "json", + }); + }); + + test("updateUser should send PUT request to /user with correct data", async () => { + const mockResponse = { success: true }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + // Set accessToken + (client as any).accessToken = "access-token"; + + const userId = "user-id"; + const email = "new@example.com"; + const name = "New Name"; + const bio = "New Bio"; + const profilePicture = "http://example.com/pic.jpg"; + const isSuperuser = true; + + const result = await client.updateUser( + userId, + email, + isSuperuser, + name, + bio, + profilePicture, + ); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "PUT", + url: "user", + data: JSON.stringify({ + user_id: userId, + email, + is_superuser: isSuperuser, + name, + bio, + profile_picture: profilePicture, + }), + headers: { + Authorization: "Bearer access-token", + "Content-Type": "application/json", + }, + responseType: "json", + }); + }); + + test("refreshAccessToken should send POST request to /refresh_access_token and update tokens", async () => { + const mockResponse = { + results: { + access_token: { + token: "new-access-token", + token_type: "access_token", + }, + refresh_token: { + token: "new-refresh-token", + token_type: "refresh_token", + }, + }, + }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + // Set refreshToken + (client as any).refreshToken = "old-refresh-token"; + + const result = await client.refreshAccessToken(); + + expect(result).toEqual(mockResponse); + expect((client as any).accessToken).toBe("new-access-token"); + expect((client as any).refreshToken).toBe("new-refresh-token"); + + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "POST", + url: "refresh_access_token", + data: "old-refresh-token", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + responseType: "json", + }); + }); + + test("changePassword should send POST request to /change_password with correct data", async () => { + const mockResponse = { success: true }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + // Set accessToken + (client as any).accessToken = "access-token"; + + const current_password = "old-password"; + const new_password = "new-password"; + + const result = await client.changePassword( + current_password, + new_password, + ); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "POST", + url: "change_password", + data: JSON.stringify({ + current_password, + new_password, + }), + headers: { + Authorization: "Bearer access-token", + "Content-Type": "application/json", + }, + responseType: "json", + }); + }); + + test("requestPasswordReset should send POST request to /request_password_reset with correct data", async () => { + const mockResponse = { success: true }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + const email = "test@example.com"; + + const result = await client.requestPasswordReset(email); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "POST", + url: "request_password_reset", + data: JSON.stringify({ email }), + headers: { + "Content-Type": "application/json", + }, + responseType: "json", + }); + }); + + test("confirmPasswordReset should send POST request to /reset_password/{resetToken} with correct data", async () => { + const mockResponse = { success: true }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + const resetToken = "reset-token"; + const newPassword = "new-password"; + + const result = await client.confirmPasswordReset(resetToken, newPassword); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "POST", + url: `reset_password/${resetToken}`, + data: JSON.stringify({ new_password: newPassword }), + headers: { + "Content-Type": "application/json", + }, + responseType: "json", + }); + }); + + test("deleteUser should send DELETE request to /user/{userId} with correct data", async () => { + const mockResponse = { success: true }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + // Set accessToken + (client as any).accessToken = "access-token"; + + const userId = "user-id"; + const password = "password123"; + + const result = await client.deleteUser(userId, password); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "DELETE", + url: `user/${userId}`, + data: JSON.stringify({ password }), + headers: { + Authorization: "Bearer access-token", + "Content-Type": "application/json", + }, + responseType: "json", + }); + }); + }); + + describe("Ingestion Methods", () => { + test("ingestChunks should send POST request to /ingest_chunks with correct data", async () => { + const mockResponse = { success: true }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + // Set accessToken + (client as any).accessToken = "access-token"; + + const chunks = [ + { text: "Chunk 1", metadata: {} }, + { text: "Chunk 2", metadata: {} }, + ]; + const documentId = "doc-id"; + const metadata = { key: "value" }; + const run_with_orchestration = true; + + const result = await client.ingestChunks( + chunks, + documentId, + metadata, + run_with_orchestration, + ); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "POST", + url: "ingest_chunks", + data: JSON.stringify({ + chunks, + document_id: documentId, + metadata, + run_with_orchestration, + }), + headers: { + Authorization: "Bearer access-token", + "Content-Type": "application/json", + }, + responseType: "json", + }); + }); + + test("updateChunk should send PUT request to /update_chunk/{documentId}/{extractionId} with correct data", async () => { + const mockResponse = { success: true }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + // Set accessToken + (client as any).accessToken = "access-token"; + + const documentId = "doc-id"; + const extractionId = "chunk-id"; + const text = "Updated text"; + const metadata = { key: "new value" }; + const runWithOrchestration = false; + + const result = await client.updateChunk( + documentId, + extractionId, + text, + metadata, + runWithOrchestration, + ); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "PUT", + url: `update_chunk/${documentId}/${extractionId}`, + data: JSON.stringify({ + text, + metadata, + run_with_orchestration: runWithOrchestration, + }), + headers: { + Authorization: "Bearer access-token", + "Content-Type": "application/json", + }, + responseType: "json", + }); + }); + }); + + describe("Management Methods", () => { + test("serverStats should send GET request to /server_stats and return data", async () => { + const mockResponse = { uptime: 12345 }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + // Set accessToken + (client as any).accessToken = "access-token"; + + const result = await client.serverStats(); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "GET", + url: "server_stats", + headers: { + Authorization: "Bearer access-token", + }, + responseType: "json", + }); + }); + + test("updatePrompt should send POST request to /update_prompt with correct data", async () => { + const mockResponse = { success: true }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + // Set accessToken + (client as any).accessToken = "access-token"; + + const name = "default_system"; + const template = "New template"; + const input_types = { key: "value" }; + + const result = await client.updatePrompt(name, template, input_types); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "POST", + url: "update_prompt", + data: JSON.stringify({ + name, + template, + input_types, + }), + headers: { + Authorization: "Bearer access-token", + "Content-Type": "application/json", + }, + responseType: "json", + }); + }); + + test("analytics should send GET request to /analytics with correct params", async () => { + const mockResponse = { data: [] }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + // Set accessToken + (client as any).accessToken = "access-token"; + + const filter_criteria = { date: "2021-01-01" }; + const analysis_types = ["type1", "type2"]; + + const result = await client.analytics(filter_criteria, analysis_types); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith( + expect.objectContaining({ + method: "GET", + url: "analytics", + params: { + filter_criteria: JSON.stringify(filter_criteria), + analysis_types: JSON.stringify(analysis_types), + }, + headers: { + Authorization: "Bearer access-token", + }, + responseType: "json", + }), + ); + }); + }); + + describe("Retrieval Methods", () => { + test("search should send POST request to /search with correct data", async () => { + const mockResponse = { results: [] }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + // Set accessToken + (client as any).accessToken = "access-token"; + + const query = "test query"; + const vector_search_settings = { top_k: 5 }; + const kg_search_settings = { max_hops: 2 }; + + const result = await client.search( + query, + vector_search_settings, + kg_search_settings, + ); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "POST", + url: "search", + data: JSON.stringify({ + query, + vector_search_settings, + kg_search_settings, + }), + headers: { + Authorization: "Bearer access-token", + "Content-Type": "application/json", + }, + responseType: "json", + }); + }); + + test("rag should send POST request to /rag with correct data", async () => { + const mockResponse = { answer: "Test answer" }; + mockAxiosInstance.request.mockResolvedValue({ data: mockResponse }); + + // Set accessToken + (client as any).accessToken = "access-token"; + + const query = "test query"; + const rag_generation_config = { max_tokens: 100 }; + const vector_search_settings = { top_k: 5 }; + const kg_search_settings = { max_hops: 2 }; + const task_prompt_override = "Custom prompt"; + const include_title_if_available = true; + + const result = await client.rag( + query, + vector_search_settings, + kg_search_settings, + rag_generation_config, + task_prompt_override, + include_title_if_available, + ); + + expect(result).toEqual(mockResponse); + expect(mockAxiosInstance.request).toHaveBeenCalledWith({ + method: "POST", + url: "rag", + data: JSON.stringify({ + query, + vector_search_settings, + kg_search_settings, + rag_generation_config, + task_prompt_override, + include_title_if_available, + }), + headers: { + Authorization: "Bearer access-token", + "Content-Type": "application/json", + }, + responseType: "json", + }); + }); + }); });