From c0e0be6bd8244fb85daf6c0ea4d1b7df2229a3fd Mon Sep 17 00:00:00 2001 From: morvanzhou Date: Wed, 26 Jun 2024 09:19:05 +0800 Subject: [PATCH] feat(app): - typedict optimize - add extend_on_node_update / create --- src/retk/core/ai/__init__.py | 1 + src/retk/core/ai/llm/api/aliyun.py | 11 +- src/retk/core/ai/llm/api/baidu.py | 14 +-- src/retk/core/ai/llm/api/moonshot.py | 5 +- src/retk/core/ai/llm/api/openai.py | 31 ++++-- src/retk/core/ai/llm/api/tencent.py | 12 +- src/retk/core/ai/llm/api/xfyun.py | 16 ++- src/retk/core/ai/llm/knowledge/__init__.py | 13 ++- src/retk/core/ai/llm/knowledge/db_ops.py | 67 +++++++++++ .../core/files/importing/async_tasks/utils.py | 22 ++-- src/retk/core/node/node.py | 5 +- src/retk/core/notice.py | 40 +++---- src/retk/core/scheduler/schedule.py | 6 + src/retk/core/scheduler/tasks/__init__.py | 1 + src/retk/core/scheduler/tasks/extend_node.py | 66 +++++++++++ src/retk/core/statistic.py | 13 ++- src/retk/models/client.py | 2 + src/retk/models/coll.py | 9 ++ src/retk/models/indexing.py | 11 ++ src/retk/models/tps/llm.py | 21 ++++ src/retk/models/tps/statistic.py | 2 +- src/retk/utils.py | 104 +++++++++--------- tests/test_ai_llm_knowledge.py | 6 +- tests/test_core_local.py | 67 ++++++----- tests/test_core_remote.py | 27 +++-- 25 files changed, 390 insertions(+), 182 deletions(-) create mode 100644 src/retk/core/ai/llm/knowledge/db_ops.py create mode 100644 src/retk/core/scheduler/tasks/extend_node.py create mode 100644 src/retk/models/tps/llm.py diff --git a/src/retk/core/ai/__init__.py b/src/retk/core/ai/__init__.py index e69de29..9716642 100644 --- a/src/retk/core/ai/__init__.py +++ b/src/retk/core/ai/__init__.py @@ -0,0 +1 @@ +from . import llm diff --git a/src/retk/core/ai/llm/api/aliyun.py b/src/retk/core/ai/llm/api/aliyun.py index 88c119c..8cb3562 100644 --- a/src/retk/core/ai/llm/api/aliyun.py +++ b/src/retk/core/ai/llm/api/aliyun.py @@ -31,14 +31,15 @@ def __init__( timeout=timeout, default_model=AliyunModelEnum.QWEN1_5_05B.value, ) - self.api_key = config.get_settings().ALIYUN_DASHSCOPE_API_KEY - if self.api_key == "": - raise NoAPIKeyError("Aliyun API key is empty") - def get_headers(self, stream: bool) -> Dict[str, str]: + @staticmethod + def get_headers(stream: bool) -> Dict[str, str]: + k = config.get_settings().ALIYUN_DASHSCOPE_API_KEY + if k == "": + raise NoAPIKeyError("Aliyun API key is empty") h = { 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}', + 'Authorization': f'Bearer {k}', } if stream: h["Accept"] = "text/event-stream" diff --git a/src/retk/core/ai/llm/api/baidu.py b/src/retk/core/ai/llm/api/baidu.py index 617fd62..b820b51 100644 --- a/src/retk/core/ai/llm/api/baidu.py +++ b/src/retk/core/ai/llm/api/baidu.py @@ -33,11 +33,6 @@ def __init__( timeout=timeout, default_model=BaiduModelEnum.ERNIE_SPEED_8K.value, ) - self.api_key = config.get_settings().BAIDU_QIANFAN_API_KEY - self.secret_key = config.get_settings().BAIDU_QIANFAN_SECRET_KEY - if self.api_key == "" or self.secret_key == "": - raise NoAPIKeyError("Baidu api key or key is empty") - self.headers = { "Content-Type": "application/json", } @@ -46,16 +41,21 @@ def __init__( self.token = "" async def set_token(self, req_id: str = None): + _s = config.get_settings() + if _s.BAIDU_QIANFAN_API_KEY == "" or _s.BAIDU_QIANFAN_SECRET_KEY == "": + raise NoAPIKeyError("Baidu api key or skey is empty") + if self.token_expires_at > datetime.now().timestamp(): return + resp = await httpx_helper.get_async_client().post( url="https://aip.baidubce.com/oauth/2.0/token", headers={"Content-Type": "application/json", 'Accept': 'application/json'}, content=b"", params={ "grant_type": "client_credentials", - "client_id": self.api_key, - "client_secret": self.secret_key, + "client_id": _s.BAIDU_QIANFAN_API_KEY, + "client_secret": _s.BAIDU_QIANFAN_SECRET_KEY, } ) if resp.status_code != 200: diff --git a/src/retk/core/ai/llm/api/moonshot.py b/src/retk/core/ai/llm/api/moonshot.py index 47f1b8d..e28cd46 100644 --- a/src/retk/core/ai/llm/api/moonshot.py +++ b/src/retk/core/ai/llm/api/moonshot.py @@ -19,10 +19,13 @@ def __init__( timeout: float = 60., ): super().__init__( - api_key=config.get_settings().MOONSHOT_API_KEY, endpoint="https://api.moonshot.cn/v1/chat/completions", default_model=MoonshotModelEnum.V1_8K.value, top_p=top_p, temperature=temperature, timeout=timeout, ) + + @staticmethod + def get_api_key(): + return config.get_settings().MOONSHOT_API_KEY diff --git a/src/retk/core/ai/llm/api/openai.py b/src/retk/core/ai/llm/api/openai.py index 454a8fa..c9126d6 100644 --- a/src/retk/core/ai/llm/api/openai.py +++ b/src/retk/core/ai/llm/api/openai.py @@ -1,4 +1,5 @@ import json +from abc import ABC, abstractmethod from enum import Enum from typing import Tuple, AsyncIterable, Optional @@ -16,10 +17,9 @@ class OpenaiModelEnum(str, Enum): GPT35_TURBO_16K = "gpt-3.5-turbo-16k" -class OpenaiLLMStyle(BaseLLMService): +class OpenaiLLMStyle(BaseLLMService, ABC): def __init__( self, - api_key: str, endpoint: str, default_model: str, top_p: float = 0.9, @@ -33,12 +33,20 @@ def __init__( timeout=timeout, default_model=default_model, ) - self.api_key = api_key - if self.api_key == "": - raise NoAPIKeyError(f"{self.__class__.__name__} API key is empty") - self.headers = { + + @staticmethod + @abstractmethod + def get_api_key(): + pass + + def get_headers(self): + k = self.get_api_key() + if k == "": + raise NoAPIKeyError(f"{self.__class__.__name__} api key is empty") + + return { "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", + "Authorization": f"Bearer {k}", } def get_payload(self, model: Optional[str], messages: MessagesType, stream: bool) -> bytes: @@ -62,7 +70,7 @@ async def complete( payload = self.get_payload(model, messages, stream=False) rj, code = await self._complete( url=self.endpoint, - headers=self.headers, + headers=self.get_headers(), payload=payload, req_id=req_id, ) @@ -82,7 +90,7 @@ async def stream_complete( payload = self.get_payload(model, messages, stream=True) async for b, code in self._stream_complete( url=self.endpoint, - headers=self.headers, + headers=self.get_headers(), payload=payload, req_id=req_id ): @@ -118,10 +126,13 @@ def __init__( timeout: float = 60., ): super().__init__( - api_key=config.get_settings().OPENAI_API_KEY, endpoint="https://api.openai.com/v1/chat/completions", default_model=OpenaiModelEnum.GPT35_TURBO.value, top_p=top_p, temperature=temperature, timeout=timeout, ) + + @staticmethod + def get_api_key(): + return config.get_settings().OPENAI_API_KEY diff --git a/src/retk/core/ai/llm/api/tencent.py b/src/retk/core/ai/llm/api/tencent.py index 51e700c..8131396 100644 --- a/src/retk/core/ai/llm/api/tencent.py +++ b/src/retk/core/ai/llm/api/tencent.py @@ -52,12 +52,12 @@ def __init__( timeout=timeout, default_model=TencentModelEnum.HUNYUAN_LITE.value, ) - self.secret_id = config.get_settings().HUNYUAN_SECRET_ID - self.secret_key = config.get_settings().HUNYUAN_SECRET_KEY - if self.secret_id == "" or self.secret_key == "": - raise NoAPIKeyError("Tencent secret id or key is empty") def get_auth(self, action: str, payload: bytes, timestamp: int, content_type: str) -> str: + _s = config.get_settings() + if _s.HUNYUAN_SECRET_KEY == "" or _s.HUNYUAN_SECRET_ID == "": + raise NoAPIKeyError("Tencent secret id or key is empty") + algorithm = "TC3-HMAC-SHA256" date = datetime.utcfromtimestamp(timestamp).strftime("%Y-%m-%d") @@ -78,14 +78,14 @@ def get_auth(self, action: str, payload: bytes, timestamp: int, content_type: st string_to_sign = f"{algorithm}\n{timestamp}\n{credential_scope}\n{hashed_canonical_request}" # ************* 步骤 3:计算签名 ************* - secret_date = sign(f"TC3{self.secret_key}".encode("utf-8"), date) + secret_date = sign(f"TC3{_s.HUNYUAN_SECRET_KEY}".encode("utf-8"), date) secret_service = sign(secret_date, self.service) secret_signing = sign(secret_service, "tc3_request") signature = hmac.new(secret_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() # ************* 步骤 4:拼接 Authorization ************* authorization = f"{algorithm}" \ - f" Credential={self.secret_id}/{credential_scope}," \ + f" Credential={_s.HUNYUAN_SECRET_ID}/{credential_scope}," \ f" SignedHeaders={signed_headers}," \ f" Signature={signature}" return authorization diff --git a/src/retk/core/ai/llm/api/xfyun.py b/src/retk/core/ai/llm/api/xfyun.py index 780da62..a8fcabd 100644 --- a/src/retk/core/ai/llm/api/xfyun.py +++ b/src/retk/core/ai/llm/api/xfyun.py @@ -45,24 +45,22 @@ def __init__( timeout=timeout, default_model=XfYunModelEnum.SPARK_LITE.value, ) - _s = config.get_settings() - self.api_secret = _s.XFYUN_API_SECRET - self.api_key = _s.XFYUN_API_KEY - self.app_id = _s.XFYUN_APP_ID - if self.api_secret == "" or self.api_key == "" or self.app_id == "": - raise NoAPIKeyError("XfYun api secret or key is empty") def get_url(self, model: Optional[str], req_id: str = None) -> str: + _s = config.get_settings() + if _s.XFYUN_API_KEY == "" or _s.XFYUN_API_SECRET == "" or _s.XFYUN_APP_ID == "": + raise NoAPIKeyError("XfYun api secret or skey or appID is empty") + if model is None: model = self.default_model cur_time = datetime.now() date = handlers.format_date_time(mktime(cur_time.timetuple())) tmp = f"host: spark-api.xf-yun.com\ndate: {date}\nGET /{model}/chat HTTP/1.1" - tmp_sha = hmac.new(self.api_secret.encode('utf-8'), tmp.encode('utf-8'), digestmod=hashlib.sha256).digest() + tmp_sha = hmac.new(_s.XFYUN_API_SECRET.encode('utf-8'), tmp.encode('utf-8'), digestmod=hashlib.sha256).digest() signature = base64.b64encode(tmp_sha).decode(encoding='utf-8') - authorization_origin = f'api_key="{self.api_key}", ' \ + authorization_origin = f'api_key="{_s.XFYUN_API_KEY}", ' \ f'algorithm="hmac-sha256", ' \ f'headers="host date request-line", ' \ f'signature="{signature}"' @@ -79,7 +77,7 @@ def get_url(self, model: Optional[str], req_id: str = None) -> str: def get_data(self, model: str, messages: MessagesType) -> Dict: return { "header": { - "app_id": self.app_id, + "app_id": config.get_settings().XFYUN_APP_ID, "uid": "12345" }, "parameter": { diff --git a/src/retk/core/ai/llm/knowledge/__init__.py b/src/retk/core/ai/llm/knowledge/__init__.py index 75260b9..0073131 100644 --- a/src/retk/core/ai/llm/knowledge/__init__.py +++ b/src/retk/core/ai/llm/knowledge/__init__.py @@ -2,6 +2,7 @@ from typing import Tuple from retk import const +from .db_ops import extend_on_node_update, extend_on_node_post, LLM_SERVICES from ..api.base import BaseLLMService, MessagesType system_summary_prompt = (Path(__file__).parent / "system_summary.md").read_text(encoding="utf-8") @@ -12,12 +13,12 @@ async def _send( llm_service: BaseLLMService, model: str, system_prompt: str, - query: str, + md: str, req_id: str, ) -> Tuple[str, const.CodeEnum]: _msgs: MessagesType = [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": query}, + {"role": "user", "content": md}, ] return await llm_service.complete(messages=_msgs, model=model, req_id=req_id) @@ -25,14 +26,14 @@ async def _send( async def summary( llm_service: BaseLLMService, model: str, - query: str, + md: str, req_id: str = None, ) -> Tuple[str, const.CodeEnum]: return await _send( llm_service=llm_service, model=model, system_prompt=system_summary_prompt, - query=query, + md=md, req_id=req_id, ) @@ -40,13 +41,13 @@ async def summary( async def extend( llm_service: BaseLLMService, model: str, - query: str, + md: str, req_id: str = None, ) -> Tuple[str, const.CodeEnum]: return await _send( llm_service=llm_service, model=model, system_prompt=system_extend_prompt, - query=query, + md=md, req_id=req_id, ) diff --git a/src/retk/core/ai/llm/knowledge/db_ops.py b/src/retk/core/ai/llm/knowledge/db_ops.py new file mode 100644 index 0000000..1f55ba3 --- /dev/null +++ b/src/retk/core/ai/llm/knowledge/db_ops.py @@ -0,0 +1,67 @@ +from datetime import timedelta + +from bson import ObjectId + +from retk.models.client import client +from retk.models.tps.llm import NodeExtendQueue +from retk.models.tps.node import Node +from .. import api + +TOP_P = 0.9 +TEMPERATURE = 0.5 +TIMEOUT = 60 + +LLM_SERVICES = { + "tencent": api.TencentService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), + "ali": api.AliyunService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), + "openai": api.OpenaiService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), + "moonshot": api.MoonshotService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), + "xf": api.XfYunService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), + "baidu": api.BaiduService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), +} + + +async def extend_on_node_post(data: Node): + q: NodeExtendQueue = NodeExtendQueue( + _id=ObjectId(), + uid=data["uid"], + nid=data["id"], + summaryService="tencent", + summaryModel=api.TencentModelEnum.HUNYUAN_LITE.value, + extendService="ali", + extendModel=api.AliyunModelEnum.QWEN_PLUS.value, + ) + + # sort by _id desc + docs = await client.coll.llm_extend_node_queue.find( + filter={"uid": data["uid"]} + ).sort("_id", -1).to_list(None) + has_q = False + for doc in docs: + if doc["nid"] == data["id"]: + has_q = True + q["_id"] = doc["_id"] + # renew the creating time + await client.coll.llm_extend_knowledge_queue.update_one( + filter={"_id": doc["_id"]}, + update={"_id": q["_id"]}, + ) + break + + max_keep = 5 + if not has_q: + if len(docs) >= max_keep: + # remove the oldest and only keep the latest 5 + await client.coll.llm_extend_node_queue.delete_many( + {"_id": {"$in": [doc["_id"] for doc in docs[max_keep:]]}} + ) + + await client.coll.llm_extend_node_queue.insert_one(q) + + +async def extend_on_node_update(old_data: Node, new_data: Node): + # filter out frequent updates + if new_data["modifiedAt"] - old_data["modifiedAt"] < timedelta(seconds=60): + return + + await extend_on_node_post(new_data) diff --git a/src/retk/core/files/importing/async_tasks/utils.py b/src/retk/core/files/importing/async_tasks/utils.py index 1d857f0..cdce901 100644 --- a/src/retk/core/files/importing/async_tasks/utils.py +++ b/src/retk/core/files/importing/async_tasks/utils.py @@ -20,17 +20,17 @@ async def check_last_task_finished(uid: str, type_: str) -> Tuple[Optional[Impor return None, False if doc is None: - doc: ImportData = { - "_id": ObjectId(), - "uid": uid, - "process": 0, - "type": type_, - "startAt": datetime.datetime.now(tz=utc), - "running": True, - "msg": "", - "code": 0, - "obsidian": {}, - } + doc: ImportData = ImportData( + _id=ObjectId(), + uid=uid, + process=0, + type=type_, + startAt=datetime.datetime.now(tz=utc), + running=True, + msg="", + code=0, + obsidian={}, + ) res = await client.coll.import_data.insert_one(doc) if not res.acknowledged: await set_running_false( diff --git a/src/retk/core/node/node.py b/src/retk/core/node/node.py index eb2664d..874561b 100644 --- a/src/retk/core/node/node.py +++ b/src/retk/core/node/node.py @@ -7,7 +7,7 @@ from retk import config, const, utils, regex from retk import plugins -from retk.core import user +from retk.core import user, ai from retk.logger import logger from retk.models import tps, db_ops from retk.models.client import client @@ -78,6 +78,7 @@ async def post( code = await client.search.add(au=au, doc=SearchDoc(nid=nid, title=title, body=body)) if code != const.CodeEnum.OK: logger.error(f"add search index failed, code: {code}") + await ai.llm.knowledge.extend_on_node_post(data=data) return data, const.CodeEnum.OK @@ -220,6 +221,8 @@ async def update_md( code = await backup.storage_md(node=doc, keep_hist=True) if code != const.CodeEnum.OK: return doc, old_n, code + + await ai.llm.knowledge.extend_on_node_update(old_data=old_n, new_data=doc) return doc, old_n, code diff --git a/src/retk/core/notice.py b/src/retk/core/notice.py index 0336b2c..39edc04 100644 --- a/src/retk/core/notice.py +++ b/src/retk/core/notice.py @@ -30,18 +30,18 @@ async def post_in_manager_delivery( publish_at = publish_at.astimezone(utc) # add system notice - notice: NoticeManagerDelivery = { - "_id": ObjectId(), - "senderType": au.u.type, - "senderId": au.u.id, - "title": title, - "html": md2html(content), - "snippet": md2txt(content)[:20], - "recipientType": recipient_type, # send to which user type, 0: all, 1: batch, 2: admin, 3: manager - "batchTypeIds": batch_type_ids, # if recipient=batch, put user id here - "publishAt": publish_at, # publish time - "scheduled": False, # has been scheduled to sent to user - } + notice = NoticeManagerDelivery( + _id=ObjectId(), + senderType=au.u.type, + senderId=au.u.id, + title=title, + html=md2html(content), + snippet=md2txt(content)[:20], + recipientType=recipient_type, # send to which user type, 0: all, 1: batch, 2: admin, 3: manager + batchTypeIds=batch_type_ids, # if recipient=batch, put user id here + publishAt=publish_at, # publish time + scheduled=False, # has been scheduled to sent to user + ) res = await client.coll.notice_manager_delivery.insert_one(notice) if not res.acknowledged: return None, const.CodeEnum.OPERATION_FAILED @@ -150,14 +150,14 @@ async def get_user_notices( new_system_notices: List[Notice] = [] for usn in user_system_notices: detail = n_details_dict[usn["noticeId"]] - new_system_notices.append({ - "id": str(usn["noticeId"]), - "title": detail["title"], - "snippet": detail["snippet"], - "publishAt": datetime2str(detail["publishAt"]), - "read": usn["read"], - "readTime": datetime2str(usn["readTime"]) if usn["readTime"] is not None else None, - }) + new_system_notices.append(Notice( + id=str(usn["noticeId"]), + title=detail["title"], + snippet=detail["snippet"], + publishAt=datetime2str(detail["publishAt"]), + read=usn["read"], + readTime=datetime2str(usn["readTime"]) if usn["readTime"] is not None else None, + )) return { "hasUnread": has_unread, diff --git a/src/retk/core/scheduler/schedule.py b/src/retk/core/scheduler/schedule.py index 076445f..4551aa9 100644 --- a/src/retk/core/scheduler/schedule.py +++ b/src/retk/core/scheduler/schedule.py @@ -92,6 +92,12 @@ def init_tasks(): func=tasks.notice.deliver_unscheduled_system_notices, second=0, ) + # check unscheduled extend node every hour + run_every_at( + job_id="deliver_unscheduled_node_extend", + func=tasks.extend_node.deliver_unscheduled_extend_nodes, + minute=0, + ) return diff --git a/src/retk/core/scheduler/tasks/__init__.py b/src/retk/core/scheduler/tasks/__init__.py index 7da7c15..1ec2016 100644 --- a/src/retk/core/scheduler/tasks/__init__.py +++ b/src/retk/core/scheduler/tasks/__init__.py @@ -1,4 +1,5 @@ from . import ( email, notice, + extend_node, ) diff --git a/src/retk/core/scheduler/tasks/extend_node.py b/src/retk/core/scheduler/tasks/extend_node.py new file mode 100644 index 0000000..252c580 --- /dev/null +++ b/src/retk/core/scheduler/tasks/extend_node.py @@ -0,0 +1,66 @@ +import asyncio +import random +from typing import List + +from bson import ObjectId + +from retk import const +from retk.core.ai.llm import knowledge +from retk.logger import logger +from retk.models.client import init_mongo +from retk.models.tps.llm import NodeExtendQueue, ExtendedNode + + +def deliver_unscheduled_extend_nodes(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + res = loop.run_until_complete(async_deliver_unscheduled_extend_nodes()) + loop.close() + return res + + +async def async_deliver_unscheduled_extend_nodes() -> str: + _, db = init_mongo(connection_timeout=5) + batch_size = 3 + total_knowledge_extended = 0 + while True: + batch: List[NodeExtendQueue] = await db["llmExtendNodeQueue"].find().limit(batch_size).to_list(None) + if len(batch) == 0: + break + + batch_result: List[ExtendedNode] = [] + for item in batch: + req_id = "".join([str(random.randint(0, 9)) for _ in range(10)]) + md = await db["node"].find_one({"id": item["nid"]}) + # md = md[:int(8000 * 1.8)] + _summary, code = await knowledge.summary( + llm_service=knowledge.LLM_SERVICES[item["summaryService"]], + model=item["summaryModel"], + md=md, + req_id=req_id, + ) + if code != const.CodeEnum.OK: + logger.error(f"knowledge summary error: {code}") + continue + _extended, code = await knowledge.extend( + llm_service=knowledge.LLM_SERVICES[item["extendService"]], + model=item["extendModel"], + md=md, + req_id=req_id, + ) + if code != const.CodeEnum.OK: + logger.error(f"knowledge extend error: {code}") + continue + batch_result.append(ExtendedNode( + _id=ObjectId(), + uid=item["uid"], + sourceNids=[item["nid"]], + sourceMd=[md], + extendMd=_extended, + )) + total_knowledge_extended += 1 + + if len(batch_result) > 0: + await db["llmExtendedNode"].insert_many(batch_result) + + return f"successfully extent {total_knowledge_extended} node" diff --git a/src/retk/core/statistic.py b/src/retk/core/statistic.py index 4e163fd..170ac03 100644 --- a/src/retk/core/statistic.py +++ b/src/retk/core/statistic.py @@ -2,6 +2,7 @@ from retk import const from retk.models.client import client +from retk.models.tps.statistic import UserBehavior async def add_user_behavior( @@ -9,9 +10,9 @@ async def add_user_behavior( type_: const.UserBehaviorTypeEnum, remark: str, ): - await client.coll.user_behavior.insert_one({ - "_id": ObjectId(), - "uid": uid, - "type": type_.value, - "remark": remark, - }) + await client.coll.user_behavior.insert_one(UserBehavior( + _id=ObjectId(), + uid=uid, + type=type_.value, + remark=remark, + )) diff --git a/src/retk/models/client.py b/src/retk/models/client.py index 16da657..88738f8 100644 --- a/src/retk/models/client.py +++ b/src/retk/models/client.py @@ -80,6 +80,8 @@ def init_mongo(self): self.coll.user_behavior = db["userBehavior"] self.coll.notice_manager_delivery = db["noticeManagerDelivery"] self.coll.notice_system = db["noticeSystem"] + self.coll.llm_extend_node_queue = db["llmExtendNodeQueue"] + self.coll.llm_extended_node = db["llmExtendedNode"] async def init_search(self): conf = config.get_settings() diff --git a/src/retk/models/coll.py b/src/retk/models/coll.py index 42d1431..74ad0c4 100644 --- a/src/retk/models/coll.py +++ b/src/retk/models/coll.py @@ -9,10 +9,19 @@ @dataclass class Collections: + # base collections users: Union[Collection, "AsyncIOMotorCollection"] = None nodes: Union[Collection, "AsyncIOMotorCollection"] = None import_data: Union[Collection, "AsyncIOMotorCollection"] = None user_file: Union[Collection, "AsyncIOMotorCollection"] = None + + # notification notice_manager_delivery: Union[Collection, "AsyncIOMotorCollection"] = None notice_system: Union[Collection, "AsyncIOMotorCollection"] = None + + # user behavior user_behavior: Union[Collection, "AsyncIOMotorCollection"] = None + + # llm + llm_extend_node_queue: Union[Collection, "AsyncIOMotorCollection"] = None + llm_extended_node: Union[Collection, "AsyncIOMotorCollection"] = None diff --git a/src/retk/models/indexing.py b/src/retk/models/indexing.py index 5c70b81..1be33e7 100644 --- a/src/retk/models/indexing.py +++ b/src/retk/models/indexing.py @@ -19,6 +19,7 @@ async def remote_try_build_index(coll: Collections): await user_behavior_coll(coll.user_behavior) await notice_manager_delivery_coll(coll.notice_manager_delivery) await notice_system_coll(coll.notice_system) + await llm_extend_node_queue_coll(coll.llm_extend_node_queue) async def not_in_and_create_index(coll: "AsyncIOMotorCollection", index_info, keys: list, unique: bool) -> str: @@ -128,3 +129,13 @@ async def notice_system_coll(coll: "AsyncIOMotorCollection"): keys=["senderId"], unique=False ) + + +async def llm_extend_node_queue_coll(coll: "AsyncIOMotorCollection"): + index_info = await coll.index_information() + await not_in_and_create_index( + coll=coll, + index_info=index_info, + keys=["uid"], + unique=False, + ) diff --git a/src/retk/models/tps/llm.py b/src/retk/models/tps/llm.py new file mode 100644 index 0000000..00c41e2 --- /dev/null +++ b/src/retk/models/tps/llm.py @@ -0,0 +1,21 @@ +from typing import TypedDict, List + +from bson import ObjectId + + +class NodeExtendQueue(TypedDict): + _id: ObjectId + uid: str + nid: str + summaryService: str + summaryModel: str + extendService: str + extendModel: str + + +class ExtendedNode(TypedDict): + _id: ObjectId + uid: str + sourceNids: List[str] + sourceMd: List[str] + extendMd: str diff --git a/src/retk/models/tps/statistic.py b/src/retk/models/tps/statistic.py index 01885b4..750c271 100644 --- a/src/retk/models/tps/statistic.py +++ b/src/retk/models/tps/statistic.py @@ -6,5 +6,5 @@ class UserBehavior(TypedDict): _id: ObjectId uid: str - bType: int + type: int remark: str diff --git a/src/retk/utils.py b/src/retk/utils.py index cc6af47..03ea7da 100644 --- a/src/retk/utils.py +++ b/src/retk/utils.py @@ -8,7 +8,7 @@ import webbrowser from html.parser import HTMLParser from io import StringIO -from typing import Tuple, Optional, List, Dict, Any +from typing import Tuple, Optional, List, Literal from urllib.parse import urlparse import httpx @@ -349,43 +349,43 @@ def get_user_dict( last_state_node_display_sort_key: str, settings_language: str, - settings_theme: str, - settings_editor_mode: str, + settings_theme: Literal["light", "dark"], + settings_editor_mode: Literal["ir", "wysiwyg"], settings_editor_font_size: int, - settings_editor_code_theme: str, + settings_editor_code_theme: tps.user.CODE_THEME_TYPES, settings_editor_sep_right_width: int, settings_editor_side_current_tool_id: str, ) -> tps.UserMeta: - return { - "_id": _id, - "id": uid, - "source": source, - "account": account, - "nickname": nickname, - "email": email, - "avatar": avatar, - "hashed": hashed, - "disabled": disabled, - "modifiedAt": modified_at, - "usedSpace": used_space, - "type": type_, - - "lastState": { - "recentCursorSearchSelectedNIds": last_state_recent_cursor_search_selected_nids, - "recentSearch": last_state_recent_search, - "nodeDisplayMethod": last_state_node_display_method, - "nodeDisplaySortKey": last_state_node_display_sort_key, - }, - "settings": { - "language": settings_language, - "theme": settings_theme, - "editorMode": settings_editor_mode, - "editorFontSize": settings_editor_font_size, - "editorCodeTheme": settings_editor_code_theme, - "editorSepRightWidth": settings_editor_sep_right_width, - "editorSideCurrentToolId": settings_editor_side_current_tool_id, - }, - } + return tps.UserMeta( + _id=_id, + id=uid, + source=source, + account=account, + nickname=nickname, + email=email, + avatar=avatar, + hashed=hashed, + disabled=disabled, + modifiedAt=modified_at, + usedSpace=used_space, + type=type_, + + lastState=tps.user._LastState( + recentCursorSearchSelectedNIds=last_state_recent_cursor_search_selected_nids, + recentSearch=last_state_recent_search, + nodeDisplayMethod=last_state_node_display_method, + nodeDisplaySortKey=last_state_node_display_sort_key, + ), + settings=tps.user._Settings( + language=settings_language, + theme=settings_theme, + editorMode=settings_editor_mode, + editorFontSize=settings_editor_font_size, + editorCodeTheme=settings_editor_code_theme, + editorSepRightWidth=settings_editor_sep_right_width, + editorSideCurrentToolId=settings_editor_side_current_tool_id, + ), + ) def get_node_dict( @@ -402,25 +402,25 @@ def get_node_dict( in_trash_at: Optional[datetime.datetime], from_node_ids: List[str], to_node_ids: List[str], - history: List[Dict[str, Any]], + history: List[str], ) -> tps.Node: - return { - "_id": _id, - "id": nid, - "uid": uid, - "md": md, - "title": title, - "snippet": snippet, - "type": type_, - "disabled": disabled, - "inTrash": in_trash, - "modifiedAt": modified_at, - "inTrashAt": in_trash_at, - "fromNodeIds": from_node_ids, - "toNodeIds": to_node_ids, - "history": history, - "favorite": False, - } + return tps.Node( + _id=_id, + id=nid, + uid=uid, + md=md, + title=title, + snippet=snippet, + type=type_, + disabled=disabled, + inTrash=in_trash, + modifiedAt=modified_at, + inTrashAt=in_trash_at, + fromNodeIds=from_node_ids, + toNodeIds=to_node_ids, + history=history, + favorite=False, + ) def get_token(uid: str, language: str) -> Tuple[str, str]: diff --git a/tests/test_ai_llm_knowledge.py b/tests/test_ai_llm_knowledge.py index 93cb222..73fe750 100644 --- a/tests/test_ai_llm_knowledge.py +++ b/tests/test_ai_llm_knowledge.py @@ -16,7 +16,7 @@ 2. 工艺简单,大量的预制工作,较低出餐时间,出餐快。适合快节奏的打工人群 3. 因为出餐快,所以不用招人,省人力成本 -![IMG6992.png](https://files.rethink.run/userData/RroFuzYSd8NGoKRL5zrrkZ/3a4344ccd6ba477e59ddf1f7f67e98bd.png) +![IMG6992.png](https://files.rethink.run/userData/3a4344ccd6ba477e59ddf1f7f67e98bd.png) 更值得一提的是猪脚饭在广东便宜,其它地方贵,原因之一是可以从香港走私猪脚,因为外国人不吃,所以产能过剩 @@ -62,7 +62,7 @@ async def test_summary(self): text, code = await llm.knowledge.summary( llm_service=service, model=model.value, - query=md_source, + md=md_source, ) self.assertEqual(const.CodeEnum.OK, code, msg=text) print(f"{service.__class__.__name__} {model.name}\n{text}\n\n") @@ -80,7 +80,7 @@ async def test_extend(self): text, code = await llm.knowledge.extend( llm_service=service, model=model.value, - query=md_summary, + md=md_summary, ) self.assertEqual(const.CodeEnum.OK, code, msg=text) print(f"{service.__class__.__name__} {model.name}\n{text}\n\n") diff --git a/tests/test_core_local.py b/tests/test_core_local.py index 5da121f..755193f 100644 --- a/tests/test_core_local.py +++ b/tests/test_core_local.py @@ -6,7 +6,7 @@ from io import BytesIO from pathlib import Path from textwrap import dedent -from unittest.mock import patch +from unittest.mock import patch, AsyncMock import httpx from PIL import Image @@ -26,6 +26,7 @@ from . import utils +@patch("retk.core.ai.llm.knowledge._send", new_callable=AsyncMock, return_value=["", const.CodeEnum.OK]) class LocalModelsTest(unittest.IsolatedAsyncioTestCase): @classmethod def setUpClass(cls) -> None: @@ -54,7 +55,7 @@ async def asyncTearDown(self) -> None: shutil.rmtree(Path(__file__).parent / "temp" / const.settings.DOT_DATA / "files", ignore_errors=True) shutil.rmtree(Path(__file__).parent / "temp" / const.settings.DOT_DATA / "md", ignore_errors=True) - async def test_user(self): + async def test_user(self, mock_send): u, code = await core.user.get_by_email(email=const.DEFAULT_USER["email"]) self.assertEqual(const.CodeEnum.OK, code) self.assertEqual("rethink", u["nickname"]) @@ -113,7 +114,7 @@ async def test_user(self): await core.account.manager.delete_by_uid(uid=_uid) - async def test_node(self): + async def test_node(self, mock_send): node, code = await core.node.post( au=self.au, md="a" * (const.settings.MD_MAX_LENGTH + 1), type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -212,7 +213,7 @@ async def test_node(self): self.assertEqual(2, len(nodes)) self.assertEqual(2, total) - async def test_parse_at(self): + async def test_parse_at(self, mock_send): nid1, _ = await core.node.post( au=self.au, md="c", type_=const.NodeTypeEnum.MARKDOWN.value, ) @@ -265,7 +266,7 @@ async def test_parse_at(self): self.assertEqual(const.CodeEnum.OK, code) self.assertEqual(0, len(n["fromNodeIds"])) - async def test_add_set(self): + async def test_add_set(self, mock_send): node, code = await core.node.post( au=self.au, md="title\ntext", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -278,7 +279,7 @@ async def test_add_set(self): self.assertEqual(const.CodeEnum.OK, code) self.assertEqual(1, len(node["toNodeIds"])) - async def test_cursor_text(self): + async def test_cursor_text(self, mock_send): n1, code = await core.node.post( au=self.au, md="title\ntext", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -322,7 +323,7 @@ async def test_cursor_text(self): self.assertEqual(3, total) self.assertEqual("Welcome to Rethink", recom[2].title) - async def test_to_trash(self): + async def test_to_trash(self, mock_send): n1, code = await core.node.post( au=self.au, md="title\ntext", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -350,7 +351,7 @@ async def test_to_trash(self): self.assertEqual(4, len(nodes)) self.assertEqual(4, total) - async def test_search(self): + async def test_search(self, mock_send): code = await core.recent.put_recent_search(au=self.au, query="a") self.assertEqual(const.CodeEnum.OK, code) await core.recent.put_recent_search(au=self.au, query="c") @@ -360,7 +361,7 @@ async def test_search(self): self.assertIsNotNone(doc) self.assertEqual(["b", "c", "a"], doc["lastState"]["recentSearch"]) - async def test_batch(self): + async def test_batch(self, mock_send): ns = [] for i in range(10): n, code = await core.node.post( @@ -393,19 +394,19 @@ async def test_batch(self): self.assertEqual(0, total) self.assertEqual(0, len(tns)) - async def test_files_upload_process(self): + async def test_files_upload_process(self, mock_send): now = datetime.datetime.now(tz=utc) - doc: ImportData = { - "_id": ObjectId(), - "uid": "xxx", - "process": 0, - "type": "text", - "startAt": now, - "running": True, - "obsidian": {}, - "msg": "", - "code": const.CodeEnum.OK.value, - } + doc = ImportData( + _id=ObjectId(), + uid="xxx", + process=0, + type="text", + startAt=now, + running=True, + obsidian={}, + msg="", + code=const.CodeEnum.OK.value, + ) res = await client.coll.import_data.insert_one(doc) self.assertTrue(res.acknowledged) @@ -423,7 +424,7 @@ async def test_files_upload_process(self): await client.coll.import_data.delete_one({"uid": "xxx"}) - async def test_update_title_and_from_nodes_updates(self): + async def test_update_title_and_from_nodes_updates(self, mock_send): n1, code = await core.node.post( au=self.au, md="title1\ntext", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -439,7 +440,7 @@ async def test_update_title_and_from_nodes_updates(self): self.assertEqual(const.CodeEnum.OK, code) self.assertEqual(f"title2\n[@title1Changed](/n/{n1['id']})", n2["md"]) - async def test_upload_image_vditor(self): + async def test_upload_image_vditor(self, mock_send): u, code = await core.user.get(self.au.u.id) self.assertEqual(const.CodeEnum.OK, code) used_space = u["usedSpace"] @@ -466,10 +467,8 @@ async def test_upload_image_vditor(self): u, code = await core.user.get(self.au.u.id) self.assertEqual(used_space + size, u["usedSpace"]) - @patch( - "retk.core.files.upload.httpx.AsyncClient.get", - ) - async def test_fetch_image_vditor(self, mock_get): + @patch("retk.core.files.upload.httpx.AsyncClient.get", ) + async def test_fetch_image_vditor(self, mock_get, mock_send): f = open(Path(__file__).parent / "temp" / "fake.png", "rb") mock_get.return_value = httpx.Response( 200, @@ -493,7 +492,7 @@ async def test_fetch_image_vditor(self, mock_get): self.assertEqual(used_space + f.tell(), u["usedSpace"]) f.close() - async def test_update_used_space(self): + async def test_update_used_space(self, mock_send): u, code = await core.user.get(self.au.u.id) base_used_space = u["usedSpace"] for delta, value in [ @@ -513,7 +512,7 @@ async def test_update_used_space(self): base_used_space = 0 self.assertAlmostEqual(value, now, msg=f"delta: {delta}, value: {value}") - async def test_node_version(self): + async def test_node_version(self, mock_send): node, code = await core.node.post( au=self.au, md="[title](/qqq)\nbody", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -534,7 +533,7 @@ async def test_node_version(self): self.assertEqual(const.CodeEnum.OK, code) self.assertEqual(2, len(list(hist_dir.glob("*.md")))) - async def test_md_history(self): + async def test_md_history(self, mock_send): bi = config.get_settings().MD_BACKUP_INTERVAL config.get_settings().MD_BACKUP_INTERVAL = 0.0001 n1, code = await core.node.post( @@ -571,14 +570,14 @@ async def test_md_history(self): config.get_settings().MD_BACKUP_INTERVAL = bi - async def test_get_version(self): + async def test_get_version(self, mock_send): v, code = await core.self_hosted.get_latest_pkg_version() self.assertEqual(const.CodeEnum.OK, code) self.assertEqual(3, len(v)) for num in v: self.assertTrue(isinstance(num, int)) - async def test_system_notice(self): + async def test_system_notice(self, mock_send): au = deepcopy(self.au) au.u.type = const.USER_TYPE.MANAGER.id publish_at = datetime.datetime.now() @@ -606,7 +605,7 @@ async def test_system_notice(self): docs, total = await core.notice.get_system_notices(0, 10) self.assertTrue(docs[0]["scheduled"]) - async def test_notice(self): + async def test_notice(self, mock_send): au = deepcopy(self.au) doc, code = await core.notice.post_in_manager_delivery( au=au, @@ -661,7 +660,7 @@ async def test_notice(self): self.assertFalse(sn[0]["read"]) self.assertIsNone(sn[0]["readTime"]) - async def test_mark_read(self): + async def test_mark_read(self, mock_send): au = deepcopy(self.au) au.u.type = const.USER_TYPE.MANAGER.id for i in range(3): diff --git a/tests/test_core_remote.py b/tests/test_core_remote.py index 5bdda29..559bbf3 100644 --- a/tests/test_core_remote.py +++ b/tests/test_core_remote.py @@ -3,7 +3,7 @@ import unittest from copy import deepcopy from textwrap import dedent -from unittest.mock import patch +from unittest.mock import patch, AsyncMock import elastic_transport import pymongo.errors @@ -20,6 +20,7 @@ from . import utils +@patch("retk.core.ai.llm.knowledge._send", new_callable=AsyncMock, return_value=["", const.CodeEnum.OK]) class RemoteModelsTest(unittest.IsolatedAsyncioTestCase): default_pwd = "rethink123" @@ -40,7 +41,8 @@ async def asyncSetUp(self) -> None: client.connection_timeout = 1 await client.init() for coll in client.coll.__dict__.values(): - await coll.delete_many({}) + if coll is not None: + await coll.delete_many({}) u, code = await signup( email=const.DEFAULT_USER["email"], @@ -85,7 +87,7 @@ async def asyncTearDown(self) -> None: utils.skip_no_connect.skip = True @utils.skip_no_connect - async def test_same_key(self): + async def test_same_key(self, mock_send): async def add(): oid = ObjectId() await client.coll.users.insert_one({ @@ -126,7 +128,7 @@ async def add(): await add() @utils.skip_no_connect - async def test_user(self): + async def test_user(self, mock_send): u, code = await core.user.get_by_email(email=const.DEFAULT_USER["email"]) self.assertEqual(const.CodeEnum.OK, code) self.assertEqual("rethink", u["nickname"]) @@ -187,6 +189,7 @@ async def test_user(self): @patch("retk.core.node.backup.__save_md_to_cos") async def test_node( self, + mock_send, mock_save_md_to_cos, mock_get_md_from_cos, mock_remove_md_from_cos, @@ -282,6 +285,7 @@ async def test_node( @patch("retk.core.node.backup.__save_md_to_cos") async def test_parse_at( self, + mock_send, mock_save_md_to_cos, mock_get_md_from_cos, mock_remove_md_from_cos, @@ -347,7 +351,7 @@ async def test_parse_at( self.assertEqual(0, len(n["fromNodeIds"])) @utils.skip_no_connect - async def test_add_set(self): + async def test_add_set(self, mock_send): node, code = await core.node.post( au=self.au, md="title\ntext", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -361,7 +365,7 @@ async def test_add_set(self): self.assertEqual(1, len(node["toNodeIds"])) @utils.skip_no_connect - async def test_to_trash(self): + async def test_to_trash(self, mock_send): n1, code = await core.node.post( au=self.au, md="title\ntext", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -399,10 +403,12 @@ async def test_to_trash(self): @patch("retk.core.node.backup.__save_md_to_cos") async def test_batch( self, + mock_send, mock_save_md_to_cos, mock_get_md_from_cos, mock_remove_md_from_cos, mock_remove_md_all_versions_from_cos, + ): mock_save_md_to_cos.return_value = const.CodeEnum.OK mock_get_md_from_cos.return_value = ("", const.CodeEnum.OK) @@ -446,7 +452,7 @@ async def test_batch( self.assertEqual(0, len(tns)) @utils.skip_no_connect - async def test_update_used_space(self): + async def test_update_used_space(self, mock_send): u, code = await core.user.get(self.au.u.id) base_used_space = u["usedSpace"] for delta, value in [ @@ -473,6 +479,7 @@ async def test_update_used_space(self): @patch("retk.core.node.backup.__save_md_to_cos") async def test_md_history( self, + mock_send, mock_save_md_to_cos, mock_get_md_from_cos, mock_remove_md_from_cos, @@ -520,7 +527,7 @@ async def test_md_history( config.get_settings().MD_BACKUP_INTERVAL = bi @utils.skip_no_connect - async def test_system_notice(self): + async def test_system_notice(self, mock_send): au = deepcopy(self.au) au.u.type = const.USER_TYPE.MANAGER.id publish_at = datetime.datetime.now() @@ -549,7 +556,7 @@ async def test_system_notice(self): self.assertTrue(docs[0]["scheduled"]) @utils.skip_no_connect - async def test_notice(self): + async def test_notice(self, mock_send): au = deepcopy(self.au) doc, code = await core.notice.post_in_manager_delivery( au=au, @@ -605,7 +612,7 @@ async def test_notice(self): self.assertIsNone(sn[0]["readTime"]) @utils.skip_no_connect - async def test_mark_read(self): + async def test_mark_read(self, mock_send): au = deepcopy(self.au) au.u.type = const.USER_TYPE.MANAGER.id for i in range(3):