diff --git a/sergeant/executor/threaded.py b/sergeant/executor/threaded.py index 8004f6b..353061a 100644 --- a/sergeant/executor/threaded.py +++ b/sergeant/executor/threaded.py @@ -21,13 +21,19 @@ def __init__( has_soft_timeout = self.worker_object.config.timeouts.soft_timeout > 0 self.should_use_a_killer = has_soft_timeout - self.thread_killers: typing.Dict[int, killer.thread.Killer] = {} + self.thread_killer = killer.thread.Killer( + exception=worker.WorkerSoftTimedout, + sleep_interval=0.1, + ) self.interrupt_exception: typing.Optional[Exception] = None def execute_tasks( self, tasks: typing.Iterable[objects.Task], ) -> None: + if self.should_use_a_killer: + self.thread_killer.start() + with concurrent.futures.ThreadPoolExecutor( max_workers=self.number_of_threads, ) as executor: @@ -48,9 +54,8 @@ def execute_tasks( if self.interrupt_exception: break - for thread_killer in self.thread_killers.values(): - thread_killer.kill() - self.thread_killers.clear() + if self.should_use_a_killer: + self.thread_killer.stop() if self.interrupt_exception: raise self.interrupt_exception @@ -154,18 +159,10 @@ def pre_work( ) if self.should_use_a_killer: - current_thread_id = threading.get_ident() - - if current_thread_id in self.thread_killers: - self.thread_killers[current_thread_id].reset() - self.thread_killers[current_thread_id].resume() - else: - self.thread_killers[current_thread_id] = killer.thread.Killer( - thread_id=current_thread_id, - timeout=self.worker_object.config.timeouts.soft_timeout, - exception=worker.WorkerSoftTimedout, - ) - self.thread_killers[current_thread_id].start() + self.thread_killer.add( + thread_id=threading.get_ident(), + timeout=self.worker_object.config.timeouts.soft_timeout, + ) def post_work( self, @@ -174,7 +171,9 @@ def post_work( exception: typing.Optional[Exception] = None, ) -> None: if self.should_use_a_killer: - self.thread_killers[threading.get_ident()].suspend() + self.thread_killer.remove( + thread_id=threading.get_ident(), + ) try: self.worker_object.post_work( @@ -191,11 +190,3 @@ def post_work( 'exception': exception, }, ) - - def __del__( - self, - ) -> None: - for thread_killer in self.thread_killers.values(): - thread_killer.kill() - - self.thread_killers.clear() diff --git a/sergeant/killer/thread.py b/sergeant/killer/thread.py index 2ed919c..1552fe3 100644 --- a/sergeant/killer/thread.py +++ b/sergeant/killer/thread.py @@ -9,71 +9,79 @@ class Killer( ): def __init__( self, - thread_id: int, - timeout: float, exception: typing.Type[BaseException], sleep_interval: float = 0.1, ) -> None: super().__init__() - self.timeout = timeout self.exception = exception - self.thread_id = thread_id self.sleep_interval = sleep_interval - self.time_elapsed = 0.0 + self.thread_to_end_time: typing.Dict[int, float] = {} self.enabled = True - self.running = True + self.started = False + self.finished = threading.Event() self.lock = threading.Lock() def run( self, ) -> None: - while self.enabled: - if not self.running: - time.sleep(self.sleep_interval) - - continue - - if self.time_elapsed < self.timeout: - self.time_elapsed += self.sleep_interval - time.sleep(self.sleep_interval) - - continue + self.started = True + while self.enabled: with self.lock: - if self.running: - self.running = False + for thread_id, end_time in list(self.thread_to_end_time.items()): + if time.time() > end_time: + del self.thread_to_end_time[thread_id] + self.raise_exception_in_thread( + thread_id=thread_id, + exception=self.exception, + ) - ctypes.pythonapi.PyThreadState_SetAsyncExc( - ctypes.c_ulong(self.thread_id), - ctypes.py_object(self.exception), - ) + time.sleep(self.sleep_interval) - def kill( - self, - ) -> None: - self.enabled = False + self.finished.set() - def suspend( + def raise_exception_in_thread( + self, + thread_id: int, + exception: typing.Type[BaseException], + ) -> bool: + try: + ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_ulong(thread_id), + ctypes.py_object(exception), + ) + + return True + except Exception: + return False + + def stop( self, ) -> None: - with self.lock: - self.running = False + if self.started: + self.enabled = False + self.finished.wait() - def resume( + def remove( self, + thread_id: int, ) -> None: with self.lock: - self.running = True + if thread_id in self.thread_to_end_time: + del self.thread_to_end_time[thread_id] - def reset( + def add( self, + thread_id: int, + timeout: float, ) -> None: - self.time_elapsed = 0 + with self.lock: + self.thread_to_end_time[thread_id] = time.time() + timeout def __del__( self, ) -> None: - self.kill() + self.stop() diff --git a/tests/executor/test_threaded.py b/tests/executor/test_threaded.py index f287fca..d044931 100644 --- a/tests/executor/test_threaded.py +++ b/tests/executor/test_threaded.py @@ -21,6 +21,9 @@ def setUp( type='', params={}, ), + timeouts=sergeant.config.Timeouts( + soft_timeout=1.0, + ) ) self.worker.work = unittest.mock.MagicMock( return_value=True, @@ -40,95 +43,95 @@ def setUp( def test_pre_work( self, ): - threaded_executor = sergeant.executor.threaded.ThreadedExecutor( - worker_object=self.worker, - number_of_threads=1, - ) - - task = sergeant.objects.Task() - threaded_executor.pre_work( - task=task, - ) - self.worker.pre_work.assert_called_once_with( - task=task, - ) - self.worker.logger.error.assert_not_called() - - self.worker.pre_work.side_effect = Exception('exception message') - threaded_executor.pre_work( - task=task, - ) - self.worker.logger.error.assert_called_once_with( - msg='pre_work has failed: exception message', - extra={ - 'task': task, - }, - ) - - threaded_executor.should_use_a_killer = True with unittest.mock.patch( target='sergeant.killer.thread.Killer', ): + threaded_executor = sergeant.executor.threaded.ThreadedExecutor( + worker_object=self.worker, + number_of_threads=1, + ) + + task = sergeant.objects.Task() + threaded_executor.pre_work( + task=task, + ) + self.worker.pre_work.assert_called_once_with( + task=task, + ) + self.worker.logger.error.assert_not_called() + + self.worker.pre_work.side_effect = Exception('exception message') threaded_executor.pre_work( task=task, ) - threaded_executor.thread_killers[threading.get_ident()].start.assert_called_once() - threaded_executor.thread_killers[threading.get_ident()].start.reset_mock() + self.worker.logger.error.assert_called_once_with( + msg='pre_work has failed: exception message', + extra={ + 'task': task, + }, + ) + threaded_executor.should_use_a_killer = True + threaded_executor.thread_killer.add.reset_mock() threaded_executor.pre_work( task=task, ) - threaded_executor.thread_killers[threading.get_ident()].reset.assert_called_once() - threaded_executor.thread_killers[threading.get_ident()].resume.assert_called_once() - threaded_executor.thread_killers[threading.get_ident()].start.assert_not_called() + threaded_executor.thread_killer.add.assert_called_once_with( + thread_id=threading.get_ident(), + timeout=self.worker.config.timeouts.soft_timeout, + ) def test_post_work( self, ): - threaded_executor = sergeant.executor.threaded.ThreadedExecutor( - worker_object=self.worker, - number_of_threads=1, - ) - threaded_executor.thread_killers[threading.get_ident()] = unittest.mock.MagicMock() + with unittest.mock.patch( + target='sergeant.killer.thread.Killer', + ): + threaded_executor = sergeant.executor.threaded.ThreadedExecutor( + worker_object=self.worker, + number_of_threads=1, + ) + threaded_executor.should_use_a_killer = False - task = sergeant.objects.Task() - threaded_executor.post_work( - task=task, - success=True, - exception=None, - ) - self.worker.post_work.assert_called_once_with( - task=task, - success=True, - exception=None, - ) - threaded_executor.thread_killers[threading.get_ident()].suspend.assert_not_called() - self.worker.logger.error.assert_not_called() + task = sergeant.objects.Task() + threaded_executor.post_work( + task=task, + success=True, + exception=None, + ) + self.worker.post_work.assert_called_once_with( + task=task, + success=True, + exception=None, + ) + threaded_executor.thread_killer.remove.assert_not_called() + self.worker.logger.error.assert_not_called() - exception = Exception('exception message') - self.worker.post_work.side_effect = exception - threaded_executor.post_work( - task=task, - success=True, - exception=None, - ) - self.worker.logger.error.assert_called_once_with( - msg='post_work has failed: exception message', - extra={ - 'task': task, - 'success': True, - 'exception': exception, - }, - ) + exception = Exception('exception message') + self.worker.post_work.side_effect = exception + threaded_executor.post_work( + task=task, + success=True, + exception=None, + ) + self.worker.logger.error.assert_called_once_with( + msg='post_work has failed: exception message', + extra={ + 'task': task, + 'success': True, + 'exception': exception, + }, + ) - threaded_executor.thread_killers[threading.get_ident()] = unittest.mock.MagicMock() - threaded_executor.should_use_a_killer = True - threaded_executor.post_work( - task=task, - success=True, - exception=None, - ) - threaded_executor.thread_killers[threading.get_ident()].suspend.assert_called_once() + threaded_executor.should_use_a_killer = True + threaded_executor.post_work( + task=task, + success=True, + exception=None, + ) + threaded_executor.thread_killer.remove.assert_called_once_with( + thread_id=threading.get_ident(), + ) def test_success( self, diff --git a/tests/killer/test_thread.py b/tests/killer/test_thread.py index 94a8734..1513d5b 100644 --- a/tests/killer/test_thread.py +++ b/tests/killer/test_thread.py @@ -12,6 +12,7 @@ def setUp( self, ): self.raised_exception = None + self.thread_enabled = True def tearDown( self, @@ -22,7 +23,7 @@ def thread_function( self, ): try: - while True: + while self.thread_enabled: time.sleep(0.05) except Exception as exception: self.raised_exception = exception @@ -36,11 +37,13 @@ def test_simple( thread.start() killer = sergeant.killer.thread.Killer( - thread_id=thread.ident, - timeout=0.5, exception=ExceptionTest, ) killer.start() + killer.add( + thread_id=thread.ident, + timeout=0.5, + ) self.assertTrue( expr=thread.is_alive(), ) @@ -52,23 +55,20 @@ def test_simple( obj=self.raised_exception, cls=ExceptionTest, ) + self.assertEqual( - first=killer.time_elapsed, - second=0.5, - ) - self.assertFalse( - expr=killer.running, + first=killer.thread_to_end_time, + second={}, ) self.assertTrue( expr=killer.enabled, ) - - killer.kill() + killer.stop() self.assertFalse( expr=killer.enabled, ) - def test_suspend_resume( + def test_remove_while_running( self, ): thread = threading.Thread( @@ -77,156 +77,39 @@ def test_suspend_resume( thread.start() killer = sergeant.killer.thread.Killer( - thread_id=thread.ident, - timeout=0.7, exception=ExceptionTest, ) killer.start() - self.assertTrue( - expr=thread.is_alive(), - ) - - time.sleep(0.5) - killer.suspend() - self.assertEqual( - first=killer.time_elapsed, - second=0.5, - ) - self.assertFalse( - expr=killer.running, - ) - self.assertTrue( - expr=killer.enabled, - ) - - time.sleep(0.3) - self.assertTrue( - expr=thread.is_alive(), - ) - self.assertEqual( - first=killer.time_elapsed, - second=0.5, - ) - self.assertFalse( - expr=killer.running, - ) - self.assertTrue( - expr=killer.enabled, - ) - - killer.resume() - time.sleep(0.4) - self.assertFalse( - expr=thread.is_alive(), - ) - self.assertEqual( - first=killer.time_elapsed, - second=0.7, - ) - self.assertFalse( - expr=killer.running, - ) - self.assertTrue( - expr=killer.enabled, - ) - self.assertIsInstance( - obj=self.raised_exception, - cls=ExceptionTest, - ) - - killer.kill() - self.assertFalse( - expr=killer.enabled, - ) - - def test_reset_while_running( - self, - ): - thread = threading.Thread( - target=self.thread_function, - ) - thread.start() - - killer = sergeant.killer.thread.Killer( + killer.add( thread_id=thread.ident, timeout=0.5, - exception=ExceptionTest, ) - killer.start() self.assertTrue( expr=thread.is_alive(), ) - time.sleep(0.3) - killer.reset() time.sleep(0.3) - self.assertTrue( - expr=thread.is_alive(), - ) - time.sleep(0.4) - self.assertFalse( - expr=thread.is_alive(), - ) - self.assertIsInstance( - obj=self.raised_exception, - cls=ExceptionTest, - ) - killer.kill() - - def test_reuse_after_kill( - self, - ): - thread = threading.Thread( - target=self.thread_function, + self.assertNotEqual( + first=killer.thread_to_end_time, + second={}, ) - thread.start() - - killer = sergeant.killer.thread.Killer( + killer.remove( thread_id=thread.ident, - timeout=0.3, - exception=ExceptionTest, - ) - killer.start() - - self.assertTrue( - expr=thread.is_alive(), - ) - time.sleep(0.5) - self.assertFalse( - expr=thread.is_alive(), - ) - self.assertIsInstance( - obj=self.raised_exception, - cls=ExceptionTest, ) - self.assertFalse( - expr=killer.running, - ) - - thread = threading.Thread( - target=self.thread_function, + self.assertEqual( + first=killer.thread_to_end_time, + second={}, ) - thread.start() + time.sleep(0.3) - killer.thread_id = thread.ident - killer.reset() - killer.resume() self.assertTrue( expr=thread.is_alive(), ) - time.sleep(0.5) - self.assertFalse( - expr=thread.is_alive(), - ) - self.assertIsInstance( - obj=self.raised_exception, - cls=ExceptionTest, - ) - self.assertFalse( - expr=killer.running, - ) - killer.kill() + + killer.stop() + self.thread_enabled = False + thread.join() class ExceptionTest(