Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrusNuevoDia committed Nov 20, 2024
1 parent f289a25 commit c8d3d98
Show file tree
Hide file tree
Showing 8 changed files with 657 additions and 145 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
/examples/nli/scone/compiled_program.dspy
/examples/qa/hotpot/compiled_program.dspy
/ScoNe/
# testing/playbook.ipynb

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -58,4 +59,4 @@ assertion.log
.mypy_cache
dummy.csv
docs/docs/**/*.json*
*.index
*.index
10 changes: 6 additions & 4 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from .signatures import *

# Functional must be imported after primitives, predict and signatures
from .functional import * # isort: skip
from dspy.evaluate import Evaluate # isort: skip
from dspy.clients import * # isort: skip
from dspy.adapters import * # isort: skip
from .functional import * # isort: skip
from dspy.evaluate import Evaluate # isort: skip
from dspy.clients import * # isort: skip
from dspy.adapters import * # isort: skip
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
from dspy.utils.asyncify import asyncify

Expand Down Expand Up @@ -68,6 +68,8 @@
LabeledFewShot = dspy.teleprompt.LabeledFewShot
BootstrapFewShot = dspy.teleprompt.BootstrapFewShot
BootstrapFewShotWithRandomSearch = dspy.teleprompt.BootstrapFewShotWithRandomSearch
BootstrapKNN = dspy.teleprompt.BootstrapKNN
BootstrapKNNWithRandomSearch = dspy.teleprompt.BootstrapKNNWithRandomSearch
BootstrapRS = dspy.teleprompt.BootstrapFewShotWithRandomSearch
BootstrapFinetune = dspy.teleprompt.BootstrapFinetune
BetterTogether = dspy.teleprompt.BetterTogether
Expand Down
3 changes: 3 additions & 0 deletions dspy/predict/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ def __call__(self, **kwargs) -> List["dspy.Example"]:
nearest_samples_idxs = scores.argsort()[-self.k :][::-1]
train_sampled = [self.trainset[cur_idx] for cur_idx in nearest_samples_idxs]
return train_sampled

def demo_selector(self, predict: "dspy.Predict", inputs):
return self(**inputs)
3 changes: 1 addition & 2 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import random
from functools import lru_cache
from typing import Optional

from pydantic import BaseModel

Expand Down Expand Up @@ -191,7 +190,7 @@ def forward(self, **kwargs):
import dspy

if hasattr(self, "knn"):
demos = self.knn(**inputs)
demos += self.knn(**inputs)

if isinstance(lm, dspy.LM):
completions = v2_5_generate(lm, config, signature, demos, inputs, _parse_values=self._parse_values)
Expand Down
29 changes: 19 additions & 10 deletions dspy/teleprompt/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,18 @@ class BootstrapKNN(BootstrapFewShot):
def __init__(
self,
embedding: Optional[Callable[[list[str]], np.ndarray]] = None, # dspy.Embedding
k: int = 16,
metric=None,
metric_threshold=None,
teacher_settings: Optional[Dict] = None,
max_bootstrapped_demos=10_000, # fill the predictor .demos
max_labeled_demos=16, # for the teacher model only
max_bootstrapped_demos=10_000,
num_static_demos=0,
max_labeled_demos=16,
max_rounds=1,
max_errors=2000,
random_seed=0,
):
assert num_static_demos < max_labeled_demos, "static demos must be less than max labeled demos."

super().__init__(
metric=metric,
metric_threshold=metric_threshold,
Expand All @@ -296,17 +299,23 @@ def __init__(
max_rounds=max_rounds,
max_errors=max_errors,
)
self.k = k
self.num_static_demos = num_static_demos
self.embedding = embedding
self.random_seed = random_seed

def _train(self):
rng = random.Random(self.random_seed)

for name, predictor in self.student.named_predictors():
predictor.demos = self.name2traces[name][: self.max_labeled_demos]
augmented_demos = self.name2traces[name]

static_demos = rng.sample(augmented_demos, self.num_static_demos)
dynamic_demos = list(set(augmented_demos) - set(static_demos))

predictor.demos = static_demos

# TODO: Make this dump/load-able
predictor.knn = dspy.KNN(
k=self.k,
trainset=predictor.demos,
vectorizer=self.embedding,
)
k = self.max_labeled_demos - self.num_static_demos
predictor.knn = dspy.KNN(k, trainset=dynamic_demos, vectorizer=self.embedding)

return self.student
143 changes: 140 additions & 3 deletions dspy/teleprompt/random_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dspy.evaluate.evaluate import Evaluate
from dspy.teleprompt.teleprompt import Teleprompter

from .bootstrap import BootstrapFewShot
from .bootstrap import BootstrapFewShot, BootstrapKNN
from .vanilla import LabeledFewShot

# TODO: Don't forget dealing with the raw demos.
Expand All @@ -27,7 +27,7 @@ class BootstrapFewShotWithRandomSearch(Teleprompter):
def __init__(
self,
metric,
teacher_settings={},
teacher_settings=None,
max_bootstrapped_demos=4,
max_labeled_demos=16,
max_rounds=1,
Expand All @@ -38,7 +38,7 @@ def __init__(
metric_threshold=None,
):
self.metric = metric
self.teacher_settings = teacher_settings
self.teacher_settings = teacher_settings or {}
self.max_rounds = max_rounds

self.num_threads = num_threads
Expand Down Expand Up @@ -150,6 +150,143 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None
return best_program


class BootstrapKNNWithRandomSearch(Teleprompter):
def __init__(
self,
metric,
embedding,
teacher_settings=None,
max_bootstrapped_demos=10_000,
max_labeled_demos=16,
max_rounds=1,
num_candidate_programs=16,
num_threads=6,
max_errors=10,
stop_at_score=None,
metric_threshold=None,
):
self.metric = metric
self.embedding = embedding
self.teacher_settings = teacher_settings or {}
self.max_rounds = max_rounds

self.num_threads = num_threads
self.stop_at_score = stop_at_score
self.metric_threshold = metric_threshold
self.min_static_demos = 1
self.max_static_demos = max_labeled_demos - 1
self.max_errors = max_errors
self.num_candidate_sets = num_candidate_programs
self.max_labeled_demos = max_labeled_demos
self.max_bootstrapped_demos = max_bootstrapped_demos

print(
f"Going to sample between {self.min_static_demos} and {self.max_static_demos} static demos per predictor."
)
print(f"Will attempt to bootstrap {self.num_candidate_sets} candidate sets.")

def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None, labeled_sample=True):
self.trainset = trainset
self.valset = valset or trainset # TODO: FIXME: Note this choice.

scores = []
all_subscores = []
score_data = []

for seed in range(-3, self.num_candidate_sets):
if (restrict is not None) and (seed not in restrict):
continue

trainset_copy = list(self.trainset)

if seed == -3:
"Zero Shot"
program = student.reset_copy()

elif seed == -2:
"Labeled Few Shot"
teleprompter = LabeledFewShot(k=self.max_labeled_demos)
program = teleprompter.compile(student, trainset=trainset_copy, sample=labeled_sample)

elif seed == -1:
print("BootstrapKNN with 0 static demos")
optimizer = BootstrapKNN(
metric=self.metric,
embedding=self.embedding,
metric_threshold=self.metric_threshold,
max_bootstrapped_demos=self.max_bootstrapped_demos,
max_labeled_demos=self.max_labeled_demos,
num_static_demos=0,
teacher_settings=self.teacher_settings,
max_rounds=self.max_rounds,
max_errors=self.max_errors,
)
program = optimizer.compile(student, teacher=teacher, trainset=trainset_copy)

else:
num_static_demos = random.Random(seed).randint(
self.min_static_demos,
self.max_static_demos,
)
print(f"BootstrapKNN with {num_static_demos} static demos")

optimizer = BootstrapKNN(
metric=self.metric,
metric_threshold=self.metric_threshold,
max_bootstrapped_demos=self.max_bootstrapped_demos,
max_labeled_demos=self.max_labeled_demos,
teacher_settings=self.teacher_settings,
max_rounds=self.max_rounds,
max_errors=self.max_errors,
num_static_demos=num_static_demos,
random_seed=seed,
)

program = optimizer.compile(student, teacher=teacher, trainset=trainset_copy)

evaluate = Evaluate(
devset=self.valset,
metric=self.metric,
num_threads=self.num_threads,
max_errors=self.max_errors,
display_table=False,
display_progress=True,
)

score, subscores = evaluate(program, return_all_scores=True)

all_subscores.append(subscores)

############ Assertion-aware Optimization ############
if hasattr(program, "_suggest_failures"):
score = score - program._suggest_failures * 0.2
if hasattr(program, "_assert_failures"):
score = 0 if program._assert_failures > 0 else score
######################################################

if len(scores) == 0 or score > max(scores):
print("New best score:", score, "for seed", seed)
best_program = program

scores.append(score)
print(f"Scores so far: {scores}")
print(f"Best score so far: {max(scores)}")

score_data.append((score, subscores, seed, program))

if self.stop_at_score is not None and score >= self.stop_at_score:
print(f"Stopping early because score {score} is >= stop_at_score {self.stop_at_score}")
break

# To best program, attach all program candidates in decreasing average score
best_program.candidate_programs = score_data
best_program.candidate_programs = sorted(best_program.candidate_programs, key=lambda x: x[0], reverse=True)

print(f"{len(best_program.candidate_programs)} candidate programs found.")

return best_program


# sample between 4 and 10 examples from traces
# TODO: FIXME: The max number of demos should be determined in part by the LM's tokenizer + max_length.
# This does require executing the program, or at least the predictor.
Expand Down
Loading

0 comments on commit c8d3d98

Please sign in to comment.