Skip to content

Commit

Permalink
Merge branch 'feature/cache' of https://github.com/AlboCode/core into…
Browse files Browse the repository at this point in the history
… AlboCode-feature/cache
  • Loading branch information
pieroit committed Feb 5, 2025
2 parents b3c3d5a + d7e5b9b commit 2dd3ac1
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 38 deletions.
43 changes: 13 additions & 30 deletions core/cat/auth/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from cat.looking_glass.stray_cat import StrayCat
from cat.log import log


class ConnectionAuth(ABC):

def __init__(
Expand Down Expand Up @@ -103,16 +104,14 @@ def extract_credentials(self, connection: Request) -> Tuple[str, str] | None:


async def get_user_stray(self, user: AuthUserInfo, connection: Request) -> StrayCat:
strays = connection.app.state.strays
event_loop = connection.app.state.event_loop

if user.id not in strays.keys():
strays[user.id] = StrayCat(
# TODOV2: user_id should be the user.id
user_id=user.name, user_data=user, main_loop=event_loop
)
return strays[user.id]

return StrayCat(
# TODOV2: user_id should be the user.id
user_id=user.name, user_data=user, main_loop=event_loop
)


def not_allowed(self, connection: Request):
raise HTTPException(status_code=403, detail={"error": "Invalid Credentials"})

Expand All @@ -136,29 +135,13 @@ def extract_credentials(self, connection: WebSocket) -> Tuple[str, str] | None:


async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> StrayCat:
strays = connection.app.state.strays
return StrayCat(
ws=connection,
user_id=user.name, # TODOV2: user_id should be the user.id
user_data=user,
main_loop=asyncio.get_running_loop(),
)

if user.id in strays.keys():
stray = strays[user.id]
await stray.close_connection()

# Set new ws connection
stray.reset_connection(connection)
log.info(
f"New websocket connection for user '{user.id}', the old one has been closed."
)
return stray

else:
stray = StrayCat(
ws=connection,
user_id=user.name, # TODOV2: user_id should be the user.id
user_data=user,
main_loop=asyncio.get_running_loop(),
)
strays[user.id] = stray
return stray

def not_allowed(self, connection: WebSocket):
raise WebSocketException(code=1004, reason="Invalid Credentials")

Expand Down
85 changes: 85 additions & 0 deletions core/cat/cache/array_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from cat.cache.cache_interface import CacheInterface
from cat.cache.cache_item import CacheItem

from cat.utils import singleton


@singleton
class ArrayCache(CacheInterface):
"""Cache implementation using a python dictionary.
Attributes
----------
cache : dict
Dictionary to store the cache.
"""

def __init__(self):
self.cache = {}

def insert(self, cache_item):
"""Insert a key-value pair in the cache.
Parameters
----------
cache_item : CacheItem
Cache item to store.
"""
self.cache[cache_item.key] = cache_item

def get_item(self, key) -> CacheItem:
"""Get the value stored in the cache.
Parameters
----------
key : str
Key to retrieve the value.
Returns
-------
any
Value stored in the cache.
"""
item = self.cache.get(key)

if item and item.is_expired():
del self.cache[key]
return None

return item

def get_value(self, key):
"""Get the value stored in the cache.
Parameters
----------
key : str
Key to retrieve the value.
Returns
-------
any
Value stored in the cache.
"""


item = self.get_item(key)
if item:
return item.value
return None

def delete(self, key):
"""Delete a key-value pair from the cache.
Parameters
----------
key : str
Key to delete the value.
"""
if key in self.cache:
del self.cache[key]
16 changes: 16 additions & 0 deletions core/cat/cache/cache_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

class CacheInterface:

def insert(self, cache_item):
pass

def get_item(self, key):
pass

def get_value(self, key):
pass

def delete(self, key):
pass


18 changes: 18 additions & 0 deletions core/cat/cache/cache_item.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import time


class CacheItem:
def __init__(self, key, value, ttl):
self.key = key
self.value = value
self.ttl = ttl
self.created_at = time.time()

def is_expired(self):
if self.ttl == -1 or self.ttl is None:
return False

return (self.created_at + self.ttl) < time.time()

def __repr__(self):
return f'CacheItem(key={self.key}, value={self.value}, ttl={self.ttl}, created_at={self.created_at})'
18 changes: 18 additions & 0 deletions core/cat/cache/cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from cat.env import get_env

from cat.utils import singleton


@singleton
class CacheManager:
def __init__(self):
cache_type = get_env("CCAT_CACHE_TYPE")
if cache_type == "file_system":
cache_dir = get_env("CCAT_CACHE_DIR")
if not cache_dir:
cache_dir = "/tmp"
from cat.cache.file_system_cache import FileSystemCache
self.cache = FileSystemCache(cache_dir)
else:
from cat.cache.array_cache import ArrayCache
self.cache = ArrayCache()
99 changes: 99 additions & 0 deletions core/cat/cache/file_system_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import pickle
from cat.cache.cache_interface import CacheInterface
from cat.utils import singleton


@singleton
class FileSystemCache(CacheInterface):
"""Cache implementation using the file system.
Attributes
----------
cache_dir : str
Directory to store the cache.
"""

def __init__(self, cache_dir):
self.cache_dir = cache_dir
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
def _get_file_path(self, key):
return os.path.join(self.cache_dir, f"{key}.cache")

def insert(self, cache_item):
"""Insert a key-value pair in the cache.
Parameters
----------
cache_item : CacheItem
Cache item to store.
"""

with open(self._get_file_path(cache_item.key), "wb") as f:
pickle.dump(cache_item, f)


def get_item(self, key):
"""Get the value stored in the cache.
Parameters
----------
key : str
Key to retrieve the value.
Returns
-------
any
Value stored in the cache.
"""
file_path = self._get_file_path(key)
if not os.path.exists(file_path):
return None

with open(file_path, "rb") as f:
cache_item = pickle.load(f)

if cache_item.is_expired():
os.remove(file_path)
return None

return cache_item

def get_value(self, key):
"""Get the value stored in the cache.
Parameters
----------
key : str
Key to retrieve the value.
Returns
-------
any
Value stored in the cache.
"""

cache_item = self.get_item(key)
if cache_item:
return cache_item.value
return None

def delete(self, key):
"""Delete a key-value pair from the cache.
Parameters
----------
key : str
Key to delete the value.
"""
file_path = self._get_file_path(key)
if os.path.exists(file_path):
os.remove(file_path)


4 changes: 3 additions & 1 deletion core/cat/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def get_supported_env_variables():
"CCAT_JWT_EXPIRE_MINUTES": str(60 * 24), # JWT expires after 1 day
"CCAT_HTTPS_PROXY_MODE": False,
"CCAT_CORS_FORWARDED_ALLOW_IPS": "*",
"CCAT_CORS_ENABLED": "true"
"CCAT_CORS_ENABLED": "true",
"CCAT_CACHE_TYPE": "array",
"CCAT_CACHE_DIR": None,
}


Expand Down
9 changes: 9 additions & 0 deletions core/cat/looking_glass/cheshire_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from cat.rabbit_hole import RabbitHole
from cat.utils import singleton
from cat import utils
from cat.cache.cache_manager import CacheManager
from cat.cache.cache_interface import CacheInterface


class Procedure(Protocol):
name: str
Expand Down Expand Up @@ -94,6 +97,8 @@ def __init__(self, fastapi_app):
# Rabbit Hole Instance
self.rabbit_hole = RabbitHole(self) # :(

self.cache_manager = CacheManager()

# allows plugins to do something after the cat bootstrap is complete
self.mad_hatter.execute_hook("after_cat_bootstrap", cat=self)

Expand Down Expand Up @@ -434,3 +439,7 @@ def llm(self, prompt, *args, **kwargs) -> str:
)

return output

@property
def cache(self) -> CacheInterface:
return self.cache_manager.cache
16 changes: 15 additions & 1 deletion core/cat/looking_glass/stray_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from cat import utils
from websockets.exceptions import ConnectionClosedOK

from cat.cache.cache_item import CacheItem

MSG_TYPES = Literal["notification", "chat", "error", "chat_token"]


Expand All @@ -40,7 +42,7 @@ def __init__(
self.__user_id = user_id
self.__user_data = user_data

self.working_memory = WorkingMemory()
self.working_memory = self._cache.get_value(f"{user_id}_working_memory") or WorkingMemory()

# attribute to store ws connection
self.__ws = ws
Expand All @@ -50,6 +52,13 @@ def __init__(
def __repr__(self):
return f"StrayCat(user_id={self.user_id})"

def __del__(self):
self.__main_loop = None
self.__ws = None
log.warning(f"StrayCat {self.user_id} deleted")
self.__user_id = None
self.__user_data = None

def __send_ws_json(self, data: Any):
# Run the corutine in the main event loop in the main thread
# and wait for the result
Expand Down Expand Up @@ -464,6 +473,8 @@ def __call__(self, message_dict):
final_output
)

self._cache.insert(CacheItem(f"{self.user_id}_working_memory", self.working_memory, -1))

return final_output

def run(self, user_message_json, return_message=False):
Expand Down Expand Up @@ -626,3 +637,6 @@ def main_agent(self):
@property
def white_rabbit(self):
return CheshireCat().white_rabbit
@property
def _cache(self):
return CheshireCat().cache
Loading

0 comments on commit 2dd3ac1

Please sign in to comment.