diff --git a/arq/worker.py b/arq/worker.py index e4ac1b7b..2bdab0f0 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -239,7 +239,11 @@ def __init__( self.on_job_start = on_job_start self.on_job_end = on_job_end self.after_job_end = after_job_end - self.sem = asyncio.BoundedSemaphore(max_jobs) + + self.max_jobs = max_jobs + self.sem = asyncio.BoundedSemaphore(max_jobs + 1) + self.job_counter: int = 0 + self.job_timeout_s = to_seconds(job_timeout) self.keep_result_s = to_seconds(keep_result) self.keep_result_forever = keep_result_forever @@ -377,13 +381,13 @@ async def _poll_iteration(self) -> None: return count = min(burst_jobs_remaining, count) if self.allow_pick_jobs: - async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs + if self.job_counter < self.max_jobs: now = timestamp_ms() job_ids = await self.pool.zrangebyscore( self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now ) - await self.start_jobs(job_ids) + await self.start_jobs(job_ids) if self.allow_abort_jobs: await self._cancel_aborted_jobs() @@ -422,12 +426,23 @@ async def _cancel_aborted_jobs(self) -> None: self.aborting_tasks.update(aborted) await self.pool.zrem(abort_jobs_ss, *aborted) + def _release_sem_dec_counter_on_complete(self) -> None: + self.job_counter = self.job_counter - 1 + self.sem.release() + async def start_jobs(self, job_ids: List[bytes]) -> None: """ For each job id, get the job definition, check it's not running and start it in a task """ for job_id_b in job_ids: await self.sem.acquire() + + if self.job_counter >= self.max_jobs: + self.sem.release() + return None + + self.job_counter = self.job_counter + 1 + job_id = job_id_b.decode() in_progress_key = in_progress_key_prefix + job_id async with self.pool.pipeline(transaction=True) as pipe: @@ -436,6 +451,7 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: score = await pipe.zscore(self.queue_name, job_id) if ongoing_exists or not score: # job already started elsewhere, or already finished and removed from queue + self.job_counter = self.job_counter - 1 self.sem.release() logger.debug('job %s already running elsewhere', job_id) continue @@ -448,11 +464,12 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: await pipe.execute() except (ResponseError, WatchError): # job already started elsewhere since we got 'existing' + self.job_counter = self.job_counter - 1 self.sem.release() logger.debug('multi-exec error, job %s already started elsewhere', job_id) else: t = self.loop.create_task(self.run_job(job_id, int(score))) - t.add_done_callback(lambda _: self.sem.release()) + t.add_done_callback(lambda _: self._release_sem_dec_counter_on_complete()) self.tasks[job_id] = t async def run_job(self, job_id: str, score: int) -> None: # noqa: C901 diff --git a/tests/test_worker.py b/tests/test_worker.py index aa56085b..23dd91d2 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -984,6 +984,36 @@ async def test(ctx): assert result['called'] == 4 +async def test_job_cancel_on_max_jobs(arq_redis: ArqRedis, worker, caplog): + async def longfunc(ctx): + await asyncio.sleep(3600) + + async def wait_and_abort(job, delay=0.1): + await asyncio.sleep(delay) + assert await job.abort() is True + + caplog.set_level(logging.INFO) + await arq_redis.zadd(abort_jobs_ss, {b'foobar': int(1e9)}) + job = await arq_redis.enqueue_job('longfunc', _job_id='testing') + + worker: Worker = worker( + functions=[func(longfunc, name='longfunc')], allow_abort_jobs=True, poll_delay=0.1, max_jobs=1 + ) + assert worker.jobs_complete == 0 + assert worker.jobs_failed == 0 + assert worker.jobs_retried == 0 + await asyncio.gather(wait_and_abort(job), worker.main()) + await worker.main() + assert worker.jobs_complete == 0 + assert worker.jobs_failed == 1 + assert worker.jobs_retried == 0 + log = re.sub(r'\d+.\d\ds', 'X.XXs', '\n'.join(r.message for r in caplog.records)) + assert 'X.XXs → testing:longfunc()\n X.XXs ⊘ testing:longfunc aborted' in log + assert worker.aborting_tasks == set() + assert worker.tasks == {} + assert worker.job_tasks == {} + + async def test_worker_timezone_defaults_to_system_timezone(worker): worker = worker(functions=[func(foobar)]) assert worker.timezone is not None