diff --git a/src/middlewared/middlewared/main.py b/src/middlewared/middlewared/main.py index 979c8a3c41d74..3d9cad504b855 100644 --- a/src/middlewared/middlewared/main.py +++ b/src/middlewared/middlewared/main.py @@ -24,7 +24,7 @@ from .utils.plugins import LoadPluginsMixin from .utils.privilege import credential_has_full_admin from .utils.profile import profile_wrap -from .utils.rate_limit import RateLimitCache +from .utils.rate_limit.cache import RateLimitCache from .utils.service.call import ServiceCallMixin from .utils.syslog import syslog_message from .utils.threading import set_thread_name, IoThreadPoolExecutor, io_thread_pool_executor @@ -359,10 +359,10 @@ async def on_message(self, message: typing.Dict[str, typing.Any]): self.send_error(message, e.errno, str(e), sys.exc_info(), extra=e.extra) error = True - auth_required = not hasattr(methodobj, '_no_auth_required') if not error: + auth_required = not hasattr(methodobj, '_no_auth_required') if not auth_required: - ip_added = RateLimitCache.add(message['method'], self.origin) + ip_added = await RateLimitCache.add(message['method'], self.origin) if ip_added is not None: if any(( RateLimitCache.max_entries_reached, @@ -375,8 +375,8 @@ async def on_message(self, message: typing.Dict[str, typing.Any]): # origin IP address # In either scenario, sleep a random delay and send an error await self.__log_audit_message_for_method(message, methodobj, False, True, False) - await RateLimitCache.sleep_random() - self.send_error('Rate Limit Exceeded', errno.EBUSY) + await RateLimitCache.random_sleep() + self.send_error(message, errno.EBUSY, 'Rate Limit Exceeded') error = True else: # was added to rate limit cache but rate limit thresholds haven't diff --git a/src/middlewared/middlewared/plugins/rate_limit/__init__.py b/src/middlewared/middlewared/plugins/rate_limit/__init__.py index b94f403d79578..2d6a5dc0f2cc9 100644 --- a/src/middlewared/middlewared/plugins/rate_limit/__init__.py +++ b/src/middlewared/middlewared/plugins/rate_limit/__init__.py @@ -1,6 +1,6 @@ from middlewared.service import periodic, Service -from middlewared.utils.rate_limit import RateLimitCache +from middlewared.utils.rate_limit.cache import RateLimitCache CLEAR_CACHE_INTERVAL = 600 @@ -22,4 +22,4 @@ async def clear_cache(self): # store a maximum of amount of entries in the cache and # then refuse to honor any more requests for all consumers. # This is required for STIG purposes. - RateLimitCache.clear() + await RateLimitCache.clear() diff --git a/src/middlewared/middlewared/service/__init__.py b/src/middlewared/middlewared/service/__init__.py index 6fd9ff666eee3..43074c7b6ef03 100644 --- a/src/middlewared/middlewared/service/__init__.py +++ b/src/middlewared/middlewared/service/__init__.py @@ -17,7 +17,6 @@ from .service_part import ServicePartBase # noqa from .sharing_service import SharingService, SharingTaskService, TaskPathService # noqa from .system_service import SystemServiceService # noqa -from .throttle import throttle # noqa ABSTRACT_SERVICES = ( # noqa diff --git a/src/middlewared/middlewared/utils/rate_limit/cache.py b/src/middlewared/middlewared/utils/rate_limit/cache.py index ca60bf1ab7102..d1ea957caf0f6 100644 --- a/src/middlewared/middlewared/utils/rate_limit/cache.py +++ b/src/middlewared/middlewared/utils/rate_limit/cache.py @@ -1,84 +1,77 @@ from asyncio import sleep from dataclasses import dataclass from random import uniform -from threading import RLock from time import monotonic +from typing import Self, TypedDict from middlewared.auth import is_ha_connection from middlewared.utils.origin import TCPIPOrigin __all__ = ('RateLimitCache') -"""The maximum number of calls per unique consumer of the endpoint.""" -MAX_CALLS: int = 20 -"""The maximum time in seconds that a unique consumer may request an -endpoint that is being rate limited.""" -MAX_PERIOD: int = 60 - -@dataclass(slots=True, kw_only=True) -class RateLimitObject: - """A per-{endpoint/consumer} re-entrant lock so that a - global lock is not shared between all (potential) - consumers hitting the same endpoint.""" - lock: RLock - """The number of times this method was called by the consumer.""" - num_times_called: int = 0 - """The monotonic time representing when this particular cache - entry was last reset.""" - last_reset: float = monotonic() - - -@dataclass(slots=True) -class RateLimitCachedObjects: +@dataclass(frozen=True) +class RateLimitConfig: + """The maximum number of calls per unique consumer of the endpoint.""" + max_calls: int = 20 + """The maximum time in seconds that a unique consumer may request an + endpoint that is being rate limited.""" + max_period: int = 60 """The maximum number of unique entries the cache supports""" - MAX_CACHE_ENTRIES: int = 100 + max_cache_entries: int = 100 """The value used to separate the unique values when generating a unique key to be used to store the cached information.""" - SEPARATOR: str = '_' - """The global cache object used to store the information about - all endpoints/consumers being rate limited.""" - CACHE: dict[str, RateLimitObject] = dict() + separator: str = '_##_' """The starting decimal value for the time to be slept in the event rate limit thresholds for a particular consumer has been met.""" - RANDOM_START: float = 1.0 + sleep_start: float = 1.0 """The ending decimal value for the time to be slept in the event rate limit thresholds for a particular consumer has been met.""" - RANDOM_END: float = 10.0 + sleep_end: float = 10.0 + + +class RateLimitObject(TypedDict): + """The number of times this method was called by the consumer.""" + num_times_called: int + """The monotonic time representing when this particular cache + entry was last reset.""" + last_reset: float - @property - def max_entries_reached(self) -> bool: - """Return a boolean indicating if the total number of entries - in the global cache has reached `self.MAX_CACHE_ENTRIES`.""" - return len(self.CACHE) == self.MAX_CACHE_ENTRIES +RL_CACHE: dict[str, RateLimitObject] = dict() + + + +class RateLimit: def cache_key(self, method_name: str, ip: str) -> str: """Generate a unique key per endpoint/consumer""" - return f'{method_name}{self.SEPARATOR}{ip}' + return f'{method_name}{RateLimitConfig.separator}{ip}' def rate_limit_exceeded(self, method_name: str, ip: str) -> bool: """Return a boolean indicating if the total number of calls per unique endpoint/consumer has been reached.""" key = self.cache_key(method_name, ip) try: - with self.CACHE[key].lock: - now: float = monotonic() - if MAX_PERIOD - (now - self.CACHE[key].last_reset) <= 0: - # time window elapsed, so time to reset - self.CACHE[key].num_times_called = 0 - self.CACHE[key].last_reset = now - - # always increment - self.CACHE[key].num_times_called += 1 - return self.CACHE[key].num_times_called > MAX_CALLS + now: float = monotonic() + if RateLimitConfig.max_period - (now - RL_CACHE[key]['last_reset']) <= 0: + # time window elapsed, so time to reset + RL_CACHE[key]['num_times_called'] = 0 + RL_CACHE[key]['last_reset'] = now + + # always increment + RL_CACHE[key]['num_times_called'] += 1 + return RL_CACHE[key]['num_times_called'] > RateLimitConfig.max_calls except KeyError: pass return False - def add(self, method_name: str, origin: TCPIPOrigin) -> str | None: + async def add(self, method_name: str, origin: TCPIPOrigin) -> str | None: """Add an entry to the cache. Returns the IP address of origin of the request if it has been cached, returns None otherwise""" + if not isinstance(origin, TCPIPOrigin): + return None + ip, port = origin.addr, origin.port if any((ip is None, port is None)) or is_ha_connection(ip, port): # Short-circuit if: @@ -89,29 +82,27 @@ def add(self, method_name: str, origin: TCPIPOrigin) -> str | None: return None else: key = self.cache_key(method_name, ip) - if key not in self.CACHE: - self.CACHE[key] = RateLimitObject(lock=RLock()) - return ip - - return None + if key not in RL_CACHE: + RL_CACHE[key] = RateLimitObject(num_times_called=0, last_reset=monotonic()) + return ip - def pop(self, method_name: str, ip: str) -> None: + async def pop(self, method_name: str, ip: str) -> None: """Pop (remove) an entry from the cache.""" - self.CACHE.pop(self.cache_key(method_name, ip), None) + RL_CACHE.pop(self.cache_key(method_name, ip), None) - def clear(self) -> None: + async def clear(self) -> None: """Clear all entries from the cache.""" - self.CACHE.clear() - - @property - def random_range(self) -> float: - """Return a random float within self.RANDOM_START and self.RANDOM_END - rounded to the 100th decimal point""" - return round(uniform(self.RANDOM_START, self.RANDOM_END), 2) + RL_CACHE.clear() async def random_sleep(self) -> None: - """Sleep a random amount of seconds within range of `self.random_range`.""" - await sleep(self.random_range) + """Sleep a random amount of seconds.""" + await sleep(round(uniform(RateLimitConfig.sleep_start, RateLimitConfig.sleep_end), 2)) + + @property + def max_entries_reached(self) -> bool: + """Return a boolean indicating if the total number of entries + in the global cache has reached `self.max_cache_entries`.""" + return len(RL_CACHE) == RateLimitConfig.max_cache_entries -RateLimitCache = RateLimitCachedObjects() +RateLimitCache = RateLimit()