diff --git a/src/replit/database/database.py b/src/replit/database/database.py index 41cd1b6..9a59f86 100644 --- a/src/replit/database/database.py +++ b/src/replit/database/database.py @@ -1,5 +1,6 @@ """Async and dict-like interfaces for interacting with Replit Database.""" +import asyncio from collections import abc import json import threading @@ -82,8 +83,17 @@ class AsyncDatabase: :param unbind Callable: Permit additional behavior after Database close """ - __slots__ = ("db_url", "sess", "client", "_get_db_url", "_unbind", "_refresh_timer") + __slots__ = ( + "db_url", + "sess", + "client", + "_get_db_url", + "_unbind", + "_refresh_timer", + "_watchdog_timer", + ) _refresh_timer: Optional[threading.Timer] + _watchdog_timer: Optional[threading.Timer] def __init__( self, @@ -113,6 +123,9 @@ def __init__( if self._get_db_url: self._refresh_timer = threading.Timer(3600, self._refresh_db) self._refresh_timer.start() + watched_thread = threading.main_thread() + self._watchdog_timer = threading.Timer(1, self._watchdog, args=[watched_thread]) + self._watchdog_timer.start() def _refresh_db(self) -> None: if self._refresh_timer: @@ -125,6 +138,14 @@ def _refresh_db(self) -> None: self._refresh_timer = threading.Timer(3600, self._refresh_db) self._refresh_timer.start() + def _watchdog(self, watched_thread: threading.Thread) -> None: + if not watched_thread.is_alive(): + return asyncio.run(self.close()) + if self._watchdog_timer: + self._watchdog_timer.cancel() + self._watchdog_timer = threading.Timer(1, self._watchdog, args=[watched_thread]) + self._watchdog_timer.start() + def update_db_url(self, db_url: str) -> None: """Update the database url. @@ -292,6 +313,9 @@ async def close(self) -> None: if self._refresh_timer: self._refresh_timer.cancel() self._refresh_timer = None + if self._watchdog_timer: + self._watchdog_timer.cancel() + self._watchdog_timer = None if self._unbind: # Permit signaling to surrounding scopes that we have closed self._unbind() @@ -485,8 +509,16 @@ class Database(abc.MutableMapping): :param unbind Callable: Permit additional behavior after Database close """ - __slots__ = ("db_url", "sess", "_get_db_url", "_unbind", "_refresh_timer") + __slots__ = ( + "db_url", + "sess", + "_get_db_url", + "_unbind", + "_refresh_timer", + "_watchdog_timer", + ) _refresh_timer: Optional[threading.Timer] + _watchdog_timer: Optional[threading.Timer] def __init__( self, @@ -518,6 +550,9 @@ def __init__( if self._get_db_url: self._refresh_timer = threading.Timer(3600, self._refresh_db) self._refresh_timer.start() + watched_thread = threading.main_thread() + self._watchdog_timer = threading.Timer(1, self._watchdog, args=[watched_thread]) + self._watchdog_timer.start() def _refresh_db(self) -> None: if self._refresh_timer: @@ -530,6 +565,14 @@ def _refresh_db(self) -> None: self._refresh_timer = threading.Timer(3600, self._refresh_db) self._refresh_timer.start() + def _watchdog(self, watched_thread: threading.Thread) -> None: + if not watched_thread.is_alive(): + return self.close() + if self._watchdog_timer: + self._watchdog_timer.cancel() + self._watchdog_timer = threading.Timer(1, self._watchdog, args=[watched_thread]) + self._watchdog_timer.start() + def update_db_url(self, db_url: str) -> None: """Update the database url. @@ -720,6 +763,9 @@ def close(self) -> None: if self._refresh_timer: self._refresh_timer.cancel() self._refresh_timer = None + if self._watchdog_timer: + self._watchdog_timer.cancel() + self._watchdog_timer = None if self._unbind: # Permit signaling to surrounding scopes that we have closed self._unbind()