-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmemory.py
80 lines (65 loc) · 2.73 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
This module defines a `ConversationMemoryRunnable` class that stores and manages
conversation history for an LLM-based assistant.
"""
from langchain_core.runnables import Runnable
# Initialize a global variable to store the conversation history
conversation_history = []
class ConversationMemoryRunnable(Runnable):
"""
A runnable that manages conversation memory by storing user inputs and
assistant responses. It uses an LLM to generate responses based on session history.
"""
def __init__(self, llm):
"""
Initializes the ConversationMemoryRunnable with an LLM.
Args:
llm: The language model used to generate responses.
"""
self.llm = llm
def run(self, input_text: str, **kwargs) -> str:
"""
Generates a response from the LLM based on conversation history and
the latest user input.
Args:
input_text: The user's input message.
**kwargs: Additional keyword arguments for the LLM.
Returns:
str: The response generated by the LLM.
"""
session_history = get_session_history()
combined_input = f"{session_history}\nUser: {input_text}\nAssistant:"
response = self.llm.invoke(combined_input, **kwargs)
update_conversation_history(input_text, response)
return response
def invoke(self, *args, **kwargs) -> str:
"""
Invokes the run method using the first argument as input text.
Args:
*args: Positional arguments where the first argument is expected to be the input text.
**kwargs: Additional keyword arguments for the run method.
Returns:
str: The response generated by the run method.
"""
input_text = args[0] if args else ""
return self.run(input_text, **kwargs)
def get_session_history() -> str:
"""
Retrieves the formatted session history of conversation between the user and the assistant.
Returns:
str: A formatted string of the conversation history.
"""
formatted_history = "\n".join(
f"User: {entry['user']}\nAssistant: {entry['assistant']}" for entry in conversation_history
)
return formatted_history
def update_conversation_history(user_message: str, assistant_response: str) -> None:
"""
Updates the conversation history by adding the latest user message and assistant response.
Args:
user_message: The user's input message.
assistant_response: The assistant's response message.
"""
conversation_history.append({"user": user_message, "assistant": assistant_response})
if len(conversation_history) > 2: # Adjust the limit as needed
conversation_history.pop(0) # Remove the oldest entry