From 645a7ca17ebe3bf339ba0395b5cf0fd610a79130 Mon Sep 17 00:00:00 2001 From: Shroominic Date: Mon, 13 Nov 2023 15:08:21 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80=20ruff=20fmt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .pre-commit-config.yaml | 27 ++++++++------------ examples/email_answering.py | 4 ++- examples/generate_pp_and_tos.py | 8 ++++-- funcchain/chain.py | 42 +++++++++++++++++++++++-------- funcchain/config.py | 4 ++- funcchain/parser.py | 28 +++++++++++++++------ funcchain/prompt.py | 21 +++++++++++++--- funcchain/types.py | 8 ++++-- funcchain/utils/function_frame.py | 12 +++++++-- funcchain/utils/helpers.py | 15 ++++++++--- funcchain/utils/model_defaults.py | 14 +++++++---- 11 files changed, 129 insertions(+), 54 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1247f22..0cad823 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,28 +1,23 @@ repos: + - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v4.5.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace -- repo: https://github.com/psf/black - rev: 23.7.0 - hooks: - - id: black - args: [--line-length=120] + - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.4.1 + rev: v1.7.0 hooks: - id: mypy args: [--ignore-missing-imports, --follow-imports=skip] additional_dependencies: [types-requests] -- repo: https://github.com/pre-commit/mirrors-isort - rev: v5.9.3 - hooks: - - id: isort - args: [--profile=black] -- repo: https://github.com/PYCQA/flake8 - rev: 6.1.0 + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.4 hooks: - - id: flake8 - args: [--max-line-length=120, "--ignore=F401,F403,W503"] + - id: ruff + args: [ --fix ] + - id: ruff-format + types_or: [ python, pyi, jupyter ] diff --git a/examples/email_answering.py b/examples/email_answering.py index a1d2416..d7b1071 100644 --- a/examples/email_answering.py +++ b/examples/email_answering.py @@ -20,7 +20,9 @@ def get_emails_from_inbox() -> List[Tuple[str, str]]: """ # Run AppleScript and collect output - process = subprocess.Popen(["osascript", "-e", apple_script], stdout=subprocess.PIPE) + process = subprocess.Popen( + ["osascript", "-e", apple_script], stdout=subprocess.PIPE + ) out, _ = process.communicate() raw_output = out.decode("utf-8").strip() diff --git a/examples/generate_pp_and_tos.py b/examples/generate_pp_and_tos.py index e0ab2c1..163ceea 100644 --- a/examples/generate_pp_and_tos.py +++ b/examples/generate_pp_and_tos.py @@ -62,14 +62,18 @@ def generate_pp(answered_questions: list[str]) -> str: if __name__ == "__main__": - print("Please answer the following questions to generate a Terms of Service and Privacy Policy.") + print( + "Please answer the following questions to generate a Terms of Service and Privacy Policy." + ) print("To skip a question, press enter without typing anything.") legal_questions = example_legal_questions.copy() # or from scratch using generate_legal_questions() for i, question in enumerate(legal_questions): - answer = input(f"{i+1}/{len(legal_questions)}: {question} ") or "No answer provided." + answer = ( + input(f"{i+1}/{len(legal_questions)}: {question} ") or "No answer provided." + ) legal_questions[i] = f"Q: {question}\nA: {answer}\n" diff --git a/funcchain/chain.py b/funcchain/chain.py index c849bc0..0ca33c8 100644 --- a/funcchain/chain.py +++ b/funcchain/chain.py @@ -42,13 +42,19 @@ def create_union_chain( output_types = output_type.__args__ # type: ignore output_type_names = [t.__name__ for t in output_types] - input_kwargs["format_instructions"] = f"Extract to one of these output types: {output_type_names}." + input_kwargs[ + "format_instructions" + ] = f"Extract to one of these output types: {output_type_names}." functions = multi_pydantic_to_functions(output_types) if isinstance(LLM, RunnableWithFallbacks): LLM = LLM.runnable.bind(**functions).with_fallbacks( - [fallback.bind(**functions) for fallback in LLM.fallbacks if hasattr(LLM, "fallbacks")] + [ + fallback.bind(**functions) + for fallback in LLM.fallbacks + if hasattr(LLM, "fallbacks") + ] ) else: LLM = LLM.bind(**functions) # type: ignore @@ -78,7 +84,11 @@ def create_pydanctic_chain( functions = pydantic_to_functions(output_type) LLM = ( LLM.runnable.bind(**functions).with_fallbacks( # type: ignore - [fallback.bind(**functions) for fallback in LLM.fallbacks if hasattr(LLM, "fallbacks")] + [ + fallback.bind(**functions) + for fallback in LLM.fallbacks + if hasattr(LLM, "fallbacks") + ] ) if isinstance(LLM, RunnableWithFallbacks) else LLM.bind(**functions) @@ -106,17 +116,25 @@ def create_chain( images = [v for v in input_kwargs.values() if isinstance(v, Image.Image)] if is_vision_model(LLM): - input_kwargs = {k: v for k, v in input_kwargs.items() if not isinstance(v, Image.Image)} + input_kwargs = { + k: v for k, v in input_kwargs.items() if not isinstance(v, Image.Image) + } elif images: raise RuntimeError("Images as input are only supported for vision models.") prompt = create_prompt(instruction, system, context, images=images, **input_kwargs) if func_model: - if getattr(output_type, "__origin__", None) is Union or isinstance(output_type, UnionType): - return create_union_chain(output_type, instruction, system, context, LLM, **input_kwargs) - - if issubclass(output_type, BaseModel) and not issubclass(output_type, ParserBaseModel): + if getattr(output_type, "__origin__", None) is Union or isinstance( + output_type, UnionType + ): + return create_union_chain( + output_type, instruction, system, context, LLM, **input_kwargs + ) + + if issubclass(output_type, BaseModel) and not issubclass( + output_type, ParserBaseModel + ): return create_pydanctic_chain(output_type, prompt, LLM, **input_kwargs) return prompt | LLM | parser # type: ignore @@ -152,7 +170,9 @@ def chain( with get_openai_callback() as cb: result = chain.invoke(input_kwargs) if cb.total_tokens != 0: - log(f"{cb.total_tokens:05}T / {cb.total_cost:.3f}$ - {get_parent_frame(3).function}") + log( + f"{cb.total_tokens:05}T / {cb.total_cost:.3f}$ - {get_parent_frame(3).function}" + ) return result @@ -173,6 +193,8 @@ async def achain( with get_openai_callback() as cb: result = await chain.ainvoke(input_kwargs) if cb.total_tokens != 0: - log(f"{cb.total_tokens:05}T / {cb.total_cost:.3f}$ - {get_parent_frame(3).function}") + log( + f"{cb.total_tokens:05}T / {cb.total_cost:.3f}$ - {get_parent_frame(3).function}" + ) return result diff --git a/funcchain/config.py b/funcchain/config.py index bef0416..a47c749 100644 --- a/funcchain/config.py +++ b/funcchain/config.py @@ -36,7 +36,9 @@ class FuncchainSettings(BaseSettings): def model_kwargs(self) -> dict[str, Any]: return { - "model_name": self.MODEL_NAME if "::" not in self.MODEL_NAME else self.MODEL_NAME.split("::")[1], + "model_name": self.MODEL_NAME + if "::" not in self.MODEL_NAME + else self.MODEL_NAME.split("::")[1], "temperature": self.MODEL_TEMPERATURE, "verbose": self.VERBOSE, "openai_api_key": self.OPENAI_API_KEY, diff --git a/funcchain/parser.py b/funcchain/parser.py index a7e1442..d7d2f6a 100644 --- a/funcchain/parser.py +++ b/funcchain/parser.py @@ -18,7 +18,9 @@ class LambdaOutputParser(BaseOutputParser[T]): def parse(self, text: str) -> T: if self._parse is None: - raise NotImplementedError("LambdaOutputParser.lambda_parse() is not implemented") + raise NotImplementedError( + "LambdaOutputParser.lambda_parse() is not implemented" + ) return self._parse(text) @property @@ -59,19 +61,27 @@ def parse_result(self, result: list[Generation], *, partial: bool = False) -> M: def _pre_parse_function_call(self, result: list[Generation]) -> dict: generation = result[0] if not isinstance(generation, ChatGeneration): - raise OutputParserException("This output parser can only be used with a chat generation.") + raise OutputParserException( + "This output parser can only be used with a chat generation." + ) message = generation.message try: func_call = copy.deepcopy(message.additional_kwargs["function_call"]) except KeyError: - raise OutputParserException(f"The model refused to respond with a function call:\n{message.content}\n\n") + raise OutputParserException( + f"The model refused to respond with a function call:\n{message.content}\n\n" + ) return func_call def _get_parser_for(self, function_name: str) -> BaseGenerationOutputParser[M]: - output_type_iter = filter(lambda t: t.__name__.lower() == function_name, self.output_types) + output_type_iter = filter( + lambda t: t.__name__.lower() == function_name, self.output_types + ) if output_type_iter is None: - raise OutputParserException(f"No parser found for function: {function_name}") + raise OutputParserException( + f"No parser found for function: {function_name}" + ) output_type: Type[M] = next(output_type_iter) return PydanticOutputFunctionsParser(pydantic_schema=output_type) @@ -85,7 +95,9 @@ def output_parser(cls) -> BaseOutputParser[Self]: @classmethod def parse(cls, text: str) -> Self: """Override for custom parsing.""" - match = re.search(r"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL) + match = re.search( + r"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL + ) json_str = "" if match: json_str = match.group() @@ -134,7 +146,9 @@ class CodeBlock(ParserBaseModel): @classmethod def parse(cls, text: str) -> "CodeBlock": - matches = re.finditer(r"```(?P\w+)?\n?(?P.*?)```", text, re.DOTALL) + matches = re.finditer( + r"```(?P\w+)?\n?(?P.*?)```", text, re.DOTALL + ) for match in matches: groupdict = match.groupdict() groupdict["language"] = groupdict.get("language", None) diff --git a/funcchain/prompt.py b/funcchain/prompt.py index 856bb99..b2a8b2d 100644 --- a/funcchain/prompt.py +++ b/funcchain/prompt.py @@ -2,7 +2,10 @@ from typing import Any, Type from langchain.prompts import ChatPromptTemplate -from langchain.prompts.chat import BaseStringMessagePromptTemplate, MessagePromptTemplateT +from langchain.prompts.chat import ( + BaseStringMessagePromptTemplate, + MessagePromptTemplateT, +) from langchain.prompts.prompt import PromptTemplate from langchain.schema import BaseMessage, HumanMessage, SystemMessage from PIL import Image @@ -88,13 +91,19 @@ def create_prompt( input_kwargs[k] = v[: (settings.MAX_TOKENS - base_tokens) * 2 // 3] print("Truncated: ", len(input_kwargs[k])) - template_format = "jinja2" if "{{" in instruction or "{%" in instruction else "f-string" + template_format = ( + "jinja2" if "{{" in instruction or "{%" in instruction else "f-string" + ) required_f_str_vars = extract_fstring_vars(instruction) # TODO: jinja2 if "format_instructions" in required_f_str_vars: required_f_str_vars.remove("format_instructions") - inject_vars = [f"[{var}]:\n{value}\n" for var, value in input_kwargs.items() if var not in required_f_str_vars] + inject_vars = [ + f"[{var}]:\n{value}\n" + for var, value in input_kwargs.items() + if var not in required_f_str_vars + ] added_instruction = ("".join(inject_vars)).replace("{", "{{").replace("}", "}}") instruction = added_instruction + instruction @@ -117,4 +126,8 @@ def extract_fstring_vars(template: str) -> list[str]: """ Function to extract f-string variables from a string. """ - return [field_name for _, field_name, _, _ in Formatter().parse(template) if field_name is not None] + return [ + field_name + for _, field_name, _, _ in Formatter().parse(template) + if field_name is not None + ] diff --git a/funcchain/types.py b/funcchain/types.py index b1030ea..136f99a 100644 --- a/funcchain/types.py +++ b/funcchain/types.py @@ -15,7 +15,9 @@ class CodeBlock(ParserBaseModel): @classmethod def parse(cls, text: str) -> "CodeBlock": - matches = re.finditer(r"```(?P\w+)?\n?(?P.*?)```", text, re.DOTALL) + matches = re.finditer( + r"```(?P\w+)?\n?(?P.*?)```", text, re.DOTALL + ) for match in matches: groupdict = match.groupdict() groupdict["language"] = groupdict.get("language", None) @@ -38,4 +40,6 @@ class Error(BaseModel): """If anything goes wrong and you can not do what is expected, use this error function as fallback.""" title: str = Field(..., description="CamelCase Name titeling the error") - description: str = Field(..., description="Short description of the unexpected situation") + description: str = Field( + ..., description="Short description of the unexpected situation" + ) diff --git a/funcchain/utils/function_frame.py b/funcchain/utils/function_frame.py index b329e89..04d717d 100644 --- a/funcchain/utils/function_frame.py +++ b/funcchain/utils/function_frame.py @@ -19,7 +19,11 @@ def from_docstring() -> str: """ Get the docstring of the parent caller function. """ - doc_str = (caller_frame := get_parent_frame()).frame.f_globals[caller_frame.function].__doc__ + doc_str = ( + (caller_frame := get_parent_frame()) + .frame.f_globals[caller_frame.function] + .__doc__ + ) return "\n".join([line.lstrip() for line in doc_str.split("\n")]) @@ -28,7 +32,11 @@ def get_output_type() -> type: Get the output type annotation of the parent caller function. """ try: - return (caller_frame := get_parent_frame()).frame.f_globals[caller_frame.function].__annotations__["return"] + return ( + (caller_frame := get_parent_frame()) + .frame.f_globals[caller_frame.function] + .__annotations__["return"] + ) except KeyError: raise ValueError("The funcchain must have a return type annotation") diff --git a/funcchain/utils/helpers.py b/funcchain/utils/helpers.py index 323b237..1ee5006 100644 --- a/funcchain/utils/helpers.py +++ b/funcchain/utils/helpers.py @@ -83,7 +83,9 @@ def gather_llm_type(llm: BaseLanguageModel | Runnable, func_check: bool = True) if func_check: llm.predict_messages( [ - SystemMessage(content="This is a test message to see if the model can run functions."), + SystemMessage( + content="This is a test message to see if the model can run functions." + ), HumanMessage(content="Hello!"), ], functions=[ @@ -141,11 +143,15 @@ def pydantic_to_functions(pydantic_object: Type[BaseModel]) -> dict[str, Any]: docstring = parse(pydantic_object.__doc__ or "") parameters = {k: v for k, v in schema.items() if k not in ("title", "description")} for param in docstring.params: - if (name := param.arg_name) in parameters["properties"] and (description := param.description): + if (name := param.arg_name) in parameters["properties"] and ( + description := param.description + ): if "description" not in parameters["properties"][name]: parameters["properties"][name]["description"] = description - parameters["required"] = sorted(k for k, v in parameters["properties"].items() if "default" not in v) + parameters["required"] = sorted( + k for k, v in parameters["properties"].items() if "default" not in v + ) parameters["type"] = "object" if "description" not in schema: @@ -178,7 +184,8 @@ def multi_pydantic_to_functions( pydantic_objects: list[Type[BaseModel]], ) -> dict[str, Any]: functions: list[dict[str, Any]] = [ - pydantic_to_functions(pydantic_object)["functions"][0] for pydantic_object in pydantic_objects + pydantic_to_functions(pydantic_object)["functions"][0] + for pydantic_object in pydantic_objects ] return { diff --git a/funcchain/utils/model_defaults.py b/funcchain/utils/model_defaults.py index f5392da..e4d4022 100644 --- a/funcchain/utils/model_defaults.py +++ b/funcchain/utils/model_defaults.py @@ -1,7 +1,13 @@ from typing import Any from dotenv import load_dotenv -from langchain.chat_models import AzureChatOpenAI, ChatAnthropic, ChatGooglePalm, ChatOpenAI, JinaChat +from langchain.chat_models import ( + AzureChatOpenAI, + ChatAnthropic, + ChatGooglePalm, + ChatOpenAI, + JinaChat, +) from langchain.chat_models.base import BaseChatModel from funcchain.config import settings @@ -41,8 +47,7 @@ def model_from_env( if name := settings.MODEL_NAME: return model_from_name(name, **kwargs) raise ValueError( - "Model not found! " - "Make sure to use the correct env variables." + "Model not found! Make sure to use the correct env variables." # "For more info: docs.url" ) @@ -86,7 +91,6 @@ def model_from_name( case "google": return ChatGooglePalm(**kwargs) raise ValueError( - "Model not found! " - "Make sure to use the correct env variables." + "Model not found! Make sure to use the correct env variables." # "For more info: docs.url" )