Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Caching System and Optimize Memory Usage with Customizable Storage Support #1016

Merged
merged 3 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 13 additions & 30 deletions core/cat/auth/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from cat.looking_glass.stray_cat import StrayCat
from cat.log import log


class ConnectionAuth(ABC):

def __init__(
Expand Down Expand Up @@ -103,16 +104,14 @@ def extract_credentials(self, connection: Request) -> Tuple[str, str] | None:


async def get_user_stray(self, user: AuthUserInfo, connection: Request) -> StrayCat:
strays = connection.app.state.strays
event_loop = connection.app.state.event_loop

if user.id not in strays.keys():
strays[user.id] = StrayCat(
# TODOV2: user_id should be the user.id
user_id=user.name, user_data=user, main_loop=event_loop
)
return strays[user.id]

return StrayCat(
# TODOV2: user_id should be the user.id
user_id=user.name, user_data=user, main_loop=event_loop
)


def not_allowed(self, connection: Request):
raise HTTPException(status_code=403, detail={"error": "Invalid Credentials"})

Expand All @@ -136,29 +135,13 @@ def extract_credentials(self, connection: WebSocket) -> Tuple[str, str] | None:


async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> StrayCat:
strays = connection.app.state.strays
return StrayCat(
ws=connection,
user_id=user.name, # TODOV2: user_id should be the user.id
user_data=user,
main_loop=asyncio.get_running_loop(),
)

if user.id in strays.keys():
stray = strays[user.id]
await stray.close_connection()

# Set new ws connection
stray.reset_connection(connection)
log.info(
f"New websocket connection for user '{user.id}', the old one has been closed."
)
return stray

else:
stray = StrayCat(
ws=connection,
user_id=user.name, # TODOV2: user_id should be the user.id
user_data=user,
main_loop=asyncio.get_running_loop(),
)
strays[user.id] = stray
return stray

def not_allowed(self, connection: WebSocket):
raise WebSocketException(code=1004, reason="Invalid Credentials")

Expand Down
85 changes: 85 additions & 0 deletions core/cat/cache/array_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from cat.cache.cache_interface import CacheInterface
from cat.cache.cache_item import CacheItem

from cat.utils import singleton


@singleton
class ArrayCache(CacheInterface):
"""Cache implementation using a python dictionary.

Attributes
----------
cache : dict
Dictionary to store the cache.

"""

def __init__(self):
self.cache = {}

def insert(self, cache_item):
"""Insert a key-value pair in the cache.

Parameters
----------
cache_item : CacheItem
Cache item to store.

"""
self.cache[cache_item.key] = cache_item

def get_item(self, key) -> CacheItem:
"""Get the value stored in the cache.

Parameters
----------
key : str
Key to retrieve the value.

Returns
-------
any
Value stored in the cache.

"""
item = self.cache.get(key)

if item and item.is_expired():
del self.cache[key]
return None

return item

def get_value(self, key):
"""Get the value stored in the cache.

Parameters
----------
key : str
Key to retrieve the value.

Returns
-------
any
Value stored in the cache.

"""


item = self.get_item(key)
if item:
return item.value
return None

def delete(self, key):
"""Delete a key-value pair from the cache.

Parameters
----------
key : str
Key to delete the value.

"""
if key in self.cache:
del self.cache[key]
16 changes: 16 additions & 0 deletions core/cat/cache/cache_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

class CacheInterface:

def insert(self, cache_item):
pass

def get_item(self, key):
pass

def get_value(self, key):
pass

def delete(self, key):
pass


18 changes: 18 additions & 0 deletions core/cat/cache/cache_item.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import time


class CacheItem:
def __init__(self, key, value, ttl):
self.key = key
self.value = value
self.ttl = ttl
self.created_at = time.time()

def is_expired(self):
if self.ttl == -1 or self.ttl is None:
return False

return (self.created_at + self.ttl) < time.time()

def __repr__(self):
return f'CacheItem(key={self.key}, value={self.value}, ttl={self.ttl}, created_at={self.created_at})'
18 changes: 18 additions & 0 deletions core/cat/cache/cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from cat.env import get_env

from cat.utils import singleton


@singleton
class CacheManager:
def __init__(self):
cache_type = get_env("CCAT_CACHE_TYPE")
if cache_type == "file_system":
cache_dir = get_env("CCAT_CACHE_DIR")
if not cache_dir:
cache_dir = "/tmp"
from cat.cache.file_system_cache import FileSystemCache
self.cache = FileSystemCache(cache_dir)
else:
from cat.cache.array_cache import ArrayCache
self.cache = ArrayCache()
99 changes: 99 additions & 0 deletions core/cat/cache/file_system_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import pickle
from cat.cache.cache_interface import CacheInterface
from cat.utils import singleton


@singleton
class FileSystemCache(CacheInterface):
"""Cache implementation using the file system.

Attributes
----------
cache_dir : str
Directory to store the cache.

"""

def __init__(self, cache_dir):
self.cache_dir = cache_dir
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
def _get_file_path(self, key):
return os.path.join(self.cache_dir, f"{key}.cache")

def insert(self, cache_item):
"""Insert a key-value pair in the cache.

Parameters
----------
cache_item : CacheItem
Cache item to store.

"""

with open(self._get_file_path(cache_item.key), "wb") as f:
pickle.dump(cache_item, f)


def get_item(self, key):
"""Get the value stored in the cache.

Parameters
----------
key : str
Key to retrieve the value.

Returns
-------
any
Value stored in the cache.

"""
file_path = self._get_file_path(key)
if not os.path.exists(file_path):
return None

with open(file_path, "rb") as f:
cache_item = pickle.load(f)

if cache_item.is_expired():
os.remove(file_path)
return None

return cache_item

def get_value(self, key):
"""Get the value stored in the cache.

Parameters
----------
key : str
Key to retrieve the value.

Returns
-------
any
Value stored in the cache.

"""

cache_item = self.get_item(key)
if cache_item:
return cache_item.value
return None

def delete(self, key):
"""Delete a key-value pair from the cache.

Parameters
----------
key : str
Key to delete the value.

"""
file_path = self._get_file_path(key)
if os.path.exists(file_path):
os.remove(file_path)


4 changes: 3 additions & 1 deletion core/cat/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def get_supported_env_variables():
"CCAT_JWT_EXPIRE_MINUTES": str(60 * 24), # JWT expires after 1 day
"CCAT_HTTPS_PROXY_MODE": False,
"CCAT_CORS_FORWARDED_ALLOW_IPS": "*",
"CCAT_CORS_ENABLED": "true"
"CCAT_CORS_ENABLED": "true",
"CCAT_CACHE_TYPE": "array",
"CCAT_CACHE_DIR": None,
}


Expand Down
9 changes: 9 additions & 0 deletions core/cat/looking_glass/cheshire_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from cat.rabbit_hole import RabbitHole
from cat.utils import singleton
from cat import utils
from cat.cache.cache_manager import CacheManager
from cat.cache.cache_interface import CacheInterface


class Procedure(Protocol):
name: str
Expand Down Expand Up @@ -94,6 +97,8 @@ def __init__(self, fastapi_app):
# Rabbit Hole Instance
self.rabbit_hole = RabbitHole(self) # :(

self.cache_manager = CacheManager()

# allows plugins to do something after the cat bootstrap is complete
self.mad_hatter.execute_hook("after_cat_bootstrap", cat=self)

Expand Down Expand Up @@ -434,3 +439,7 @@ def llm(self, prompt, *args, **kwargs) -> str:
)

return output

@property
def cache(self) -> CacheInterface:
return self.cache_manager.cache
16 changes: 15 additions & 1 deletion core/cat/looking_glass/stray_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from cat import utils
from websockets.exceptions import ConnectionClosedOK

from cat.cache.cache_item import CacheItem

MSG_TYPES = Literal["notification", "chat", "error", "chat_token"]


Expand All @@ -40,7 +42,7 @@ def __init__(
self.__user_id = user_id
self.__user_data = user_data

self.working_memory = WorkingMemory()
self.working_memory = self._cache.get_value(f"{user_id}_working_memory") or WorkingMemory()

# attribute to store ws connection
self.__ws = ws
Expand All @@ -50,6 +52,13 @@ def __init__(
def __repr__(self):
return f"StrayCat(user_id={self.user_id})"

def __del__(self):
self.__main_loop = None
self.__ws = None
log.warning(f"StrayCat {self.user_id} deleted")
self.__user_id = None
self.__user_data = None

def __send_ws_json(self, data: Any):
# Run the corutine in the main event loop in the main thread
# and wait for the result
Expand Down Expand Up @@ -464,6 +473,8 @@ def __call__(self, message_dict):
final_output
)

self._cache.insert(CacheItem(f"{self.user_id}_working_memory", self.working_memory, -1))

return final_output

def run(self, user_message_json, return_message=False):
Expand Down Expand Up @@ -663,3 +674,6 @@ def main_agent(self):
@property
def white_rabbit(self):
return CheshireCat().white_rabbit
@property
def _cache(self):
return CheshireCat().cache
Loading
Loading