From 0b236b05172879396199d168c3d1cad3b756ae04 Mon Sep 17 00:00:00 2001 From: Aru Sharma Date: Sat, 11 Jan 2025 16:21:44 +0530 Subject: [PATCH] outline schema updated --- docetl/operations/utils/__init__.py | 4 +- docetl/operations/utils/api.py | 42 ++++-- docetl/operations/utils/oss_llm.py | 107 +++++++++++++++ docetl/operations/utils/validation.py | 71 +++++----- tests/test_oss_llm.py | 183 ++++++++++++++++++++++++++ 5 files changed, 363 insertions(+), 44 deletions(-) create mode 100644 docetl/operations/utils/oss_llm.py create mode 100644 tests/test_oss_llm.py diff --git a/docetl/operations/utils/__init__.py b/docetl/operations/utils/__init__.py index e787f842..a30e518f 100644 --- a/docetl/operations/utils/__init__.py +++ b/docetl/operations/utils/__init__.py @@ -12,6 +12,7 @@ from .llm import LLMResult, InvalidOutputError, truncate_messages from .progress import RichLoopBar, rich_as_completed from .validation import safe_eval, convert_val, convert_dict_schema_to_list_schema, get_user_input_for_schema, strict_render +from .oss_llm import OutlinesBackend __all__ = [ 'APIWrapper', @@ -32,5 +33,6 @@ 'convert_dict_schema_to_list_schema', 'get_user_input_for_schema', 'truncate_messages', - "strict_render" + "strict_render", + "OutlinesBackend" ] \ No newline at end of file diff --git a/docetl/operations/utils/api.py b/docetl/operations/utils/api.py index d82d9bcf..e6b79902 100644 --- a/docetl/operations/utils/api.py +++ b/docetl/operations/utils/api.py @@ -10,22 +10,44 @@ from docetl.utils import completion_cost -from .cache import cache, cache_key, freezeargs -from .llm import InvalidOutputError, LLMResult, timeout, truncate_messages -from .validation import ( - convert_dict_schema_to_list_schema, - convert_val, - get_user_input_for_schema, - safe_eval, - strict_render, -) +from rich import print as rprint BASIC_MODELS = ["gpt-4o-mini", "gpt-4o"] -class APIWrapper(object): +class APIWrapper: def __init__(self, runner): self.runner = runner + self._oss_operations = {} + + def _is_outlines_model(self, model: str) -> bool: + """Check if model is an Outlines model""" + return model.startswith("outlines/") + + def _get_model_path(self, model: str) -> str: + """Extract model path from outlines model string""" + return model.split("outlines/", 1)[1] + + def _call_llm_with_cache( + self, + model: str, + op_type: str, + messages: List[Dict[str, str]], + output_schema: Dict[str, str], + tools: Optional[str] = None, + scratchpad: Optional[str] = None, + litellm_completion_kwargs: Dict[str, Any] = {}, + ) -> ModelResponse: + """Handle both Outlines and cloud model calls""" + + # Add OSS model handling + if self._is_outlines_model(model): + model_path = self._get_model_path(model) + return self.outlines_backend.process_messages( + model_path=model_path, + messages=messages, + output_schema=output_schema + ).response @freezeargs def gen_embedding(self, model: str, input: List[str]) -> List[float]: diff --git a/docetl/operations/utils/oss_llm.py b/docetl/operations/utils/oss_llm.py new file mode 100644 index 00000000..bd1c81fe --- /dev/null +++ b/docetl/operations/utils/oss_llm.py @@ -0,0 +1,107 @@ +from typing import Any, Dict, List, Tuple +from pydantic import BaseModel, create_model +from outlines import generate, models +import json +import hashlib +from .llm import LLMResult, InvalidOutputError + +class OutlinesBackend: + """Backend for handling Outlines (local) models in DocETL operations.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize the Outlines backend. + + Args: + config: Optional configuration dictionary containing global settings + """ + self._models = {} + self._processors = {} + self.config = config or {} + + def setup_model(self, model_path: str, output_schema: Dict[str, Any] = None): + """Initialize Outlines model and processor if needed. + + Args: + model_path: Path to the model, without the 'outlines/' prefix + output_schema: Schema for the expected output + """ + if model_path not in self._models: + model_kwargs = { + k: v for k, v in self.config.items() + if k in ['max_tokens'] + } + self._models[model_path] = models.transformers(model_path, **model_kwargs) + + if output_schema: + field_definitions = { + k: (eval(v) if isinstance(v, str) else v, ...) + for k, v in output_schema.items() + } + output_model = create_model('OutputModel', **field_definitions) + self._processors[model_path] = generate.json( + self._models[model_path], + output_model + ) + + def process_messages( + self, + model_path: str, + messages: List[Dict[str, str]], + output_schema: Dict[str, Any] + ) -> LLMResult: + """Process messages through Outlines model. + + Args: + model_path: Path to the model, without the 'outlines/' prefix + messages: List of message dictionaries with 'role' and 'content' + output_schema: Schema for the expected output + + Returns: + LLMResult containing the model's response in LiteLLM format + """ + try: + self.setup_model(model_path, output_schema) + + prompt = "\n".join( + f"{msg['role'].capitalize()}: {msg['content']}" + for msg in messages + ) + + result = self._processors[model_path](prompt) + + response = { + "choices": [{ + "message": { + "role": "assistant", + "content": None, + "tool_calls": [{ + "function": { + "name": "send_output", + "arguments": json.dumps(result.model_dump()) + }, + "id": "call_" + hashlib.md5( + json.dumps(result.model_dump()).encode() + ).hexdigest(), + "type": "function" + }] + }, + "finish_reason": "stop", + "index": 0 + }], + "model": f"outlines/{model_path}", + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + } + + return LLMResult(response=response, total_cost=0.0, validated=True) + + except Exception as e: + raise InvalidOutputError( + message=str(e), + output=str(e), + expected_schema=output_schema, + messages=messages + ) \ No newline at end of file diff --git a/docetl/operations/utils/validation.py b/docetl/operations/utils/validation.py index b89f5388..380430b2 100644 --- a/docetl/operations/utils/validation.py +++ b/docetl/operations/utils/validation.py @@ -1,46 +1,47 @@ +import ast import json -from typing import Any, Dict, Union - +from typing import Union, Dict, Any from asteval import Interpreter -from jinja2 import Environment, StrictUndefined, Template -from jinja2.exceptions import UndefinedError from rich import print as rprint from rich.prompt import Prompt -aeval = Interpreter() +from jinja2 import Environment, StrictUndefined, Template +from jinja2.exceptions import UndefinedError + +aeval = Interpreter() def strict_render(template: Union[Template, str], context: Dict[str, Any]) -> str: """ Renders a Jinja template with strict undefined checking. - + Args: template: Either a Jinja2 Template object or a template string context: Dictionary containing the template variables - + Returns: The rendered template string - + Raises: UndefinedError: When any undefined variable, attribute or index is accessed ValueError: When template is invalid """ # Create strict environment env = Environment(undefined=StrictUndefined) - + # Convert string to Template if needed if isinstance(template, str): # # If "inputs" in the context, make sure they are not accessing some attribute of inputs # if "inputs" in context and "{{ inputs." in template: # raise UndefinedError("The inputs variable is a list, so you cannot access attributes of inputs. Use inputs[index].key instead.") - + try: template = env.from_string(template) except Exception as e: raise ValueError(f"Invalid template: {str(e)}") - - try: + + try: return template.render(context) except UndefinedError as e: # Get the available context keys for better error reporting @@ -52,12 +53,8 @@ def strict_render(template: Union[Template, str], context: Dict[str, Any]) -> st if isinstance(context[var], dict): var_attributes[var] = list(context[var].keys()) elif isinstance(context[var], list) and len(context[var]) > 0: - var_attributes[var] = [ - f"inputs[i].{k}" - for k in context[var][0].keys() - if "_observability" not in k - ] - + var_attributes[var] = [f"inputs[i].{k}" for k in context[var][0].keys() if "_observability" not in k] + raise UndefinedError( f"{str(e)}\n" f"Your prompt can include the following variables: {available_vars}\n" @@ -77,28 +74,39 @@ def safe_eval(expression: str, output: Dict) -> bool: except Exception: return False - def convert_val(value: Any, model: str = "gpt-4o-mini") -> Dict[str, Any]: """Convert a string representation of a type to a dictionary representation.""" value = value.strip().lower() + is_outlines = model.startswith("outlines/") + + # Basic types if value in ["str", "text", "string", "varchar"]: - return {"type": "string"} + return "str" if is_outlines else {"type": "string"} elif value in ["int", "integer"]: - return {"type": "integer"} + return "int" if is_outlines else {"type": "integer"} elif value in ["float", "decimal", "number"]: - return {"type": "number"} + return "float" if is_outlines else {"type": "number"} elif value in ["bool", "boolean"]: - return {"type": "boolean"} + return "bool" if is_outlines else {"type": "boolean"} + + # Lists elif value.startswith("list["): inner_type = value[5:-1].strip() - return {"type": "array", "items": convert_val(inner_type, model)} + inner_val = convert_val(inner_type, model) + return f"List[{inner_val}]" if is_outlines else {"type": "array", "items": inner_val} elif value == "list": raise ValueError("List type must specify its elements, e.g., 'list[str]'") + + # Objects elif value.startswith("{") and value.endswith("}"): properties = {} for item in value[1:-1].split(","): key, val = item.strip().split(":") properties[key.strip()] = convert_val(val.strip(), model) + + if is_outlines: + return properties + result = { "type": "object", "properties": properties, @@ -107,20 +115,21 @@ def convert_val(value: Any, model: str = "gpt-4o-mini") -> Dict[str, Any]: if "gemini" not in model: result["additionalProperties"] = False return result + + # Enums elif value.startswith("enum[") and value.endswith("]"): enum_values = value[5:-1].strip().split(",") enum_values = [v.strip() for v in enum_values] return {"type": "string", "enum": enum_values} + else: raise ValueError(f"Unsupported value type: {value}") - def convert_dict_schema_to_list_schema(schema: Dict[str, Any]) -> Dict[str, Any]: """Convert a dictionary schema to a list schema.""" schema_str = "{" + ", ".join([f"{k}: {v}" for k, v in schema.items()]) + "}" return {"results": f"list[{schema_str}]"} - def get_user_input_for_schema(schema: Dict[str, Any]) -> Dict[str, Any]: """Prompt the user for input for each key in the schema.""" user_input = {} @@ -134,15 +143,11 @@ def get_user_input_for_schema(schema: Dict[str, Any]) -> Dict[str, Any]: if isinstance(parsed_value, eval(value_type)): user_input[key] = parsed_value else: - rprint( - f"[bold red]Error:[/bold red] Input for '{key}' does not match the expected type {value_type}." - ) + rprint(f"[bold red]Error:[/bold red] Input for '{key}' does not match the expected type {value_type}.") return get_user_input_for_schema(schema) except json.JSONDecodeError: - rprint( - f"[bold red]Error:[/bold red] Invalid JSON input for '{key}'. Please try again." - ) + rprint(f"[bold red]Error:[/bold red] Invalid JSON input for '{key}'. Please try again.") return get_user_input_for_schema(schema) - return user_input + return user_input \ No newline at end of file diff --git a/tests/test_oss_llm.py b/tests/test_oss_llm.py new file mode 100644 index 00000000..81e17022 --- /dev/null +++ b/tests/test_oss_llm.py @@ -0,0 +1,183 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from docetl.operations.utils.oss_llm import OutlinesBackend +from docetl.operations.utils.llm import LLMResult, InvalidOutputError +import json + +@pytest.fixture +def mock_global_config(): + return { + "max_tokens": 4096, + } + +@pytest.fixture +def mock_output_schema(): + return { + "first_name": "str", + "last_name": "str" + } + +@pytest.fixture +def mock_research_schema(): + return { + "title": "str", + "authors": "list", + "methodology": "str", + "findings": "list", + "limitations": "list", + "future_work": "list" + } + +@pytest.fixture +def mock_research_output(): + research_data = { + "title": "Deep Learning in Natural Language Processing", + "authors": ["John Smith", "Jane Doe"], + "methodology": "Comparative analysis of transformer architectures", + "findings": [ + "Improved accuracy by 15%", + "Reduced training time by 30%" + ], + "limitations": [ + "Limited dataset size", + "Computational constraints" + ], + "future_work": [ + "Extend to multilingual models", + "Optimize for edge devices" + ] + } + + class MockOutput: + def model_dump(self): + return research_data + return MockOutput() + +def test_process_messages(mock_global_config, mock_output_schema): + mock_model = MagicMock() + + class MockOutput: + def model_dump(self): + return { + "first_name": "John", + "last_name": "Doe" + } + + mock_processor = Mock(return_value=MockOutput()) + + with patch('outlines.models.transformers', return_value=mock_model) as mock_transformers, \ + patch('outlines.generate.json', return_value=mock_processor): + + backend = OutlinesBackend(mock_global_config) + messages = [ + {"role": "user", "content": "Extract information about John Doe"} + ] + + result = backend.process_messages( + model_path="meta-llama/Llama-3.2-1B-Instruct", + messages=messages, + output_schema=mock_output_schema + ) + + assert isinstance(result, LLMResult) + assert result.total_cost == 0.0 + assert result.validated == True + assert "tool_calls" in result.response["choices"][0]["message"] + + tool_call = result.response["choices"][0]["message"]["tool_calls"][0] + output_data = json.loads(tool_call["function"]["arguments"]) + assert output_data["first_name"] == "John" + assert output_data["last_name"] == "Doe" + +def test_research_paper_analysis(mock_global_config, mock_research_schema, mock_research_output): + mock_model = MagicMock() + mock_processor = Mock(return_value=mock_research_output) + + with patch('outlines.models.transformers', return_value=mock_model) as mock_transformers, \ + patch('outlines.generate.json', return_value=mock_processor): + + backend = OutlinesBackend(mock_global_config) + messages = [{ + "role": "user", + "content": """ + Analyze this research paper: + This paper presents a comprehensive analysis of deep learning approaches + in natural language processing. We compare various transformer architectures + and their performance on standard NLP tasks. + """ + }] + + result = backend.process_messages( + model_path="meta-llama/Llama-3.2-1B-Instruct", + messages=messages, + output_schema=mock_research_schema + ) + + # Verify LLMResult structure + assert isinstance(result, LLMResult) + assert result.total_cost == 0.0 + assert result.validated == True + + # Extract output data from tool call + tool_call = result.response["choices"][0]["message"]["tool_calls"][0] + output_data = json.loads(tool_call["function"]["arguments"]) + + # Verify structure and types + assert isinstance(output_data["title"], str) + assert isinstance(output_data["authors"], list) + assert isinstance(output_data["methodology"], str) + assert isinstance(output_data["findings"], list) + assert len(output_data["findings"]) > 0 + assert isinstance(output_data["limitations"], list) + assert isinstance(output_data["future_work"], list) + +def test_model_reuse(mock_global_config, mock_output_schema): + """Test that the same model is reused for multiple calls""" + mock_model = MagicMock() + mock_processor = Mock(return_value=MagicMock(model_dump=lambda: {"first_name": "John", "last_name": "Doe"})) + + with patch('outlines.models.transformers', return_value=mock_model) as mock_transformers, \ + patch('outlines.generate.json', return_value=mock_processor): + + backend = OutlinesBackend(mock_global_config) + messages = [{"role": "user", "content": "Test message"}] + model_path = "meta-llama/Llama-3.2-1B-Instruct" + + # First call should initialize the model + backend.process_messages(model_path, messages, mock_output_schema) + # Second call should reuse the model + backend.process_messages(model_path, messages, mock_output_schema) + + # Check that transformers was only called once + mock_transformers.assert_called_once() + +def test_invalid_output_schema(mock_global_config): + backend = OutlinesBackend(mock_global_config) + messages = [{"role": "user", "content": "Test"}] + + with pytest.raises(Exception): + backend.process_messages( + model_path="test-model", + messages=messages, + output_schema={"invalid_type": "unknown"} + ) + +def test_model_error_handling(mock_global_config, mock_output_schema): + """Test handling of model processing errors""" + mock_model = MagicMock() + mock_processor = Mock(side_effect=Exception("Model processing error")) + + with patch('outlines.models.transformers', return_value=mock_model), \ + patch('outlines.generate.json', return_value=mock_processor): + + backend = OutlinesBackend(mock_global_config) + messages = [{"role": "user", "content": "Test message"}] + + with pytest.raises(InvalidOutputError) as exc_info: + backend.process_messages( + model_path="test-model", + messages=messages, + output_schema=mock_output_schema + ) + + assert "Model processing error" in str(exc_info.value) \ No newline at end of file