Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update get-action-from-context #586

Merged
merged 8 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions lavague-core/lavague/core/navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,14 @@ def get_action_from_context(self, context: str, query: str) -> str:
"""
Generate the code from a query and a context
"""
prompt = self.prompt_template.format(context_str=context, query_str=query)
authorized_xpaths = extract_xpaths_from_html(context)

prompt = self.prompt_template.format(
context_str=context,
query_str=query,
authorized_xpaths=authorized_xpaths,
)

response = self.llm.complete(prompt).text
code = self.extractor.extract(response)
return code
Expand Down Expand Up @@ -241,8 +248,11 @@ def execute_instruction_gradio(self, instruction: str, action_engine: Any):
except:
pass
start = time.time()
authorized_xpaths = extract_xpaths_from_html(llm_context)
prompt = self.prompt_template.format(
context_str=llm_context, query_str=instruction
context_str=llm_context,
query_str=instruction,
authorized_xpaths=authorized_xpaths,
)
response = self.llm.complete(prompt).text
end = time.time()
Expand Down
27 changes: 13 additions & 14 deletions lavague-core/lavague/core/python_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.embeddings import BaseEmbedding
import re
from lavague.core.extractors import DynamicExtractor

DEFAULT_TEMPERATURE = 0.0

Expand All @@ -39,6 +40,7 @@ class PythonEngine(BaseEngine):
ocr_llm: BaseLLM
batch_size: int
confidence_threshold: float
fallback_theshold: float
temp_screenshots_path: str
n_search_attempts: int

Expand All @@ -54,6 +56,7 @@ def __init__(
display: bool = False,
batch_size: int = 5,
confidence_threshold: float = 0.85,
fallback_threshold: float = 0.85,
temp_screenshots_path="./tmp_screenshots",
n_search_attemps=10,
):
Expand All @@ -71,24 +74,19 @@ def __init__(
self.confidence_threshold = confidence_threshold
self.temp_screenshots_path = temp_screenshots_path
self.n_search_attempts = n_search_attemps
self.fallback_theshold = fallback_threshold

@classmethod
def from_context(cls, context: Context, driver: BaseDriver):
return cls(llm=context.llm, embedding=context.embedding, driver=driver)

def extract_json(self, output: str) -> Optional[dict]:
clean = (
output.replace("'ret'", '"ret"')
.replace("'score'", '"score"')
.replace("```json", "")
.replace("```", "")
.strip()
)
clean = re.sub(r"\n+", "\n", clean)
extractor = DynamicExtractor()
clean = extractor.extract(output)
try:
output_dict = json.loads(clean)
except json.JSONDecodeError as e:
print(f"Error extracting JSON: {e}")
print(f"Error extracting Json: {e}")
return None
return output_dict

Expand Down Expand Up @@ -121,7 +119,7 @@ def perform_fallback(self, prompt, instruction) -> str:
context_score = -1

prompt = f"""
You must respond with a dictionary in the following format:
You must respond with a JSON object in the following format:
{{
"ret": "[any relevant text transcribed from the image in order to answer the query {instruction} - make sure to answer with full sentences so the reponse can be understood out of context.]",
"score": [a confidence score between 0 and 1 that the necessary context has been captured in order to answer the following query]
Expand Down Expand Up @@ -152,9 +150,10 @@ def perform_fallback(self, prompt, instruction) -> str:
image_documents=screenshots, prompt=prompt
).text.strip()
output_dict = self.extract_json(output)
context_score = output_dict.get("score")
output = output_dict.get("ret")
memory += output
if output_dict:
context_score = output_dict.get("score", 0)
output = output_dict.get("ret")
memory += output

# delete temp image folder
shutil.rmtree(Path(self.temp_screenshots_path))
Expand Down Expand Up @@ -208,7 +207,7 @@ def execute_instruction(self, instruction: str) -> ActionResult:

try:
if (
output_dict.get("score", 0) < self.confidence_threshold
output_dict.get("score", 0) < self.fallback_theshold
): # use fallback method
output = self.perform_fallback(prompt=prompt, instruction=instruction)

Expand Down
Loading