Skip to content

Commit

Permalink
✨ Enhance chain composition logic
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Feb 15, 2024
1 parent 3b625fe commit c52a34b
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions src/funcchain/backend/compiler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, TypeVar
from operator import itemgetter
from typing import Annotated, Any, TypeVar, get_args, get_origin

from langchain_core.callbacks import Callbacks
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.output_parsers import BaseGenerationOutputParser, BaseOutputParser
from langchain_core.runnables import Runnable
from langchain_core.runnables import Runnable, RunnableParallel
from pydantic import BaseModel

from ..model.abilities import is_openai_function_model, is_vision_model
Expand All @@ -22,6 +23,7 @@
from ..schema.signature import Signature
from ..syntax.input_types import Image
from ..syntax.output_types import ParserBaseModel
from ..syntax.params import Depends
from ..utils.msg_tools import msg_to_str
from ..utils.pydantic import multi_pydantic_to_functions, pydantic_to_functions
from ..utils.token_counter import count_tokens
Expand All @@ -44,6 +46,7 @@ def create_union_chain(
memory: BaseChatMessageHistory,
context: list[BaseMessage],
llm: BaseChatModel,
leading_runnable: Runnable[dict[str, Any], Any],
input_kwargs: dict[str, Any],
) -> Runnable[dict[str, Any], Any]:
"""
Expand Down Expand Up @@ -72,7 +75,12 @@ def create_union_chain(
memory=memory,
)

return prompt | llm | RetryOpenAIFunctionPydanticUnionParser(output_types=output_types, retry=3, retry_llm=_llm)
return (
leading_runnable
| prompt
| llm
| RetryOpenAIFunctionPydanticUnionParser(output_types=output_types, retry=3, retry_llm=_llm)
)


def patch_openai_function_to_pydantic(
Expand Down Expand Up @@ -122,16 +130,36 @@ def create_chain(
# handle input arguments
prompt_args: list[str] = []
pydantic_args: list[str] = []
annotated_args: list[tuple[str, type]] = []
special_args: list[tuple[str, type]] = []

for i in input_args:
if i[1] is str:
prompt_args.append(i[0])
if issubclass(i[1], BaseModel):
elif get_origin(i[1]) is Annotated:
annotated_args.append(i)
elif issubclass(i[1], BaseModel):
pydantic_args.append(i[0])
else:
special_args.append(i)

dependencies: list[tuple[str, Depends]] = []

for arg_name, arg_type in annotated_args:
dependencies.append((arg_name, get_args(arg_type)[1:][0]))
if get_args(arg_type)[0] is str:
prompt_args.append(arg_name)

if issubclass(get_args(arg_type)[0], BaseModel):
pydantic_args.append(arg_name)

leading_runnable: Runnable[Any, Any] = RunnableParallel(
{
**{name: itemgetter(name) for name in (prompt_args + pydantic_args)},
**{name: dep.dependency for name, dep in dependencies}, # type: ignore
}
)

# TODO: change this into input_args
input_kwargs = {k: "" for k in (prompt_args + pydantic_args)}

Expand Down Expand Up @@ -185,6 +213,7 @@ def create_chain(
memory,
context,
llm,
leading_runnable,
input_kwargs,
)
if isinstance(parser, RetryJsonPydanticParser) or isinstance(parser, RetryJsonPrimitiveTypeParser):
Expand All @@ -206,7 +235,7 @@ def create_chain(
...

assert parser is not None
return chat_prompt | llm | parser
return leading_runnable | chat_prompt | llm | parser


def compile_chain(signature: Signature, temp_images: list[Image] = []) -> Runnable[dict[str, Any], ChainOutput]:
Expand Down

0 comments on commit c52a34b

Please sign in to comment.