-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6dbc8bc
commit a1a6bb2
Showing
6 changed files
with
219 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from dspy.refine.metrics import BoolMetric, FloatMetric | ||
from dspy.refine.refine import Refine | ||
|
||
__all__ = [ | ||
"Refine", | ||
"BoolMetric", | ||
"FloatMetric", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from typing import Union | ||
|
||
from dspy.signatures.field import InputField, OutputField | ||
from dspy.signatures.signature import Signature | ||
|
||
|
||
class GenerateFeedback(Signature): | ||
""" | ||
Based on each metric value and metric definition for the inputs-outputs pair, provide feedback the DSPy module | ||
along with submodules in order to improve the metric values at the retry. Only provide feedback for built-in | ||
classses, e.g., dspy.Predict, dspy.Module, dspy.ChainOfThought and so on. If an attribute is a list, make sure | ||
you look into every element. It's also possible that some components are not related to the certain score, we | ||
should skip generating feedback if it is the case. | ||
""" | ||
|
||
metrics: list[str] = InputField(desc="The definition of each scoring criterion") | ||
metric_values: list[Union[int, float, bool]] = InputField(desc="The value of each metric, the higher the better") | ||
module_inputs: dict = InputField(desc="The inputs of the DSPy module") | ||
module_outputs: dict = InputField(desc="The outputs of the DSPy module") | ||
source_code: str = InputField(desc="The source code of the DSPy module") | ||
feedback: dict[str, list[str]] = OutputField( | ||
desc="Feedback for the DSPy module in general, along with feedback for each submodule in the DSPy model, only " | ||
"provide feedback for attributes in `__init__` method that is a built-in class of dspy. The key should be the " | ||
"attribute name, e.g., `self.cot` or `self.predict`. If the attribute is a list, write the key as " | ||
"`self.cots[0]`, `self.predicts[1]` and so on. The feedback should be " | ||
"a list of strings, corresponding to each score function in `metrics`." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import inspect | ||
|
||
|
||
class Metric: | ||
def __init__(self, name, description, fn): | ||
self.name = name | ||
self.description = description | ||
self.fn = fn | ||
|
||
def __call__(self, inputs, outputs): | ||
return self.fn(inputs, outputs) | ||
|
||
def __repr__(self): | ||
if self.description: | ||
return f"Metric name: {self.name}\nMetric description: {self.description}\n" | ||
else: | ||
return f"Metric name: {self.name}\nMetric function: {inspect.getsource(self.fn)}\n" | ||
|
||
|
||
class BoolMetric(Metric): | ||
def __init__(self, name, description, fn): | ||
super().__init__(name, description, fn) | ||
self.type_description = "This is a bool metric, true if the metric looks good, false otherwise." | ||
|
||
def __repr__(self): | ||
return f"{super().__repr__()}\nMetric type:{self.type_description}" | ||
|
||
|
||
class FloatMetric(Metric): | ||
def __init__(self, name, description, fn): | ||
super().__init__(name, description, fn) | ||
self.type_description = "This is a float metric, the higher the value the better." | ||
|
||
def __repr__(self): | ||
return f"{super().__repr__()}\nMetric type:{self.type_description}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import inspect | ||
from functools import partial | ||
from types import MethodType | ||
|
||
from dspy.predict.chain_of_thought import ChainOfThought | ||
from dspy.primitives.program import Module | ||
from dspy.refine.feedback import GenerateFeedback | ||
from dspy.refine.utils import get_traces | ||
from dspy.signatures.field import InputField | ||
|
||
|
||
class Refine(Module): | ||
def __init__(self, module, metrics, metric_thresholds=None, max_iter=3): | ||
self.module = module.deepcopy() | ||
self.metrics = metrics | ||
self.metric_thresholds = metric_thresholds | ||
self.max_iter = max_iter | ||
|
||
self.metric_descriptions = [self._get_metric_description(metric) for metric in metrics] | ||
self.feedback_program = ChainOfThought(GenerateFeedback) | ||
|
||
self._named_predicts = {name: predict for name, predict in self.module.named_predictors()} | ||
|
||
def _get_metric_description(self, metric): | ||
if hasattr(metric, "__repr__"): | ||
return str(metric) | ||
else: | ||
return inspect.getsource(metric.__class__) | ||
|
||
def _patch_predict_call_with_feedback(self, feedbacks): | ||
named_predicts = {} | ||
for name in feedbacks.keys(): | ||
# Only patch the predict that has feedback. | ||
named_predicts[name] = self._named_predicts[name] | ||
|
||
predict_traces = get_traces(named_predicts) | ||
|
||
def forward_with_feedback(instance, dspy_refine_feedback, dspy_refine_last_trace, **kwargs): | ||
return instance.original_forward( | ||
**kwargs, | ||
dspy_refine_feedback=dspy_refine_feedback, | ||
dspy_refine_last_trace=dspy_refine_last_trace, | ||
) | ||
|
||
for name, predict in named_predicts.items(): | ||
last_trace = predict_traces.get(name, None) | ||
# We trim out the last round's feedback and last_trace from the inputs to avoid too much nesting. | ||
if "dspy_refine_feedback" in last_trace["inputs"]: | ||
del last_trace["inputs"]["dspy_refine_feedback"] | ||
if "dspy_refine_last_trace" in last_trace["inputs"]: | ||
del last_trace["inputs"]["dspy_refine_last_trace"] | ||
|
||
feedback = feedbacks.get(name, None) | ||
if not hasattr(predict, "original_forward"): | ||
# If the predict has never been patched for refine calls, patch it. | ||
predict.original_signature = predict.signature | ||
predict.signature = predict.signature.prepend( | ||
"dspy_refine_feedback", | ||
InputField(desc="Improvement suggestion based on last try", type=str), | ||
).prepend("dspy_refine_last_trace", InputField(desc="Trace of the last try", type=dict)) | ||
|
||
# Save the original forward method before patching. | ||
predict.original_forward = predict.forward | ||
|
||
partial_forward = partial( | ||
forward_with_feedback, dspy_refine_feedback=feedback, dspy_refine_last_trace=last_trace | ||
) | ||
# Patch the `forward` method to the `forward_with_feedback` methd with partial values of feedback and | ||
# last_trace. | ||
predict.forward = MethodType(partial_forward, predict) | ||
|
||
def _undo_patch_predict_call_with_feedback(self, named_predicts): | ||
for _, predict in named_predicts.items(): | ||
if hasattr(predict, "original_forward"): | ||
predict.forward = predict.original_forward | ||
predict.signature = predict.original_signature | ||
del predict.original_signature | ||
del predict.original_forward | ||
|
||
def _get_feedback_for_predicts(self, inputs, outputs): | ||
metric_descriptions = [] | ||
metric_values = [] | ||
for i, metric in enumerate(self.metrics): | ||
metric_value = metric(inputs, outputs) | ||
if self.metric_thresholds and metric_value < self.metric_thresholds[i]: | ||
metric_descriptions.append(self.metric_descriptions[i]) | ||
metric_values.append(metric_value) | ||
|
||
if len(metric_descriptions) == 0: | ||
# All metric values are above the threshold, no need to refine. | ||
return {} | ||
|
||
# Get feedback for each metric. | ||
feedbacks = self.feedback_program( | ||
metrics=metric_descriptions, | ||
metric_values=metric_values, | ||
module_inputs=inputs, | ||
module_outputs=outputs, | ||
source_code=inspect.getsource(self.module.__class__), | ||
).feedback | ||
named_predicts = self._named_predicts | ||
|
||
predict_name_to_feedback = {} | ||
for name in named_predicts.keys(): | ||
top_module_name = name.split(".")[0] | ||
if top_module_name in feedbacks: | ||
predict_name_to_feedback[name] = feedbacks[top_module_name] | ||
elif f"self.{top_module_name}" in feedbacks: | ||
predict_name_to_feedback[name] = feedbacks[f"self.{top_module_name}"] | ||
return predict_name_to_feedback | ||
|
||
def __call__(self, **kwargs): | ||
outputs = self.module(**kwargs) | ||
|
||
for i in range(self.max_iter): | ||
feedbacks = self._get_feedback_for_predicts(kwargs, outputs) | ||
|
||
if len(feedbacks) == 0: | ||
break | ||
self._patch_predict_call_with_feedback(feedbacks) | ||
|
||
outputs = self.module(**kwargs) | ||
|
||
named_predicts = {name: predict for name, predict in self.module.named_predictors()} | ||
self._undo_patch_predict_call_with_feedback(named_predicts) | ||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from dspy.dsp.utils.settings import settings | ||
|
||
|
||
def get_traces(named_predicts): | ||
predict_name_to_traces = {} | ||
predict_id_to_name = {id(predict): name for name, predict in named_predicts.items()} | ||
|
||
traces = settings.trace | ||
for i in range(len(traces)): | ||
trace = traces[-i - 1] | ||
trace_predict_id = id(trace[0]) | ||
if trace_predict_id in predict_id_to_name: | ||
predict_name = predict_id_to_name[trace_predict_id] | ||
if predict_name not in predict_name_to_traces: | ||
predict_name_to_traces[predict_name] = { | ||
"inputs": trace[1], | ||
"outputs": trace[2].toDict(), | ||
} | ||
if len(predict_name_to_traces) == len(named_predicts): | ||
# Stop searching when all predicts' traces are found. | ||
break | ||
return predict_name_to_traces |