Skip to content

Commit

Permalink
feat(app):
Browse files Browse the repository at this point in the history
- add relative search for llm
- improve stream json detector
  • Loading branch information
MorvanZhou committed Jul 23, 2024
1 parent 781eb19 commit de1f998
Show file tree
Hide file tree
Showing 19 changed files with 559 additions and 180 deletions.
4 changes: 4 additions & 0 deletions src/retk/controllers/ai/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ async def get_extended_nodes(
sourceTitle=doc["sourceMd"].split("\n", 1)[0].strip(),
title=title.strip(),
content=content.strip(),
searchTerms=list(filter(
lambda x: x != "",
map(str.strip, doc.get("extendSearchTerms", "").split(","))
))[:3],
)
nodes.append(node)
return schemas.ai.GetExtendedNodesResponse(
Expand Down
1 change: 1 addition & 0 deletions src/retk/controllers/schemas/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Node(BaseModel):
sourceTitle: str
title: str
content: str
searchTerms: List[str]

requestId: str
nodes: List[Node]
35 changes: 31 additions & 4 deletions src/retk/core/ai/llm/api/aliyun.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
from enum import Enum
from typing import Tuple, AsyncIterable, Optional, Dict, List
from typing import Tuple, AsyncIterable, Optional, Dict, List, Union, Callable

from retk import config, const
from retk.core.utils import ratelimiter
Expand Down Expand Up @@ -166,12 +166,13 @@ async def stream_complete(
txt += choice["message"]["content"]
yield txt.encode("utf-8"), code

async def batch_complete(
async def _batch_complete_union(
self,
messages: List[MessagesType],
func: Callable,
model: str = None,
req_id: str = None,
) -> List[Tuple[str, const.CodeEnum]]:
) -> List[Tuple[Union[str, Dict[str, str]], const.CodeEnum]]:
if model is None:
m = self.default_model
else:
Expand All @@ -180,11 +181,37 @@ async def batch_complete(
rate_limiter = ratelimiter.RateLimiter(requests=m.RPM, period=60)

tasks = [
self._batch_complete(
func(
limiters=[concurrent_limiter, rate_limiter],
messages=m,
model=model,
req_id=req_id,
) for m in messages
]
return await asyncio.gather(*tasks)

async def batch_complete(
self,
messages: List[MessagesType],
model: str = None,
req_id: str = None,
) -> List[Tuple[str, const.CodeEnum]]:
return await self._batch_complete_union(
messages=messages,
func=self._batch_complete,
model=model,
req_id=req_id,
)

async def batch_complete_json_detect(
self,
messages: List[MessagesType],
model: str = None,
req_id: str = None,
) -> List[Tuple[Dict[str, str], const.CodeEnum]]:
return await self._batch_complete_union(
messages=messages,
func=self._batch_stream_complete_json_detect,
model=model,
req_id=req_id,
)
35 changes: 31 additions & 4 deletions src/retk/core/ai/llm/api/baidu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from datetime import datetime
from enum import Enum
from typing import Tuple, AsyncIterable, List, Dict
from typing import Tuple, AsyncIterable, List, Dict, Union, Callable

import httpx

Expand Down Expand Up @@ -201,24 +201,51 @@ async def stream_complete(
txt += json_data["result"]
yield txt.encode("utf-8"), code

async def batch_complete(
async def _batch_complete_union(
self,
messages: List[MessagesType],
func: Callable,
model: str = None,
req_id: str = None,
) -> List[Tuple[str, const.CodeEnum]]:
) -> List[Tuple[Union[str, Dict[str, str]], const.CodeEnum]]:
if model is None:
m = self.default_model
else:
m = _key2model[model].value
limiter = ratelimiter.RateLimiter(requests=m.RPM, period=60)

tasks = [
self._batch_complete(
func(
limiters=[limiter],
messages=m,
model=model,
req_id=req_id,
) for m in messages
]
return await asyncio.gather(*tasks)

async def batch_complete(
self,
messages: List[MessagesType],
model: str = None,
req_id: str = None,
) -> List[Tuple[str, const.CodeEnum]]:
return await self._batch_complete_union(
messages=messages,
func=self._batch_complete,
model=model,
req_id=req_id,
)

async def batch_complete_json_detect(
self,
messages: List[MessagesType],
model: str = None,
req_id: str = None,
) -> List[Tuple[Dict[str, str], const.CodeEnum]]:
return await self._batch_complete_union(
messages=messages,
func=self._batch_stream_complete_json_detect,
model=model,
req_id=req_id,
)
70 changes: 68 additions & 2 deletions src/retk/core/ai/llm/api/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Dict, Literal, AsyncIterable, Tuple, Optional, Union
from typing import (
List, Dict, Literal, AsyncIterable, Tuple, Optional, Union
)

import httpx

from retk import const
from retk.core.utils import ratelimiter
from retk.logger import logger
from ..utils import parse_json_pattern

MessagesType = List[Dict[Literal["role", "content"], str]]

Expand Down Expand Up @@ -95,12 +98,13 @@ async def _stream_complete(
url: str,
headers: Dict[str, str],
payload: bytes,
method: str = "POST",
params: Dict[str, str] = None,
req_id: str = None,
) -> AsyncIterable[Tuple[bytes, const.CodeEnum]]:
client = httpx.AsyncClient()
async with client.stream(
method="POST",
method=method,
url=url,
headers=headers,
content=payload,
Expand Down Expand Up @@ -141,6 +145,59 @@ async def _batch_complete(
else:
raise ValueError("Invalid number of limiters, should less than 4")

async def _batch_stream_complete_json_detect(
self,
limiters: List[Union[ratelimiter.RateLimiter, ratelimiter.ConcurrentLimiter]],
messages: MessagesType,
model: str = None,
req_id: str = None,
) -> Tuple[Optional[Dict[str, str]], const.CodeEnum]:
if len(limiters) == 4:
async with limiters[0], limiters[1], limiters[2], limiters[3]:
return await self.stream_complete_json_detect(messages=messages, model=model, req_id=req_id)
elif len(limiters) == 3:
async with limiters[0], limiters[1], limiters[2]:
return await self.stream_complete_json_detect(messages=messages, model=model, req_id=req_id)
elif len(limiters) == 2:
async with limiters[0], limiters[1]:
return await self.stream_complete_json_detect(messages=messages, model=model, req_id=req_id)
elif len(limiters) == 1:
async with limiters[0]:
return await self.stream_complete_json_detect(messages=messages, model=model, req_id=req_id)
else:
raise ValueError("Invalid number of limiters, should less than 4")

async def stream_complete_json_detect(
self,
messages: MessagesType,
model: str = None,
req_id: str = None,
) -> Tuple[Dict[str, str], const.CodeEnum]:
chunks: List[bytes] = []
chunks_append = chunks.append

async for b, code in self.stream_complete(
messages=messages,
model=model,
req_id=req_id,
):
if code != const.CodeEnum.OK:
logger.error(f"rid='{req_id}' | Model error: {code}")
return {}, code

chunks_append(b)
if b"}" in b:
text_bytes = b"".join(chunks)
text = text_bytes.decode("utf-8")
try:
d = parse_json_pattern(text)
return d, const.CodeEnum.OK
except ValueError:
continue
oneline = (b"".join(chunks).decode("utf-8")).replace("\n", "\\n")
logger.error(f"rid='{req_id}' | {self.__class__.__name__} {model} | error: No JSON pattern found | {oneline}")
return {}, const.CodeEnum.LLM_INVALID_RESPONSE_FORMAT

@abstractmethod
async def stream_complete(
self,
Expand All @@ -158,3 +215,12 @@ async def batch_complete(
req_id: str = None,
) -> List[Tuple[str, const.CodeEnum]]:
...

@abstractmethod
async def batch_complete_json_detect(
self,
messages: List[MessagesType],
model: str = None,
req_id: str = None,
) -> List[Tuple[Dict[str, str], const.CodeEnum]]:
...
39 changes: 31 additions & 8 deletions src/retk/core/ai/llm/api/moonshot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from enum import Enum
from typing import List, Tuple
from typing import List, Tuple, Callable, Union, Dict

from retk import config, const
from retk.core.utils import ratelimiter
Expand Down Expand Up @@ -46,25 +46,48 @@ def get_api_key():
return config.get_settings().MOONSHOT_API_KEY

@staticmethod
def get_concurrency():
return config.get_settings().MOONSHOT_CONCURRENCY

async def batch_complete(
self,
async def _batch_complete_union(
messages: List[MessagesType],
func: Callable,
model: str = None,
req_id: str = None,
) -> List[Tuple[str, const.CodeEnum]]:
) -> List[Tuple[Union[str, Dict[str, str]], const.CodeEnum]]:
settings = config.get_settings()
rate_limiter = ratelimiter.RateLimiter(requests=settings.MOONSHOT_RPM, period=60)
concurrent_limiter = ratelimiter.ConcurrentLimiter(n=settings.MOONSHOT_CONCURRENCY)

tasks = [
self._batch_complete(
func(
limiters=[concurrent_limiter, rate_limiter],
messages=m,
model=model,
req_id=req_id,
) for m in messages
]
return await asyncio.gather(*tasks)

async def batch_complete(
self,
messages: List[MessagesType],
model: str = None,
req_id: str = None,
) -> List[Tuple[str, const.CodeEnum]]:
return await self._batch_complete_union(
messages=messages,
func=self._batch_complete,
model=model,
req_id=req_id,
)

async def batch_complete_json_detect(
self,
messages: List[MessagesType],
model: str = None,
req_id: str = None,
) -> List[Tuple[Dict[str, str], const.CodeEnum]]:
return await self._batch_complete_union(
messages=messages,
func=self._batch_stream_complete_json_detect,
model=model,
req_id=req_id,
)
Loading

0 comments on commit de1f998

Please sign in to comment.