-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding stream to DSPy LMs #338
Comments
I'll think about this. Support per LM call would be sufficient (and what we're doing anyway). For reference, here's from the langchhain-dspy notebook:
But that sounds harder. |
This sounds super cool, similar to #249 I think the broader question is "how does DSPy fit into the productionization workflow" and something we can think more about to come up with an elegant approach. |
Three different aspects of dspy streaming:
|
Any news on this? I really like dspy, but without streaming it's hard to use in production |
My take on this is to implement the following two parts:
Right now I am working on a temporary workaround as I only need streaming for the final response synthesis module. As my project is a RAG pipeline, I only need the last part that responds to the user to be streaming. Thus, I will try to export the compiled template as a string, then make the call to the LM and parse the streaming response myself. |
@theta-lin if you find how to compile the string with dspy without having to check the history (having made a call) or a wrapper around that (doing a class that sends null response, and then you check history), do tell. |
@pedroallenrevez I do have a solution here, it is adapted from dspy/dspy/predict/llamaindex.py Line 22 in 55510ee
def get_template(predict_module: dspy.Predict, **kwargs) -> str:
"""Get formatted template from predict module."""
# (I suddenly realized what these comment mean now... they are copied from `old_generate()`)
# Extract the three privileged keyword arguments.
signature = ensure_signature(predict_module.signature)
# Switch to legacy format for dsp.generate
template = signature_to_template(signature)
# The only difference from the original code, to make it work with uncompiled predictors
if hasattr(predict_module, "demos"):
demos = predict_module.demos
else:
demos = []
# All of the other kwargs are presumed to fit a prefix of the signature.
# That is, they are input variables for the bottom most generation, so
# we place them inside the input - x - together with the demos.
x = dsp.Example(demos=demos, **kwargs)
return template(x) As it happens that I indeed finished a workaround, here is an example of both extracting the prompt template string and actually implementing the streaming response: class Example(dspy.Module):
self.synthesizer = dspy.Predict(...)
self.llm = ...
def forward(self, query):
synthesizer_template = get_template(
# Suppose that we have a predict module called `synthesizer` that takes in an InputField
# called `query` and outputs an OutputField called `response`
self.synthesizer,
# Just pass in all the input fields as kwargs
query=query,
)
def parse_gen():
"""
A generator that returns the part after "Response:" and strips whitespace.
In other words, it's like `dsp.adapters.Template.extract()` but only for one field.
The assumption is that you ONLY want to extract the `response` field and
that the "response" field is the LAST field of the output. I only implemented
this because I am using CoT for my actual code so that there is a `rationale`
field proceeding the `response` field.
"""
# Most of these are just for stripping the whitespace at the beginning and
# the end, but preserving those in the middle.
def rstripped(s):
from itertools import takewhile
"""Extract the trailing whitespace itself."""
return "".join(reversed(tuple(takewhile(str.isspace, reversed(s)))))
field = "Response:"
# Suppose that you have an `llm` class that returns a response generator for
# its `stream_complete()` method
gen = self.llm.stream_complete(synthesizer_template)
before_response = ""
for r in gen:
# r.delta is the amount of "delta" LLM output, or, the new tokens
before_response += r.delta
offset = before_response.find(field)
if offset != -1:
s = before_response[offset + len(field) :]
if s.strip():
yield s.strip()
prev_whitespace = rstripped(s)
break
for r in gen:
s = r.delta
yield prev_whitespace + s.rstrip()
prev_whitespace = rstripped(s)
# The `response` is a generator for the actual streaming response
return dspy.Prediction(response=parse_gen()) |
@pedroallenrevez my LinkedIn DMs are open, let's chat |
I agree with @pedroallenrevez. Currently my whole pipeline(~3 components) is in DSPy, except the last step(final answer generation). I don't care about streaming in between, I just want the final answer to be streamed. I currently mock the prompt and send it to LLM via a vllm client and stream the answer. would be nice if there was some streaming support for cases where there are only one output field. |
Is there any update on this? |
@LouisCastricato after async support in #1734 this is the next thing I'll be tackling |
Hello @CyrusOfEden, @pedroallenrevez , I've developed an enhancement to DSPy's streaming capabilities that enables streaming of specific fields. This adaptation allows to stream individual fields from DSPy's output by specifying a Here's the key implementation: import asyncio
from dataclasses import dataclass
from typing import AsyncIterator, Optional
import dspy
# import dspy
import litellm
from dspy.signatures.signature import ensure_signature
# import dspy
llm = dspy.LM()
dspy.configure(lm=llm, backoff_time=240)
@dataclass
class StreamingResult:
"""Container for both streaming tokens and final parsed prediction"""
token: str
prediction: Optional[dspy.Prediction] = None
class StreamingChatAdapter(dspy.ChatAdapter):
async def stream(self, lm, lm_kwargs, signature, demos, inputs, field_to_stream=None, _parse_values=True) -> \
AsyncIterator[str]:
messages = self.format(signature, demos, inputs)
completions = ""
in_target_field = False
current_field = None
buffer = ""
async for chunk in lm.astream(messages, **lm_kwargs):
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
token = chunk.choices[0].delta.content
completions += token
# Check for field markers
if "[[" in token:
buffer = token
current_field = None
in_target_field = False
continue
if buffer:
buffer += token
if "]]" in buffer:
field_marker = buffer.split("## ")[1].split(" ##")[0].strip()
current_field = field_marker
in_target_field = (field_to_stream is None) or (field_marker == field_to_stream)
# in_target_field = field_marker == field_to_stream
buffer = ""
continue
try:
parsed_fields = self.parse(signature, completions, _parse_values=True)
prediction = dspy.Prediction(parsed_fields)
if in_target_field:
yield StreamingResult(token=token, prediction=prediction)
elif prediction: # Store prediction even if we're not streaming this field
yield StreamingResult(token="", prediction=prediction)
except (ValueError, KeyError):
if in_target_field:
yield StreamingResult(token=token)
else:
yield StreamingResult(token="")
class StreamingLM(dspy.LM):
async def astream(self, messages, **kwargs):
combined_kwargs = {**self.kwargs, **kwargs}
response = litellm.acompletion(
model=self.model,
messages=messages,
stream=True,
**combined_kwargs
)
async for chunk in await response:
yield chunk
class AsyncStreamingPredict(dspy.Predict):
async def forward_stream(self, field_to_stream, **kwargs):
signature = ensure_signature(kwargs.pop("signature", self.signature))
demos = kwargs.pop("demos", self.demos)
config = dict(**self.config, **kwargs.pop("config", {}))
lm = kwargs.pop("lm", self.lm) or dspy.settings.lm
assert lm is not None, "No LM is loaded."
if not isinstance(lm, StreamingLM):
lm = StreamingLM(lm.model, **lm.kwargs)
adapter = StreamingChatAdapter()
async for result in adapter.stream(lm, config, signature, demos, kwargs, field_to_stream=field_to_stream,
_parse_values=True):
yield result
class Joke(dspy.Signature):
"""
write a joke
"""
title: str = dspy.InputField()
answer: str = dspy.OutputField()
class WriteJoke(dspy.Module):
def __init__(self):
self.predict = AsyncStreamingPredict("title->reasoning, answer")
async def forward(self, title: str, field_to_stream: str = None):
last_prediction = ""
prediction = self.predict.forward_stream(
title=title,
field_to_stream=field_to_stream
)
async for result in prediction:
if result.token: # Only print if there's a token (i.e., we're in the answer field)
print(result.token, end="", flush=True)
if result.prediction:
last_prediction = result.prediction
await asyncio.sleep(0)
print(last_prediction)
import nest_asyncio
if __name__ == '__main__':
nest_asyncio.apply()
# asyncio.run(main())
predict_joke = WriteJoke()
asyncio.run(predict_joke.forward(title="Tell me a joke about AI & humanity ",
field_to_stream="answer")) Let me know if you'd like to discuss this implementation or if you have any suggestions for improvement. |
Very cool! I think a sync form could make it into main -- async forwards aren't compatible with optimizers afaik. We're approaching this from a "sync on the inside", "async on the outside" manner @aazizisoufiane |
@CyrusOfEden something like this? from typing import AsyncIterator, Optional
from dataclasses import dataclass
import dspy
from dspy.signatures.signature import ensure_signature
import litellm
@dataclass
class StreamingResult:
"""Container for both streaming tokens and final parsed prediction"""
token: str
prediction: Optional[dspy.Prediction] = None
class StreamingChatAdapter(dspy.ChatAdapter):
async def stream(self, lm, lm_kwargs, signature, demos, inputs, field_to_stream=None, _parse_values=True) -> AsyncIterator[StreamingResult]:
messages = self.format(signature, demos, inputs)
completions = ""
in_target_field = False
buffer = ""
async for chunk in lm.astream(messages, **lm_kwargs):
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
token = chunk.choices[0].delta.content
completions += token
# Check for field markers
if "[[" in token:
buffer = token
in_target_field = False
continue
if buffer:
buffer += token
if "]]" in buffer:
field_marker = buffer.split("## ")[1].split(" ##")[0].strip()
in_target_field = (field_to_stream is None) or (field_marker == field_to_stream)
buffer = ""
continue
try:
parsed_fields = self.parse(signature, completions, _parse_values=True)
prediction = dspy.Prediction(parsed_fields)
if in_target_field:
yield StreamingResult(token=token, prediction=prediction)
elif prediction:
yield StreamingResult(token="", prediction=prediction)
except (ValueError, KeyError):
if in_target_field:
yield StreamingResult(token=token)
else:
yield StreamingResult(token="")
class StreamingLM(dspy.LM):
async def astream(self, messages, **kwargs):
"""Async streaming interface using litellm"""
combined_kwargs = {**self.kwargs, **kwargs}
response = await litellm.acompletion(
model=self.model,
messages=messages,
stream=True,
**combined_kwargs
)
async for chunk in response:
yield chunk
class StreamingPredict(dspy.Predict):
async def aforward_stream(self, field_to_stream, **kwargs):
"""Streaming-specific async forward implementation"""
signature = ensure_signature(kwargs.pop("signature", self.signature))
demos = kwargs.pop("demos", self.demos)
config = dict(**self.config, **kwargs.pop("config", {}))
lm = kwargs.pop("lm", self.lm) or dspy.settings.lm
assert lm is not None, "No LM is loaded."
if not isinstance(lm, StreamingLM):
lm = StreamingLM(lm.model, **lm.kwargs)
adapter = StreamingChatAdapter()
async for result in adapter.stream(
lm,
config,
signature,
demos,
kwargs,
field_to_stream=field_to_stream,
_parse_values=True
):
yield result
async def aforward(self, **kwargs):
"""Standard async forward implementation"""
field_to_stream = kwargs.pop("field_to_stream", None)
if field_to_stream:
# Collect all streaming results to return final prediction
last_prediction = None
async for result in self.aforward_stream(field_to_stream, **kwargs):
if result.prediction:
last_prediction = result.prediction
return last_prediction
else:
# Use standard predict forward
return await super().aforward(**kwargs)
class Joke(dspy.Signature):
"""Write a joke"""
title: str = dspy.InputField()
answer: str = dspy.OutputField()
class WriteJoke(dspy.Module):
def __init__(self):
self.predict = StreamingPredict("title->reasoning, answer")
async def aforward(self, title: str, field_to_stream: str = None):
"""Async forward implementation with optional streaming"""
last_prediction = None
async for result in self.predict.aforward_stream(
title=title,
field_to_stream=field_to_stream
):
if result.token:
print(result.token, end="", flush=True)
if result.prediction:
last_prediction = result.prediction
return last_prediction
if __name__ == '__main__':
joke_writer = WriteJoke()
result = await joke_writer.aforward(
title="Tell me a joke about Humanity & AI",
field_to_stream="answer"
) |
@aazizisoufiane I'll take a look at this this week, ETA Monday for a PR |
@aazizisoufiane check out #1874 🔥 |
@CyrusNuevoDia , thank you for your fix! I noticed that all outputs are being streamed, including reasoning sections (e.g., [[ "" reasoning "" ]]) and placeholders like "completed", without any cleanup or formatting. Should we consider selectively streaming specific fields, such as just the "answer" field, instead of all raw outputs? This could provide a cleaner and more user-friendly experience. |
Hey @aazizisoufiane! The MVP of streaming was to not modify the LM. The V2 I'm thinking to be to stream partial signatures from every step :) There's still 3 days until Monday ;) |
A few members of the community have been asking for support for streaming LM output in DSPy.
@sutyum and @detaos have discussed this extensively before.
One of the challenges is that it's not even clear how to stream a DSPy program (it makes many calls to the LM, in arbitrary order, decided by the user). It's not really possible to stream that in general case.
However, we can add support per LM calls, e.g. in
dspy.Predict
.The key thing is that
Predict
actually outputs multiple fields so the support for streaming will need to return an object (Prediction
) orStreamedPrediction
(maybe?) and each field in there (e.g.,rationale
andanswer
) will need to stream separately.I don't have any insight into what that would take but it sounds simple overall.
The text was updated successfully, but these errors were encountered: