From 5c0ab7a39bfe13be581607f8aac1d75d08fa17a6 Mon Sep 17 00:00:00 2001 From: lhchavez Date: Tue, 27 Feb 2024 14:01:31 -0800 Subject: [PATCH] [replit] The lesser of two evils (#167) * [replit] The lesser of two evils Currently the replit library has a very gross quirk: it has a global in `replit.database.default_db.db`, and the mere action of importing this library causes side effects to run! (connects to the database, starts a thread to refresh the URL, and prints a warning to stdout, adding insult to injury). So we're trading that very gross quirk with a gross workaround to preserve backwards compatibility: the modules that somehow end up importing that module now have a `__getattr__` that _lazily_ calls the code that used to be invoked as a side-effect of importing the library. Maybe in the future we'll deploy a breaking version of the library where we're not beholden to this backwards-compatibility quirck. * Marking internal properties as private Providing accessors, to hint that we are accessing mutable state * Reintroduce refresh_db noop to avoid errors on upgrade * Reflow LazyDB back down into default_db module An issue with LazyDB is that the refresh_db timer would not get canceled if the user closes the database. Additionally, the db_url refresh logic relies on injection, whereas the Database should ideally be the thing requesting that information from the environment * Removing stale main.sh --------- Co-authored-by: Devon Stewart --- main.sh | 1 - src/replit/__init__.py | 15 ++++- src/replit/database/__init__.py | 16 +++++- src/replit/database/database.py | 96 ++++++++++++++++++++++++++++--- src/replit/database/default_db.py | 65 +++++++++++++-------- src/replit/database/server.py | 22 ++++--- src/replit/web/__init__.py | 13 ++++- src/replit/web/user.py | 23 +++++--- 8 files changed, 200 insertions(+), 51 deletions(-) delete mode 100644 main.sh diff --git a/main.sh b/main.sh deleted file mode 100644 index 975cd92a..00000000 --- a/main.sh +++ /dev/null @@ -1 +0,0 @@ -python testapp.py \ No newline at end of file diff --git a/src/replit/__init__.py b/src/replit/__init__.py index 14f79cf9..66f5ef4f 100644 --- a/src/replit/__init__.py +++ b/src/replit/__init__.py @@ -2,10 +2,11 @@ """The Replit Python module.""" -from . import web +from typing import Any + +from . import database, web from .audio import Audio from .database import ( - db, Database, AsyncDatabase, make_database_proxy_blueprint, @@ -23,3 +24,13 @@ def clear() -> None: audio = Audio() + + +# Previous versions of this library would just have side-effects and always set +# up a database unconditionally. That is very undesirable, so instead of doing +# that, we are using this egregious hack to get the database / database URL +# lazily. +def __getattr__(name: str) -> Any: + if name == "db": + return database.db + raise AttributeError(name) diff --git a/src/replit/database/__init__.py b/src/replit/database/__init__.py index 2e94cc7a..2fafdec0 100644 --- a/src/replit/database/__init__.py +++ b/src/replit/database/__init__.py @@ -1,6 +1,8 @@ """Interface with the Replit Database.""" +from typing import Any + +from . import default_db from .database import AsyncDatabase, Database, DBJSONEncoder, dumps, to_primitive -from .default_db import db, db_url from .server import make_database_proxy_blueprint, start_database_proxy __all__ = [ @@ -14,3 +16,15 @@ "start_database_proxy", "to_primitive", ] + + +# Previous versions of this library would just have side-effects and always set +# up a database unconditionally. That is very undesirable, so instead of doing +# that, we are using this egregious hack to get the database / database URL +# lazily. +def __getattr__(name: str) -> Any: + if name == "db": + return default_db.db + if name == "db_url": + return default_db.db_url + raise AttributeError(name) diff --git a/src/replit/database/database.py b/src/replit/database/database.py index f4e63399..025e6767 100644 --- a/src/replit/database/database.py +++ b/src/replit/database/database.py @@ -1,7 +1,8 @@ -"""Async and dict-like interfaces for interacting with Repl.it Database.""" +"""Async and dict-like interfaces for interacting with Replit Database.""" from collections import abc import json +import threading from typing import ( Any, Callable, @@ -61,24 +62,57 @@ def dumps(val: Any) -> str: class AsyncDatabase: - """Async interface for Repl.it Database.""" + """Async interface for Replit Database. - __slots__ = ("db_url", "sess", "client") + :param str db_url: The Database URL to connect to + :param int retry_count: How many retry attempts we should make + :param get_db_url Callable: A callback that returns the current db_url + :param unbind Callable: Permit additional behavior after Database close + """ + + __slots__ = ("db_url", "sess", "client", "_get_db_url", "_unbind", "_refresh_timer") + _refresh_timer: Optional[threading.Timer] - def __init__(self, db_url: str, retry_count: int = 5) -> None: + def __init__( + self, + db_url: str, + retry_count: int = 5, + get_db_url: Optional[Callable[[], Optional[str]]] = None, + unbind: Optional[Callable[[], None]] = None, + ) -> None: """Initialize database. You shouldn't have to do this manually. Args: db_url (str): Database url to use. retry_count (int): How many times to retry connecting (with exponential backoff) + get_db_url (callable[[], str]): A function that will be called to refresh + the db_url property + unbind (callable[[], None]): A callback to clean up after .close() is called """ self.db_url = db_url self.sess = aiohttp.ClientSession() + self._get_db_url = get_db_url + self._unbind = unbind retry_options = ExponentialRetry(attempts=retry_count) self.client = RetryClient(client_session=self.sess, retry_options=retry_options) + if self._get_db_url: + self._refresh_timer = threading.Timer(3600, self._refresh_db) + self._refresh_timer.start() + + def _refresh_db(self) -> None: + if self._refresh_timer: + self._refresh_timer.cancel() + self._refresh_timer = None + if self._get_db_url: + db_url = self._get_db_url() + if db_url: + self.update_db_url(db_url) + self._refresh_timer = threading.Timer(3600, self._refresh_db) + self._refresh_timer.start() + def update_db_url(self, db_url: str) -> None: """Update the database url. @@ -239,6 +273,16 @@ async def items(self) -> Tuple[Tuple[str, str], ...]: """ return tuple((await self.to_dict()).items()) + async def close(self) -> None: + """Closes the database client connection.""" + await self.sess.close() + if self._refresh_timer: + self._refresh_timer.cancel() + self._refresh_timer = None + if self._unbind: + # Permit signaling to surrounding scopes that we have closed + self._unbind() + def __repr__(self) -> str: """A representation of the database. @@ -417,30 +461,62 @@ def item_to_observed(on_mutate: Callable[[Any], None], item: Any) -> Any: class Database(abc.MutableMapping): - """Dictionary-like interface for Repl.it Database. + """Dictionary-like interface for Replit Database. This interface will coerce all values everything to and from JSON. If you don't want this, use AsyncDatabase instead. + + :param str db_url: The Database URL to connect to + :param int retry_count: How many retry attempts we should make + :param get_db_url Callable: A callback that returns the current db_url + :param unbind Callable: Permit additional behavior after Database close """ - __slots__ = ("db_url", "sess") + __slots__ = ("db_url", "sess", "_get_db_url", "_unbind", "_refresh_timer") + _refresh_timer: Optional[threading.Timer] - def __init__(self, db_url: str, retry_count: int = 5) -> None: + def __init__( + self, + db_url: str, + retry_count: int = 5, + get_db_url: Optional[Callable[[], Optional[str]]] = None, + unbind: Optional[Callable[[], None]] = None, + ) -> None: """Initialize database. You shouldn't have to do this manually. Args: db_url (str): Database url to use. retry_count (int): How many times to retry connecting (with exponential backoff) + get_db_url (callable[[], str]): A function that will be called to refresh + the db_url property + unbind (callable[[], None]): A callback to clean up after .close() is called """ self.db_url = db_url self.sess = requests.Session() + self._get_db_url = get_db_url + self._unbind = unbind retries = Retry( total=retry_count, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504] ) self.sess.mount("http://", HTTPAdapter(max_retries=retries)) self.sess.mount("https://", HTTPAdapter(max_retries=retries)) + if self._get_db_url: + self._refresh_timer = threading.Timer(3600, self._refresh_db) + self._refresh_timer.start() + + def _refresh_db(self) -> None: + if self._refresh_timer: + self._refresh_timer.cancel() + self._refresh_timer = None + if self._get_db_url: + db_url = self._get_db_url() + if db_url: + self.update_db_url(db_url) + self._refresh_timer = threading.Timer(3600, self._refresh_db) + self._refresh_timer.start() + def update_db_url(self, db_url: str) -> None: """Update the database url. @@ -627,3 +703,9 @@ def __repr__(self) -> str: def close(self) -> None: """Closes the database client connection.""" self.sess.close() + if self._refresh_timer: + self._refresh_timer.cancel() + self._refresh_timer = None + if self._unbind: + # Permit signaling to surrounding scopes that we have closed + self._unbind() diff --git a/src/replit/database/default_db.py b/src/replit/database/default_db.py index cc00c8d2..7b6e1427 100644 --- a/src/replit/database/default_db.py +++ b/src/replit/database/default_db.py @@ -1,41 +1,58 @@ """A module containing the default database.""" -from os import environ, path -import threading -from typing import Optional - +import os +import os.path +from typing import Any, Optional from .database import Database -def get_db_url() -> str: +def get_db_url() -> Optional[str]: """Fetches the most up-to-date db url from the Repl environment.""" # todo look into the security warning ignored below tmpdir = "/tmp/replitdb" # noqa: S108 - if path.exists(tmpdir): + if os.path.exists(tmpdir): with open(tmpdir, "r") as file: - db_url = file.read() - else: - db_url = environ.get("REPLIT_DB_URL") + return file.read() - return db_url + return os.environ.get("REPLIT_DB_URL") def refresh_db() -> None: - """Refresh the DB URL every hour.""" - global db + """Deprecated: refresh_db is now the responsibility of the Database instance.""" + pass + + +def _unbind() -> None: + global _db + _db = None + + +def _get_db() -> Optional[Database]: + global _db + if _db is not None: + return _db + db_url = get_db_url() - db.update_db_url(db_url) - threading.Timer(3600, refresh_db).start() + if db_url: + _db = Database(db_url, get_db_url=get_db_url, unbind=_unbind) + else: + # The user will see errors if they try to use the database. + print("Warning: error initializing database. Replit DB is not configured.") + _db = None + return _db + + +_db: Optional[Database] = None -db: Optional[Database] -db_url = get_db_url() -if db_url: - db = Database(db_url) -else: - # The user will see errors if they try to use the database. - print("Warning: error initializing database. Replit DB is not configured.") - db = None -if db: - refresh_db() +# Previous versions of this library would just have side-effects and always set +# up a database unconditionally. That is very undesirable, so instead of doing +# that, we are using this egregious hack to get the database / database URL +# lazily. +def __getattr__(name: str) -> Any: + if name == "db": + return _get_db() + if name == "db_url": + return get_db_url() + raise AttributeError(name) diff --git a/src/replit/database/server.py b/src/replit/database/server.py index 0149415f..6b70b30d 100644 --- a/src/replit/database/server.py +++ b/src/replit/database/server.py @@ -4,7 +4,7 @@ from flask import Blueprint, Flask, request -from .default_db import db +from . import default_db def make_database_proxy_blueprint(view_only: bool, prefix: str = "") -> Blueprint: @@ -20,10 +20,12 @@ def make_database_proxy_blueprint(view_only: bool, prefix: str = "") -> Blueprin app = Blueprint("database_proxy" + ("_view_only" if view_only else ""), __name__) def list_keys() -> Any: - user_prefix = request.args.get("prefix") + if default_db.db is None: + return "Database is not configured", 500 + user_prefix = request.args.get("prefix", "") encode = "encode" in request.args - keys = db.prefix(prefix=prefix + user_prefix) - keys = [k[len(prefix) :] for k in keys] + raw_keys = default_db.db.prefix(prefix=prefix + user_prefix) + keys = [k[len(prefix) :] for k in raw_keys] if encode: return "\n".join(quote(k) for k in keys) @@ -31,10 +33,12 @@ def list_keys() -> Any: return "\n".join(keys) def set_key() -> Any: + if default_db.db is None: + return "Database is not configured", 500 if view_only: return "Database is view only", 401 for k, v in request.form.items(): - db[prefix + k] = v + default_db.db[prefix + k] = v return "" @app.route("/", methods=["GET", "POST"]) @@ -44,16 +48,20 @@ def index() -> Any: return set_key() def get_key(key: str) -> Any: + if default_db.db is None: + return "Database is not configured", 500 try: - return db[prefix + key] + return default_db.db[prefix + key] except KeyError: return "", 404 def delete_key(key: str) -> Any: + if default_db.db is None: + return "Database is not configured", 500 if view_only: return "Database is view only", 401 try: - del db[prefix + key] + del default_db.db[prefix + key] except KeyError: return "", 404 return "" diff --git a/src/replit/web/__init__.py b/src/replit/web/__init__.py index 117b02f3..03c86f1b 100644 --- a/src/replit/web/__init__.py +++ b/src/replit/web/__init__.py @@ -9,6 +9,17 @@ from .app import debug, ReplitAuthContext, run from .user import User, UserStore from .utils import * -from ..database import AsyncDatabase, Database, db +from .. import database +from ..database import AsyncDatabase, Database auth = LocalProxy(lambda: ReplitAuthContext.from_headers(flask.request.headers)) + + +# Previous versions of this library would just have side-effects and always set +# up a database unconditionally. That is very undesirable, so instead of doing +# that, we are using this egregious hack to get the database / database URL +# lazily. +def __getattr__(name: str) -> Any: + if name == "db": + return database.db + raise AttributeError(name) diff --git a/src/replit/web/user.py b/src/replit/web/user.py index fa85ba87..133e6c37 100644 --- a/src/replit/web/user.py +++ b/src/replit/web/user.py @@ -5,9 +5,7 @@ import flask from .app import ReplitAuthContext -from ..database import Database, db as real_db - -db: Database = real_db # type: ignore +from .. import database class User(MutableMapping): @@ -31,15 +29,22 @@ def set_value(self, value: str) -> None: Args: value (str): The value to set in the database + + Raises: + RuntimeError: Raised if the database is not configured. """ - db[self.db_key()] = value + if database.db is None: + raise RuntimeError("database not configured") + database.db[self.db_key()] = value def _ensure_value(self) -> Any: + if database.db is None: + raise RuntimeError("database not configured") try: - return db[self.db_key()] + return database.db[self.db_key()] except KeyError: - db[self.db_key()] = {} - return db[self.db_key()] + database.db[self.db_key()] = {} + return database.db[self.db_key()] def set(self, key: str, val: Any) -> None: """Sets a key to a value for this user's entry in the database. @@ -103,7 +108,9 @@ def __getitem__(self, name: str) -> User: return User(username=name, prefix=self.prefix) def __iter__(self) -> Iterator[str]: - for k in db.keys(): + if database.db is None: + raise RuntimeError("database not configured") + for k in database.db.keys(): if k.startswith(self.prefix): yield self._strip_prefix(k)