Skip to content

Commit

Permalink
🚀 ruff fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Nov 13, 2023
1 parent 7dedc2f commit 645a7ca
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 54 deletions.
27 changes: 11 additions & 16 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,28 +1,23 @@
repos:

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
args: [--line-length=120]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1
rev: v1.7.0
hooks:
- id: mypy
args: [--ignore-missing-imports, --follow-imports=skip]
additional_dependencies: [types-requests]
- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.9.3
hooks:
- id: isort
args: [--profile=black]
- repo: https://github.com/PYCQA/flake8
rev: 6.1.0

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.4
hooks:
- id: flake8
args: [--max-line-length=120, "--ignore=F401,F403,W503"]
- id: ruff
args: [ --fix ]
- id: ruff-format
types_or: [ python, pyi, jupyter ]
4 changes: 3 additions & 1 deletion examples/email_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def get_emails_from_inbox() -> List[Tuple[str, str]]:
"""

# Run AppleScript and collect output
process = subprocess.Popen(["osascript", "-e", apple_script], stdout=subprocess.PIPE)
process = subprocess.Popen(
["osascript", "-e", apple_script], stdout=subprocess.PIPE
)
out, _ = process.communicate()
raw_output = out.decode("utf-8").strip()

Expand Down
8 changes: 6 additions & 2 deletions examples/generate_pp_and_tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,18 @@ def generate_pp(answered_questions: list[str]) -> str:


if __name__ == "__main__":
print("Please answer the following questions to generate a Terms of Service and Privacy Policy.")
print(
"Please answer the following questions to generate a Terms of Service and Privacy Policy."
)
print("To skip a question, press enter without typing anything.")

legal_questions = example_legal_questions.copy()
# or from scratch using generate_legal_questions()

for i, question in enumerate(legal_questions):
answer = input(f"{i+1}/{len(legal_questions)}: {question} ") or "No answer provided."
answer = (
input(f"{i+1}/{len(legal_questions)}: {question} ") or "No answer provided."
)

legal_questions[i] = f"Q: {question}\nA: {answer}\n"

Expand Down
42 changes: 32 additions & 10 deletions funcchain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,19 @@ def create_union_chain(
output_types = output_type.__args__ # type: ignore
output_type_names = [t.__name__ for t in output_types]

input_kwargs["format_instructions"] = f"Extract to one of these output types: {output_type_names}."
input_kwargs[
"format_instructions"
] = f"Extract to one of these output types: {output_type_names}."

functions = multi_pydantic_to_functions(output_types)

if isinstance(LLM, RunnableWithFallbacks):
LLM = LLM.runnable.bind(**functions).with_fallbacks(
[fallback.bind(**functions) for fallback in LLM.fallbacks if hasattr(LLM, "fallbacks")]
[
fallback.bind(**functions)
for fallback in LLM.fallbacks
if hasattr(LLM, "fallbacks")
]
)
else:
LLM = LLM.bind(**functions) # type: ignore
Expand Down Expand Up @@ -78,7 +84,11 @@ def create_pydanctic_chain(
functions = pydantic_to_functions(output_type)
LLM = (
LLM.runnable.bind(**functions).with_fallbacks( # type: ignore
[fallback.bind(**functions) for fallback in LLM.fallbacks if hasattr(LLM, "fallbacks")]
[
fallback.bind(**functions)
for fallback in LLM.fallbacks
if hasattr(LLM, "fallbacks")
]
)
if isinstance(LLM, RunnableWithFallbacks)
else LLM.bind(**functions)
Expand Down Expand Up @@ -106,17 +116,25 @@ def create_chain(

images = [v for v in input_kwargs.values() if isinstance(v, Image.Image)]
if is_vision_model(LLM):
input_kwargs = {k: v for k, v in input_kwargs.items() if not isinstance(v, Image.Image)}
input_kwargs = {
k: v for k, v in input_kwargs.items() if not isinstance(v, Image.Image)
}
elif images:
raise RuntimeError("Images as input are only supported for vision models.")

prompt = create_prompt(instruction, system, context, images=images, **input_kwargs)

if func_model:
if getattr(output_type, "__origin__", None) is Union or isinstance(output_type, UnionType):
return create_union_chain(output_type, instruction, system, context, LLM, **input_kwargs)

if issubclass(output_type, BaseModel) and not issubclass(output_type, ParserBaseModel):
if getattr(output_type, "__origin__", None) is Union or isinstance(
output_type, UnionType
):
return create_union_chain(
output_type, instruction, system, context, LLM, **input_kwargs
)

if issubclass(output_type, BaseModel) and not issubclass(
output_type, ParserBaseModel
):
return create_pydanctic_chain(output_type, prompt, LLM, **input_kwargs)

return prompt | LLM | parser # type: ignore
Expand Down Expand Up @@ -152,7 +170,9 @@ def chain(
with get_openai_callback() as cb:
result = chain.invoke(input_kwargs)
if cb.total_tokens != 0:
log(f"{cb.total_tokens:05}T / {cb.total_cost:.3f}$ - {get_parent_frame(3).function}")
log(
f"{cb.total_tokens:05}T / {cb.total_cost:.3f}$ - {get_parent_frame(3).function}"
)

return result

Expand All @@ -173,6 +193,8 @@ async def achain(
with get_openai_callback() as cb:
result = await chain.ainvoke(input_kwargs)
if cb.total_tokens != 0:
log(f"{cb.total_tokens:05}T / {cb.total_cost:.3f}$ - {get_parent_frame(3).function}")
log(
f"{cb.total_tokens:05}T / {cb.total_cost:.3f}$ - {get_parent_frame(3).function}"
)

return result
4 changes: 3 additions & 1 deletion funcchain/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ class FuncchainSettings(BaseSettings):

def model_kwargs(self) -> dict[str, Any]:
return {
"model_name": self.MODEL_NAME if "::" not in self.MODEL_NAME else self.MODEL_NAME.split("::")[1],
"model_name": self.MODEL_NAME
if "::" not in self.MODEL_NAME
else self.MODEL_NAME.split("::")[1],
"temperature": self.MODEL_TEMPERATURE,
"verbose": self.VERBOSE,
"openai_api_key": self.OPENAI_API_KEY,
Expand Down
28 changes: 21 additions & 7 deletions funcchain/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ class LambdaOutputParser(BaseOutputParser[T]):

def parse(self, text: str) -> T:
if self._parse is None:
raise NotImplementedError("LambdaOutputParser.lambda_parse() is not implemented")
raise NotImplementedError(
"LambdaOutputParser.lambda_parse() is not implemented"
)
return self._parse(text)

@property
Expand Down Expand Up @@ -59,19 +61,27 @@ def parse_result(self, result: list[Generation], *, partial: bool = False) -> M:
def _pre_parse_function_call(self, result: list[Generation]) -> dict:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException("This output parser can only be used with a chat generation.")
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
except KeyError:
raise OutputParserException(f"The model refused to respond with a function call:\n{message.content}\n\n")
raise OutputParserException(
f"The model refused to respond with a function call:\n{message.content}\n\n"
)

return func_call

def _get_parser_for(self, function_name: str) -> BaseGenerationOutputParser[M]:
output_type_iter = filter(lambda t: t.__name__.lower() == function_name, self.output_types)
output_type_iter = filter(
lambda t: t.__name__.lower() == function_name, self.output_types
)
if output_type_iter is None:
raise OutputParserException(f"No parser found for function: {function_name}")
raise OutputParserException(
f"No parser found for function: {function_name}"
)
output_type: Type[M] = next(output_type_iter)

return PydanticOutputFunctionsParser(pydantic_schema=output_type)
Expand All @@ -85,7 +95,9 @@ def output_parser(cls) -> BaseOutputParser[Self]:
@classmethod
def parse(cls, text: str) -> Self:
"""Override for custom parsing."""
match = re.search(r"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL)
match = re.search(
r"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL
)
json_str = ""
if match:
json_str = match.group()
Expand Down Expand Up @@ -134,7 +146,9 @@ class CodeBlock(ParserBaseModel):

@classmethod
def parse(cls, text: str) -> "CodeBlock":
matches = re.finditer(r"```(?P<language>\w+)?\n?(?P<code>.*?)```", text, re.DOTALL)
matches = re.finditer(
r"```(?P<language>\w+)?\n?(?P<code>.*?)```", text, re.DOTALL
)
for match in matches:
groupdict = match.groupdict()
groupdict["language"] = groupdict.get("language", None)
Expand Down
21 changes: 17 additions & 4 deletions funcchain/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from typing import Any, Type

from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import BaseStringMessagePromptTemplate, MessagePromptTemplateT
from langchain.prompts.chat import (
BaseStringMessagePromptTemplate,
MessagePromptTemplateT,
)
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseMessage, HumanMessage, SystemMessage
from PIL import Image
Expand Down Expand Up @@ -88,13 +91,19 @@ def create_prompt(
input_kwargs[k] = v[: (settings.MAX_TOKENS - base_tokens) * 2 // 3]
print("Truncated: ", len(input_kwargs[k]))

template_format = "jinja2" if "{{" in instruction or "{%" in instruction else "f-string"
template_format = (
"jinja2" if "{{" in instruction or "{%" in instruction else "f-string"
)

required_f_str_vars = extract_fstring_vars(instruction) # TODO: jinja2
if "format_instructions" in required_f_str_vars:
required_f_str_vars.remove("format_instructions")

inject_vars = [f"[{var}]:\n{value}\n" for var, value in input_kwargs.items() if var not in required_f_str_vars]
inject_vars = [
f"[{var}]:\n{value}\n"
for var, value in input_kwargs.items()
if var not in required_f_str_vars
]
added_instruction = ("".join(inject_vars)).replace("{", "{{").replace("}", "}}")
instruction = added_instruction + instruction

Expand All @@ -117,4 +126,8 @@ def extract_fstring_vars(template: str) -> list[str]:
"""
Function to extract f-string variables from a string.
"""
return [field_name for _, field_name, _, _ in Formatter().parse(template) if field_name is not None]
return [
field_name
for _, field_name, _, _ in Formatter().parse(template)
if field_name is not None
]
8 changes: 6 additions & 2 deletions funcchain/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class CodeBlock(ParserBaseModel):

@classmethod
def parse(cls, text: str) -> "CodeBlock":
matches = re.finditer(r"```(?P<language>\w+)?\n?(?P<code>.*?)```", text, re.DOTALL)
matches = re.finditer(
r"```(?P<language>\w+)?\n?(?P<code>.*?)```", text, re.DOTALL
)
for match in matches:
groupdict = match.groupdict()
groupdict["language"] = groupdict.get("language", None)
Expand All @@ -38,4 +40,6 @@ class Error(BaseModel):
"""If anything goes wrong and you can not do what is expected, use this error function as fallback."""

title: str = Field(..., description="CamelCase Name titeling the error")
description: str = Field(..., description="Short description of the unexpected situation")
description: str = Field(
..., description="Short description of the unexpected situation"
)
12 changes: 10 additions & 2 deletions funcchain/utils/function_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ def from_docstring() -> str:
"""
Get the docstring of the parent caller function.
"""
doc_str = (caller_frame := get_parent_frame()).frame.f_globals[caller_frame.function].__doc__
doc_str = (
(caller_frame := get_parent_frame())
.frame.f_globals[caller_frame.function]
.__doc__
)
return "\n".join([line.lstrip() for line in doc_str.split("\n")])


Expand All @@ -28,7 +32,11 @@ def get_output_type() -> type:
Get the output type annotation of the parent caller function.
"""
try:
return (caller_frame := get_parent_frame()).frame.f_globals[caller_frame.function].__annotations__["return"]
return (
(caller_frame := get_parent_frame())
.frame.f_globals[caller_frame.function]
.__annotations__["return"]
)
except KeyError:
raise ValueError("The funcchain must have a return type annotation")

Expand Down
15 changes: 11 additions & 4 deletions funcchain/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def gather_llm_type(llm: BaseLanguageModel | Runnable, func_check: bool = True)
if func_check:
llm.predict_messages(
[
SystemMessage(content="This is a test message to see if the model can run functions."),
SystemMessage(
content="This is a test message to see if the model can run functions."
),
HumanMessage(content="Hello!"),
],
functions=[
Expand Down Expand Up @@ -141,11 +143,15 @@ def pydantic_to_functions(pydantic_object: Type[BaseModel]) -> dict[str, Any]:
docstring = parse(pydantic_object.__doc__ or "")
parameters = {k: v for k, v in schema.items() if k not in ("title", "description")}
for param in docstring.params:
if (name := param.arg_name) in parameters["properties"] and (description := param.description):
if (name := param.arg_name) in parameters["properties"] and (
description := param.description
):
if "description" not in parameters["properties"][name]:
parameters["properties"][name]["description"] = description

parameters["required"] = sorted(k for k, v in parameters["properties"].items() if "default" not in v)
parameters["required"] = sorted(
k for k, v in parameters["properties"].items() if "default" not in v
)
parameters["type"] = "object"

if "description" not in schema:
Expand Down Expand Up @@ -178,7 +184,8 @@ def multi_pydantic_to_functions(
pydantic_objects: list[Type[BaseModel]],
) -> dict[str, Any]:
functions: list[dict[str, Any]] = [
pydantic_to_functions(pydantic_object)["functions"][0] for pydantic_object in pydantic_objects
pydantic_to_functions(pydantic_object)["functions"][0]
for pydantic_object in pydantic_objects
]

return {
Expand Down
14 changes: 9 additions & 5 deletions funcchain/utils/model_defaults.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import Any

from dotenv import load_dotenv
from langchain.chat_models import AzureChatOpenAI, ChatAnthropic, ChatGooglePalm, ChatOpenAI, JinaChat
from langchain.chat_models import (
AzureChatOpenAI,
ChatAnthropic,
ChatGooglePalm,
ChatOpenAI,
JinaChat,
)
from langchain.chat_models.base import BaseChatModel

from funcchain.config import settings
Expand Down Expand Up @@ -41,8 +47,7 @@ def model_from_env(
if name := settings.MODEL_NAME:
return model_from_name(name, **kwargs)
raise ValueError(
"Model not found! "
"Make sure to use the correct env variables."
"Model not found! Make sure to use the correct env variables."
# "For more info: docs.url"
)

Expand Down Expand Up @@ -86,7 +91,6 @@ def model_from_name(
case "google":
return ChatGooglePalm(**kwargs)
raise ValueError(
"Model not found! "
"Make sure to use the correct env variables."
"Model not found! Make sure to use the correct env variables."
# "For more info: docs.url"
)

0 comments on commit 645a7ca

Please sign in to comment.