Skip to content

Commit

Permalink
feat(app):
Browse files Browse the repository at this point in the history
- typedict optimize
- add extend_on_node_update / create
  • Loading branch information
MorvanZhou committed Jun 26, 2024
1 parent 19e5aed commit c0e0be6
Show file tree
Hide file tree
Showing 25 changed files with 390 additions and 182 deletions.
1 change: 1 addition & 0 deletions src/retk/core/ai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import llm
11 changes: 6 additions & 5 deletions src/retk/core/ai/llm/api/aliyun.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ def __init__(
timeout=timeout,
default_model=AliyunModelEnum.QWEN1_5_05B.value,
)
self.api_key = config.get_settings().ALIYUN_DASHSCOPE_API_KEY
if self.api_key == "":
raise NoAPIKeyError("Aliyun API key is empty")

def get_headers(self, stream: bool) -> Dict[str, str]:
@staticmethod
def get_headers(stream: bool) -> Dict[str, str]:
k = config.get_settings().ALIYUN_DASHSCOPE_API_KEY
if k == "":
raise NoAPIKeyError("Aliyun API key is empty")
h = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}',
'Authorization': f'Bearer {k}',
}
if stream:
h["Accept"] = "text/event-stream"
Expand Down
14 changes: 7 additions & 7 deletions src/retk/core/ai/llm/api/baidu.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ def __init__(
timeout=timeout,
default_model=BaiduModelEnum.ERNIE_SPEED_8K.value,
)
self.api_key = config.get_settings().BAIDU_QIANFAN_API_KEY
self.secret_key = config.get_settings().BAIDU_QIANFAN_SECRET_KEY
if self.api_key == "" or self.secret_key == "":
raise NoAPIKeyError("Baidu api key or key is empty")

self.headers = {
"Content-Type": "application/json",
}
Expand All @@ -46,16 +41,21 @@ def __init__(
self.token = ""

async def set_token(self, req_id: str = None):
_s = config.get_settings()
if _s.BAIDU_QIANFAN_API_KEY == "" or _s.BAIDU_QIANFAN_SECRET_KEY == "":
raise NoAPIKeyError("Baidu api key or skey is empty")

if self.token_expires_at > datetime.now().timestamp():
return

resp = await httpx_helper.get_async_client().post(
url="https://aip.baidubce.com/oauth/2.0/token",
headers={"Content-Type": "application/json", 'Accept': 'application/json'},
content=b"",
params={
"grant_type": "client_credentials",
"client_id": self.api_key,
"client_secret": self.secret_key,
"client_id": _s.BAIDU_QIANFAN_API_KEY,
"client_secret": _s.BAIDU_QIANFAN_SECRET_KEY,
}
)
if resp.status_code != 200:
Expand Down
5 changes: 4 additions & 1 deletion src/retk/core/ai/llm/api/moonshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ def __init__(
timeout: float = 60.,
):
super().__init__(
api_key=config.get_settings().MOONSHOT_API_KEY,
endpoint="https://api.moonshot.cn/v1/chat/completions",
default_model=MoonshotModelEnum.V1_8K.value,
top_p=top_p,
temperature=temperature,
timeout=timeout,
)

@staticmethod
def get_api_key():
return config.get_settings().MOONSHOT_API_KEY
31 changes: 21 additions & 10 deletions src/retk/core/ai/llm/api/openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from abc import ABC, abstractmethod
from enum import Enum
from typing import Tuple, AsyncIterable, Optional

Expand All @@ -16,10 +17,9 @@ class OpenaiModelEnum(str, Enum):
GPT35_TURBO_16K = "gpt-3.5-turbo-16k"


class OpenaiLLMStyle(BaseLLMService):
class OpenaiLLMStyle(BaseLLMService, ABC):
def __init__(
self,
api_key: str,
endpoint: str,
default_model: str,
top_p: float = 0.9,
Expand All @@ -33,12 +33,20 @@ def __init__(
timeout=timeout,
default_model=default_model,
)
self.api_key = api_key
if self.api_key == "":
raise NoAPIKeyError(f"{self.__class__.__name__} API key is empty")
self.headers = {

@staticmethod
@abstractmethod
def get_api_key():
pass

def get_headers(self):
k = self.get_api_key()
if k == "":
raise NoAPIKeyError(f"{self.__class__.__name__} api key is empty")

return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
"Authorization": f"Bearer {k}",
}

def get_payload(self, model: Optional[str], messages: MessagesType, stream: bool) -> bytes:
Expand All @@ -62,7 +70,7 @@ async def complete(
payload = self.get_payload(model, messages, stream=False)
rj, code = await self._complete(
url=self.endpoint,
headers=self.headers,
headers=self.get_headers(),
payload=payload,
req_id=req_id,
)
Expand All @@ -82,7 +90,7 @@ async def stream_complete(
payload = self.get_payload(model, messages, stream=True)
async for b, code in self._stream_complete(
url=self.endpoint,
headers=self.headers,
headers=self.get_headers(),
payload=payload,
req_id=req_id
):
Expand Down Expand Up @@ -118,10 +126,13 @@ def __init__(
timeout: float = 60.,
):
super().__init__(
api_key=config.get_settings().OPENAI_API_KEY,
endpoint="https://api.openai.com/v1/chat/completions",
default_model=OpenaiModelEnum.GPT35_TURBO.value,
top_p=top_p,
temperature=temperature,
timeout=timeout,
)

@staticmethod
def get_api_key():
return config.get_settings().OPENAI_API_KEY
12 changes: 6 additions & 6 deletions src/retk/core/ai/llm/api/tencent.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ def __init__(
timeout=timeout,
default_model=TencentModelEnum.HUNYUAN_LITE.value,
)
self.secret_id = config.get_settings().HUNYUAN_SECRET_ID
self.secret_key = config.get_settings().HUNYUAN_SECRET_KEY
if self.secret_id == "" or self.secret_key == "":
raise NoAPIKeyError("Tencent secret id or key is empty")

def get_auth(self, action: str, payload: bytes, timestamp: int, content_type: str) -> str:
_s = config.get_settings()
if _s.HUNYUAN_SECRET_KEY == "" or _s.HUNYUAN_SECRET_ID == "":
raise NoAPIKeyError("Tencent secret id or key is empty")

algorithm = "TC3-HMAC-SHA256"
date = datetime.utcfromtimestamp(timestamp).strftime("%Y-%m-%d")

Expand All @@ -78,14 +78,14 @@ def get_auth(self, action: str, payload: bytes, timestamp: int, content_type: st
string_to_sign = f"{algorithm}\n{timestamp}\n{credential_scope}\n{hashed_canonical_request}"

# ************* 步骤 3:计算签名 *************
secret_date = sign(f"TC3{self.secret_key}".encode("utf-8"), date)
secret_date = sign(f"TC3{_s.HUNYUAN_SECRET_KEY}".encode("utf-8"), date)
secret_service = sign(secret_date, self.service)
secret_signing = sign(secret_service, "tc3_request")
signature = hmac.new(secret_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()

# ************* 步骤 4:拼接 Authorization *************
authorization = f"{algorithm}" \
f" Credential={self.secret_id}/{credential_scope}," \
f" Credential={_s.HUNYUAN_SECRET_ID}/{credential_scope}," \
f" SignedHeaders={signed_headers}," \
f" Signature={signature}"
return authorization
Expand Down
16 changes: 7 additions & 9 deletions src/retk/core/ai/llm/api/xfyun.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,22 @@ def __init__(
timeout=timeout,
default_model=XfYunModelEnum.SPARK_LITE.value,
)
_s = config.get_settings()
self.api_secret = _s.XFYUN_API_SECRET
self.api_key = _s.XFYUN_API_KEY
self.app_id = _s.XFYUN_APP_ID
if self.api_secret == "" or self.api_key == "" or self.app_id == "":
raise NoAPIKeyError("XfYun api secret or key is empty")

def get_url(self, model: Optional[str], req_id: str = None) -> str:
_s = config.get_settings()
if _s.XFYUN_API_KEY == "" or _s.XFYUN_API_SECRET == "" or _s.XFYUN_APP_ID == "":
raise NoAPIKeyError("XfYun api secret or skey or appID is empty")

if model is None:
model = self.default_model
cur_time = datetime.now()
date = handlers.format_date_time(mktime(cur_time.timetuple()))

tmp = f"host: spark-api.xf-yun.com\ndate: {date}\nGET /{model}/chat HTTP/1.1"
tmp_sha = hmac.new(self.api_secret.encode('utf-8'), tmp.encode('utf-8'), digestmod=hashlib.sha256).digest()
tmp_sha = hmac.new(_s.XFYUN_API_SECRET.encode('utf-8'), tmp.encode('utf-8'), digestmod=hashlib.sha256).digest()

signature = base64.b64encode(tmp_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.api_key}", ' \
authorization_origin = f'api_key="{_s.XFYUN_API_KEY}", ' \
f'algorithm="hmac-sha256", ' \
f'headers="host date request-line", ' \
f'signature="{signature}"'
Expand All @@ -79,7 +77,7 @@ def get_url(self, model: Optional[str], req_id: str = None) -> str:
def get_data(self, model: str, messages: MessagesType) -> Dict:
return {
"header": {
"app_id": self.app_id,
"app_id": config.get_settings().XFYUN_APP_ID,
"uid": "12345"
},
"parameter": {
Expand Down
13 changes: 7 additions & 6 deletions src/retk/core/ai/llm/knowledge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Tuple

from retk import const
from .db_ops import extend_on_node_update, extend_on_node_post, LLM_SERVICES
from ..api.base import BaseLLMService, MessagesType

system_summary_prompt = (Path(__file__).parent / "system_summary.md").read_text(encoding="utf-8")
Expand All @@ -12,41 +13,41 @@ async def _send(
llm_service: BaseLLMService,
model: str,
system_prompt: str,
query: str,
md: str,
req_id: str,
) -> Tuple[str, const.CodeEnum]:
_msgs: MessagesType = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": query},
{"role": "user", "content": md},
]
return await llm_service.complete(messages=_msgs, model=model, req_id=req_id)


async def summary(
llm_service: BaseLLMService,
model: str,
query: str,
md: str,
req_id: str = None,
) -> Tuple[str, const.CodeEnum]:
return await _send(
llm_service=llm_service,
model=model,
system_prompt=system_summary_prompt,
query=query,
md=md,
req_id=req_id,
)


async def extend(
llm_service: BaseLLMService,
model: str,
query: str,
md: str,
req_id: str = None,
) -> Tuple[str, const.CodeEnum]:
return await _send(
llm_service=llm_service,
model=model,
system_prompt=system_extend_prompt,
query=query,
md=md,
req_id=req_id,
)
67 changes: 67 additions & 0 deletions src/retk/core/ai/llm/knowledge/db_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from datetime import timedelta

from bson import ObjectId

from retk.models.client import client
from retk.models.tps.llm import NodeExtendQueue
from retk.models.tps.node import Node
from .. import api

TOP_P = 0.9
TEMPERATURE = 0.5
TIMEOUT = 60

LLM_SERVICES = {
"tencent": api.TencentService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT),
"ali": api.AliyunService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT),
"openai": api.OpenaiService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT),
"moonshot": api.MoonshotService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT),
"xf": api.XfYunService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT),
"baidu": api.BaiduService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT),
}


async def extend_on_node_post(data: Node):
q: NodeExtendQueue = NodeExtendQueue(
_id=ObjectId(),
uid=data["uid"],
nid=data["id"],
summaryService="tencent",
summaryModel=api.TencentModelEnum.HUNYUAN_LITE.value,
extendService="ali",
extendModel=api.AliyunModelEnum.QWEN_PLUS.value,
)

# sort by _id desc
docs = await client.coll.llm_extend_node_queue.find(
filter={"uid": data["uid"]}
).sort("_id", -1).to_list(None)
has_q = False
for doc in docs:
if doc["nid"] == data["id"]:
has_q = True
q["_id"] = doc["_id"]
# renew the creating time
await client.coll.llm_extend_knowledge_queue.update_one(
filter={"_id": doc["_id"]},
update={"_id": q["_id"]},
)
break

max_keep = 5
if not has_q:
if len(docs) >= max_keep:
# remove the oldest and only keep the latest 5
await client.coll.llm_extend_node_queue.delete_many(
{"_id": {"$in": [doc["_id"] for doc in docs[max_keep:]]}}
)

await client.coll.llm_extend_node_queue.insert_one(q)


async def extend_on_node_update(old_data: Node, new_data: Node):
# filter out frequent updates
if new_data["modifiedAt"] - old_data["modifiedAt"] < timedelta(seconds=60):
return

await extend_on_node_post(new_data)
22 changes: 11 additions & 11 deletions src/retk/core/files/importing/async_tasks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ async def check_last_task_finished(uid: str, type_: str) -> Tuple[Optional[Impor
return None, False

if doc is None:
doc: ImportData = {
"_id": ObjectId(),
"uid": uid,
"process": 0,
"type": type_,
"startAt": datetime.datetime.now(tz=utc),
"running": True,
"msg": "",
"code": 0,
"obsidian": {},
}
doc: ImportData = ImportData(
_id=ObjectId(),
uid=uid,
process=0,
type=type_,
startAt=datetime.datetime.now(tz=utc),
running=True,
msg="",
code=0,
obsidian={},
)
res = await client.coll.import_data.insert_one(doc)
if not res.acknowledged:
await set_running_false(
Expand Down
Loading

0 comments on commit c0e0be6

Please sign in to comment.