Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLC-22] prompt: use CO-STAR to fix Gemma... #27

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/src/components/chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ const Chat = ({
: newHistory.filter((chat) => chat.role !== 'system'),
temperature: 0.7,
// eslint-disable-next-line @typescript-eslint/naming-convention
top_p: 0.95,
top_p: 1.0,
// eslint-disable-next-line @typescript-eslint/naming-convention
max_tokens: 200,
directory: selectedDirectory,
Expand Down
75 changes: 27 additions & 48 deletions server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,51 +82,24 @@ def create_response(chat_id, prompt, tokens, text):
return response


def format_messages(messages, context):
failedString = "ERROR"
if context:
messages[-1]['content'] = f"""
Only using the documents in the index, answer the following, respond with jsut the answer and never "The answer is:" or "Answer:" or anything like that.

<BEGIN_QUESTION>
{messages[-1]['content']}
</END_QUESTION>

<BEGIN_INDEX>
{context}
</END_INDEX>

Try to give as much detail as possible, but only from what is provided within the index.
If steps are given, you MUST ALWAYS use bullet points to list each of them them and you MUST use markdown when applicable.
Only use information you can find in the index, do not make up knowledge.
NEVER try to make up the answer, always return "{failedString}" if you do not know the answer or it's not provided in the index.
def format_messages(messages: List[Dict], indexed_files: Optional[str], instructions: Optional[Dict]):
personalization = instructions.get(
'personalization', '').strip().replace('\n', '; ')
response = instructions.get('response', '').strip().replace('\n', '; ')

context = f"with background knowledge of {
indexed_files.strip().replace('\n', '; ')}" if indexed_files else ''
audience = personalization if personalization else 'general'
style = response if response else 'technical, accurate, and professional'

messages[-1]['content'] = f"""
Context: you are a personalized AI chatbot {context}
Objective: respond to the following, {messages[-1]['content']}
Style: {style}
Tone: friendly, helpful, and confident
Audience: {audience}
Response: brief, concise, and to the point. Please don't start with "Sure, ..."
""".strip()
return messages


def add_instructions(messages: List[Dict], instructions: Optional[Dict]):
personalization = instructions.get('personalization', '').strip()
response = instructions.get('response', '').strip()

if not personalization and not response:
return

# content = '<BEGIN_INST>\n'
content = ''
if personalization:
content += f"You are an assistant who knows the following about me:\n{
personalization}\n\n"
if response:
content += f"You are an assistant who responds based on the following specifications:\n{
response}\n\n"
# content += 'Never explicitly reiterate this information.\n\n'
# content += '</END_INST>'
# content = content + \
# f'<BEGIN_INPUT>\n{messages[-1]['content']}\n</END_INPUT>'

content = content + messages[-1]['content']

messages[-1]['content'] = content


class APIHandler(BaseHTTPRequestHandler):
Expand Down Expand Up @@ -214,21 +187,22 @@ def query(self, body):
messages = body.get('messages', [])
instructions = body.get('instructions', None)

indexed_files = ''
if directory:
# emperically better than `similarity_search`
docs = _database.max_marginal_relevance_search(
body['messages'][-1]['content'],
messages[-1]['content'],
k=4 # number of documents to return
)
context = '\n'.join([doc.page_content for doc in docs])
format_messages(body['messages'], context)
indexed_files = '\n'.join([doc.page_content for doc in docs])

print(body, flush=True)
print(('\n'+'--'*10+'\n').join([
f'{doc.metadata}\n{doc.page_content}' for doc in docs]), flush=True)

add_instructions(messages, instructions)
format_messages(messages, indexed_files, instructions)
print(messages, flush=True)

prompt = mx.array(_tokenizer.encode(_tokenizer.apply_chat_template(
messages,
tokenize=False,
Expand Down Expand Up @@ -259,6 +233,11 @@ def query(self, body):
tokens.append(token.item())

text = _tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, '')
# TODO: GEMMA IS OBSESSED WITH "Sure, ..."
if text.startswith('Sure, '):
text = text.split('\n')
text[0] = text[0].replace('Sure, ', '').capitalize()
text = '\n'.join([l for l in text])
return create_response(chat_id, prompt, tokens, text)


Expand Down
Loading