diff --git a/CHANGES.rst b/CHANGES.rst index 5bef1c3bfb..99918c812b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,15 @@ +.. scm-version-title:: v8.4.4 + +- :issue:`304` via :pr:`309`: Refactored :py:class:`~\ + cheroot.connections.ConnectionManager` to use :py:meth:`~\ + selectors.BaseSelector.get_map` and reorganized the + readable connection tracking -- by :user:`liamstask`. + +- :issue:`304` via :pr:`309`: Fixed the server shutdown + sequence to avoid race condition resulting in accepting + new connections while it is being terminated + -- by :user:`liamstask`. + .. scm-version-title:: v8.4.3 - :pr:`282`: Fixed a race condition happening when an HTTP diff --git a/cheroot/connections.py b/cheroot/connections.py index 89fd204e56..b416077c8c 100644 --- a/cheroot/connections.py +++ b/cheroot/connections.py @@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function __metaclass__ = type +import collections import io import os import socket @@ -10,6 +11,7 @@ from . import errors from ._compat import selectors +from ._compat import suppress from .makefile import MakeFile import six @@ -61,9 +63,14 @@ def __init__(self, server): that uses this ConnectionManager instance. """ self.server = server - self.connections = [] + self._readable_conns = collections.deque() self._selector = selectors.DefaultSelector() + self._selector.register( + server.socket.fileno(), + selectors.EVENT_READ, data=server, + ) + def put(self, conn): """Put idle connection into the ConnectionManager to be managed. @@ -72,8 +79,14 @@ def put(self, conn): to be managed. """ conn.last_used = time.time() - conn.ready_with_data = conn.rfile.has_data() - self.connections.append(conn) + # if this conn doesn't have any more data waiting to be read, + # register it with the selector. + if conn.rfile.has_data(): + self._readable_conns.append(conn) + else: + self._selector.register( + conn.socket.fileno(), selectors.EVENT_READ, data=conn, + ) def expire(self): """Expire least recently used connections. @@ -83,21 +96,19 @@ def expire(self): This should be called periodically. """ - if not self.connections: - return - - # Look at the first connection - if it can be closed, then do - # that, and wait for get_conn to return it. - conn = self.connections[0] - if conn.closeable: - return - - # Connection too old? - if (conn.last_used + self.server.timeout) < time.time(): - conn.closeable = True - return + # find any connections still registered with the selector + # that have not been active recently enough. + threshold = time.time() - self.server.timeout + timed_out_connections = ( + (sock_fd, conn) + for _, (_, sock_fd, _, conn) in self._selector.get_map().items() + if conn != self.server and conn.last_used < threshold + ) + for sock_fd, conn in timed_out_connections: + self._selector.unregister(sock_fd) + conn.close() - def get_conn(self, server_socket): + def get_conn(self): """Return a HTTPConnection object which is ready to be handled. A connection returned by this method should be ready for a worker @@ -107,88 +118,57 @@ def get_conn(self, server_socket): Any connection returned by this method will need to be `put` back if it should be examined again for another request. - Args: - server_socket (socket.socket): Socket to listen to for new - connections. Returns: cheroot.server.HTTPConnection instance, or None. """ - # Grab file descriptors from sockets, but stop if we find a - # connection which is already marked as ready. - socket_dict = {} - for conn in self.connections: - if conn.closeable or conn.ready_with_data: - break - socket_dict[conn.socket.fileno()] = conn - else: - # No ready connection. - conn = None - - # We have a connection ready for use. - if conn: - self.connections.remove(conn) - return conn + # return a readable connection if any exist + with suppress(IndexError): + return self._readable_conns.popleft() # Will require a select call. - ss_fileno = server_socket.fileno() - socket_dict[ss_fileno] = server_socket try: - for fno in socket_dict: - self._selector.register(fno, selectors.EVENT_READ) # The timeout value impacts performance and should be carefully # chosen. Ref: # github.com/cherrypy/cheroot/issues/305#issuecomment-663985165 rlist = [ - key.fd for key, _event + key for key, _ in self._selector.select(timeout=0.01) ] except OSError: - # Mark any connection which no longer appears valid. - for fno, conn in list(socket_dict.items()): + # Mark any connection which no longer appears valid + for _, key in self._selector.get_map().items(): # If the server socket is invalid, we'll just ignore it and # wait to be shutdown. - if fno == ss_fileno: + if key.data == self.server: continue + try: - os.fstat(fno) + os.fstat(key.fd) except OSError: - # Socket is invalid, close the connection, insert at - # the front. - self.connections.remove(conn) - self.connections.insert(0, conn) - conn.closeable = True + # Socket is invalid, close the connection + self._selector.unregister(key.fd) + conn = key.data + conn.close() # Wait for the next tick to occur. return None - finally: - for fno in socket_dict: - self._selector.unregister(fno) - - try: - # See if we have a new connection coming in. - rlist.remove(ss_fileno) - except ValueError: - # If we didn't get any readable sockets, wait for the next tick - if not rlist: - return None - - # No new connection, but reuse an existing socket. - conn = socket_dict[rlist.pop()] - else: - # If we have a new connection, reuse the server socket - conn = server_socket - # All remaining connections in rlist should be marked as ready. - for fno in rlist: - socket_dict[fno].ready_with_data = True + for key in rlist: + if key.data is self.server: + # New connection + return self._from_server_socket(self.server.socket) - # New connection. - if conn is server_socket: - return self._from_server_socket(server_socket) + conn = key.data + # unregister connection from the selector until the server + # has read from it and returned it via put() + self._selector.unregister(key.fd) + self._readable_conns.append(conn) - self.connections.remove(conn) - return conn + try: + return self._readable_conns.popleft() + except IndexError: + return None def _from_server_socket(self, server_socket): try: @@ -282,12 +262,28 @@ def _from_server_socket(self, server_socket): def close(self): """Close all monitored connections.""" - for conn in self.connections[:]: + for conn in self._readable_conns: conn.close() - self.connections = [] + self._readable_conns.clear() + + for _, key in self._selector.get_map().items(): + if key.data != self.server: # server closes its own socket + key.data.socket.close() + + self._selector.close() + + @property + def _num_connections(self): # noqa: D401 + """The current number of connections. + + Includes any in the readable list or registered with the selector, + minus one for the server socket, which is always registered + with the selector. + """ + return len(self._readable_conns) + len(self._selector.get_map()) - 1 @property def can_add_keepalive_connection(self): """Flag whether it is allowed to add a new keep-alive connection.""" ka_limit = self.server.keep_alive_conn_limit - return ka_limit is None or len(self.connections) < ka_limit + return ka_limit is None or self._num_connections < ka_limit diff --git a/cheroot/server.py b/cheroot/server.py index d6f89dbcdb..44075958cb 100644 --- a/cheroot/server.py +++ b/cheroot/server.py @@ -1231,9 +1231,7 @@ class HTTPConnection: peercreds_resolve_enabled = False # Fields set by ConnectionManager. - closeable = False last_used = None - ready_with_data = False def __init__(self, server, sock, makefile=MakeFile): """Initialize HTTPConnection instance. @@ -1587,7 +1585,7 @@ def __init__( self.requests = threadpool.ThreadPool( self, min=minthreads or 1, max=maxthreads, ) - self.connections = connections.ConnectionManager(self) + self.serving = False if not server_name: server_name = self.version @@ -1781,6 +1779,8 @@ def prepare(self): self.socket.settimeout(1) self.socket.listen(self.request_queue_size) + self.connections = connections.ConnectionManager(self) + # Create worker threads self.requests.start() @@ -1789,6 +1789,7 @@ def prepare(self): def serve(self): """Serve requests, after invoking :func:`prepare()`.""" + self.serving = True while self.ready: try: self.tick() @@ -1800,12 +1801,7 @@ def serve(self): traceback=True, ) - if self.interrupt: - while self.interrupt is True: - # Wait for self.stop() to complete. See _set_interrupt. - time.sleep(0.1) - if self.interrupt: - raise self.interrupt + self.serving = False def start(self): """Run the server forever. @@ -2023,10 +2019,7 @@ def resolve_real_bind_addr(socket_): def tick(self): """Accept a new connection and put it on the Queue.""" - if not self.ready: - return - - conn = self.connections.get_conn(self.socket) + conn = self.connections.get_conn() if conn: try: self.requests.put(conn) @@ -2047,6 +2040,8 @@ def interrupt(self, interrupt): self._interrupt = True self.stop() self._interrupt = interrupt + if self._interrupt: + raise self.interrupt def stop(self): """Gracefully shutdown a server that is serving forever.""" @@ -2055,6 +2050,10 @@ def stop(self): self._run_time += (time.time() - self._start_time) self._start_time = None + # ensure serve is no longer accessing socket, connections + while self.serving: + time.sleep(0.1) + sock = getattr(self, 'socket', None) if sock: if not isinstance( diff --git a/cheroot/workers/threadpool.py b/cheroot/workers/threadpool.py index 4466e7a143..b9987d9de3 100644 --- a/cheroot/workers/threadpool.py +++ b/cheroot/workers/threadpool.py @@ -111,11 +111,6 @@ def run(self): if conn is _SHUTDOWNREQUEST: return - # Just close the connection and move on. - if conn.closeable: - conn.close() - continue - self.conn = conn is_stats_enabled = self.server.stats['Enabled'] if is_stats_enabled: diff --git a/docs/conf.py b/docs/conf.py index cca80af33a..60959201d4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -92,6 +92,7 @@ # NOTE: consider having a separate ignore file # Ref: https://stackoverflow.com/a/30624034/595220 nitpick_ignore = [ + ('py:class', 'cheroot.connections.ConnectionManager'), ('py:const', 'socket.SO_PEERCRED'), ('py:class', '_pyio.BufferedWriter'), ('py:class', '_pyio.BufferedReader'),