From 7e9635ef2547e58d59a4ed2a79872117a42398a7 Mon Sep 17 00:00:00 2001 From: morvanzhou Date: Thu, 11 Jul 2024 20:46:43 +0800 Subject: [PATCH] feat(llm): extend node from llm --- src/retk/controllers/ai/knowledge.py | 20 +++++++++++++++----- src/retk/controllers/schemas/ai.py | 3 ++- src/retk/controllers/user.py | 3 ++- src/retk/routes/ai.py | 2 +- tests/test_api.py | 3 ++- 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/retk/controllers/ai/knowledge.py b/src/retk/controllers/ai/knowledge.py index 84b8ae5..e50d80e 100644 --- a/src/retk/controllers/ai/knowledge.py +++ b/src/retk/controllers/ai/knowledge.py @@ -9,14 +9,24 @@ async def get_extended_nodes( au: AuthedUser, ) -> schemas.ai.GetExtendedNodesResponse: docs = await core.ai.llm.knowledge.extended.get_extended_nodes(uid=au.u.id) - return schemas.ai.GetExtendedNodesResponse( - requestId=au.request_id, - nodes=[schemas.ai.GetExtendedNodesResponse.Node( + nodes = [] + for doc in docs: + res = doc["extendMd"].split("\n", 1) + if len(res) == 2: + title, content = res + else: + title, content = res[0], "" + node = schemas.ai.GetExtendedNodesResponse.Node( id=str(doc["_id"]), sourceNid=doc["sourceNid"], sourceTitle=doc["sourceMd"].split("\n", 1)[0].strip(), - md=doc["extendMd"], - ) for doc in docs] + title=title.strip(), + content=content.strip(), + ) + nodes.append(node) + return schemas.ai.GetExtendedNodesResponse( + requestId=au.request_id, + nodes=nodes ) diff --git a/src/retk/controllers/schemas/ai.py b/src/retk/controllers/schemas/ai.py index 317998f..947e5e6 100644 --- a/src/retk/controllers/schemas/ai.py +++ b/src/retk/controllers/schemas/ai.py @@ -8,7 +8,8 @@ class Node(BaseModel): id: str sourceNid: str sourceTitle: str - md: str + title: str + content: str requestId: str nodes: List[Node] diff --git a/src/retk/controllers/user.py b/src/retk/controllers/user.py index 54c9e84..c610004 100644 --- a/src/retk/controllers/user.py +++ b/src/retk/controllers/user.py @@ -15,6 +15,7 @@ async def get_user( max_space = 0 else: max_space = const.USER_TYPE.id2config(au.u.type).max_store_space + total_nodes = await core.user.get_user_nodes_count(uid=au.u.id, disabled=False, in_trash=False) return schemas.user.UserInfoResponse( requestId=au.request_id, user=schemas.user.UserInfoResponse.User( @@ -38,7 +39,7 @@ async def get_user( editorSepRightWidth=au.u.settings.editor_sep_right_width, editorSideCurrentToolId=au.u.settings.editor_side_current_tool_id, ), - totalNodes=await core.user.get_user_nodes_count(uid=au.u.id, disabled=False, in_trash=False), + totalNodes=total_nodes, ), ) diff --git a/src/retk/routes/ai.py b/src/retk/routes/ai.py index 71f6900..6fe198d 100644 --- a/src/retk/routes/ai.py +++ b/src/retk/routes/ai.py @@ -8,7 +8,7 @@ router = APIRouter( prefix="/api/ai", - tags=["node"], + tags=["ai"], responses={404: {"description": "Not found"}}, ) diff --git a/tests/test_api.py b/tests/test_api.py index ba69bb8..fef2319 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1598,7 +1598,8 @@ async def test_node_extend(self): rj = self.check_ok_response(resp, 200) self.assertEqual(1, len(rj["nodes"])) n = rj["nodes"][0] - self.assertEqual("this is extended md", n["md"]) + self.assertEqual("this is extended md", n["title"]) + self.assertEqual("", n["content"]) self.assertEqual(node["id"], n["sourceNid"]) resp = self.client.post(