Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move task scheduling to a separate thread #63

Merged
merged 3 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 4 additions & 135 deletions alts/scheduler/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
"""AlmaLinux Test System tasks scheduler application."""

import logging
import random
import signal
import uuid
from threading import Event

from celery.exceptions import TimeoutError
Expand All @@ -26,14 +24,8 @@
from alts.scheduler.monitoring import TasksMonitor
from alts.shared.constants import API_VERSION
from alts.shared.exceptions import ALTSBaseError
from alts.shared.models import (
TaskRequestPayload,
TaskRequestResponse,
TaskResultResponse,
)
from alts.shared.models import TaskResultResponse
from alts.worker.app import celery_app
from alts.worker.mappings import RUNNER_MAPPING
from alts.worker.tasks import run_tests

app = FastAPI()
monitor = None
Expand Down Expand Up @@ -163,7 +155,9 @@ async def authenticate_user(credentials: str = Depends(http_bearer_scheme)):
else:
token = credentials.credentials
return jwt.decode(
token, CONFIG.jwt_secret, algorithms=[CONFIG.hashing_algorithm]
token,
CONFIG.jwt_secret,
algorithms=[CONFIG.hashing_algorithm],
)
except JWTError:
raise HTTPException(
Expand Down Expand Up @@ -196,128 +190,3 @@ async def get_task_result(
task_result = get_celery_task_result(task_id)
task_result['api_version'] = API_VERSION
return JSONResponse(content=task_result)


@app.post(
'/tasks/schedule',
response_model=TaskRequestResponse,
responses={
201: {'model': TaskRequestResponse},
400: {'model': TaskRequestResponse},
},
)
async def schedule_task(
task_data: TaskRequestPayload,
_=Depends(authenticate_user),
) -> JSONResponse:
"""
Schedules new tasks in Test System.

Parameters
----------
task_data : TaskRequestPayload
Loader task data in appropriate for request form.
b_tasks : BackgroundTasks
Tasks running in background.
_ : dict
Authenticated user's token.

Returns
-------
JSONResponse
JSON-encoded response if task executed successfully or not.
"""
# Get only supported runners mapping based on the config
if (
isinstance(CONFIG.supported_runners, str)
and CONFIG.supported_runners == 'all'
):
runner_mapping = RUNNER_MAPPING
elif isinstance(CONFIG.supported_runners, list):
runner_mapping = {
key: value
for key, value in RUNNER_MAPPING.items()
if key in CONFIG.supported_runners
}
else:
raise ValueError(
f'Misconfiguration found: supported_runners is '
f'{CONFIG.supported_runners}'
)
runner_type = task_data.runner_type
if runner_type == 'any':
runner_type = random.choice(list(runner_mapping.keys()))
runner_class = RUNNER_MAPPING[runner_type]

if task_data.dist_arch not in CONFIG.supported_architectures:
raise ValueError(f'Unknown architecture: {task_data.dist_arch}')
if task_data.dist_name not in CONFIG.supported_distributions:
raise ValueError(f'Unknown distribution: {task_data.dist_name}')

# TODO: Make decision on what queue to use for particular task based on
# queues load
queue_arch = None
for arch, supported_arches in runner_class.ARCHITECTURES_MAPPING.items():
if task_data.dist_arch in supported_arches:
queue_arch = arch

if not queue_arch:
raise ValueError(
'Cannot map requested architecture to any '
'host architecture, possible coding error'
)

# Make sure all repositories have their names
# (needed only for RHEL-like distributions)
# Convert repositories structures to dictionaries
repositories = []
repo_counter = 0
for repository in task_data.repositories:
if not repository.name:
repo_name = f'repo-{repo_counter}'
repo_counter += 1
else:
repo_name = repository.name
repositories.append({'url': repository.baseurl, 'name': repo_name})

queue_name = f'{runner_type}-{queue_arch}-{runner_class.COST}'
task_id = str(uuid.uuid4())
response_content = {'api_version': API_VERSION}
task_params = task_data.model_dump()
task_params['task_id'] = task_id
task_params['runner_type'] = runner_type
task_params['repositories'] = repositories
try:
run_tests.apply_async(
(task_params,),
task_id=task_id,
queue=queue_name,
)
except Exception as e:
logging.exception('Cannot launch the task:')
response_content.update(
{'success': False, 'error_description': str(e)}
)
return JSONResponse(status_code=400, content=response_content)
with Session() as session:
with session.begin():
try:
task_record = Task(
task_id=task_id,
queue_name=queue_name,
status='NEW',
)
session.add(task_record)
session.commit()
response_content.update({'success': True, 'task_id': task_id})
return JSONResponse(status_code=201, content=response_content)
except Exception as e:
logging.exception('Cannot save task data into DB:')
response_content.update(
{
'success': False,
'task_id': task_id,
'error_description': str(e),
}
)
return JSONResponse(status_code=400, content=response_content)
199 changes: 168 additions & 31 deletions alts/scheduler/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,188 @@
import random
import threading
import time
import urllib.parse
import uuid
from typing import List

import requests
from celery.exceptions import TimeoutError
from celery.states import READY_STATES

from alts.scheduler import CONFIG
from alts.scheduler.db import Session, Task
from alts.shared.models import TaskRequestPayload
from alts.worker.mappings import RUNNER_MAPPING
from alts.worker.tasks import run_tests


class TasksMonitor(threading.Thread):
def __init__(self, terminated_event: threading.Event,
graceful_terminate: threading.Event, celery_app,
get_result_timeout: int = 1):
def __init__(
self,
terminated_event: threading.Event,
graceful_terminate: threading.Event,
celery_app,
get_result_timeout: int = 1,
):
super().__init__()
self.__terminated_event = terminated_event
self.__graceful_terminate = graceful_terminate
self.__celery = celery_app
self.__get_result_timeout = get_result_timeout
self.logger = logging.getLogger(__file__)

def get_available_test_tasks(self) -> List[dict]:
response = []
try:
self.logger.info('Getting new available test tasks')
response = requests.get(
urllib.parse.urljoin(
CONFIG.bs_host,
CONFIG.bs_get_task_endpoint,
),
headers={'Authorization': f'Bearer {CONFIG.bs_token}'},
).json()
if not response:
self.logger.info('There is no available test tasks')
except Exception:
self.logger.exception('Cannot get available test tasks:')
return response

def schedule_test_task(self, payload: TaskRequestPayload):
"""
Schedules new tasks in Test System.

Parameters
----------
payload : TaskRequestPayload
Loader task data in appropriate for request form.

Returns
-------
JSONResponse
JSON-encoded response if task executed successfully or not.
"""
# Get only supported runners mapping based on the config
if (
isinstance(CONFIG.supported_runners, str)
and CONFIG.supported_runners == 'all'
):
runner_mapping = RUNNER_MAPPING
elif isinstance(CONFIG.supported_runners, list):
runner_mapping = {
key: value
for key, value in RUNNER_MAPPING.items()
if key in CONFIG.supported_runners
}
else:
self.logger.error(
'Misconfiguration found: supported_runners is %s',
CONFIG.supported_runners,
)
return
runner_type = payload.runner_type
if runner_type == 'any':
runner_type = random.choice(list(runner_mapping.keys()))
runner_class = RUNNER_MAPPING[runner_type]

if payload.dist_arch not in CONFIG.supported_architectures:
self.logger.error('Unknown architecture: %s', payload.dist_arch)
return
if payload.dist_name not in CONFIG.supported_distributions:
self.logger.error('Unknown distribution: %s', payload.dist_name)
return

# TODO: Make decision on what queue to use for particular task based on
# queues load
queue_arch = None
for (
arch,
supported_arches,
) in runner_class.ARCHITECTURES_MAPPING.items():
if payload.dist_arch in supported_arches:
queue_arch = arch

if not queue_arch:
self.logger.error(
'Cannot map requested architecture to any '
'host architecture, possible coding error'
)
return

# Make sure all repositories have their names
# (needed only for RHEL-like distributions)
# Convert repositories structures to dictionaries
repositories = []
repo_counter = 0
for repository in payload.repositories:
repo_name = repository.name
if not repo_name:
repo_name = f'repo-{repo_counter}'
repo_counter += 1
repositories.append({'url': repository.baseurl, 'name': repo_name})

queue_name = f'{runner_type}-{queue_arch}-{runner_class.COST}'
task_id = str(uuid.uuid4())
task_params = payload.model_dump()
task_params['task_id'] = task_id
task_params['runner_type'] = runner_type
task_params['repositories'] = repositories
try:
run_tests.apply_async(
(task_params,),
task_id=task_id,
queue=queue_name,
)
except Exception:
# TODO: report error to the web server
self.logger.exception('Cannot launch the task:')
with Session() as session, session.begin():
try:
task_record = Task(
task_id=task_id,
queue_name=queue_name,
status='NEW',
)
session.add(task_record)
session.commit()
except Exception:
self.logger.exception('Cannot save task data into DB:')

def run(self) -> None:
while not self.__graceful_terminate.is_set() or \
not self.__terminated_event.is_set():
while (
not self.__graceful_terminate.is_set()
or not self.__terminated_event.is_set()
):
for test_task_payload in self.get_available_test_tasks():
self.schedule_test_task(
TaskRequestPayload(**test_task_payload)
)

updated_tasks = []
with Session() as session:
with session.begin():
for task in session.query(Task).filter(
Task.status.notin_(READY_STATES)):
task_result = self.__celery.AsyncResult(task.task_id)
# Ensure that task state will be updated
# by getting task result
try:
_ = task_result.get(
timeout=self.__get_result_timeout)
except TimeoutError:
pass
if task_result.state != task.status:
self.logger.info('Updating task %s status to %s',
task.task_id, task_result.state)
task.status = task_result.state
updated_tasks.append(task)
time.sleep(0.5)
if updated_tasks:
try:
session.add_all(updated_tasks)
session.commit()
except Exception as e:
self.logger.exception(
'Cannot update tasks statuses:')
self.__terminated_event.wait(random.randint(5, 10))
with Session() as session, session.begin():
for task in session.query(Task).filter(
Task.status.notin_(READY_STATES)
):
task_result = self.__celery.AsyncResult(task.task_id)
# Ensure that task state will be updated
# by getting task result
try:
_ = task_result.get(timeout=self.__get_result_timeout)
except TimeoutError:
pass
if task_result.state != task.status:
self.logger.info(
'Updating task %s status to %s',
task.task_id,
task_result.state,
)
task.status = task_result.state
updated_tasks.append(task)
time.sleep(0.5)
if updated_tasks:
try:
session.add_all(updated_tasks)
session.commit()
except Exception:
self.logger.exception('Cannot update tasks statuses:')
self.__terminated_event.wait(random.randint(10, 15))
1 change: 1 addition & 0 deletions alts/shared/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(self, **data):
ssh_public_key_path: str = '~/.ssh/id_rsa.pub'
# Build system settings
bs_host: typing.Optional[str] = None
bs_get_task_endpoint: str = '/api/v1/tests/get_test_tasks/'
bs_token: typing.Optional[str] = None
# Log uploader settings
logs_uploader_config: typing.Optional[
Expand Down
Loading