Skip to content

Commit

Permalink
Do not import middleware API in ZFS process pool (#15455)
Browse files Browse the repository at this point in the history
  • Loading branch information
themylogin authored Jan 23, 2025
1 parent 911a966 commit b7f0128
Show file tree
Hide file tree
Showing 17 changed files with 72 additions and 57 deletions.
2 changes: 2 additions & 0 deletions src/middlewared/middlewared/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .base.decorator import *

API_LOADING_FORBIDDEN = False
6 changes: 4 additions & 2 deletions src/middlewared/middlewared/api/base/types/urls.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Annotated
from typing import Annotated, Literal, TypeAlias

from pydantic import AfterValidator, HttpUrl

from middlewared.api.base.validators import https_only_check

__all__ = ["HttpsOnlyURL"]
__all__ = ["HttpsOnlyURL", "HttpVerb"]


HttpsOnlyURL = Annotated[HttpUrl, AfterValidator(https_only_check)]

HttpVerb: TypeAlias = Literal["GET", "POST", "PUT", "DELETE", "CALL", "SUBSCRIBE", "*"]
7 changes: 7 additions & 0 deletions src/middlewared/middlewared/api/current.py
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
from . import API_LOADING_FORBIDDEN
if API_LOADING_FORBIDDEN:
raise RuntimeError(
"Middleware API loading forbidden in this code path as it is too resource-consuming. Please, inspect the "
"provided traceback and ensure that nothing is imported from `middlewared.api.current`."
)

from .v25_04_0 import * # noqa
7 changes: 2 additions & 5 deletions src/middlewared/middlewared/api/v25_04_0/api_key.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from datetime import datetime
from typing import Annotated, Literal, TypeAlias
from typing import Annotated, Literal

from pydantic import Secret, StringConstraints

from middlewared.api.base import (
BaseModel, Excluded, excluded_field, ForUpdateMetaclass, NonEmptyString,
LocalUsername, RemoteUsername
LocalUsername, RemoteUsername, HttpVerb,
)


HttpVerb: TypeAlias = Literal["GET", "POST", "PUT", "DELETE", "CALL", "SUBSCRIBE", "*"]


class AllowListItem(BaseModel):
method: HttpVerb
resource: NonEmptyString
Expand Down
7 changes: 3 additions & 4 deletions src/middlewared/middlewared/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .schema import OROperator
import middlewared.service
from .service_exception import CallError, ErrnoMixin
from .utils import MIDDLEWARE_RUN_DIR, sw_version
from .utils import MIDDLEWARE_RUN_DIR, MIDDLEWARE_STARTED_SENTINEL_PATH, sw_version
from .utils.audit import audit_username_from_session
from .utils.debug import get_threads_stacks
from .utils.limits import MsgSizeError, MsgSizeLimit, parse_message
Expand All @@ -27,7 +27,6 @@
from .utils.rate_limit.cache import RateLimitCache
from .utils.service.call import ServiceCallMixin
from .utils.service.crud import real_crud_method
from .utils.syslog import syslog_message
from .utils.threading import set_thread_name, IoThreadPoolExecutor, io_thread_pool_executor
from .utils.time_utils import utc_now
from .utils.type import copy_function_metadata
Expand Down Expand Up @@ -200,7 +199,7 @@ def _add_api_route(self, version: str, api: API):
self.app.router.add_route('GET', f'/api/{version}', RpcWebSocketHandler(self, api.methods))

def __init_services(self):
from middlewared.service import CoreService
from middlewared.service.core_service import CoreService
self.add_service(CoreService(self))
self.event_register('core.environ', 'Send on middleware process environment changes.', private=True)

Expand Down Expand Up @@ -448,7 +447,7 @@ def __notify_startup_progress(self):
systemd_notify(f'EXTEND_TIMEOUT_USEC={SYSTEMD_EXTEND_USECS}')

def __notify_startup_complete(self):
with open(middlewared.service.MIDDLEWARE_STARTED_SENTINEL_PATH, 'w'):
with open(MIDDLEWARE_STARTED_SENTINEL_PATH, 'w'):
pass

systemd_notify('READY=1')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from middlewared.service import private, Service

from middlewared.plugins.config import FREENAS_DATABASE
from middlewared.utils.db import FREENAS_DATABASE

thread_pool = ThreadPoolExecutor(1)

Expand Down
2 changes: 0 additions & 2 deletions src/middlewared/middlewared/plugins/datastore/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

from sqlalchemy import and_, func, select
from sqlalchemy.sql import Alias
from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.sql.expression import nullsfirst, nullslast
from sqlalchemy.sql.operators import desc_op, nullsfirst_op, nullslast_op

from middlewared.schema import accepts, Bool, Dict, Int, List, Ref, Str
from middlewared.service import Service
Expand Down
3 changes: 2 additions & 1 deletion src/middlewared/middlewared/plugins/zettarepl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
from zettarepl.zettarepl import create_zettarepl

from middlewared.logger import setup_logging
from middlewared.service import CallError, Service
from middlewared.service.service import Service
from middlewared.service_exception import CallError
from middlewared.utils.cgroups import move_to_root_cgroups
from middlewared.utils.prctl import die_with_parent
from middlewared.utils.size import format_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from middlewared.service import CoreService
from middlewared.service.core_service import CoreService


@pytest.mark.parametrize("doc,names,descriptions", [
Expand Down
1 change: 0 additions & 1 deletion src/middlewared/middlewared/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from .compound_service import CompoundService # noqa
from .config_service import ConfigService # noqa
from .core_service import CoreService, MIDDLEWARE_RUN_DIR, MIDDLEWARE_STARTED_SENTINEL_PATH # noqa
from .crud_service import CRUDService # noqa
from .decorators import ( # noqa
cli_private, filterable, filterable_returns, item_method, job, lock, no_auth_required,
Expand Down
5 changes: 1 addition & 4 deletions src/middlewared/middlewared/service/core_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from middlewared.pipe import Pipes
from middlewared.schema import accepts, Any, Bool, Datetime, Dict, Int, List, Str
from middlewared.service_exception import CallError, ValidationErrors
from middlewared.utils import BOOTREADY, filter_list, MIDDLEWARE_RUN_DIR
from middlewared.utils import BOOTREADY, filter_list, MIDDLEWARE_STARTED_SENTINEL_PATH
from middlewared.utils.debug import get_frame_details, get_threads_stacks
from middlewared.validators import IpAddress, Range

Expand All @@ -44,9 +44,6 @@
from .service import Service


MIDDLEWARE_STARTED_SENTINEL_PATH = os.path.join(MIDDLEWARE_RUN_DIR, 'middlewared-started')


def is_service_class(service, klass):
return (
isinstance(service, klass) or
Expand Down
5 changes: 3 additions & 2 deletions src/middlewared/middlewared/service/crud_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from pydantic import create_model, Field

from middlewared.api import api_method
from middlewared.api import API_LOADING_FORBIDDEN, api_method
from middlewared.api.base.model import BaseModel, query_result, query_result_item
from middlewared.api.current import QueryArgs, QueryOptions
if not API_LOADING_FORBIDDEN:
from middlewared.api.current import QueryArgs, QueryOptions
from middlewared.service_exception import CallError, InstanceNotFound
from middlewared.schema import accepts, Any, Bool, convert_schema, Dict, Int, List, OROperator, Patch, Ref, returns
from middlewared.utils import filter_list
Expand Down
5 changes: 3 additions & 2 deletions src/middlewared/middlewared/service/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from collections import defaultdict, namedtuple
from functools import wraps

from middlewared.api import api_method
from middlewared.api import API_LOADING_FORBIDDEN, api_method
from middlewared.api.base import query_result
from middlewared.api.current import QueryArgs, GenericQueryResult
if not API_LOADING_FORBIDDEN:
from middlewared.api.current import QueryArgs, GenericQueryResult
from middlewared.schema import accepts, Int, List, OROperator, Ref, returns


Expand Down
1 change: 1 addition & 0 deletions src/middlewared/middlewared/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ProductNames:

MID_PID = None
MIDDLEWARE_RUN_DIR = '/var/run/middleware'
MIDDLEWARE_STARTED_SENTINEL_PATH = f'{MIDDLEWARE_RUN_DIR}/middlewared-started'
BOOTREADY = f'{MIDDLEWARE_RUN_DIR}/.bootready'
MANIFEST_FILE = '/data/manifest.json'
BRAND = ProductName.PRODUCT_NAME
Expand Down
2 changes: 1 addition & 1 deletion src/middlewared/middlewared/utils/allowlist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import fnmatch
import re

from middlewared.api.current import HttpVerb
from middlewared.api.base.types import HttpVerb
from middlewared.utils.privilege_constants import ALLOW_LIST_FULL_ADMIN


Expand Down
16 changes: 9 additions & 7 deletions src/middlewared/middlewared/utils/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
logger = logging.getLogger(__name__)


def load_modules(directory, base=None, depth=0):
def load_modules(directory, base=None, depth=0, whitelist=None):
directory = os.path.normpath(directory)
if base is None:
middlewared_root = os.path.dirname(os.path.dirname(__file__))
Expand All @@ -32,13 +32,15 @@ def load_modules(directory, base=None, depth=0):
base = '.'.join(os.path.relpath(directory, new_module_path).split('/'))

_, dirs, files = next(os.walk(directory))
for f in filter(lambda x: x[-3:] == '.py' and x.find('_freebsd') == -1, files):
yield importlib.import_module(base if f == '__init__.py' else f'{base}.{f[:-3]}')
for f in filter(lambda x: x[-3:] == '.py', files):
module_name = base if f == '__init__.py' else f'{base}.{f[:-3]}'
if whitelist is None or any(module_name.startswith(w) for w in whitelist):
yield importlib.import_module(module_name)

for f in filter(lambda x: x.find('_freebsd') == -1, dirs):
for f in dirs:
if depth > 0:
path = os.path.join(directory, f)
yield from load_modules(path, f'{base}.{f}', depth - 1)
yield from load_modules(path, f'{base}.{f}', depth - 1, whitelist)


def load_classes(module, base, blacklist):
Expand Down Expand Up @@ -92,15 +94,15 @@ def __init__(self):
self._services_aliases = {}
super().__init__()

def _load_plugins(self, on_module_begin=None, on_module_end=None, on_modules_loaded=None):
def _load_plugins(self, on_module_begin=None, on_module_end=None, on_modules_loaded=None, whitelist=None):
from middlewared.service import Service, CompoundService, ABSTRACT_SERVICES

services = []
plugins_dir = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'plugins'))
if not os.path.exists(plugins_dir):
raise ValueError(f'plugins dir not found: {plugins_dir}')

for mod in load_modules(plugins_dir, depth=1):
for mod in load_modules(plugins_dir, depth=1, whitelist=whitelist):
if on_module_begin:
on_module_begin(mod)

Expand Down
56 changes: 32 additions & 24 deletions src/middlewared/middlewared/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from truenas_api_client import Client

import middlewared.api
from . import logger
from .common.environ import environ_update
from .utils import MIDDLEWARE_RUN_DIR
Expand Down Expand Up @@ -49,29 +50,32 @@ def call_sync(self, method, *params, timeout=None, **kwargs):
"""
Calls a method using middleware client
"""
serviceobj, methodobj = self.get_method(method)

if serviceobj._config.process_pool and not hasattr(method, '_job'):
if asyncio.iscoroutinefunction(methodobj):
try:
# Search for a synchronous implementation of the asynchronous method (i.e. `get_instance`).
# Why is this needed? Imagine we have a `ZFSSnapshot` service that uses a process pool. Let's say
# its `create` method calls `zfs.snapshot.get_instance` to return the result. That call will have
# to be forwarded to the main middleware process, which will call `zfs.snapshot.query` in the
# process pool. If the process pool is already exhausted, it will lead to a deadlock.
# By executing a synchronous implementation of the same method in the same process pool we
# eliminate `Hold and wait` condition and prevent deadlock situation from arising.
_, sync_methodobj = self.get_method(f'{method}__sync')
except MethodNotFoundError:
# FIXME: Make this an exception in 22.MM
self.logger.warning('Service uses a process pool but has an asynchronous method: %r', method)
sync_methodobj = None
else:
sync_methodobj = methodobj

if sync_methodobj is not None:
self.logger.trace('Calling %r in current process', method)
return sync_methodobj(*params)
try:
serviceobj, methodobj = self.get_method(method)
except Exception:
pass
else:
if serviceobj._config.process_pool and not hasattr(method, '_job'):
if asyncio.iscoroutinefunction(methodobj):
try:
# Search for a synchronous implementation of the asynchronous method (i.e. `get_instance`).
# Why is this needed? Imagine we have a `ZFSSnapshot` service that uses a process pool. Let's say
# its `create` method calls `zfs.snapshot.get_instance` to return the result. That call will have
# to be forwarded to the main middleware process, which will call `zfs.snapshot.query` in the
# process pool. If the process pool is already exhausted, it will lead to a deadlock.
# By executing a synchronous implementation of the same method in the same process pool we
# eliminate `Hold and wait` condition and prevent deadlock situation from arising.
_, sync_methodobj = self.get_method(f'{method}__sync')
except MethodNotFoundError:
# FIXME: Make this an exception in 22.MM
self.logger.warning('Service uses a process pool but has an asynchronous method: %r', method)
sync_methodobj = None
else:
sync_methodobj = methodobj

if sync_methodobj is not None:
self.logger.trace('Calling %r in current process', method)
return sync_methodobj(*params)

return self.client.call(method, *params, timeout=timeout, **kwargs)

Expand Down Expand Up @@ -128,9 +132,13 @@ def receive_events():

def worker_init(debug_level, log_handler):
global MIDDLEWARE
middlewared.api.API_LOADING_FORBIDDEN = True
MIDDLEWARE = FakeMiddleware()
os.environ['MIDDLEWARED_LOADING'] = 'True'
MIDDLEWARE._load_plugins()
MIDDLEWARE._load_plugins(whitelist=[
'middlewared.plugins.datastore',
'middlewared.plugins.zfs_',
])
os.environ['MIDDLEWARED_LOADING'] = 'False'
setproctitle.setproctitle('middlewared (worker)')
die_with_parent()
Expand Down

0 comments on commit b7f0128

Please sign in to comment.