Skip to content

Commit

Permalink
Merge pull request #302 from stanfordnlp/optimizers
Browse files Browse the repository at this point in the history
Adding example awareness to bayesian teleprompter
  • Loading branch information
klopsahlong authored Jan 29, 2024
2 parents e9cfae5 + 7ff3909 commit 533093c
Showing 1 changed file with 158 additions and 38 deletions.
196 changes: 158 additions & 38 deletions dspy/teleprompt/signature_opt_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,32 @@ class BasicGenerateInstructionWithDataObservations(Signature):
proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model")
proposed_prefix_for_output_field = dspy.OutputField(desc="The string at the end of the prompt, which will help the model start solving the task")

class BasicGenerateInstructionWithExamples(dspy.Signature):
("""You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Specifically, I will also provide you with the current ``basic instruction`` that is being used for this task. I will also provide you with some ``examples`` of the expected inputs and outputs.
Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative.""")
# attempted_instructions = dspy.InputField(format=str, desc="Previously attempted task instructions, along with their resulting validation score, and an example of the instruction in use on a sample from our dataset.")
basic_instruction = dspy.InputField(desc="The initial instructions before optimization")
# examples = dspy.InputField(format=dsp.passages2text, desc="Example(s) of the task")
examples = dspy.InputField(format=dsp.passages2text, desc="Example(s) of the task")
proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model")
proposed_prefix_for_output_field = dspy.OutputField(desc="The string at the end of the prompt, which will help the model start solving the task")

class BasicGenerateInstructionWithExamplesAndDataObservations(dspy.Signature):
("""You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Specifically, I will also provide you with the current ``basic instruction`` that is being used for this task. I will also provide you with some ``observations`` I have made about the dataset and task, along with some ``examples`` of the expected inputs and outputs.
Your task is to propose a new improved instruction and prefix for the output field that will lead a good language model to perform the task well. Don't be afraid to be creative.""")
basic_instruction = dspy.InputField(desc="The initial instructions before optimization")
observations = dspy.InputField(desc="Observations about the dataset and task")
examples = dspy.InputField(format=dsp.passages2text, desc="Example(s) of the task")
proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model")
proposed_prefix_for_output_field = dspy.OutputField(desc="The string at the end of the prompt, which will help the model start solving the task")

class ObservationSummarizer(dspy.Signature):
("""Given a series of observations I have made about my dataset, please summarize them into a brief 2-3 sentence summary which highlights only the most important details.""")
observations = dspy.InputField(desc="Observations I have made about my dataset")
summary = dspy.OutputField(desc="Two to Three sentence summary of only the most significant highlights of my observations")

class DatasetDescriptor(dspy.Signature):
("""Given several examples from a dataset please write observations about trends that hold for most or all of the samples. """
"""Some areas you may consider in your observations: topics, content, syntax, conciceness, etc. """
Expand Down Expand Up @@ -98,14 +124,14 @@ def _print_model_history(self, model, n=1):
if self.verbose: print(f"Model ({model}) History:")
model.inspect_history(n=n)

def _observe_data(self, trainset, data_limit=500):
def _observe_data(self, trainset):
upper_lim = min(len(trainset), self.view_data_batch_size)
observation = dspy.Predict(DatasetDescriptor, n=1, temperature=1.0)(examples=(trainset[0:upper_lim].__repr__()))
observations = observation["observations"]

skips = 0
for b in range(self.view_data_batch_size, min(len(trainset), data_limit), self.view_data_batch_size):
upper_lim = min(min(len(trainset), data_limit), b+self.view_data_batch_size)
for b in range(self.view_data_batch_size, len(trainset), self.view_data_batch_size):
upper_lim = min(len(trainset), b+self.view_data_batch_size)
output = dspy.Predict(DatasetDescriptorWithPriorObservations, n=1, temperature=1.0)(prior_observations=observations, examples=(trainset[b:upper_lim].__repr__()))
if len(output["observations"]) >= 8 and output["observations"][:8].upper() == "COMPLETE":
skips += 1
Expand All @@ -114,20 +140,64 @@ def _observe_data(self, trainset, data_limit=500):
continue
observations += output["observations"]

return observations
summary = dspy.Predict(ObservationSummarizer, n=1, temperature=1.0)(observations=observations)

return summary.summary

def _generate_first_N_candidates(self, module, N, data_informed_prompt_generation, devset):
def _create_example_string(self, fields, example):

# Building the output string
output = []
for field in fields:
name = field.name
separator = field.separator
input_variable = field.input_variable

# Determine the value from input_data or prediction_data
value = example.get(input_variable)

# Construct the string for the current field
field_str = f"{name}{separator}{value}"
output.append(field_str)

# Joining all the field strings
return '\n'.join(output)

def _generate_first_N_candidates(self, module, N, view_data, view_examples, demo_candidates, devset):
candidates = {}
evaluated_candidates = defaultdict(dict)

if data_informed_prompt_generation:
if view_data:
# Create data observations
self.observations = None
if self.prompt_model:
with dspy.settings.context(lm=self.prompt_model):
self.observations = self._observe_data(devset).replace("Observations:","")
self.observations = self._observe_data(devset).replace("Observations:","").replace("Summary:","")
else:
self.observations = self._observe_data(devset).replace("Observations:","")
self.observations = self._observe_data(devset).replace("Observations:","").replace("Summary:","")

if view_examples:
example_sets = {}
for predictor in module.predictors():
# Get all augmented examples
example_set = {}
all_sets_of_examples = demo_candidates[id(predictor)] # Get all generated sets of examples
for example_set_i, set_of_examples in enumerate(all_sets_of_examples):
if example_set_i != 0: # Skip the no examples case
for example in set_of_examples: # Get each individual example in the set
if "augmented" in example.keys():
if example["augmented"]:
if example_set_i not in example_set:
example_set[example_set_i] = []
fields_to_use = predictor.signature.fields
input_variable_names = [field.input_variable for field in fields_to_use]
example_with_only_signature_fields = {key: value for key, value in example.items() if key in input_variable_names}
example_string = self._create_example_string(fields_to_use, example_with_only_signature_fields)
example_set[example_set_i].append(example_string)
example_sets[id(predictor)] = example_set
else:
example_set[example_set_i] = []
example_sets[id(predictor)] = example_set

# Seed the prompt optimizer zero shot with just the instruction, generate BREADTH new prompts
for predictor in module.predictors():
Expand All @@ -140,55 +210,105 @@ def _generate_first_N_candidates(self, module, N, data_informed_prompt_generatio
basic_instruction = predictor.extended_signature1.instructions
basic_prefix = predictor.extended_signature1.fields[-1].name
if self.prompt_model:
if data_informed_prompt_generation:
with dspy.settings.context(lm=self.prompt_model):
with dspy.settings.context(lm=self.prompt_model):
# Data & Examples
if view_data and view_examples:
instruct = None
for i in range(1,self.n):
new_instruct = dspy.Predict(BasicGenerateInstructionWithExamplesAndDataObservations, n=1, temperature=self.init_temperature)(basic_instruction=basic_instruction, observations=self.observations, examples=example_sets[id(predictor)][i])
if not instruct:
instruct = new_instruct
else:
instruct.completions.proposed_instruction.extend(new_instruct.completions.proposed_instruction)
instruct.completions.proposed_prefix_for_output_field.extend(new_instruct.completions.proposed_prefix_for_output_field)
# Just data
elif view_data:
instruct = dspy.Predict(BasicGenerateInstructionWithDataObservations, n=N-1, temperature=self.init_temperature)(basic_instruction=basic_instruction, observations=self.observations)
else:
with dspy.settings.context(lm=self.prompt_model):
# Just examples
elif view_examples:
instruct = None
for i in range(1,self.n): # Note: skip over the first example set which is empty
new_instruct = dspy.Predict(BasicGenerateInstructionWithExamples, n=1, temperature=self.init_temperature)(basic_instruction=basic_instruction, examples=example_sets[id(predictor)][i])
if not instruct:
instruct = new_instruct
else:
instruct.completions.proposed_instruction.extend(new_instruct.completions.proposed_instruction)
instruct.completions.proposed_prefix_for_output_field.extend(new_instruct.completions.proposed_prefix_for_output_field)
# Neither
else:
instruct = dspy.Predict(BasicGenerateInstruction, n=N-1, temperature=self.init_temperature)(basic_instruction=basic_instruction)
else:
if data_informed_prompt_generation:
# Data & Examples
if view_data and view_examples:
instruct = None
for i in range(1,self.n):
new_instruct = dspy.Predict(BasicGenerateInstructionWithExamplesAndDataObservations, n=1, temperature=self.init_temperature)(basic_instruction=basic_instruction, observations=self.observations, examples=example_sets[id(predictor)][i])
if not instruct:
instruct = new_instruct
else:
instruct.completions.proposed_instruction.extend(new_instruct.completions.proposed_instruction)
instruct.completions.proposed_prefix_for_output_field.extend(new_instruct.completions.proposed_prefix_for_output_field)
# Just data
elif view_data:
instruct = dspy.Predict(BasicGenerateInstructionWithDataObservations, n=N-1, temperature=self.init_temperature)(basic_instruction=basic_instruction, observations=self.observations)
else:
# Just examples
elif view_examples:
instruct = None
for i in range(1,self.n): # Note: skip over the first example set which is empty
new_instruct = dspy.Predict(BasicGenerateInstructionWithExamples, n=1, temperature=self.init_temperature)(basic_instruction=basic_instruction, examples=example_sets[id(predictor)][i])
if not instruct:
instruct = new_instruct
else:
instruct.completions.proposed_instruction.extend(new_instruct.completions.proposed_instruction)
instruct.completions.proposed_prefix_for_output_field.extend(new_instruct.completions.proposed_prefix_for_output_field)
# Neither
else:
instruct = dspy.Predict(BasicGenerateInstruction, n=N-1, temperature=self.init_temperature)(basic_instruction=basic_instruction)

# Add in our initial prompt as a candidate as well
instruct.completions.proposed_instruction.append(basic_instruction)
instruct.completions.proposed_prefix_for_output_field.append(basic_prefix)
instruct.completions.proposed_instruction.insert(0, basic_instruction)
instruct.completions.proposed_prefix_for_output_field.insert(0, basic_prefix)
candidates[id(predictor)] = instruct.completions
evaluated_candidates[id(predictor)] = {}

if self.verbose and self.prompt_model: self._print_model_history(self.prompt_model)

return candidates, evaluated_candidates

def compile(self, student, *, devset, optuna_trials_num, max_bootstrapped_demos, max_labeled_demos, eval_kwargs, seed=42, data_informed_prompt_generation=True):
def compile(self, student, *, devset, optuna_trials_num, max_bootstrapped_demos, max_labeled_demos, eval_kwargs, seed=42, view_data=True, view_examples=True):

random.seed(seed)

# Set up program and evaluation function
module = student.deepcopy()
evaluate = Evaluate(devset=devset, metric=self.metric, **eval_kwargs)

# Generate N candidate prompts
instruction_candidates, _ = self._generate_first_N_candidates(module, self.n, data_informed_prompt_generation, devset)

# Generate N few shot example sets
demo_candidates = {}
for seed in range(self.n):
if self.verbose: print(f"Creating basic bootstrap {seed}/{self.n}")

# Create a new basic bootstrap few - shot program .
rng = random.Random(seed)
shuffled_devset = devset[:] # Create a copy of devset
rng.shuffle(shuffled_devset) # Shuffle the copy
tp = BootstrapFewShot(metric = self.metric, max_bootstrapped_demos=max_bootstrapped_demos, max_labeled_demos=max_labeled_demos, teacher_settings=self.teacher_settings)
candidate_program = tp.compile(student=module.deepcopy(), trainset=shuffled_devset)

# Store the candidate demos
for module_p, candidate_p in zip(module.predictors(), candidate_program.predictors()):
if id(module_p) not in demo_candidates.keys():
demo_candidates[id(module_p)] = []
demo_candidates[id(module_p)].append(candidate_p.demos)
for i in range(self.n):
if i == 0: # Story empty set of demos as default for index 0
for module_p in module.predictors():
if id(module_p) not in demo_candidates.keys():
demo_candidates[id(module_p)] = []
demo_candidates[id(module_p)].append([])
else:
if self.verbose: print(f"Creating basic bootstrap: {i}/{self.n-1}")

# Create a new basic bootstrap few - shot program .
rng = random.Random(i)
shuffled_devset = devset[:] # Create a copy of devset
rng.shuffle(shuffled_devset) # Shuffle the copy
tp = BootstrapFewShot(metric = self.metric, max_bootstrapped_demos=max_bootstrapped_demos, max_labeled_demos=max_labeled_demos, teacher_settings=self.teacher_settings)
candidate_program = tp.compile(student=module.deepcopy(), trainset=shuffled_devset)

# Store the candidate demos
for module_p, candidate_p in zip(module.predictors(), candidate_program.predictors()):
if id(module_p) not in demo_candidates.keys():
demo_candidates[id(module_p)] = []
demo_candidates[id(module_p)].append(candidate_p.demos)

# Generate N candidate prompts
instruction_candidates, _ = self._generate_first_N_candidates(module, self.n, view_data, view_examples, demo_candidates, devset)

# Initialize variables to store the best program and its score
best_score = float('-inf')
Expand Down Expand Up @@ -216,8 +336,8 @@ def objective(trial):
# Suggest the index of the instruction candidate to use in our trial
instruction_idx = trial.suggest_int(f"{id(p_old)}_predictor_instruction",low=0, high=len(p_instruction_candidates)-1)
demos_idx = trial.suggest_int(f"{id(p_old)}_predictor_demos",low=0, high=len(p_demo_candidates)-1)
trial_logs[trial_num]["instruction_idx"] = instruction_idx
trial_logs[trial_num]["demos_idx"] = demos_idx
trial_logs[trial_num][f"{id(p_old)}_predictor_instruction"] = instruction_idx
trial_logs[trial_num][f"{id(p_old)}_predictor_demos"] = demos_idx

# Get the selected instruction candidate
selected_candidate = p_instruction_candidates[instruction_idx]
Expand All @@ -239,7 +359,7 @@ def objective(trial):
trial_logs[trial_num]["program"] = candidate_program

# Evaluate with the new prompts
total_score, curr_avg_score = 0, 0
total_score = 0
batch_size = 100
num_batches = math.ceil(len(devset) / batch_size)

Expand Down

0 comments on commit 533093c

Please sign in to comment.