Skip to content

Commit

Permalink
Add survey mode
Browse files Browse the repository at this point in the history
  • Loading branch information
cornzz committed Feb 8, 2025
1 parent a895fce commit 0afea2e
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 19 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ pip install -r requirements.txt
```
- Create a `.env` file, e.g.:
```
LLM_ENDPOINT=https://api.openai.com/v1 # Optional. If not provided, only compression will be possible
LLM_ENDPOINT=https://api.openai.com/v1 # Optional. If not provided, only compression will be possible.
LLM_TOKEN=token_1234
LLM_LIST=gpt-4o-mini, gpt-3.5-turbo # Optional. If not provided, a list of models will be fetched from the API
FLAG_PASSWORD=very_secret # Optional. If not provided, /flagged and /logs endpoints are disabled
LLM_LIST=gpt-4o-mini, gpt-3.5-turbo # Optional. If not provided, a list of models will be fetched from the API.
SURVEY_MODE=false # Optional. If set to 1, survey mode is enabled, i.e. answers are returned in random order and feedback can be submitted by the user.
FLAG_PASSWORD=very_secret # Optional. If not provided, /flagged and /logs endpoints are disabled.
APP_PATH=/ # Optional. Sets the root path of the application, for example if the application is behind a reverse proxy.
```

## Running
Expand Down
26 changes: 13 additions & 13 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CONSENT_POPUP = os.getenv("CONSENT_POPUP", "false")
SURVEY_MODE = os.getenv("SURVEY_MODE", "false") == "true"
FLAG_DIRECTORY = os.path.join(BASE_DIR, "../flagged")
FLAG_PASSWORD = os.getenv("FLAG_PASSWORD")
LOG_DIRECTORY = os.path.join(FLAG_DIRECTORY, "logs")
Expand Down Expand Up @@ -205,7 +206,7 @@ def run_demo(
compressed,
diff,
metrics,
*shuffle_and_flatten(res_original, res_compressed),
*shuffle_and_flatten(res_compressed, res_original, survey_mode=SURVEY_MODE),
]
+ [gr.Button(interactive=not error, elem_classes="button-pulse" if not error else "")] * 4
+ [[None, None]]
Expand Down Expand Up @@ -234,7 +235,7 @@ def run_demo(
)
gr.Markdown(
f"""
- **The order of the responses (prompt compressed / uncompressed) is randomized** and will be revealed after feedback submission.
{'- **The order of the responses (prompt compressed / uncompressed) is randomized** and will be revealed after feedback submission.' if SURVEY_MODE else ''}
- LLMLingua-2 is a task-agnostic compression model, the value of the question field is not considered in the compression process. Compression is performed {'on a CPU. Using a GPU would be faster.' if not (MPS_AVAILABLE or CUDA_AVAILABLE) else f'on a GPU {"using MPS." if MPS_AVAILABLE else f"({torch.cuda.get_device_name()})."}'}
- The example prompts were taken from the [MeetingBank-QA-Summary](https://huggingface.co/datasets/microsoft/MeetingBank-QA-Summary) dataset. Click on a question to autofill the question field.
- Token counts are calculated using the [GPT-3.5/-4 tokenizer](https://platform.openai.com/tokenizer), actual counts may vary for different target models. The saving metric is based on an API pricing of $0.03 / 1000 tokens.
Expand Down Expand Up @@ -323,25 +324,24 @@ def run_demo(
compressed = gr.Textbox(label="Compressed Prompt", visible=False)
with gr.Row(elem_classes="responses") as responses:
with gr.Column(elem_classes="responses"):
response_a = gr.Textbox(
label="LLM Response A", lines=10, max_lines=10, autoscroll=False, interactive=False
)
res_label_a = "LLM Response A" if SURVEY_MODE else "LLM Response Compressed Prompt"
response_a = gr.Textbox(label=res_label_a, lines=10, max_lines=10, autoscroll=False, interactive=False)
response_a_obj = gr.Textbox(label="Response A", visible=False)
with gr.Row():
with gr.Row(visible=SURVEY_MODE):
a_yes = gr.Button("✅", interactive=False)
a_no = gr.Button("❌", interactive=False)
with gr.Column(elem_classes="responses"):
response_b = gr.Textbox(
label="LLM Response B", lines=10, max_lines=10, autoscroll=False, interactive=False
)
res_label_b = "LLM Response B" if SURVEY_MODE else "LLM Response Original Prompt"
response_b = gr.Textbox(label=res_label_b, lines=10, max_lines=10, autoscroll=False, interactive=False)
response_b_obj = gr.Textbox(label="Response B", visible=False)
with gr.Row():
with gr.Row(visible=SURVEY_MODE):
b_yes = gr.Button("✅", interactive=False)
b_no = gr.Button("❌", interactive=False)
FLAG_BUTTONS = [a_yes, a_no, b_yes, b_no]
gr.Markdown(
'<div class="button-hint"><b>Please click on one of the two buttons <em>for each answer &nbsp;</em>to submit feedback.</b><br>'
"✅ = answered your question / solved your problem&nbsp;&nbsp;&nbsp; ❌ = did not answer your question / solve your problem.</div>"
"✅ = answered your question / solved your problem&nbsp;&nbsp;&nbsp; ❌ = did not answer your question / solve your problem.</div>",
visible=SURVEY_MODE,
)

# States
Expand All @@ -368,7 +368,7 @@ def run_demo(
)
clear.click(
lambda: [None] * 6
+ [gr.Textbox(label="LLM Response A", value=None), gr.Textbox(label="LLM Response B", value=None)]
+ [gr.Textbox(label=res_label_a, value=None), gr.Textbox(label=res_label_b, value=None)]
+ [create_metrics_df(), gr.Dataset(visible=True), gr.Button(visible=False), gr.DataFrame(visible=False)]
+ [gr.Button(elem_classes="", interactive=False)] * 4
+ [[None, None]],
Expand Down Expand Up @@ -424,7 +424,7 @@ def handle_flag_selection(question, prompt, compressed, rate, metrics, res_a, re
flagging_callback.flag(args, flag_option=json.dumps(flags), username=request.cookies["session"])
gr.Info("Preference saved. Thank you for your feedback.")
get_label = lambda res: "LLM Response " + (
"(compressed prompt)" if '"compressed": true' in res else "(original prompt)"
"Compressed Prompt" if '"compressed": true' in res else "Original Prompt"
)
return gr.Textbox(label=get_label(res_a)), gr.Textbox(label=get_label(res_b))

Expand Down
7 changes: 4 additions & 3 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ def update_label(content: str, component: gr.Textbox | gr.HighlightedText) -> gr
return gr.Textbox(label=new_label) if isinstance(component, gr.Textbox) else gr.HighlightedText(label=new_label)


def shuffle_and_flatten(original: dict[str, object], compressed: dict[str, object]):
responses = [original, compressed]
shuffle(responses)
def shuffle_and_flatten(compressed: dict[str, object], original: dict[str, object], survey_mode: bool) -> iter:
responses = [compressed, original]
if survey_mode:
shuffle(responses)
return (x for xs in responses for x in xs.values())


Expand Down

0 comments on commit 0afea2e

Please sign in to comment.