-
Notifications
You must be signed in to change notification settings - Fork 1
/
respond.py
76 lines (53 loc) · 1.74 KB
/
respond.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import json
from pydantic import BaseModel
from gba.client import ChatClient, Llama3Instruct
from gba.tools.base import Tool
from gba.utils import Scratchpad
class Result(BaseModel):
answer: str
SYSTEM_PROMPT = """You are a helpful assistant."""
USER_PROMPT_TEMPLATE = """You are given a user request and context information:
User request:
```
{request}
```
Context information:
```
{context}
```
Answer the user request using the available context information only.
The answer should be a single sentence in natural language containing all relevant information to answer the user request.
Do not mention the existence of that context in your answer.
Use the following output format:
{{
"answer": <generated answer>
}}"""
class RespondTool(Tool):
name: str = "respond_to_user"
def __init__(self, model: Llama3Instruct):
self.client = ChatClient(model=model)
def run(
self,
request: str,
task: str,
scratchpad: Scratchpad,
temperature: float = -1,
return_user_prompt: bool = False,
**kwargs,
) -> str:
"""Useful for responding with a final answer to the user request."""
user_prompt = USER_PROMPT_TEMPLATE.format(request=request, context=scratchpad.entries_repr())
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
]
message = self.client.complete(
messages,
schema=Result.model_json_schema(),
temperature=temperature,
)
result = json.loads(message["content"])
answer = result["answer"]
if return_user_prompt:
return answer, user_prompt # type: ignore
return answer