Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming #1874

Merged
merged 23 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
/examples/nli/scone/compiled_program.dspy
/examples/qa/hotpot/compiled_program.dspy
/ScoNe/
testing/playbook.ipynb

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -58,4 +59,4 @@ assertion.log
.mypy_cache
dummy.csv
docs/docs/**/*.json*
*.index
*.index
41 changes: 37 additions & 4 deletions docs/docs/tutorials/deployment/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ class Question(BaseModel):
# Configure your language model and 'asyncify' your DSPy program.
lm = dspy.LM("openai/gpt-4o-mini")
dspy.settings.configure(lm=lm, async_max_workers=4) # default is 8
dspy_program = dspy.ChainOfThought("question -> answer")
dspy_program = dspy.asyncify(dspy_program)

dspy_program = dspy.asyncify(dspy.ChainOfThought("question -> answer"))
streaming_dspy_program = dspy.streamify(dspy_program)

# Define an endpoint (no streaming)
@app.post("/predict")
async def predict(question: Question):
try:
Expand All @@ -54,14 +56,45 @@ async def predict(question: Question):
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

# Define an endpoint (streaming)
from fastapi.responses import StreamingResponse

@app.post("/predict/stream")
async def stream(question: Question):
async def generate():
async for value in streaming_dspy_program(question=question.text):
if isinstance(value, dspy.Prediction):
data = {"prediction": value.labels().toDict()}
elif isinstance(value, litellm.ModelResponse):
data = {"chunk": value.json()}
yield f"data: {ujson.dumps(data)}\n\n"
yield "data: [DONE]\n\n"

return StreamingResponse(generate(), media_type="text/event-stream")

# Since you're often going to want to stream the result of a DSPy program as server-sent events,
# we've included a helper function for that, which is equivalent to the code above.

from dspy.utils.streaming import streaming_response

@app.post("/predict/stream")
async def stream(question: Question):
stream = streaming_dspy_program(question=question.text)
return StreamingResponse(streaming_response(stream), media_type="text/event-stream")
CyrusNuevoDia marked this conversation as resolved.
Show resolved Hide resolved
```

In the code above, we call `dspy.asyncify` to convert the dspy program to run in async mode for high-throughput FastAPI
deployments. Currently, this runs the dspy program in a
separate thread and awaits its result. By default, the limit of spawned threads is 8. Think of this like a worker pool.
deployments. Currently, this runs the dspy program in a separate thread and awaits its result.

By default, the limit of spawned threads is 8. Think of this like a worker pool.
If you have 8 in-flight programs and call it once more, the 9th call will wait until one of the 8 returns.
You can configure the async capacity using the new `async_max_workers` setting.

We also use `dspy.streamify` to convert the dspy program to a streaming mode. This is useful when you want to stream
the intermediate outputs (i.e. O1-style reasoning) to the client before the final prediction is ready. This uses
asyncify under the hood and inherits the execution semantics.

Write your code to a file, e.g., `fastapi_dspy.py`. Then you can serve the app with:

```bash
Expand Down
2 changes: 2 additions & 0 deletions dsp/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
backoff_time=10,
callbacks=[],
async_max_workers=8,
request_cache=None,
send_stream=None,
)

# Global base configuration
Expand Down
9 changes: 5 additions & 4 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import dspy.retrievers

# Functional must be imported after primitives, predict and signatures
from .functional import * # isort: skip
from dspy.evaluate import Evaluate # isort: skip
from dspy.clients import * # isort: skip
from dspy.adapters import * # isort: skip
from .functional import * # isort: skip
from dspy.evaluate import Evaluate # isort: skip
from dspy.clients import * # isort: skip
from dspy.adapters import * # isort: skip
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
from dspy.utils.asyncify import asyncify
from dspy.utils.streaming import streamify

settings = dsp.settings

Expand Down
149 changes: 91 additions & 58 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import functools
import logging
import os
import threading
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional, cast

from anyio.streams.memory import MemoryObjectSendStream
from asyncer import syncify
import litellm
import ujson

import dspy
from dspy.adapters.base import Adapter
from dspy.clients.openai import OpenAIProvider
from dspy.clients.provider import Provider, TrainingJob
from dspy.clients.utils_finetune import DataFormat, infer_data_format, validate_data_format
from dspy.utils.caching import LRUCache
from dspy.utils.callback import BaseCallback, with_callbacks

from .base_lm import BaseLM
Expand Down Expand Up @@ -85,30 +87,32 @@ def __call__(self, prompt=None, messages=None, **kwargs):
messages = messages or [{"role": "user", "content": prompt}]
kwargs = {**self.kwargs, **kwargs}

# Make the request and handle LRU & disk caching.
if self.model_type == "chat":
completion = cached_litellm_completion if cache else litellm_completion
else:
completion = cached_litellm_text_completion if cache else litellm_text_completion
completion = litellm_completion if self.model_type == "chat" else litellm_text_completion

response = completion(
request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)),
num_retries=self.num_retries,
request=dict(
model=self.model,
messages=messages,
num_retries=self.num_retries,
**kwargs,
),
cache=cache,
)
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]

# Logging, with removed api key & where `cost` is None on cache hit.
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}
entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response)
entry = dict(**entry, outputs=outputs, usage=dict(response["usage"]))
entry = dict(**entry, cost=response.get("_hidden_params", {}).get("response_cost"))
entry = dict(
**entry,
timestamp=datetime.now().isoformat(),
uuid=str(uuid.uuid4()),
model=self.model,
model_type=self.model_type,
)
entry = {
"prompt": prompt,
"messages": messages,
"kwargs": {k: v for k, v in kwargs.items() if not k.startswith("api_")},
"response": response,
"outputs": outputs,
"usage": dict(response["usage"]),
"cost": response.get("_hidden_params", {}).get("response_cost"),
"timestamp": datetime.now().isoformat(),
"uuid": str(uuid.uuid4()),
"model": self.model,
"model_type": self.model_type,
}
self.history.append(entry)
self.update_global_history(entry)

Expand Down Expand Up @@ -153,7 +157,11 @@ def thread_function_wrapper():
thread = threading.Thread(target=thread_function_wrapper)
model_to_finetune = self.finetuning_model or self.model
job = self.provider.TrainingJob(
thread=thread, model=model_to_finetune, train_data=train_data, train_kwargs=train_kwargs, data_format=data_format
thread=thread,
model=model_to_finetune,
train_data=train_data,
train_kwargs=train_kwargs,
data_format=data_format,
)
thread.start()

Expand Down Expand Up @@ -212,54 +220,79 @@ def copy(self, **kwargs):
return new_instance


@functools.lru_cache(maxsize=None)
def cached_litellm_completion(request, num_retries: int):
return litellm_completion(
request,
cache={"no-cache": False, "no-store": False},
num_retries=num_retries,
)
def request_cache(default_cache=LRUCache([], maxsize=10_000_000)) -> LRUCache:
return dspy.settings.request_cache or default_cache


def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)
return litellm.completion(
num_retries=num_retries,
cache=cache,
**kwargs,
)
def _litellm_completion(request: dict, **kwargs):
stream = dspy.settings.send_stream
if stream is None:
return litellm.completion(**request, **kwargs)

# The stream is already opened, and will be closed by the caller.
stream = cast(MemoryObjectSendStream, stream)

@syncify
async def stream_completion():
response = await litellm.acompletion(**request, **kwargs, stream=True)
chunks = []
async for chunk in response:
chunks.append(chunk)
await stream.send(chunk)
return litellm.stream_chunk_builder(chunks)

return stream_completion()


@functools.lru_cache(maxsize=None)
def cached_litellm_text_completion(request, num_retries: int):
return litellm_text_completion(
request,
num_retries=num_retries,
def litellm_completion(request: dict, cache=False):
if not cache:
return _litellm_completion(request, cache={"no-cache": True, "no-store": True})

response = request_cache().get(request, None)
if response:
return response

response = _litellm_completion(request, cache={"no-cache": False, "no-store": False})
request_cache()[request] = response

return response


def litellm_text_completion(request: dict, cache=False):
params = _prepare_litellm_text_completion_params(request)
if not cache:
return litellm.text_completion(**params, cache={"no-cache": True, "no-store": True})

response = request_cache().get(request, None)
if response:
return response

response = litellm.text_completion(
**params,
cache={"no-cache": False, "no-store": False},
)
request_cache()[request] = response

return response

def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)

def _prepare_litellm_text_completion_params(request: dict):
# Extract the provider and model from the model string.
# TODO: Not all the models are in the format of "provider/model"
model = kwargs.pop("model").split("/", 1)
model = request.pop("model").split("/", 1)
provider, model = model[0] if len(model) > 1 else "openai", model[-1]

# Use the API key and base from the kwargs, or from the environment.
api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")

# Build the prompt from the messages.
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"])

return litellm.text_completion(
cache=cache,
model=f"text-completion-openai/{model}",
api_key=api_key,
api_base=api_base,
prompt=prompt,
num_retries=num_retries,
**kwargs,
)
prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])

return {
"model": f"text-completion-openai/{model}",
"api_key": api_key,
"api_base": api_base,
"prompt": prompt,
**request,
}
63 changes: 62 additions & 1 deletion dspy/utils/caching.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import os
from collections import OrderedDict
from hashlib import sha256
from pathlib import Path
from typing import Union
import os
import pickle

import ujson


_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".dspy_cache")
Expand All @@ -12,3 +18,58 @@ def create_subdir_in_cachedir(subdir: str) -> str:
subdir = os.path.abspath(subdir)
os.makedirs(subdir, exist_ok=True)
return subdir


class LRUCache(OrderedDict):
maxsize: int
CyrusNuevoDia marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, iterable, maxsize: int):
super().__init__(iterable)
self.maxsize = maxsize

@classmethod
def load(cls, file, maxsize: int) -> "LRUCache":
return cls(pickle.load(file), maxsize)

@staticmethod
def dump(obj, file) -> None:
pickle.dump([[k, v] for k, v in obj.items()], file)

def __setitem__(self, request: dict, value):
key = self.cache_key(request)

if key in self:
self.move_to_end(key)
return

if len(self) == self.maxsize:
self.popitem(last=False)

super().__setitem__(key, value)

def __getitem__(self, request: dict):
key = self.cache_key(request)
return super().__getitem__(key)

def __contains__(self, request: dict):
key = self.cache_key(request)
return super().__contains__(key)

def get(self, request: dict, default=None):
key = self.cache_key(request)
return super().get(key, default)

def __delitem__(self, request: dict):
key = self.cache_key(request)
super().__delitem__(key)

def pop(self, request: dict, default=None):
key = self.cache_key(request)
return super().pop(key, default)

@staticmethod
def cache_key(request: Union[dict, str]) -> str:
params = request
if isinstance(request, dict):
params = {k: v for k, v in request.items() if not callable(v)}
return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest()
CyrusNuevoDia marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading