Skip to content

Commit

Permalink
Add DSPy Refine
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub committed Dec 19, 2024
1 parent 6dbc8bc commit a1a6bb2
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 0 deletions.
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dspy.teleprompt import *

import dspy.retrievers
from dspy.refine import Refine

from dspy.evaluate import Evaluate # isort: skip
from dspy.clients import * # isort: skip
Expand Down
8 changes: 8 additions & 0 deletions dspy/refine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dspy.refine.metrics import BoolMetric, FloatMetric
from dspy.refine.refine import Refine

__all__ = [
"Refine",
"BoolMetric",
"FloatMetric",
]
27 changes: 27 additions & 0 deletions dspy/refine/feedback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Union

from dspy.signatures.field import InputField, OutputField
from dspy.signatures.signature import Signature


class GenerateFeedback(Signature):
"""
Based on each metric value and metric definition for the inputs-outputs pair, provide feedback the DSPy module
along with submodules in order to improve the metric values at the retry. Only provide feedback for built-in
classses, e.g., dspy.Predict, dspy.Module, dspy.ChainOfThought and so on. If an attribute is a list, make sure
you look into every element. It's also possible that some components are not related to the certain score, we
should skip generating feedback if it is the case.
"""

metrics: list[str] = InputField(desc="The definition of each scoring criterion")
metric_values: list[Union[int, float, bool]] = InputField(desc="The value of each metric, the higher the better")
module_inputs: dict = InputField(desc="The inputs of the DSPy module")
module_outputs: dict = InputField(desc="The outputs of the DSPy module")
source_code: str = InputField(desc="The source code of the DSPy module")
feedback: dict[str, list[str]] = OutputField(
desc="Feedback for the DSPy module in general, along with feedback for each submodule in the DSPy model, only "
"provide feedback for attributes in `__init__` method that is a built-in class of dspy. The key should be the "
"attribute name, e.g., `self.cot` or `self.predict`. If the attribute is a list, write the key as "
"`self.cots[0]`, `self.predicts[1]` and so on. The feedback should be "
"a list of strings, corresponding to each score function in `metrics`."
)
35 changes: 35 additions & 0 deletions dspy/refine/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import inspect


class Metric:
def __init__(self, name, description, fn):
self.name = name
self.description = description
self.fn = fn

def __call__(self, inputs, outputs):
return self.fn(inputs, outputs)

def __repr__(self):
if self.description:
return f"Metric name: {self.name}\nMetric description: {self.description}\n"
else:
return f"Metric name: {self.name}\nMetric function: {inspect.getsource(self.fn)}\n"


class BoolMetric(Metric):
def __init__(self, name, description, fn):
super().__init__(name, description, fn)
self.type_description = "This is a bool metric, true if the metric looks good, false otherwise."

def __repr__(self):
return f"{super().__repr__()}\nMetric type:{self.type_description}"


class FloatMetric(Metric):
def __init__(self, name, description, fn):
super().__init__(name, description, fn)
self.type_description = "This is a float metric, the higher the value the better."

def __repr__(self):
return f"{super().__repr__()}\nMetric type:{self.type_description}"
126 changes: 126 additions & 0 deletions dspy/refine/refine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import inspect
from functools import partial
from types import MethodType

from dspy.predict.chain_of_thought import ChainOfThought
from dspy.primitives.program import Module
from dspy.refine.feedback import GenerateFeedback
from dspy.refine.utils import get_traces
from dspy.signatures.field import InputField


class Refine(Module):
def __init__(self, module, metrics, metric_thresholds=None, max_iter=3):
self.module = module.deepcopy()
self.metrics = metrics
self.metric_thresholds = metric_thresholds
self.max_iter = max_iter

self.metric_descriptions = [self._get_metric_description(metric) for metric in metrics]
self.feedback_program = ChainOfThought(GenerateFeedback)

self._named_predicts = {name: predict for name, predict in self.module.named_predictors()}

def _get_metric_description(self, metric):
if hasattr(metric, "__repr__"):
return str(metric)
else:
return inspect.getsource(metric.__class__)

def _patch_predict_call_with_feedback(self, feedbacks):
named_predicts = {}
for name in feedbacks.keys():
# Only patch the predict that has feedback.
named_predicts[name] = self._named_predicts[name]

predict_traces = get_traces(named_predicts)

def forward_with_feedback(instance, dspy_refine_feedback, dspy_refine_last_trace, **kwargs):
return instance.original_forward(
**kwargs,
dspy_refine_feedback=dspy_refine_feedback,
dspy_refine_last_trace=dspy_refine_last_trace,
)

for name, predict in named_predicts.items():
last_trace = predict_traces.get(name, None)
# We trim out the last round's feedback and last_trace from the inputs to avoid too much nesting.
if "dspy_refine_feedback" in last_trace["inputs"]:
del last_trace["inputs"]["dspy_refine_feedback"]
if "dspy_refine_last_trace" in last_trace["inputs"]:
del last_trace["inputs"]["dspy_refine_last_trace"]

feedback = feedbacks.get(name, None)
if not hasattr(predict, "original_forward"):
# If the predict has never been patched for refine calls, patch it.
predict.original_signature = predict.signature
predict.signature = predict.signature.prepend(
"dspy_refine_feedback",
InputField(desc="Improvement suggestion based on last try", type=str),
).prepend("dspy_refine_last_trace", InputField(desc="Trace of the last try", type=dict))

# Save the original forward method before patching.
predict.original_forward = predict.forward

partial_forward = partial(
forward_with_feedback, dspy_refine_feedback=feedback, dspy_refine_last_trace=last_trace
)
# Patch the `forward` method to the `forward_with_feedback` methd with partial values of feedback and
# last_trace.
predict.forward = MethodType(partial_forward, predict)

def _undo_patch_predict_call_with_feedback(self, named_predicts):
for _, predict in named_predicts.items():
if hasattr(predict, "original_forward"):
predict.forward = predict.original_forward
predict.signature = predict.original_signature
del predict.original_signature
del predict.original_forward

def _get_feedback_for_predicts(self, inputs, outputs):
metric_descriptions = []
metric_values = []
for i, metric in enumerate(self.metrics):
metric_value = metric(inputs, outputs)
if self.metric_thresholds and metric_value < self.metric_thresholds[i]:
metric_descriptions.append(self.metric_descriptions[i])
metric_values.append(metric_value)

if len(metric_descriptions) == 0:
# All metric values are above the threshold, no need to refine.
return {}

# Get feedback for each metric.
feedbacks = self.feedback_program(
metrics=metric_descriptions,
metric_values=metric_values,
module_inputs=inputs,
module_outputs=outputs,
source_code=inspect.getsource(self.module.__class__),
).feedback
named_predicts = self._named_predicts

predict_name_to_feedback = {}
for name in named_predicts.keys():
top_module_name = name.split(".")[0]
if top_module_name in feedbacks:
predict_name_to_feedback[name] = feedbacks[top_module_name]
elif f"self.{top_module_name}" in feedbacks:
predict_name_to_feedback[name] = feedbacks[f"self.{top_module_name}"]
return predict_name_to_feedback

def __call__(self, **kwargs):
outputs = self.module(**kwargs)

for i in range(self.max_iter):
feedbacks = self._get_feedback_for_predicts(kwargs, outputs)

if len(feedbacks) == 0:
break
self._patch_predict_call_with_feedback(feedbacks)

outputs = self.module(**kwargs)

named_predicts = {name: predict for name, predict in self.module.named_predictors()}
self._undo_patch_predict_call_with_feedback(named_predicts)
return outputs
22 changes: 22 additions & 0 deletions dspy/refine/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from dspy.dsp.utils.settings import settings


def get_traces(named_predicts):
predict_name_to_traces = {}
predict_id_to_name = {id(predict): name for name, predict in named_predicts.items()}

traces = settings.trace
for i in range(len(traces)):
trace = traces[-i - 1]
trace_predict_id = id(trace[0])
if trace_predict_id in predict_id_to_name:
predict_name = predict_id_to_name[trace_predict_id]
if predict_name not in predict_name_to_traces:
predict_name_to_traces[predict_name] = {
"inputs": trace[1],
"outputs": trace[2].toDict(),
}
if len(predict_name_to_traces) == len(named_predicts):
# Stop searching when all predicts' traces are found.
break
return predict_name_to_traces

0 comments on commit a1a6bb2

Please sign in to comment.