From e13511914bc154cfde8fbcb0561de3b26137735b Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Mon, 23 Sep 2024 15:47:12 -0700 Subject: [PATCH] Cleanup after clients upon disconnection --- neon_hana/app/routers/node_server.py | 2 +- neon_hana/mq_websocket_api.py | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index c17d9b3..6badae3 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -65,7 +65,7 @@ async def node_v1_endpoint(websocket: WebSocket, token: str): socket_api.handle_client_input(client_in, client_id) except WebSocketDisconnect: disconnect_event.set() - # TODO: Delete client from socket_api + socket_api.end_session(session_id=client_id) @node_route.websocket("/v1/stream") diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py index 6e87323..c4e285d 100644 --- a/neon_hana/mq_websocket_api.py +++ b/neon_hana/mq_websocket_api.py @@ -71,6 +71,20 @@ def new_connection(self, ws: WebSocket, session_id: str): "socket": ws, "user": self.user_config} + def end_session(self, session_id: str): + """ + End a client connection upon WS disconnection + """ + session: Optional[dict] = self._sessions.pop(session_id, None) + if not session: + LOG.error(f"Ended session is not established {session_id}") + return + stream: RemoteStreamHandler = session.get('stream') + if stream: + stream.shutdown() + stream.join() + LOG.info(f"Ended stream handler for: {session_id}") + def get_session(self, session_id: str) -> dict: """ Get the latest session context for the given session_id. @@ -245,7 +259,7 @@ def start(self): pass def stop(self): - pass + self.queue.put(None) def read_chunk(self) -> Optional[bytes]: return self.queue.get() @@ -302,6 +316,10 @@ def on_audio(self, audio_bytes: bytes, context: dict): def on_chunk(self, chunk: ChunkInfo): LOG.debug(f"Chunk: {chunk}") + def shutdown(self): + self.mic.stop() + self.voice_loop.stop() + class MockTransformers(Mock): def transform(self, chunk):