diff --git a/lavague-core/lavague/core/navigation.py b/lavague-core/lavague/core/navigation.py index cc15f0ae..c8b52f79 100644 --- a/lavague-core/lavague/core/navigation.py +++ b/lavague-core/lavague/core/navigation.py @@ -166,7 +166,14 @@ def get_action_from_context(self, context: str, query: str) -> str: """ Generate the code from a query and a context """ - prompt = self.prompt_template.format(context_str=context, query_str=query) + authorized_xpaths = extract_xpaths_from_html(context) + + prompt = self.prompt_template.format( + context_str=context, + query_str=query, + authorized_xpaths=authorized_xpaths, + ) + response = self.llm.complete(prompt).text code = self.extractor.extract(response) return code @@ -241,8 +248,11 @@ def execute_instruction_gradio(self, instruction: str, action_engine: Any): except: pass start = time.time() + authorized_xpaths = extract_xpaths_from_html(llm_context) prompt = self.prompt_template.format( - context_str=llm_context, query_str=instruction + context_str=llm_context, + query_str=instruction, + authorized_xpaths=authorized_xpaths, ) response = self.llm.complete(prompt).text end = time.time() diff --git a/lavague-core/lavague/core/python_engine.py b/lavague-core/lavague/core/python_engine.py index 5d2fdae2..9a763d51 100644 --- a/lavague-core/lavague/core/python_engine.py +++ b/lavague-core/lavague/core/python_engine.py @@ -21,6 +21,7 @@ from llama_index.core.base.llms.base import BaseLLM from llama_index.core.embeddings import BaseEmbedding import re +from lavague.core.extractors import DynamicExtractor DEFAULT_TEMPERATURE = 0.0 @@ -39,6 +40,7 @@ class PythonEngine(BaseEngine): ocr_llm: BaseLLM batch_size: int confidence_threshold: float + fallback_theshold: float temp_screenshots_path: str n_search_attempts: int @@ -54,6 +56,7 @@ def __init__( display: bool = False, batch_size: int = 5, confidence_threshold: float = 0.85, + fallback_threshold: float = 0.85, temp_screenshots_path="./tmp_screenshots", n_search_attemps=10, ): @@ -71,24 +74,19 @@ def __init__( self.confidence_threshold = confidence_threshold self.temp_screenshots_path = temp_screenshots_path self.n_search_attempts = n_search_attemps + self.fallback_theshold = fallback_threshold @classmethod def from_context(cls, context: Context, driver: BaseDriver): return cls(llm=context.llm, embedding=context.embedding, driver=driver) def extract_json(self, output: str) -> Optional[dict]: - clean = ( - output.replace("'ret'", '"ret"') - .replace("'score'", '"score"') - .replace("```json", "") - .replace("```", "") - .strip() - ) - clean = re.sub(r"\n+", "\n", clean) + extractor = DynamicExtractor() + clean = extractor.extract(output) try: output_dict = json.loads(clean) except json.JSONDecodeError as e: - print(f"Error extracting JSON: {e}") + print(f"Error extracting Json: {e}") return None return output_dict @@ -121,7 +119,7 @@ def perform_fallback(self, prompt, instruction) -> str: context_score = -1 prompt = f""" - You must respond with a dictionary in the following format: + You must respond with a JSON object in the following format: {{ "ret": "[any relevant text transcribed from the image in order to answer the query {instruction} - make sure to answer with full sentences so the reponse can be understood out of context.]", "score": [a confidence score between 0 and 1 that the necessary context has been captured in order to answer the following query] @@ -152,9 +150,10 @@ def perform_fallback(self, prompt, instruction) -> str: image_documents=screenshots, prompt=prompt ).text.strip() output_dict = self.extract_json(output) - context_score = output_dict.get("score") - output = output_dict.get("ret") - memory += output + if output_dict: + context_score = output_dict.get("score", 0) + output = output_dict.get("ret") + memory += output # delete temp image folder shutil.rmtree(Path(self.temp_screenshots_path)) @@ -208,7 +207,7 @@ def execute_instruction(self, instruction: str) -> ActionResult: try: if ( - output_dict.get("score", 0) < self.confidence_threshold + output_dict.get("score", 0) < self.fallback_theshold ): # use fallback method output = self.perform_fallback(prompt=prompt, instruction=instruction)