Skip to content

Commit

Permalink
Add test for tktrio
Browse files Browse the repository at this point in the history
  • Loading branch information
CoolCat467 committed Jan 17, 2025
1 parent ce475e2 commit 6b0edcb
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/idlemypyextension/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def read_fstring(
def tokenize_definition(
start_line: int,
get_line: Callable[[int], str],
) -> tuple[list[Token], int]:
) -> tuple[list[Token], int]: # pragma: nocover
"""Return list of Tokens and number of lines after start."""
current_line_no = start_line

Expand Down
49 changes: 24 additions & 25 deletions src/idlemypyextension/tktrio.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from enum import IntEnum, auto
from functools import partial, wraps
from tkinter import messagebox
from traceback import format_exception, print_exception
from traceback import format_exception
from typing import TYPE_CHECKING, Any, TypeGuard

from idlemypyextension import mttkinter, utils
Expand All @@ -41,21 +41,23 @@
with guard_imports({"trio", "exceptiongroup"}):
import trio

if sys.version_info < (3, 11):
if sys.version_info < (3, 11): # pragma: nocover
from exceptiongroup import ExceptionGroup

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable

from outcome import Outcome
from typing_extensions import Self
from typing_extensions import Self, TypeVarTuple, Unpack

PosArgT = TypeVarTuple("PosArgT")


# Use mttkinter somewhere so pycln doesn't eat it
assert mttkinter.TRUE


def debug(message: object) -> None:
def debug(message: object) -> None: # pragma: nocover
"""Print debug message."""
# TODO: Censor username/user files
print(f"\n[{__title__}] DEBUG: {message}")
Expand Down Expand Up @@ -97,7 +99,7 @@ def uninstall_protocol_override(root: tk.Wm) -> None:
If root.protocol has the __wrapped__ attribute, reset it
to what it was originally
"""
if hasattr(root.protocol, "__wrapped__"):
if hasattr(root.protocol, "__wrapped__"): # pragma: nocover
root.protocol = root.protocol.__wrapped__ # type: ignore


Expand All @@ -117,7 +119,7 @@ class RunStatus(IntEnum):
NO_TASK = auto()


def evil_does_trio_have_runner() -> bool:
def evil_does_trio_have_runner() -> bool: # pragma: nocover
"""Evil function to see if trio has a runner."""
core = getattr(trio, "_core", None)
if core is None:
Expand Down Expand Up @@ -183,8 +185,6 @@ def __init__(
restore_close: Callable[[], Any] | None = None,
) -> None:
"""Initialize trio runner."""
if not is_tk_wm_and_misc_subclass(root):
raise ValueError("Must be subclass of both tk.Misc and tk.Wm")
if (
hasattr(root, "__trio__")
and getattr(root, "__trio__", lambda: None)() is self
Expand All @@ -203,8 +203,8 @@ def __init__(

def schedule_task_threadsafe(
self,
function: Callable[..., Any],
*args: Any,
function: Callable[[Unpack[PosArgT]], object],
*args: Unpack[PosArgT],
) -> None:
"""Schedule task in Tkinter's event loop."""
try:
Expand All @@ -213,7 +213,6 @@ def schedule_task_threadsafe(
debug(f"Exception scheduling task {function = }")
# probably "main thread is not in main loop" error
# mtTkinter is supposed to fix this sort of issue
print_exception(exc)
utils.extension_log_exception(exc)

self.cancel_current_task()
Expand Down Expand Up @@ -253,7 +252,6 @@ def cancel_current_task(self) -> None:
# because the exception that triggered this was from
# a start group tick failing because of start_soon
# not running from main thread because thread lock shenanigans
print_exception(exc)
utils.extension_log_exception(exc)

# We can't even show an error properly because of the same
Expand All @@ -263,11 +261,23 @@ def cancel_current_task(self) -> None:
"".join(format_exception(exc)),
)
except RuntimeError as exc:
print_exception(exc)
utils.extension_log_exception(exc)
else:
self.run_status = RunStatus.TRIO_RUNNING_CANCELED

def _done_callback(self, outcome: Outcome[None]) -> None:
"""Handle when trio is done running."""
assert self.run_status in {
RunStatus.TRIO_RUNNING_CANCELED,
RunStatus.TRIO_RUNNING,
}
self.run_status = RunStatus.NO_TASK
del self.nursery
try:
outcome.unwrap()
except ExceptionGroup as exc:
utils.extension_log_exception(exc)

def _start_async_task(
self,
function: Callable[[], Awaitable[Any]],
Expand All @@ -290,18 +300,7 @@ async def run_nursery() -> None:
return

def done_callback(outcome: Outcome[None]) -> None:
"""Handle when trio is done running."""
assert self.run_status in {
RunStatus.TRIO_RUNNING_CANCELED,
RunStatus.TRIO_RUNNING,
}
self.run_status = RunStatus.NO_TASK
del self.nursery
try:
outcome.unwrap()
except ExceptionGroup as exc:
print_exception(exc)
utils.extension_log_exception(exc)
self._done_callback(outcome)

if self.run_status != RunStatus.NO_TASK:
raise RuntimeError(
Expand Down
21 changes: 21 additions & 0 deletions tests/test_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from io import StringIO
from tokenize import generate_tokens
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch

import pytest

Expand All @@ -15,6 +16,17 @@
from collections.abc import Collection, Sequence


@pytest.fixture(autouse=True)
def mock_extension_log() -> MagicMock:
"""Fixture to override extension_log with an empty function."""
with patch(
"idlemypyextension.utils.extension_log",
return_value=None,
) as mock_log:
with patch("idlemypyextension.annotate.extension_log", new=mock_log):
yield mock_log


def test_parse_error() -> None:
with pytest.raises(annotate.ParseError, match=""):
raise annotate.ParseError()
Expand Down Expand Up @@ -613,6 +625,15 @@ def get_line(line_no: int) -> str:
"""def bad_default_arg(name_map: set[str] = {"jerald", "cat", "bob"}) -> str:""",
None,
),
(
"""def lambda_arg_test(call_func = lambda x, y: x+y):""",
[
"Callable[[int, int], int]",
],
"int",
"""def lambda_arg_test(call_func: Callable[[int, int], int] = lambda x, y: x+y) -> int:""",
None,
),
],
)
def test_get_annotation(
Expand Down
182 changes: 182 additions & 0 deletions tests/test_tktrio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from __future__ import annotations

from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch

import pytest
import trio

from idlemypyextension.tktrio import RunStatus, TkTrioRunner

if TYPE_CHECKING:
import tkinter as tk
from collections.abc import Callable

from typing_extensions import TypeVarTuple, Unpack

PosArgT = TypeVarTuple("PosArgT")


@pytest.fixture(autouse=True)
def mock_extension_log() -> MagicMock:
"""Fixture to override extension_log with an empty function."""
with patch(
"idlemypyextension.utils.extension_log",
return_value=None,
) as mock_log:
yield mock_log


class FakeTK:
"""Fake tkinter root."""

__slots__ = ("__dict__", "should_stop", "tasks")

def __init__(self) -> None:
"""Initialize self."""
self.tasks: list[Callable[..., object]] = []
self.should_stop = False

def after_idle(
self,
function: Callable[[Unpack[PosArgT]], object],
*args: Unpack[PosArgT],
) -> None:
"""Add function to run queue."""
self.tasks.append(partial(function, *args))

def update(self) -> None:
"""Run one task from queue."""
if self.should_stop:
raise RuntimeError(f"{self.should_stop = }")
if self.tasks:
self.tasks.pop(0)()

def destroy(self) -> None:
"""Mark update to fail if called again."""
self.should_stop = True

def withdraw(self) -> None:
"""Fake hide window."""

def protocol(
self,
name: str | None,
func: Callable[[], None] | None = None,
) -> None:
"""Fake protocol."""


@pytest.fixture
def mock_tk() -> FakeTK:
"""Fixture to create a mock tkinter root."""
return FakeTK()


@pytest.fixture
def trio_runner(mock_tk: tk.Tk) -> TkTrioRunner:
"""Fixture to create a TkTrioRunner instance."""
with patch(
"idlemypyextension.tktrio.is_tk_wm_and_misc_subclass",
return_value=True,
):
return TkTrioRunner(mock_tk)


def test_initialization(trio_runner: TkTrioRunner) -> None:
"""Test initialization of TkTrioRunner."""
assert trio_runner.run_status == RunStatus.NO_TASK
assert trio_runner.root is not None


def test_new_gives_copy(mock_tk: tk.Tk) -> None:
with patch(
"idlemypyextension.tktrio.is_tk_wm_and_misc_subclass",
return_value=True,
):
runner = TkTrioRunner(mock_tk)
runner2 = TkTrioRunner(mock_tk)
assert runner is runner2


def test_invalid_initialization() -> None:
"""Test initialization with invalid root."""
with pytest.raises(
ValueError,
match=r"^Must be subclass of both tk\.Misc and tk\.Wm$",
):
TkTrioRunner(None)


def test_schedule_task_threadsafe(trio_runner: TkTrioRunner) -> None:
"""Test scheduling a task in the Tkinter event loop."""
mock_function = MagicMock()
trio_runner.schedule_task_threadsafe(mock_function)

# Process the scheduled tasks
trio_runner.root.update()

mock_function.assert_called_once()


def test_cancel_current_task(trio_runner: TkTrioRunner) -> None:
"""Test canceling the current task."""
assert trio_runner.run_status != RunStatus.TRIO_RUNNING

async def test() -> None:
while True:
await trio.lowlevel.checkpoint()

trio_runner(test)

while trio_runner.run_status == RunStatus.TRIO_STARTING:
trio_runner.root.update()

nursery = trio_runner.nursery

trio_runner.cancel_current_task()

assert trio_runner.run_status == RunStatus.TRIO_RUNNING_CANCELED

while trio_runner.run_status != RunStatus.NO_TASK:
trio_runner.root.update()

assert nursery.cancel_scope.cancel_called

trio_runner.cancel_current_task()


def test_get_del_window_proto(trio_runner: TkTrioRunner) -> None:
"""Test the WM_DELETE_WINDOW protocol."""
new_protocol = MagicMock()
shutdown_function = trio_runner.get_del_window_proto(new_protocol)

# Call the shutdown function
shutdown_function()

# Check if the new protocol was called
new_protocol.assert_called_once()


def test_show_warning_trio_already_running(trio_runner: TkTrioRunner) -> None:
"""Test showing warning when Trio is already running."""
with patch("tkinter.messagebox.showerror") as mock_showerror:
trio_runner.show_warning_trio_already_running()
mock_showerror.assert_called_once_with(
title="Error: Trio is already running",
message="Trio is already running from somewhere else, please try again later.",
parent=trio_runner.root,
)


def test_no_start_trio_is_stopping(trio_runner: TkTrioRunner) -> None:
"""Test showing warning when Trio is stopping."""
with patch("tkinter.messagebox.showwarning") as mock_showwarning:
trio_runner.no_start_trio_is_stopping()
mock_showwarning.assert_called_once_with(
title="Warning: Trio is stopping a previous run",
message="Trio is in the process of stopping, please try again later.",
parent=trio_runner.root,
)

0 comments on commit 6b0edcb

Please sign in to comment.