diff --git a/tosfs/retry.py b/tosfs/retry.py index 4bab0f1..1d9621d 100644 --- a/tosfs/retry.py +++ b/tosfs/retry.py @@ -15,9 +15,10 @@ """The module contains retry utility functions for the tosfs stability.""" import math import time -from typing import Any, Optional, Tuple +from typing import Any, Optional, Tuple, Union import requests +import urllib3.exceptions from requests import RequestException from requests.exceptions import ( ChunkedEncodingError, @@ -66,6 +67,7 @@ Timeout, ConnectTimeout, ReadTimeout, + urllib3.exceptions.ReadTimeoutError, StreamConsumedError, RetryError, InterruptedError, @@ -98,38 +100,42 @@ def retryable_func_executor( try: return func(*args, **kwargs) except TosError as e: - from tosfs.core import logger - - if attempt >= max_retry_num: - logger.error("Retry exhausted after %d times.", max_retry_num) - raise e - - if is_retryable_exception(e): - logger.warning( - "Retry TOS request in the %d times, error: %s", attempt, e - ) - try: - sleep_time = _get_sleep_time(e, attempt) - time.sleep(sleep_time) - except InterruptedError as ie: - raise TosfsError(f"Request {func} interrupted.") from ie - else: - _rethrow_retryable_exception(e) - # Note: maybe not all the retryable exceptions are warped by `TosError` - # Will pay attention to those cases + _do_retry(e, func, attempt, max_retry_num) except Exception as e: - raise TosfsError(f"{e}") from e + _do_retry(e, func, attempt, max_retry_num) + + +def _do_retry( + e: Union[TosError, Exception], func: Any, attempt: int, max_retry_num: int +) -> None: + from tosfs.core import logger + + if attempt >= max_retry_num: + logger.error("Retry exhausted after %d times.", max_retry_num) + raise e + + if is_retryable_exception(e): + logger.warning("Retry TOS request in the %d times, error: %s", attempt, e) + try: + sleep_time = _get_sleep_time(e, attempt) + time.sleep(sleep_time) + except InterruptedError as ie: + raise TosfsError(f"Request {func} interrupted.") from ie + else: + _rethrow_retryable_exception(e) -def _rethrow_retryable_exception(e: TosError) -> None: +def _rethrow_retryable_exception(e: Union[Exception, TosError]) -> None: """For debug purpose.""" raise e def is_retryable_exception(e: TosError) -> bool: """Check if the exception is retryable.""" - return _is_retryable_tos_server_exception(e) or _is_retryable_tos_client_exception( - e + return ( + _is_retryable_tos_server_exception(e) + or _is_retryable_tos_client_exception(e) + or _is_retryable_general_client_exception(e) ) @@ -162,6 +168,15 @@ def _is_retryable_tos_client_exception(e: TosError) -> bool: return False +def _is_retryable_general_client_exception(e: Optional[Exception]) -> bool: + while e is not None: + for excp in TOS_CLIENT_RETRYABLE_EXCEPTIONS: + if isinstance(e, excp): + return True + e = getattr(e, "cause", None) + return False + + def _get_sleep_time(err: TosError, retry_count: int) -> float: sleep_time = SLEEP_BASE_SECONDS * math.pow(2, retry_count) sleep_time = min(sleep_time, SLEEP_MAX_SECONDS) diff --git a/tosfs/tests/test_retry.py b/tosfs/tests/test_retry.py index cbf5096..79590d8 100644 --- a/tosfs/tests/test_retry.py +++ b/tosfs/tests/test_retry.py @@ -18,7 +18,8 @@ import requests from tos.exceptions import TosClientError, TosServerError from tos.http import Response -from urllib3.exceptions import ProtocolError +from urllib3 import HTTPConnectionPool +from urllib3.exceptions import ProtocolError, ReadTimeoutError from tosfs.retry import _get_sleep_time, is_retryable_exception @@ -79,6 +80,16 @@ ), True, ), + ( + requests.exceptions.ConnectionError( + HTTPConnectionPool(host="proton-ci.tos-cn-beijing.volces.com", port=80) + ), + True, + ), + ( + ReadTimeoutError(None, message="", url=""), + True, + ), ], ) def test_is_retry_exception(