Skip to content

Commit

Permalink
Reraise the exception instead of raising a SystemExit if the authoriz…
Browse files Browse the repository at this point in the history
…ation update cannot be handled, as it may be managed within user-level code.

Ensure resources are cleaned up properly and propagate asyncio.CancelledError whenever it is caught.
  • Loading branch information
pylakey committed Jun 19, 2024
1 parent d0086a7 commit 6c71987
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 47 deletions.
87 changes: 42 additions & 45 deletions aiotdlib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def __init__(self, settings: Optional[ClientSettings] = None):
settings (ClientSettings): Settings for client,
if not provided default settings will be used, including environment variables
"""
self._current_authorization_state = None
self._authorized_event = asyncio.Event()
self._running = False
self._pending_requests: dict[str, PendingRequest] = {}
Expand Down Expand Up @@ -207,6 +206,9 @@ async def __aenter__(self) -> 'Client':
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.stop()

if bool(exc_val):
raise exc_val

async def _call_handler(self, handler: Handler, update: TDLibObject):
try:
await handler(self, update)
Expand Down Expand Up @@ -278,43 +280,39 @@ async def _handle_pending_request(self, update: TDLibObject):
)

async def _updates_loop(self):
try:
async for packet in self.tdjson_client.receive():
if not bool(packet):
continue

try:
update = parse_tdlib_object(packet)
except pydantic.ValidationError as e:
self.logger.error(f'Unable to parse incoming update: {packet}! {e}', exc_info=True)
continue

if isinstance(update, UpdateAuthorizationState):
try:
await self._on_authorization_state_update(update.authorization_state)
except asyncio.CancelledError:
raise
except Exception as e:
self.logger.error(
f'Unable to handle authorization state update {update.model_dump_json()}! {e}',
exc_info=True
)
raise SystemExit from e
async for packet in self.tdjson_client.receive():
if not bool(packet):
continue

continue
try:
update = parse_tdlib_object(packet)
except pydantic.ValidationError as e:
self.logger.error(f'Unable to parse incoming update: {packet}! {e}', exc_info=True)
continue

if isinstance(update, UpdateAuthorizationState):
try:
await self._handle_pending_request(update)
await self._on_authorization_state_update(update.authorization_state)
except asyncio.CancelledError:
self.logger.info("Authorization process has been cancelled")
raise
except Exception as e:
self.logger.error(f'Unable to handle pending request {update}! {e}', exc_info=True)
self.logger.error(
f'Unable to handle authorization state update {update.model_dump_json()}! {e}',
exc_info=True
)
raise
else:
continue

self._create_handler_task(self._handle_update(update))
except asyncio.CancelledError:
self._pending_requests.clear()
self._pending_messages.clear()
raise
try:
await self._handle_pending_request(update)
except asyncio.CancelledError:
raise
except Exception as e:
self.logger.error(f'Unable to handle pending request {update}! {e}', exc_info=True)

self._create_handler_task(self._handle_update(update))

async def _setup_proxy(self):
if not bool(self.settings.proxy_settings):
Expand Down Expand Up @@ -569,14 +567,13 @@ async def _on_authorization_state_update(self, authorization_state: Authorizatio
await action()

async def _cleanup(self):
self._pending_requests.clear()
self._pending_messages.clear()

if bool(self._update_task) and not self._update_task.cancelled():
self.logger.info("Cancelling updates loop task")
self._update_task.cancel()

try:
await self._update_task
except asyncio.CancelledError:
pass
await self._update_task

# Cancel all background handlers tasks
if bool(self._handlers_tasks):
Expand Down Expand Up @@ -683,7 +680,7 @@ async def authorize(self):
GetAuthorizationState()
)

self.logger.info('Waiting for authorization to be completed...')
self.logger.info('Waiting for authorization to be completed')
await self._authorized_event.wait()

async def start(self) -> 'Client':
Expand All @@ -708,22 +705,22 @@ async def start(self) -> 'Client':
await self.authorize()
except asyncio.CancelledError:
await self._cleanup()
raise
else:
self.logger.info('Authorization is completed...')
self.logger.info('Authorization is completed')

return self

async def idle(self):
try:
while True:
while True:
try:
await asyncio.sleep(0.1)
except asyncio.CancelledError:
pass
finally:
self.logger.info('Stop Idling...')
except asyncio.CancelledError:
self.logger.info('Stop Idling')
raise

async def stop(self):
self.logger.info('Stopping telegram client...')
self.logger.info('Stopping telegram client')
await self._cleanup()

# Cache related methods
Expand Down
13 changes: 11 additions & 2 deletions aiotdlib/tdjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ def _unsubscribe(self):
self.td_json.unsubscribe_updates(self.client_id)
self.client_id = None

def _cleanup(self):
# Clear the queue in case of cancellation
while not self._updates_queue.empty():
self._updates_queue.get_nowait()

async def send(self, query: TDJsonQuery):
self._subscribe()
return await asyncio.to_thread(self.td_json.send, self.client_id, query)
Expand All @@ -247,8 +252,12 @@ async def close(self):

async def receive(self) -> typing.AsyncGenerator[dict, None]:
while True:
message = await self._updates_queue.get()
yield message
try:
message = await self._updates_queue.get()
yield message
except asyncio.CancelledError:
self._cleanup()
raise

async def enqueue_update(self, update: dict) -> None:
await self._updates_queue.put(update)

0 comments on commit 6c71987

Please sign in to comment.