diff --git a/socketio/sgunicorn.py b/socketio/sgunicorn.py index a5a993a..1ee2d38 100644 --- a/socketio/sgunicorn.py +++ b/socketio/sgunicorn.py @@ -49,90 +49,14 @@ def __init__(self, age, ppid, socket, app, timeout, cfg, log): super(GeventSocketIOBaseWorker, self).__init__( age, ppid, socket, app, timeout, cfg, log) - def run(self): - if gunicorn_version >= (0, 17, 0): - servers = [] - ssl_args = {} - - if self.cfg.is_ssl: - ssl_args = dict( - server_side=True, - do_handshake_on_connect=False, - **self.cfg.ssl_options - ) - - for s in self.sockets: - s.setblocking(1) - pool = Pool(self.worker_connections) - if self.server_class is not None: - self.server_class.base_env['wsgi.multiprocess'] = \ - self.cfg.workers > 1 - - server = self.server_class( - s, - application=self.wsgi, - spawn=pool, - resource=self.resource, - log=self.log, - policy_server=self.policy_server, - handler_class=self.wsgi_handler, - ws_handler_class=self.ws_wsgi_handler, - **ssl_args - ) - else: - hfun = partial(self.handle, s) - server = StreamServer( - s, handle=hfun, spawn=pool, **ssl_args) - - server.start() - servers.append(server) - - pid = os.getpid() - try: - while self.alive: - self.notify() - - if pid == os.getpid() and self.ppid != os.getppid(): - self.log.info( - "Parent changed, shutting down: %s", self) - break - - gevent.sleep(1.0) - - except KeyboardInterrupt: - pass - - try: - # Stop accepting requests - [server.stop_accepting() for server in servers] - - # Handle current requests until graceful_timeout - ts = time.time() - while time.time() - ts <= self.cfg.graceful_timeout: - accepting = 0 - for server in servers: - if server.pool.free_count() != server.pool.size: - accepting += 1 - - if not accepting: - return - - self.notify() - gevent.sleep(1.0) - - # Force kill all active the handlers - self.log.warning("Worker graceful timeout (pid:%s)" % self.pid) - [server.stop(timeout=1) for server in servers] - except: - pass - else: - self.socket.setblocking(1) - pool = Pool(self.worker_connections) - self.server_class.base_env['wsgi.multiprocess'] = \ - self.cfg.workers > 1 + def _start_server(self, socket, ssl_args={}): + socket.setblocking(1) + pool = Pool(self.worker_connections) + if self.server_class is not None: + self.server_class.base_env['wsgi.multiprocess'] = self.cfg.workers > 1 server = self.server_class( - self.socket, + socket, application=self.wsgi, spawn=pool, resource=self.resource, @@ -140,43 +64,67 @@ def run(self): policy_server=self.policy_server, handler_class=self.wsgi_handler, ws_handler_class=self.ws_wsgi_handler, + **ssl_args ) + else: + hfun = partial(self.handle, socket) + server = StreamServer(socket, handle=hfun, spawn=pool, **ssl_args) + + server.start() + return server + + def _run_server(self): + pid = os.getpid() + try: + while self.alive: + self.notify() + + if pid == os.getpid() and self.ppid != os.getppid(): + self.log.info("Parent changed, shutting down: %s", self) + break + + gevent.sleep(1.0) + + except KeyboardInterrupt: + pass + + def _stop_servers(self, servers): + try: + # Stop accepting requests + for server in servers: + server.stop_accepting() + + # Handle current requests until graceful_timeout + ts = time.time() + while time.time() - ts <= self.cfg.graceful_timeout: + servers = [server for server in servers + if server.pool.free_count() != server.pool.size] + if not servers: + break + + self.notify() + gevent.sleep(1.0) + else: + # Force kill all active the handlers + self.log.warning("Worker graceful timeout (pid:%s)" % self.pid) + for server in servers: + server.stop(timeout=1) + except: + pass - server.start() - pid = os.getpid() - - try: - while self.alive: - self.notify() - - if pid == os.getpid() and self.ppid != os.getppid(): - self.log.info( - "Parent changed, shutting down: %s", self) - break - - gevent.sleep(1.0) - - except KeyboardInterrupt: - pass - - try: - # Stop accepting requests - server.kill() - - # Handle current requests until graceful_timeout - ts = time.time() - while time.time() - ts <= self.cfg.graceful_timeout: - if server.pool.free_count() == server.pool.size: - return # all requests was handled + def run(self): + if gunicorn_version >= (0, 17, 0): + ssl_args = {} - self.notify() - gevent.sleep(1.0) + if self.cfg.is_ssl: + ssl_args = dict(server_side=True, + do_handshake_on_connect=False, **self.cfg.ssl_options) - # Force kill all active the handlers - self.log.warning("Worker graceful timeout (pid:%s)" % self.pid) - server.stop(timeout=1) - except: - pass + servers = [self._start_server(s, ssl_args) for s in self.sockets] + else: + servers = [self._start_server(self.socket)] + self._run_server() + self._stop_servers(servers) class GeventSocketIOWorker(GeventSocketIOBaseWorker):