Skip to content

Commit

Permalink
#588 fix exceeded output tokens limit for extraction (#591)
Browse files Browse the repository at this point in the history
* fix(#588): avoid extraction llm answer truncature

* chore: update doc

* fix: compress llm output
  • Loading branch information
adeprez authored Aug 27, 2024
1 parent eb177eb commit c0a7fc8
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 35 deletions.
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ world_model = WorldModel()
action_engine = ActionEngine(selenium_driver)
agent = WebAgent(world_model, action_engine)
agent.get("https://huggingface.co/docs")
agent.run("Go on the installation page for PEFT")
agent.run("Go on the quicktour of PEFT")

# Launch Gradio Agent Demo
agent.demo("Go on the quicktour of PEFT")
Expand Down
8 changes: 6 additions & 2 deletions lavague-core/lavague/core/action_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Dict
from typing import Dict, Optional
from llama_index.core import PromptTemplate
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.base.embeddings.base import BaseEmbedding
Expand Down Expand Up @@ -49,6 +49,7 @@ def __init__(
python_engine: BaseEngine = None,
navigation_control: BaseEngine = None,
llm: BaseLLM = None,
extraction_llm: Optional[BaseLLM] = None,
embedding: BaseEmbedding = None,
retriever: BaseHtmlRetriever = None,
prompt_template: PromptTemplate = NAVIGATION_ENGINE_PROMPT_TEMPLATE.prompt_template,
Expand All @@ -63,6 +64,9 @@ def __init__(
if embedding is None:
embedding = get_default_context().embedding

if extraction_llm is None:
extraction_llm = get_default_context().extraction_llm

self.driver = driver

if retriever is None:
Expand All @@ -81,7 +85,7 @@ def __init__(
embedding=embedding,
)
if python_engine is None:
python_engine = PythonEngine(driver, llm, embedding)
python_engine = PythonEngine(driver, extraction_llm, embedding)
if navigation_control is None:
navigation_control = NavigationControl(
driver,
Expand Down
3 changes: 3 additions & 0 deletions lavague-core/lavague/core/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from llama_index.core.llms import LLM
from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.embeddings import BaseEmbedding
from typing import Optional

DEFAULT_MAX_TOKENS = 512
DEFAULT_TEMPERATURE = 0.0
Expand All @@ -14,6 +15,7 @@ def __init__(
llm: LLM,
mm_llm: MultiModalLLM,
embedding: BaseEmbedding,
extraction_llm: Optional[LLM] = None,
):
"""
llm (`LLM`):
Expand All @@ -26,6 +28,7 @@ def __init__(
self.llm = llm
self.mm_llm = mm_llm
self.embedding = embedding
self.extraction_llm = extraction_llm or llm


def get_default_context() -> Context:
Expand Down
34 changes: 21 additions & 13 deletions lavague-core/lavague/core/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jsonschema import validate, ValidationError
import yaml
import json
from typing import Any, Dict
from typing import Any, Dict, Tuple


def extract_xpaths_from_html(html):
Expand Down Expand Up @@ -59,11 +59,18 @@ def extract(self, markdown_text: str) -> str:
if match:
# Return the first matched group, which is the code inside the ```python ```
yml_str = match.group(1).strip()
cleaned_yml = re.sub(r"^```.*\n|```$", "", yml_str, flags=re.DOTALL)
try:
yaml.safe_load(yml_str)
return yml_str
yaml.safe_load(cleaned_yml)
return cleaned_yml
except yaml.YAMLError:
return None
# retry with extra quote in case of truncated output
cleaned_yml += '"'
try:
yaml.safe_load(cleaned_yml)
return cleaned_yml
except yaml.YAMLError:
return None

def extract_as_object(self, text: str):
return yaml.safe_load(self.extract(text))
Expand Down Expand Up @@ -164,27 +171,28 @@ def __init__(self):
"python": PythonFromMarkdownExtractor(),
}

def get_type(self, text: str) -> str:
def get_type(self, text: str) -> Tuple[str, str]:
types_pattern = "|".join(self.extractors.keys())
pattern = rf"```({types_pattern}).*?```"
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1).strip()
return match.group(1).strip(), text
else:
# Try to auto-detect first matching extractor
# Try to auto-detect first matching extractor, and remove extra ```(type)``` wrappers
cleaned_text = re.sub(r"^```.*\n|```$", "", text, flags=re.DOTALL)
for type, extractor in self.extractors.items():
try:
value = extractor.extract(text)
value = extractor.extract(cleaned_text)
if value:
return type
return type, value
except:
pass
raise ValueError(f"No extractor pattern can be found from {text}")

def extract(self, text: str) -> str:
type = self.get_type(text)
return self.extractors[type].extract(text)
type, target_text = self.get_type(text)
return self.extractors[type].extract(target_text)

def extract_as_object(self, text: str) -> Any:
type = self.get_type(text)
return self.extractors[type].extract_as_object(text)
type, target_text = self.get_type(text)
return self.extractors[type].extract_as_object(target_text)
31 changes: 12 additions & 19 deletions lavague-core/lavague/core/python_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import shutil
import time
from io import BytesIO
Expand All @@ -20,7 +19,6 @@
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.embeddings import BaseEmbedding
import re
from lavague.core.extractors import DynamicExtractor

DEFAULT_TEMPERATURE = 0.0
Expand Down Expand Up @@ -60,14 +58,14 @@ def __init__(
temp_screenshots_path="./tmp_screenshots",
n_search_attemps=10,
):
self.llm = llm or get_default_context().llm
self.llm = llm or get_default_context().extraction_llm
self.embedding = embedding or get_default_context().embedding
self.clean_html = clean_html
self.driver = driver
self.logger = logger
self.display = display
self.ocr_mm_llm = ocr_mm_llm or OpenAIMultiModal(
model="gpt-4o-mini", temperature=DEFAULT_TEMPERATURE
model="gpt-4o-mini", temperature=DEFAULT_TEMPERATURE, max_new_tokens=16384
)
self.ocr_llm = ocr_llm or self.llm
self.batch_size = batch_size
Expand All @@ -80,15 +78,9 @@ def __init__(
def from_context(cls, context: Context, driver: BaseDriver):
return cls(llm=context.llm, embedding=context.embedding, driver=driver)

def extract_json(self, output: str) -> Optional[dict]:
def extract_structured_data(self, output: str) -> Optional[dict]:
extractor = DynamicExtractor()
clean = extractor.extract(output)
try:
output_dict = json.loads(clean)
except json.JSONDecodeError as e:
print(f"Error extracting Json: {e}")
return None
return output_dict
return extractor.extract_as_object(output)

def get_screenshots_batch(self) -> list[str]:
screenshot_paths = []
Expand Down Expand Up @@ -149,7 +141,7 @@ def perform_fallback(self, prompt, instruction) -> str:
output = self.ocr_mm_llm.complete(
image_documents=screenshots, prompt=prompt
).text.strip()
output_dict = self.extract_json(output)
output_dict = self.extract_structured_data(output)
if output_dict:
context_score = output_dict.get("score", 0)
output = output_dict.get("ret")
Expand Down Expand Up @@ -193,17 +185,18 @@ def execute_instruction(self, instruction: str) -> ActionResult:
query_engine = index.as_query_engine(llm=llm)

prompt = f"""
Based on the context provided, you must respond to query with a JSON object in the following format:
{{
"ret": "[your answer]",
"score": [a float value between 0 and 1 on your confidence that you have enough context to answer the question]
}}
Based on the context provided, you must respond to query with a YAML object in the following format:
```yaml
score: [a float value between 0 and 1 on your confidence that you have enough context to answer the question]
ret: "[your answer]"
```
If you do not have sufficient context, set 'ret' to 'Insufficient context' and 'score' to 0.
Keep the answer in 'ret' concise but informative.
The query is: {instruction}
"""

output = query_engine.query(prompt).response.strip()
output_dict = self.extract_json(output)
output_dict = self.extract_structured_data(output)

try:
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def __init__(
),
OpenAIMultiModal(api_key=api_key, model=mm_llm),
OpenAIEmbedding(api_key=api_key, model=embedding),
OpenAI(
api_key=api_key,
model=llm,
max_tokens=4096,
temperature=DEFAULT_TEMPERATURE,
),
)


Expand Down

0 comments on commit c0a7fc8

Please sign in to comment.