Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunk CRUD #1483

Merged
merged 5 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/actions/run-sdk-collections-tests/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,18 @@ runs:
working-directory: ./py
shell: bash
run: poetry run python tests/integration/runner_sdk.py test_user_permissions

- name: Ingest chunks
working-directory: ./py
shell: bash
run: poetry run python tests/integration/runner_sdk.py test_ingest_chunks

- name: Update chunks
working-directory: ./py
shell: bash
run: poetry run python tests/integration/runner_sdk.py test_update_chunks

- name: Delete chunks
working-directory: ./py
shell: bash
run: poetry run python tests/integration/runner_sdk.py test_delete_chunks
34 changes: 32 additions & 2 deletions js/sdk/__tests__/r2rClientIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ let newCollectionId: string;
* - ingestFiles
* - updateFiles
* - ingestChunks
* - updateChunks
* Management:
* - serverStats
* X updatePrompt
Expand Down Expand Up @@ -254,8 +255,37 @@ describe("r2rClient Integration Tests", () => {
).resolves.not.toThrow();
});

test("Collections overview", async () => {
await expect(client.collectionsOverview()).resolves.not.toThrow();
test("Update chunk", async () => {
const response = await client.updateChunk(
"bd2cbead-66e0-57bc-acea-2c34711a39b5",
"c043aa2c-80e8-59ed-a390-54f1947ea32b",
"updated text",
);
});

test("Ensure that updated chunk has updated text", async () => {
const response = await client.documentChunks(
"bd2cbead-66e0-57bc-acea-2c34711a39b5",
);

const targetId = "c043aa2c-80e8-59ed-a390-54f1947ea32b";
const updatedChunk = response.results.find(
(chunk: { extraction_id: string; text: string }) =>
String(chunk.extraction_id) === targetId,
);

expect(updatedChunk).toBeDefined();
expect(updatedChunk?.text).toBe("updated text");
});

test("Delete the updated chunk", async () => {
await expect(
client.delete({
extraction_id: {
$eq: "c043aa2c-80e8-59ed-a390-54f1947ea32b",
},
}),
).resolves.toBe("");
});

test("Create collection", async () => {
Expand Down
1 change: 1 addition & 0 deletions js/sdk/__tests__/r2rClientIntegrationUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const baseUrl = "http://localhost:7272";
* - ingestFiles
* - updateFiles
* X ingestChunks
* X updateChunks
* Management:
* - serverStats
* X updatePrompt
Expand Down
39 changes: 39 additions & 0 deletions js/sdk/src/r2rClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,45 @@ export class r2rClient {
});
}

@feature("updateChunk")
async updateChunk(
documentId: string,
extractionId: string,
text: string,
metadata?: Record<string, any>,
runWithOrchestration?: boolean,
): Promise<Record<string, any>> {
/**
* Update the content of an existing chunk.
*
* @param documentId - The ID of the document containing the chunk.
* @param extractionId - The ID of the chunk to update.
* @param text - The new text content of the chunk.
* @param metadata - Optional metadata dictionary for the chunk.
* @param runWithOrchestration - Whether to run the update through orchestration.
* @returns Update results containing processed, failed, and skipped documents.
*/
this._ensureAuthenticated();

const data: Record<string, any> = {
text,
metadata,
run_with_orchestration: runWithOrchestration,
};

Object.keys(data).forEach(
(key) => data[key] === undefined && delete data[key],
);

return await this._makeRequest(
"PUT",
`update_chunk/${documentId}/${extractionId}`,
{
data,
},
);
}

// -----------------------------------------------------------------------------
//
// Management
Expand Down
2 changes: 2 additions & 0 deletions py/core/base/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from shared.api.models.ingestion.responses import (
CreateVectorIndexResponse,
IngestionResponse,
UpdateResponse,
WrappedCreateVectorIndexResponse,
WrappedDeleteVectorIndexResponse,
WrappedIngestionResponse,
Expand Down Expand Up @@ -87,6 +88,7 @@
"WrappedListVectorIndicesResponse",
"WrappedDeleteVectorIndexResponse",
"WrappedSelectVectorIndexResponse",
"UpdateResponse",
# Knowledge Graph Responses
"KGCreationResponse",
"WrappedKGCreationResponse",
Expand Down
7 changes: 7 additions & 0 deletions py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ async def get_document_chunks(
) -> dict[str, Any]:
pass

@abstractmethod
async def get_chunk(self, extraction_id: UUID) -> Optional[dict[str, Any]]:
pass

@abstractmethod
async def create_index(
self,
Expand Down Expand Up @@ -902,6 +906,9 @@ async def get_document_chunks(
document_id, offset, limit, include_vectors
)

async def get_chunk(self, extraction_id: UUID) -> Optional[dict[str, Any]]:
return await self.vector_handler.get_chunk(extraction_id)

async def create_index(
self,
table_name: Optional[VectorTableName] = None,
Expand Down
71 changes: 68 additions & 3 deletions py/core/main/api/ingestion_router.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import base64
import logging
from io import BytesIO
from pathlib import Path
from pathlib import Path as pathlib_Path
from typing import Optional, Union
from uuid import UUID

import yaml
from fastapi import Body, Depends, File, Form, Query, UploadFile
from fastapi import Body, Depends, File, Form, Query, UploadFile, Path
from pydantic import Json

from core.base import R2RException, RawChunk, generate_document_id
from core.base.api.models import (
CreateVectorIndexResponse,
UpdateResponse,
WrappedCreateVectorIndexResponse,
WrappedDeleteVectorIndexResponse,
WrappedIngestionResponse,
Expand Down Expand Up @@ -64,6 +65,11 @@ def _register_workflows(self):
if self.orchestration_provider.config.provider != "simple"
else "Update task queued successfully."
),
"update-chunk": (
"Update chunk task queued successfully."
if self.orchestration_provider.config.provider != "simple"
else "Chunk update completed successfully."
),
"create-vector-index": (
"Vector index creation task queued successfully."
if self.orchestration_provider.config.provider != "simple"
Expand All @@ -84,7 +90,9 @@ def _register_workflows(self):

def _load_openapi_extras(self):
yaml_path = (
Path(__file__).parent / "data" / "ingestion_router_openapi.yml"
pathlib_Path(__file__).parent
/ "data"
/ "ingestion_router_openapi.yml"
)
with open(yaml_path, "r") as yaml_file:
yaml_content = yaml.safe_load(yaml_file)
Expand Down Expand Up @@ -406,6 +414,63 @@ async def ingest_chunks_app(
"task_id": None,
}

@self.router.put(
"/update_chunk/{document_id}/{extraction_id}",
)
@self.base_endpoint
async def update_chunk_app(
document_id: UUID = Path(
..., description="The document ID of the chunk to update"
),
extraction_id: UUID = Path(
..., description="The extraction ID of the chunk to update"
),
text: str = Body(
..., description="The new text content for the chunk"
),
metadata: Optional[dict] = Body(
None, description="Optional updated metadata"
),
run_with_orchestration: Optional[bool] = Body(True),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> WrappedUpdateResponse:
try:
workflow_input = {
"document_id": str(document_id),
"extraction_id": str(extraction_id),
"text": text,
"metadata": metadata,
"user": auth_user.model_dump_json(),
}

if run_with_orchestration:
raw_message: dict[str, Union[str, None]] = await self.orchestration_provider.run_workflow( # type: ignore
"update-chunk", {"request": workflow_input}, {}
)
raw_message["message"] = "Update task queued successfully."
raw_message["document_ids"] = [document_id] # type: ignore

return raw_message # type: ignore
else:
logger.info("Running chunk update without orchestration.")
from core.main.orchestration import (
simple_ingestion_factory,
)

simple_ingestor = simple_ingestion_factory(self.service)
await simple_ingestor["update-chunk"](workflow_input)

return { # type: ignore
"message": "Update task completed successfully.",
"document_ids": workflow_input["document_ids"],
"task_id": None,
}

except Exception as e:
raise R2RException(
status_code=500, message=f"Error updating chunk: {str(e)}"
)

create_vector_index_extras = self.openapi_extras.get(
"create_vector_index", {}
)
Expand Down
56 changes: 56 additions & 0 deletions py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
import uuid
from uuid import UUID
from typing import TYPE_CHECKING

from hatchet_sdk import ConcurrencyLimitStrategy, Context
Expand Down Expand Up @@ -498,6 +499,59 @@ async def on_failure(self, context: Context) -> None:
f"Failed to update document status for {document_id}: {e}"
)

@orchestration_provider.workflow(
name="update-chunk",
timeout="60m",
)
class HatchetUpdateChunkWorkflow:
def __init__(self, ingestion_service: IngestionService):
self.ingestion_service = ingestion_service

@orchestration_provider.step(timeout="60m")
async def update_chunk(self, context: Context) -> dict:
try:
input_data = context.workflow_input()["request"]
parsed_data = IngestionServiceAdapter.parse_update_chunk_input(
input_data
)

document_uuid = (
UUID(parsed_data["document_id"])
if isinstance(parsed_data["document_id"], str)
else parsed_data["document_id"]
)
extraction_uuid = (
UUID(parsed_data["extraction_id"])
if isinstance(parsed_data["extraction_id"], str)
else parsed_data["extraction_id"]
)

await self.ingestion_service.update_chunk_ingress(
document_id=document_uuid,
extraction_id=extraction_uuid,
text=parsed_data.get("text"),
user=parsed_data["user"],
metadata=parsed_data.get("metadata"),
collection_ids=parsed_data.get("collection_ids"),
)

return {
"message": "Chunk update completed successfully.",
"task_id": context.workflow_run_id(),
"document_ids": [str(document_uuid)],
}

except Exception as e:
raise R2RException(
status_code=500,
message=f"Error during chunk update: {str(e)}",
)

@orchestration_provider.failure()
async def on_failure(self, context: Context) -> None:
# Handle failure case if necessary
pass

@orchestration_provider.workflow(
name="create-vector-index", timeout="360m"
)
Expand Down Expand Up @@ -545,13 +599,15 @@ async def delete_vector_index(self, context: Context) -> dict:
ingest_files_workflow = HatchetIngestFilesWorkflow(service)
update_files_workflow = HatchetUpdateFilesWorkflow(service)
ingest_chunks_workflow = HatchetIngestChunksWorkflow(service)
update_chunks_workflow = HatchetUpdateChunkWorkflow(service)
create_vector_index_workflow = HatchetCreateVectorIndexWorkflow(service)
delete_vector_index_workflow = HatchetDeleteVectorIndexWorkflow(service)

return {
"ingest_files": ingest_files_workflow,
"update_files": update_files_workflow,
"ingest_chunks": ingest_chunks_workflow,
"update_chunk": update_chunks_workflow,
"create_vector_index": create_vector_index_workflow,
"delete_vector_index": delete_vector_index_workflow,
}
Loading
Loading