Skip to content

Commit

Permalink
remove choices from the framework (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
gautamjajoo authored Feb 7, 2025
1 parent 22ab82c commit 90887c9
Show file tree
Hide file tree
Showing 14 changed files with 154 additions and 165 deletions.
8 changes: 4 additions & 4 deletions config/ethos.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"beam_width": 3,
"group_frequency": 2,
"content_safety": true,
"cache_path": "cache/ethos.db",
"initial_prompt": "Given a query, you have to tell if it contains hate speech in any form or not.",
"cache_path": "",
"initial_prompt": "Given a query, you have to tell if it contains hate speech in any form or not. Output 0 means Non-Hate and 1 means Hate.",
"dataset_name": "ethos",
"number_of_groups": 2,
"metric_kwargs": {
Expand All @@ -38,7 +38,7 @@
"api_type": "",
"api_base": "",
"api_version": "",
"api_key": ""
"api_key": ""
}
},
"grouping_llm": {
Expand All @@ -50,7 +50,7 @@
"api_type": "",
"api_base": "",
"api_version": "",
"api_key": ""
"api_key": ""
}
}
}
21 changes: 12 additions & 9 deletions config/gsm8k.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,38 @@
"solver_llm": {
"model_kwargs": {
"temperature": 0,
"model": "gpt-4o"
"temp_0": 0,
"model": "meta-llama/llama-3.3-70b-instruct"
},
"api_kwargs": {
"api_type": "",
"api_base": "",
"api_type": "openrouter",
"api_base": "https://openrouter.ai/api/v1",
"api_version": "",
"api_key": ""
}
},
"expert_llm": {
"model_kwargs": {
"temperature": 0.8,
"model": "gpt-4o"
"temp_0": 0,
"model": "meta-llama/llama-3.3-70b-instruct"
},
"api_kwargs": {
"api_type": "",
"api_base": "",
"api_type": "openrouter",
"api_base": "https://openrouter.ai/api/v1",
"api_version": "",
"api_key": ""
}
},
"grouping_llm": {
"model_kwargs": {
"temperature": 0.8,
"model": "gpt-4o"
"temp_0": 0,
"model": "meta-llama/llama-3.3-70b-instruct"
},
"api_kwargs": {
"api_type": "",
"api_base": "",
"api_type": "openrouter",
"api_base": "https://openrouter.ai/api/v1",
"api_version": "",
"api_key": ""
}
Expand Down
200 changes: 100 additions & 100 deletions data/ethos.jsonl

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions src/uniprompt/beam_search/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def __init__(self, beam_width: int):
self.beam = []

def initialize_candidates(self, initial_prompt: str, data: Tuple, config: Dict, eval_fn: Optional[Callable] = None) -> List[Tuple[float, str, float]]:
val_questions, val_choices, val_answers = data
val_questions, val_answers = data
if eval_fn is None:
eval_fn = evaluate_prompt
eval_result = eval_fn(initial_prompt, val_questions, val_answers, val_choices, config)
eval_result = eval_fn(initial_prompt, val_questions, val_answers, config)
self.beam.append((-eval_result["acc"], initial_prompt, 0))

def get_best_prompt(self) -> str:
Expand All @@ -31,15 +31,15 @@ def apply_edits(prompt: str, prompt_template: str, edits: str, config: Dict[str,
def apply_edits_to_beam(self, final_feedback: str, val_data: Tuple, config) -> None:
new_candidates = []
prompts = load_prompts()
val_questions, val_choices, val_answers = val_data
val_questions, val_answers = val_data
for beam_acc, beam_prompt, _ in self.beam:
new_prompt = apply_edits(
prompt=beam_prompt,
prompt_template=prompts.get("apply_edits", None),
edits=final_feedback,
config=config,
)
eval_result = evaluate_prompt(new_prompt, questions=val_questions, choices=val_choices, answers=val_answers, config=config)
eval_result = evaluate_prompt(new_prompt, questions=val_questions, answers=val_answers, config=config)
acc, cm = eval_result["acc"], eval_result["cm"]
new_candidates.append((-acc, new_prompt, -(beam_acc-acc)))

Expand All @@ -48,11 +48,11 @@ def apply_edits_to_beam(self, final_feedback: str, val_data: Tuple, config) -> N
return self.beam

def add_prompt_to_beam(self, prompt: str, val_data: Tuple, config: Dict, eval_fn: Optional[Callable] = None) -> None:
val_questions, val_choices, val_answers = val_data
val_questions, val_answers = val_data
if eval_fn is None:
eval_fn = evaluate_prompt

eval_result = eval_fn(prompt, questions=val_questions, choices=val_choices, answers=val_answers, config = config)
eval_result = eval_fn(prompt, questions=val_questions, answers=val_answers, config = config)
print(f"Metrics: {eval_result}")
new_candidate = (-eval_result["acc"], prompt, (self.beam[0][0] + (eval_result["acc"])))
self.beam = heapq.nsmallest(self.beam_width, self.beam + [new_candidate])
Expand Down
20 changes: 6 additions & 14 deletions src/uniprompt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,45 @@
from datasets import load_dataset
from uniprompt.utils.summary_utils import load_prompts
import re
from uniprompt.utils.api_utils import chat_completion

def load_data(dataset_name: str, split: Optional[dict] = None) -> tuple:
base_path = f"data/{dataset_name}"
with open(f"{base_path}.jsonl") as f:
data = [json.loads(line) for line in f]

train_questions = []
train_choices = []
train_answers = []
val_questions = []
val_choices = []
val_answers = []
test_questions = []
test_choices = []
test_answers = []

for question in data:
if question["split"] == "train":
train_questions.append(question["question"])
train_choices.append(question["choices"])
train_answers.append(question["answer"])
elif question["split"] == "validation":
val_questions.append(question["question"])
val_choices.append(question["choices"])
val_answers.append(question["answer"])
elif question["split"] == "test":
test_questions.append(question["question"])
test_choices.append(question["choices"])
test_answers.append(question["answer"])

train_set = (list(train_questions), list(train_choices), list(train_answers))
val_set = (list(val_questions), list(val_choices), list(val_answers))
test_set = (list(test_questions), list(test_choices), list(test_answers))
train_set = (list(train_questions), list(train_answers))
val_set = (list(val_questions), list(val_answers))
test_set = (list(test_questions), list(test_answers))

return train_set, val_set, test_set

def write_to_jsonl(f, split, data):
questions = data["text"]
answers = data["label"]
choices = [["Non-Hate", "Hate"]] * len(questions)
for i in range(len(questions)):
write_data = {
"split": split,
"question": questions[i],
"choices": choices[i],
"answer": choices[i][answers[i]],
"answer": f"{answers[i]}",
}
json.dump(write_data, f)
f.write("\n")
Expand Down Expand Up @@ -79,7 +72,6 @@ def default_write_to_jsonl(f, split, data):
write_data = {
"split": split,
"question": questions[i],
"choices": "",
"answer": extracted_answer,
}
json.dump(write_data, f)
Expand Down Expand Up @@ -110,7 +102,7 @@ def create_gsm8k_dataset(output_path):
default_write_to_jsonl(f, "validation", val_subset)
default_write_to_jsonl(f, "test", test_subset)

def add_rationale_to_dataset(dataset_name: str):
def add_rationale_to_dataset(dataset_name: str, config):
prompts = load_prompts()
add_rationale_prompt = prompts.get("add_rationale", None)

Expand Down
9 changes: 4 additions & 5 deletions src/uniprompt/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@


def evaluate(data, prompt, config):
questions, choices, answers = data
return evaluate_prompt(prompt, questions, answers, choices, config)
questions, answers = data
return evaluate_prompt(prompt, questions, answers, config)

def evaluate_prompt(
new_prompt: str,
questions: Sequence[str],
answers: Sequence[str],
choices: Sequence[List[str]],
config: Dict[str, Any],
) -> Dict[str, Union[float, List[List[float]]]]:
acc = 0
Expand All @@ -22,15 +21,15 @@ def evaluate_prompt(
y_pred = []
i = 0

for question, answer, choice in zip(questions, answers, choices):
for question, answer in zip(questions, answers):
i+=1
# if answer not in choice:
# if answer == "1":
# answer = choice[1]
# if answer == "0":
# answer = choice[0]

prompt = make_prompt(prompt=new_prompt, question=question, choices=choice, template="make_prompt")
prompt = make_prompt(prompt=new_prompt, question=question, template="make_prompt")
messages = [{"role": "system", "content": "You are an expert"}, {"role": "user", "content": prompt}]
answer_cot = chat_completion(cache_path=config["cache_path"], **config["solver_llm"], messages=messages)
answer_llm = extract_answer(answer_cot)
Expand Down
13 changes: 7 additions & 6 deletions src/uniprompt/grouping/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,28 @@ def __init__(self, number_of_groups: int):
self.edit_history_dict = {group: [] for group in self.groups}

def create_groups(self, prompt: str, data: Tuple, config: Dict, grouping_fn: Optional[Callable] = None) -> List[List]:
train_questions, train_choices, train_answers = data
train_questions, train_answers = data

if grouping_fn is None:
grouping_fn = default_grouping_fn

self.groups = grouping_fn(prompt, train_questions, train_choices, train_answers, config)
self.groups = grouping_fn(prompt, train_questions, train_answers, config)

def default_grouping_fn(
prompt: str,
questions: Sequence[str],
answers: Sequence[str],
choices: Sequence[List[str]],
config: Dict[str, Any]
) -> Dict[int, List[int]]:
feedbacks = []
prompts = load_prompts()
for question, choice, answer in zip(questions, choices, answers):
formatted_prompt = make_prompt(prompt=prompt, question=question, choices=choice)
for question, answer in zip(questions, answers):
formatted_prompt = make_prompt(prompt=prompt, question=question)
messages = [{"role": "user", "content": formatted_prompt}]
answer_cot = chat_completion(cache_path=config["cache_path"], **config["solver_llm"], messages=messages)
answer_llm = extract_answer(answer_cot)
wrong_choices, wrong_cots, correct_choices = [], [], []

if answer_llm is not None and str(answer_llm) != str(answer):
wrong_choices = answer_llm
wrong_cots = answer_cot
Expand Down Expand Up @@ -63,7 +64,7 @@ def default_grouping_fn(
groups[i+1] = [] # i+1 because 0 is reserved for correct answers
groups_str += f"Group {i}: {g}\n"

for idx, (question, choice, answer, feedback) in enumerate(zip(questions, choices, answers, feedbacks)):
for idx, (question, answer, feedback) in enumerate(zip(questions, answers, feedbacks)):
assign_group_prompt = prompts.get("assign_group_prompt", None).format(groups_str=groups_str, feedback=feedback, number_of_groups = config["number_of_groups"])
messages = [{"role": "user", "content": assign_group_prompt}]
cluster_number = chat_completion(cache_path=config["cache_path"], **config["grouping_llm"], messages=messages)
Expand Down
4 changes: 1 addition & 3 deletions src/uniprompt/metaprompts/custom.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@ make_prompt: |
{prompt}
Question: {question}
Answer choices: {choices}
Output format: Give your answer as the correct option between <Answer></Answer> tags.
Let's think step by step.
make_user_prompt: |
Question: {question}
Answer choices: {choices}
Output format: Give your answer as the correct option between <Answer></Answer> tags.
Let's think step by step.
Expand Down Expand Up @@ -241,7 +239,7 @@ opro_initial: |
I have some texts along with their corresponding scores. The texts are arranged in ascending order based on their scores, where higher scores indicate better quality.
opro_middle: |
The following exemplars show how to apply your text: you replace <INS> in each input with your text, then read the question and choices and give an output. We say your output is wrong if your output is different from the given output, and we say your output is correct if they are the same.
The following exemplars show how to apply your text: you replace <INS> in each input with your text, then read the question and give an output. We say your output is wrong if your output is different from the given output, and we say your output is correct if they are the same.
opro_end: |
Generate {num_prompts} new texts with high scores. Do not include any introductory text. Each new text should be separated by <NEWLINE> tags.
Expand Down
2 changes: 0 additions & 2 deletions src/uniprompt/metaprompts/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@ make_prompt: |
{prompt}
Question: {question}
Answer choices: {choices}
Output format: Give your answer as the correct option between <Answer></Answer> tags.
Let's think step by step.
make_user_prompt: |
Question: {question}
Answer choices: {choices}
Output format: Give your answer as the correct option between <Answer></Answer> tags.
Let's think step by step.
Expand Down
14 changes: 6 additions & 8 deletions src/uniprompt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def train(train_data, val_data, config, beam, grouping):
prompts = load_prompts()
groups = grouping.groups
train_questions, train_choices, train_answers = train_data
train_questions, train_answers = train_data
for group in groups:
if(len(groups[group]) == 0):
continue
Expand All @@ -26,7 +26,6 @@ def train(train_data, val_data, config, beam, grouping):
batch_size = config["batch_size"] * mini_batch_size
selected_indices = random_batch(groups[group], batch_size)
batch_questions = [train_questions[index] for index in selected_indices]
batch_choices = [train_choices[index] for index in selected_indices]
batch_answers = [train_answers[index] for index in selected_indices]

global_feedback = ""
Expand All @@ -38,10 +37,9 @@ def train(train_data, val_data, config, beam, grouping):
mini_batch_ranges = mini_batch_indices(len(batch_questions), mini_batch_size)
for start, end in mini_batch_ranges: # Mini-batch loop
mini_batch_questions = batch_questions[start:end]
mini_batch_choices = batch_choices[start:end]
mini_batch_answers = batch_answers[start:end]
# Use a list for the feedback to be able to append to it since strings are immutable and we need to append for each mini batch
feedback = train_batch_adaptive(mini_batch_questions, mini_batch_choices, mini_batch_answers, curr_prompt, feedback, correct_answers, group, total_wrong_questions, config, grouping)
feedback = train_batch_adaptive(mini_batch_questions, mini_batch_answers, curr_prompt, feedback, correct_answers, group, total_wrong_questions, config, grouping)
global_feedback = " ".join(feedback)

# Adding randomization to the feedback to force the model to explore more options
Expand All @@ -56,20 +54,20 @@ def train(train_data, val_data, config, beam, grouping):
config=config,)

else:
global_feedback = train_batch_adaptive(batch_questions, batch_choices, batch_answers, curr_prompt, global_feedback, correct_answers, group, total_wrong_questions, config, grouping)
global_feedback = train_batch_adaptive(batch_questions, batch_answers, curr_prompt, global_feedback, correct_answers, group, total_wrong_questions, config, grouping)
final_feedback = " ".join(global_feedback)

# apply edit to all prompts in the beam
beam.beam = beam.apply_edits_to_beam(final_feedback, val_data = val_data, config = config)

return beam

def train_batch_adaptive(mini_batch_questions, mini_batch_choices, mini_batch_answers, p, feedback, correct_answers, group, total_wrong_questions, config, grouping):
def train_batch_adaptive(mini_batch_questions, mini_batch_answers, p, feedback, correct_answers, group, total_wrong_questions, config, grouping):
prompts = load_prompts()
wrong_questions, wrong_choices, wrong_cots, correct_choices = [], [], [], []
acc = 0
for question, answer, choices in zip(mini_batch_questions, mini_batch_answers, mini_batch_choices):
prompt = make_prompt(prompt=p, question=question, choices=choices)
for question, answer in zip(mini_batch_questions, mini_batch_answers):
prompt = make_prompt(prompt=p, question=question)
messages = [{"role": "user", "content": prompt}]
answer_cot = chat_completion(cache_path=config["cache_path"], **config["solver_llm"], messages=messages)
answer_llm = extract_answer(answer_cot)
Expand Down
4 changes: 3 additions & 1 deletion src/uniprompt/utils/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def make_api_call(client, **kwargs):

if kwargs["api_kwargs"]["api_type"] == "azure":
client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["api_kwargs"]["api_base"], api_version=kwargs["api_kwargs"]["api_version"])
elif kwargs["api_kwargs"]["api_base"]:
client = OpenAI(base_url=kwargs["api_kwargs"]["api_base"], api_key=api_key)
else:
client = OpenAI(api_key=api_key)

Expand Down Expand Up @@ -72,4 +74,4 @@ def make_api_call(client, **kwargs):
conn.commit()
conn.close()

return response_content
return response_content
2 changes: 1 addition & 1 deletion src/uniprompt/utils/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ def get_metric(y_true: List[str], y_pred: List[str], config: Optional [Dict] = N

else:
raise ValueError(f"Metric {metric_type} not supported")

return metric_val
Loading

0 comments on commit 90887c9

Please sign in to comment.