Skip to content

Commit

Permalink
improved exception handling/logging in apply meta task
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-berg committed Mar 31, 2022
1 parent 80fc91e commit 0daed04
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 36 deletions.
4 changes: 2 additions & 2 deletions src/asyncio_taskpool/internals/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from asyncio.streams import StreamReader, StreamWriter
from pathlib import Path
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
from typing import Any, Awaitable, Callable, Coroutine, Iterable, Mapping, Tuple, TypeVar, Union


T = TypeVar('T')
Expand All @@ -32,7 +32,7 @@
KwArgsT = Mapping[str, Any]

AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
CoroutineFunc = Callable[[...], Awaitable[Any]]
CoroutineFunc = Callable[[...], Coroutine]

EndCB = Callable
CancelCB = Callable
Expand Down
47 changes: 33 additions & 14 deletions src/asyncio_taskpool/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,8 @@ async def gather_and_close(self, return_exceptions: bool = False):
self._tasks_cancelled.clear()
self._tasks_running.clear()
self._closed = True
# TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will
# await it to allow blocking until a closing command comes from a server.


class TaskPool(BaseTaskPool):
Expand Down Expand Up @@ -566,36 +568,51 @@ def _generate_group_name(self, prefix: str, coroutine_function: CoroutineFunc) -
return name
i += 1

async def _apply_num(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
async def _apply_spawner(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
"""
Creates a coroutine with the supplied arguments and runs it as a new task in the pool.
Creates coroutines with the supplied arguments and runs them as new tasks in the pool.
This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks.
Args:
group_name:
Name of the task group to add the new task to.
Name of the task group to add the new tasks to.
func:
The coroutine function to be run as a task within the task pool.
The coroutine function to be run in `num` tasks within the task pool.
args (optional):
The positional arguments to pass into the function call.
The positional arguments to pass into each function call.
kwargs (optional):
The keyword-arguments to pass into the function call.
The keyword-arguments to pass into each function call.
num (optional):
The number of tasks to spawn with the specified parameters.
end_callback (optional):
A callback to execute after the task has ended.
A callback to execute after each task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of the task.
A callback to execute after cancellation of each task.
It is run with the task's ID as its only positional argument.
"""
if kwargs is None:
kwargs = {}
# TODO: Add exception logging
await gather(*(self._start_task(func(*args, **kwargs), group_name=group_name, end_callback=end_callback,
cancel_callback=cancel_callback) for _ in range(num)))
for i in range(num):
try:
coroutine = func(*args, **kwargs)
except Exception as e:
# This means there was probably something wrong with the function arguments.
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
continue
try:
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
cancel_callback=cancel_callback)
except CancelledError:
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately.
log.debug("Cancelled spawning tasks in group '%s' after %s out of %s tasks have been spawned",
group_name, i, num)
coroutine.close()
return

def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, group_name: str = None,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
Expand Down Expand Up @@ -650,8 +667,8 @@ def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, n
raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!")
self._task_groups.setdefault(group_name, TaskGroupRegister())
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
meta_tasks.add(create_task(self._apply_num(group_name, func, args, kwargs, num,
end_callback=end_callback, cancel_callback=cancel_callback)))
meta_tasks.add(create_task(self._apply_spawner(group_name, func, args, kwargs, num,
end_callback=end_callback, cancel_callback=cancel_callback)))
return group_name

@staticmethod
Expand Down Expand Up @@ -696,6 +713,8 @@ async def _arg_consumer(self, group_name: str, num_concurrent: int, func: Corout
# When the number of running tasks spawned by this method reaches the specified maximum,
# this next line will block, until one of them ends and releases the semaphore.
await map_semaphore.acquire()
# TODO: Clean up exception handling/logging. Cancellation can also occur while awaiting the semaphore.
# Wrap `star_function` call in a separate `try` block (similar to `_apply_spawner`).
try:
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
Expand Down
57 changes: 37 additions & 20 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,49 +455,65 @@ def test__generate_group_name(self):
self.assertEqual(expected_output, output)

@patch.object(pool.TaskPool, '_start_task')
async def test__apply_num(self, mock__start_task: AsyncMock):
group_name = FOO + BAR
mock_awaitable = object()
mock_func = MagicMock(return_value=mock_awaitable)
args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
async def test__apply_spawner(self, mock__start_task: AsyncMock):
grp_name = FOO + BAR
mock_awaitable1, mock_awaitable2 = object(), object()
mock_func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
args, kw, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb))
mock_func.assert_has_calls(3 * [call(*args, **kwargs)])
mock__start_task.assert_has_awaits(3 * [
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, kw, num, end_cb, cancel_cb))
mock_func.assert_has_calls(num * [call(*args, **kw)])
mock__start_task.assert_has_awaits([
call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
call(mock_awaitable2, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
])

mock_func.reset_mock()
mock_func.reset_mock(side_effect=True)
mock__start_task.reset_mock()

self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb))
mock_func.assert_has_calls(num * [call(*args)])
mock__start_task.assert_has_awaits(num * [
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
# Simulate cancellation while the second task is being started.
mock__start_task.side_effect = [None, CancelledError, None]
mock_coroutine_to_close = MagicMock()
mock_func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, None, num, end_cb, cancel_cb))
mock_func.assert_has_calls(2 * [call(*args)])
mock__start_task.assert_has_awaits([
call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
call(mock_coroutine_to_close, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
])
mock_coroutine_to_close.close.assert_called_once_with()

@patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock())
@patch.object(pool.TaskPool, '_apply_spawner', new_callable=MagicMock())
@patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.TaskPool, '_generate_group_name')
@patch.object(pool.BaseTaskPool, '_check_start')
def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock,
mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock):
mock_reg_cls: MagicMock, mock__apply_spawner: MagicMock, mock_create_task: MagicMock):
mock__generate_group_name.return_value = generated_name = 'name 123'
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock__apply_num.return_value = mock_apply_coroutine = object()
mock__apply_spawner.return_value = mock_apply_coroutine = object()
mock_create_task.return_value = fake_task = object()
mock_func, num, group_name = MagicMock(), 3, FOO + BAR
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()

self.task_pool._task_groups = {group_name: 'causes error'}
with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
mock__check_start.assert_called_once_with(function=mock_func)
mock__apply_spawner.assert_not_called()
mock_create_task.assert_not_called()

mock__check_start.reset_mock()
self.task_pool._task_groups = {}

def check_assertions(_group_name, _output):
self.assertEqual(_group_name, _output)
mock__check_start.assert_called_once_with(function=mock_func)
self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__apply_spawner.assert_called_once_with(_group_name, mock_func, args, kwargs, num,
end_callback=end_cb, cancel_callback=cancel_cb)
mock_create_task.assert_called_once_with(mock_apply_coroutine)
self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name])

Expand All @@ -507,7 +523,7 @@ def check_assertions(_group_name, _output):

mock__check_start.reset_mock()
self.task_pool._task_groups.clear()
mock__apply_num.reset_mock()
mock__apply_spawner.reset_mock()
mock_create_task.reset_mock()

output = self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
Expand Down Expand Up @@ -695,6 +711,7 @@ def setUp(self) -> None:

def tearDown(self) -> None:
self.base_class_init_patcher.stop()
super().tearDown()

def test_init(self):
self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func)
Expand Down

0 comments on commit 0daed04

Please sign in to comment.