Skip to content

Commit

Permalink
Merge pull request #178 from alces-software/dev/tweak-async
Browse files Browse the repository at this point in the history
Clean up the async `start_task`
  • Loading branch information
DavidMarchant authored Nov 19, 2018
2 parents 7aab33f + 8741ede commit 56b07c8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
32 changes: 20 additions & 12 deletions src/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,21 +126,30 @@ async def start_tasks(tasks):
active_tasks = []
def remove_done_tasks():
for active_task in active_tasks:
if active_task._state == 'FINISHED':
if active_task.finished:
active_tasks.remove(active_task)
break
for task in tasks:
while len(active_tasks) > max_ssh:

async def add_tasks():
for task in tasks:
while len(active_tasks) > max_ssh:
remove_done_tasks()
await asyncio.sleep(0.01)
asyncio.ensure_future(task, loop = loop)
active_tasks.append(task)
run_print('Starting Job: {}'.format(task.node))
await(asyncio.sleep(start_delay))

async def await_finished():
while len(active_tasks) > 0:
remove_done_tasks()
await asyncio.sleep(0.01)
asyncio.ensure_future(task, loop = loop)
active_tasks.append(task)
run_print('Starting Job: {}'.format(task.node))
await(asyncio.sleep(start_delay))
run_print('Waiting for jobs to finish...')
while len(active_tasks) > 0:
remove_done_tasks()
await asyncio.sleep(0.01)

try:
await add_tasks()
run_print('Waiting for jobs to finish...')
except concurrent.futures.CancelledError: pass
finally: await await_finished()

session = Session()
try:
Expand All @@ -153,7 +162,6 @@ def remove_done_tasks():
run_print('Executing: {}'.format(batch.name()))
tasks = map(lambda j: j.task(thread_pool = pool), batch.jobs)
loop.run_until_complete(start_tasks(tasks))
except concurrent.futures.CancelledError: pass
finally:
run_print('Cleaning up...')
pool.shutdown(wait = True)
Expand Down
12 changes: 9 additions & 3 deletions src/models/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,17 @@ def __init__(self, job, thread_pool = None):
self.thread_pool = thread_pool
super().__init__(self.run_async())
self.job = job
self.add_job_callback(lambda job: job.connection().close())
self.add_job_callback(type(self).close)
self.add_job_callback(lambda job: job.set_ssh_results())
self.add_done_callback(type(self).report_results)

def close(self):
try: job.connection.close()
except: pass

def finished(self):
return self._state == 'FINISHED'

def __getattr__(self, attr):
return getattr(self.job, attr)

Expand Down Expand Up @@ -117,8 +124,7 @@ def catch_errors(func, *args):

async def run_async(self):
if self.check_command():
try: await self._run_thread(self.connection().open)
except concurrent.futures.CancelledError as e: raise e
await self._run_thread(self.connection().open)

if self.connection().is_connected:
await self._run_thread(self.run, self.batch)
Expand Down

0 comments on commit 56b07c8

Please sign in to comment.