diff --git a/fern/docs/pages/recipes/list-llm.mdx b/fern/docs/pages/recipes/list-llm.mdx index 2cb80e483..1e53804bb 100644 --- a/fern/docs/pages/recipes/list-llm.mdx +++ b/fern/docs/pages/recipes/list-llm.mdx @@ -24,7 +24,7 @@ user: {{ user_message }} assistant: {{ assistant_message }} ``` -And the "`tag`" style looks like this: +The "`tag`" style looks like this: ```text <|system|>: {{ system_prompt }} @@ -32,7 +32,23 @@ And the "`tag`" style looks like this: <|assistant|>: {{ assistant_message }} ``` -Some LLMs will not understand this prompt style, and will not work (returning nothing). +The "`mistral`" style looks like this: + +```text +[INST] You are an AI assistant. [/INST][INST] Hello, how are you doing? [/INST] +``` + +The "`chatml`" style looks like this: +```text +<|im_start|>system +{{ system_prompt }}<|im_end|> +<|im_start|>user" +{{ user_message }}<|im_end|> +<|im_start|>assistant +{{ assistant_message }} +``` + +Some LLMs will not understand these prompt styles, and will not work (returning nothing). You can try to change the prompt style to `default` (or `tag`) in the settings, and it will change the way the messages are formatted to be passed to the LLM. diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py index a8ca60f27..d1df9b814 100644 --- a/private_gpt/components/llm/prompt_helper.py +++ b/private_gpt/components/llm/prompt_helper.py @@ -123,8 +123,51 @@ def _completion_to_prompt(self, completion: str) -> str: ) +class MistralPromptStyle(AbstractPromptStyle): + def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: + prompt = "" + for message in messages: + role = message.role + content = message.content or "" + if role.lower() == "system": + message_from_user = f"[INST] {content.strip()} [/INST]" + prompt += message_from_user + elif role.lower() == "user": + prompt += "" + message_from_user = f"[INST] {content.strip()} [/INST]" + prompt += message_from_user + return prompt + + def _completion_to_prompt(self, completion: str) -> str: + return self._messages_to_prompt( + [ChatMessage(content=completion, role=MessageRole.USER)] + ) + + +class ChatMLPromptStyle(AbstractPromptStyle): + def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: + prompt = "<|im_start|>system\n" + for message in messages: + role = message.role + content = message.content or "" + if role.lower() == "system": + message_from_user = f"{content.strip()}" + prompt += message_from_user + elif role.lower() == "user": + prompt += "<|im_end|>\n<|im_start|>user\n" + message_from_user = f"{content.strip()}<|im_end|>\n" + prompt += message_from_user + prompt += "<|im_start|>assistant\n" + return prompt + + def _completion_to_prompt(self, completion: str) -> str: + return self._messages_to_prompt( + [ChatMessage(content=completion, role=MessageRole.USER)] + ) + + def get_prompt_style( - prompt_style: Literal["default", "llama2", "tag"] | None + prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None ) -> AbstractPromptStyle: """Get the prompt style to use from the given string. @@ -137,4 +180,8 @@ def get_prompt_style( return Llama2PromptStyle() elif prompt_style == "tag": return TagPromptStyle() + elif prompt_style == "mistral": + return MistralPromptStyle() + elif prompt_style == "chatml": + return ChatMLPromptStyle() raise ValueError(f"Unknown prompt_style='{prompt_style}'") diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index 7c58a762e..499ce66d7 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -110,13 +110,14 @@ class LocalSettings(BaseModel): embedding_hf_model_name: str = Field( description="Name of the HuggingFace model to use for embeddings" ) - prompt_style: Literal["default", "llama2", "tag"] = Field( + prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] = Field( "llama2", description=( "The prompt style to use for the chat engine. " "If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n" "If `llama2` - use the llama2 prompt style from the llama_index. Based on ``, `[INST]` and `<>`.\n" "If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n" + "If `mistral` - use the `mistral prompt style. It shoudl look like [INST] {System Prompt} [/INST][INST] { UserInstructions } [/INST]" "`llama2` is the historic behaviour. `default` might work better with your custom models." ), ) diff --git a/scripts/extract_openapi.py b/scripts/extract_openapi.py index ba6f138ad..15840d91f 100644 --- a/scripts/extract_openapi.py +++ b/scripts/extract_openapi.py @@ -1,6 +1,7 @@ import argparse import json import sys + import yaml from uvicorn.importer import import_from_string diff --git a/settings.yaml b/settings.yaml index b2ea0c698..d7e7ce028 100644 --- a/settings.yaml +++ b/settings.yaml @@ -51,7 +51,7 @@ qdrant: path: local_data/private_gpt/qdrant local: - prompt_style: "llama2" + prompt_style: "mistral" llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF llm_hf_model_file: mistral-7b-instruct-v0.2.Q4_K_M.gguf embedding_hf_model_name: BAAI/bge-small-en-v1.5 diff --git a/tests/test_prompt_helper.py b/tests/test_prompt_helper.py index 48cac0ba5..48597698b 100644 --- a/tests/test_prompt_helper.py +++ b/tests/test_prompt_helper.py @@ -2,8 +2,10 @@ from llama_index.llms import ChatMessage, MessageRole from private_gpt.components.llm.prompt_helper import ( + ChatMLPromptStyle, DefaultPromptStyle, Llama2PromptStyle, + MistralPromptStyle, TagPromptStyle, get_prompt_style, ) @@ -15,6 +17,8 @@ ("default", DefaultPromptStyle), ("llama2", Llama2PromptStyle), ("tag", TagPromptStyle), + ("mistral", MistralPromptStyle), + ("chatml", ChatMLPromptStyle), ], ) def test_get_prompt_style_success(prompt_style, expected_prompt_style): @@ -62,6 +66,39 @@ def test_tag_prompt_style_format_with_system_prompt(): assert prompt_style.messages_to_prompt(messages) == expected_prompt +def test_mistral_prompt_style_format(): + prompt_style = MistralPromptStyle() + messages = [ + ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), + ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), + ] + + expected_prompt = ( + "[INST] You are an AI assistant. [/INST]" + "[INST] Hello, how are you doing? [/INST]" + ) + + assert prompt_style.messages_to_prompt(messages) == expected_prompt + + +def test_chatml_prompt_style_format(): + prompt_style = ChatMLPromptStyle() + messages = [ + ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), + ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), + ] + + expected_prompt = ( + "<|im_start|>system\n" + "You are an AI assistant.<|im_end|>\n" + "<|im_start|>user\n" + "Hello, how are you doing?<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + assert prompt_style.messages_to_prompt(messages) == expected_prompt + + def test_llama2_prompt_style_format(): prompt_style = Llama2PromptStyle() messages = [