Skip to content

Commit

Permalink
Streaming (#1874)
Browse files Browse the repository at this point in the history
* dspy.streamify

* Update docs

* Fix ruff lint error

* Bring back send_stream to settings

* Improve doc

* Bring back request_cache setting

* sse => streaming_response

* Simplify dsp.utils.settings diff

* Add load/dump to LRUCache + drop callable request params

* ujson => pickle for dump/load

* Stream fix

Signed-off-by: dbczumar <[email protected]>

* test streaming

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* Streaming works

Signed-off-by: dbczumar <[email protected]>

* Fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* no ignore change

Signed-off-by: dbczumar <[email protected]>

* Simple init

* Simple init

---------

Signed-off-by: dbczumar <[email protected]>
Co-authored-by: dbczumar <[email protected]>
  • Loading branch information
CyrusNuevoDia and dbczumar authored Dec 17, 2024
1 parent 7e102fe commit 027312b
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 16 deletions.
41 changes: 37 additions & 4 deletions docs/docs/tutorials/deployment/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ class Question(BaseModel):
# Configure your language model and 'asyncify' your DSPy program.
lm = dspy.LM("openai/gpt-4o-mini")
dspy.settings.configure(lm=lm, async_max_workers=4) # default is 8
dspy_program = dspy.ChainOfThought("question -> answer")
dspy_program = dspy.asyncify(dspy_program)

dspy_program = dspy.asyncify(dspy.ChainOfThought("question -> answer"))
streaming_dspy_program = dspy.streamify(dspy_program)

# Define an endpoint (no streaming)
@app.post("/predict")
async def predict(question: Question):
try:
Expand All @@ -54,14 +56,45 @@ async def predict(question: Question):
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

# Define an endpoint (streaming)
from fastapi.responses import StreamingResponse

@app.post("/predict/stream")
async def stream(question: Question):
async def generate():
async for value in streaming_dspy_program(question=question.text):
if isinstance(value, dspy.Prediction):
data = {"prediction": value.labels().toDict()}
elif isinstance(value, litellm.ModelResponse):
data = {"chunk": value.json()}
yield f"data: {ujson.dumps(data)}\n\n"
yield "data: [DONE]\n\n"

return StreamingResponse(generate(), media_type="text/event-stream")

# Since you're often going to want to stream the result of a DSPy program as server-sent events,
# we've included a helper function for that, which is equivalent to the code above.

from dspy.utils.streaming import streaming_response

@app.post("/predict/stream")
async def stream(question: Question):
stream = streaming_dspy_program(question=question.text)
return StreamingResponse(streaming_response(stream), media_type="text/event-stream")
```

In the code above, we call `dspy.asyncify` to convert the dspy program to run in async mode for high-throughput FastAPI
deployments. Currently, this runs the dspy program in a
separate thread and awaits its result. By default, the limit of spawned threads is 8. Think of this like a worker pool.
deployments. Currently, this runs the dspy program in a separate thread and awaits its result.

By default, the limit of spawned threads is 8. Think of this like a worker pool.
If you have 8 in-flight programs and call it once more, the 9th call will wait until one of the 8 returns.
You can configure the async capacity using the new `async_max_workers` setting.

We also use `dspy.streamify` to convert the dspy program to a streaming mode. This is useful when you want to stream
the intermediate outputs (i.e. O1-style reasoning) to the client before the final prediction is ready. This uses
asyncify under the hood and inherits the execution semantics.

Write your code to a file, e.g., `fastapi_dspy.py`. Then you can serve the app with:

```bash
Expand Down
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
from dspy.utils.asyncify import asyncify
from dspy.utils.saving import load
from dspy.utils.streaming import streamify

from dspy.dsp.utils.settings import settings

Expand Down
36 changes: 32 additions & 4 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
import uuid
from datetime import datetime
from hashlib import sha256
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional, cast

import litellm
import pydantic
import ujson
from anyio.streams.memory import MemoryObjectSendStream
from asyncer import syncify
from cachetools import LRUCache, cached
from litellm import RetryPolicy

import dspy
from dspy.adapters.base import Adapter
from dspy.clients.openai import OpenAIProvider
from dspy.clients.provider import Provider, TrainingJob
Expand Down Expand Up @@ -309,16 +312,41 @@ def cached_litellm_completion(request: Dict[str, Any], num_retries: int):


def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
return litellm.completion(
cache=cache,
retry_kwargs = dict(
retry_policy=_get_litellm_retry_policy(num_retries),
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
# to completion()), the default value of max_retries is non-zero for certain providers, and
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
max_retries=0,
**request,
)

stream = dspy.settings.send_stream
if stream is None:
return litellm.completion(
cache=cache,
**retry_kwargs,
**request,
)

# The stream is already opened, and will be closed by the caller.
stream = cast(MemoryObjectSendStream, stream)

@syncify
async def stream_completion():
response = await litellm.acompletion(
cache=cache,
stream=True,
**retry_kwargs,
**request,
)
chunks = []
async for chunk in response:
chunks.append(chunk)
await stream.send(chunk)
return litellm.stream_chunk_builder(chunks)

return stream_completion()


@request_cache(maxsize=None)
def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int):
Expand Down
20 changes: 12 additions & 8 deletions dspy/dsp/utils/settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import threading
from contextlib import contextmanager

from dspy.dsp.utils.utils import dotdict

DEFAULT_CONFIG = dotdict(
Expand All @@ -17,6 +18,7 @@
backoff_time=10,
callbacks=[],
async_max_workers=8,
send_stream=None,
)

# Global base configuration and owner tracking
Expand All @@ -26,20 +28,22 @@
# Global lock for settings configuration
global_lock = threading.Lock()


class ThreadLocalOverrides(threading.local):
def __init__(self):
self.overrides = dotdict()


thread_local_overrides = ThreadLocalOverrides()


class Settings:
"""
A singleton class for DSPy configuration settings.
Thread-safe global configuration.
Thread-safe global configuration.
- 'configure' can be called by only one 'owner' thread (the first thread that calls it).
- Other threads see the configured global values from 'main_thread_config'.
- 'context' sets thread-local overrides. These overrides propagate to threads spawned
- 'context' sets thread-local overrides. These overrides propagate to threads spawned
inside that context block, when (and only when!) using a ParallelExecutor that copies overrides.
1. Only one unique thread (which can be any thread!) can call dspy.configure.
Expand All @@ -61,7 +65,7 @@ def lock(self):
return global_lock

def __getattr__(self, name):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
overrides = getattr(thread_local_overrides, "overrides", dotdict())
if name in overrides:
return overrides[name]
elif name in main_thread_config:
Expand All @@ -70,7 +74,7 @@ def __getattr__(self, name):
raise AttributeError(f"'Settings' object has no attribute '{name}'")

def __setattr__(self, name, value):
if name in ('_instance',):
if name in ("_instance",):
super().__setattr__(name, value)
else:
self.configure(**{name: value})
Expand All @@ -82,7 +86,7 @@ def __setitem__(self, key, value):
self.__setattr__(key, value)

def __contains__(self, key):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
overrides = getattr(thread_local_overrides, "overrides", dotdict())
return key in overrides or key in main_thread_config

def get(self, key, default=None):
Expand All @@ -92,7 +96,7 @@ def get(self, key, default=None):
return default

def copy(self):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
overrides = getattr(thread_local_overrides, "overrides", dotdict())
return dotdict({**main_thread_config, **overrides})

@property
Expand Down Expand Up @@ -122,7 +126,7 @@ def context(self, **kwargs):
If threads are spawned inside this block using ParallelExecutor, they will inherit these overrides.
"""

original_overrides = getattr(thread_local_overrides, 'overrides', dotdict()).copy()
original_overrides = getattr(thread_local_overrides, "overrides", dotdict()).copy()
new_overrides = dotdict({**main_thread_config, **original_overrides, **kwargs})
thread_local_overrides.overrides = new_overrides

Expand All @@ -132,7 +136,7 @@ def context(self, **kwargs):
thread_local_overrides.overrides = original_overrides

def __repr__(self):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
overrides = getattr(thread_local_overrides, "overrides", dotdict())
combined_config = {**main_thread_config, **overrides}
return repr(combined_config)

Expand Down
87 changes: 87 additions & 0 deletions dspy/utils/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from asyncio import iscoroutinefunction
from typing import Any, AsyncGenerator, Awaitable, Callable

import litellm
import ujson
from anyio import create_memory_object_stream, create_task_group
from anyio.streams.memory import MemoryObjectSendStream

from dspy.primitives.prediction import Prediction
from dspy.primitives.program import Module
from dspy.utils.asyncify import asyncify


def streamify(program: Module) -> Callable[[Any, Any], Awaitable[Any]]:
"""
Wrap a DSPy program so that it streams its outputs incrementally, rather than returning them
all at once.
Args:
program: The DSPy program to wrap with streaming functionality.
Returns:
A function that takes the same arguments as the original program, but returns an async
generator that yields the program's outputs incrementally.
Example:
>>> class TestSignature(dspy.Signature):
>>> input_text: str = dspy.InputField()
>>> output_text: str = dspy.OutputField()
>>>
>>> # Create the program and wrap it with streaming functionality
>>> program = dspy.streamify(dspy.Predict(TestSignature))
>>>
>>> # Use the program with streaming output
>>> async def use_streaming():
>>> output_stream = program(input_text="Test")
>>> async for value in output_stream:
>>> print(value) # Print each streamed value incrementally
"""
import dspy

if not iscoroutinefunction(program):
program = asyncify(program)

async def generator(args, kwargs, stream: MemoryObjectSendStream):
with dspy.settings.context(send_stream=stream):
prediction = await program(*args, **kwargs)

await stream.send(prediction)

async def streamer(*args, **kwargs):
send_stream, receive_stream = create_memory_object_stream(16)
async with create_task_group() as tg, send_stream, receive_stream:
tg.start_soon(generator, args, kwargs, send_stream)

async for value in receive_stream:
yield value
if isinstance(value, Prediction):
return

return streamer


async def streaming_response(streamer: AsyncGenerator) -> AsyncGenerator:
"""
Convert a DSPy program output stream to an OpenAI-compatible output stream that can be
used by a service as an API response to a streaming request.
Args:
streamer: An async generator that yields values from a DSPy program output stream.
Returns:
An async generator that yields OpenAI-compatible streaming response chunks.
"""
async for value in streamer:
if isinstance(value, Prediction):
data = {"prediction": {k: v for k, v in value.items(include_dspy=False)}}
yield f"data: {ujson.dumps(data)}\n\n"
elif isinstance(value, litellm.ModelResponse):
data = {"chunk": value.json()}
yield f"data: {ujson.dumps(data)}\n\n"
elif isinstance(value, str) and value.startswith("data:"):
# The chunk value is an OpenAI-compatible streaming chunk value,
# e.g. "data: {"finish_reason": "stop", "index": 0, "is_finished": True, ...}",
# so yield it directly
yield value
else:
raise ValueError(f"Unknown chunk value type: {value}")
yield "data: [DONE]\n\n"
25 changes: 25 additions & 0 deletions tests/test_utils/server/litellm_server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import os
import time
from typing import AsyncIterator, Iterator

import litellm
from litellm import CustomLLM
from litellm.types.utils import GenericStreamingChunk

LITELLM_TEST_SERVER_LOG_FILE_PATH_ENV_VAR = "LITELLM_TEST_SERVER_LOG_FILE_PATH"

Expand All @@ -16,6 +19,28 @@ async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
_append_request_to_log_file(kwargs)
return _get_mock_llm_response(kwargs)

def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": '{"output_text": "Hello!"}',
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
return generic_streaming_chunk # type: ignore

async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": '{"output_text": "Hello!"}',
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
yield generic_streaming_chunk


def _get_mock_llm_response(request_kwargs):
_throw_exception_based_on_content_if_applicable(request_kwargs)
Expand Down
Loading

0 comments on commit 027312b

Please sign in to comment.