Skip to content

Commit

Permalink
update get-action-from-context (#586)
Browse files Browse the repository at this point in the history
* update get-action-from-context

* update python engine

* add authorized paths to gradio
  • Loading branch information
lyie28 authored Aug 26, 2024
1 parent e6a6588 commit 2717a87
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
14 changes: 12 additions & 2 deletions lavague-core/lavague/core/navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,14 @@ def get_action_from_context(self, context: str, query: str) -> str:
"""
Generate the code from a query and a context
"""
prompt = self.prompt_template.format(context_str=context, query_str=query)
authorized_xpaths = extract_xpaths_from_html(context)

prompt = self.prompt_template.format(
context_str=context,
query_str=query,
authorized_xpaths=authorized_xpaths,
)

response = self.llm.complete(prompt).text
code = self.extractor.extract(response)
return code
Expand Down Expand Up @@ -241,8 +248,11 @@ def execute_instruction_gradio(self, instruction: str, action_engine: Any):
except:
pass
start = time.time()
authorized_xpaths = extract_xpaths_from_html(llm_context)
prompt = self.prompt_template.format(
context_str=llm_context, query_str=instruction
context_str=llm_context,
query_str=instruction,
authorized_xpaths=authorized_xpaths,
)
response = self.llm.complete(prompt).text
end = time.time()
Expand Down
27 changes: 13 additions & 14 deletions lavague-core/lavague/core/python_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.embeddings import BaseEmbedding
import re
from lavague.core.extractors import DynamicExtractor

DEFAULT_TEMPERATURE = 0.0

Expand All @@ -39,6 +40,7 @@ class PythonEngine(BaseEngine):
ocr_llm: BaseLLM
batch_size: int
confidence_threshold: float
fallback_theshold: float
temp_screenshots_path: str
n_search_attempts: int

Expand All @@ -54,6 +56,7 @@ def __init__(
display: bool = False,
batch_size: int = 5,
confidence_threshold: float = 0.85,
fallback_threshold: float = 0.85,
temp_screenshots_path="./tmp_screenshots",
n_search_attemps=10,
):
Expand All @@ -71,24 +74,19 @@ def __init__(
self.confidence_threshold = confidence_threshold
self.temp_screenshots_path = temp_screenshots_path
self.n_search_attempts = n_search_attemps
self.fallback_theshold = fallback_threshold

@classmethod
def from_context(cls, context: Context, driver: BaseDriver):
return cls(llm=context.llm, embedding=context.embedding, driver=driver)

def extract_json(self, output: str) -> Optional[dict]:
clean = (
output.replace("'ret'", '"ret"')
.replace("'score'", '"score"')
.replace("```json", "")
.replace("```", "")
.strip()
)
clean = re.sub(r"\n+", "\n", clean)
extractor = DynamicExtractor()
clean = extractor.extract(output)
try:
output_dict = json.loads(clean)
except json.JSONDecodeError as e:
print(f"Error extracting JSON: {e}")
print(f"Error extracting Json: {e}")
return None
return output_dict

Expand Down Expand Up @@ -121,7 +119,7 @@ def perform_fallback(self, prompt, instruction) -> str:
context_score = -1

prompt = f"""
You must respond with a dictionary in the following format:
You must respond with a JSON object in the following format:
{{
"ret": "[any relevant text transcribed from the image in order to answer the query {instruction} - make sure to answer with full sentences so the reponse can be understood out of context.]",
"score": [a confidence score between 0 and 1 that the necessary context has been captured in order to answer the following query]
Expand Down Expand Up @@ -152,9 +150,10 @@ def perform_fallback(self, prompt, instruction) -> str:
image_documents=screenshots, prompt=prompt
).text.strip()
output_dict = self.extract_json(output)
context_score = output_dict.get("score")
output = output_dict.get("ret")
memory += output
if output_dict:
context_score = output_dict.get("score", 0)
output = output_dict.get("ret")
memory += output

# delete temp image folder
shutil.rmtree(Path(self.temp_screenshots_path))
Expand Down Expand Up @@ -208,7 +207,7 @@ def execute_instruction(self, instruction: str) -> ActionResult:

try:
if (
output_dict.get("score", 0) < self.confidence_threshold
output_dict.get("score", 0) < self.fallback_theshold
): # use fallback method
output = self.perform_fallback(prompt=prompt, instruction=instruction)

Expand Down

0 comments on commit 2717a87

Please sign in to comment.