From 6c7198767b6769f855af3bcc8f4eb5eef32e0dae Mon Sep 17 00:00:00 2001 From: pylakey Date: Wed, 19 Jun 2024 13:57:07 +0300 Subject: [PATCH] Reraise the exception instead of raising a SystemExit if the authorization update cannot be handled, as it may be managed within user-level code. Ensure resources are cleaned up properly and propagate asyncio.CancelledError whenever it is caught. --- aiotdlib/client.py | 87 ++++++++++++++++++++++------------------------ aiotdlib/tdjson.py | 13 +++++-- 2 files changed, 53 insertions(+), 47 deletions(-) diff --git a/aiotdlib/client.py b/aiotdlib/client.py index b055323..32b7128 100644 --- a/aiotdlib/client.py +++ b/aiotdlib/client.py @@ -113,7 +113,6 @@ def __init__(self, settings: Optional[ClientSettings] = None): settings (ClientSettings): Settings for client, if not provided default settings will be used, including environment variables """ - self._current_authorization_state = None self._authorized_event = asyncio.Event() self._running = False self._pending_requests: dict[str, PendingRequest] = {} @@ -207,6 +206,9 @@ async def __aenter__(self) -> 'Client': async def __aexit__(self, exc_type, exc_val, exc_tb): await self.stop() + if bool(exc_val): + raise exc_val + async def _call_handler(self, handler: Handler, update: TDLibObject): try: await handler(self, update) @@ -278,43 +280,39 @@ async def _handle_pending_request(self, update: TDLibObject): ) async def _updates_loop(self): - try: - async for packet in self.tdjson_client.receive(): - if not bool(packet): - continue - - try: - update = parse_tdlib_object(packet) - except pydantic.ValidationError as e: - self.logger.error(f'Unable to parse incoming update: {packet}! {e}', exc_info=True) - continue - - if isinstance(update, UpdateAuthorizationState): - try: - await self._on_authorization_state_update(update.authorization_state) - except asyncio.CancelledError: - raise - except Exception as e: - self.logger.error( - f'Unable to handle authorization state update {update.model_dump_json()}! {e}', - exc_info=True - ) - raise SystemExit from e + async for packet in self.tdjson_client.receive(): + if not bool(packet): + continue - continue + try: + update = parse_tdlib_object(packet) + except pydantic.ValidationError as e: + self.logger.error(f'Unable to parse incoming update: {packet}! {e}', exc_info=True) + continue + if isinstance(update, UpdateAuthorizationState): try: - await self._handle_pending_request(update) + await self._on_authorization_state_update(update.authorization_state) except asyncio.CancelledError: + self.logger.info("Authorization process has been cancelled") raise except Exception as e: - self.logger.error(f'Unable to handle pending request {update}! {e}', exc_info=True) + self.logger.error( + f'Unable to handle authorization state update {update.model_dump_json()}! {e}', + exc_info=True + ) + raise + else: + continue - self._create_handler_task(self._handle_update(update)) - except asyncio.CancelledError: - self._pending_requests.clear() - self._pending_messages.clear() - raise + try: + await self._handle_pending_request(update) + except asyncio.CancelledError: + raise + except Exception as e: + self.logger.error(f'Unable to handle pending request {update}! {e}', exc_info=True) + + self._create_handler_task(self._handle_update(update)) async def _setup_proxy(self): if not bool(self.settings.proxy_settings): @@ -569,14 +567,13 @@ async def _on_authorization_state_update(self, authorization_state: Authorizatio await action() async def _cleanup(self): + self._pending_requests.clear() + self._pending_messages.clear() + if bool(self._update_task) and not self._update_task.cancelled(): self.logger.info("Cancelling updates loop task") self._update_task.cancel() - - try: - await self._update_task - except asyncio.CancelledError: - pass + await self._update_task # Cancel all background handlers tasks if bool(self._handlers_tasks): @@ -683,7 +680,7 @@ async def authorize(self): GetAuthorizationState() ) - self.logger.info('Waiting for authorization to be completed...') + self.logger.info('Waiting for authorization to be completed') await self._authorized_event.wait() async def start(self) -> 'Client': @@ -708,22 +705,22 @@ async def start(self) -> 'Client': await self.authorize() except asyncio.CancelledError: await self._cleanup() + raise else: - self.logger.info('Authorization is completed...') + self.logger.info('Authorization is completed') return self async def idle(self): - try: - while True: + while True: + try: await asyncio.sleep(0.1) - except asyncio.CancelledError: - pass - finally: - self.logger.info('Stop Idling...') + except asyncio.CancelledError: + self.logger.info('Stop Idling') + raise async def stop(self): - self.logger.info('Stopping telegram client...') + self.logger.info('Stopping telegram client') await self._cleanup() # Cache related methods diff --git a/aiotdlib/tdjson.py b/aiotdlib/tdjson.py index 75e5cc8..ac215db 100644 --- a/aiotdlib/tdjson.py +++ b/aiotdlib/tdjson.py @@ -231,6 +231,11 @@ def _unsubscribe(self): self.td_json.unsubscribe_updates(self.client_id) self.client_id = None + def _cleanup(self): + # Clear the queue in case of cancellation + while not self._updates_queue.empty(): + self._updates_queue.get_nowait() + async def send(self, query: TDJsonQuery): self._subscribe() return await asyncio.to_thread(self.td_json.send, self.client_id, query) @@ -247,8 +252,12 @@ async def close(self): async def receive(self) -> typing.AsyncGenerator[dict, None]: while True: - message = await self._updates_queue.get() - yield message + try: + message = await self._updates_queue.get() + yield message + except asyncio.CancelledError: + self._cleanup() + raise async def enqueue_update(self, update: dict) -> None: await self._updates_queue.put(update)