Skip to content

Commit

Permalink
feat: add mistral + chatml prompts (#1426)
Browse files Browse the repository at this point in the history
CognitiveTech authored Jan 16, 2024
1 parent 6191bcd commit e326126
Showing 6 changed files with 107 additions and 5 deletions.
20 changes: 18 additions & 2 deletions fern/docs/pages/recipes/list-llm.mdx
Original file line number Diff line number Diff line change
@@ -24,15 +24,31 @@ user: {{ user_message }}
assistant: {{ assistant_message }}
```

And the "`tag`" style looks like this:
The "`tag`" style looks like this:

```text
<|system|>: {{ system_prompt }}
<|user|>: {{ user_message }}
<|assistant|>: {{ assistant_message }}
```

Some LLMs will not understand this prompt style, and will not work (returning nothing).
The "`mistral`" style looks like this:

```text
<s>[INST] You are an AI assistant. [/INST]</s>[INST] Hello, how are you doing? [/INST]
```

The "`chatml`" style looks like this:
```text
<|im_start|>system
{{ system_prompt }}<|im_end|>
<|im_start|>user"
{{ user_message }}<|im_end|>
<|im_start|>assistant
{{ assistant_message }}
```

Some LLMs will not understand these prompt styles, and will not work (returning nothing).
You can try to change the prompt style to `default` (or `tag`) in the settings, and it will
change the way the messages are formatted to be passed to the LLM.

49 changes: 48 additions & 1 deletion private_gpt/components/llm/prompt_helper.py
Original file line number Diff line number Diff line change
@@ -123,8 +123,51 @@ def _completion_to_prompt(self, completion: str) -> str:
)


class MistralPromptStyle(AbstractPromptStyle):
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = "<s>"
for message in messages:
role = message.role
content = message.content or ""
if role.lower() == "system":
message_from_user = f"[INST] {content.strip()} [/INST]"
prompt += message_from_user
elif role.lower() == "user":
prompt += "</s>"
message_from_user = f"[INST] {content.strip()} [/INST]"
prompt += message_from_user
return prompt

def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt(
[ChatMessage(content=completion, role=MessageRole.USER)]
)


class ChatMLPromptStyle(AbstractPromptStyle):
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = "<|im_start|>system\n"
for message in messages:
role = message.role
content = message.content or ""
if role.lower() == "system":
message_from_user = f"{content.strip()}"
prompt += message_from_user
elif role.lower() == "user":
prompt += "<|im_end|>\n<|im_start|>user\n"
message_from_user = f"{content.strip()}<|im_end|>\n"
prompt += message_from_user
prompt += "<|im_start|>assistant\n"
return prompt

def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt(
[ChatMessage(content=completion, role=MessageRole.USER)]
)


def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag"] | None
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.
@@ -137,4 +180,8 @@ def get_prompt_style(
return Llama2PromptStyle()
elif prompt_style == "tag":
return TagPromptStyle()
elif prompt_style == "mistral":
return MistralPromptStyle()
elif prompt_style == "chatml":
return ChatMLPromptStyle()
raise ValueError(f"Unknown prompt_style='{prompt_style}'")
3 changes: 2 additions & 1 deletion private_gpt/settings/settings.py
Original file line number Diff line number Diff line change
@@ -110,13 +110,14 @@ class LocalSettings(BaseModel):
embedding_hf_model_name: str = Field(
description="Name of the HuggingFace model to use for embeddings"
)
prompt_style: Literal["default", "llama2", "tag"] = Field(
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] = Field(
"llama2",
description=(
"The prompt style to use for the chat engine. "
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
"If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n"
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
"If `mistral` - use the `mistral prompt style. It shoudl look like <s>[INST] {System Prompt} [/INST]</s>[INST] { UserInstructions } [/INST]"
"`llama2` is the historic behaviour. `default` might work better with your custom models."
),
)
1 change: 1 addition & 0 deletions scripts/extract_openapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import json
import sys

import yaml
from uvicorn.importer import import_from_string

2 changes: 1 addition & 1 deletion settings.yaml
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ qdrant:
path: local_data/private_gpt/qdrant

local:
prompt_style: "llama2"
prompt_style: "mistral"
llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF
llm_hf_model_file: mistral-7b-instruct-v0.2.Q4_K_M.gguf
embedding_hf_model_name: BAAI/bge-small-en-v1.5
37 changes: 37 additions & 0 deletions tests/test_prompt_helper.py
Original file line number Diff line number Diff line change
@@ -2,8 +2,10 @@
from llama_index.llms import ChatMessage, MessageRole

from private_gpt.components.llm.prompt_helper import (
ChatMLPromptStyle,
DefaultPromptStyle,
Llama2PromptStyle,
MistralPromptStyle,
TagPromptStyle,
get_prompt_style,
)
@@ -15,6 +17,8 @@
("default", DefaultPromptStyle),
("llama2", Llama2PromptStyle),
("tag", TagPromptStyle),
("mistral", MistralPromptStyle),
("chatml", ChatMLPromptStyle),
],
)
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
@@ -62,6 +66,39 @@ def test_tag_prompt_style_format_with_system_prompt():
assert prompt_style.messages_to_prompt(messages) == expected_prompt


def test_mistral_prompt_style_format():
prompt_style = MistralPromptStyle()
messages = [
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
]

expected_prompt = (
"<s>[INST] You are an AI assistant. [/INST]</s>"
"[INST] Hello, how are you doing? [/INST]"
)

assert prompt_style.messages_to_prompt(messages) == expected_prompt


def test_chatml_prompt_style_format():
prompt_style = ChatMLPromptStyle()
messages = [
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
]

expected_prompt = (
"<|im_start|>system\n"
"You are an AI assistant.<|im_end|>\n"
"<|im_start|>user\n"
"Hello, how are you doing?<|im_end|>\n"
"<|im_start|>assistant\n"
)

assert prompt_style.messages_to_prompt(messages) == expected_prompt


def test_llama2_prompt_style_format():
prompt_style = Llama2PromptStyle()
messages = [

0 comments on commit e326126

Please sign in to comment.