diff --git a/pdl/benchmark.py b/pdl/benchmark.py index 69f804d5..e91e2613 100644 --- a/pdl/benchmark.py +++ b/pdl/benchmark.py @@ -121,7 +121,7 @@ def accumulator(self, q): ) -def write_log( # pylint: disable=too-many-arguments +def write_log( # pylint: disable=too-many-arguments,too-many-positional-arguments log, index, question, truth, answer, solution, document, exc ): log.write("\n\n------------------------\n") diff --git a/pdl/pdl_ast.py b/pdl/pdl_ast.py index e998ad9d..620925f9 100644 --- a/pdl/pdl_ast.py +++ b/pdl/pdl_ast.py @@ -1,7 +1,6 @@ """PDL programs are represented by the Pydantic data structure defined in this file. """ - from enum import StrEnum from typing import Any, Literal, Optional, TypeAlias, TypedDict, Union @@ -353,7 +352,7 @@ class MessageBlock(Block): """Create a message.""" kind: Literal[BlockKind.MESSAGE] = BlockKind.MESSAGE - role: RoleType + role: RoleType # type: ignore """Role of associated to the message.""" content: "BlocksType" """Content of the message.""" @@ -602,32 +601,32 @@ def set_default_granite_model_parameters( if parameters is None: parameters = {} if "decoding_method" not in parameters: - parameters[ - "decoding_method" - ] = DECODING_METHOD # pylint: disable=attribute-defined-outside-init + parameters["decoding_method"] = ( + DECODING_METHOD # pylint: disable=attribute-defined-outside-init + ) if "max_tokens" in parameters and parameters["max_tokens"] is None: - parameters[ - "max_tokens" - ] = MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init + parameters["max_tokens"] = ( + MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init + ) if "min_new_tokens" not in parameters: - parameters[ - "min_new_tokens" - ] = MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init + parameters["min_new_tokens"] = ( + MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init + ) if "repetition_penalty" not in parameters: - parameters[ - "repetition_penalty" - ] = REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init + parameters["repetition_penalty"] = ( + REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init + ) if parameters["decoding_method"] == "sample": if "temperature" not in parameters: - parameters[ - "temperature" - ] = TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init + parameters["temperature"] = ( + TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init + ) if "top_k" not in parameters: - parameters[ - "top_k" - ] = TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init + parameters["top_k"] = ( + TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init + ) if "top_p" not in parameters: - parameters[ - "top_p" - ] = TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init + parameters["top_p"] = ( + TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init + ) return parameters diff --git a/pdl/pdl_compilers/to_regex.py b/pdl/pdl_compilers/to_regex.py index 371a1650..453430c6 100644 --- a/pdl/pdl_compilers/to_regex.py +++ b/pdl/pdl_compilers/to_regex.py @@ -30,8 +30,7 @@ class Re(ABC): @abstractmethod - def to_re(self) -> str: - ... + def to_re(self) -> str: ... class ReEmpty(Re): diff --git a/pdl/pdl_dumper.py b/pdl/pdl_dumper.py index 91495e8a..f2b108cf 100644 --- a/pdl/pdl_dumper.py +++ b/pdl/pdl_dumper.py @@ -203,9 +203,9 @@ def block_to_dict(block: pdl_ast.BlockType) -> int | float | str | dict[str, Any def blocks_to_dict( blocks: BlocksType, ) -> int | float | str | dict[str, Any] | list[int | float | str | dict[str, Any]]: - result: int | float | str | dict[str, Any] | list[ - int | float | str | dict[str, Any] - ] + result: ( + int | float | str | dict[str, Any] | list[int | float | str | dict[str, Any]] + ) if not isinstance(blocks, str) and isinstance(blocks, Sequence): result = [block_to_dict(block) for block in blocks] else: diff --git a/pdl/pdl_llms.py b/pdl/pdl_llms.py index 8657bda3..3a3403e2 100644 --- a/pdl/pdl_llms.py +++ b/pdl/pdl_llms.py @@ -43,7 +43,7 @@ def get_model() -> BamClient: return BamModel.bam_client @staticmethod - def generate_text( # pylint: disable=too-many-arguments + def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arguments model_id: str, prompt_id: Optional[str], model_input: Optional[str], @@ -69,7 +69,7 @@ def generate_text( # pylint: disable=too-many-arguments return {"role": None, "content": text} @staticmethod - def generate_text_stream( # pylint: disable=too-many-arguments + def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positional-arguments model_id: str, prompt_id: Optional[str], model_input: Optional[str], diff --git a/pdl/pdl_scheduler.py b/pdl/pdl_scheduler.py index 3da0dd4d..bc6d675a 100644 --- a/pdl/pdl_scheduler.py +++ b/pdl/pdl_scheduler.py @@ -95,9 +95,9 @@ def schedule( ) -> list[GeneratorReturnT]: global _LAST_ROLE # pylint: disable= global-statement todo: list[tuple[int, Generator[YieldMessage, Any, GeneratorReturnT], Any]] - todo_next: list[ - tuple[int, Generator[YieldMessage, Any, GeneratorReturnT], Any] - ] = [] + todo_next: list[tuple[int, Generator[YieldMessage, Any, GeneratorReturnT], Any]] = ( + [] + ) done: list[Optional[GeneratorReturnT]] todo = [(i, gen, None) for i, gen in enumerate(generators)] done = [None for _ in generators] @@ -106,10 +106,10 @@ def schedule( try: msg = gen.send(v) match msg: - case ModelYieldResultMessage( - result=result - ) | CodeYieldResultMessage(result=result) | YieldResultMessage( - result=result + case ( + ModelYieldResultMessage(result=result) + | CodeYieldResultMessage(result=result) + | YieldResultMessage(result=result) ): if msg.color is None: text = stringify(result) diff --git a/pdl/pdl_utils.py b/pdl/pdl_utils.py index ecee0fa2..f2247f68 100644 --- a/pdl/pdl_utils.py +++ b/pdl/pdl_utils.py @@ -36,9 +36,11 @@ def messages_to_str(messages: Messages) -> str: # TODO return "".join( [ - msg["content"] - if msg["role"] is None - else f"<|{msg['role']}|>{msg['content']}" + ( + msg["content"] + if msg["role"] is None + else f"<|{msg['role']}|>{msg['content']}" + ) for msg in messages ] )