Skip to content

Commit

Permalink
QA round #1
Browse files Browse the repository at this point in the history
  • Loading branch information
yocalebo committed Jun 25, 2024
1 parent 18a10a6 commit 351af9e
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 71 deletions.
8 changes: 4 additions & 4 deletions src/middlewared/middlewared/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .utils.plugins import LoadPluginsMixin
from .utils.privilege import credential_has_full_admin
from .utils.profile import profile_wrap
from .utils.rate_limit import RateLimitCache
from .utils.rate_limit.cache import RateLimitCache
from .utils.service.call import ServiceCallMixin
from .utils.syslog import syslog_message
from .utils.threading import set_thread_name, IoThreadPoolExecutor, io_thread_pool_executor
Expand Down Expand Up @@ -359,10 +359,10 @@ async def on_message(self, message: typing.Dict[str, typing.Any]):
self.send_error(message, e.errno, str(e), sys.exc_info(), extra=e.extra)
error = True

auth_required = not hasattr(methodobj, '_no_auth_required')
if not error:
auth_required = not hasattr(methodobj, '_no_auth_required')
if not auth_required:
ip_added = RateLimitCache.add(message['method'], self.origin)
ip_added = await RateLimitCache.add(message['method'], self.origin)
if ip_added is not None:
if any((
RateLimitCache.max_entries_reached,
Expand All @@ -375,7 +375,7 @@ async def on_message(self, message: typing.Dict[str, typing.Any]):
# origin IP address
# In either scenario, sleep a random delay and send an error
await self.__log_audit_message_for_method(message, methodobj, False, True, False)
await RateLimitCache.sleep_random()
await RateLimitCache.random_sleep()
self.send_error('Rate Limit Exceeded', errno.EBUSY)
error = True
else:
Expand Down
4 changes: 2 additions & 2 deletions src/middlewared/middlewared/plugins/rate_limit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from middlewared.service import periodic, Service

from middlewared.utils.rate_limit import RateLimitCache
from middlewared.utils.rate_limit.cache import RateLimitCache

CLEAR_CACHE_INTERVAL = 600

Expand All @@ -22,4 +22,4 @@ async def clear_cache(self):
# store a maximum of amount of entries in the cache and
# then refuse to honor any more requests for all consumers.
# This is required for STIG purposes.
RateLimitCache.clear()
await RateLimitCache.clear()
1 change: 0 additions & 1 deletion src/middlewared/middlewared/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from .service_part import ServicePartBase # noqa
from .sharing_service import SharingService, SharingTaskService, TaskPathService # noqa
from .system_service import SystemServiceService # noqa
from .throttle import throttle # noqa


ABSTRACT_SERVICES = ( # noqa
Expand Down
119 changes: 55 additions & 64 deletions src/middlewared/middlewared/utils/rate_limit/cache.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,77 @@
from asyncio import sleep
from dataclasses import dataclass
from random import uniform
from threading import RLock
from time import monotonic
from typing import Self, TypedDict

from middlewared.auth import is_ha_connection
from middlewared.utils.origin import TCPIPOrigin

__all__ = ('RateLimitCache')

"""The maximum number of calls per unique consumer of the endpoint."""
MAX_CALLS: int = 20
"""The maximum time in seconds that a unique consumer may request an
endpoint that is being rate limited."""
MAX_PERIOD: int = 60


@dataclass(slots=True, kw_only=True)
class RateLimitObject:
"""A per-{endpoint/consumer} re-entrant lock so that a
global lock is not shared between all (potential)
consumers hitting the same endpoint."""
lock: RLock
"""The number of times this method was called by the consumer."""
num_times_called: int = 0
"""The monotonic time representing when this particular cache
entry was last reset."""
last_reset: float = monotonic()


@dataclass(slots=True)
class RateLimitCachedObjects:
@dataclass(frozen=True)
class RateLimitConfig:
"""The maximum number of calls per unique consumer of the endpoint."""
max_calls: int = 20
"""The maximum time in seconds that a unique consumer may request an
endpoint that is being rate limited."""
max_period: int = 60
"""The maximum number of unique entries the cache supports"""
MAX_CACHE_ENTRIES: int = 100
max_cache_entries: int = 100
"""The value used to separate the unique values when generating
a unique key to be used to store the cached information."""
SEPARATOR: str = '_'
"""The global cache object used to store the information about
all endpoints/consumers being rate limited."""
CACHE: dict[str, RateLimitObject] = dict()
separator: str = '_##_'
"""The starting decimal value for the time to be slept in the event
rate limit thresholds for a particular consumer has been met."""
RANDOM_START: float = 1.0
sleep_start: float = 1.0
"""The ending decimal value for the time to be slept in the event
rate limit thresholds for a particular consumer has been met."""
RANDOM_END: float = 10.0
sleep_end: float = 10.0


class RateLimitObject(TypedDict):
"""The number of times this method was called by the consumer."""
num_times_called: int
"""The monotonic time representing when this particular cache
entry was last reset."""
last_reset: float

@property
def max_entries_reached(self) -> bool:
"""Return a boolean indicating if the total number of entries
in the global cache has reached `self.MAX_CACHE_ENTRIES`."""
return len(self.CACHE) == self.MAX_CACHE_ENTRIES

RL_CACHE: dict[str, RateLimitObject] = dict()



class RateLimit:
def cache_key(self, method_name: str, ip: str) -> str:
"""Generate a unique key per endpoint/consumer"""
return f'{method_name}{self.SEPARATOR}{ip}'
return f'{method_name}{RateLimitConfig.separator}{ip}'

def rate_limit_exceeded(self, method_name: str, ip: str) -> bool:
"""Return a boolean indicating if the total number of calls
per unique endpoint/consumer has been reached."""
key = self.cache_key(method_name, ip)
try:
with self.CACHE[key].lock:
now: float = monotonic()
if MAX_PERIOD - (now - self.CACHE[key].last_reset) <= 0:
# time window elapsed, so time to reset
self.CACHE[key].num_times_called = 0
self.CACHE[key].last_reset = now

# always increment
self.CACHE[key].num_times_called += 1
return self.CACHE[key].num_times_called > MAX_CALLS
now: float = monotonic()
if RateLimitConfig.max_period - (now - RL_CACHE[key]['last_reset']) <= 0:
# time window elapsed, so time to reset
RL_CACHE[key]['num_times_called'] = 0
RL_CACHE[key]['last_reset'] = now

# always increment
RL_CACHE[key]['num_times_called'] += 1
return RL_CACHE[key]['num_times_called'] > RateLimitConfig.max_calls
except KeyError:
pass

return False

def add(self, method_name: str, origin: TCPIPOrigin) -> str | None:
async def add(self, method_name: str, origin: TCPIPOrigin) -> str | None:
"""Add an entry to the cache. Returns the IP address of
origin of the request if it has been cached, returns None otherwise"""
if not isinstance(origin, TCPIPOrigin):
return None

ip, port = origin.addr, origin.port
if any((ip is None, port is None)) or is_ha_connection(ip, port):
# Short-circuit if:
Expand All @@ -89,29 +82,27 @@ def add(self, method_name: str, origin: TCPIPOrigin) -> str | None:
return None
else:
key = self.cache_key(method_name, ip)
if key not in self.CACHE:
self.CACHE[key] = RateLimitObject(lock=RLock())
return ip

return None
if key not in RL_CACHE:
RL_CACHE[key] = RateLimitObject(num_times_called=0, last_reset=monotonic())
return ip

def pop(self, method_name: str, ip: str) -> None:
async def pop(self, method_name: str, ip: str) -> None:
"""Pop (remove) an entry from the cache."""
self.CACHE.pop(self.cache_key(method_name, ip), None)
RL_CACHE.pop(self.cache_key(method_name, ip), None)

def clear(self) -> None:
async def clear(self) -> None:
"""Clear all entries from the cache."""
self.CACHE.clear()

@property
def random_range(self) -> float:
"""Return a random float within self.RANDOM_START and self.RANDOM_END
rounded to the 100th decimal point"""
return round(uniform(self.RANDOM_START, self.RANDOM_END), 2)
RL_CACHE.clear()

async def random_sleep(self) -> None:
"""Sleep a random amount of seconds within range of `self.random_range`."""
await sleep(self.random_range)
"""Sleep a random amount of seconds."""
await sleep(round(uniform(RateLimitConfig.sleep_start, RateLimitConfig.sleep_end), 2))

@property
def max_entries_reached(self) -> bool:
"""Return a boolean indicating if the total number of entries
in the global cache has reached `self.max_cache_entries`."""
return len(RL_CACHE) == RateLimitConfig.max_cache_entries


RateLimitCache = RateLimitCachedObjects()
RateLimitCache = RateLimit()

0 comments on commit 351af9e

Please sign in to comment.