Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding stream to DSPy LMs #338

Open
okhat opened this issue Feb 4, 2024 · 18 comments
Open

Adding stream to DSPy LMs #338

okhat opened this issue Feb 4, 2024 · 18 comments
Labels
deployment help wanted Contributions are welcome!

Comments

@okhat
Copy link
Collaborator

okhat commented Feb 4, 2024

A few members of the community have been asking for support for streaming LM output in DSPy.

@sutyum and @detaos have discussed this extensively before.

One of the challenges is that it's not even clear how to stream a DSPy program (it makes many calls to the LM, in arbitrary order, decided by the user). It's not really possible to stream that in general case.

However, we can add support per LM calls, e.g. in dspy.Predict.

The key thing is that Predict actually outputs multiple fields so the support for streaming will need to return an object (Prediction) or StreamedPrediction (maybe?) and each field in there (e.g., rationale and answer) will need to stream separately.

I don't have any insight into what that would take but it sounds simple overall.

@AriMKatz
Copy link

AriMKatz commented Feb 4, 2024

I'll think about this. Support per LM call would be sufficient (and what we're doing anyway).

For reference, here's from the langchhain-dspy notebook:

If we convert this into a fuller integration, all users stand to benefit. LangChain users will gain the ability to optimize any chain with any DSPy optimizer. DSPy users will gain the ability to export any DSPy program into an LCEL that supports streaming and tracing, and other rich production-targeted features in LangChain.

But that sounds harder.

@okhat okhat added help wanted Contributions are welcome! deployment labels Feb 11, 2024
@CyrusNuevoDia
Copy link
Collaborator

This sounds super cool, similar to #249 I think the broader question is "how does DSPy fit into the productionization workflow" and something we can think more about to come up with an elegant approach.

@motin
Copy link

motin commented May 24, 2024

Three different aspects of dspy streaming:

  • For debugging / understanding: Being able to see the LLM chatter streaming gives insight into what is actually going on, even though the chatter may be completely incoherent if there are multiple concurrent streams, structured output etc, it can help catch issues early without having to wait for the whole program to finish.
  • For lower latency UI: Being able to subscribe to changes to the final output of the program, be it a stream of text or a partial object
  • For progress visualization: Being able to subscribe to state changes in the execution of the program, so that it is possible to understand how far the execution has progressed so far

@pedroallenrevez
Copy link

Any news on this? I really like dspy, but without streaming it's hard to use in production

@theta-lin
Copy link
Contributor

theta-lin commented Aug 5, 2024

My take on this is to implement the following two parts:

  1. For the LM class, expose a streaming counterpart of the basic_request() abstract method. I suggest that it should be up to the LMs to implement something like basic_streaming_request(), which should be similar to OpenAI's streaming API that takes a request and returns a response generator.
  2. Implement a streaming counterpart to dsp.adapters.Template.extract() so that it parses the streaming response returned by LM.basic_streaming_request() on the fly and returns the response generators corresponding to each output field. The user can then retrieve the streaming response for a field by reading its corresponding generator. However, I think implementing streaming parsing would be the difficult part.

Right now I am working on a temporary workaround as I only need streaming for the final response synthesis module. As my project is a RAG pipeline, I only need the last part that responds to the user to be streaming. Thus, I will try to export the compiled template as a string, then make the call to the LM and parse the streaming response myself.

@pedroallenrevez
Copy link

@theta-lin if you find how to compile the string with dspy without having to check the history (having made a call) or a wrapper around that (doing a class that sends null response, and then you check history), do tell.

@theta-lin
Copy link
Contributor

@pedroallenrevez I do have a solution here, it is adapted from

def get_formatted_template(predict_module: Predict, kwargs: Dict[str, Any]) -> str:

def get_template(predict_module: dspy.Predict, **kwargs) -> str:
    """Get formatted template from predict module."""

    # (I suddenly realized what these comment mean now... they are copied from `old_generate()`)
    # Extract the three privileged keyword arguments.
    signature = ensure_signature(predict_module.signature)
    # Switch to legacy format for dsp.generate
    template = signature_to_template(signature)

   # The only difference from the original code, to make it work with uncompiled predictors
    if hasattr(predict_module, "demos"):
        demos = predict_module.demos
    else:
        demos = []
    # All of the other kwargs are presumed to fit a prefix of the signature.
    # That is, they are input variables for the bottom most generation, so
    # we place them inside the input - x - together with the demos.
    x = dsp.Example(demos=demos, **kwargs)
    return template(x)

As it happens that I indeed finished a workaround, here is an example of both extracting the prompt template string and actually implementing the streaming response:

class Example(dspy.Module):
    self.synthesizer = dspy.Predict(...)
    self.llm = ...


    def forward(self, query):
        synthesizer_template = get_template(
            # Suppose that we have a predict module called `synthesizer` that takes in an InputField
            # called `query` and outputs an OutputField called `response`
            self.synthesizer,
            # Just pass in all the input fields as kwargs
            query=query,
        )

        def parse_gen():
            """
            A generator that returns the part after "Response:" and strips whitespace.

            In other words, it's like `dsp.adapters.Template.extract()` but only for one field.
            The assumption is that you ONLY want to extract the `response` field and
            that the "response" field is the LAST field of the output. I only implemented
            this because I am using CoT for my actual code so that there is a `rationale`
            field proceeding the `response` field.
            """

            # Most of these are just for stripping the whitespace at the beginning and
            # the end, but preserving those in the middle.
            def rstripped(s):
                from itertools import takewhile

                """Extract the trailing whitespace itself."""
                return "".join(reversed(tuple(takewhile(str.isspace, reversed(s)))))

            field = "Response:"
            # Suppose that you have an `llm` class that returns a response generator for
            # its `stream_complete()` method
            gen = self.llm.stream_complete(synthesizer_template)
            before_response = ""
            for r in gen:
                # r.delta is the amount of "delta" LLM output, or, the new tokens
                before_response += r.delta
                offset = before_response.find(field)
                if offset != -1:
                    s = before_response[offset + len(field) :]
                    if s.strip():
                        yield s.strip()
                        prev_whitespace = rstripped(s)
                        break

            for r in gen:
                s = r.delta
                yield prev_whitespace + s.rstrip()
                prev_whitespace = rstripped(s)

        # The `response` is a generator for the actual streaming response
        return dspy.Prediction(response=parse_gen())

@CyrusNuevoDia
Copy link
Collaborator

@pedroallenrevez my LinkedIn DMs are open, let's chat

@MohammedAlhajji
Copy link
Contributor

I agree with @pedroallenrevez. Currently my whole pipeline(~3 components) is in DSPy, except the last step(final answer generation). I don't care about streaming in between, I just want the final answer to be streamed. I currently mock the prompt and send it to LLM via a vllm client and stream the answer.

would be nice if there was some streaming support for cases where there are only one output field.

@LouisCastricato
Copy link

Is there any update on this?

@CyrusNuevoDia
Copy link
Collaborator

CyrusNuevoDia commented Nov 1, 2024

@LouisCastricato after async support in #1734 this is the next thing I'll be tackling

@aazizisoufiane
Copy link

aazizisoufiane commented Nov 25, 2024

Hello @CyrusOfEden, @pedroallenrevez , I've developed an enhancement to DSPy's streaming capabilities that enables streaming of specific fields. This adaptation allows to stream individual fields from DSPy's output by specifying a field_to_stream parameter.

Here's the key implementation:

import asyncio
from dataclasses import dataclass
from typing import AsyncIterator, Optional


import dspy
# import dspy
import litellm
from dspy.signatures.signature import ensure_signature

# import dspy
llm = dspy.LM()
dspy.configure(lm=llm, backoff_time=240)


@dataclass
class StreamingResult:
    """Container for both streaming tokens and final parsed prediction"""
    token: str
    prediction: Optional[dspy.Prediction] = None


class StreamingChatAdapter(dspy.ChatAdapter):
    async def stream(self, lm, lm_kwargs, signature, demos, inputs, field_to_stream=None, _parse_values=True) -> \
            AsyncIterator[str]:
        messages = self.format(signature, demos, inputs)
        completions = ""
        in_target_field = False
        current_field = None
        buffer = ""

        async for chunk in lm.astream(messages, **lm_kwargs):
            if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
                token = chunk.choices[0].delta.content
                completions += token

                # Check for field markers
                if "[[" in token:
                    buffer = token
                    current_field = None
                    in_target_field = False
                    continue

                if buffer:
                    buffer += token
                    if "]]" in buffer:
                        field_marker = buffer.split("## ")[1].split(" ##")[0].strip()
                        current_field = field_marker
                        in_target_field = (field_to_stream is None) or (field_marker == field_to_stream)
                        # in_target_field = field_marker == field_to_stream
                        buffer = ""
                    continue

                try:
                    parsed_fields = self.parse(signature, completions, _parse_values=True)
                    prediction = dspy.Prediction(parsed_fields)
                    if in_target_field:
                        yield StreamingResult(token=token, prediction=prediction)
                    elif prediction:  # Store prediction even if we're not streaming this field
                        yield StreamingResult(token="", prediction=prediction)
                except (ValueError, KeyError):
                    if in_target_field:
                        yield StreamingResult(token=token)
                    else:
                        yield StreamingResult(token="")


class StreamingLM(dspy.LM):
    async def astream(self, messages, **kwargs):
        combined_kwargs = {**self.kwargs, **kwargs}
        response = litellm.acompletion(
            model=self.model,
            messages=messages,
            stream=True,
            **combined_kwargs
        )
        async for chunk in await response:
            yield chunk


class AsyncStreamingPredict(dspy.Predict):
    async def forward_stream(self, field_to_stream, **kwargs):
        signature = ensure_signature(kwargs.pop("signature", self.signature))
        demos = kwargs.pop("demos", self.demos)
        config = dict(**self.config, **kwargs.pop("config", {}))

        lm = kwargs.pop("lm", self.lm) or dspy.settings.lm
        assert lm is not None, "No LM is loaded."

        if not isinstance(lm, StreamingLM):
            lm = StreamingLM(lm.model, **lm.kwargs)

        adapter = StreamingChatAdapter()
        async for result in adapter.stream(lm, config, signature, demos, kwargs, field_to_stream=field_to_stream,
                                           _parse_values=True):
            yield result


class Joke(dspy.Signature):
    """
    write a joke
    """
    title: str = dspy.InputField()
    answer: str = dspy.OutputField()


class WriteJoke(dspy.Module):
    def __init__(self):
        self.predict = AsyncStreamingPredict("title->reasoning, answer")

    async def forward(self, title: str, field_to_stream: str = None):
        last_prediction = ""
        prediction = self.predict.forward_stream(
            title=title,
            field_to_stream=field_to_stream
        )
        async for result in prediction:
            if result.token:  # Only print if there's a token (i.e., we're in the answer field)
                print(result.token, end="", flush=True)
            if result.prediction:
                last_prediction = result.prediction
            await asyncio.sleep(0)

        print(last_prediction)


import nest_asyncio

if __name__ == '__main__':
    nest_asyncio.apply()
    # asyncio.run(main())
    predict_joke = WriteJoke()
    asyncio.run(predict_joke.forward(title="Tell me a joke about AI & humanity ",
                                     field_to_stream="answer"))

Let me know if you'd like to discuss this implementation or if you have any suggestions for improvement.

@CyrusNuevoDia
Copy link
Collaborator

Very cool! I think a sync form could make it into main -- async forwards aren't compatible with optimizers afaik.

We're approaching this from a "sync on the inside", "async on the outside" manner @aazizisoufiane

@aazizisoufiane
Copy link

@CyrusOfEden something like this?

from typing import AsyncIterator, Optional
from dataclasses import dataclass
import dspy
from dspy.signatures.signature import ensure_signature
import litellm

@dataclass
class StreamingResult:
    """Container for both streaming tokens and final parsed prediction"""
    token: str
    prediction: Optional[dspy.Prediction] = None

class StreamingChatAdapter(dspy.ChatAdapter):
    async def stream(self, lm, lm_kwargs, signature, demos, inputs, field_to_stream=None, _parse_values=True) -> AsyncIterator[StreamingResult]:
        messages = self.format(signature, demos, inputs)
        completions = ""
        in_target_field = False
        buffer = ""

        async for chunk in lm.astream(messages, **lm_kwargs):
            if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
                token = chunk.choices[0].delta.content
                completions += token

                # Check for field markers
                if "[[" in token:
                    buffer = token
                    in_target_field = False
                    continue

                if buffer:
                    buffer += token
                    if "]]" in buffer:
                        field_marker = buffer.split("## ")[1].split(" ##")[0].strip()
                        in_target_field = (field_to_stream is None) or (field_marker == field_to_stream)
                        buffer = ""
                    continue

                try:
                    parsed_fields = self.parse(signature, completions, _parse_values=True)
                    prediction = dspy.Prediction(parsed_fields)
                    if in_target_field:
                        yield StreamingResult(token=token, prediction=prediction)
                    elif prediction:
                        yield StreamingResult(token="", prediction=prediction)
                except (ValueError, KeyError):
                    if in_target_field:
                        yield StreamingResult(token=token)
                    else:
                        yield StreamingResult(token="")

class StreamingLM(dspy.LM):
    async def astream(self, messages, **kwargs):
        """Async streaming interface using litellm"""
        combined_kwargs = {**self.kwargs, **kwargs}
        response = await litellm.acompletion(
            model=self.model,
            messages=messages,
            stream=True,
            **combined_kwargs
        )
        async for chunk in response:
            yield chunk

class StreamingPredict(dspy.Predict):
    async def aforward_stream(self, field_to_stream, **kwargs):
        """Streaming-specific async forward implementation"""
        signature = ensure_signature(kwargs.pop("signature", self.signature))
        demos = kwargs.pop("demos", self.demos)
        config = dict(**self.config, **kwargs.pop("config", {}))

        lm = kwargs.pop("lm", self.lm) or dspy.settings.lm
        assert lm is not None, "No LM is loaded."

        if not isinstance(lm, StreamingLM):
            lm = StreamingLM(lm.model, **lm.kwargs)

        adapter = StreamingChatAdapter()
        async for result in adapter.stream(
            lm, 
            config, 
            signature, 
            demos, 
            kwargs, 
            field_to_stream=field_to_stream,
            _parse_values=True
        ):
            yield result

    async def aforward(self, **kwargs):
        """Standard async forward implementation"""
        field_to_stream = kwargs.pop("field_to_stream", None)
        if field_to_stream:
            # Collect all streaming results to return final prediction
            last_prediction = None
            async for result in self.aforward_stream(field_to_stream, **kwargs):
                if result.prediction:
                    last_prediction = result.prediction
            return last_prediction
        else:
            # Use standard predict forward
            return await super().aforward(**kwargs)

class Joke(dspy.Signature):
    """Write a joke"""
    title: str = dspy.InputField()
    answer: str = dspy.OutputField()

class WriteJoke(dspy.Module):
    def __init__(self):
        self.predict = StreamingPredict("title->reasoning, answer")
    
    async def aforward(self, title: str, field_to_stream: str = None):
        """Async forward implementation with optional streaming"""
        last_prediction = None
        async for result in self.predict.aforward_stream(
            title=title,
            field_to_stream=field_to_stream
        ):
            if result.token:
                print(result.token, end="", flush=True)
            if result.prediction:
                last_prediction = result.prediction
        
        return last_prediction

if __name__ == '__main__':
    joke_writer = WriteJoke()
    result = await joke_writer.aforward(
        title="Tell me a joke about Humanity & AI",
        field_to_stream="answer"
    )

@CyrusNuevoDia
Copy link
Collaborator

CyrusNuevoDia commented Nov 26, 2024

@aazizisoufiane I'll take a look at this this week, ETA Monday for a PR

@CyrusNuevoDia
Copy link
Collaborator

@aazizisoufiane check out #1874 🔥

@aazizisoufiane
Copy link

@CyrusNuevoDia , thank you for your fix! I noticed that all outputs are being streamed, including reasoning sections (e.g., [[ "" reasoning "" ]]) and placeholders like "completed", without any cleanup or formatting.

Should we consider selectively streaming specific fields, such as just the "answer" field, instead of all raw outputs? This could provide a cleaner and more user-friendly experience.

@CyrusNuevoDia
Copy link
Collaborator

CyrusNuevoDia commented Nov 29, 2024

Hey @aazizisoufiane! The MVP of streaming was to not modify the LM. The V2 I'm thinking to be to stream partial signatures from every step :)

There's still 3 days until Monday ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
deployment help wanted Contributions are welcome!
Projects
None yet
Development

No branches or pull requests

9 participants