From 2ace8f72292aeb435b203a4eb9b863ffb3bdb29b Mon Sep 17 00:00:00 2001 From: Jakub Fidler <31575114+RisingOrange@users.noreply.github.com> Date: Fri, 29 Dec 2023 20:17:21 +0100 Subject: [PATCH 1/6] fix: Database is locked using RWLock (#817) * Vendor `readerwriterlock` * Use rw lock to prevent db errors * Use multiple read locks, not just one global one * Use read lock in the DBConnection * Refactor * Edit docstring * Raise `LockAcquisitionTimeoutError` * Release write lock after detaching * Remove sorting capability for `EditedAfterSyncColumn` * Use lock in context manager * Create `write_lock_context` and `read_lock_context` * Use `detached_ankihub_db` in `errors.py` * Extract timeout constant * Edit docstrings * Refactor * Reduce log level * Add tests * Edit test * Edit tests to make them work on all Anki versions * Edit tests * Lock on transaction level * fix: Dont commit if there was an exception in the context * Adjust code for the not-context manager-case * Rollback explicitly * Add test * Add tests * Add docstrings * ref: Rename to `thread_safety_context_func` --- ankihub/db/__init__.py | 3 +- ankihub/db/db.py | 48 +- ankihub/db/db_utils.py | 23 +- ankihub/db/exceptions.py | 5 + ankihub/db/rw_lock.py | 44 + ankihub/gui/browser/browser.py | 41 +- ankihub/gui/browser/custom_columns.py | 34 +- ankihub/gui/errors.py | 17 +- ankihub/lib/readerwriterlock/AUTHORS.md | 12 + ankihub/lib/readerwriterlock/LICENSE.txt | 21 + ankihub/lib/readerwriterlock/__init__.py | 3 + ankihub/lib/readerwriterlock/py.typed | 1 + ankihub/lib/readerwriterlock/rwlock.py | 718 +++++++++++++++++ ankihub/lib/readerwriterlock/rwlock_async.py | 794 +++++++++++++++++++ pyproject.toml | 1 + tests/addon/test_unit.py | 59 +- 16 files changed, 1720 insertions(+), 104 deletions(-) create mode 100644 ankihub/db/rw_lock.py create mode 100644 ankihub/lib/readerwriterlock/AUTHORS.md create mode 100644 ankihub/lib/readerwriterlock/LICENSE.txt create mode 100644 ankihub/lib/readerwriterlock/__init__.py create mode 100644 ankihub/lib/readerwriterlock/py.typed create mode 100644 ankihub/lib/readerwriterlock/rwlock.py create mode 100644 ankihub/lib/readerwriterlock/rwlock_async.py diff --git a/ankihub/db/__init__.py b/ankihub/db/__init__.py index 96d4cffcd..c0ca29048 100644 --- a/ankihub/db/__init__.py +++ b/ankihub/db/__init__.py @@ -1,7 +1,6 @@ from .db import ( # noqa: F401 ankihub_db, - attach_ankihub_db_to_anki_db_connection, attached_ankihub_db, - detach_ankihub_db_from_anki_db_connection, + detached_ankihub_db, is_ankihub_db_attached_to_anki_db, ) diff --git a/ankihub/db/db.py b/ankihub/db/db.py index 9ab39c54d..d28901c34 100644 --- a/ankihub/db/db.py +++ b/ankihub/db/db.py @@ -29,9 +29,37 @@ from ..settings import ANKI_INT_VERSION, ANKI_VERSION_23_10_00 from .db_utils import DBConnection from .exceptions import IntegrityError +from .rw_lock import exclusive_db_access_context, non_exclusive_db_access_context -def attach_ankihub_db_to_anki_db_connection() -> None: +@contextmanager +def attached_ankihub_db(): + """Context manager that attaches the AnkiHub DB to the Anki DB connection and detaches it when the context exits. + The purpose is to e.g. do join queries between the Anki DB and the AnkiHub DB through aqt.mw.col.db.execute(). + A lock is used to ensure that other threads don't try to access the AnkiHub DB through the _AnkiHubDB class + while it is attached to the Anki DB. + """ + with exclusive_db_access_context(): + _attach_ankihub_db_to_anki_db_connection() + try: + yield + finally: + _detach_ankihub_db_from_anki_db_connection() + + +@contextmanager +def detached_ankihub_db(): + """Context manager that ensures the AnkiHub DB is detached from the Anki DB connection while the context is active. + The purpose of this is to be able to safely perform operations on the AnkiHub DB which require it to be detached, + for example coyping the AnkiHub DB file. + It's used by the _AnkiHubDB class to ensure that the AnkiHub DB is detached from the Anki DB while + queries are executed through the _AnkiHubDB class. + """ + with non_exclusive_db_access_context(): + yield + + +def _attach_ankihub_db_to_anki_db_connection() -> None: if aqt.mw.col is None: LOGGER.info("The collection is not open. Not attaching AnkiHub DB.") return @@ -44,7 +72,7 @@ def attach_ankihub_db_to_anki_db_connection() -> None: LOGGER.info("Attached AnkiHub DB to Anki DB connection") -def detach_ankihub_db_from_anki_db_connection() -> None: +def _detach_ankihub_db_from_anki_db_connection() -> None: if aqt.mw.col is None: LOGGER.info("The collection is not open. Not detaching AnkiHub DB.") return @@ -82,17 +110,7 @@ def is_ankihub_db_attached_to_anki_db() -> bool: return result -@contextmanager -def attached_ankihub_db(): - attach_ankihub_db_to_anki_db_connection() - try: - yield - finally: - detach_ankihub_db_from_anki_db_connection() - - class _AnkiHubDB: - # name of the database when attached to the Anki DB connection database_name = "ankihub_db" database_path: Optional[Path] = None @@ -199,7 +217,10 @@ def schema_version(self) -> int: return result def connection(self) -> DBConnection: - result = DBConnection(conn=sqlite3.connect(ankihub_db.database_path)) + result = DBConnection( + conn=sqlite3.connect(ankihub_db.database_path), + thread_safety_context_func=detached_ankihub_db, + ) return result def upsert_notes_data( @@ -224,7 +245,6 @@ def upsert_notes_data( upserted_notes: List[NoteInfo] = [] skipped_notes: List[NoteInfo] = [] with self.connection() as conn: - for note_data in notes_data: conflicting_ah_nid = conn.first( """ diff --git a/ankihub/db/db_utils.py b/ankihub/db/db_utils.py index 9f6620cda..c5499482b 100644 --- a/ankihub/db/db_utils.py +++ b/ankihub/db/db_utils.py @@ -1,5 +1,5 @@ import sqlite3 -from typing import Any, List, Optional, Tuple +from typing import Any, Callable, ContextManager, List, Optional, Tuple from .. import LOGGER @@ -11,14 +11,23 @@ class DBConnection: the context will be part of a single transaction. If an exception occurs within the context, the transaction will be automatically rolled back. + thread_safety_context_func: A function that returns a context manager. This is used to ensure that + threads only access the database when it is safe to do so. + Note: Once a query has been executed using an instance of this class, the instance cannot be used to execute another query unless it is within a context manager. Attempting to do so will raise an exception. """ - def __init__(self, conn: sqlite3.Connection): + def __init__( + self, + conn: sqlite3.Connection, + thread_safety_context_func: Callable[[], ContextManager], + ): self._conn = conn self._is_used_as_context_manager = False + self._thread_safety_context_func = thread_safety_context_func + self._thread_safety_context: Optional[ContextManager] = None def execute( self, @@ -26,6 +35,13 @@ def execute( *args, first_row_only=False, ) -> List: + if self._is_used_as_context_manager: + return self._execute_inner(sql, *args, first_row_only=first_row_only) + else: + with self._thread_safety_context_func(): + return self._execute_inner(sql, *args, first_row_only=first_row_only) + + def _execute_inner(self, sql: str, *args, first_row_only=False) -> List: try: cur = self._conn.cursor() cur.execute(sql, args) @@ -67,6 +83,8 @@ def first(self, sql: str, *args) -> Optional[Tuple]: return None def __enter__(self): + self._thread_safety_context = self._thread_safety_context_func() + self._thread_safety_context.__enter__() self._is_used_as_context_manager = True return self @@ -78,3 +96,4 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._conn.close() self._is_used_as_context_manager = False + self._thread_safety_context.__exit__(exc_type, exc_val, exc_tb) diff --git a/ankihub/db/exceptions.py b/ankihub/db/exceptions.py index 007ca9996..2aaf95484 100644 --- a/ankihub/db/exceptions.py +++ b/ankihub/db/exceptions.py @@ -5,3 +5,8 @@ class AnkiHubDBError(Exception): class IntegrityError(AnkiHubDBError): def __init__(self, message): super().__init__(message) + + +class LockAcquisitionTimeoutError(AnkiHubDBError): + def __init__(self, message): + super().__init__(message) diff --git a/ankihub/db/rw_lock.py b/ankihub/db/rw_lock.py new file mode 100644 index 000000000..30771b20e --- /dev/null +++ b/ankihub/db/rw_lock.py @@ -0,0 +1,44 @@ +"""This module defines a readers-writer lock for the AnkiHub DB. +Multiple threads can enter the non_exclusive_db_access_context() context, but when +a thread enters the exclusive_db_access_context() context, no other thread can enter +either context until the thread that entered the exclusive_db_access_context() context +exits it. +""" +from contextlib import contextmanager + +from readerwriterlock import rwlock + +from .. import LOGGER +from .exceptions import LockAcquisitionTimeoutError + +LOCK_TIMEOUT_SECONDS = 5 + +rw_lock = rwlock.RWLockFair() +write_lock = rw_lock.gen_wlock() + + +@contextmanager +def exclusive_db_access_context(): + if write_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT_SECONDS): + LOGGER.debug("Acquired exclusive access.") + try: + yield + finally: + write_lock.release() + LOGGER.debug("Released exclusive access.") + else: + raise LockAcquisitionTimeoutError("Could not acquire exclusive access.") + + +@contextmanager +def non_exclusive_db_access_context(): + lock = rw_lock.gen_rlock() + if lock.acquire(blocking=True, timeout=LOCK_TIMEOUT_SECONDS): + LOGGER.debug("Acquired non-exclusive access.") + try: + yield + finally: + lock.release() + LOGGER.debug("Released non-exclusive access.") + else: + raise LockAcquisitionTimeoutError("Could not acquire non-exclusive access.") diff --git a/ankihub/gui/browser/browser.py b/ankihub/gui/browser/browser.py index 32ba2c53a..2e3913dd9 100644 --- a/ankihub/gui/browser/browser.py +++ b/ankihub/gui/browser/browser.py @@ -33,12 +33,7 @@ from ... import LOGGER from ...ankihub_client import SuggestionType -from ...db import ( - ankihub_db, - attach_ankihub_db_to_anki_db_connection, - attached_ankihub_db, - detach_ankihub_db_from_anki_db_connection, -) +from ...db import ankihub_db, attached_ankihub_db from ...main.importing import get_fields_protected_by_tags from ...main.note_conversion import ( TAG_FOR_PROTECTING_ALL_FIELDS, @@ -55,7 +50,6 @@ from ..utils import ask_user, choose_ankihub_deck, choose_list, choose_subset from .custom_columns import ( AnkiHubIdColumn, - CustomColumn, EditedAfterSyncColumn, UpdatedSinceLastReviewColumn, ) @@ -512,7 +506,6 @@ def on_done(future: Future) -> None: def _reset_optional_tag_group(extension_id: int) -> None: - extension_config = config.deck_extension_config(extension_id) _remove_optional_tags_of_extension(extension_config) @@ -566,30 +559,10 @@ def _on_browser_did_fetch_row( # cutom search nodes def _setup_search(): browser_will_search.append(_on_browser_will_search) - browser_did_search.append(_on_browser_did_search) + browser_did_search.append(_on_browser_did_search_handle_custom_search_parameters) def _on_browser_will_search(ctx: SearchContext): - _on_browser_will_search_handle_custom_column_ordering(ctx) - _on_browser_will_search_handle_custom_search_parameters(ctx) - - -def _on_browser_will_search_handle_custom_column_ordering(ctx: SearchContext): - if not isinstance(ctx.order, Column): - return - - custom_column: CustomColumn = next( - (c for c in custom_columns if c.builtin_column.key == ctx.order.key), None - ) - if custom_column is None: - return - - attach_ankihub_db_to_anki_db_connection() - - ctx.order = custom_column.order_by_str() - - -def _on_browser_will_search_handle_custom_search_parameters(ctx: SearchContext): if not ctx.search: return @@ -615,15 +588,6 @@ def _on_browser_will_search_handle_custom_search_parameters(ctx: SearchContext): ctx.search = ctx.search.replace(m.group(0), "") -def _on_browser_did_search(ctx: SearchContext): - # Detach the ankihub database in case it was attached in on_browser_will_search_handle_custom_column_ordering. - # The attached_ankihub_db context manager can't be used for this because the database query happens - # in the rust backend. - detach_ankihub_db_from_anki_db_connection() - - _on_browser_did_search_handle_custom_search_parameters(ctx) - - def _on_browser_did_search_handle_custom_search_parameters(ctx: SearchContext): global custom_search_nodes @@ -735,7 +699,6 @@ def _sidebar_item_descendants(item: SidebarItem) -> List[SidebarItem]: def _add_ankihub_tree(tree: SidebarItem) -> SidebarItem: - result = tree.add_simple( name="👑 AnkiHub", icon="", diff --git a/ankihub/gui/browser/custom_columns.py b/ankihub/gui/browser/custom_columns.py index 9ad417189..9058b7180 100644 --- a/ankihub/gui/browser/custom_columns.py +++ b/ankihub/gui/browser/custom_columns.py @@ -1,17 +1,14 @@ """Custom Anki browser columns.""" import uuid from abc import abstractmethod -from typing import Optional, Sequence +from typing import Sequence import aqt from anki.collection import BrowserColumns from anki.notes import Note -from anki.utils import ids2str from aqt.browser import Browser, CellRow, Column, ItemId from ...db import ankihub_db -from ...main.utils import note_types_with_ankihub_id_field -from ...settings import ANKI_INT_VERSION, ANKI_VERSION_23_10_00 class CustomColumn: @@ -49,11 +46,6 @@ def _display_value( ) -> str: raise NotImplementedError - def order_by_str(self) -> Optional[str]: - """Return the SQL string that will be appended after "ORDER BY" to the query that - fetches the search results when sorting by this column.""" - return None - class AnkiHubIdColumn(CustomColumn): builtin_column = Column( @@ -79,21 +71,10 @@ def _display_value( class EditedAfterSyncColumn(CustomColumn): def __init__(self) -> None: - if ANKI_INT_VERSION >= ANKI_VERSION_23_10_00: - sorting_args = { - "sorting_cards": BrowserColumns.SORTING_DESCENDING, - "sorting_notes": BrowserColumns.SORTING_DESCENDING, - } - else: - sorting_args = { - "sorting": BrowserColumns.SORTING_DESCENDING, - } - self.builtin_column = Column( key="edited_after_sync", cards_mode_label="AnkiHub: Modified After Sync", notes_mode_label="AnkiHub: Modified After Sync", - **sorting_args, # type: ignore uses_cell_font=False, alignment=BrowserColumns.ALIGNMENT_CENTER, ) @@ -112,19 +93,6 @@ def _display_value( return "Yes" if note.mod > last_sync else "No" - def order_by_str(self) -> str: - mids = note_types_with_ankihub_id_field() - if not mids: - return None - - return f""" - ( - SELECT n.mod > ah_n.mod from {ankihub_db.database_name}.notes AS ah_n - WHERE ah_n.anki_note_id = n.id LIMIT 1 - ) DESC, - (n.mid IN {ids2str(mids)}) DESC - """ - class UpdatedSinceLastReviewColumn(CustomColumn): builtin_column = Column( diff --git a/ankihub/gui/errors.py b/ankihub/gui/errors.py index 9a43a9b7f..fd4a74504 100644 --- a/ankihub/gui/errors.py +++ b/ankihub/gui/errors.py @@ -30,10 +30,7 @@ from .. import LOGGER from ..addon_ankihub_client import AddonAnkiHubClient as AnkiHubClient from ..ankihub_client import AnkiHubHTTPError, AnkiHubRequestException -from ..db import ( - detach_ankihub_db_from_anki_db_connection, - is_ankihub_db_attached_to_anki_db, -) +from ..db import detached_ankihub_db, is_ankihub_db_attached_to_anki_db from ..gui.exceptions import DeckDownloadAndInstallError from ..settings import ( ADDON_VERSION, @@ -135,9 +132,6 @@ def upload_logs_and_data_in_background( def _upload_logs_and_data_in_background(key: str) -> str: - # detach the ankihub database from the anki database connection to prevent file permission errors - detach_ankihub_db_from_anki_db_connection() - file_path = _zip_logs_and_data() # upload the zip file @@ -159,10 +153,11 @@ def _zip_logs_and_data() -> Path: temp_file = tempfile.NamedTemporaryFile(delete=False) temp_file.close() with zipfile.ZipFile(temp_file.name, "w") as zipf: - # Add the ankihub base directory to the zip. It also contains the logs. - source_dir = ankihub_base_path() - for file in source_dir.rglob("*"): - zipf.write(file, arcname=file.relative_to(source_dir)) + with detached_ankihub_db(): + # Add the ankihub base directory to the zip. It also contains the logs. + source_dir = ankihub_base_path() + for file in source_dir.rglob("*"): + zipf.write(file, arcname=file.relative_to(source_dir)) # Add the Anki collection to the zip. try: diff --git a/ankihub/lib/readerwriterlock/AUTHORS.md b/ankihub/lib/readerwriterlock/AUTHORS.md new file mode 100644 index 000000000..f555921c2 --- /dev/null +++ b/ankihub/lib/readerwriterlock/AUTHORS.md @@ -0,0 +1,12 @@ +Author +====== +Éric Larivière + +Contributors +------------ + +**Thank you to every contributor** + + +- Justin Patrin +- Mike Merrill diff --git a/ankihub/lib/readerwriterlock/LICENSE.txt b/ankihub/lib/readerwriterlock/LICENSE.txt new file mode 100644 index 000000000..96f8ed14e --- /dev/null +++ b/ankihub/lib/readerwriterlock/LICENSE.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2018 Éric Larivière + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/ankihub/lib/readerwriterlock/__init__.py b/ankihub/lib/readerwriterlock/__init__.py new file mode 100644 index 000000000..3da1778dc --- /dev/null +++ b/ankihub/lib/readerwriterlock/__init__.py @@ -0,0 +1,3 @@ +"""Reader writer locks.""" + +__all__ = ["rwlock"] diff --git a/ankihub/lib/readerwriterlock/py.typed b/ankihub/lib/readerwriterlock/py.typed new file mode 100644 index 000000000..755a5649f --- /dev/null +++ b/ankihub/lib/readerwriterlock/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561. The readerwriterlock package uses inline types. diff --git a/ankihub/lib/readerwriterlock/rwlock.py b/ankihub/lib/readerwriterlock/rwlock.py new file mode 100644 index 000000000..a70509eac --- /dev/null +++ b/ankihub/lib/readerwriterlock/rwlock.py @@ -0,0 +1,718 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +"""Read Write Lock.""" + +import threading +import sys +import time + +from typing import Callable +from typing import Optional +from typing import Type +from types import TracebackType +from typing_extensions import Protocol +from typing_extensions import runtime_checkable + +try: + threading.Lock().release() +except BaseException as exc: + RELEASE_ERR_CLS = type(exc) # pylint: disable=invalid-name + RELEASE_ERR_MSG = str(exc) +else: + raise AssertionError() # pragma: no cover + + +@runtime_checkable +class Lockable(Protocol): + """Lockable. Compatible with threading.Lock interface.""" + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + def release(self) -> None: + """Release the lock.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + def __enter__(self) -> bool: + """Enter context manager.""" + self.acquire() + return False + + def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[Exception], exc_tb: Optional[TracebackType]) -> Optional[bool]: # type: ignore + """Exit context manager.""" + self.release() + return False + + +@runtime_checkable +class LockableD(Lockable, Protocol): + """Lockable Downgradable.""" + + def downgrade(self) -> Lockable: + """Downgrade.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + +class _ThreadSafeInt(): + """Internal thread safe integer like object. + + Implements only the bare minimum features for the RWLock implementation's need. + """ + + def __init__(self, initial_value: int, lock_factory: Callable[[], Lockable] = threading.Lock) -> None: + """Init.""" + self.__value_lock = lock_factory() + self.__value: int = initial_value + + def __int__(self) -> int: + """Get int value.""" + return self.__value + + def __eq__(self, other) -> bool: + """Self == other.""" + return int(self) == int(other) + + def increment(self) -> None: + """Increment value by one.""" + with self.__value_lock: + self.__value += 1 + + def decrement(self) -> None: + """Decrement value by one.""" + with self.__value_lock: + self.__value -= 1 + + +@runtime_checkable +class RWLockable(Protocol): + """Read/write lock.""" + + def gen_rlock(self) -> Lockable: + """Generate a reader lock.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + def gen_wlock(self) -> Lockable: + """Generate a writer lock.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + +@runtime_checkable +class RWLockableD(Protocol): + """Read/write lock Downgradable.""" + + def gen_rlock(self) -> Lockable: + """Generate a reader lock.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + def gen_wlock(self) -> LockableD: + """Generate a writer lock.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + +class RWLockRead(RWLockable): + """A Read/Write lock giving preference to Reader.""" + + def __init__(self, lock_factory: Callable[[], Lockable] = threading.Lock, time_source: Callable[[], float] = time.perf_counter) -> None: + """Init.""" + self.v_read_count: int = 0 + self.c_time_source = time_source + self.c_resource = lock_factory() + self.c_lock_read_count = lock_factory() + + class _aReader(Lockable): + def __init__(self, p_RWLock: "RWLockRead") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout: Optional[float] = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline: Optional[float] = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + if not self.c_rw_lock.c_lock_read_count.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + return False + self.c_rw_lock.v_read_count += 1 + if 1 == self.c_rw_lock.v_read_count: + if not self.c_rw_lock.c_resource.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.v_read_count -= 1 + self.c_rw_lock.c_lock_read_count.release() + return False + self.c_rw_lock.c_lock_read_count.release() + self.v_locked = True + return True + + def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_lock_read_count.acquire() + self.c_rw_lock.v_read_count -= 1 + if 0 == self.c_rw_lock.v_read_count: + self.c_rw_lock.c_resource.release() + self.c_rw_lock.c_lock_read_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + class _aWriter(Lockable): + def __init__(self, p_RWLock: "RWLockRead") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + locked: bool = self.c_rw_lock.c_resource.acquire(blocking, timeout) + self.v_locked = locked + return locked + + def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_resource.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + def gen_rlock(self) -> "RWLockRead._aReader": + """Generate a reader lock.""" + return RWLockRead._aReader(self) + + def gen_wlock(self) -> "RWLockRead._aWriter": + """Generate a writer lock.""" + return RWLockRead._aWriter(self) + + +class RWLockWrite(RWLockable): + """A Read/Write lock giving preference to Writer.""" + + def __init__(self, lock_factory: Callable[[], Lockable] = threading.Lock, time_source: Callable[[], float] = time.perf_counter) -> None: + """Init.""" + self.v_read_count: int = 0 + self.v_write_count: int = 0 + self.c_time_source = time_source + self.c_lock_read_count = lock_factory() + self.c_lock_write_count = lock_factory() + self.c_lock_read_entry = lock_factory() + self.c_lock_read_try = lock_factory() + self.c_resource = lock_factory() + + class _aReader(Lockable): + def __init__(self, p_RWLock: "RWLockWrite") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + if not self.c_rw_lock.c_lock_read_entry.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + return False + if not self.c_rw_lock.c_lock_read_try.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.c_lock_read_entry.release() + return False + if not self.c_rw_lock.c_lock_read_count.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_read_entry.release() + return False + self.c_rw_lock.v_read_count += 1 + if 1 == self.c_rw_lock.v_read_count: + if not self.c_rw_lock.c_resource.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_read_entry.release() + self.c_rw_lock.v_read_count -= 1 + self.c_rw_lock.c_lock_read_count.release() + return False + self.c_rw_lock.c_lock_read_count.release() + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_read_entry.release() + self.v_locked = True + return True + + def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_lock_read_count.acquire() + self.c_rw_lock.v_read_count -= 1 + if 0 == self.c_rw_lock.v_read_count: + self.c_rw_lock.c_resource.release() + self.c_rw_lock.c_lock_read_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + class _aWriter(Lockable): + def __init__(self, p_RWLock: "RWLockWrite") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + if not self.c_rw_lock.c_lock_write_count.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + return False + self.c_rw_lock.v_write_count += 1 + if 1 == self.c_rw_lock.v_write_count: + if not self.c_rw_lock.c_lock_read_try.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.v_write_count -= 1 + self.c_rw_lock.c_lock_write_count.release() + return False + self.c_rw_lock.c_lock_write_count.release() + if not self.c_rw_lock.c_resource.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.c_lock_write_count.acquire() + self.c_rw_lock.v_write_count -= 1 + if 0 == self.c_rw_lock.v_write_count: + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_write_count.release() + return False + self.v_locked = True + return True + + def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_resource.release() + self.c_rw_lock.c_lock_write_count.acquire() + self.c_rw_lock.v_write_count -= 1 + if 0 == self.c_rw_lock.v_write_count: + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_write_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + def gen_rlock(self) -> "RWLockWrite._aReader": + """Generate a reader lock.""" + return RWLockWrite._aReader(self) + + def gen_wlock(self) -> "RWLockWrite._aWriter": + """Generate a writer lock.""" + return RWLockWrite._aWriter(self) + + +class RWLockFair(RWLockable): + """A Read/Write lock giving fairness to both Reader and Writer.""" + + def __init__(self, lock_factory: Callable[[], Lockable] = threading.Lock, time_source: Callable[[], float] = time.perf_counter) -> None: + """Init.""" + self.v_read_count: int = 0 + self.c_time_source = time_source + self.c_lock_read_count = lock_factory() + self.c_lock_read = lock_factory() + self.c_lock_write = lock_factory() + + class _aReader(Lockable): + def __init__(self, p_RWLock: "RWLockFair") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + if not self.c_rw_lock.c_lock_read.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + return False + if not self.c_rw_lock.c_lock_read_count.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.c_lock_read.release() + return False + self.c_rw_lock.v_read_count += 1 + if 1 == self.c_rw_lock.v_read_count: + if not self.c_rw_lock.c_lock_write.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.v_read_count -= 1 + self.c_rw_lock.c_lock_read_count.release() + self.c_rw_lock.c_lock_read.release() + return False + self.c_rw_lock.c_lock_read_count.release() + self.c_rw_lock.c_lock_read.release() + self.v_locked = True + return True + + def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_lock_read_count.acquire() + self.c_rw_lock.v_read_count -= 1 + if 0 == self.c_rw_lock.v_read_count: + self.c_rw_lock.c_lock_write.release() + self.c_rw_lock.c_lock_read_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + class _aWriter(Lockable): + def __init__(self, p_RWLock: "RWLockFair") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + if not self.c_rw_lock.c_lock_read.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + return False + if not self.c_rw_lock.c_lock_write.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.c_lock_read.release() + return False + self.v_locked = True + return True + + def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_lock_write.release() + self.c_rw_lock.c_lock_read.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + def gen_rlock(self) -> "RWLockFair._aReader": + """Generate a reader lock.""" + return RWLockFair._aReader(self) + + def gen_wlock(self) -> "RWLockFair._aWriter": + """Generate a writer lock.""" + return RWLockFair._aWriter(self) + + +class RWLockReadD(RWLockableD): + """A Read/Write lock giving preference to Reader.""" + + def __init__(self, lock_factory: Callable[[], Lockable] = threading.Lock, time_source: Callable[[], float] = time.perf_counter) -> None: + """Init.""" + self.v_read_count: _ThreadSafeInt = _ThreadSafeInt(initial_value=0, lock_factory=lock_factory) + self.c_time_source = time_source + self.c_resource = lock_factory() + self.c_lock_read_count = lock_factory() + + class _aReader(Lockable): + def __init__(self, p_RWLock: "RWLockReadD") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout: Optional[float] = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline: Optional[float] = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + if not self.c_rw_lock.c_lock_read_count.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + return False + self.c_rw_lock.v_read_count.increment() + if 1 == int(self.c_rw_lock.v_read_count): + if not self.c_rw_lock.c_resource.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.v_read_count.decrement() + self.c_rw_lock.c_lock_read_count.release() + return False + self.c_rw_lock.c_lock_read_count.release() + self.v_locked = True + return True + + def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_lock_read_count.acquire() + self.c_rw_lock.v_read_count.decrement() + if 0 == int(self.c_rw_lock.v_read_count): + self.c_rw_lock.c_resource.release() + self.c_rw_lock.c_lock_read_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + class _aWriter(LockableD): + def __init__(self, p_RWLock: "RWLockReadD") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + locked: bool = self.c_rw_lock.c_resource.acquire(blocking, timeout) + self.v_locked = locked + return locked + + def downgrade(self) -> Lockable: + """Downgrade.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + + result = self.c_rw_lock.gen_rlock() + + wait_blocking: bool = True + + def lock_result() -> None: + nonlocal wait_blocking + wait_blocking = False + result.acquire() # This is a blocking action + + threading.Thread(group=None, target=lock_result, name="RWLockReadD_Downgrade", daemon=False).start() + while wait_blocking: # Busy wait for the thread to be almost in its blocking state. + time.sleep(sys.float_info.min) + + for _ in range(123): time.sleep(sys.float_info.min) # Heuristic sleep delay to leave some extra time for the thread to block. + + self.release() # Open the gate! the current RW lock strategy gives priority to reader, therefore the result will acquire lock before any other writer lock. + + while not result.locked(): + time.sleep(sys.float_info.min) # Busy wait for the threads to complete their tasks. + return result + + def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_resource.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + def gen_rlock(self) -> "RWLockReadD._aReader": + """Generate a reader lock.""" + return RWLockReadD._aReader(self) + + def gen_wlock(self) -> "RWLockReadD._aWriter": + """Generate a writer lock.""" + return RWLockReadD._aWriter(self) + + +class RWLockWriteD(RWLockableD): + """A Read/Write lock giving preference to Writer.""" + + def __init__(self, lock_factory: Callable[[], Lockable] = threading.Lock, time_source: Callable[[], float] = time.perf_counter) -> None: + """Init.""" + self.v_read_count: _ThreadSafeInt = _ThreadSafeInt(lock_factory=lock_factory, initial_value=0) + self.v_write_count: int = 0 + self.c_time_source = time_source + self.c_lock_read_count = lock_factory() + self.c_lock_write_count = lock_factory() + self.c_lock_read_entry = lock_factory() + self.c_lock_read_try = lock_factory() + self.c_resource = lock_factory() + + class _aReader(Lockable): + def __init__(self, p_RWLock: "RWLockWriteD") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + if not self.c_rw_lock.c_lock_read_entry.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + return False + if not self.c_rw_lock.c_lock_read_try.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.c_lock_read_entry.release() + return False + if not self.c_rw_lock.c_lock_read_count.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_read_entry.release() + return False + self.c_rw_lock.v_read_count.increment() + if 1 == self.c_rw_lock.v_read_count: + if not self.c_rw_lock.c_resource.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_read_entry.release() + self.c_rw_lock.v_read_count.decrement() + self.c_rw_lock.c_lock_read_count.release() + return False + self.c_rw_lock.c_lock_read_count.release() + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_read_entry.release() + self.v_locked = True + return True + + def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_lock_read_count.acquire() + self.c_rw_lock.v_read_count.decrement() + if 0 == self.c_rw_lock.v_read_count: + self.c_rw_lock.c_resource.release() + self.c_rw_lock.c_lock_read_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + class _aWriter(LockableD): + def __init__(self, p_RWLock: "RWLockWriteD") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + if not self.c_rw_lock.c_lock_write_count.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + return False + self.c_rw_lock.v_write_count += 1 + if 1 == self.c_rw_lock.v_write_count: + if not self.c_rw_lock.c_lock_read_try.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.v_write_count -= 1 + self.c_rw_lock.c_lock_write_count.release() + return False + self.c_rw_lock.c_lock_write_count.release() + if not self.c_rw_lock.c_resource.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.c_lock_write_count.acquire() + self.c_rw_lock.v_write_count -= 1 + if 0 == self.c_rw_lock.v_write_count: + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_write_count.release() + return False + self.v_locked = True + return True + + def downgrade(self) -> Lockable: + """Downgrade.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.c_rw_lock.v_read_count.increment() + + self.v_locked = False + self.c_rw_lock.c_lock_write_count.acquire() + self.c_rw_lock.v_write_count -= 1 + if 0 == self.c_rw_lock.v_write_count: + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_write_count.release() + + result = self.c_rw_lock._aReader(p_RWLock=self.c_rw_lock) + result.v_locked = True + return result + + def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_resource.release() + self.c_rw_lock.c_lock_write_count.acquire() + self.c_rw_lock.v_write_count -= 1 + if 0 == self.c_rw_lock.v_write_count: + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_write_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + def gen_rlock(self) -> "RWLockWriteD._aReader": + """Generate a reader lock.""" + return RWLockWriteD._aReader(self) + + def gen_wlock(self) -> "RWLockWriteD._aWriter": + """Generate a writer lock.""" + return RWLockWriteD._aWriter(self) + + +class RWLockFairD(RWLockableD): + """A Read/Write lock giving fairness to both Reader and Writer.""" + + def __init__(self, lock_factory: Callable[[], Lockable] = threading.Lock, time_source: Callable[[], float] = time.perf_counter) -> None: + """Init.""" + self.v_read_count: int = 0 + self.c_time_source = time_source + self.c_lock_read_count = lock_factory() + self.c_lock_read = lock_factory() + self.c_lock_write = lock_factory() + + class _aReader(Lockable): + def __init__(self, p_RWLock: "RWLockFairD") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + if not self.c_rw_lock.c_lock_read.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + return False + if not self.c_rw_lock.c_lock_read_count.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.c_lock_read.release() + return False + self.c_rw_lock.v_read_count += 1 + if 1 == self.c_rw_lock.v_read_count: + if not self.c_rw_lock.c_lock_write.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.v_read_count -= 1 + self.c_rw_lock.c_lock_read_count.release() + self.c_rw_lock.c_lock_read.release() + return False + self.c_rw_lock.c_lock_read_count.release() + self.c_rw_lock.c_lock_read.release() + self.v_locked = True + return True + + def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_lock_read_count.acquire() + self.c_rw_lock.v_read_count -= 1 + if 0 == self.c_rw_lock.v_read_count: + self.c_rw_lock.c_lock_write.release() + self.c_rw_lock.c_lock_read_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + class _aWriter(LockableD): + def __init__(self, p_RWLock: "RWLockFairD") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + if not self.c_rw_lock.c_lock_read.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + return False + if not self.c_rw_lock.c_lock_write.acquire(blocking=True, timeout=-1 if c_deadline is None else max(0, c_deadline - self.c_rw_lock.c_time_source())): + self.c_rw_lock.c_lock_read.release() + return False + self.v_locked = True + return True + + def downgrade(self) -> Lockable: + """Downgrade.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.c_rw_lock.v_read_count += 1 + + self.v_locked = False + self.c_rw_lock.c_lock_read.release() + + result = self.c_rw_lock._aReader(p_RWLock=self.c_rw_lock) + result.v_locked = True + return result + + def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_lock_write.release() + self.c_rw_lock.c_lock_read.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + def gen_rlock(self) -> "RWLockFairD._aReader": + """Generate a reader lock.""" + return RWLockFairD._aReader(self) + + def gen_wlock(self) -> "RWLockFairD._aWriter": + """Generate a writer lock.""" + return RWLockFairD._aWriter(self) diff --git a/ankihub/lib/readerwriterlock/rwlock_async.py b/ankihub/lib/readerwriterlock/rwlock_async.py new file mode 100644 index 000000000..19e377b6e --- /dev/null +++ b/ankihub/lib/readerwriterlock/rwlock_async.py @@ -0,0 +1,794 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +"""Read Write Lock.""" + +import asyncio +import sys +import time + +from typing import Callable +from typing import Optional +from typing import Type +from typing import Union +from types import TracebackType +from typing_extensions import Protocol +from typing_extensions import runtime_checkable + +try: + from asyncio import create_task as run_task +except ImportError: # pragma: no cover + from asyncio import ensure_future as run_task # type: ignore [misc] # pragma: no cover + +try: + asyncio.Lock().release() +except BaseException as exc: + RELEASE_ERR_CLS = type(exc) # pylint: disable=invalid-name + RELEASE_ERR_MSG = str(exc) +else: + raise AssertionError() # pragma: no cover + + +@runtime_checkable +class Lockable(Protocol): + """Lockable. Compatible with threading.Lock interface.""" + + async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + async def release(self) -> None: + """Release the lock.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + async def __aenter__(self) -> bool: + """Enter context manager.""" + await self.acquire() + return False + + async def __aexit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[Exception], exc_tb: Optional[TracebackType]) -> Optional[bool]: # type: ignore + """Exit context manager.""" + await self.release() + return False + + +@runtime_checkable +class LockableD(Lockable, Protocol): + """Lockable Downgradable.""" + + async def downgrade(self) -> Lockable: + """Downgrade.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + +class _ThreadSafeInt(): + """Internal thread safe integer like object. + + Implements only the bare minimum features for the RWLock implementation's need. + """ + + def __init__(self, initial_value: int, lock_factory: Union[Callable[[], Lockable], Type[asyncio.Lock]] = asyncio.Lock) -> None: + """Init.""" + self.__value_lock = lock_factory() + self.__value: int = initial_value + + def __int__(self) -> int: + """Get int value.""" + return self.__value + + def __eq__(self, other) -> bool: + """Self == other.""" + return int(self) == int(other) + + async def increment(self): + """Increment the value by one.""" + async with self.__value_lock: + self.__value += 1 + + async def decrement(self): + """Decrement the value by one.""" + async with self.__value_lock: + self.__value -= 1 + + +@runtime_checkable +class RWLockable(Protocol): + """Read/write lock.""" + + async def gen_rlock(self) -> Lockable: + """Generate a reader lock.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + async def gen_wlock(self) -> Lockable: + """Generate a writer lock.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + +@runtime_checkable +class RWLockableD(Protocol): + """Read/write lock Downgradable.""" + + async def gen_rlock(self) -> Lockable: + """Generate a reader lock.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + async def gen_wlock(self) -> LockableD: + """Generate a writer lock.""" + raise AssertionError("Should be overriden") # Will be overriden. # pragma: no cover + + +class RWLockRead(RWLockable): + """A Read/Write lock giving preference to Reader.""" + + def __init__(self, lock_factory: Union[Callable[[], Lockable], Type[asyncio.Lock]] = asyncio.Lock, time_source: Callable[[], float] = time.perf_counter) -> None: + """Init.""" + self.v_read_count: int = 0 + self.c_time_source = time_source + self.c_resource = lock_factory() + self.c_lock_read_count = lock_factory() + + class _aReader(Lockable): + def __init__(self, p_RWLock: "RWLockRead") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout: Optional[float] = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline: Optional[float] = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read_count.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + return False + self.c_rw_lock.v_read_count += 1 + if 1 == self.c_rw_lock.v_read_count: + try: + await asyncio.wait_for(self.c_rw_lock.c_resource.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.v_read_count -= 1 + self.c_rw_lock.c_lock_read_count.release() + return False + self.c_rw_lock.c_lock_read_count.release() + self.v_locked = True + return True + + async def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + await self.c_rw_lock.c_lock_read_count.acquire() + self.c_rw_lock.v_read_count -= 1 + if 0 == self.c_rw_lock.v_read_count: + self.c_rw_lock.c_resource.release() + self.c_rw_lock.c_lock_read_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + class _aWriter(Lockable): + def __init__(self, p_RWLock: "RWLockRead") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout: Optional[float] = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline: Optional[float] = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + try: + await asyncio.wait_for(self.c_rw_lock.c_resource.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + locked: bool = True + except asyncio.TimeoutError: + locked = False + self.v_locked = locked + return locked + + async def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_resource.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + async def gen_rlock(self) -> "RWLockRead._aReader": + """Generate a reader lock.""" + return RWLockRead._aReader(self) + + async def gen_wlock(self) -> "RWLockRead._aWriter": + """Generate a writer lock.""" + return RWLockRead._aWriter(self) + + +class RWLockWrite(RWLockable): + """A Read/Write lock giving preference to Writer.""" + + def __init__(self, lock_factory: Union[Callable[[], Lockable], Type[asyncio.Lock]] = asyncio.Lock, time_source: Callable[[], float] = time.perf_counter) -> None: + """Init.""" + self.v_read_count: int = 0 + self.v_write_count: int = 0 + self.c_time_source = time_source + self.c_lock_read_count = lock_factory() + self.c_lock_write_count = lock_factory() + self.c_lock_read_entry = lock_factory() + self.c_lock_read_try = lock_factory() + self.c_resource = lock_factory() + + class _aReader(Lockable): + def __init__(self, p_RWLock: "RWLockWrite") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read_entry.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + return False + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read_try.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.c_lock_read_entry.release() + return False + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read_count.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_read_entry.release() + return False + self.c_rw_lock.v_read_count += 1 + if 1 == self.c_rw_lock.v_read_count: + try: + await asyncio.wait_for(self.c_rw_lock.c_resource.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_read_entry.release() + self.c_rw_lock.v_read_count -= 1 + self.c_rw_lock.c_lock_read_count.release() + return False + self.c_rw_lock.c_lock_read_count.release() + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_read_entry.release() + self.v_locked = True + return True + + async def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + await self.c_rw_lock.c_lock_read_count.acquire() + self.c_rw_lock.v_read_count -= 1 + if 0 == self.c_rw_lock.v_read_count: + self.c_rw_lock.c_resource.release() + self.c_rw_lock.c_lock_read_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + class _aWriter(Lockable): + def __init__(self, p_RWLock: "RWLockWrite") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_write_count.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + return False + self.c_rw_lock.v_write_count += 1 + if 1 == int(self.c_rw_lock.v_write_count): + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read_try.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.v_write_count -= 1 + self.c_rw_lock.c_lock_write_count.release() + return False + self.c_rw_lock.c_lock_write_count.release() + try: + await asyncio.wait_for(self.c_rw_lock.c_resource.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + await self.c_rw_lock.c_lock_write_count.acquire() + self.c_rw_lock.v_write_count -= 1 + if 0 == int(self.c_rw_lock.v_write_count): + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_write_count.release() + return False + self.v_locked = True + return True + + async def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_resource.release() + await self.c_rw_lock.c_lock_write_count.acquire() + self.c_rw_lock.v_write_count -= 1 + if 0 == int(self.c_rw_lock.v_write_count): + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_write_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + async def gen_rlock(self) -> "RWLockWrite._aReader": + """Generate a reader lock.""" + return RWLockWrite._aReader(self) + + async def gen_wlock(self) -> "RWLockWrite._aWriter": + """Generate a writer lock.""" + return RWLockWrite._aWriter(self) + + +class RWLockFair(RWLockable): + """A Read/Write lock giving fairness to both Reader and Writer.""" + + def __init__(self, lock_factory: Union[Callable[[], Lockable], Type[asyncio.Lock]] = asyncio.Lock, time_source: Callable[[], float] = time.perf_counter) -> None: + """Init.""" + self.v_read_count: int = 0 + self.c_time_source = time_source + self.c_lock_read_count = lock_factory() + self.c_lock_read = lock_factory() + self.c_lock_write = lock_factory() + + class _aReader(Lockable): + def __init__(self, p_RWLock: "RWLockFair") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + return False + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read_count.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.c_lock_read.release() + return False + self.c_rw_lock.v_read_count += 1 + if 1 == self.c_rw_lock.v_read_count: + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_write.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.v_read_count -= 1 + self.c_rw_lock.c_lock_read_count.release() + self.c_rw_lock.c_lock_read.release() + return False + self.c_rw_lock.c_lock_read_count.release() + self.c_rw_lock.c_lock_read.release() + self.v_locked = True + return True + + async def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + await self.c_rw_lock.c_lock_read_count.acquire() + self.c_rw_lock.v_read_count -= 1 + if 0 == self.c_rw_lock.v_read_count: + self.c_rw_lock.c_lock_write.release() + self.c_rw_lock.c_lock_read_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + class _aWriter(Lockable): + def __init__(self, p_RWLock: "RWLockFair") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + return False + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_write.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.c_lock_read.release() + return False + self.v_locked = True + return True + + async def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_lock_write.release() + self.c_rw_lock.c_lock_read.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + async def gen_rlock(self) -> "RWLockFair._aReader": + """Generate a reader lock.""" + return RWLockFair._aReader(self) + + async def gen_wlock(self) -> "RWLockFair._aWriter": + """Generate a writer lock.""" + return RWLockFair._aWriter(self) + + +class RWLockReadD(RWLockableD): + """A Read/Write lock giving preference to Reader.""" + + def __init__(self, lock_factory: Union[Callable[[], Lockable], Type[asyncio.Lock]] = asyncio.Lock, time_source: Callable[[], float] = time.perf_counter) -> None: + """Init.""" + self.v_read_count: _ThreadSafeInt = _ThreadSafeInt(initial_value=0, lock_factory=lock_factory) + self.c_time_source = time_source + self.c_resource = lock_factory() + self.c_lock_read_count = lock_factory() + + class _aReader(Lockable): + def __init__(self, p_RWLock: "RWLockReadD") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout: Optional[float] = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline: Optional[float] = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read_count.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + return False + await self.c_rw_lock.v_read_count.increment() + if 1 == int(self.c_rw_lock.v_read_count): + try: + await asyncio.wait_for(self.c_rw_lock.c_resource.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + await self.c_rw_lock.v_read_count.decrement() + self.c_rw_lock.c_lock_read_count.release() + return False + self.c_rw_lock.c_lock_read_count.release() + self.v_locked = True + return True + + async def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + await self.c_rw_lock.c_lock_read_count.acquire() + await self.c_rw_lock.v_read_count.decrement() + if 0 == int(self.c_rw_lock.v_read_count): + self.c_rw_lock.c_resource.release() + self.c_rw_lock.c_lock_read_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + class _aWriter(LockableD): + def __init__(self, p_RWLock: "RWLockReadD") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout: Optional[float] = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline: Optional[float] = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + try: + await asyncio.wait_for(self.c_rw_lock.c_resource.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + locked: bool = True + except asyncio.TimeoutError: + locked = False + + self.v_locked = locked + return locked + + async def downgrade(self) -> Lockable: + """Downgrade.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + + result = await self.c_rw_lock.gen_rlock() + + wait_blocking = asyncio.Event() + + async def lock_result() -> None: + wait_blocking.set() + await result.acquire() # This is a blocking action + wait_blocking.set() + + run_task(lock_result()) + + await wait_blocking.wait() # Wait for the thread to be almost in its blocking state. + wait_blocking.clear() + + for _ in range(123): + await asyncio.sleep(sys.float_info.min) # Heuristic sleep delay to leave some extra time for the thread to block. + + await self.release() # Open the gate! the current RW lock strategy gives priority to reader, therefore the result will acquire lock before any other writer lock. + + await wait_blocking.wait() # Wait for the lock to be acquired + return result + + async def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_resource.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + async def gen_rlock(self) -> "RWLockReadD._aReader": + """Generate a reader lock.""" + return RWLockReadD._aReader(self) + + async def gen_wlock(self) -> "RWLockReadD._aWriter": + """Generate a writer lock.""" + return RWLockReadD._aWriter(self) + + +class RWLockWriteD(RWLockableD): + """A Read/Write lock giving preference to Writer.""" + + def __init__(self, lock_factory: Union[Callable[[], Lockable], Type[asyncio.Lock]] = asyncio.Lock, time_source: Callable[[], float] = time.perf_counter) -> None: + """Init.""" + self.v_read_count: _ThreadSafeInt = _ThreadSafeInt(lock_factory=lock_factory, initial_value=0) + self.v_write_count: int = 0 + self.c_time_source = time_source + self.c_lock_read_count = lock_factory() + self.c_lock_write_count = lock_factory() + self.c_lock_read_entry = lock_factory() + self.c_lock_read_try = lock_factory() + self.c_resource = lock_factory() + + class _aReader(Lockable): + def __init__(self, p_RWLock: "RWLockWriteD") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read_entry.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + return False + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read_try.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.c_lock_read_entry.release() + return False + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read_count.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_read_entry.release() + return False + await self.c_rw_lock.v_read_count.increment() + if 1 == int(self.c_rw_lock.v_read_count): + try: + await asyncio.wait_for(self.c_rw_lock.c_resource.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_read_entry.release() + await self.c_rw_lock.v_read_count.decrement() + self.c_rw_lock.c_lock_read_count.release() + return False + self.c_rw_lock.c_lock_read_count.release() + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_read_entry.release() + self.v_locked = True + return True + + async def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + await self.c_rw_lock.c_lock_read_count.acquire() + await self.c_rw_lock.v_read_count.decrement() + if 0 == int(self.c_rw_lock.v_read_count): + self.c_rw_lock.c_resource.release() + self.c_rw_lock.c_lock_read_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + class _aWriter(LockableD): + def __init__(self, p_RWLock: "RWLockWriteD") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_write_count.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + return False + self.c_rw_lock.v_write_count += 1 + if 1 == self.c_rw_lock.v_write_count: + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read_try.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.v_write_count -= 1 + self.c_rw_lock.c_lock_write_count.release() + return False + self.c_rw_lock.c_lock_write_count.release() + try: + await asyncio.wait_for(self.c_rw_lock.c_resource.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + await self.c_rw_lock.c_lock_write_count.acquire() + self.c_rw_lock.v_write_count -= 1 + if 0 == self.c_rw_lock.v_write_count: + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_write_count.release() + return False + self.v_locked = True + return True + + async def downgrade(self) -> Lockable: + """Downgrade.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + await self.c_rw_lock.v_read_count.increment() + + self.v_locked = False + await self.c_rw_lock.c_lock_write_count.acquire() + self.c_rw_lock.v_write_count -= 1 + if 0 == self.c_rw_lock.v_write_count: + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_write_count.release() + + result = self.c_rw_lock._aReader(p_RWLock=self.c_rw_lock) + result.v_locked = True + return result + + async def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_resource.release() + await self.c_rw_lock.c_lock_write_count.acquire() + self.c_rw_lock.v_write_count -= 1 + if 0 == self.c_rw_lock.v_write_count: + self.c_rw_lock.c_lock_read_try.release() + self.c_rw_lock.c_lock_write_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + async def gen_rlock(self) -> "RWLockWriteD._aReader": + """Generate a reader lock.""" + return RWLockWriteD._aReader(self) + + async def gen_wlock(self) -> "RWLockWriteD._aWriter": + """Generate a writer lock.""" + return RWLockWriteD._aWriter(self) + + +class RWLockFairD(RWLockableD): + """A Read/Write lock giving fairness to both Reader and Writer.""" + + def __init__(self, lock_factory: Union[Callable[[], Lockable], Type[asyncio.Lock]] = asyncio.Lock, time_source: Callable[[], float] = time.perf_counter) -> None: + """Init.""" + self.v_read_count: int = 0 + self.c_time_source = time_source + self.c_lock_read_count = lock_factory() + self.c_lock_read = lock_factory() + self.c_lock_write = lock_factory() + + class _aReader(Lockable): + def __init__(self, p_RWLock: "RWLockFairD") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + return False + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read_count.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.c_lock_read.release() + return False + self.c_rw_lock.v_read_count += 1 + if 1 == self.c_rw_lock.v_read_count: + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_write.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.v_read_count -= 1 + self.c_rw_lock.c_lock_read_count.release() + self.c_rw_lock.c_lock_read.release() + return False + self.c_rw_lock.c_lock_read_count.release() + self.c_rw_lock.c_lock_read.release() + self.v_locked = True + return True + + async def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + await self.c_rw_lock.c_lock_read_count.acquire() + self.c_rw_lock.v_read_count -= 1 + if 0 == self.c_rw_lock.v_read_count: + self.c_rw_lock.c_lock_write.release() + self.c_rw_lock.c_lock_read_count.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + class _aWriter(LockableD): + def __init__(self, p_RWLock: "RWLockFairD") -> None: + self.c_rw_lock = p_RWLock + self.v_locked: bool = False + + async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + """Acquire a lock.""" + p_timeout = None if (blocking and timeout < 0) else (timeout if blocking else 0) + c_deadline = None if p_timeout is None else (self.c_rw_lock.c_time_source() + p_timeout) + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_read.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + return False + try: + await asyncio.wait_for(self.c_rw_lock.c_lock_write.acquire(), timeout=(None if c_deadline is None else max(sys.float_info.min, c_deadline - self.c_rw_lock.c_time_source()))) + except asyncio.TimeoutError: + self.c_rw_lock.c_lock_read.release() + return False + self.v_locked = True + return True + + async def downgrade(self) -> Lockable: + """Downgrade.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.c_rw_lock.v_read_count += 1 + + self.v_locked = False + self.c_rw_lock.c_lock_read.release() + + result = self.c_rw_lock._aReader(p_RWLock=self.c_rw_lock) + result.v_locked = True + return result + + async def release(self) -> None: + """Release the lock.""" + if not self.v_locked: raise RELEASE_ERR_CLS(RELEASE_ERR_MSG) + self.v_locked = False + self.c_rw_lock.c_lock_write.release() + self.c_rw_lock.c_lock_read.release() + + def locked(self) -> bool: + """Answer to 'is it currently locked?'.""" + return self.v_locked + + async def gen_rlock(self) -> "RWLockFairD._aReader": + """Generate a reader lock.""" + return RWLockFairD._aReader(self) + + async def gen_wlock(self) -> "RWLockFairD._aWriter": + """Generate a writer lock.""" + return RWLockFairD._aWriter(self) diff --git a/pyproject.toml b/pyproject.toml index c2e5ae9fc..613c8bf62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ exclude = [ 'ankihub/lib/urllib', 'ankihub/lib/other/typing_extension', 'ankihub/lib/tenacity', + 'ankihub/lib/readerwriterlock', ] no_strict_optional = true diff --git a/tests/addon/test_unit.py b/tests/addon/test_unit.py index f36948027..749c28da6 100644 --- a/tests/addon/test_unit.py +++ b/tests/addon/test_unit.py @@ -8,7 +8,8 @@ from pathlib import Path from sqlite3 import ProgrammingError from textwrap import dedent -from typing import Callable, Generator, List, Optional, Protocol, Tuple +from time import sleep +from typing import Callable, ContextManager, Generator, List, Optional, Protocol, Tuple from unittest.mock import Mock import aqt @@ -60,8 +61,9 @@ SuggestionType, TagGroupValidationResponse, ) +from ankihub.db import attached_ankihub_db, detached_ankihub_db from ankihub.db.db import _AnkiHubDB -from ankihub.db.exceptions import IntegrityError +from ankihub.db.exceptions import IntegrityError, LockAcquisitionTimeoutError from ankihub.feature_flags import _FeatureFlags, feature_flags from ankihub.gui import errors, suggestion_dialog from ankihub.gui.error_dialog import ErrorDialog @@ -74,7 +76,7 @@ ) from ankihub.gui.media_sync import media_sync from ankihub.gui.menu import AnkiHubLogin -from ankihub.gui.operations import deck_creation +from ankihub.gui.operations import AddonQueryOp, deck_creation from ankihub.gui.operations.deck_creation import ( DeckCreationConfirmationDialog, create_collaborative_deck, @@ -1514,6 +1516,57 @@ def test_with_none_in_db( ) +class TestAnkiHubDBContextManagers: + @pytest.mark.parametrize( + "task_configs, task_times_out", + [ + # Format for a task_confg: (context_manager, duration) + # Detached, Detached - tasks don't block each other + ([(detached_ankihub_db, 0.1), (detached_ankihub_db, 0.1)], False), + ([(detached_ankihub_db, 0.5), (detached_ankihub_db, 0.1)], False), + # Attached, Attached - tasks block each other + ([(attached_ankihub_db, 0.1), (attached_ankihub_db, 0.1)], False), + ([(attached_ankihub_db, 0.5), (attached_ankihub_db, 0.1)], True), + # Attached, Detached - tasks block each other + ([(attached_ankihub_db, 0.1), (detached_ankihub_db, 0.1)], False), + ([(attached_ankihub_db, 0.5), (detached_ankihub_db, 0.1)], True), + # Detached, Attached - tasks block each other + ([(detached_ankihub_db, 0.1), (attached_ankihub_db, 0.1)], False), + ([(detached_ankihub_db, 0.5), (attached_ankihub_db, 0.1)], True), + ], + ) + def test_blocking_and_timeout_behavior( + self, + anki_session_with_addon_data: AnkiSession, + monkeypatch: MonkeyPatch, + qtbot: QtBot, + task_configs: List[Tuple[Callable[[], ContextManager], float]], + task_times_out: bool, + ): + monkeypatch.setattr("ankihub.db.rw_lock.LOCK_TIMEOUT_SECONDS", 0.2) + + def task(context_manager: Callable[[], ContextManager], duration: float): + with context_manager(): + sleep(duration) + + with anki_session_with_addon_data.profile_loaded(): + for context_manager, duration in task_configs: + AddonQueryOp( + parent=qtbot, + op=lambda _: task(context_manager, duration), + success=lambda _: None, + ).without_collection().run_in_background() + + with qtbot.captureExceptions() as exceptions: + qtbot.wait(500) + + if task_times_out: + assert len(exceptions) == 1 + assert isinstance(exceptions[0][1], LockAcquisitionTimeoutError) + else: + assert len(exceptions) == 0 + + class TestErrorHandling: def test_contains_path_to_this_addon(self): # Assert that the function returns True when the input string contains the From ad169335cfc9622cc4afe57e5c85a361015daf73 Mon Sep 17 00:00:00 2001 From: Jakub Fidler <31575114+RisingOrange@users.noreply.github.com> Date: Sat, 30 Dec 2023 10:43:29 +0100 Subject: [PATCH 2/6] chore: Fix client tests not working without web app docker container (#855) * Ignore "No such container" error * Remove "sudo" from docker commands --- tests/client/test_client.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 0cc924501..feafe31dc 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -106,7 +106,6 @@ def client_with_server_setup(vcr: VCR, marks: List[str], request: FixtureRequest # Restore DB from dump result = subprocess.run( [ - "sudo", "docker", "exec", "-i", @@ -152,7 +151,6 @@ def create_db_dump_if_not_exists() -> None: # Check if DB dump exists result = subprocess.run( [ - "sudo", "docker", "exec", "-i", @@ -178,7 +176,6 @@ def create_db_dump_if_not_exists() -> None: # Prepare the DB state result = subprocess.run( [ - "sudo", "docker", "exec", DJANGO_CONTAINER_NAME, @@ -199,7 +196,6 @@ def create_db_dump_if_not_exists() -> None: # Dump the DB to a file to be able to restore it before each test result = subprocess.run( [ - "sudo", "docker", "exec", DB_CONTAINER_NAME, @@ -223,7 +219,6 @@ def remove_db_dump() -> Generator: """Remove the db dump on the start of the session so that it is re-created for each session.""" result = subprocess.run( [ - "sudo", "docker", "exec", DB_CONTAINER_NAME, @@ -238,6 +233,9 @@ def remove_db_dump() -> Generator: if result.returncode == 0 or NO_SUCH_FILE_OR_DIRECTORY_MESSAGE in result.stderr: # Nothing to do pass + elif result.returncode == 1 and "No such container" in result.stderr: + # Nothing to do + pass elif "Container" in result.stderr and "is not running" in result.stderr: # Container is not running, nothing to do pass From 0519834d134ca01ef9b27fd8f38752e69fe56a02 Mon Sep 17 00:00:00 2001 From: Jakub Fidler <31575114+RisingOrange@users.noreply.github.com> Date: Sat, 30 Dec 2023 17:44:44 +0100 Subject: [PATCH 3/6] chore: Use pytest-mock in tests (#847) * Add pytest-mock to requirements * chore: Replace usages of monkeypatch in test_unit.py * Replace usages of `mock_function` in test_unit.py with `mocker.patch` * Simplify raising exceptions * Prefeer `mocker.patch` over `mocker.patch.object` * Use mocker.stub() for mocking callbacks * chore: Use pytest-mock in client tests (#848) * chore: Replaces usages of monkeypatch in test_client.py * Attempt to use mocker and reset os.remove (#850) * Attempt to use mocker and reset os.remove * Use mocker.resetall to reset all mocks * Use stop (reset is just for the mock call values) * Fix types: use mocker.stop on remove_mock --------- Co-authored-by: Trey Hunner * Chore/use pytest mock for integration tests (#849) * Replace some monkeypatch uses with pytest-mock * Use mocker everywhere in integration tests Replace all uses of unittest.mock, monkeypatch fixture, and the custom mock_function fixture with the mocker fixture from pytest-mock. * Use mocker for mock_function fixture * Use non-string type annotations for Mock * Prefer mocker.patch over mocker.patch.object * Apply suggestions from code review Co-authored-by: Jakub Fidler <31575114+RisingOrange@users.noreply.github.com> * Fix variable name typos (mesaage -> message) * Remove unused mock_function fixture --------- Co-authored-by: Jakub Fidler <31575114+RisingOrange@users.noreply.github.com> --------- Co-authored-by: Trey Hunner --- requirements/dev.txt | 1 + tests/addon/conftest.py | 1 - tests/addon/test_integration.py | 564 ++++++++++++++------------------ tests/addon/test_unit.py | 214 ++++++------ tests/client/test_client.py | 145 ++++---- tests/fixtures.py | 33 -- 6 files changed, 409 insertions(+), 549 deletions(-) diff --git a/requirements/dev.txt b/requirements/dev.txt index b5f1afc5e..a40ed9c82 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -17,6 +17,7 @@ vcrpy==4.2.0 pytest-vcr==1.0.2 pytest-qt==4.2.0 pytest-split==0.8.1 +pytest-mock==3.12.0 factory-boy==3.2.1 types-factory-boy==0.3.1 pytest-xvfb==2.0.0 diff --git a/tests/addon/conftest.py b/tests/addon/conftest.py index 5cb6559b3..0d982b835 100644 --- a/tests/addon/conftest.py +++ b/tests/addon/conftest.py @@ -20,7 +20,6 @@ install_ah_deck, mock_all_feature_flags_to_default_values, mock_download_and_install_deck_dependencies, - mock_function, mock_message_box_with_cb, mock_show_dialog_with_cb, mock_study_deck_dialog_with_cb, diff --git a/tests/addon/test_integration.py b/tests/addon/test_integration.py index 88c20247a..e03c02e6c 100644 --- a/tests/addon/test_integration.py +++ b/tests/addon/test_integration.py @@ -22,7 +22,7 @@ Tuple, Union, ) -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock from zipfile import ZipFile import aqt @@ -42,8 +42,9 @@ from aqt.importing import AnkiPackageImporter from aqt.qt import QAction, Qt from aqt.theme import theme_manager -from pytest import MonkeyPatch, fixture +from pytest import fixture from pytest_anki import AnkiSession +from pytest_mock import MockerFixture from pytestqt.qtbot import QtBot # type: ignore from requests import Response # type: ignore from requests_mock import Mocker @@ -52,7 +53,6 @@ DeckMediaUpdateChunk, UserDeckExtensionRelation, ) -from ankihub.gui import deckbrowser from ankihub.gui.browser.browser import ( ModifiedAfterSyncSearchNode, NewNoteSearchNode, @@ -72,7 +72,6 @@ ImportAHNote, InstallAHDeck, MockDownloadAndInstallDeckDependencies, - MockFunction, MockShowDialogWithCB, MockStudyDeckDialogWithCB, MockSuggestionDialog, @@ -115,7 +114,7 @@ _setup_logging_for_db_begin, _setup_logging_for_sync_collection_and_media, ) -from ankihub.gui import auto_sync, operations, utils +from ankihub.gui import utils from ankihub.gui.auto_sync import ( SYNC_RATE_LIMIT_SECONDS, _setup_ankihub_sync_on_ankiweb_sync, @@ -132,7 +131,7 @@ from ankihub.gui.errors import upload_logs_and_data_in_background from ankihub.gui.media_sync import media_sync from ankihub.gui.menu import menu_state -from ankihub.gui.operations import ankihub_sync, new_deck_subscriptions +from ankihub.gui.operations import ankihub_sync from ankihub.gui.operations.db_check import ah_db_check from ankihub.gui.operations.db_check.ah_db_check import check_ankihub_db from ankihub.gui.operations.deck_installation import download_and_install_decks @@ -333,45 +332,20 @@ def mock_ankihub_sync_dependencies( @fixture def mock_fetch_note_types_to_return_empty_dict( - monkeypatch: MonkeyPatch, + mocker: MockerFixture, ) -> None: # This prevents the add-on from fetching the note types from the server - monkeypatch.setattr( - "ankihub.main.note_types._fetch_note_types", - lambda *args, **kwargs: {}, - ) + mocker.patch("ankihub.main.note_types._fetch_note_types") @pytest.fixture -def mock_client_methods_called_during_ankihub_sync(monkeypatch: MonkeyPatch) -> None: - monkeypatch.setattr( - AnkiHubClient, "get_deck_subscriptions", lambda *args, **kwargs: [] - ) - monkeypatch.setattr( - AnkiHubClient, - "get_deck_extensions_by_deck_id", - lambda *args, **kwargs: [], - ) - monkeypatch.setattr( - AnkiHubClient, - "is_media_upload_finished", - lambda *args, **kwargs: True, - ) - monkeypatch.setattr( - AnkiHubClient, - "get_deck_updates", - lambda *args, **kwargs: [], - ) - monkeypatch.setattr( - AnkiHubClient, - "get_deck_media_updates", - lambda *args, **kwargs: [], - ) - monkeypatch.setattr( - AnkiHubClient, - "send_card_review_data", - lambda *args, **kwargs: [], - ) +def mock_client_methods_called_during_ankihub_sync(mocker: MockerFixture) -> None: + mocker.patch.object(AnkiHubClient, "get_deck_subscriptions") + mocker.patch.object(AnkiHubClient, "get_deck_extensions_by_deck_id") + mocker.patch.object(AnkiHubClient, "is_media_upload_finished") + mocker.patch.object(AnkiHubClient, "get_deck_updates") + mocker.patch.object(AnkiHubClient, "get_deck_media_updates") + mocker.patch.object(AnkiHubClient, "send_card_review_data") class MockClientGetNoteType(Protocol): @@ -380,7 +354,7 @@ def __call__(self, note_types: List[NotetypeDict]) -> None: @fixture -def mock_client_get_note_type(monkeypatch: MonkeyPatch) -> MockClientGetNoteType: +def mock_client_get_note_type(mocker: MockerFixture) -> MockClientGetNoteType: """Mock the get_note_type method of the AnkiHubClient to return the matching note type based on the id of the note type.""" @@ -397,10 +371,7 @@ def note_type_by_id(self, note_type_id: int) -> NotetypeDict: assert result is not None return result - monkeypatch.setattr( - "ankihub.main.reset_local_changes.AnkiHubClient.get_note_type", - note_type_by_id, - ) + mocker.patch.object(AnkiHubClient, "get_note_type", side_effect=note_type_by_id) return _mock_client_note_types @@ -439,13 +410,13 @@ def __call__(self, note: Note, wait_for_media_upload: bool) -> Mock: @pytest.fixture def create_change_suggestion( - qtbot: QtBot, mock_function: MockFunction, mock_client_media_upload: Mocker + qtbot: QtBot, mocker: MockerFixture, mock_client_media_upload: Mocker ): """Create a change suggestion for a note and wait for the background thread that uploads media to finish. Returns the mock for the create_change_note_suggestion method. It can be used to get information about the suggestion that was passed to the client.""" - create_change_suggestion_mock = mock_function( + create_change_suggestion_mock = mocker.patch.object( AnkiHubClient, "create_change_note_suggestion", ) @@ -479,13 +450,13 @@ def __call__( @pytest.fixture def create_new_note_suggestion( - qtbot: QtBot, mock_function: MockFunction, mock_client_media_upload: Mocker + qtbot: QtBot, mocker: MockerFixture, mock_client_media_upload: Mocker ): """Create a new note suggestion for a note and wait for the background thread that uploads media to finish. Returns the mock for the create_new_note_suggestion_mock method. It can be used to get information about the suggestion that was passed to the client.""" - create_new_note_suggestion_mock = mock_function( + create_new_note_suggestion_mock = mocker.patch.object( AnkiHubClient, "create_new_note_suggestion", ) @@ -524,7 +495,7 @@ def test_entry_point(anki_session_with_addon_data: AnkiSession, qtbot: QtBot): def test_editor( anki_session_with_addon_data: AnkiSession, requests_mock: Mocker, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, next_deterministic_uuid: Callable[[], uuid.UUID], install_sample_ah_deck: InstallSampleAHDeck, mock_suggestion_dialog: MockSuggestionDialog, @@ -544,7 +515,7 @@ def test_editor( note_1_ah_nid = next_deterministic_uuid() - monkeypatch.setattr("ankihub.main.exporting.uuid.uuid4", lambda: note_1_ah_nid) + mocker.patch("ankihub.main.exporting.uuid.uuid4", return_value=note_1_ah_nid) requests_mock.post( f"{config.api_url}/notes/{note_1_ah_nid}/suggestion/", @@ -625,7 +596,7 @@ def test_modify_note_type(anki_session_with_addon_data: AnkiSession): def test_create_collaborative_deck_and_upload( anki_session_with_addon_data: AnkiSession, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, next_deterministic_uuid: Callable[[], uuid.UUID], ): with anki_session_with_addon_data.profile_loaded(): @@ -643,15 +614,14 @@ def test_create_collaborative_deck_and_upload( # upload deck ah_did = next_deterministic_uuid() - upload_deck_mock = Mock() - upload_deck_mock.return_value = ah_did ah_nid = next_deterministic_uuid() - with monkeypatch.context() as m: - m.setattr( - "ankihub.ankihub_client.AnkiHubClient.upload_deck", upload_deck_mock - ) - m.setattr("uuid.uuid4", lambda: ah_nid) - create_ankihub_deck(deck_name, private=False) + upload_deck_mock = mocker.patch.object( + AnkiHubClient, + "upload_deck", + return_value=ah_did, + ) + mocker.patch("uuid.uuid4", return_value=ah_nid) + create_ankihub_deck(deck_name, private=False) # re-load note to get updated note.mid note.load() @@ -697,6 +667,7 @@ def test_download_and_install_deck( self, anki_session_with_addon_data: AnkiSession, qtbot: QtBot, + mocker: MockerFixture, mock_download_and_install_deck_dependencies: MockDownloadAndInstallDeckDependencies, ankihub_basic_note_type: NotetypeDict, ): @@ -709,7 +680,7 @@ def test_download_and_install_deck( ) # Download and install the deck - on_success_mock = Mock() + on_success_mock = mocker.stub() download_and_install_decks([deck.ah_did], on_done=on_success_mock) qtbot.wait_until(lambda: on_success_mock.call_count == 1) @@ -735,18 +706,16 @@ def test_download_and_install_deck( ), f"Mock {name} was not called once, but {mock.call_count} times" def test_exception_is_not_backpropagated_to_caller( - self, anki_session_with_addon_data: AnkiSession, mock_function: MockFunction + self, anki_session_with_addon_data: AnkiSession, mocker: MockerFixture ): with anki_session_with_addon_data.profile_loaded(): # Mock a function which is called in download_install_decks to raise an exception. - exception_mesaage = "test exception" + exception_message = "test exception" - def raise_exception(*args, **kwargs) -> None: - raise Exception(exception_mesaage) - - mock_function( - "ankihub.gui.operations.deck_installation.aqt.mw.taskman.with_progress", - side_effect=raise_exception, + mocker.patch.object( + aqt.mw.taskman, + "with_progress", + side_effect=Exception(exception_message), ) # Set up the on_done callback @@ -760,7 +729,7 @@ def on_done(future_: Future) -> None: download_and_install_decks(ankihub_dids=[], on_done=on_done) # Assert that the future contains the exception and that it contains the expected message. - assert future.exception().args[0] == exception_mesaage + assert future.exception().args[0] == exception_message class TestCheckAndInstallNewDeckSubscriptions: @@ -768,7 +737,7 @@ def test_one_new_subscription( self, anki_session_with_addon_data: AnkiSession, qtbot: QtBot, - mock_function: MockFunction, + mocker: MockerFixture, mock_show_dialog_with_cb: MockShowDialogWithCB, ): anki_session = anki_session_with_addon_data @@ -780,16 +749,15 @@ def test_one_new_subscription( ) # Mock download and install operation to only call the on_done callback - download_and_install_decks_mock = mock_function( - operations.new_deck_subscriptions, - "download_and_install_decks", - side_effect=lambda *args, **kwargs: kwargs["on_done"]( + download_and_install_decks_mock = mocker.patch( + "ankihub.gui.operations.new_deck_subscriptions.download_and_install_decks", + side_effect=lambda *args, on_done, **kwargs: on_done( future_with_result(None) ), ) # Call the function with a deck - on_done_mock = Mock() + on_done_mock = mocker.stub() deck = DeckFactory.create() check_and_install_new_deck_subscriptions( subscribed_decks=[deck], on_done=on_done_mock @@ -809,6 +777,7 @@ def test_user_declines( self, anki_session_with_addon_data: AnkiSession, qtbot: QtBot, + mocker: MockerFixture, mock_show_dialog_with_cb: MockShowDialogWithCB, ): anki_session = anki_session_with_addon_data @@ -820,7 +789,7 @@ def test_user_declines( ) # Call the function with a deck - on_done_mock = Mock() + on_done_mock = mocker.stub() deck = DeckFactory.create() check_and_install_new_deck_subscriptions( subscribed_decks=[deck], on_done=on_done_mock @@ -835,12 +804,13 @@ def test_user_declines( def test_no_new_subscriptions( self, anki_session_with_addon_data: AnkiSession, + mocker: MockerFixture, qtbot: QtBot, ): anki_session = anki_session_with_addon_data with anki_session.profile_loaded(): # Call the function with an empty list - on_done_mock = Mock() + on_done_mock = mocker.stub() check_and_install_new_deck_subscriptions( subscribed_decks=[], on_done=on_done_mock ) @@ -855,23 +825,19 @@ def test_confirmation_dialog_raises_exception( self, anki_session_with_addon_data: AnkiSession, qtbot: QtBot, - mock_function: MockFunction, + mocker: MockerFixture, ): anki_session = anki_session_with_addon_data with anki_session.profile_loaded(): # Mock confirmation dialog to raise an exception - def raise_exception(*args, **kwargs): - raise Exception("Something went wrong") - - message_box_mock = mock_function( - new_deck_subscriptions, - "show_dialog", - side_effect=raise_exception, + message_box_mock = mocker.patch( + "ankihub.gui.operations.new_deck_subscriptions.show_dialog", + side_effect=Exception("Something went wrong"), ) # Call the function with a deck - on_done_mock = Mock() + on_done_mock = mocker.stub() deck = DeckFactory.create() check_and_install_new_deck_subscriptions( subscribed_decks=[deck], on_done=on_done_mock @@ -890,7 +856,7 @@ def test_install_operation_raises_exception( self, anki_session_with_addon_data: AnkiSession, qtbot: QtBot, - mock_function: MockFunction, + mocker: MockerFixture, mock_show_dialog_with_cb: MockShowDialogWithCB, ): anki_session = anki_session_with_addon_data @@ -902,17 +868,13 @@ def test_install_operation_raises_exception( ) # Mock download and install operation to raise an exception - def raise_exception(*args, **kwargs): - raise Exception("Something went wrong") - - download_and_install_decks_mock = mock_function( - operations.new_deck_subscriptions, - "download_and_install_decks", - side_effect=raise_exception, + download_and_install_decks_mock = mocker.patch( + "ankihub.gui.operations.new_deck_subscriptions.download_and_install_decks", + side_effect=Exception("Something went wrong"), ) # Call the function with a deck - on_done_mock = Mock() + on_done_mock = mocker.stub() deck = DeckFactory.create() check_and_install_new_deck_subscriptions( subscribed_decks=[deck], on_done=on_done_mock @@ -972,7 +934,7 @@ def test_get_deck_by_id( def test_suggest_note_update( anki_session_with_addon_data: AnkiSession, install_sample_ah_deck: InstallSampleAHDeck, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, ): anki_session = anki_session_with_addon_data with anki_session.profile_loaded(): @@ -1009,17 +971,16 @@ def test_suggest_note_update( note.tags.remove("removed") # Suggest the changes - create_change_note_suggestion_mock = MagicMock() - monkeypatch.setattr( - "ankihub.ankihub_client.AnkiHubClient.create_change_note_suggestion", - create_change_note_suggestion_mock, + create_change_note_suggestion_mock = mocker.patch.object( + AnkiHubClient, + "create_change_note_suggestion", ) suggest_note_update( note=note, change_type=SuggestionType.NEW_CONTENT, comment="test", - media_upload_cb=Mock(), + media_upload_cb=mocker.stub(), ) # Check that the correct suggestion was created @@ -1039,6 +1000,7 @@ def test_suggest_note_update( def test_suggest_new_note( anki_session_with_addon_data: AnkiSession, + mocker: MockerFixture, requests_mock: Mocker, install_sample_ah_deck: InstallSampleAHDeck, ): @@ -1064,7 +1026,7 @@ def test_suggest_new_note( note=note, ankihub_did=ah_did, comment="test", - media_upload_cb=Mock(), + media_upload_cb=mocker.stub(), ) # ... assert that add-on internal and optional tags were filtered out @@ -1088,7 +1050,7 @@ def test_suggest_new_note( note=note, ankihub_did=ah_did, comment="test", - media_upload_cb=Mock(), + media_upload_cb=mocker.stub(), ) except AnkiHubHTTPError as e: exc = e @@ -1097,15 +1059,13 @@ def test_suggest_new_note( def test_suggest_notes_in_bulk( anki_session_with_addon_data: AnkiSession, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, install_sample_ah_deck: InstallSampleAHDeck, next_deterministic_uuid: Callable[[], uuid.UUID], ): anki_session = anki_session_with_addon_data - bulk_suggestions_method_mock = MagicMock() - monkeypatch.setattr( - "ankihub.ankihub_client.AnkiHubClient.create_suggestions_in_bulk", - bulk_suggestions_method_mock, + bulk_suggestions_method_mock = mocker.patch.object( + AnkiHubClient, "create_suggestions_in_bulk" ) with anki_session.profile_loaded(): mw = anki_session.mw @@ -1134,16 +1094,15 @@ def test_suggest_notes_in_bulk( mw.col.update_notes(notes) new_note_ah_id = next_deterministic_uuid() - with monkeypatch.context() as m: - m.setattr("uuid.uuid4", lambda: new_note_ah_id) - suggest_notes_in_bulk( - ankihub_did=ah_did, - notes=notes, - auto_accept=False, - change_type=SuggestionType.NEW_CONTENT, - comment="test", - media_upload_cb=Mock(), - ) + mocker.patch("uuid.uuid4", return_value=new_note_ah_id) + suggest_notes_in_bulk( + ankihub_did=ah_did, + notes=notes, + auto_accept=False, + change_type=SuggestionType.NEW_CONTENT, + comment="test", + media_upload_cb=mocker.stub(), + ) assert bulk_suggestions_method_mock.call_count == 1 assert bulk_suggestions_method_mock.call_args.kwargs == { @@ -1866,7 +1825,7 @@ def test_unsubscribe_from_deck( anki_session_with_addon_data: AnkiSession, install_sample_ah_deck: InstallSampleAHDeck, qtbot: QtBot, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, requests_mock: Mocker, ): anki_session = anki_session_with_addon_data @@ -1878,10 +1837,7 @@ def test_unsubscribe_from_deck( mids = ankihub_db.note_types_for_ankihub_deck(ah_did) assert len(mids) == 2 - monkeypatch.setattr( - "ankihub.settings._Config.is_logged_in", - lambda *args, **kwargs: True, - ) + mocker.patch.object(config, "is_logged_in", return_value=True) deck = mw.col.decks.get(anki_deck_id) requests_mock.get( f"{DEFAULT_API_URL}/decks/subscriptions/", @@ -1908,19 +1864,17 @@ def test_unsubscribe_from_deck( deck_item_index = 0 deck_item = decks_list.item(deck_item_index) deck_item.setSelected(True) - monkeypatch.setattr( - "ankihub.gui.decks_dialog.ask_user", - lambda *args, **kwargs: True, - ) + mocker.patch("ankihub.gui.decks_dialog.ask_user", return_value=True) requests_mock.get( f"{DEFAULT_API_URL}/decks/subscriptions/", status_code=200, json=[] ) - with patch.object( - AnkiHubClient, "unsubscribe_from_deck" - ) as unsubscribe_from_deck_mock: - qtbot.mouseClick(dialog.unsubscribe_btn, Qt.MouseButton.LeftButton) - unsubscribe_from_deck_mock.assert_called_once() + unsubscribe_from_deck_mock = mocker.patch.object( + AnkiHubClient, + "unsubscribe_from_deck", + ) + qtbot.mouseClick(dialog.unsubscribe_btn, Qt.MouseButton.LeftButton) + unsubscribe_from_deck_mock.assert_called_once() assert dialog.decks_list.count() == 0 @@ -2151,6 +2105,7 @@ def test_ModifiedAfterSyncSearchNode_with_notes( self, anki_session_with_addon_data: AnkiSession, install_sample_ah_deck: InstallSampleAHDeck, + mocker: MockerFixture, ): with anki_session_with_addon_data.profile_loaded(): mw = anki_session_with_addon_data.mw @@ -2158,7 +2113,7 @@ def test_ModifiedAfterSyncSearchNode_with_notes( install_sample_ah_deck() all_nids = mw.col.find_notes("") - browser = Mock() + browser = mocker.Mock() browser.table.is_notes_mode.return_value = True with attached_ankihub_db(): @@ -2187,6 +2142,7 @@ def test_ModifiedAfterSyncSearchNode_with_cards( self, anki_session_with_addon_data: AnkiSession, install_sample_ah_deck: InstallSampleAHDeck, + mocker: MockerFixture, ): with anki_session_with_addon_data.profile_loaded(): mw = anki_session_with_addon_data.mw @@ -2194,7 +2150,7 @@ def test_ModifiedAfterSyncSearchNode_with_cards( install_sample_ah_deck() all_cids = mw.col.find_cards("") - browser = Mock() + browser = mocker.Mock() browser.table.is_notes_mode.return_value = False with attached_ankihub_db(): @@ -2223,6 +2179,7 @@ def test_UpdatedInTheLastXDaysSearchNode( self, anki_session_with_addon_data: AnkiSession, install_sample_ah_deck: InstallSampleAHDeck, + mocker: MockerFixture, ): with anki_session_with_addon_data.profile_loaded(): mw = anki_session_with_addon_data.mw @@ -2231,7 +2188,7 @@ def test_UpdatedInTheLastXDaysSearchNode( all_nids = mw.col.find_notes("") - browser = Mock() + browser = mocker.Mock() browser.table.is_notes_mode.return_value = True with attached_ankihub_db(): @@ -2264,6 +2221,7 @@ def test_NewNoteSearchNode( self, anki_session_with_addon_data: AnkiSession, next_deterministic_uuid: Callable[[], uuid.UUID], + mocker: MockerFixture, ): with anki_session_with_addon_data.profile_loaded(): mw = anki_session_with_addon_data.mw @@ -2294,7 +2252,7 @@ def test_NewNoteSearchNode( all_nids = mw.col.find_notes("") - browser = Mock() + browser = mocker.Mock() browser.table.is_notes_mode.return_value = True with attached_ankihub_db(): @@ -2308,6 +2266,7 @@ def test_SuggestionTypeSearchNode( self, anki_session_with_addon_data: AnkiSession, next_deterministic_uuid: Callable[[], uuid.UUID], + mocker: MockerFixture, ): with anki_session_with_addon_data.profile_loaded(): mw = anki_session_with_addon_data.mw @@ -2338,7 +2297,7 @@ def test_SuggestionTypeSearchNode( all_nids = mw.col.find_notes("") - browser = Mock() + browser = mocker.Mock() browser.table.is_notes_mode.return_value = True with attached_ankihub_db(): @@ -2356,6 +2315,7 @@ def test_UpdatedSinceLastReviewSearchNode( self, anki_session_with_addon_data: AnkiSession, install_sample_ah_deck: InstallSampleAHDeck, + mocker: MockerFixture, ): with anki_session_with_addon_data.profile_loaded(): mw = anki_session_with_addon_data.mw @@ -2364,7 +2324,7 @@ def test_UpdatedSinceLastReviewSearchNode( all_nids = mw.col.find_notes("") - browser = Mock() + browser = mocker.Mock() browser.table.is_notes_mode.return_value = True with attached_ankihub_db(): @@ -2558,7 +2518,7 @@ def test_browser_custom_columns( def test_protect_fields_action( anki_session_with_addon_data: AnkiSession, install_sample_ah_deck: InstallSampleAHDeck, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, qtbot: QtBot, field_names_to_protect: Set[str], expected_tag: str, @@ -2572,9 +2532,9 @@ def test_protect_fields_action( browser: Browser = dialogs.open("Browser", mw) # Patch gui function choose_subset to return the fields to protect - monkeypatch.setattr( + mocker.patch( "ankihub.gui.browser.browser.choose_subset", - lambda *args, **kwargs: field_names_to_protect, + return_value=field_names_to_protect, ) # Call the action for a note @@ -2600,20 +2560,20 @@ def test_basic( anki_session_with_addon_data: AnkiSession, install_ah_deck: InstallAHDeck, qtbot: QtBot, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, nightmode: bool, ): with anki_session_with_addon_data.profile_loaded(): - self._mock_dependencies(monkeypatch) + self._mock_dependencies(mocker) deck_name = "Test Deck" ah_did = install_ah_deck(ah_deck_name=deck_name) anki_did = config.deck_config(ah_did).anki_id - monkeypatch.setattr( + mocker.patch.object( AnkiHubClient, "get_deck_subscriptions", - lambda *args: [ + return_value=[ DeckFactory.create(ah_did=ah_did, anki_did=anki_did, name=deck_name) ], ) @@ -2638,10 +2598,10 @@ def test_toggle_subdecks( qtbot: QtBot, install_ah_deck: InstallAHDeck, import_ah_note: ImportAHNote, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, ): with anki_session_with_addon_data.profile_loaded(): - self._mock_dependencies(monkeypatch) + self._mock_dependencies(mocker) # Install a deck with subdeck tags subdeck_name, anki_did, ah_did = self._install_deck_with_subdeck_tag( @@ -2651,10 +2611,10 @@ def test_toggle_subdecks( assert aqt.mw.col.decks.by_name(subdeck_name) is None # Mock get_deck_subscriptions to return the deck - monkeypatch.setattr( + mocker.patch.object( AnkiHubClient, "get_deck_subscriptions", - lambda *args: [DeckFactory.create(ah_did=ah_did, anki_did=anki_did)], + return_value=[DeckFactory.create(ah_did=ah_did, anki_did=anki_did)], ) # Open the dialog @@ -2702,19 +2662,19 @@ def test_change_destination_for_new_cards( anki_session_with_addon_data: AnkiSession, qtbot: QtBot, install_ah_deck: InstallAHDeck, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, mock_study_deck_dialog_with_cb: MockStudyDeckDialogWithCB, ): with anki_session_with_addon_data.profile_loaded(): - self._mock_dependencies(monkeypatch) + self._mock_dependencies(mocker) ah_did = install_ah_deck() # Mock get_deck_subscriptions to return the deck - monkeypatch.setattr( + mocker.patch.object( AnkiHubClient, "get_deck_subscriptions", - lambda *args: [ + return_value=[ DeckFactory.create( ah_did=ah_did, anki_did=config.deck_config(ah_did).anki_id ) @@ -2752,19 +2712,19 @@ def test_with_deck_not_installed( self, anki_session_with_addon_data: AnkiSession, qtbot: QtBot, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, next_deterministic_uuid: Callable[[], uuid.UUID], next_deterministic_id: Callable[[], int], ): with anki_session_with_addon_data.profile_loaded(): - self._mock_dependencies(monkeypatch) + self._mock_dependencies(mocker) ah_did = next_deterministic_uuid() anki_did = next_deterministic_id() - monkeypatch.setattr( + mocker.patch.object( AnkiHubClient, "get_deck_subscriptions", - lambda *args: [DeckFactory.create(ah_did=ah_did, anki_did=anki_did)], + return_value=[DeckFactory.create(ah_did=ah_did, anki_did=anki_did)], ) dialog = DeckManagementDialog() @@ -2778,14 +2738,12 @@ def test_with_deck_not_installed( assert hasattr(dialog, "deck_not_installed_label") - def _mock_dependencies(self, monkeypatch: MonkeyPatch) -> None: + def _mock_dependencies(self, mocker: MockerFixture) -> None: # Mock the config to return that the user is logged in - monkeypatch.setattr(config, "is_logged_in", lambda: True) + mocker.patch.object(config, "is_logged_in", return_value=True) # Mock the ask_user function to always return True - monkeypatch.setattr( - operations.subdecks, "ask_user", lambda *args, **kwargs: True - ) + mocker.patch("ankihub.gui.operations.subdecks.ask_user", return_value=True) class TestBuildSubdecksAndMoveCardsToThem: @@ -2987,7 +2945,7 @@ def test_reset_local_changes_to_notes( anki_session_with_addon_data: AnkiSession, install_sample_ah_deck: InstallSampleAHDeck, mock_client_get_note_type: MockClientGetNoteType, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, ): with anki_session_with_addon_data.profile_loaded(): mw = anki_session_with_addon_data.mw @@ -3007,14 +2965,8 @@ def test_reset_local_changes_to_notes( mw.col.remove_notes([basic_note_2.id]) # mock the client functions that are called to get the data needed for resetting local changes - monkeypatch.setattr( - "ankihub.main.reset_local_changes.AnkiHubClient.get_protected_fields", - lambda *args, **kwargs: {}, - ) - monkeypatch.setattr( - "ankihub.main.reset_local_changes.AnkiHubClient.get_protected_tags", - lambda *args, **kwargs: [], - ) + mocker.patch.object(AnkiHubClient, "get_protected_fields") + mocker.patch.object(AnkiHubClient, "get_protected_tags") mock_client_get_note_type([note_type for note_type in mw.col.models.all()]) # reset local changes @@ -3040,15 +2992,12 @@ def test_reset_local_changes_to_notes( def test_migrate_profile_data_from_old_location( anki_session_with_addon_before_profile_support: AnkiSession, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, ): anki_session = anki_session_with_addon_before_profile_support # mock update_decks_and_media so that the add-on doesn't try to download updates from AnkiHub - monkeypatch.setattr( - "ankihub.gui.deck_updater.ah_deck_updater.update_decks_and_media", - lambda *args, **kwargs: None, - ) + mocker.patch("ankihub.gui.deck_updater.ah_deck_updater.update_decks_and_media") # run the entrypoint and load the profile to trigger the migration entry_point.run() @@ -3072,7 +3021,7 @@ def test_migrate_profile_data_from_old_location( def test_profile_swap( anki_session_with_addon_data: AnkiSession, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, install_sample_ah_deck: InstallSampleAHDeck, ): anki_session = anki_session_with_addon_data @@ -3084,8 +3033,7 @@ def test_profile_swap( PROFILE_2_NAME = "User 2" PROFILE_2_ID = uuid.UUID("22222222-2222-2222-2222-222222222222") - general_setup_mock = Mock() - monkeypatch.setattr("ankihub.entry_point._general_setup", general_setup_mock) + general_setup_mock = mocker.patch("ankihub.entry_point._general_setup") entry_point.run() @@ -3107,15 +3055,14 @@ def test_profile_swap( # load the second profile mw.pm.load(PROFILE_2_NAME) - # monkeypatch uuid4 so that the id of the second profile is known - with monkeypatch.context() as m: - m.setattr("uuid.uuid4", lambda: PROFILE_2_ID) - with anki_session.profile_loaded(): - assert profile_files_path() == ankihub_base_path() / str(PROFILE_2_ID) - # the database should be empty - assert len(ankihub_db.ankihub_deck_ids()) == 0 - # the config should not conatin any deck subscriptions - assert len(config.deck_ids()) == 0 + # monkey patch uuid4 so that the id of the second profile is known + mocker.patch("uuid.uuid4", return_value=PROFILE_2_ID) + with anki_session.profile_loaded(): + assert profile_files_path() == ankihub_base_path() / str(PROFILE_2_ID) + # the database should be empty + assert len(ankihub_db.ankihub_deck_ids()) == 0 + # the config should not conatin any deck subscriptions + assert len(config.deck_ids()) == 0 # load the first profile again mw.pm.load(PROFILE_1_NAME) @@ -3159,7 +3106,7 @@ def test_update_note( anki_session_with_addon_data: AnkiSession, install_ah_deck: InstallAHDeck, import_ah_note: ImportAHNote, - mock_function: MockFunction, + mocker: MockerFixture, mock_ankihub_sync_dependencies: None, ): with anki_session_with_addon_data.profile_loaded(): @@ -3171,8 +3118,9 @@ def test_update_note( note_info.fields[0].value = "changed" latest_update = datetime.now() - mock_function( - "ankihub.gui.deck_updater.AnkiHubClient.get_deck_updates", + mocker.patch.object( + AnkiHubClient, + "get_deck_updates", return_value=[ DeckUpdateChunk( latest_update=latest_update, @@ -3183,8 +3131,9 @@ def test_update_note( ], ) - mock_function( - "ankihub.gui.deck_updater.AnkiHubClient.get_deck_by_id", + mocker.patch.object( + AnkiHubClient, + "get_deck_by_id", return_value=DeckFactory.create(ah_did=ah_did), ) @@ -3250,7 +3199,7 @@ def test_update_optional_tags( initial_tags: List[str], incoming_optional_tags: List[str], expected_tags: List[str], - mock_function: MockFunction, + mocker: MockerFixture, mock_ankihub_sync_dependencies: None, ): with anki_session_with_addon_data.profile_loaded(): @@ -3265,21 +3214,24 @@ def test_update_optional_tags( # Mock client to return a deck extension update with incoming_optional_tags latest_update = datetime.now() - mock_function( - "ankihub.gui.deck_updater.AnkiHubClient.get_deck_by_id", + mocker.patch.object( + AnkiHubClient, + "get_deck_by_id", return_value=DeckFactory.create(ah_did=ah_did), ) deck_extension = DeckExtensionFactory.create( ah_did=ah_did, tag_group_name="tag_group" ) - mock_function( - "ankihub.gui.deck_updater.AnkiHubClient.get_deck_extensions_by_deck_id", + mocker.patch.object( + AnkiHubClient, + "get_deck_extensions_by_deck_id", return_value=[deck_extension], ) - mock_function( - "ankihub.gui.deck_updater.AnkiHubClient.get_deck_extension_updates", + mocker.patch.object( + AnkiHubClient, + "get_deck_extension_updates", return_value=[ DeckExtensionUpdateChunk( note_customizations=[ @@ -3327,7 +3279,7 @@ def test_user_relation_gets_updated_in_deck_config( self, anki_session_with_addon_data: AnkiSession, install_ah_deck: InstallAHDeck, - mock_function: MockFunction, + mocker: MockerFixture, current_relation: UserDeckRelation, incoming_relation: UserDeckRelation, mock_ankihub_sync_dependencies: None, @@ -3344,8 +3296,9 @@ def test_user_relation_gets_updated_in_deck_config( # Mock client.get_deck_by_id to return the deck with the incoming relation deck = copy.deepcopy(deck) deck.user_relation = incoming_relation - mock_function( - "ankihub.gui.deck_updater.AnkiHubClient.get_deck_by_id", + mocker.patch.object( + AnkiHubClient, + "get_deck_by_id", return_value=deck, ) @@ -3371,7 +3324,7 @@ def test_sync_uninstalls_unsubscribed_decks( self, anki_session_with_addon_data: AnkiSession, install_sample_ah_deck: InstallSampleAHDeck, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, mock_client_methods_called_during_ankihub_sync: None, sync_with_ankihub: SyncWithAnkiHub, subscribed_to_deck: bool, @@ -3384,10 +3337,10 @@ def test_sync_uninstalls_unsubscribed_decks( # Mock client.get_deck_subscriptions to return the deck if subscribed_to_deck is True and # return an empty list otherwise - monkeypatch.setattr( + mocker.patch.object( AnkiHubClient, "get_deck_subscriptions", - lambda *args, **kwargs: [DeckFactory.create(ah_did=ah_did)] + return_value=[DeckFactory.create(ah_did=ah_did)] if subscribed_to_deck else [], ) @@ -3434,18 +3387,16 @@ def test_sync_updates_api_version_on_last_sync( assert config._private_config.api_version_on_last_sync == API_VERSION def test_exception_is_not_backpropagated_to_caller( - self, anki_session_with_addon_data: AnkiSession, mock_function: MockFunction + self, anki_session_with_addon_data: AnkiSession, mocker: MockerFixture ): with anki_session_with_addon_data.profile_loaded(): # Mock a client function which is called in sync_with_ankihub to raise an exception. - exception_mesaage = "test exception" - - def raise_exception(*args, **kwargs) -> None: - raise Exception(exception_mesaage) + exception_message = "test exception" - mock_function( - "ankihub.gui.operations.ankihub_sync.AnkiHubClient.get_deck_subscriptions", - side_effect=raise_exception, + mocker.patch.object( + AnkiHubClient, + "get_deck_subscriptions", + side_effect=Exception(exception_message), ) # Set up the on_done callback @@ -3459,7 +3410,7 @@ def on_done(future_: Future) -> None: ankihub_sync.sync_with_ankihub(on_done=on_done) # Assert that the future contains the exception and that it contains the expected message. - assert future.exception().args[0] == exception_mesaage + assert future.exception().args[0] == exception_message def test_uninstalling_deck_removes_related_deck_extension_from_config( @@ -3483,7 +3434,7 @@ class TestAutoSync: def test_with_on_ankiweb_sync_config_option( self, anki_session_with_addon_data: AnkiSession, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, mock_client_methods_called_during_ankihub_sync: None, qtbot: QtBot, ): @@ -3491,7 +3442,7 @@ def test_with_on_ankiweb_sync_config_option( mw = anki_session_with_addon_data.mw # Mock the syncs. - self._mock_syncs_and_check_new_subscriptions(monkeypatch) + self._mock_syncs_and_check_new_subscriptions(mocker) # Setup the auto sync. _setup_ankihub_sync_on_ankiweb_sync() @@ -3500,7 +3451,7 @@ def test_with_on_ankiweb_sync_config_option( config.public_config["auto_sync"] = "on_ankiweb_sync" # Trigger the AnkiWeb sync. - mw._sync_collection_and_media(after_sync=Mock()) + mw._sync_collection_and_media(after_sync=mocker.stub()) qtbot.wait(500) # Assert that both syncs were called. @@ -3512,14 +3463,14 @@ def test_with_on_ankiweb_sync_config_option( def test_with_never_option( self, anki_session_with_addon_data: AnkiSession, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, qtbot: QtBot, ): with anki_session_with_addon_data.profile_loaded(): mw = anki_session_with_addon_data.mw # Mock the syncs. - self._mock_syncs_and_check_new_subscriptions(monkeypatch) + self._mock_syncs_and_check_new_subscriptions(mocker) # Setup the auto sync. _setup_ankihub_sync_on_ankiweb_sync() @@ -3528,7 +3479,7 @@ def test_with_never_option( config.public_config["auto_sync"] = "never" # Trigger the AnkiWeb sync. - mw._sync_collection_and_media(after_sync=Mock()) + mw._sync_collection_and_media(after_sync=mocker.stub()) qtbot.wait(500) # Assert that only the AnkiWeb sync was called. @@ -3540,7 +3491,7 @@ def test_with_never_option( def test_with_on_startup_option( self, anki_session_with_addon_data: AnkiSession, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, mock_client_methods_called_during_ankihub_sync: None, qtbot: QtBot, ): @@ -3548,7 +3499,7 @@ def test_with_on_startup_option( mw = anki_session_with_addon_data.mw # Mock the syncs. - self._mock_syncs_and_check_new_subscriptions(monkeypatch) + self._mock_syncs_and_check_new_subscriptions(mocker) # Setup the auto sync. _setup_ankihub_sync_on_ankiweb_sync() @@ -3557,7 +3508,7 @@ def test_with_on_startup_option( config.public_config["auto_sync"] = "on_startup" # Trigger the AnkiWeb sync. - mw._sync_collection_and_media(after_sync=Mock()) + mw._sync_collection_and_media(after_sync=mocker.stub()) qtbot.wait(500) # Assert that both syncs were called. @@ -3568,7 +3519,7 @@ def test_with_on_startup_option( self.check_and_install_new_deck_subscriptions_mock.call_count == 1 # Trigger the AnkiWeb sync again. - mw._sync_collection_and_media(after_sync=Mock()) + mw._sync_collection_and_media(after_sync=mocker.stub()) qtbot.wait(500) # Assert that only the AnkiWeb sync was called the second time. @@ -3577,34 +3528,23 @@ def test_with_on_startup_option( assert self.check_and_install_new_deck_subscriptions_mock.call_count == 1 - def _mock_syncs_and_check_new_subscriptions(self, monkeypatch: MonkeyPatch): + def _mock_syncs_and_check_new_subscriptions(self, mocker: MockerFixture): # Mock the token so that the AnkiHub sync is not skipped. - monkeypatch.setattr( - config, "token", MagicMock(return_value=lambda: "test_token") - ) + mocker.patch.object(config, "token", return_value="test_token") # Mock update_decks_and_media so it does nothing. - self.udpate_decks_and_media_mock = Mock() - monkeypatch.setattr( - ah_deck_updater, "update_decks_and_media", self.udpate_decks_and_media_mock + self.udpate_decks_and_media_mock = mocker.patch.object( + ah_deck_updater, "update_decks_and_media" ) # Mock the AnkiWeb sync so it does nothing. - self.ankiweb_sync_mock = Mock() - monkeypatch.setattr( - aqt.sync, - "sync_collection", - self.ankiweb_sync_mock, - ) + self.ankiweb_sync_mock = mocker.patch.object(aqt.sync, "sync_collection") # ... and reload aqt.main so the mock is used. importlib.reload(aqt.main) # Mock the new deck subscriptions operation to just call its callback. - self.check_and_install_new_deck_subscriptions_mock = Mock() - monkeypatch.setattr( - operations.ankihub_sync, - "check_and_install_new_deck_subscriptions", - self.check_and_install_new_deck_subscriptions_mock, + self.check_and_install_new_deck_subscriptions_mock = mocker.patch( + "ankihub.gui.operations.ankihub_sync.check_and_install_new_deck_subscriptions" ) self.check_and_install_new_deck_subscriptions_mock.side_effect = ( lambda *args, **kwargs: kwargs["on_done"](future_with_result(None)) @@ -3624,7 +3564,7 @@ class TestAutoSyncRateLimit: def test_rate_limit( self, anki_session_with_addon_data: AnkiSession, - mock_function: MockFunction, + mocker: MockerFixture, qtbot: QtBot, mock_ankihub_sync_dependencies, delay_between_syncs_in_seconds: float, @@ -3633,7 +3573,9 @@ def test_rate_limit( # Run the entry point so that the auto sync and rate limit is set up. entry_point.run() with anki_session_with_addon_data.profile_loaded(): - sync_with_ankihub_mock = mock_function(auto_sync, "sync_with_ankihub") + sync_with_ankihub_mock = mocker.patch( + "ankihub.gui.auto_sync.sync_with_ankihub" + ) # Trigger the sync two times, with a delay in between. aqt.mw._sync_collection_and_media(lambda: None) @@ -3649,10 +3591,9 @@ def test_rate_limit( def test_optional_tag_suggestion_dialog( anki_session_with_addon_data: AnkiSession, qtbot: QtBot, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, import_ah_note: ImportAHNote, next_deterministic_uuid, - mock_function: MockFunction, ): anki_session = anki_session_with_addon_data @@ -3684,14 +3625,16 @@ def test_optional_tag_suggestion_dialog( notes[2].flush() # Mock client methods - mock_function( - "ankihub.gui.optional_tag_suggestion_dialog.AnkiHubClient.get_deck_extensions", + mocker.patch.object( + AnkiHubClient, + "get_deck_extensions", return_value=[], ) - monkeypatch.setattr( - "ankihub.ankihub_client.AnkiHubClient.prevalidate_tag_groups", - lambda *args, **kwargs: [ + mocker.patch.object( + AnkiHubClient, + "prevalidate_tag_groups", + return_value=[ TagGroupValidationResponse( tag_group_name="VALID", deck_extension_id=1, @@ -3724,10 +3667,9 @@ def test_optional_tag_suggestion_dialog( assert dialog.tag_group_list.item(1).toolTip() == "" assert dialog.submit_btn.isEnabled() - suggest_optional_tags_mock = Mock() - monkeypatch.setattr( - "ankihub.ankihub_client.AnkiHubClient.suggest_optional_tags", - suggest_optional_tags_mock, + suggest_optional_tags_mock = mocker.patch.object( + AnkiHubClient, + "suggest_optional_tags", ) # Select the "VALID" tag group and click the submit button @@ -3769,7 +3711,7 @@ def test_optional_tag_suggestion_dialog( def test_reset_optional_tags_action( anki_session_with_addon_data: AnkiSession, qtbot: QtBot, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, install_sample_ah_deck: InstallSampleAHDeck, ): entry_point.run() @@ -3805,24 +3747,25 @@ def test_reset_optional_tags_action( mw.col.add_note(other_note, DeckId(1)) # mock the choose_list function to always return the first item - choose_list_mock = Mock() - choose_list_mock.return_value = 0 - monkeypatch.setattr("ankihub.gui.browser.browser.choose_list", choose_list_mock) + choose_list_mock = mocker.patch( + "ankihub.gui.browser.browser.choose_list", + return_value=0, + ) # mock the ask_user function to always confirm the reset - monkeypatch.setattr( - "ankihub.gui.browser.browser.ask_user", lambda *args, **kwargs: True - ) + mocker.patch("ankihub.gui.browser.browser.ask_user", return_value=True) # mock the is_logged_in function to always return True - is_logged_in_mock = Mock() - is_logged_in_mock.return_value = True - monkeypatch.setattr(config, "is_logged_in", is_logged_in_mock) + is_logged_in_mock = mocker.patch.object( + config, + "is_logged_in", + return_value=True, + ) # mock method of ah_deck_updater - update_decks_and_media_mock = Mock() - monkeypatch.setattr( - ah_deck_updater, "update_decks_and_media", update_decks_and_media_mock + update_decks_and_media_mock = mocker.patch.object( + ah_deck_updater, + "update_decks_and_media", ) # run the reset action @@ -3854,7 +3797,7 @@ def test_download_media( self, anki_session_with_addon_data: AnkiSession, install_sample_ah_deck: InstallSampleAHDeck, - mock_function: MockFunction, + mocker: MockerFixture, qtbot: QtBot, ): with anki_session_with_addon_data.profile_loaded(): @@ -3869,7 +3812,7 @@ def test_download_media( exists_on_s3=True, download_enabled=True, ) - get_deck_media_updates_mock = mock_function( + get_deck_media_updates_mock = mocker.patch.object( AnkiHubClient, "get_deck_media_updates", return_value=[ @@ -3880,7 +3823,7 @@ def test_download_media( ) # Mock the client method for downloading media - download_media_mock = mock_function(AnkiHubClient, "download_media") + download_media_mock = mocker.patch.object(AnkiHubClient, "download_media") # Start the media sync and wait for it to finish media_sync.start_media_download() @@ -3911,14 +3854,14 @@ def test_download_media_with_no_updates( self, anki_session_with_addon_data: AnkiSession, install_sample_ah_deck: InstallSampleAHDeck, - mock_function: MockFunction, + mocker: MockerFixture, qtbot: QtBot, ): with anki_session_with_addon_data.profile_loaded(): _, ah_did = install_sample_ah_deck() # Mock client to return an empty deck media update - get_deck_media_updates_mock = mock_function( + get_deck_media_updates_mock = mocker.patch.object( AnkiHubClient, "get_deck_media_updates", return_value=[ @@ -3927,7 +3870,7 @@ def test_download_media_with_no_updates( ) # Mock the client method for downloading media - download_media_mock = mock_function(AnkiHubClient, "download_media") + download_media_mock = mocker.patch.object(AnkiHubClient, "download_media") # Start the media sync and wait for it to finish media_sync.start_media_download() @@ -3943,7 +3886,7 @@ def test_download_media_with_no_updates( @fixture def mock_client_media_upload( - monkeypatch: MonkeyPatch, + mocker: MockerFixture, requests_mock: Mocker, ) -> Iterator[Mocker]: fake_presigned_url = AnkiHubClient().s3_bucket_url + "/fake_key" @@ -3951,21 +3894,18 @@ def mock_client_media_upload( fake_presigned_url, json={"success": True}, status_code=204 ) - monkeypatch.setattr( + mocker.patch.object( AnkiHubClient, "is_media_upload_finished", - lambda *args, **kwargs: True, + return_value=True, ) - monkeypatch.setattr( - "ankihub.ankihub_client.AnkiHubClient.media_upload_finished", - lambda *args, **kwargs: False, - ) + mocker.patch.object(AnkiHubClient, "media_upload_finished") - monkeypatch.setattr( + mocker.patch.object( AnkiHubClient, "_get_presigned_url_for_multiple_uploads", - lambda *args, **kwargs: { + return_value={ "url": fake_presigned_url, "fields": { "key": "deck_images/test/${filename}", @@ -3974,8 +3914,7 @@ def mock_client_media_upload( ) # Mock os.remove so the zip is not deleted - os_remove_mock = MagicMock() - monkeypatch.setattr(os, "remove", os_remove_mock) + mocker.patch("os.remove") # Create a temporary media folder and copy the test media files to it. # Patch the media folder path to point to the temporary folder. @@ -3983,9 +3922,7 @@ def mock_client_media_upload( for file in (TEST_DATA_PATH / "media").glob("*"): shutil.copy(file, Path(tmp_dir) / file.name) - monkeypatch.setattr( - "anki.media.MediaManager.dir", lambda *args, **kwargs: tmp_dir - ) + mocker.patch("anki.media.MediaManager.dir", return_value=tmp_dir) yield s3_upload_request_mock # type: ignore @@ -4303,7 +4240,7 @@ def test_check_and_prompt_for_updates_on_main_window( @pytest.mark.qt_no_exception_capture class TestDebugModule: def test_setup_logging_for_sync_collection_and_media( - self, anki_session: AnkiSession, monkeypatch: MonkeyPatch + self, anki_session: AnkiSession, mocker: MockerFixture ): # Test that the original AnkiQt._sync_collection_and_media method gets called # despite the monkeypatching we do in debug.py. @@ -4311,13 +4248,12 @@ def test_setup_logging_for_sync_collection_and_media( mw = anki_session.mw # Mock the AnkiWeb sync to do nothing - monkeypatch.setattr(aqt.sync, "sync_collection", Mock()) + mocker.patch.object(aqt.sync, "sync_collection") # ... and reload the main module so that the mock is used. importlib.reload(aqt.main) # Mock the sync_will_start hook so that we can check if it was called when the sync starts. - sync_will_start_mock = Mock() - monkeypatch.setattr(gui_hooks, "sync_will_start", sync_will_start_mock) + sync_will_start_mock = mocker.patch.object(gui_hooks, "sync_will_start") _setup_logging_for_sync_collection_and_media() @@ -4326,13 +4262,12 @@ def test_setup_logging_for_sync_collection_and_media( sync_will_start_mock.assert_called_once() def test_setup_logging_for_db_begin( - self, anki_session: AnkiSession, monkeypatch: MonkeyPatch + self, anki_session: AnkiSession, mocker: MockerFixture ): with anki_session.profile_loaded(): mw = anki_session.mw - db_begin_mock = Mock() - monkeypatch.setattr(mw.col._backend, "db_begin", db_begin_mock) + db_begin_mock = mocker.patch.object(mw.col._backend, "db_begin") _setup_logging_for_db_begin() @@ -4393,7 +4328,7 @@ def test_handle_notes_deleted_from_webapp( def test_upload_logs_and_data( anki_session_with_addon_data: AnkiSession, - monkeypatch: MonkeyPatch, + mocker: MockerFixture, qtbot: QtBot, ): with anki_session_with_addon_data.profile_loaded(): @@ -4407,10 +4342,7 @@ def upload_logs_mock(*args, **kwargs): key = kwargs["key"] # Mock the client.upload_logs method - monkeypatch.setattr( - "ankihub.gui.errors.AnkiHubClient.upload_logs", - upload_logs_mock, - ) + mocker.patch.object(AnkiHubClient, "upload_logs", side_effect=upload_logs_mock) # Start the upload in the background and wait until it is finished. upload_logs_and_data_in_background() @@ -4471,7 +4403,7 @@ def test_delete_ankihub_private_config_on_deckBrowser__delete_option( anki_session_with_addon_data: AnkiSession, install_sample_ah_deck: InstallSampleAHDeck, qtbot: QtBot, - mock_function: MockFunction, + mocker: MockerFixture, ): entry_point.run() @@ -4489,13 +4421,13 @@ def test_delete_ankihub_private_config_on_deckBrowser__delete_option( assert deck_uuid # Will control the conditional responsible to delete or not the ankihub deck private config - mock_function(deckbrowser, "ask_user", return_value=True) + mocker.patch("ankihub.gui.deckbrowser.ask_user", return_value=True) - with patch.object( + unsubscribe_from_deck_mock = mocker.patch.object( AnkiHubClient, "unsubscribe_from_deck" - ) as unsubscribe_from_deck_mock: - mw.deckBrowser._delete(anki_deck_id) - unsubscribe_from_deck_mock.assert_called_once() + ) + mw.deckBrowser._delete(anki_deck_id) + unsubscribe_from_deck_mock.assert_called_once() qtbot.wait(500) @@ -4523,7 +4455,7 @@ def test_not_delete_ankihub_private_config_on_deckBrowser__delete_option( anki_session_with_addon_data: AnkiSession, install_sample_ah_deck: InstallSampleAHDeck, qtbot: QtBot, - mock_function: MockFunction, + mocker: MockerFixture, ): entry_point.run() @@ -4539,7 +4471,7 @@ def test_not_delete_ankihub_private_config_on_deckBrowser__delete_option( assert deck_uuid # Will control the conditional responsible to delete or not the ankihub deck private config - mock_function(deckbrowser, "ask_user", return_value=False) + mocker.patch("ankihub.gui.deckbrowser.ask_user", return_value=False) mw.deckBrowser._delete(anki_deck_id) qtbot.wait(500) @@ -4552,8 +4484,8 @@ def test_not_delete_ankihub_private_config_on_deckBrowser__delete_option( @pytest.mark.qt_no_exception_capture class TestAHDBCheck: - def test_with_nothing_missing(self, qtbot: QtBot): - on_done_mock = Mock() + def test_with_nothing_missing(self, qtbot: QtBot, mocker: MockerFixture): + on_done_mock = mocker.stub() check_ankihub_db(on_done_mock) qtbot.wait_until(lambda: on_done_mock.call_count == 1) @@ -4573,7 +4505,7 @@ def test_with_deck_missing_from_config( import_ah_note: ImportAHNote, mock_download_and_install_deck_dependencies: MockDownloadAndInstallDeckDependencies, ankihub_basic_note_type: NotetypeDict, - mock_function: MockFunction, + mocker: MockerFixture, qtbot: QtBot, user_confirms: bool, deck_exists_on_ankihub: bool, @@ -4601,17 +4533,17 @@ def raise_404(*args, **kwargs) -> None: response_404.status_code = 404 raise AnkiHubHTTPError(response=response_404) - mock_function( + mocker.patch.object( AnkiHubClient, "get_deck_by_id", side_effect=raise_404, ) # Mock ask_user function - mock_function(ah_db_check, "ask_user", return_value=user_confirms) + mocker.patch.object(ah_db_check, "ask_user", return_value=user_confirms) # Run the db check - on_done_mock = Mock() + on_done_mock = mocker.stub() check_ankihub_db(on_done_mock) qtbot.wait_until(lambda: on_done_mock.call_count == 1) diff --git a/tests/addon/test_unit.py b/tests/addon/test_unit.py index 749c28da6..f220db865 100644 --- a/tests/addon/test_unit.py +++ b/tests/addon/test_unit.py @@ -17,10 +17,9 @@ from anki.decks import DeckId from anki.models import NotetypeDict from anki.notes import Note, NoteId -from aqt import utils from aqt.qt import QDialog, QDialogButtonBox, Qt, QTimer, QWidget -from pytest import MonkeyPatch from pytest_anki import AnkiSession +from pytest_mock import MockerFixture from pytestqt.qtbot import QtBot # type: ignore from requests import Response @@ -38,7 +37,6 @@ from ..fixtures import ( # type: ignore ImportAHNoteType, InstallAHDeck, - MockFunction, MockStudyDeckDialogWithCB, NewNoteWithNoteType, SetFeatureFlagState, @@ -65,7 +63,6 @@ from ankihub.db.db import _AnkiHubDB from ankihub.db.exceptions import IntegrityError, LockAcquisitionTimeoutError from ankihub.feature_flags import _FeatureFlags, feature_flags -from ankihub.gui import errors, suggestion_dialog from ankihub.gui.error_dialog import ErrorDialog from ankihub.gui.errors import ( OUTDATED_CLIENT_ERROR_REASON, @@ -76,7 +73,7 @@ ) from ankihub.gui.media_sync import media_sync from ankihub.gui.menu import AnkiHubLogin -from ankihub.gui.operations import AddonQueryOp, deck_creation +from ankihub.gui.operations import AddonQueryOp from ankihub.gui.operations.deck_creation import ( DeckCreationConfirmationDialog, create_collaborative_deck, @@ -180,14 +177,11 @@ def test_update_media_names_on_notes( class TestMediaSyncMediaDownload: - def test_with_exception(self, mock_function: MockFunction, qtbot: QtBot): - def raise_exception() -> None: - raise Exception("test") - - update_and_download_mock = mock_function( + def test_with_exception(self, mocker: MockerFixture, qtbot: QtBot): + update_and_download_mock = mocker.patch.object( media_sync, "_update_deck_media_and_download_missing_media", - side_effect=raise_exception, + side_effect=Exception("test"), ) with qtbot.captureExceptions() as exceptions: @@ -204,19 +198,15 @@ class TestMediaSyncMediaUpload: def test_with_exception( self, anki_session_with_addon_data: AnkiSession, - mock_function: MockFunction, + mocker: MockerFixture, qtbot: QtBot, next_deterministic_uuid, ): with anki_session_with_addon_data.profile_loaded(): - - def raise_exception() -> None: - raise Exception("test") - - upload_media_mock = mock_function( + upload_media_mock = mocker.patch.object( media_sync._client, "upload_media", - side_effect=raise_exception, + side_effect=Exception("test"), ) with qtbot.captureExceptions() as exceptions: @@ -509,12 +499,16 @@ def test_add_subdeck_tags_to_notes_with_spaces_in_deck_name( class TestAnkiHubLoginDialog: - def test_login(self, qtbot: QtBot, mock_function: MockFunction): + def test_login( + self, + qtbot: QtBot, + mocker: MockerFixture, + ): username = "test_username" password = "test_password" token = "test_token" - login_mock = mock_function( + login_mock = mocker.patch( "ankihub.gui.menu.AnkiHubClient.login", return_value=token ) @@ -560,8 +554,9 @@ def test_visibility_of_form_elements_and_form_result( source_type: SourceType, media_was_added: bool, qtbot: QtBot, + mocker: MockerFixture, ): - callback_mock = Mock() + callback_mock = mocker.stub() dialog = SuggestionDialog( is_for_anking_deck=is_for_anking_deck, is_new_note_suggestion=is_new_note_suggestion, @@ -643,8 +638,10 @@ def test_visibility_of_form_elements_and_form_result( False, ], ) - def test_submit_without_review_checkbox(self, can_submit_without_review: bool): - callback_mock = Mock() + def test_submit_without_review_checkbox( + self, can_submit_without_review: bool, mocker: MockerFixture + ): + callback_mock = mocker.stub() dialog = SuggestionDialog( is_for_anking_deck=False, is_new_note_suggestion=False, @@ -737,7 +734,7 @@ def __call__(self, user_cancels: bool) -> Tuple[Mock, Mock]: @pytest.fixture def mock_dependiencies_for_suggestion_dialog( - mock_function: MockFunction, + mocker: MockerFixture, mock_suggestion_dialog, ) -> MockDependenciesForSuggestionDialog: """Mocks the dependencies for open_suggestion_dialog_for_note. @@ -750,13 +747,11 @@ def mock_dependencies_for_suggestion_dialog_inner( ) -> Tuple[Mock, Mock]: mock_suggestion_dialog(user_cancels=user_cancels) - suggest_note_update_mock = mock_function( - suggestion_dialog, - "suggest_note_update", + suggest_note_update_mock = mocker.patch( + "ankihub.gui.suggestion_dialog.suggest_note_update" ) - suggest_new_note_mock = mock_function( - suggestion_dialog, - "suggest_new_note", + suggest_new_note_mock = mocker.patch( + "ankihub.gui.suggestion_dialog.suggest_new_note" ) return suggest_note_update_mock, suggest_new_note_mock @@ -820,7 +815,7 @@ def test_with_new_note_which_could_belong_to_two_decks( import_ah_note_type: ImportAHNoteType, new_note_with_note_type: NewNoteWithNoteType, mock_dependiencies_for_suggestion_dialog: MockDependenciesForSuggestionDialog, - mock_function: MockFunction, + mocker: MockerFixture, user_cancels: bool, ): with anki_session_with_addon_data.profile_loaded(): @@ -838,9 +833,8 @@ def test_with_new_note_which_could_belong_to_two_decks( suggest_new_note_mock, ) = mock_dependiencies_for_suggestion_dialog(user_cancels=False) - choose_ankihub_deck_mock = mock_function( - suggestion_dialog, - "choose_ankihub_deck", + choose_ankihub_deck_mock = mocker.patch( + "ankihub.gui.suggestion_dialog.choose_ankihub_deck", return_value=None if user_cancels else ah_did_1, ) @@ -868,8 +862,8 @@ def __call__(self, user_cancels: bool) -> Mock: @pytest.fixture def mock_dependencies_for_bulk_suggestion_dialog( - monkeypatch: MonkeyPatch, mock_suggestion_dialog, + mocker: MockerFixture, ) -> MockDependenciesForBulkSuggestionDialog: """Mocks the dependencies for open_suggestion_dialog_for_bulk_suggestion. Returns a Mock that replaces suggest_notes_in_bulk. @@ -879,15 +873,11 @@ def mock_dependencies_for_bulk_suggestion_dialog( def mock_dependencies_for_suggestion_dialog_inner(user_cancels: bool) -> Mock: mock_suggestion_dialog(user_cancels=user_cancels) - suggest_notes_in_bulk_mock = Mock() - monkeypatch.setattr( + suggest_notes_in_bulk_mock = mocker.patch( "ankihub.gui.suggestion_dialog.suggest_notes_in_bulk", - suggest_notes_in_bulk_mock, ) - monkeypatch.setattr( - "ankihub.gui.suggestion_dialog._on_suggest_notes_in_bulk_done", Mock() - ) + mocker.patch("ankihub.gui.suggestion_dialog._on_suggest_notes_in_bulk_done") return suggest_notes_in_bulk_mock return mock_dependencies_for_suggestion_dialog_inner @@ -967,7 +957,7 @@ def test_with_two_new_notes_with_decks_in_common( import_ah_note_type: ImportAHNoteType, new_note_with_note_type: NewNoteWithNoteType, mock_dependencies_for_bulk_suggestion_dialog: MockDependenciesForBulkSuggestionDialog, - mock_function: MockFunction, + mocker: MockerFixture, qtbot: QtBot, ): with anki_session_with_addon_data.profile_loaded(): @@ -981,9 +971,8 @@ def test_with_two_new_notes_with_decks_in_common( nids = [note_1.id, note_2.id] - choose_ankihub_deck_mock = mock_function( - suggestion_dialog, - "choose_ankihub_deck", + choose_ankihub_deck_mock = mocker.patch( + "ankihub.gui.suggestion_dialog.choose_ankihub_deck", return_value=ah_did_1, ) suggest_notes_in_bulk_mock = mock_dependencies_for_bulk_suggestion_dialog( @@ -1007,12 +996,9 @@ def test_with_two_new_notes_with_decks_in_common( class TestOnSuggestNotesInBulkDone: def test_correct_message_is_shown( self, - mock_function: MockFunction, + mocker: MockerFixture, ): - showText_mock = mock_function( - suggestion_dialog, - "showText", - ) + showText_mock = mocker.patch("ankihub.gui.suggestion_dialog.showText") nid_1 = NoteId(1) nid_2 = NoteId(2) _on_suggest_notes_in_bulk_done( @@ -1056,13 +1042,15 @@ def test_with_exception_in_future(self): parent=aqt.mw, ) - def test_with_http_403_exception_in_future(self, mock_function: MockFunction): + def test_with_http_403_exception_in_future(self, mocker: MockerFixture): response = Response() response.status_code = 403 response.json = lambda: {"detail": "test"} # type: ignore exception = AnkiHubHTTPError(response) - show_error_dialog_mock = mock_function(suggestion_dialog, "show_error_dialog") + show_error_dialog_mock = mocker.patch( + "ankihub.gui.suggestion_dialog.show_error_dialog", + ) _on_suggest_notes_in_bulk_done( future=future_with_exception(exception), @@ -1538,12 +1526,12 @@ class TestAnkiHubDBContextManagers: def test_blocking_and_timeout_behavior( self, anki_session_with_addon_data: AnkiSession, - monkeypatch: MonkeyPatch, qtbot: QtBot, + mocker: MockerFixture, task_configs: List[Tuple[Callable[[], ContextManager], float]], task_times_out: bool, ): - monkeypatch.setattr("ankihub.db.rw_lock.LOCK_TIMEOUT_SECONDS", 0.2) + mocker.patch("ankihub.db.rw_lock.LOCK_TIMEOUT_SECONDS", 0.2) def task(context_manager: Callable[[], ContextManager], duration: float): with context_manager(): @@ -1601,9 +1589,9 @@ def test_contains_path_to_this_addon(self): "\\addons21\\12345789\\src\\ankihub\\errors.py" ) - def test_handle_ankihub_401(self, mock_function: MockFunction): + def test_handle_ankihub_401(self, mocker: MockerFixture): # Set up mock for AnkiHub login dialog. - display_login_mock = mock_function(AnkiHubLogin, "display_login") + display_login_mock = mocker.patch.object(AnkiHubLogin, "display_login") handled = _try_handle_exception( exc_type=AnkiHubHTTPError, @@ -1624,11 +1612,11 @@ def test_handle_ankihub_401(self, mock_function: MockFunction): ], ) def test_handle_ankihub_403( - self, mock_function: MockFunction, response_content: str, expected_handled: bool + self, mocker: MockerFixture, response_content: str, expected_handled: bool ): - show_error_dialog_mock = mock_function(errors, "show_error_dialog") + show_error_dialog_mock = mocker.patch("ankihub.gui.errors.show_error_dialog") - response_mock = Mock() + response_mock = mocker.Mock() response_mock.status_code = 403 response_mock.text = response_content response_mock.json = lambda: json.loads(response_content) # type: ignore @@ -1641,8 +1629,8 @@ def test_handle_ankihub_403( assert handled == expected_handled assert show_error_dialog_mock.called == expected_handled - def test_handle_ankihub_406(self, mock_function: MockFunction): - ask_user_mock = mock_function(errors, "ask_user", return_value=False) + def test_handle_ankihub_406(self, mocker: MockerFixture): + ask_user_mock = mocker.patch("ankihub.gui.errors.ask_user", return_value=False) handled = _try_handle_exception( exc_type=AnkiHubHTTPError, exc_value=AnkiHubHTTPError( @@ -1655,22 +1643,18 @@ def test_handle_ankihub_406(self, mock_function: MockFunction): def test_show_error_dialog( - anki_session_with_addon_data: AnkiSession, mock_function: MockFunction, qtbot: QtBot + anki_session_with_addon_data: AnkiSession, mocker: MockerFixture, qtbot: QtBot ): with anki_session_with_addon_data.profile_loaded(): - show_dialog_mock = mock_function("ankihub.gui.utils.show_dialog") + show_dialog_mock = mocker.patch("ankihub.gui.utils.show_dialog") show_error_dialog("some message", title="some title", parent=aqt.mw) qtbot.wait_until(lambda: show_dialog_mock.called) class TestUploadLogs: - def test_basic( - self, - qtbot: QtBot, - mock_function: MockFunction, - ): - on_done_mock = Mock() - upload_logs_mock = mock_function(AddonAnkiHubClient, "upload_logs") + def test_basic(self, qtbot: QtBot, mocker: MockerFixture): + on_done_mock = mocker.stub() + upload_logs_mock = mocker.patch.object(AddonAnkiHubClient, "upload_logs") upload_logs_in_background(on_done=on_done_mock) qtbot.wait_until(lambda: on_done_mock.called) @@ -1697,18 +1681,15 @@ def test_basic( def test_with_exception( self, qtbot: QtBot, - mock_function: MockFunction, - exception: Exception, + mocker: MockerFixture, expected_report_exception_called: bool, + exception: Exception, ): - def raise_exception(*args, **kwargs) -> None: - raise exception - - on_done_mock = Mock() - upload_logs_mock = mock_function( - AddonAnkiHubClient, "upload_logs", side_effect=raise_exception + on_done_mock = mocker.stub() + upload_logs_mock = mocker.patch.object( + AddonAnkiHubClient, "upload_logs", side_effect=exception ) - report_exception_mock = mock_function(errors, "_report_exception") + report_exception_mock = mocker.patch("ankihub.gui.errors._report_exception") upload_logs_in_background(on_done=on_done_mock) qtbot.wait(500) @@ -1757,7 +1738,7 @@ def foo(on_done: Callable[[], None]) -> None: assert execution_counter == 11 -def test_error_dialog(qtbot: QtBot, monkeypatch: MonkeyPatch): +def test_error_dialog(qtbot: QtBot, mocker: MockerFixture): try: raise Exception("test") except Exception as e: @@ -1771,8 +1752,7 @@ def test_error_dialog(qtbot: QtBot, monkeypatch: MonkeyPatch): dialog.debug_info_button.click() # Check that the Yes button opens a link (to the AnkiHub forum). - open_link_mock = Mock() - monkeypatch.setattr(utils, "openLink", open_link_mock) + open_link_mock = mocker.patch("aqt.utils.openLink") dialog.button_box.button(QDialogButtonBox.StandardButton.Yes).click() open_link_mock.assert_called_once() @@ -1862,7 +1842,7 @@ def __call__(self, deck_name: str) -> None: @pytest.fixture def mock_ui_for_create_collaborative_deck( - mock_function: MockFunction, + mocker: MockerFixture, mock_study_deck_dialog_with_cb: MockStudyDeckDialogWithCB, ) -> MockUIForCreateCollaborativeDeck: """Mock the UI interaction for creating a collaborative deck. @@ -1872,9 +1852,9 @@ def mock_ui_interaction_inner(deck_name) -> None: mock_study_deck_dialog_with_cb( "ankihub.gui.operations.deck_creation.StudyDeck", deck_name ) - mock_function(deck_creation, "ask_user", return_value=True) - mock_function(deck_creation, "showInfo") - mock_function(DeckCreationConfirmationDialog, "run", return_value=True) + mocker.patch("ankihub.gui.operations.deck_creation.ask_user", return_value=True) + mocker.patch("ankihub.gui.operations.deck_creation.showInfo") + mocker.patch.object(DeckCreationConfirmationDialog, "run", return_value=True) return mock_ui_interaction_inner @@ -1888,7 +1868,7 @@ class TestCreateCollaborativeDeck: def test_basic( self, anki_session_with_addon_data: AnkiSession, - mock_function: MockFunction, + mocker: MockerFixture, next_deterministic_uuid: Callable[[], uuid.UUID], qtbot: QtBot, mock_ui_for_create_collaborative_deck: MockUIForCreateCollaborativeDeck, @@ -1902,32 +1882,29 @@ def test_basic( mock_ui_for_create_collaborative_deck(deck_name) - mock_function(AnkiHubClient, "get_owned_decks", return_value=[]) - - def raise_exception(*args, **kwargs) -> None: - raise Exception("test") + mocker.patch.object(AnkiHubClient, "get_owned_decks", return_value=[]) ah_did = next_deterministic_uuid() notes_data = [NoteInfoFactory.create()] - create_ankihub_deck_mock = mock_function( - deck_creation, - "create_ankihub_deck", + create_ankihub_deck_mock = mocker.patch( + "ankihub.gui.operations.deck_creation.create_ankihub_deck", return_value=DeckCreationResult( ankihub_did=ah_did, notes_data=notes_data, ), - side_effect=raise_exception if creating_deck_fails else None, + side_effect=Exception("test") if creating_deck_fails else None, ) - get_media_names_from_notes_data_mock = mock_function( - deck_creation, - "get_media_names_from_notes_data", + get_media_names_from_notes_data_mock = mocker.patch( + "ankihub.gui.operations.deck_creation.get_media_names_from_notes_data", return_value=[], ) - start_media_upload_mock = mock_function( - deck_creation.media_sync, "start_media_upload" + start_media_upload_mock = mocker.patch.object( + media_sync, "start_media_upload" + ) + showInfo_mock = mocker.patch( + "ankihub.gui.operations.deck_creation.showInfo" ) - showInfo_mock = mock_function(deck_creation, "showInfo") # Create the AnkiHub deck. if creating_deck_fails: @@ -1950,7 +1927,7 @@ def raise_exception(*args, **kwargs) -> None: def test_with_deck_name_existing( self, anki_session_with_addon_data: AnkiSession, - mock_function: MockFunction, + mocker: MockerFixture, mock_ui_for_create_collaborative_deck: MockUIForCreateCollaborativeDeck, ): """When the user already has a deck with the same name, the deck creation is cancelled and @@ -1963,7 +1940,7 @@ def test_with_deck_name_existing( mock_ui_for_create_collaborative_deck(deck_name) - mock_function( + mocker.patch.object( AnkiHubClient, "get_owned_decks", return_value=[ @@ -1973,10 +1950,11 @@ def test_with_deck_name_existing( ], ) - showInfo_mock = mock_function(deck_creation, "showInfo") - create_ankihub_deck_mock = mock_function( - deck_creation, - "create_ankihub_deck", + showInfo_mock = mocker.patch( + "ankihub.gui.operations.deck_creation.showInfo" + ) + create_ankihub_deck_mock = mocker.patch( + "ankihub.gui.operations.deck_creation.create_ankihub_deck" ) create_collaborative_deck() @@ -2204,7 +2182,7 @@ def test_with_two_reviews_for_one_deck( anki_session_with_addon_data: AnkiSession, install_ah_deck: InstallAHDeck, import_ah_note: ImportAHNote, - mock_function: MockFunction, + mocker: MockerFixture, ) -> None: with anki_session_with_addon_data.profile_loaded(): ah_did = install_ah_deck() @@ -2217,7 +2195,7 @@ def test_with_two_reviews_for_one_deck( second_review_time = first_review_time + timedelta(days=1) record_review_for_anki_nid(NoteId(note_info_2.anki_nid), second_review_time) - send_card_review_data_mock = mock_function( + send_card_review_data_mock = mocker.patch.object( AnkiHubClient, "send_card_review_data" ) @@ -2243,14 +2221,14 @@ def test_without_reviews( self, anki_session_with_addon_data: AnkiSession, install_ah_deck: InstallAHDeck, - mock_function: MockFunction, + mocker: MockerFixture, ) -> None: with anki_session_with_addon_data.profile_loaded(): # We install the deck so that we get coverage for the case where a deck # has no reviews. install_ah_deck() - send_card_review_data_mock = mock_function( + send_card_review_data_mock = mocker.patch.object( AnkiHubClient, "send_card_review_data" ) @@ -2397,7 +2375,7 @@ def test_submit_tags_for_validated_groups( qtbot: QtBot, install_ah_deck: InstallAHDeck, import_ah_note: ImportAHNote, - mock_function: MockFunction, + mocker: MockerFixture, ): with anki_session_with_addon_data.profile_loaded(): ah_did = install_ah_deck() @@ -2434,12 +2412,12 @@ def test_submit_tags_for_validated_groups( ) validation_responses.append(validation_reponse) - get_deck_extensions_mock = mock_function( + get_deck_extensions_mock = mocker.patch( "ankihub.gui.optional_tag_suggestion_dialog.AnkiHubClient.get_deck_extensions", return_value=deck_extensions, ) - prevalidate_tag_groups_mock = mock_function( + prevalidate_tag_groups_mock = mocker.patch( "ankihub.main.optional_tag_suggestions.AnkiHubClient.prevalidate_tag_groups", return_value=validation_responses, ) @@ -2451,7 +2429,7 @@ def test_submit_tags_for_validated_groups( ) # Mock the suggest_tags_for_groups method which is called when the submit button is clicked - suggest_tags_for_groups_mock = mock_function( + suggest_tags_for_groups_mock = mocker.patch.object( dialog._optional_tags_helper, "suggest_tags_for_groups" ) @@ -2485,7 +2463,7 @@ def test_submit_without_review_checkbox_hidden_when_user_cant_use_it( qtbot: QtBot, install_ah_deck: InstallAHDeck, import_ah_note: ImportAHNote, - mock_function: MockFunction, + mocker: MockerFixture, user_relation: UserDeckExtensionRelation, expected_checkbox_is_visible: bool, ): @@ -2507,12 +2485,12 @@ def test_submit_without_review_checkbox_hidden_when_user_cant_use_it( deck_extension_id=deck_extension.id, ) - mock_function( + mocker.patch( "ankihub.gui.optional_tag_suggestion_dialog.AnkiHubClient.get_deck_extensions", return_value=[deck_extension], ) - mock_function( + mocker.patch( "ankihub.main.optional_tag_suggestions.AnkiHubClient.prevalidate_tag_groups", return_value=[validation_reponse], ) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index feafe31dc..e678b27bc 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -10,11 +10,11 @@ from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Callable, Generator, List, cast -from unittest.mock import MagicMock, Mock import pytest import requests_mock -from pytest import FixtureRequest, MonkeyPatch +from pytest import FixtureRequest +from pytest_mock import MockerFixture from requests_mock import Mocker from vcr import VCR # type: ignore @@ -402,13 +402,13 @@ def test_client_login_and_signout_with_email(client_with_server_setup): @pytest.mark.vcr() def test_download_deck( - authorized_client_for_user_test1: AnkiHubClient, monkeypatch: MonkeyPatch + authorized_client_for_user_test1: AnkiHubClient, mocker: MockerFixture ): client = authorized_client_for_user_test1 - - get_presigned_url_suffix = MagicMock() - get_presigned_url_suffix.return_value = "/fake_key" - monkeypatch.setattr(client, "_get_presigned_url_suffix", get_presigned_url_suffix) + presigned_url_suffix = "/fake_key" + mocker.patch.object( + client, "_get_presigned_url_suffix", return_value=presigned_url_suffix + ) original_get_deck_by_id = client.get_deck_by_id @@ -417,11 +417,11 @@ def get_deck_by_id(*args, **kwargs) -> Deck: result.csv_notes_filename = "notes.csv" return result - monkeypatch.setattr(client, "get_deck_by_id", get_deck_by_id) + mocker.patch.object(client, "get_deck_by_id", side_effect=get_deck_by_id) with requests_mock.Mocker(real_http=True) as m: m.get( - f"{client.s3_bucket_url}{get_presigned_url_suffix.return_value}", + f"{client.s3_bucket_url}{presigned_url_suffix}", content=DECK_CSV.read_bytes(), ) notes_data = client.download_deck(ah_did=ID_OF_DECK_OF_USER_TEST1) @@ -431,13 +431,15 @@ def get_deck_by_id(*args, **kwargs) -> Deck: @pytest.mark.vcr() def test_download_compressed_deck( - authorized_client_for_user_test1: AnkiHubClient, monkeypatch: MonkeyPatch + authorized_client_for_user_test1: AnkiHubClient, + mocker: MockerFixture, ): client = authorized_client_for_user_test1 - get_presigned_url_suffix = MagicMock() - get_presigned_url_suffix.return_value = "/fake_key" - monkeypatch.setattr(client, "_get_presigned_url_suffix", get_presigned_url_suffix) + presigned_url_suffix = "/fake_key" + mocker.patch.object( + client, "_get_presigned_url_suffix", return_value=presigned_url_suffix + ) original_get_deck_by_id = client.get_deck_by_id @@ -446,11 +448,11 @@ def get_deck_by_id(*args, **kwargs) -> Deck: result.csv_notes_filename = "notes.csv.gz" return result - monkeypatch.setattr(client, "get_deck_by_id", get_deck_by_id) + mocker.patch.object(client, "get_deck_by_id", side_effect=get_deck_by_id) with requests_mock.Mocker(real_http=True) as m: m.get( - f"{client.s3_bucket_url}{get_presigned_url_suffix.return_value}", + f"{client.s3_bucket_url}{presigned_url_suffix}", content=DECK_CSV_GZ.read_bytes(), ) notes_data = client.download_deck(ah_did=ID_OF_DECK_OF_USER_TEST1) @@ -460,13 +462,14 @@ def get_deck_by_id(*args, **kwargs) -> Deck: @pytest.mark.vcr() def test_download_deck_with_progress( - authorized_client_for_user_test1: AnkiHubClient, monkeypatch: MonkeyPatch + authorized_client_for_user_test1: AnkiHubClient, mocker: MockerFixture ): client = authorized_client_for_user_test1 - get_presigned_url_suffix = MagicMock() - get_presigned_url_suffix.return_value = "/fake_key" - monkeypatch.setattr(client, "_get_presigned_url_suffix", get_presigned_url_suffix) + presigned_url_suffix = "/fake_key" + mocker.patch.object( + client, "_get_presigned_url_suffix", return_value=presigned_url_suffix + ) original_get_deck_by_id = client.get_deck_by_id @@ -475,11 +478,11 @@ def get_deck_by_id(*args, **kwargs) -> Deck: result.csv_notes_filename = "notes.csv" return result - monkeypatch.setattr(client, "get_deck_by_id", get_deck_by_id) + mocker.patch.object(client, "get_deck_by_id", side_effect=get_deck_by_id) with requests_mock.Mocker(real_http=True) as m: m.get( - f"{client.s3_bucket_url}{get_presigned_url_suffix.return_value}", + f"{client.s3_bucket_url}{presigned_url_suffix}", content=DECK_CSV.read_bytes(), headers={"content-length": "1000000"}, ) @@ -516,7 +519,7 @@ def create_note_on_ankihub_and_assert( def test_upload_deck( authorized_client_for_user_test1: AnkiHubClient, next_deterministic_id: Callable[[], int], - monkeypatch: MonkeyPatch, + mocker: MockerFixture, ): client = authorized_client_for_user_test1 @@ -525,20 +528,16 @@ def test_upload_deck( # create the deck on AnkiHub # upload to s3 is mocked out, this will potentially cause errors on the locally running AnkiHub # because the deck will not be uploaded to s3, but we don't care about that here - upload_to_s3_mock = Mock() - with monkeypatch.context() as m: - m.setattr(client, "_upload_to_s3", upload_to_s3_mock) - m.setattr( - client, "_get_presigned_url_suffix", lambda *args, **kwargs: "fake_key" - ) - - client.upload_deck( - deck_name="test deck", - notes_data=[note_data], - note_types_data=[], - anki_deck_id=next_deterministic_id(), - private=False, - ) + upload_to_s3_mock = mocker.patch.object(client, "_upload_to_s3") + mocker.patch.object(client, "_get_presigned_url_suffix", return_value="fake_key") + + client.upload_deck( + deck_name="test deck", + notes_data=[note_data], + note_types_data=[], + anki_deck_id=next_deterministic_id(), + private=False, + ) # check that the deck would be uploaded to s3 assert upload_to_s3_mock.call_count == 1 @@ -858,14 +857,12 @@ def test_basic( class TestGetDeckUpdates: @pytest.mark.vcr() def test_get_deck_updates( - self, - authorized_client_for_user_test2: AnkiHubClient, - monkeypatch: MonkeyPatch, + self, authorized_client_for_user_test2: AnkiHubClient, mocker: MockerFixture ): client = authorized_client_for_user_test2 page_size = 5 - monkeypatch.setattr( + mocker.patch( "ankihub.ankihub_client.ankihub_client.DECK_UPDATE_PAGE_SIZE", page_size ) update_chunks: List[DeckUpdateChunk] = list( @@ -980,15 +977,13 @@ def test_get_media_since( @pytest.mark.vcr() def test_pagination( - self, - authorized_client_for_user_test1: AnkiHubClient, - monkeypatch: MonkeyPatch, + self, authorized_client_for_user_test1: AnkiHubClient, mocker: MockerFixture ): client = authorized_client_for_user_test1 # Set page size to 1 so that we can test pagination page_size = 1 - monkeypatch.setattr( + mocker.patch( "ankihub.ankihub_client.ankihub_client.DECK_MEDIA_UPDATE_PAGE_SIZE", page_size, ) @@ -1085,9 +1080,7 @@ def test_get_note_customizations_by_deck_extension_id( @pytest.mark.vcr() -def test_get_media_disabled_fields( - authorized_client_for_user_test1: AnkiHubClient, monkeypatch: MonkeyPatch -): +def test_get_media_disabled_fields(authorized_client_for_user_test1: AnkiHubClient): client = authorized_client_for_user_test1 deck_uuid = ID_OF_DECK_OF_USER_TEST1 @@ -1125,14 +1118,16 @@ def test_media_upload_finished(authorized_client_for_user_test1: AnkiHubClient): @pytest.mark.vcr() def test_get_note_customizations_by_deck_extension_id_in_multiple_chunks( - authorized_client_for_user_test1: AnkiHubClient, monkeypatch: MonkeyPatch + authorized_client_for_user_test1: AnkiHubClient, mocker: MockerFixture ): client = authorized_client_for_user_test1 deck_extension_id = 999 - monkeypatch.setattr( - "ankihub.ankihub_client.ankihub_client.DECK_EXTENSION_UPDATE_PAGE_SIZE", 1 + page_size = 1 + mocker.patch( + "ankihub.ankihub_client.ankihub_client.DECK_EXTENSION_UPDATE_PAGE_SIZE", + page_size, ) expected_chunk_1 = DeckExtensionUpdateChunk( @@ -1307,7 +1302,7 @@ def test_upload_media_for_suggestion( self, suggestion_type: str, requests_mock: Mocker, - monkeypatch, + mocker: MockerFixture, next_deterministic_uuid: Callable[[], uuid.UUID], remove_generated_media_files, request: FixtureRequest, @@ -1332,10 +1327,10 @@ def test_upload_media_for_suggestion( suggestion_request_mock = None - monkeypatch.setattr( + mocker.patch.object( AnkiHubClient, "_get_presigned_url_for_multiple_uploads", - lambda *args, **kwargs: { + return_value={ "url": fake_presigned_url, "fields": { "key": "deck_images/test/${filename}", @@ -1448,25 +1443,20 @@ def notes_data_with_many_media_files(self) -> List[NoteInfo]: return notes_data def test_zips_media_files_from_deck_notes( - self, next_deterministic_uuid: Callable[[], uuid.UUID], monkeypatch: MonkeyPatch + self, + next_deterministic_uuid: Callable[[], uuid.UUID], + mocker: MockerFixture, ): client = AnkiHubClient(local_media_dir_path_cb=lambda: TEST_MEDIA_PATH) notes_data = self.notes_data_with_many_media_files() - # Mock os.remove so the zip is not deleted - os_remove_mock = MagicMock() - monkeypatch.setattr(os, "remove", os_remove_mock) - # Mock upload-related stuff - monkeypatch.setattr( - client, "_get_presigned_url_for_multiple_uploads", MagicMock() - ) - monkeypatch.setattr( - client, "_upload_file_to_s3_with_reusable_presigned_url", MagicMock() - ) + mocker.patch.object(client, "_get_presigned_url_for_multiple_uploads") + mocker.patch.object(client, "_upload_file_to_s3_with_reusable_presigned_url") deck_id = next_deterministic_uuid() + remove_mock = mocker.patch("os.remove") self._upload_media_for_notes_data(client, notes_data, deck_id) # We will create and check for just one chunk in this test @@ -1481,12 +1471,12 @@ def test_zips_media_files_from_deck_notes( assert set(zip_ref.namelist()) == set(all_media_names_in_notes) # Remove the zipped file at the end of the test - monkeypatch.undo() + mocker.stop(remove_mock) os.remove(path_to_created_zip_file) assert path_to_created_zip_file.is_file() is False def test_uploads_generated_zipped_file( - self, next_deterministic_uuid: Callable[[], uuid.UUID], monkeypatch: MonkeyPatch + self, next_deterministic_uuid: Callable[[], uuid.UUID], mocker: MockerFixture ): client = AnkiHubClient(local_media_dir_path_cb=lambda: TEST_MEDIA_PATH) @@ -1507,17 +1497,14 @@ def test_uploads_generated_zipped_file( "x-amz-signature": "test_822ac386d1ece605db8cfca", }, } - get_presigned_url_mock = MagicMock() - get_presigned_url_mock.return_value = s3_info_mocked_value - monkeypatch.setattr( - client, "_get_presigned_url_for_multiple_uploads", get_presigned_url_mock + get_presigned_url_mock = mocker.patch.object( + client, + "_get_presigned_url_for_multiple_uploads", + return_value=s3_info_mocked_value, ) - - mocked_upload_file_to_s3 = MagicMock() - monkeypatch.setattr( + mocked_upload_file_to_s3 = mocker.patch.object( client, "_upload_file_to_s3_with_reusable_presigned_url", - mocked_upload_file_to_s3, ) self._upload_media_for_notes_data(client, notes_data, deck_id) @@ -1529,19 +1516,15 @@ def test_uploads_generated_zipped_file( ) def test_removes_zipped_file_after_upload( - self, next_deterministic_uuid: Callable[[], uuid.UUID], monkeypatch: MonkeyPatch + self, next_deterministic_uuid: Callable[[], uuid.UUID], mocker: MockerFixture ): client = AnkiHubClient(local_media_dir_path_cb=lambda: TEST_MEDIA_PATH) notes_data = self.notes_data_with_many_media_files() # Mock upload-related stuff - monkeypatch.setattr( - client, "_get_presigned_url_for_multiple_uploads", MagicMock() - ) - monkeypatch.setattr( - client, "_upload_file_to_s3_with_reusable_presigned_url", MagicMock() - ) + mocker.patch.object(client, "_get_presigned_url_for_multiple_uploads") + mocker.patch.object(client, "_upload_file_to_s3_with_reusable_presigned_url") deck_id = next_deterministic_uuid() self._upload_media_for_notes_data(client, notes_data, deck_id) diff --git a/tests/fixtures.py b/tests/fixtures.py index 71bdf0e4a..2361f2e5e 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -137,39 +137,6 @@ def mock_all_feature_flags_to_default_values_inner() -> None: return mock_all_feature_flags_to_default_values_inner -class MockFunction(Protocol): - def __call__( - self, - *args, - return_value: Optional[Any] = None, - side_effect: Optional[Callable] = None, - ) -> Mock: - ... - - -@pytest.fixture -def mock_function( - monkeypatch: MonkeyPatch, -) -> MockFunction: - def _mock_function( - *args, - return_value: Optional[Any] = None, - side_effect: Optional[Callable] = None, - ) -> Mock: - # The args can be either an object and a function name or the full path to the function as a string. - assert len(args) in [1, 2] - mock = Mock() - mock.return_value = return_value - monkeypatch.setattr( # type: ignore - *args, - mock, - ) - mock.side_effect = side_effect - return mock - - return _mock_function - - class ImportAHNote(Protocol): def __call__( self, From f99ed5c30baecb9cee72fef9b1245871d977507b Mon Sep 17 00:00:00 2001 From: Jakub Fidler <31575114+RisingOrange@users.noreply.github.com> Date: Tue, 2 Jan 2024 13:19:29 +0100 Subject: [PATCH 4/6] chore: fix flaky performance test by increasing allowed duration of the operation (#858) --- tests/addon/performance/test_exporting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/addon/performance/test_exporting.py b/tests/addon/performance/test_exporting.py index 77c42b6eb..704e58f4f 100644 --- a/tests/addon/performance/test_exporting.py +++ b/tests/addon/performance/test_exporting.py @@ -58,4 +58,4 @@ def export_notes(): duration = profile(export_notes) print(f"Exporting {len(notes)} notes took {duration} seconds") - assert duration < 0.1 + assert duration < 0.2 From 648c92305cf18437be3216c8fe5037f764b0ff96 Mon Sep 17 00:00:00 2001 From: Heitor Carvalho Date: Tue, 2 Jan 2024 10:45:57 -0300 Subject: [PATCH 5/6] feat: add confirmation dialog before sign out (#854) --- ankihub/gui/menu.py | 14 ++++++++- tests/addon/test_unit.py | 63 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/ankihub/gui/menu.py b/ankihub/gui/menu.py index bdc5fbd1d..53c96c31a 100644 --- a/ankihub/gui/menu.py +++ b/ankihub/gui/menu.py @@ -192,6 +192,18 @@ def _create_collaborative_deck_setup(parent: QMenu): parent.addAction(q_action) +def _confirm_sign_out(): + confirm = ask_user( + "Are you sure you want to Sign out?", + yes_button_label="Sign Out", + no_button_label="Cancel", + ) + if not confirm: + return + + _sign_out_action() + + def _sign_out_action(): try: AnkiHubClient().signout() @@ -459,7 +471,7 @@ def _trigger_install_release_version(): def _ankihub_logout_setup(parent: QMenu): q_action = QAction("🔑 Sign out", aqt.mw) - qconnect(q_action.triggered, _sign_out_action) + qconnect(q_action.triggered, _confirm_sign_out) parent.addAction(q_action) diff --git a/tests/addon/test_unit.py b/tests/addon/test_unit.py index f220db865..8de76459a 100644 --- a/tests/addon/test_unit.py +++ b/tests/addon/test_unit.py @@ -17,16 +17,22 @@ from anki.decks import DeckId from anki.models import NotetypeDict from anki.notes import Note, NoteId +from aqt import QMenu from aqt.qt import QDialog, QDialogButtonBox, Qt, QTimer, QWidget +from pytest import MonkeyPatch from pytest_anki import AnkiSession from pytest_mock import MockerFixture from pytestqt.qtbot import QtBot # type: ignore from requests import Response +from requests_mock import Mocker +from ankihub.ankihub_client.ankihub_client import DEFAULT_API_URL from ankihub.ankihub_client.models import ( # type: ignore CardReviewData, UserDeckExtensionRelation, ) +from ankihub.gui import menu +from ankihub.gui.config_dialog import setup_config_dialog_manager from ..factories import ( DeckExtensionFactory, @@ -72,7 +78,7 @@ upload_logs_in_background, ) from ankihub.gui.media_sync import media_sync -from ankihub.gui.menu import AnkiHubLogin +from ankihub.gui.menu import AnkiHubLogin, menu_state, refresh_ankihub_menu from ankihub.gui.operations import AddonQueryOp from ankihub.gui.operations.deck_creation import ( DeckCreationConfirmationDialog, @@ -498,6 +504,61 @@ def test_add_subdeck_tags_to_notes_with_spaces_in_deck_name( assert note3.tags == [f"{SUBDECK_TAG}::AA::b_b::c_c"] +class TestAnkiHubSignOut: + @pytest.mark.parametrize( + "confirmed_sign_out,expected_logged_in_state", [(True, False), (False, True)] + ) + def test_sign_out( + self, + monkeypatch: MonkeyPatch, + anki_session_with_addon_data: AnkiSession, + requests_mock: Mocker, + confirmed_sign_out: bool, + expected_logged_in_state: bool, + ): + anki_session = anki_session_with_addon_data + + with anki_session.profile_loaded(): + user_token = "random_token_382fasfkjep1flaksnioqwndjk&@*(%248)" + # This means user is logged in + config._private_config.token = user_token + + mw = anki_session.mw + menu_state.ankihub_menu = QMenu("&AnkiHub", parent=aqt.mw) + mw.form.menubar.addMenu(menu_state.ankihub_menu) + setup_config_dialog_manager() + refresh_ankihub_menu() + + sign_out_action = [ + action + for action in menu_state.ankihub_menu.actions() + if action.text() == "🔑 Sign out" + ][0] + + assert sign_out_action is not None + + ask_user_mock = Mock(return_value=confirmed_sign_out) + monkeypatch.setattr(menu, "ask_user", ask_user_mock) + + requests_mock.post(f"{DEFAULT_API_URL}/logout/", status_code=204, json=[]) + + sign_out_action.trigger() + + ask_user_mock.assert_called_once_with( + "Are you sure you want to Sign out?", + yes_button_label="Sign Out", + no_button_label="Cancel", + ) + + if expected_logged_in_state: + expected_token = user_token + else: + expected_token = "" + + assert config.token() == expected_token + assert config.is_logged_in() is expected_logged_in_state + + class TestAnkiHubLoginDialog: def test_login( self, From b31038d561e2229f0c4c340456c22e08baf98510 Mon Sep 17 00:00:00 2001 From: Jakub Fidler <31575114+RisingOrange@users.noreply.github.com> Date: Wed, 3 Jan 2024 12:51:43 +0100 Subject: [PATCH 6/6] chore: Fix flaky optional tag suggestion dialog test (#857) --- tests/addon/test_unit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/addon/test_unit.py b/tests/addon/test_unit.py index 8de76459a..46d3f3fb3 100644 --- a/tests/addon/test_unit.py +++ b/tests/addon/test_unit.py @@ -2563,5 +2563,6 @@ def test_submit_without_review_checkbox_hidden_when_user_cant_use_it( ) dialog.show() + qtbot.wait(500) assert dialog.auto_accept_cb.isVisible() == expected_checkbox_is_visible