Skip to content

Commit

Permalink
initialize the outlines backend
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Jan 19, 2025
1 parent 636f3f3 commit db857cf
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 72 deletions.
60 changes: 36 additions & 24 deletions docetl/operations/utils/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@

from docetl.utils import completion_cost

from rich import print as rprint
from .cache import cache, cache_key, freezeargs
from .llm import InvalidOutputError, LLMResult, timeout, truncate_messages
from .oss_llm import OutlinesBackend
from .validation import (
convert_dict_schema_to_list_schema,
convert_val,
get_user_input_for_schema,
safe_eval,
strict_render,
)

BASIC_MODELS = ["gpt-4o-mini", "gpt-4o"]

Expand All @@ -19,6 +28,21 @@ class APIWrapper:
def __init__(self, runner):
self.runner = runner
self._oss_operations = {}
self.outlines_backend = OutlinesBackend()

# If any of the models in the config are outlines models, initialize the outlines backend
default_model = self.runner.config.get("default_model", "gpt-4o-mini")
models_and_schemas = [
(op.get("model", default_model), op.get("output", {}).get("schema", {}))
for op in self.runner.config.get("operations", [])
]
for model, schema in models_and_schemas:
if self._is_outlines_model(model):
model = model.replace("outlines/", "")
self.runner.console.log(
f"Initializing outlines backend for model: {model}"
)
self.outlines_backend.setup_model(model, schema)

def _is_outlines_model(self, model: str) -> bool:
"""Check if model is an Outlines model"""
Expand All @@ -27,27 +51,6 @@ def _is_outlines_model(self, model: str) -> bool:
def _get_model_path(self, model: str) -> str:
"""Extract model path from outlines model string"""
return model.split("outlines/", 1)[1]

def _call_llm_with_cache(
self,
model: str,
op_type: str,
messages: List[Dict[str, str]],
output_schema: Dict[str, str],
tools: Optional[str] = None,
scratchpad: Optional[str] = None,
litellm_completion_kwargs: Dict[str, Any] = {},
) -> ModelResponse:
"""Handle both Outlines and cloud model calls"""

# Add OSS model handling
if self._is_outlines_model(model):
model_path = self._get_model_path(model)
return self.outlines_backend.process_messages(
model_path=model_path,
messages=messages,
output_schema=output_schema
).response

@freezeargs
def gen_embedding(self, model: str, input: List[str]) -> List[float]:
Expand Down Expand Up @@ -461,6 +464,14 @@ def _call_llm_with_cache(
Returns:
str: The response from the LLM.
"""
# Add OSS model handling
if self._is_outlines_model(model):
model_path = self._get_model_path(model)
returned_val = self.outlines_backend.process_messages(
model_path=model_path, messages=messages, output_schema=output_schema
).response
return returned_val

props = {key: convert_val(value) for key, value in output_schema.items()}
use_tools = True

Expand Down Expand Up @@ -570,8 +581,9 @@ def _call_llm_with_cache(
"role": "system",
"content": system_prompt,
},
] + messages,
output_schema=output_schema
]
+ messages,
output_schema=output_schema,
)
# Handle other models through LiteLLM
if tools is not None:
Expand Down
102 changes: 54 additions & 48 deletions docetl/operations/utils/oss_llm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from typing import Any, Dict, List, Tuple
from pydantic import BaseModel, create_model
from outlines import generate, models
import json
import hashlib
from .llm import LLMResult, InvalidOutputError
import json
from typing import Any, Dict, List

from outlines import generate, models
from pydantic import create_model

from .llm import InvalidOutputError, LLMResult


class OutlinesBackend:
"""Backend for handling Outlines (local) models in DocETL operations."""

def __init__(self, config: Dict[str, Any] = None):
"""Initialize the Outlines backend.
Args:
config: Optional configuration dictionary containing global settings
"""
Expand All @@ -20,88 +23,91 @@ def __init__(self, config: Dict[str, Any] = None):

def setup_model(self, model_path: str, output_schema: Dict[str, Any] = None):
"""Initialize Outlines model and processor if needed.
Args:
model_path: Path to the model, without the 'outlines/' prefix
output_schema: Schema for the expected output
"""
if model_path not in self._models:
model_kwargs = {
k: v for k, v in self.config.items()
if k in ['max_tokens']
}
model_kwargs = {k: v for k, v in self.config.items() if k in ["max_tokens"]}
self._models[model_path] = models.transformers(model_path, **model_kwargs)

if output_schema:
field_definitions = {
k: (eval(v) if isinstance(v, str) else v, ...)
for k, v in output_schema.items()
}
output_model = create_model('OutputModel', **field_definitions)
output_model = create_model("OutputModel", **field_definitions)
self._processors[model_path] = generate.json(
self._models[model_path],
output_model
self._models[model_path], output_model
)

def process_messages(
self,
model_path: str,
messages: List[Dict[str, str]],
output_schema: Dict[str, Any]
output_schema: Dict[str, Any],
) -> LLMResult:
"""Process messages through Outlines model.
Args:
model_path: Path to the model, without the 'outlines/' prefix
messages: List of message dictionaries with 'role' and 'content'
output_schema: Schema for the expected output
Returns:
LLMResult containing the model's response in LiteLLM format
"""
try:
self.setup_model(model_path, output_schema)

prompt = "\n".join(
f"{msg['role'].capitalize()}: {msg['content']}"
for msg in messages
f"{msg['role'].capitalize()}: {msg['content']}" for msg in messages
)

result = self._processors[model_path](prompt)

response = {
"choices": [{
"message": {
"role": "assistant",
"content": None,
"tool_calls": [{
"function": {
"name": "send_output",
"arguments": json.dumps(result.model_dump())
},
"id": "call_" + hashlib.md5(
json.dumps(result.model_dump()).encode()
).hexdigest(),
"type": "function"
}]
},
"finish_reason": "stop",
"index": 0
}],
"choices": [
{
"message": {
"role": "assistant",
"content": None,
"tool_calls": [
{
"function": {
"name": "send_output",
"arguments": json.dumps(result.model_dump()),
},
"id": "call_"
+ hashlib.md5(
json.dumps(result.model_dump()).encode()
).hexdigest(),
"type": "function",
}
],
},
"finish_reason": "stop",
"index": 0,
}
],
"model": f"outlines/{model_path}",
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
"total_tokens": 0,
},
}

return LLMResult(response=response, total_cost=0.0, validated=True)

except Exception as e:
import traceback

traceback.print_exc()
raise InvalidOutputError(
message=str(e),
output=str(e),
expected_schema=output_schema,
messages=messages
)
messages=messages,
)

0 comments on commit db857cf

Please sign in to comment.