diff --git a/redis/asyncio/retry.py b/redis/asyncio/retry.py index a20f8b4849..87e66543b6 100644 --- a/redis/asyncio/retry.py +++ b/redis/asyncio/retry.py @@ -1,62 +1,20 @@ from asyncio import sleep -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar +from typing import Any, Awaitable, Callable, Tuple, Type, TypeVar from redis.exceptions import ConnectionError, RedisError, TimeoutError - -if TYPE_CHECKING: - from redis.backoff import AbstractBackoff - +from redis.retry import AbstractRetry T = TypeVar("T") -class Retry: - """Retry a specific number of times after a failure""" - - __slots__ = "_backoff", "_retries", "_supported_errors" - - def __init__( - self, - backoff: "AbstractBackoff", - retries: int, - supported_errors: Tuple[Type[RedisError], ...] = ( - ConnectionError, - TimeoutError, - ), - ): - """ - Initialize a `Retry` object with a `Backoff` object - that retries a maximum of `retries` times. - `retries` can be negative to retry forever. - You can specify the types of supported errors which trigger - a retry with the `supported_errors` parameter. - """ - self._backoff = backoff - self._retries = retries - self._supported_errors = supported_errors - - def update_supported_errors(self, specified_errors: list): - """ - Updates the supported errors with the specified error types - """ - self._supported_errors = tuple( - set(self._supported_errors + tuple(specified_errors)) - ) - - def get_retries(self) -> int: - """ - Get the number of retries. - """ - return self._retries - - def update_retries(self, value: int) -> None: - """ - Set the number of retries. - """ - self._retries = value +class Retry(AbstractRetry): + _supported_errors: Tuple[Type[RedisError], ...] = ( + ConnectionError, + TimeoutError, + ) async def call_with_retry( - self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any] + self, do: Callable[[], Awaitable[T]], fail: Callable[[Exception], Any] ) -> T: """ Execute an operation that might fail and returns its result, or diff --git a/redis/backoff.py b/redis/backoff.py index 22a3ed0abb..6e1f68a7ba 100644 --- a/redis/backoff.py +++ b/redis/backoff.py @@ -170,7 +170,7 @@ def __hash__(self) -> int: return hash((self._base, self._cap)) def __eq__(self, other) -> bool: - if not isinstance(other, EqualJitterBackoff): + if not isinstance(other, ExponentialWithJitterBackoff): return NotImplemented return self._base == other._base and self._cap == other._cap diff --git a/redis/retry.py b/redis/retry.py index c93f34e65f..ca79575246 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,6 +1,6 @@ import socket from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar, Union from redis.exceptions import ConnectionError, TimeoutError @@ -10,18 +10,17 @@ from redis.backoff import AbstractBackoff -class Retry: +class AbstractRetry: """Retry a specific number of times after a failure""" + __slots__ = "_backoff", "_retries", "_supported_errors" + _supported_errors: Tuple[Type[Exception], ...] + def __init__( self, backoff: "AbstractBackoff", retries: int, - supported_errors: Tuple[Type[Exception], ...] = ( - ConnectionError, - TimeoutError, - socket.timeout, - ), + supported_errors: Union[Tuple[Type[Exception], ...], None] = None, ): """ Initialize a `Retry` object with a `Backoff` object @@ -32,10 +31,11 @@ def __init__( """ self._backoff = backoff self._retries = retries - self._supported_errors = supported_errors + if supported_errors: + self._supported_errors = supported_errors def __eq__(self, other: Any) -> bool: - if not isinstance(other, Retry): + if not isinstance(other, AbstractRetry): return NotImplemented return ( @@ -69,6 +69,14 @@ def update_retries(self, value: int) -> None: """ self._retries = value + +class Retry(AbstractRetry): + _supported_errors: Tuple[Type[Exception], ...] = ( + ConnectionError, + TimeoutError, + socket.timeout, + ) + def call_with_retry( self, do: Callable[[], T], diff --git a/tests/test_retry.py b/tests/test_retry.py index 4f4f04caca..9c0ca65d81 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,6 +1,7 @@ from unittest.mock import patch import pytest +from redis.asyncio.retry import Retry as AsyncRetry from redis.backoff import ( AbstractBackoff, ConstantBackoff, @@ -89,6 +90,7 @@ def test_retry_on_error_retry(self, Class, retries): assert c.retry._retries == retries +@pytest.mark.parametrize("retry_class", [Retry, AsyncRetry]) @pytest.mark.parametrize( "args", [ @@ -108,8 +110,8 @@ def test_retry_on_error_retry(self, Class, retries): for backoff in ((Backoff(), 2), (Backoff(25), 5), (Backoff(25, 5), 5)) ], ) -def test_retry_eq_and_hashable(args): - assert Retry(*args) == Retry(*args) +def test_retry_eq_and_hashable(retry_class, args): + assert retry_class(*args) == retry_class(*args) # create another retry object with different parameters copy = list(args) @@ -118,9 +120,19 @@ def test_retry_eq_and_hashable(args): else: copy[0] = ConstantBackoff(9000) - assert Retry(*args) != Retry(*copy) - assert Retry(*copy) != Retry(*args) - assert len({Retry(*args), Retry(*args), Retry(*copy), Retry(*copy)}) == 2 + assert retry_class(*args) != retry_class(*copy) + assert retry_class(*copy) != retry_class(*args) + assert ( + len( + { + retry_class(*args), + retry_class(*args), + retry_class(*copy), + retry_class(*copy), + } + ) + == 2 + ) class TestRetry: