Skip to content

Commit

Permalink
Overhaul the scheduler.
Browse files Browse the repository at this point in the history
- Improve the handling of tasks, machines, and how they are updated in
  the database. Use transactions and locks appropriately so that changes
  are more atomic.
- Remove the need for holding the machine_lock for long periods of time
  in the main loop and remove the need for batch scheduling tasks.
- Make the code a lot cleaner and readable, including separation of concerns
  among various classes. Introduce a MachineryManager class that does
  what the name suggests. In the future, this could have an API added
  that could provide us a way to dynamically update machines in the
  database without having to update a conf file and restart cuckoo.py.
  • Loading branch information
Tommy Beadle committed Apr 1, 2024
1 parent 8392843 commit 96fddc7
Show file tree
Hide file tree
Showing 26 changed files with 1,111 additions and 1,101 deletions.
5 changes: 0 additions & 5 deletions conf/default/cuckoo.conf.default
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ machinery_screenshots = off
scaling_semaphore = off
# A configurable wait time between updating the limit value of the scaling bounded semaphore
scaling_semaphore_update_timer = 10
# Allow more than one task scheduled to be assigned at once for better scaling
# A switch to allow batch task assignment, a method that can more efficiently assign tasks to available machines
batch_scheduling = off
# The maximum number of tasks assigned to machines per batch, optimal value dependent on deployment
max_batch_count = 20

# Enable creation of memory dump of the analysis machine before shutting
# down. Even if turned off, this functionality can also be enabled at
Expand Down
2 changes: 1 addition & 1 deletion cuckoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def cuckoo_main(max_analysis_count=0):
parser.add_argument("-v", "--version", action="version", version="You are running Cuckoo Sandbox {0}".format(CUCKOO_VERSION))
parser.add_argument("-a", "--artwork", help="Show artwork", action="store_true", required=False)
parser.add_argument("-t", "--test", help="Test startup", action="store_true", required=False)
parser.add_argument("-m", "--max-analysis-count", help="Maximum number of analyses", type=int, required=False)
parser.add_argument("-m", "--max-analysis-count", help="Maximum number of analyses", type=int, required=False, default=0)
parser.add_argument(
"-s",
"--stop",
Expand Down
145 changes: 65 additions & 80 deletions lib/cuckoo/common/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# See the file 'docs/LICENSE' for copying permission.

import datetime
import inspect
import io
import logging
import os
Expand Down Expand Up @@ -38,7 +37,7 @@
from lib.cuckoo.common.path_utils import path_exists, path_mkdir
from lib.cuckoo.common.url_validate import url as url_validator
from lib.cuckoo.common.utils import create_folder, get_memdump_path, load_categories
from lib.cuckoo.core.database import Database
from lib.cuckoo.core.database import Database, Machine, _Database

try:
import re2 as re
Expand Down Expand Up @@ -107,42 +106,48 @@ class Machinery:
# Default label used in machinery configuration file to supply virtual
# machine name/label/vmx path. Override it if you dubbed it in another
# way.
LABEL = "label"
LABEL: str = "label"

# This must be defined in sub-classes.
module_name: str

def __init__(self):
self.module_name = ""
self.options = None
# Database pointer.
self.db = Database()
# Machine table is cleaned to be filled from configuration file
# at each start.
self.db.clean_machines()
self.db: _Database = Database()

# Find its configuration file.
conf = os.path.join(CUCKOO_ROOT, "conf", f"{self.module_name}.conf")
if not path_exists(conf):
raise CuckooCriticalError(
f'The configuration file for machine manager "{self.module_name}" does not exist at path: {conf}'
)
self.set_options(Config(self.module_name))

def set_options(self, options: dict):
def set_options(self, options: dict) -> None:
"""Set machine manager options.
@param options: machine manager options dict.
"""
self.options = options
mmanager_opts = self.options.get(self.module_name)
if not isinstance(mmanager_opts["machines"], list):
mmanager_opts["machines"] = str(mmanager_opts["machines"]).strip().split(",")

def initialize(self) -> None:
"""Read, load, and verify machines configuration."""
# Machine table is cleaned to be filled from configuration file
# at each start.
self.db.clean_machines()

def initialize(self, module_name):
"""Read, load, and verify machines configuration.
@param module_name: module name.
"""
# Load.
self._initialize(module_name)
self._initialize()

# Run initialization checks.
self._initialize_check()

def _initialize(self, module_name):
"""Read configuration.
@param module_name: module name.
"""
self.module_name = module_name
mmanager_opts = self.options.get(module_name)
if not isinstance(mmanager_opts["machines"], list):
mmanager_opts["machines"] = str(mmanager_opts["machines"]).strip().split(",")

def _initialize(self) -> None:
"""Read configuration."""
mmanager_opts = self.options.get(self.module_name)
for machine_id in mmanager_opts["machines"]:
try:
machine_opts = self.options.get(machine_id.strip())
Expand Down Expand Up @@ -198,7 +203,7 @@ def _initialize(self, module_name):
log.warning("Configuration details about machine %s are missing: %s", machine_id.strip(), e)
continue

def _initialize_check(self):
def _initialize_check(self) -> None:
"""Runs checks against virtualization software when a machine manager is initialized.
@note: in machine manager modules you may override or superclass his method.
@raise CuckooMachineError: if a misconfiguration or a unkown vm state is found.
Expand All @@ -208,20 +213,24 @@ def _initialize_check(self):
except NotImplementedError:
return

self.shutdown_running_machines(configured_vms)
self.check_screenshot_support()

if not cfg.timeouts.vm_state:
raise CuckooCriticalError("Virtual machine state change timeout setting not found, please add it to the config file")

def check_screenshot_support(self) -> None:
# If machinery_screenshots are enabled, check the machinery supports it.
if cfg.cuckoo.machinery_screenshots:
# inspect function members available on the machinery class
cls_members = inspect.getmembers(self.__class__, predicate=inspect.isfunction)
for name, function in cls_members:
if name != Machinery.screenshot.__name__:
continue
if Machinery.screenshot == function:
msg = f"machinery {self.module_name} does not support machinery screenshots"
raise CuckooCriticalError(msg)
break
else:
raise NotImplementedError(f"missing machinery method: {Machinery.screenshot.__name__}")
if not cfg.cuckoo.machinery_screenshots:
return

# inspect function members available on the machinery class
func = getattr(self.__class__, "screenshot")
if func == Machinery.screenshot:
msg = f"machinery {self.module_name} does not support machinery screenshots"
raise CuckooCriticalError(msg)

def shutdown_running_machines(self, configured_vms: List[str]) -> None:
for machine in self.machines():
# If this machine is already in the "correct" state, then we
# go on to the next machine.
Expand All @@ -236,16 +245,13 @@ def _initialize_check(self):
msg = f"Please update your configuration. Unable to shut '{machine.label}' down or find the machine in its proper state: {e}"
raise CuckooCriticalError(msg) from e

if not cfg.timeouts.vm_state:
raise CuckooCriticalError("Virtual machine state change timeout setting not found, please add it to the config file")

def machines(self):
"""List virtual machines.
@return: virtual machines list
"""
return self.db.list_machines(include_reserved=True)

def availables(self, label=None, platform=None, tags=None, arch=None, include_reserved=False, os_version=[]):
def availables(self, label=None, platform=None, tags=None, arch=None, include_reserved=False, os_version=None):
"""How many (relevant) machines are free.
@param label: machine ID.
@param platform: machine platform.
Expand All @@ -257,39 +263,25 @@ def availables(self, label=None, platform=None, tags=None, arch=None, include_re
label=label, platform=platform, tags=tags, arch=arch, include_reserved=include_reserved, os_version=os_version
)

def acquire(self, machine_id=None, platform=None, tags=None, arch=None, os_version=[], need_scheduled=False):
"""Acquire a machine to start analysis.
@param machine_id: machine ID.
@param platform: machine platform.
@param tags: machine tags
@param arch: machine arch
@param os_version: tags to filter per OS version. Ex: winxp, win7, win10, win11
@param need_scheduled: should the result be filtered on 'scheduled' machine status
@return: machine or None.
"""
if machine_id:
return self.db.lock_machine(label=machine_id, need_scheduled=need_scheduled)
elif platform:
return self.db.lock_machine(
platform=platform, tags=tags, arch=arch, os_version=os_version, need_scheduled=need_scheduled
)
return self.db.lock_machine(tags=tags, arch=arch, os_version=os_version, need_scheduled=need_scheduled)

def get_machines_scheduled(self):
return self.db.get_machines_scheduled()
def scale_pool(self, machine: Machine) -> None:
"""This can be overridden in sub-classes to scale the pool of machines once one has been acquired."""
return

def release(self, label=None):
def release(self, machine: Machine) -> Machine:
"""Release a machine.
@param label: machine name.
"""
self.db.unlock_machine(label)
return self.db.unlock_machine(machine)

def running(self):
"""Returns running virtual machines.
@return: running virtual machines list.
"""
return self.db.list_machines(locked=True)

def running_count(self):
return self.db.count_machines_running()

def screenshot(self, label, path):
"""Screenshot a running virtual machine.
@param label: machine name
Expand All @@ -302,9 +294,10 @@ def shutdown(self):
"""Shutdown the machine manager. Kills all alive machines.
@raise CuckooMachineError: if unable to stop machine.
"""
if len(self.running()) > 0:
log.info("Still %d guests still alive, shutting down...", len(self.running()))
for machine in self.running():
running = self.running()
if len(running) > 0:
log.info("Still %d guests still alive, shutting down...", len(running))
for machine in running:
try:
self.stop(machine.label)
except CuckooMachineError as e:
Expand Down Expand Up @@ -389,23 +382,12 @@ class LibVirtMachinery(Machinery):
ABORTED = "abort"

def __init__(self):

if not categories_need_VM:
return

if not HAVE_LIBVIRT:
raise CuckooDependencyError(
"Unable to import libvirt. Ensure that you properly installed it by running: cd /opt/CAPEv2/ ; sudo -u cape poetry run extra/libvirt_installer.sh"
)

super(LibVirtMachinery, self).__init__()

def initialize(self, module):
"""Initialize machine manager module. Override default to set proper
connection string.
@param module: machine manager module
"""
super(LibVirtMachinery, self).initialize(module)
super().__init__()

def _initialize_check(self):
"""Runs all checks when a machine manager is initialized.
Expand All @@ -420,7 +402,7 @@ def _initialize_check(self):

# Base checks. Also attempts to shutdown any machines which are
# currently still active.
super(LibVirtMachinery, self)._initialize_check()
super()._initialize_check()

def start(self, label):
"""Starts a virtual machine.
Expand All @@ -429,14 +411,17 @@ def start(self, label):
"""
log.debug("Starting machine %s", label)

vm_info = self.db.view_machine_by_label(label)
if vm_info is None:
msg = f"Unable to find machine with label {label} in database."
raise CuckooMachineError(msg)

if self._status(label) != self.POWEROFF:
msg = f"Trying to start a virtual machine that has not been turned off {label}"
raise CuckooMachineError(msg)

conn = self._connect(label)

vm_info = self.db.view_machine_by_label(label)

snapshot_list = self.vms[label].snapshotListNames(flags=0)

# If a snapshot is configured try to use it.
Expand Down
11 changes: 11 additions & 0 deletions lib/cuckoo/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ class CuckooStartupError(CuckooCriticalError):
pass


class CuckooDatabaseInitializationError(CuckooCriticalError):
def __str__(self):
return "The database has not been initialized yet. You must call init_database before attempting to use it."


class CuckooDatabaseError(CuckooCriticalError):
"""Cuckoo database error."""

Expand All @@ -33,6 +38,12 @@ class CuckooOperationalError(Exception):
pass


class CuckooUnserviceableTaskError(CuckooOperationalError):
"""There are no machines in the pool that can service the task."""

pass


class CuckooMachineError(CuckooOperationalError):
"""Error managing analysis machine."""

Expand Down
8 changes: 5 additions & 3 deletions lib/cuckoo/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import zipfile
from datetime import datetime
from io import BytesIO
from typing import Tuple, Union
from typing import Final, List, Tuple, Union

from data.family_detection_names import family_detection_names
from lib.cuckoo.common import utils_dicts
Expand Down Expand Up @@ -89,10 +89,12 @@ def arg_name_clscontext(arg_val):
sanitize_len = config.cuckoo.get("sanitize_len", 32)
sanitize_to_len = config.cuckoo.get("sanitize_to_len", 24)

CATEGORIES_NEEDING_VM: Final[Tuple[str]] = ("file", "url")

def load_categories():

def load_categories() -> Tuple[List[str], bool]:
analyzing_categories = [category.strip() for category in config.cuckoo.categories.split(",")]
needs_VM = any([category in analyzing_categories for category in ("file", "url")])
needs_VM = any(category in analyzing_categories for category in CATEGORIES_NEEDING_VM)
return analyzing_categories, needs_VM


Expand Down
9 changes: 5 additions & 4 deletions lib/cuckoo/common/web_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,14 @@ def statistics(s_days: int) -> dict:
details[module_name.split(".")[-1]].setdefault(name, entry)

top_samples = {}
session = db.Session()
added_tasks = (
session.query(Task).join(Sample, Task.sample_id == Sample.id).filter(Task.added_on.between(date_since, date_till)).all()
db.session.query(Task).join(Sample, Task.sample_id == Sample.id).filter(Task.added_on.between(date_since, date_till)).all()
)
tasks = (
session.query(Task).join(Sample, Task.sample_id == Sample.id).filter(Task.completed_on.between(date_since, date_till)).all()
db.session.query(Task)
.join(Sample, Task.sample_id == Sample.id)
.filter(Task.completed_on.between(date_since, date_till))
.all()
)
details["total"] = len(tasks)
details["average"] = f"{round(details['total'] / s_days, 2):.2f}"
Expand Down Expand Up @@ -487,7 +489,6 @@ def statistics(s_days: int) -> dict:
details["detections"] = top_detections(date_since=date_since)
details["asns"] = top_asn(date_since=date_since)

session.close()
return details


Expand Down
Loading

0 comments on commit 96fddc7

Please sign in to comment.