From 9de31a9b844da2955f5cc0653243bec32934ad0b Mon Sep 17 00:00:00 2001 From: Alexis Deprez Date: Tue, 27 Aug 2024 12:06:50 +0200 Subject: [PATCH] fix(#588): avoid extraction llm answer truncature --- lavague-core/lavague/core/action_engine.py | 8 +++-- lavague-core/lavague/core/context.py | 3 ++ lavague-core/lavague/core/extractors.py | 34 ++++++++++++------- lavague-core/lavague/core/python_engine.py | 30 ++++++---------- .../lavague/contexts/openai/base.py | 6 ++++ 5 files changed, 47 insertions(+), 34 deletions(-) diff --git a/lavague-core/lavague/core/action_engine.py b/lavague-core/lavague/core/action_engine.py index d36b851f..25cd86c3 100644 --- a/lavague-core/lavague/core/action_engine.py +++ b/lavague-core/lavague/core/action_engine.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Dict +from typing import Dict, Optional from llama_index.core import PromptTemplate from llama_index.core.base.llms.base import BaseLLM from llama_index.core.base.embeddings.base import BaseEmbedding @@ -49,6 +49,7 @@ def __init__( python_engine: BaseEngine = None, navigation_control: BaseEngine = None, llm: BaseLLM = None, + extraction_llm: Optional[BaseLLM] = None, embedding: BaseEmbedding = None, retriever: BaseHtmlRetriever = None, prompt_template: PromptTemplate = NAVIGATION_ENGINE_PROMPT_TEMPLATE.prompt_template, @@ -63,6 +64,9 @@ def __init__( if embedding is None: embedding = get_default_context().embedding + if extraction_llm is None: + extraction_llm = get_default_context().extraction_llm + self.driver = driver if retriever is None: @@ -81,7 +85,7 @@ def __init__( embedding=embedding, ) if python_engine is None: - python_engine = PythonEngine(driver, llm, embedding) + python_engine = PythonEngine(driver, extraction_llm, embedding) if navigation_control is None: navigation_control = NavigationControl( driver, diff --git a/lavague-core/lavague/core/context.py b/lavague-core/lavague/core/context.py index 5c47bbc8..b093b879 100644 --- a/lavague-core/lavague/core/context.py +++ b/lavague-core/lavague/core/context.py @@ -1,6 +1,7 @@ from llama_index.core.llms import LLM from llama_index.core.multi_modal_llms import MultiModalLLM from llama_index.core.embeddings import BaseEmbedding +from typing import Optional DEFAULT_MAX_TOKENS = 512 DEFAULT_TEMPERATURE = 0.0 @@ -14,6 +15,7 @@ def __init__( llm: LLM, mm_llm: MultiModalLLM, embedding: BaseEmbedding, + extraction_llm: Optional[LLM] = None, ): """ llm (`LLM`): @@ -26,6 +28,7 @@ def __init__( self.llm = llm self.mm_llm = mm_llm self.embedding = embedding + self.extraction_llm = extraction_llm or llm def get_default_context() -> Context: diff --git a/lavague-core/lavague/core/extractors.py b/lavague-core/lavague/core/extractors.py index cb17689d..de0c0a21 100644 --- a/lavague-core/lavague/core/extractors.py +++ b/lavague-core/lavague/core/extractors.py @@ -3,7 +3,7 @@ from jsonschema import validate, ValidationError import yaml import json -from typing import Any, Dict +from typing import Any, Dict, Tuple def extract_xpaths_from_html(html): @@ -59,11 +59,18 @@ def extract(self, markdown_text: str) -> str: if match: # Return the first matched group, which is the code inside the ```python ``` yml_str = match.group(1).strip() + cleaned_yml = re.sub(r"^```.*\n|```$", "", yml_str, flags=re.DOTALL) try: - yaml.safe_load(yml_str) - return yml_str + yaml.safe_load(cleaned_yml) + return cleaned_yml except yaml.YAMLError: - return None + # retry with extra quote in case of truncated output + cleaned_yml += '"' + try: + yaml.safe_load(cleaned_yml) + return cleaned_yml + except yaml.YAMLError: + return None def extract_as_object(self, text: str): return yaml.safe_load(self.extract(text)) @@ -164,27 +171,28 @@ def __init__(self): "python": PythonFromMarkdownExtractor(), } - def get_type(self, text: str) -> str: + def get_type(self, text: str) -> Tuple[str, str]: types_pattern = "|".join(self.extractors.keys()) pattern = rf"```({types_pattern}).*?```" match = re.search(pattern, text, re.DOTALL) if match: - return match.group(1).strip() + return match.group(1).strip(), text else: - # Try to auto-detect first matching extractor + # Try to auto-detect first matching extractor, and remove extra ```(type)``` wrappers + cleaned_text = re.sub(r"^```.*\n|```$", "", text, flags=re.DOTALL) for type, extractor in self.extractors.items(): try: - value = extractor.extract(text) + value = extractor.extract(cleaned_text) if value: - return type + return type, value except: pass raise ValueError(f"No extractor pattern can be found from {text}") def extract(self, text: str) -> str: - type = self.get_type(text) - return self.extractors[type].extract(text) + type, target_text = self.get_type(text) + return self.extractors[type].extract(target_text) def extract_as_object(self, text: str) -> Any: - type = self.get_type(text) - return self.extractors[type].extract_as_object(text) + type, target_text = self.get_type(text) + return self.extractors[type].extract_as_object(target_text) diff --git a/lavague-core/lavague/core/python_engine.py b/lavague-core/lavague/core/python_engine.py index 9a763d51..30cc8c33 100644 --- a/lavague-core/lavague/core/python_engine.py +++ b/lavague-core/lavague/core/python_engine.py @@ -1,4 +1,3 @@ -import json import shutil import time from io import BytesIO @@ -20,7 +19,6 @@ from llama_index.core import Document, VectorStoreIndex from llama_index.core.base.llms.base import BaseLLM from llama_index.core.embeddings import BaseEmbedding -import re from lavague.core.extractors import DynamicExtractor DEFAULT_TEMPERATURE = 0.0 @@ -60,14 +58,14 @@ def __init__( temp_screenshots_path="./tmp_screenshots", n_search_attemps=10, ): - self.llm = llm or get_default_context().llm + self.llm = llm or get_default_context().extraction_llm self.embedding = embedding or get_default_context().embedding self.clean_html = clean_html self.driver = driver self.logger = logger self.display = display self.ocr_mm_llm = ocr_mm_llm or OpenAIMultiModal( - model="gpt-4o-mini", temperature=DEFAULT_TEMPERATURE + model="gpt-4o-mini", temperature=DEFAULT_TEMPERATURE, max_new_tokens=16384 ) self.ocr_llm = ocr_llm or self.llm self.batch_size = batch_size @@ -80,15 +78,9 @@ def __init__( def from_context(cls, context: Context, driver: BaseDriver): return cls(llm=context.llm, embedding=context.embedding, driver=driver) - def extract_json(self, output: str) -> Optional[dict]: + def extract_structured_data(self, output: str) -> Optional[dict]: extractor = DynamicExtractor() - clean = extractor.extract(output) - try: - output_dict = json.loads(clean) - except json.JSONDecodeError as e: - print(f"Error extracting Json: {e}") - return None - return output_dict + return extractor.extract_as_object(output) def get_screenshots_batch(self) -> list[str]: screenshot_paths = [] @@ -149,7 +141,7 @@ def perform_fallback(self, prompt, instruction) -> str: output = self.ocr_mm_llm.complete( image_documents=screenshots, prompt=prompt ).text.strip() - output_dict = self.extract_json(output) + output_dict = self.extract_structured_data(output) if output_dict: context_score = output_dict.get("score", 0) output = output_dict.get("ret") @@ -193,17 +185,17 @@ def execute_instruction(self, instruction: str) -> ActionResult: query_engine = index.as_query_engine(llm=llm) prompt = f""" - Based on the context provided, you must respond to query with a JSON object in the following format: - {{ - "ret": "[your answer]", - "score": [a float value between 0 and 1 on your confidence that you have enough context to answer the question] - }} + Based on the context provided, you must respond to query with a YAML object in the following format: + ```yaml + score: [a float value between 0 and 1 on your confidence that you have enough context to answer the question] + ret: "[your answer]" + ``` If you do not have sufficient context, set 'ret' to 'Insufficient context' and 'score' to 0. The query is: {instruction} """ output = query_engine.query(prompt).response.strip() - output_dict = self.extract_json(output) + output_dict = self.extract_structured_data(output) try: if ( diff --git a/lavague-integrations/contexts/lavague-contexts-openai/lavague/contexts/openai/base.py b/lavague-integrations/contexts/lavague-contexts-openai/lavague/contexts/openai/base.py index cb31f4dc..4ee270c7 100644 --- a/lavague-integrations/contexts/lavague-contexts-openai/lavague/contexts/openai/base.py +++ b/lavague-integrations/contexts/lavague-contexts-openai/lavague/contexts/openai/base.py @@ -30,6 +30,12 @@ def __init__( ), OpenAIMultiModal(api_key=api_key, model=mm_llm), OpenAIEmbedding(api_key=api_key, model=embedding), + OpenAI( + api_key=api_key, + model=llm, + max_tokens=4096, + temperature=DEFAULT_TEMPERATURE, + ), )