Skip to content

Commit

Permalink
replace thread killers with a polling thread killer
Browse files Browse the repository at this point in the history
  • Loading branch information
Gal Ben David committed Nov 29, 2020
1 parent 9ae3cd3 commit e6d74ea
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 273 deletions.
41 changes: 16 additions & 25 deletions sergeant/executor/threaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,19 @@ def __init__(

has_soft_timeout = self.worker_object.config.timeouts.soft_timeout > 0
self.should_use_a_killer = has_soft_timeout
self.thread_killers: typing.Dict[int, killer.thread.Killer] = {}
self.thread_killer = killer.thread.Killer(
exception=worker.WorkerSoftTimedout,
sleep_interval=0.1,
)
self.interrupt_exception: typing.Optional[Exception] = None

def execute_tasks(
self,
tasks: typing.Iterable[objects.Task],
) -> None:
if self.should_use_a_killer:
self.thread_killer.start()

with concurrent.futures.ThreadPoolExecutor(
max_workers=self.number_of_threads,
) as executor:
Expand All @@ -48,9 +54,8 @@ def execute_tasks(
if self.interrupt_exception:
break

for thread_killer in self.thread_killers.values():
thread_killer.kill()
self.thread_killers.clear()
if self.should_use_a_killer:
self.thread_killer.stop()

if self.interrupt_exception:
raise self.interrupt_exception
Expand Down Expand Up @@ -154,18 +159,10 @@ def pre_work(
)

if self.should_use_a_killer:
current_thread_id = threading.get_ident()

if current_thread_id in self.thread_killers:
self.thread_killers[current_thread_id].reset()
self.thread_killers[current_thread_id].resume()
else:
self.thread_killers[current_thread_id] = killer.thread.Killer(
thread_id=current_thread_id,
timeout=self.worker_object.config.timeouts.soft_timeout,
exception=worker.WorkerSoftTimedout,
)
self.thread_killers[current_thread_id].start()
self.thread_killer.add(
thread_id=threading.get_ident(),
timeout=self.worker_object.config.timeouts.soft_timeout,
)

def post_work(
self,
Expand All @@ -174,7 +171,9 @@ def post_work(
exception: typing.Optional[Exception] = None,
) -> None:
if self.should_use_a_killer:
self.thread_killers[threading.get_ident()].suspend()
self.thread_killer.remove(
thread_id=threading.get_ident(),
)

try:
self.worker_object.post_work(
Expand All @@ -191,11 +190,3 @@ def post_work(
'exception': exception,
},
)

def __del__(
self,
) -> None:
for thread_killer in self.thread_killers.values():
thread_killer.kill()

self.thread_killers.clear()
78 changes: 43 additions & 35 deletions sergeant/killer/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,71 +9,79 @@ class Killer(
):
def __init__(
self,
thread_id: int,
timeout: float,
exception: typing.Type[BaseException],
sleep_interval: float = 0.1,
) -> None:
super().__init__()

self.timeout = timeout
self.exception = exception
self.thread_id = thread_id
self.sleep_interval = sleep_interval

self.time_elapsed = 0.0
self.thread_to_end_time: typing.Dict[int, float] = {}
self.enabled = True
self.running = True
self.started = False

self.finished = threading.Event()
self.lock = threading.Lock()

def run(
self,
) -> None:
while self.enabled:
if not self.running:
time.sleep(self.sleep_interval)

continue

if self.time_elapsed < self.timeout:
self.time_elapsed += self.sleep_interval
time.sleep(self.sleep_interval)

continue
self.started = True

while self.enabled:
with self.lock:
if self.running:
self.running = False
for thread_id, end_time in list(self.thread_to_end_time.items()):
if time.time() > end_time:
del self.thread_to_end_time[thread_id]
self.raise_exception_in_thread(
thread_id=thread_id,
exception=self.exception,
)

ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_ulong(self.thread_id),
ctypes.py_object(self.exception),
)
time.sleep(self.sleep_interval)

def kill(
self,
) -> None:
self.enabled = False
self.finished.set()

def suspend(
def raise_exception_in_thread(
self,
thread_id: int,
exception: typing.Type[BaseException],
) -> bool:
try:
ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_ulong(thread_id),
ctypes.py_object(exception),
)

return True
except Exception:
return False

def stop(
self,
) -> None:
with self.lock:
self.running = False
if self.started:
self.enabled = False
self.finished.wait()

def resume(
def remove(
self,
thread_id: int,
) -> None:
with self.lock:
self.running = True
if thread_id in self.thread_to_end_time:
del self.thread_to_end_time[thread_id]

def reset(
def add(
self,
thread_id: int,
timeout: float,
) -> None:
self.time_elapsed = 0
with self.lock:
self.thread_to_end_time[thread_id] = time.time() + timeout

def __del__(
self,
) -> None:
self.kill()
self.stop()
147 changes: 75 additions & 72 deletions tests/executor/test_threaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def setUp(
type='',
params={},
),
timeouts=sergeant.config.Timeouts(
soft_timeout=1.0,
)
)
self.worker.work = unittest.mock.MagicMock(
return_value=True,
Expand All @@ -40,95 +43,95 @@ def setUp(
def test_pre_work(
self,
):
threaded_executor = sergeant.executor.threaded.ThreadedExecutor(
worker_object=self.worker,
number_of_threads=1,
)

task = sergeant.objects.Task()
threaded_executor.pre_work(
task=task,
)
self.worker.pre_work.assert_called_once_with(
task=task,
)
self.worker.logger.error.assert_not_called()

self.worker.pre_work.side_effect = Exception('exception message')
threaded_executor.pre_work(
task=task,
)
self.worker.logger.error.assert_called_once_with(
msg='pre_work has failed: exception message',
extra={
'task': task,
},
)

threaded_executor.should_use_a_killer = True
with unittest.mock.patch(
target='sergeant.killer.thread.Killer',
):
threaded_executor = sergeant.executor.threaded.ThreadedExecutor(
worker_object=self.worker,
number_of_threads=1,
)

task = sergeant.objects.Task()
threaded_executor.pre_work(
task=task,
)
self.worker.pre_work.assert_called_once_with(
task=task,
)
self.worker.logger.error.assert_not_called()

self.worker.pre_work.side_effect = Exception('exception message')
threaded_executor.pre_work(
task=task,
)
threaded_executor.thread_killers[threading.get_ident()].start.assert_called_once()
threaded_executor.thread_killers[threading.get_ident()].start.reset_mock()
self.worker.logger.error.assert_called_once_with(
msg='pre_work has failed: exception message',
extra={
'task': task,
},
)

threaded_executor.should_use_a_killer = True
threaded_executor.thread_killer.add.reset_mock()
threaded_executor.pre_work(
task=task,
)
threaded_executor.thread_killers[threading.get_ident()].reset.assert_called_once()
threaded_executor.thread_killers[threading.get_ident()].resume.assert_called_once()
threaded_executor.thread_killers[threading.get_ident()].start.assert_not_called()
threaded_executor.thread_killer.add.assert_called_once_with(
thread_id=threading.get_ident(),
timeout=self.worker.config.timeouts.soft_timeout,
)

def test_post_work(
self,
):
threaded_executor = sergeant.executor.threaded.ThreadedExecutor(
worker_object=self.worker,
number_of_threads=1,
)
threaded_executor.thread_killers[threading.get_ident()] = unittest.mock.MagicMock()
with unittest.mock.patch(
target='sergeant.killer.thread.Killer',
):
threaded_executor = sergeant.executor.threaded.ThreadedExecutor(
worker_object=self.worker,
number_of_threads=1,
)
threaded_executor.should_use_a_killer = False

task = sergeant.objects.Task()
threaded_executor.post_work(
task=task,
success=True,
exception=None,
)
self.worker.post_work.assert_called_once_with(
task=task,
success=True,
exception=None,
)
threaded_executor.thread_killers[threading.get_ident()].suspend.assert_not_called()
self.worker.logger.error.assert_not_called()
task = sergeant.objects.Task()
threaded_executor.post_work(
task=task,
success=True,
exception=None,
)
self.worker.post_work.assert_called_once_with(
task=task,
success=True,
exception=None,
)
threaded_executor.thread_killer.remove.assert_not_called()
self.worker.logger.error.assert_not_called()

exception = Exception('exception message')
self.worker.post_work.side_effect = exception
threaded_executor.post_work(
task=task,
success=True,
exception=None,
)
self.worker.logger.error.assert_called_once_with(
msg='post_work has failed: exception message',
extra={
'task': task,
'success': True,
'exception': exception,
},
)
exception = Exception('exception message')
self.worker.post_work.side_effect = exception
threaded_executor.post_work(
task=task,
success=True,
exception=None,
)
self.worker.logger.error.assert_called_once_with(
msg='post_work has failed: exception message',
extra={
'task': task,
'success': True,
'exception': exception,
},
)

threaded_executor.thread_killers[threading.get_ident()] = unittest.mock.MagicMock()
threaded_executor.should_use_a_killer = True
threaded_executor.post_work(
task=task,
success=True,
exception=None,
)
threaded_executor.thread_killers[threading.get_ident()].suspend.assert_called_once()
threaded_executor.should_use_a_killer = True
threaded_executor.post_work(
task=task,
success=True,
exception=None,
)
threaded_executor.thread_killer.remove.assert_called_once_with(
thread_id=threading.get_ident(),
)

def test_success(
self,
Expand Down
Loading

0 comments on commit e6d74ea

Please sign in to comment.