Skip to content

Commit

Permalink
outline schema updated
Browse files Browse the repository at this point in the history
  • Loading branch information
staru09 committed Jan 11, 2025
1 parent e66cb34 commit 0b236b0
Show file tree
Hide file tree
Showing 5 changed files with 363 additions and 44 deletions.
4 changes: 3 additions & 1 deletion docetl/operations/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .llm import LLMResult, InvalidOutputError, truncate_messages
from .progress import RichLoopBar, rich_as_completed
from .validation import safe_eval, convert_val, convert_dict_schema_to_list_schema, get_user_input_for_schema, strict_render
from .oss_llm import OutlinesBackend

__all__ = [
'APIWrapper',
Expand All @@ -32,5 +33,6 @@
'convert_dict_schema_to_list_schema',
'get_user_input_for_schema',
'truncate_messages',
"strict_render"
"strict_render",
"OutlinesBackend"
]
42 changes: 32 additions & 10 deletions docetl/operations/utils/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,44 @@

from docetl.utils import completion_cost

from .cache import cache, cache_key, freezeargs
from .llm import InvalidOutputError, LLMResult, timeout, truncate_messages
from .validation import (
convert_dict_schema_to_list_schema,
convert_val,
get_user_input_for_schema,
safe_eval,
strict_render,
)
from rich import print as rprint

BASIC_MODELS = ["gpt-4o-mini", "gpt-4o"]


class APIWrapper(object):
class APIWrapper:
def __init__(self, runner):
self.runner = runner
self._oss_operations = {}

def _is_outlines_model(self, model: str) -> bool:
"""Check if model is an Outlines model"""
return model.startswith("outlines/")

def _get_model_path(self, model: str) -> str:
"""Extract model path from outlines model string"""
return model.split("outlines/", 1)[1]

def _call_llm_with_cache(
self,
model: str,
op_type: str,
messages: List[Dict[str, str]],
output_schema: Dict[str, str],
tools: Optional[str] = None,
scratchpad: Optional[str] = None,
litellm_completion_kwargs: Dict[str, Any] = {},
) -> ModelResponse:
"""Handle both Outlines and cloud model calls"""

# Add OSS model handling
if self._is_outlines_model(model):
model_path = self._get_model_path(model)
return self.outlines_backend.process_messages(
model_path=model_path,
messages=messages,
output_schema=output_schema
).response

@freezeargs
def gen_embedding(self, model: str, input: List[str]) -> List[float]:
Expand Down
107 changes: 107 additions & 0 deletions docetl/operations/utils/oss_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Any, Dict, List, Tuple
from pydantic import BaseModel, create_model
from outlines import generate, models
import json
import hashlib
from .llm import LLMResult, InvalidOutputError

class OutlinesBackend:
"""Backend for handling Outlines (local) models in DocETL operations."""

def __init__(self, config: Dict[str, Any] = None):
"""Initialize the Outlines backend.
Args:
config: Optional configuration dictionary containing global settings
"""
self._models = {}
self._processors = {}
self.config = config or {}

def setup_model(self, model_path: str, output_schema: Dict[str, Any] = None):
"""Initialize Outlines model and processor if needed.
Args:
model_path: Path to the model, without the 'outlines/' prefix
output_schema: Schema for the expected output
"""
if model_path not in self._models:
model_kwargs = {
k: v for k, v in self.config.items()
if k in ['max_tokens']
}
self._models[model_path] = models.transformers(model_path, **model_kwargs)

if output_schema:
field_definitions = {
k: (eval(v) if isinstance(v, str) else v, ...)
for k, v in output_schema.items()
}
output_model = create_model('OutputModel', **field_definitions)
self._processors[model_path] = generate.json(
self._models[model_path],
output_model
)

def process_messages(
self,
model_path: str,
messages: List[Dict[str, str]],
output_schema: Dict[str, Any]
) -> LLMResult:
"""Process messages through Outlines model.
Args:
model_path: Path to the model, without the 'outlines/' prefix
messages: List of message dictionaries with 'role' and 'content'
output_schema: Schema for the expected output
Returns:
LLMResult containing the model's response in LiteLLM format
"""
try:
self.setup_model(model_path, output_schema)

prompt = "\n".join(
f"{msg['role'].capitalize()}: {msg['content']}"
for msg in messages
)

result = self._processors[model_path](prompt)

response = {
"choices": [{
"message": {
"role": "assistant",
"content": None,
"tool_calls": [{
"function": {
"name": "send_output",
"arguments": json.dumps(result.model_dump())
},
"id": "call_" + hashlib.md5(
json.dumps(result.model_dump()).encode()
).hexdigest(),
"type": "function"
}]
},
"finish_reason": "stop",
"index": 0
}],
"model": f"outlines/{model_path}",
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}

return LLMResult(response=response, total_cost=0.0, validated=True)

except Exception as e:
raise InvalidOutputError(
message=str(e),
output=str(e),
expected_schema=output_schema,
messages=messages
)
71 changes: 38 additions & 33 deletions docetl/operations/utils/validation.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,47 @@
import ast
import json
from typing import Any, Dict, Union

from typing import Union, Dict, Any
from asteval import Interpreter
from jinja2 import Environment, StrictUndefined, Template
from jinja2.exceptions import UndefinedError
from rich import print as rprint
from rich.prompt import Prompt

aeval = Interpreter()
from jinja2 import Environment, StrictUndefined, Template
from jinja2.exceptions import UndefinedError


aeval = Interpreter()

def strict_render(template: Union[Template, str], context: Dict[str, Any]) -> str:
"""
Renders a Jinja template with strict undefined checking.
Args:
template: Either a Jinja2 Template object or a template string
context: Dictionary containing the template variables
Returns:
The rendered template string
Raises:
UndefinedError: When any undefined variable, attribute or index is accessed
ValueError: When template is invalid
"""
# Create strict environment
env = Environment(undefined=StrictUndefined)

# Convert string to Template if needed
if isinstance(template, str):

# # If "inputs" in the context, make sure they are not accessing some attribute of inputs
# if "inputs" in context and "{{ inputs." in template:
# raise UndefinedError("The inputs variable is a list, so you cannot access attributes of inputs. Use inputs[index].key instead.")

try:
template = env.from_string(template)
except Exception as e:
raise ValueError(f"Invalid template: {str(e)}")

try:
try:
return template.render(context)
except UndefinedError as e:
# Get the available context keys for better error reporting
Expand All @@ -52,12 +53,8 @@ def strict_render(template: Union[Template, str], context: Dict[str, Any]) -> st
if isinstance(context[var], dict):
var_attributes[var] = list(context[var].keys())
elif isinstance(context[var], list) and len(context[var]) > 0:
var_attributes[var] = [
f"inputs[i].{k}"
for k in context[var][0].keys()
if "_observability" not in k
]

var_attributes[var] = [f"inputs[i].{k}" for k in context[var][0].keys() if "_observability" not in k]

raise UndefinedError(
f"{str(e)}\n"
f"Your prompt can include the following variables: {available_vars}\n"
Expand All @@ -77,28 +74,39 @@ def safe_eval(expression: str, output: Dict) -> bool:
except Exception:
return False


def convert_val(value: Any, model: str = "gpt-4o-mini") -> Dict[str, Any]:
"""Convert a string representation of a type to a dictionary representation."""
value = value.strip().lower()
is_outlines = model.startswith("outlines/")

# Basic types
if value in ["str", "text", "string", "varchar"]:
return {"type": "string"}
return "str" if is_outlines else {"type": "string"}
elif value in ["int", "integer"]:
return {"type": "integer"}
return "int" if is_outlines else {"type": "integer"}
elif value in ["float", "decimal", "number"]:
return {"type": "number"}
return "float" if is_outlines else {"type": "number"}
elif value in ["bool", "boolean"]:
return {"type": "boolean"}
return "bool" if is_outlines else {"type": "boolean"}

# Lists
elif value.startswith("list["):
inner_type = value[5:-1].strip()
return {"type": "array", "items": convert_val(inner_type, model)}
inner_val = convert_val(inner_type, model)
return f"List[{inner_val}]" if is_outlines else {"type": "array", "items": inner_val}
elif value == "list":
raise ValueError("List type must specify its elements, e.g., 'list[str]'")

# Objects
elif value.startswith("{") and value.endswith("}"):
properties = {}
for item in value[1:-1].split(","):
key, val = item.strip().split(":")
properties[key.strip()] = convert_val(val.strip(), model)

if is_outlines:
return properties

result = {
"type": "object",
"properties": properties,
Expand All @@ -107,20 +115,21 @@ def convert_val(value: Any, model: str = "gpt-4o-mini") -> Dict[str, Any]:
if "gemini" not in model:
result["additionalProperties"] = False
return result

# Enums
elif value.startswith("enum[") and value.endswith("]"):
enum_values = value[5:-1].strip().split(",")
enum_values = [v.strip() for v in enum_values]
return {"type": "string", "enum": enum_values}

else:
raise ValueError(f"Unsupported value type: {value}")


def convert_dict_schema_to_list_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a dictionary schema to a list schema."""
schema_str = "{" + ", ".join([f"{k}: {v}" for k, v in schema.items()]) + "}"
return {"results": f"list[{schema_str}]"}


def get_user_input_for_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
"""Prompt the user for input for each key in the schema."""
user_input = {}
Expand All @@ -134,15 +143,11 @@ def get_user_input_for_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(parsed_value, eval(value_type)):
user_input[key] = parsed_value
else:
rprint(
f"[bold red]Error:[/bold red] Input for '{key}' does not match the expected type {value_type}."
)
rprint(f"[bold red]Error:[/bold red] Input for '{key}' does not match the expected type {value_type}.")
return get_user_input_for_schema(schema)

except json.JSONDecodeError:
rprint(
f"[bold red]Error:[/bold red] Invalid JSON input for '{key}'. Please try again."
)
rprint(f"[bold red]Error:[/bold red] Invalid JSON input for '{key}'. Please try again.")
return get_user_input_for_schema(schema)

return user_input
return user_input
Loading

0 comments on commit 0b236b0

Please sign in to comment.