From 7269e908dde9b0e190b5e464f694582e6a3f7798 Mon Sep 17 00:00:00 2001 From: Baptiste Pasquier Date: Wed, 27 Mar 2024 10:11:25 +0100 Subject: [PATCH] :bug: fix dollar in markdown --- frontend/lib/basic_chat.py | 4 +++- frontend/lib/session_chat.py | 6 ++++-- frontend/lib/utils.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 3 deletions(-) create mode 100644 frontend/lib/utils.py diff --git a/frontend/lib/basic_chat.py b/frontend/lib/basic_chat.py index baffe47..00fce6d 100644 --- a/frontend/lib/basic_chat.py +++ b/frontend/lib/basic_chat.py @@ -1,5 +1,7 @@ import streamlit as st +from .utils import format_string + def basic_chat(): user_question = st.chat_input("Say something") @@ -13,4 +15,4 @@ def basic_chat(): response = chain.stream(user_question) with st.chat_message("assistant"): - st.write(response) + st.write(format_string(response)) diff --git a/frontend/lib/session_chat.py b/frontend/lib/session_chat.py index 064276a..60d5a9d 100644 --- a/frontend/lib/session_chat.py +++ b/frontend/lib/session_chat.py @@ -6,6 +6,8 @@ from frontend.lib.backend_interface import query +from .utils import format_string + @dataclass class Message: @@ -28,7 +30,7 @@ def session_chat(): with st.container(border=True): for message in st.session_state.get("messages", []): with st.chat_message(message.sender): - st.write(message.content) + st.write(format_string(message.content)) if user_question: if len(st.session_state.get("messages", [])) == 0: @@ -54,7 +56,7 @@ def session_chat(): placeholder = st.empty() for chunk in response: full_response += chunk - placeholder.write(full_response) + placeholder.write(format_string(full_response)) bot_message = Message("assistant", full_response, session_id) st.session_state["messages"].append(bot_message) diff --git a/frontend/lib/utils.py b/frontend/lib/utils.py new file mode 100644 index 0000000..babadae --- /dev/null +++ b/frontend/lib/utils.py @@ -0,0 +1,34 @@ +"""Utility functions for the frontend.""" + +import re + + +def protect_dollar(string: str) -> str: + r"""Escape unescaped dollar signs in a string. + + This function takes a string and returns a new string where all dollar signs ($) + that are not preceded by a backslash (\) are escaped with an additional backslash. + This is useful for preparing strings that contain dollar signs for environments + where the dollar sign may be interpreted as a special character, such as in + Markdown. + + Args: + string (str): The input string containing dollar signs to be escaped. + + Returns: + str: A new string with unescaped dollar signs preceded by a backslash. + """ + return re.sub(r"(? str: + """Format a string for safe Markdown usage. + + Args: + string (str): The input string to be formatted. + + Returns: + str: The formatted string. + """ + string = protect_dollar(string) + return string