Skip to content

Commit

Permalink
fix: Database is locked using RWLock (#817)
Browse files Browse the repository at this point in the history
* Vendor `readerwriterlock`

* Use rw lock to prevent db errors

* Use multiple read locks, not just one global one

* Use read lock in the DBConnection

* Refactor

* Edit docstring

* Raise `LockAcquisitionTimeoutError`

* Release write lock after detaching

* Remove sorting capability for `EditedAfterSyncColumn`

* Use lock in context manager

* Create `write_lock_context` and `read_lock_context`

* Use `detached_ankihub_db` in `errors.py`

* Extract timeout constant

* Edit docstrings

* Refactor

* Reduce log level

* Add tests

* Edit test

* Edit tests to make them work on all Anki versions

* Edit tests

* Lock on transaction level

* fix: Dont commit if there was an exception  in the context

* Adjust code for the not-context manager-case

* Rollback explicitly

* Add test

* Add tests

* Add docstrings

* ref: Rename to `thread_safety_context_func`
  • Loading branch information
RisingOrange authored Dec 29, 2023
1 parent f95b9fb commit 2ace8f7
Show file tree
Hide file tree
Showing 16 changed files with 1,720 additions and 104 deletions.
3 changes: 1 addition & 2 deletions ankihub/db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .db import ( # noqa: F401
ankihub_db,
attach_ankihub_db_to_anki_db_connection,
attached_ankihub_db,
detach_ankihub_db_from_anki_db_connection,
detached_ankihub_db,
is_ankihub_db_attached_to_anki_db,
)
48 changes: 34 additions & 14 deletions ankihub/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,37 @@
from ..settings import ANKI_INT_VERSION, ANKI_VERSION_23_10_00
from .db_utils import DBConnection
from .exceptions import IntegrityError
from .rw_lock import exclusive_db_access_context, non_exclusive_db_access_context


def attach_ankihub_db_to_anki_db_connection() -> None:
@contextmanager
def attached_ankihub_db():
"""Context manager that attaches the AnkiHub DB to the Anki DB connection and detaches it when the context exits.
The purpose is to e.g. do join queries between the Anki DB and the AnkiHub DB through aqt.mw.col.db.execute().
A lock is used to ensure that other threads don't try to access the AnkiHub DB through the _AnkiHubDB class
while it is attached to the Anki DB.
"""
with exclusive_db_access_context():
_attach_ankihub_db_to_anki_db_connection()
try:
yield
finally:
_detach_ankihub_db_from_anki_db_connection()


@contextmanager
def detached_ankihub_db():
"""Context manager that ensures the AnkiHub DB is detached from the Anki DB connection while the context is active.
The purpose of this is to be able to safely perform operations on the AnkiHub DB which require it to be detached,
for example coyping the AnkiHub DB file.
It's used by the _AnkiHubDB class to ensure that the AnkiHub DB is detached from the Anki DB while
queries are executed through the _AnkiHubDB class.
"""
with non_exclusive_db_access_context():
yield


def _attach_ankihub_db_to_anki_db_connection() -> None:
if aqt.mw.col is None:
LOGGER.info("The collection is not open. Not attaching AnkiHub DB.")
return
Expand All @@ -44,7 +72,7 @@ def attach_ankihub_db_to_anki_db_connection() -> None:
LOGGER.info("Attached AnkiHub DB to Anki DB connection")


def detach_ankihub_db_from_anki_db_connection() -> None:
def _detach_ankihub_db_from_anki_db_connection() -> None:
if aqt.mw.col is None:
LOGGER.info("The collection is not open. Not detaching AnkiHub DB.")
return
Expand Down Expand Up @@ -82,17 +110,7 @@ def is_ankihub_db_attached_to_anki_db() -> bool:
return result


@contextmanager
def attached_ankihub_db():
attach_ankihub_db_to_anki_db_connection()
try:
yield
finally:
detach_ankihub_db_from_anki_db_connection()


class _AnkiHubDB:

# name of the database when attached to the Anki DB connection
database_name = "ankihub_db"
database_path: Optional[Path] = None
Expand Down Expand Up @@ -199,7 +217,10 @@ def schema_version(self) -> int:
return result

def connection(self) -> DBConnection:
result = DBConnection(conn=sqlite3.connect(ankihub_db.database_path))
result = DBConnection(
conn=sqlite3.connect(ankihub_db.database_path),
thread_safety_context_func=detached_ankihub_db,
)
return result

def upsert_notes_data(
Expand All @@ -224,7 +245,6 @@ def upsert_notes_data(
upserted_notes: List[NoteInfo] = []
skipped_notes: List[NoteInfo] = []
with self.connection() as conn:

for note_data in notes_data:
conflicting_ah_nid = conn.first(
"""
Expand Down
23 changes: 21 additions & 2 deletions ankihub/db/db_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sqlite3
from typing import Any, List, Optional, Tuple
from typing import Any, Callable, ContextManager, List, Optional, Tuple

from .. import LOGGER

Expand All @@ -11,21 +11,37 @@ class DBConnection:
the context will be part of a single transaction. If an exception occurs within the context,
the transaction will be automatically rolled back.
thread_safety_context_func: A function that returns a context manager. This is used to ensure that
threads only access the database when it is safe to do so.
Note: Once a query has been executed using an instance of this class,
the instance cannot be used to execute another query unless it is within a context manager.
Attempting to do so will raise an exception.
"""

def __init__(self, conn: sqlite3.Connection):
def __init__(
self,
conn: sqlite3.Connection,
thread_safety_context_func: Callable[[], ContextManager],
):
self._conn = conn
self._is_used_as_context_manager = False
self._thread_safety_context_func = thread_safety_context_func
self._thread_safety_context: Optional[ContextManager] = None

def execute(
self,
sql: str,
*args,
first_row_only=False,
) -> List:
if self._is_used_as_context_manager:
return self._execute_inner(sql, *args, first_row_only=first_row_only)
else:
with self._thread_safety_context_func():
return self._execute_inner(sql, *args, first_row_only=first_row_only)

def _execute_inner(self, sql: str, *args, first_row_only=False) -> List:
try:
cur = self._conn.cursor()
cur.execute(sql, args)
Expand Down Expand Up @@ -67,6 +83,8 @@ def first(self, sql: str, *args) -> Optional[Tuple]:
return None

def __enter__(self):
self._thread_safety_context = self._thread_safety_context_func()
self._thread_safety_context.__enter__()
self._is_used_as_context_manager = True
return self

Expand All @@ -78,3 +96,4 @@ def __exit__(self, exc_type, exc_val, exc_tb):

self._conn.close()
self._is_used_as_context_manager = False
self._thread_safety_context.__exit__(exc_type, exc_val, exc_tb)
5 changes: 5 additions & 0 deletions ankihub/db/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@ class AnkiHubDBError(Exception):
class IntegrityError(AnkiHubDBError):
def __init__(self, message):
super().__init__(message)


class LockAcquisitionTimeoutError(AnkiHubDBError):
def __init__(self, message):
super().__init__(message)
44 changes: 44 additions & 0 deletions ankihub/db/rw_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""This module defines a readers-writer lock for the AnkiHub DB.
Multiple threads can enter the non_exclusive_db_access_context() context, but when
a thread enters the exclusive_db_access_context() context, no other thread can enter
either context until the thread that entered the exclusive_db_access_context() context
exits it.
"""
from contextlib import contextmanager

from readerwriterlock import rwlock

from .. import LOGGER
from .exceptions import LockAcquisitionTimeoutError

LOCK_TIMEOUT_SECONDS = 5

rw_lock = rwlock.RWLockFair()
write_lock = rw_lock.gen_wlock()


@contextmanager
def exclusive_db_access_context():
if write_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT_SECONDS):
LOGGER.debug("Acquired exclusive access.")
try:
yield
finally:
write_lock.release()
LOGGER.debug("Released exclusive access.")
else:
raise LockAcquisitionTimeoutError("Could not acquire exclusive access.")


@contextmanager
def non_exclusive_db_access_context():
lock = rw_lock.gen_rlock()
if lock.acquire(blocking=True, timeout=LOCK_TIMEOUT_SECONDS):
LOGGER.debug("Acquired non-exclusive access.")
try:
yield
finally:
lock.release()
LOGGER.debug("Released non-exclusive access.")
else:
raise LockAcquisitionTimeoutError("Could not acquire non-exclusive access.")
41 changes: 2 additions & 39 deletions ankihub/gui/browser/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@

from ... import LOGGER
from ...ankihub_client import SuggestionType
from ...db import (
ankihub_db,
attach_ankihub_db_to_anki_db_connection,
attached_ankihub_db,
detach_ankihub_db_from_anki_db_connection,
)
from ...db import ankihub_db, attached_ankihub_db
from ...main.importing import get_fields_protected_by_tags
from ...main.note_conversion import (
TAG_FOR_PROTECTING_ALL_FIELDS,
Expand All @@ -55,7 +50,6 @@
from ..utils import ask_user, choose_ankihub_deck, choose_list, choose_subset
from .custom_columns import (
AnkiHubIdColumn,
CustomColumn,
EditedAfterSyncColumn,
UpdatedSinceLastReviewColumn,
)
Expand Down Expand Up @@ -512,7 +506,6 @@ def on_done(future: Future) -> None:


def _reset_optional_tag_group(extension_id: int) -> None:

extension_config = config.deck_extension_config(extension_id)
_remove_optional_tags_of_extension(extension_config)

Expand Down Expand Up @@ -566,30 +559,10 @@ def _on_browser_did_fetch_row(
# cutom search nodes
def _setup_search():
browser_will_search.append(_on_browser_will_search)
browser_did_search.append(_on_browser_did_search)
browser_did_search.append(_on_browser_did_search_handle_custom_search_parameters)


def _on_browser_will_search(ctx: SearchContext):
_on_browser_will_search_handle_custom_column_ordering(ctx)
_on_browser_will_search_handle_custom_search_parameters(ctx)


def _on_browser_will_search_handle_custom_column_ordering(ctx: SearchContext):
if not isinstance(ctx.order, Column):
return

custom_column: CustomColumn = next(
(c for c in custom_columns if c.builtin_column.key == ctx.order.key), None
)
if custom_column is None:
return

attach_ankihub_db_to_anki_db_connection()

ctx.order = custom_column.order_by_str()


def _on_browser_will_search_handle_custom_search_parameters(ctx: SearchContext):
if not ctx.search:
return

Expand All @@ -615,15 +588,6 @@ def _on_browser_will_search_handle_custom_search_parameters(ctx: SearchContext):
ctx.search = ctx.search.replace(m.group(0), "")


def _on_browser_did_search(ctx: SearchContext):
# Detach the ankihub database in case it was attached in on_browser_will_search_handle_custom_column_ordering.
# The attached_ankihub_db context manager can't be used for this because the database query happens
# in the rust backend.
detach_ankihub_db_from_anki_db_connection()

_on_browser_did_search_handle_custom_search_parameters(ctx)


def _on_browser_did_search_handle_custom_search_parameters(ctx: SearchContext):
global custom_search_nodes

Expand Down Expand Up @@ -735,7 +699,6 @@ def _sidebar_item_descendants(item: SidebarItem) -> List[SidebarItem]:


def _add_ankihub_tree(tree: SidebarItem) -> SidebarItem:

result = tree.add_simple(
name="👑 AnkiHub",
icon="",
Expand Down
34 changes: 1 addition & 33 deletions ankihub/gui/browser/custom_columns.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
"""Custom Anki browser columns."""
import uuid
from abc import abstractmethod
from typing import Optional, Sequence
from typing import Sequence

import aqt
from anki.collection import BrowserColumns
from anki.notes import Note
from anki.utils import ids2str
from aqt.browser import Browser, CellRow, Column, ItemId

from ...db import ankihub_db
from ...main.utils import note_types_with_ankihub_id_field
from ...settings import ANKI_INT_VERSION, ANKI_VERSION_23_10_00


class CustomColumn:
Expand Down Expand Up @@ -49,11 +46,6 @@ def _display_value(
) -> str:
raise NotImplementedError

def order_by_str(self) -> Optional[str]:
"""Return the SQL string that will be appended after "ORDER BY" to the query that
fetches the search results when sorting by this column."""
return None


class AnkiHubIdColumn(CustomColumn):
builtin_column = Column(
Expand All @@ -79,21 +71,10 @@ def _display_value(

class EditedAfterSyncColumn(CustomColumn):
def __init__(self) -> None:
if ANKI_INT_VERSION >= ANKI_VERSION_23_10_00:
sorting_args = {
"sorting_cards": BrowserColumns.SORTING_DESCENDING,
"sorting_notes": BrowserColumns.SORTING_DESCENDING,
}
else:
sorting_args = {
"sorting": BrowserColumns.SORTING_DESCENDING,
}

self.builtin_column = Column(
key="edited_after_sync",
cards_mode_label="AnkiHub: Modified After Sync",
notes_mode_label="AnkiHub: Modified After Sync",
**sorting_args, # type: ignore
uses_cell_font=False,
alignment=BrowserColumns.ALIGNMENT_CENTER,
)
Expand All @@ -112,19 +93,6 @@ def _display_value(

return "Yes" if note.mod > last_sync else "No"

def order_by_str(self) -> str:
mids = note_types_with_ankihub_id_field()
if not mids:
return None

return f"""
(
SELECT n.mod > ah_n.mod from {ankihub_db.database_name}.notes AS ah_n
WHERE ah_n.anki_note_id = n.id LIMIT 1
) DESC,
(n.mid IN {ids2str(mids)}) DESC
"""


class UpdatedSinceLastReviewColumn(CustomColumn):
builtin_column = Column(
Expand Down
Loading

0 comments on commit 2ace8f7

Please sign in to comment.