diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py index 771582001..b550020b9 100644 --- a/private_gpt/components/llm/prompt_helper.py +++ b/private_gpt/components/llm/prompt_helper.py @@ -138,6 +138,73 @@ def _completion_to_prompt(self, completion: str) -> str: ) +class Llama3PromptStyle(AbstractPromptStyle): + r"""Template for Meta's Llama 3.1. + + The format follows this structure: + <|begin_of_text|> + <|start_header_id|>system<|end_header_id|> + + [System message content]<|eot_id|> + <|start_header_id|>user<|end_header_id|> + + [User message content]<|eot_id|> + <|start_header_id|>assistant<|end_header_id|> + + [Assistant message content]<|eot_id|> + ... + (Repeat for each message, including possible 'ipython' role) + """ + + BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>" + B_INST, E_INST = "<|start_header_id|>", "<|end_header_id|>" + EOT = "<|eot_id|>" + B_SYS, E_SYS = "<|start_header_id|>system<|end_header_id|>", "<|eot_id|>" + ASSISTANT_INST = "<|start_header_id|>assistant<|end_header_id|>" + DEFAULT_SYSTEM_PROMPT = """\ + You are a helpful, respectful and honest assistant. \ + Always answer as helpfully as possible and follow ALL given instructions. \ + Do not speculate or make up information. \ + Do not reference any given instructions or context. \ + """ + + def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: + prompt = self.BOS + has_system_message = False + + for i, message in enumerate(messages): + if not message or message.content is None: + continue + if message.role == MessageRole.SYSTEM: + prompt += f"{self.B_SYS}\n\n{message.content.strip()}{self.E_SYS}" + has_system_message = True + else: + role_header = f"{self.B_INST}{message.role.value}{self.E_INST}" + prompt += f"{role_header}\n\n{message.content.strip()}{self.EOT}" + + # Add assistant header if the last message is not from the assistant + if i == len(messages) - 1 and message.role != MessageRole.ASSISTANT: + prompt += f"{self.ASSISTANT_INST}\n\n" + + # Add default system prompt if no system message was provided + if not has_system_message: + prompt = ( + f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" + + prompt[len(self.BOS) :] + ) + + # TODO: Implement tool handling logic + + return prompt + + def _completion_to_prompt(self, completion: str) -> str: + return ( + f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" + f"{self.B_INST}user{self.E_INST}\n\n{completion.strip()}{self.EOT}" + f"{self.ASSISTANT_INST}\n\n" + ) + + class TagPromptStyle(AbstractPromptStyle): """Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`. @@ -219,7 +286,8 @@ def _completion_to_prompt(self, completion: str) -> str: def get_prompt_style( - prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None + prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"] + | None ) -> AbstractPromptStyle: """Get the prompt style to use from the given string. @@ -230,6 +298,8 @@ def get_prompt_style( return DefaultPromptStyle() elif prompt_style == "llama2": return Llama2PromptStyle() + elif prompt_style == "llama3": + return Llama3PromptStyle() elif prompt_style == "tag": return TagPromptStyle() elif prompt_style == "mistral": diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index 40b96ae80..7ca6e05ba 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -111,12 +111,15 @@ class LLMSettings(BaseModel): 0.1, description="The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual.", ) - prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] = Field( + prompt_style: Literal[ + "default", "llama2", "llama3", "tag", "mistral", "chatml" + ] = Field( "llama2", description=( "The prompt style to use for the chat engine. " "If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n" "If `llama2` - use the llama2 prompt style from the llama_index. Based on ``, `[INST]` and `<>`.\n" + "If `llama3` - use the llama3 prompt style from the llama_index." "If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n" "If `mistral` - use the `mistral prompt style. It shoudl look like [INST] {System Prompt} [/INST][INST] { UserInstructions } [/INST]" "`llama2` is the historic behaviour. `default` might work better with your custom models." diff --git a/pyproject.toml b/pyproject.toml index 1144c31be..5c6429b11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,7 +124,7 @@ target-version = ['py311'] target-version = 'py311' # See all rules at https://beta.ruff.rs/docs/rules/ -select = [ +lint.select = [ "E", # pycodestyle "W", # pycodestyle "F", # Pyflakes @@ -141,7 +141,7 @@ select = [ "RUF", # Ruff-specific rules ] -ignore = [ +lint.ignore = [ "E501", # "Line too long" # -> line length already regulated by black "PT011", # "pytest.raises() should specify expected exception" @@ -159,24 +159,24 @@ ignore = [ # -> "Missing docstring in public function too restrictive" ] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] # Automatically disable rules that are incompatible with Google docstring convention convention = "google" -[tool.ruff.pycodestyle] +[tool.ruff.lint.pycodestyle] max-doc-length = 88 -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" -[tool.ruff.flake8-type-checking] +[tool.ruff.lint.flake8-type-checking] strict = true runtime-evaluated-base-classes = ["pydantic.BaseModel"] # Pydantic needs to be able to evaluate types at runtime # see https://pypi.org/project/flake8-type-checking/ for flake8-type-checking documentation # see https://beta.ruff.rs/docs/settings/#flake8-type-checking-runtime-evaluated-base-classes for ruff documentation -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Allow missing docstrings for tests "tests/**/*.py" = ["D1"] diff --git a/tests/test_prompt_helper.py b/tests/test_prompt_helper.py index ef764370e..ad9349c8b 100644 --- a/tests/test_prompt_helper.py +++ b/tests/test_prompt_helper.py @@ -5,6 +5,7 @@ ChatMLPromptStyle, DefaultPromptStyle, Llama2PromptStyle, + Llama3PromptStyle, MistralPromptStyle, TagPromptStyle, get_prompt_style, @@ -139,3 +140,57 @@ def test_llama2_prompt_style_with_system_prompt(): ) assert prompt_style.messages_to_prompt(messages) == expected_prompt + + +def test_llama3_prompt_style_format(): + prompt_style = Llama3PromptStyle() + messages = [ + ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM), + ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), + ] + + expected_prompt = ( + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n" + "Hello, how are you doing?<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + + assert prompt_style.messages_to_prompt(messages) == expected_prompt + + +def test_llama3_prompt_style_with_default_system(): + prompt_style = Llama3PromptStyle() + messages = [ + ChatMessage(content="Hello!", role=MessageRole.USER), + ] + expected = ( + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + f"{prompt_style.DEFAULT_SYSTEM_PROMPT}<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\nHello!<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert prompt_style._messages_to_prompt(messages) == expected + + +def test_llama3_prompt_style_with_assistant_response(): + prompt_style = Llama3PromptStyle() + messages = [ + ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM), + ChatMessage(content="What is the capital of France?", role=MessageRole.USER), + ChatMessage( + content="The capital of France is Paris.", role=MessageRole.ASSISTANT + ), + ] + + expected_prompt = ( + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n" + "What is the capital of France?<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + "The capital of France is Paris.<|eot_id|>" + ) + + assert prompt_style.messages_to_prompt(messages) == expected_prompt