diff --git a/docetl/operations/utils/api.py b/docetl/operations/utils/api.py index 0bec04a1..d5c5a12b 100644 --- a/docetl/operations/utils/api.py +++ b/docetl/operations/utils/api.py @@ -10,7 +10,16 @@ from docetl.utils import completion_cost -from rich import print as rprint +from .cache import cache, cache_key, freezeargs +from .llm import InvalidOutputError, LLMResult, timeout, truncate_messages +from .oss_llm import OutlinesBackend +from .validation import ( + convert_dict_schema_to_list_schema, + convert_val, + get_user_input_for_schema, + safe_eval, + strict_render, +) BASIC_MODELS = ["gpt-4o-mini", "gpt-4o"] @@ -19,6 +28,21 @@ class APIWrapper: def __init__(self, runner): self.runner = runner self._oss_operations = {} + self.outlines_backend = OutlinesBackend() + + # If any of the models in the config are outlines models, initialize the outlines backend + default_model = self.runner.config.get("default_model", "gpt-4o-mini") + models_and_schemas = [ + (op.get("model", default_model), op.get("output", {}).get("schema", {})) + for op in self.runner.config.get("operations", []) + ] + for model, schema in models_and_schemas: + if self._is_outlines_model(model): + model = model.replace("outlines/", "") + self.runner.console.log( + f"Initializing outlines backend for model: {model}" + ) + self.outlines_backend.setup_model(model, schema) def _is_outlines_model(self, model: str) -> bool: """Check if model is an Outlines model""" @@ -27,27 +51,6 @@ def _is_outlines_model(self, model: str) -> bool: def _get_model_path(self, model: str) -> str: """Extract model path from outlines model string""" return model.split("outlines/", 1)[1] - - def _call_llm_with_cache( - self, - model: str, - op_type: str, - messages: List[Dict[str, str]], - output_schema: Dict[str, str], - tools: Optional[str] = None, - scratchpad: Optional[str] = None, - litellm_completion_kwargs: Dict[str, Any] = {}, - ) -> ModelResponse: - """Handle both Outlines and cloud model calls""" - - # Add OSS model handling - if self._is_outlines_model(model): - model_path = self._get_model_path(model) - return self.outlines_backend.process_messages( - model_path=model_path, - messages=messages, - output_schema=output_schema - ).response @freezeargs def gen_embedding(self, model: str, input: List[str]) -> List[float]: @@ -461,6 +464,14 @@ def _call_llm_with_cache( Returns: str: The response from the LLM. """ + # Add OSS model handling + if self._is_outlines_model(model): + model_path = self._get_model_path(model) + returned_val = self.outlines_backend.process_messages( + model_path=model_path, messages=messages, output_schema=output_schema + ).response + return returned_val + props = {key: convert_val(value) for key, value in output_schema.items()} use_tools = True @@ -570,8 +581,9 @@ def _call_llm_with_cache( "role": "system", "content": system_prompt, }, - ] + messages, - output_schema=output_schema + ] + + messages, + output_schema=output_schema, ) # Handle other models through LiteLLM if tools is not None: diff --git a/docetl/operations/utils/oss_llm.py b/docetl/operations/utils/oss_llm.py index bd1c81fe..6f3f9213 100644 --- a/docetl/operations/utils/oss_llm.py +++ b/docetl/operations/utils/oss_llm.py @@ -1,16 +1,19 @@ -from typing import Any, Dict, List, Tuple -from pydantic import BaseModel, create_model -from outlines import generate, models -import json import hashlib -from .llm import LLMResult, InvalidOutputError +import json +from typing import Any, Dict, List + +from outlines import generate, models +from pydantic import create_model + +from .llm import InvalidOutputError, LLMResult + class OutlinesBackend: """Backend for handling Outlines (local) models in DocETL operations.""" - + def __init__(self, config: Dict[str, Any] = None): """Initialize the Outlines backend. - + Args: config: Optional configuration dictionary containing global settings """ @@ -20,88 +23,91 @@ def __init__(self, config: Dict[str, Any] = None): def setup_model(self, model_path: str, output_schema: Dict[str, Any] = None): """Initialize Outlines model and processor if needed. - + Args: model_path: Path to the model, without the 'outlines/' prefix output_schema: Schema for the expected output """ if model_path not in self._models: - model_kwargs = { - k: v for k, v in self.config.items() - if k in ['max_tokens'] - } + model_kwargs = {k: v for k, v in self.config.items() if k in ["max_tokens"]} self._models[model_path] = models.transformers(model_path, **model_kwargs) - + if output_schema: field_definitions = { k: (eval(v) if isinstance(v, str) else v, ...) for k, v in output_schema.items() } - output_model = create_model('OutputModel', **field_definitions) + output_model = create_model("OutputModel", **field_definitions) self._processors[model_path] = generate.json( - self._models[model_path], - output_model + self._models[model_path], output_model ) def process_messages( self, model_path: str, messages: List[Dict[str, str]], - output_schema: Dict[str, Any] + output_schema: Dict[str, Any], ) -> LLMResult: """Process messages through Outlines model. - + Args: model_path: Path to the model, without the 'outlines/' prefix messages: List of message dictionaries with 'role' and 'content' output_schema: Schema for the expected output - + Returns: LLMResult containing the model's response in LiteLLM format """ try: self.setup_model(model_path, output_schema) - + prompt = "\n".join( - f"{msg['role'].capitalize()}: {msg['content']}" - for msg in messages + f"{msg['role'].capitalize()}: {msg['content']}" for msg in messages ) - + result = self._processors[model_path](prompt) - + response = { - "choices": [{ - "message": { - "role": "assistant", - "content": None, - "tool_calls": [{ - "function": { - "name": "send_output", - "arguments": json.dumps(result.model_dump()) - }, - "id": "call_" + hashlib.md5( - json.dumps(result.model_dump()).encode() - ).hexdigest(), - "type": "function" - }] - }, - "finish_reason": "stop", - "index": 0 - }], + "choices": [ + { + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "function": { + "name": "send_output", + "arguments": json.dumps(result.model_dump()), + }, + "id": "call_" + + hashlib.md5( + json.dumps(result.model_dump()).encode() + ).hexdigest(), + "type": "function", + } + ], + }, + "finish_reason": "stop", + "index": 0, + } + ], "model": f"outlines/{model_path}", "usage": { "prompt_tokens": 0, "completion_tokens": 0, - "total_tokens": 0 - } + "total_tokens": 0, + }, } - + return LLMResult(response=response, total_cost=0.0, validated=True) - + except Exception as e: + import traceback + + traceback.print_exc() raise InvalidOutputError( message=str(e), output=str(e), expected_schema=output_schema, - messages=messages - ) \ No newline at end of file + messages=messages, + )