Skip to content

Commit

Permalink
Reset memory (#958)
Browse files Browse the repository at this point in the history
* reseting memory on cli

* using storage.reset

* deleting memories on command

* added tests

* handle when no flags are used

* added docs
  • Loading branch information
lorenzejay authored Jul 18, 2024
1 parent 61a1963 commit be1b9a3
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 4 deletions.
33 changes: 33 additions & 0 deletions docs/core-concepts/Memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,39 @@ my_crew = Crew(
)
```

### Resetting Memory
```sh
crewai reset_memories [OPTIONS]
```

#### Resetting Memory Options
- **`-l, --long`**
- **Description:** Reset LONG TERM memory.
- **Type:** Flag (boolean)
- **Default:** False

- **`-s, --short`**
- **Description:** Reset SHORT TERM memory.
- **Type:** Flag (boolean)
- **Default:** False

- **`-e, --entities`**
- **Description:** Reset ENTITIES memory.
- **Type:** Flag (boolean)
- **Default:** False

- **`-k, --kickoff-outputs`**
- **Description:** Reset LATEST KICKOFF TASK OUTPUTS.
- **Type:** Flag (boolean)
- **Default:** False

- **`-a, --all`**
- **Description:** Reset ALL memories.
- **Type:** Flag (boolean)
- **Default:** False



## Benefits of Using crewAI's Memory System
- **Adaptive Learning:** Crews become more efficient over time, adapting to new information and refining their approach to tasks.
- **Enhanced Personalization:** Memory enables agents to remember user preferences and historical interactions, leading to personalized experiences.
Expand Down
27 changes: 27 additions & 0 deletions src/crewai/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .create_crew import create_crew
from .train_crew import train_crew
from .replay_from_task import replay_task_command
from .reset_memories_command import reset_memories_command


@click.group()
Expand Down Expand Up @@ -99,5 +100,31 @@ def log_tasks_outputs() -> None:
click.echo(f"An error occurred while logging task outputs: {e}", err=True)


@crewai.command()
@click.option("-l", "--long", is_flag=True, help="Reset LONG TERM memory")
@click.option("-s", "--short", is_flag=True, help="Reset SHORT TERM memory")
@click.option("-e", "--entities", is_flag=True, help="Reset ENTITIES memory")
@click.option(
"-k",
"--kickoff-outputs",
is_flag=True,
help="Reset LATEST KICKOFF TASK OUTPUTS",
)
@click.option("-a", "--all", is_flag=True, help="Reset ALL memories")
def reset_memories(long, short, entities, kickoff_outputs, all):
"""
Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved.
"""
try:
if not all and not (long or short or entities or kickoff_outputs):
click.echo(
"Please specify at least one memory type to reset using the appropriate flags."
)
return
reset_memories_command(long, short, entities, kickoff_outputs, all)
except Exception as e:
click.echo(f"An error occurred while resetting memories: {e}", err=True)


if __name__ == "__main__":
crewai()
45 changes: 45 additions & 0 deletions src/crewai/cli/reset_memories_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import subprocess
import click

from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler


def reset_memories_command(long, short, entity, kickoff_outputs, all) -> None:
"""
Replay the crew execution from a specific task.
Args:
task_id (str): The ID of the task to replay from.
"""

try:
if all:
ShortTermMemory().reset()
EntityMemory().reset()
LongTermMemory().reset()
TaskOutputStorageHandler().reset()
click.echo("All memories have been reset.")
else:
if long:
LongTermMemory().reset()
click.echo("Long term memory has been reset.")

if short:
ShortTermMemory().reset()
click.echo("Short term memory has been reset.")
if entity:
EntityMemory().reset()
click.echo("Entity memory has been reset.")
if kickoff_outputs:
TaskOutputStorageHandler().reset()
click.echo("Latest Kickoff outputs stored has been reset.")

except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
click.echo(e.output, err=True)

except Exception as e:
click.echo(f"An unexpected error occurred: {e}", err=True)
6 changes: 6 additions & 0 deletions src/crewai/memory/entity/entity_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@ def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signatur
"""Saves an entity item into the SQLite storage."""
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)

def reset(self) -> None:
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the entity memory: {e}")
3 changes: 3 additions & 0 deletions src/crewai/memory/long_term/long_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signat

def search(self, task: str, latest_n: int = 3) -> Dict[str, Any]:
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"

def reset(self) -> None:
self.storage.reset()
10 changes: 9 additions & 1 deletion src/crewai/memory/short_term/short_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,16 @@ def __init__(self, crew=None, embedder_config=None):
)
super().__init__(storage)

def save(self, item: ShortTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
def save(self, item: ShortTermMemoryItem) -> None:
super().save(item.data, item.metadata, item.agent)

def search(self, query: str, score_threshold: float = 0.35):
return self.storage.search(query=query, score_threshold=score_threshold) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters

def reset(self) -> None:
try:
self.storage.reset()
except Exception as e:
raise Exception(
f"An error occurred while resetting the short-term memory: {e}"
)
3 changes: 3 additions & 0 deletions src/crewai/memory/storage/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ def save(self, key: str, value: Any, metadata: Dict[str, Any]) -> None:

def search(self, key: str) -> Dict[str, Any]: # type: ignore
pass

def reset(self) -> None:
pass
17 changes: 17 additions & 0 deletions src/crewai/memory/storage/ltm_sqlite_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,20 @@ def load(
color="red",
)
return None

def reset(
self,
) -> None:
"""Resets the LTM table with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM long_term_memories")
conn.commit()

except sqlite3.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
return None
13 changes: 11 additions & 2 deletions src/crewai/memory/storage/rag_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import io
import logging
import os
import shutil
from typing import Any, Dict, List, Optional

from embedchain import App
Expand Down Expand Up @@ -71,13 +72,13 @@ def __init__(self, type, allow_reset=True, embedder_config=None, crew=None):

if embedder_config:
config["embedder"] = embedder_config

self.type = type
self.app = App.from_config(config=config)
self.app.llm = FakeLLM()
if allow_reset:
self.app.reset()

def save(self, value: Any, metadata: Dict[str, Any]) -> None: # type: ignore # BUG?: Should be save(key, value, metadata) Signature of "save" incompatible with supertype "Storage"
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
self._generate_embedding(value, metadata)

def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage"
Expand All @@ -102,3 +103,11 @@ def search( # type: ignore # BUG?: Signature of "search" incompatible with supe
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
with suppress_logging():
self.app.add(text, data_type="text", metadata=metadata)

def reset(self) -> None:
try:
shutil.rmtree(f"{db_storage_path()}/{self.type}")
except Exception as e:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)
78 changes: 77 additions & 1 deletion tests/cli/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from click.testing import CliRunner

from crewai.cli.cli import train, version
from crewai.cli.cli import train, version, reset_memories


@pytest.fixture
Expand Down Expand Up @@ -41,6 +41,82 @@ def test_train_invalid_string_iterations(train_crew, runner):
)


@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
@mock.patch("crewai.cli.reset_memories_command.EntityMemory")
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler")
def test_reset_all_memories(
MockTaskOutputStorageHandler,
MockLongTermMemory,
MockEntityMemory,
MockShortTermMemory,
runner,
):
result = runner.invoke(reset_memories, ["--all"])
MockShortTermMemory().reset.assert_called_once()
MockEntityMemory().reset.assert_called_once()
MockLongTermMemory().reset.assert_called_once()
MockTaskOutputStorageHandler().reset.assert_called_once()

assert result.output == "All memories have been reset.\n"


@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
def test_reset_short_term_memories(MockShortTermMemory, runner):
result = runner.invoke(reset_memories, ["-s"])
MockShortTermMemory().reset.assert_called_once()
assert result.output == "Short term memory has been reset.\n"


@mock.patch("crewai.cli.reset_memories_command.EntityMemory")
def test_reset_entity_memories(MockEntityMemory, runner):
result = runner.invoke(reset_memories, ["-e"])
MockEntityMemory().reset.assert_called_once()
assert result.output == "Entity memory has been reset.\n"


@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
def test_reset_long_term_memories(MockLongTermMemory, runner):
result = runner.invoke(reset_memories, ["-l"])
MockLongTermMemory().reset.assert_called_once()
assert result.output == "Long term memory has been reset.\n"


@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler")
def test_reset_kickoff_outputs(MockTaskOutputStorageHandler, runner):
result = runner.invoke(reset_memories, ["-k"])
MockTaskOutputStorageHandler().reset.assert_called_once()
assert result.output == "Latest Kickoff outputs stored has been reset.\n"


@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
def test_reset_multiple_memory_flags(MockShortTermMemory, MockLongTermMemory, runner):
result = runner.invoke(
reset_memories,
[
"-s",
"-l",
],
)
MockShortTermMemory().reset.assert_called_once()
MockLongTermMemory().reset.assert_called_once()
assert (
result.output
== "Long term memory has been reset.\nShort term memory has been reset.\n"
)


def test_reset_no_memory_flags(runner):
result = runner.invoke(
reset_memories,
)
assert (
result.output
== "Please specify at least one memory type to reset using the appropriate flags.\n"
)


def test_version_command(runner):
result = runner.invoke(version)

Expand Down

0 comments on commit be1b9a3

Please sign in to comment.