Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#588 fix exceeded output tokens limit for extraction #591

Merged
merged 3 commits into from
Aug 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
@@ -55,7 +55,7 @@ world_model = WorldModel()
action_engine = ActionEngine(selenium_driver)
agent = WebAgent(world_model, action_engine)
agent.get("https://huggingface.co/docs")
agent.run("Go on the installation page for PEFT")
agent.run("Go on the quicktour of PEFT")

# Launch Gradio Agent Demo
agent.demo("Go on the quicktour of PEFT")
8 changes: 6 additions & 2 deletions lavague-core/lavague/core/action_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Dict
from typing import Dict, Optional
from llama_index.core import PromptTemplate
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.base.embeddings.base import BaseEmbedding
@@ -49,6 +49,7 @@ def __init__(
python_engine: BaseEngine = None,
navigation_control: BaseEngine = None,
llm: BaseLLM = None,
extraction_llm: Optional[BaseLLM] = None,
embedding: BaseEmbedding = None,
retriever: BaseHtmlRetriever = None,
prompt_template: PromptTemplate = NAVIGATION_ENGINE_PROMPT_TEMPLATE.prompt_template,
@@ -63,6 +64,9 @@ def __init__(
if embedding is None:
embedding = get_default_context().embedding

if extraction_llm is None:
extraction_llm = get_default_context().extraction_llm

self.driver = driver

if retriever is None:
@@ -81,7 +85,7 @@ def __init__(
embedding=embedding,
)
if python_engine is None:
python_engine = PythonEngine(driver, llm, embedding)
python_engine = PythonEngine(driver, extraction_llm, embedding)
if navigation_control is None:
navigation_control = NavigationControl(
driver,
3 changes: 3 additions & 0 deletions lavague-core/lavague/core/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from llama_index.core.llms import LLM
from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.embeddings import BaseEmbedding
from typing import Optional

DEFAULT_MAX_TOKENS = 512
DEFAULT_TEMPERATURE = 0.0
@@ -14,6 +15,7 @@ def __init__(
llm: LLM,
mm_llm: MultiModalLLM,
embedding: BaseEmbedding,
extraction_llm: Optional[LLM] = None,
):
"""
llm (`LLM`):
@@ -26,6 +28,7 @@ def __init__(
self.llm = llm
self.mm_llm = mm_llm
self.embedding = embedding
self.extraction_llm = extraction_llm or llm


def get_default_context() -> Context:
34 changes: 21 additions & 13 deletions lavague-core/lavague/core/extractors.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from jsonschema import validate, ValidationError
import yaml
import json
from typing import Any, Dict
from typing import Any, Dict, Tuple


def extract_xpaths_from_html(html):
@@ -59,11 +59,18 @@ def extract(self, markdown_text: str) -> str:
if match:
# Return the first matched group, which is the code inside the ```python ```
yml_str = match.group(1).strip()
cleaned_yml = re.sub(r"^```.*\n|```$", "", yml_str, flags=re.DOTALL)
try:
yaml.safe_load(yml_str)
return yml_str
yaml.safe_load(cleaned_yml)
return cleaned_yml
except yaml.YAMLError:
return None
# retry with extra quote in case of truncated output
cleaned_yml += '"'
try:
yaml.safe_load(cleaned_yml)
return cleaned_yml
except yaml.YAMLError:
return None

def extract_as_object(self, text: str):
return yaml.safe_load(self.extract(text))
@@ -164,27 +171,28 @@ def __init__(self):
"python": PythonFromMarkdownExtractor(),
}

def get_type(self, text: str) -> str:
def get_type(self, text: str) -> Tuple[str, str]:
types_pattern = "|".join(self.extractors.keys())
pattern = rf"```({types_pattern}).*?```"
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1).strip()
return match.group(1).strip(), text
else:
# Try to auto-detect first matching extractor
# Try to auto-detect first matching extractor, and remove extra ```(type)``` wrappers
cleaned_text = re.sub(r"^```.*\n|```$", "", text, flags=re.DOTALL)
for type, extractor in self.extractors.items():
try:
value = extractor.extract(text)
value = extractor.extract(cleaned_text)
if value:
return type
return type, value
except:
pass
raise ValueError(f"No extractor pattern can be found from {text}")

def extract(self, text: str) -> str:
type = self.get_type(text)
return self.extractors[type].extract(text)
type, target_text = self.get_type(text)
return self.extractors[type].extract(target_text)

def extract_as_object(self, text: str) -> Any:
type = self.get_type(text)
return self.extractors[type].extract_as_object(text)
type, target_text = self.get_type(text)
return self.extractors[type].extract_as_object(target_text)
31 changes: 12 additions & 19 deletions lavague-core/lavague/core/python_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import shutil
import time
from io import BytesIO
@@ -20,7 +19,6 @@
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.embeddings import BaseEmbedding
import re
from lavague.core.extractors import DynamicExtractor

DEFAULT_TEMPERATURE = 0.0
@@ -60,14 +58,14 @@ def __init__(
temp_screenshots_path="./tmp_screenshots",
n_search_attemps=10,
):
self.llm = llm or get_default_context().llm
self.llm = llm or get_default_context().extraction_llm
self.embedding = embedding or get_default_context().embedding
self.clean_html = clean_html
self.driver = driver
self.logger = logger
self.display = display
self.ocr_mm_llm = ocr_mm_llm or OpenAIMultiModal(
model="gpt-4o-mini", temperature=DEFAULT_TEMPERATURE
model="gpt-4o-mini", temperature=DEFAULT_TEMPERATURE, max_new_tokens=16384
)
self.ocr_llm = ocr_llm or self.llm
self.batch_size = batch_size
@@ -80,15 +78,9 @@ def __init__(
def from_context(cls, context: Context, driver: BaseDriver):
return cls(llm=context.llm, embedding=context.embedding, driver=driver)

def extract_json(self, output: str) -> Optional[dict]:
def extract_structured_data(self, output: str) -> Optional[dict]:
extractor = DynamicExtractor()
clean = extractor.extract(output)
try:
output_dict = json.loads(clean)
except json.JSONDecodeError as e:
print(f"Error extracting Json: {e}")
return None
return output_dict
return extractor.extract_as_object(output)

def get_screenshots_batch(self) -> list[str]:
screenshot_paths = []
@@ -149,7 +141,7 @@ def perform_fallback(self, prompt, instruction) -> str:
output = self.ocr_mm_llm.complete(
image_documents=screenshots, prompt=prompt
).text.strip()
output_dict = self.extract_json(output)
output_dict = self.extract_structured_data(output)
if output_dict:
context_score = output_dict.get("score", 0)
output = output_dict.get("ret")
@@ -193,17 +185,18 @@ def execute_instruction(self, instruction: str) -> ActionResult:
query_engine = index.as_query_engine(llm=llm)

prompt = f"""
Based on the context provided, you must respond to query with a JSON object in the following format:
{{
"ret": "[your answer]",
"score": [a float value between 0 and 1 on your confidence that you have enough context to answer the question]
}}
Based on the context provided, you must respond to query with a YAML object in the following format:
```yaml
score: [a float value between 0 and 1 on your confidence that you have enough context to answer the question]
ret: "[your answer]"
```
If you do not have sufficient context, set 'ret' to 'Insufficient context' and 'score' to 0.
Keep the answer in 'ret' concise but informative.
The query is: {instruction}
"""

output = query_engine.query(prompt).response.strip()
output_dict = self.extract_json(output)
output_dict = self.extract_structured_data(output)

try:
if (
Original file line number Diff line number Diff line change
@@ -30,6 +30,12 @@ def __init__(
),
OpenAIMultiModal(api_key=api_key, model=mm_llm),
OpenAIEmbedding(api_key=api_key, model=embedding),
OpenAI(
api_key=api_key,
model=llm,
max_tokens=4096,
temperature=DEFAULT_TEMPERATURE,
),
)