Skip to content

Commit

Permalink
Move task scheduling to a separate thread
Browse files Browse the repository at this point in the history
  • Loading branch information
anfimovdm committed Oct 30, 2023
1 parent cc65377 commit 0a1ec1f
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 179 deletions.
139 changes: 4 additions & 135 deletions alts/scheduler/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
"""AlmaLinux Test System tasks scheduler application."""

import logging
import random
import signal
import uuid
from threading import Event

from celery.exceptions import TimeoutError
Expand All @@ -26,14 +24,8 @@
from alts.scheduler.monitoring import TasksMonitor
from alts.shared.constants import API_VERSION
from alts.shared.exceptions import ALTSBaseError
from alts.shared.models import (
TaskRequestPayload,
TaskRequestResponse,
TaskResultResponse,
)
from alts.shared.models import TaskResultResponse
from alts.worker.app import celery_app
from alts.worker.mappings import RUNNER_MAPPING
from alts.worker.tasks import run_tests

app = FastAPI()
monitor = None
Expand Down Expand Up @@ -163,7 +155,9 @@ async def authenticate_user(credentials: str = Depends(http_bearer_scheme)):
else:
token = credentials.credentials
return jwt.decode(
token, CONFIG.jwt_secret, algorithms=[CONFIG.hashing_algorithm]
token,
CONFIG.jwt_secret,
algorithms=[CONFIG.hashing_algorithm],
)
except JWTError:
raise HTTPException(
Expand Down Expand Up @@ -196,128 +190,3 @@ async def get_task_result(
task_result = get_celery_task_result(task_id)
task_result['api_version'] = API_VERSION
return JSONResponse(content=task_result)


@app.post(
'/tasks/schedule',
response_model=TaskRequestResponse,
responses={
201: {'model': TaskRequestResponse},
400: {'model': TaskRequestResponse},
},
)
async def schedule_task(
task_data: TaskRequestPayload,
_=Depends(authenticate_user),
) -> JSONResponse:
"""
Schedules new tasks in Test System.
Parameters
----------
task_data : TaskRequestPayload
Loader task data in appropriate for request form.
b_tasks : BackgroundTasks
Tasks running in background.
_ : dict
Authenticated user's token.
Returns
-------
JSONResponse
JSON-encoded response if task executed successfully or not.
"""
# Get only supported runners mapping based on the config
if (
isinstance(CONFIG.supported_runners, str)
and CONFIG.supported_runners == 'all'
):
runner_mapping = RUNNER_MAPPING
elif isinstance(CONFIG.supported_runners, list):
runner_mapping = {
key: value
for key, value in RUNNER_MAPPING.items()
if key in CONFIG.supported_runners
}
else:
raise ValueError(
f'Misconfiguration found: supported_runners is '
f'{CONFIG.supported_runners}'
)
runner_type = task_data.runner_type
if runner_type == 'any':
runner_type = random.choice(list(runner_mapping.keys()))
runner_class = RUNNER_MAPPING[runner_type]

if task_data.dist_arch not in CONFIG.supported_architectures:
raise ValueError(f'Unknown architecture: {task_data.dist_arch}')
if task_data.dist_name not in CONFIG.supported_distributions:
raise ValueError(f'Unknown distribution: {task_data.dist_name}')

# TODO: Make decision on what queue to use for particular task based on
# queues load
queue_arch = None
for arch, supported_arches in runner_class.ARCHITECTURES_MAPPING.items():
if task_data.dist_arch in supported_arches:
queue_arch = arch

if not queue_arch:
raise ValueError(
'Cannot map requested architecture to any '
'host architecture, possible coding error'
)

# Make sure all repositories have their names
# (needed only for RHEL-like distributions)
# Convert repositories structures to dictionaries
repositories = []
repo_counter = 0
for repository in task_data.repositories:
if not repository.name:
repo_name = f'repo-{repo_counter}'
repo_counter += 1
else:
repo_name = repository.name
repositories.append({'url': repository.baseurl, 'name': repo_name})

queue_name = f'{runner_type}-{queue_arch}-{runner_class.COST}'
task_id = str(uuid.uuid4())
response_content = {'api_version': API_VERSION}
task_params = task_data.model_dump()
task_params['task_id'] = task_id
task_params['runner_type'] = runner_type
task_params['repositories'] = repositories
try:
run_tests.apply_async(
(task_params,),
task_id=task_id,
queue=queue_name,
)
except Exception as e:
logging.exception('Cannot launch the task:')
response_content.update(
{'success': False, 'error_description': str(e)}
)
return JSONResponse(status_code=400, content=response_content)
with Session() as session:
with session.begin():
try:
task_record = Task(
task_id=task_id,
queue_name=queue_name,
status='NEW',
)
session.add(task_record)
session.commit()
response_content.update({'success': True, 'task_id': task_id})
return JSONResponse(status_code=201, content=response_content)
except Exception as e:
logging.exception('Cannot save task data into DB:')
response_content.update(
{
'success': False,
'task_id': task_id,
'error_description': str(e),
}
)
return JSONResponse(status_code=400, content=response_content)
199 changes: 168 additions & 31 deletions alts/scheduler/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,188 @@
import random
import threading
import time
import urllib.parse
import uuid
from typing import List

import requests
from celery.exceptions import TimeoutError
from celery.states import READY_STATES

from alts.scheduler import CONFIG
from alts.scheduler.db import Session, Task
from alts.shared.models import TaskRequestPayload
from alts.worker.mappings import RUNNER_MAPPING
from alts.worker.tasks import run_tests


class TasksMonitor(threading.Thread):
def __init__(self, terminated_event: threading.Event,
graceful_terminate: threading.Event, celery_app,
get_result_timeout: int = 1):
def __init__(
self,
terminated_event: threading.Event,
graceful_terminate: threading.Event,
celery_app,
get_result_timeout: int = 1,
):
super().__init__()
self.__terminated_event = terminated_event
self.__graceful_terminate = graceful_terminate
self.__celery = celery_app
self.__get_result_timeout = get_result_timeout
self.logger = logging.getLogger(__file__)

def get_available_test_tasks(self) -> List[dict]:
response = []
try:
self.logger.info('Getting new available test tasks')
response = requests.get(
urllib.parse.urljoin(
CONFIG.bs_host,
CONFIG.bs_get_task_endpoint,
),
headers={'Authorization': f'Bearer {CONFIG.bs_token}'},
).json()
if not response:
self.logger.info('There is no available test tasks')
except Exception:
self.logger.exception('Cannot get available test tasks:')
return response

def schedule_test_task(self, payload: TaskRequestPayload):
"""
Schedules new tasks in Test System.
Parameters
----------
payload : TaskRequestPayload
Loader task data in appropriate for request form.
Returns
-------
JSONResponse
JSON-encoded response if task executed successfully or not.
"""
# Get only supported runners mapping based on the config
if (
isinstance(CONFIG.supported_runners, str)
and CONFIG.supported_runners == 'all'
):
runner_mapping = RUNNER_MAPPING
elif isinstance(CONFIG.supported_runners, list):
runner_mapping = {
key: value
for key, value in RUNNER_MAPPING.items()
if key in CONFIG.supported_runners
}
else:
self.logger.error(
'Misconfiguration found: supported_runners is %s',
CONFIG.supported_runners,
)
return
runner_type = payload.runner_type
if runner_type == 'any':
runner_type = random.choice(list(runner_mapping.keys()))
runner_class = RUNNER_MAPPING[runner_type]

if payload.dist_arch not in CONFIG.supported_architectures:
self.logger.error('Unknown architecture: %s', payload.dist_arch)
return
if payload.dist_name not in CONFIG.supported_distributions:
self.logger.error('Unknown distribution: %s', payload.dist_name)
return

# TODO: Make decision on what queue to use for particular task based on
# queues load
queue_arch = None
for (
arch,
supported_arches,
) in runner_class.ARCHITECTURES_MAPPING.items():
if payload.dist_arch in supported_arches:
queue_arch = arch

if not queue_arch:
self.logger.error(
'Cannot map requested architecture to any '
'host architecture, possible coding error'
)
return

# Make sure all repositories have their names
# (needed only for RHEL-like distributions)
# Convert repositories structures to dictionaries
repositories = []
repo_counter = 0
for repository in payload.repositories:
repo_name = repository.name
if not repo_name:
repo_name = f'repo-{repo_counter}'
repo_counter += 1
repositories.append({'url': repository.baseurl, 'name': repo_name})

queue_name = f'{runner_type}-{queue_arch}-{runner_class.COST}'
task_id = str(uuid.uuid4())
task_params = payload.model_dump()
task_params['task_id'] = task_id
task_params['runner_type'] = runner_type
task_params['repositories'] = repositories
try:
run_tests.apply_async(
(task_params,),
task_id=task_id,
queue=queue_name,
)
except Exception:
# TODO: report error to the web server
self.logger.exception('Cannot launch the task:')
with Session() as session, session.begin():
try:
task_record = Task(
task_id=task_id,
queue_name=queue_name,
status='NEW',
)
session.add(task_record)
session.commit()
except Exception:
self.logger.exception('Cannot save task data into DB:')

def run(self) -> None:
while not self.__graceful_terminate.is_set() or \
not self.__terminated_event.is_set():
while (
not self.__graceful_terminate.is_set()
or not self.__terminated_event.is_set()
):
for test_task_payload in self.get_available_test_tasks():
self.schedule_test_task(
TaskRequestPayload(**test_task_payload)
)

updated_tasks = []
with Session() as session:
with session.begin():
for task in session.query(Task).filter(
Task.status.notin_(READY_STATES)):
task_result = self.__celery.AsyncResult(task.task_id)
# Ensure that task state will be updated
# by getting task result
try:
_ = task_result.get(
timeout=self.__get_result_timeout)
except TimeoutError:
pass
if task_result.state != task.status:
self.logger.info('Updating task %s status to %s',
task.task_id, task_result.state)
task.status = task_result.state
updated_tasks.append(task)
time.sleep(0.5)
if updated_tasks:
try:
session.add_all(updated_tasks)
session.commit()
except Exception as e:
self.logger.exception(
'Cannot update tasks statuses:')
self.__terminated_event.wait(random.randint(5, 10))
with Session() as session, session.begin():
for task in session.query(Task).filter(
Task.status.notin_(READY_STATES)
):
task_result = self.__celery.AsyncResult(task.task_id)
# Ensure that task state will be updated
# by getting task result
try:
_ = task_result.get(timeout=self.__get_result_timeout)
except TimeoutError:
pass
if task_result.state != task.status:
self.logger.info(
'Updating task %s status to %s',
task.task_id,
task_result.state,
)
task.status = task_result.state
updated_tasks.append(task)
time.sleep(0.5)
if updated_tasks:
try:
session.add_all(updated_tasks)
session.commit()
except Exception:
self.logger.exception('Cannot update tasks statuses:')
self.__terminated_event.wait(random.randint(10, 15))
1 change: 1 addition & 0 deletions alts/shared/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(self, **data):
ssh_public_key_path: str = '~/.ssh/id_rsa.pub'
# Build system settings
bs_host: typing.Optional[str] = None
bs_get_task_endpoint: str = '/api/v1/tests/get_test_tasks/'
bs_token: typing.Optional[str] = None
# Log uploader settings
logs_uploader_config: typing.Optional[
Expand Down
Loading

0 comments on commit 0a1ec1f

Please sign in to comment.