Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
features:
- Allow asynchronous callbacks for `before`, `after`, `retry_error_callback`, `wait`, and `before_sleep` parameters.
75 changes: 57 additions & 18 deletions tenacity/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,31 @@

import functools
import sys
import typing
from asyncio import sleep
from asyncio import sleep as aio_sleep
from inspect import iscoroutinefunction
from typing import Any, Awaitable, Callable, TypeVar, Union

from tenacity import AttemptManager
from tenacity import BaseRetrying
from tenacity import DoAttempt
from tenacity import DoSleep
from tenacity import RetryCallState
from tenacity import AttemptManager, BaseRetrying, DoAttempt, DoSleep, RetryAction, RetryCallState, TryAgain

WrappedFn = typing.TypeVar("WrappedFn", bound=typing.Callable)
_RetValT = typing.TypeVar("_RetValT")
WrappedFn = TypeVar("WrappedFn", bound=Callable)
_RetValT = TypeVar("_RetValT")


class AsyncRetrying(BaseRetrying):
def __init__(self, sleep: typing.Callable[[float], typing.Awaitable] = sleep, **kwargs: typing.Any) -> None:
def __init__(self, sleep: Callable[[float], Awaitable] = aio_sleep, **kwargs: Any) -> None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are too many changes that are unrelated to this PR.
Could you avoid changing code that does not need to be changed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was giving warnings because variable and function had same names (sleep), should I make another PR for this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, writing typing. before types every time made some lines too long, and it was looking unnecessary, should I revert that one too?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, anything that is not related to the change should be avoided.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for late response, I will try to find time in upcoming days and split this PR into 2 different PRs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a new PR with AsyncRetries related changes #363
Will make another PR for coding style improvements after that.

Closing this PR (will use this branch as template)

super().__init__(**kwargs)
self.sleep = sleep

async def __call__( # type: ignore # Change signature from supertype
self,
fn: typing.Callable[..., typing.Awaitable[_RetValT]],
*args: typing.Any,
**kwargs: typing.Any,
fn: Callable[..., Awaitable[_RetValT]],
*args: Any,
**kwargs: Any,
) -> _RetValT:
self.begin()

retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
while True:
do = self.iter(retry_state=retry_state)
do = await self.iter(retry_state=retry_state)
if isinstance(do, DoAttempt):
try:
result = await fn(*args, **kwargs)
Expand All @@ -64,9 +60,9 @@ def __aiter__(self) -> "AsyncRetrying":
self._retry_state = RetryCallState(self, fn=None, args=(), kwargs={})
return self

async def __anext__(self) -> typing.Union[AttemptManager, typing.Any]:
async def __anext__(self) -> Union[AttemptManager, Any]:
while True:
do = self.iter(retry_state=self._retry_state)
do = await self.iter(retry_state=self._retry_state)
if do is None:
raise StopAsyncIteration
elif isinstance(do, DoAttempt):
Expand All @@ -82,11 +78,54 @@ def wraps(self, fn: WrappedFn) -> WrappedFn:
# Ensure wrapper is recognized as a coroutine function.

@functools.wraps(fn)
async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
async def async_wrapped(*args: Any, **kwargs: Any) -> Any:
return await fn(*args, **kwargs)

# Preserve attributes
async_wrapped.retry = fn.retry
async_wrapped.retry_with = fn.retry_with

return async_wrapped

@staticmethod
async def handle_custom_function(func: Union[Callable, Awaitable], retry_state: RetryCallState) -> Any:
if iscoroutinefunction(func):
return await func(retry_state)
return func(retry_state)

async def iter(self, retry_state: "RetryCallState") -> Union[DoAttempt, DoSleep, Any]: # noqa
fut = retry_state.outcome
if fut is None:
if self.before is not None:
await self.handle_custom_function(self.before, retry_state)
return DoAttempt()

is_explicit_retry = retry_state.outcome.failed and isinstance(retry_state.outcome.exception(), TryAgain)
if not (is_explicit_retry or self.retry(retry_state=retry_state)):
return fut.result()

if self.after is not None:
await self.handle_custom_function(self.after, retry_state)

self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start
if self.stop(retry_state=retry_state):
if self.retry_error_callback:
return await self.handle_custom_function(self.retry_error_callback, retry_state)
retry_exc = self.retry_error_cls(fut)
if self.reraise:
raise retry_exc.reraise()
raise retry_exc from fut.exception()

if self.wait:
sleep = await self.handle_custom_function(self.wait, retry_state=retry_state)
else:
sleep = 0.0
retry_state.next_action = RetryAction(sleep)
retry_state.idle_for += sleep
self.statistics["idle_for"] += sleep
self.statistics["attempt_number"] += 1

if self.before_sleep is not None:
await self.handle_custom_function(self.before_sleep, retry_state)

return DoSleep(sleep)
137 changes: 131 additions & 6 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@

import asyncio
import inspect
import logging
import unittest
from functools import wraps

from tenacity import AsyncRetrying, RetryError
import tenacity
from tenacity import _asyncio as tasyncio
from tenacity import retry, stop_after_attempt
from tenacity.wait import wait_fixed

from .test_tenacity import NoIOErrorAfterCount, current_time_ms
from .test_tenacity import CapturingHandler, NoIOErrorAfterCount, NoneReturnUntilAfterCount, current_time_ms


def asynctest(callable_):
Expand Down Expand Up @@ -67,7 +68,7 @@ async def test_iscoroutinefunction(self):
@asynctest
async def test_retry_using_async_retying(self):
thing = NoIOErrorAfterCount(5)
retrying = AsyncRetrying()
retrying = tenacity.AsyncRetrying()
await retrying(_async_function, thing)
assert thing.counter == thing.count

Expand All @@ -76,7 +77,7 @@ async def test_stop_after_attempt(self):
thing = NoIOErrorAfterCount(2)
try:
await _retryable_coroutine_with_2_attempts(thing)
except RetryError:
except tenacity.RetryError:
assert thing.counter == 2

def test_repr(self):
Expand All @@ -86,6 +87,31 @@ def test_retry_attributes(self):
assert hasattr(_retryable_coroutine, "retry")
assert hasattr(_retryable_coroutine, "retry_with")

@asynctest
async def test_async_retry_error_callback_handler(self):
num_attempts = 3
self.attempt_counter = 0

async def _retry_error_callback_handler(retry_state: tenacity.RetryCallState):
_retry_error_callback_handler.called_times += 1
return retry_state.outcome

_retry_error_callback_handler.called_times = 0

@retry(
stop=stop_after_attempt(num_attempts),
retry_error_callback=_retry_error_callback_handler,
)
async def _foobar():
self.attempt_counter += 1
raise Exception("This exception should not be raised")

result = await _foobar()

self.assertEqual(_retry_error_callback_handler.called_times, 1)
self.assertEqual(num_attempts, self.attempt_counter)
self.assertIsInstance(result, tenacity.Future)

@asynctest
async def test_attempt_number_is_correct_for_interleaved_coroutines(self):

Expand Down Expand Up @@ -125,7 +151,7 @@ async def test_do_max_attempts(self):
with attempt:
attempts += 1
raise Exception
except RetryError:
except tenacity.RetryError:
pass

assert attempts == 3
Expand All @@ -151,11 +177,110 @@ async def test_sleeps(self):
async for attempt in tasyncio.AsyncRetrying(stop=stop_after_attempt(1), wait=wait_fixed(1)):
with attempt:
raise Exception()
except RetryError:
except tenacity.RetryError:
pass
t = current_time_ms() - start
self.assertLess(t, 1.1)


class TestAsyncBeforeAfterAttempts(unittest.TestCase):
_attempt_number = 0

@asynctest
async def test_before_attempts(self):
TestAsyncBeforeAfterAttempts._attempt_number = 0

async def _before(retry_state):
TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number

@retry(
wait=tenacity.wait_fixed(1),
stop=tenacity.stop_after_attempt(1),
before=_before,
)
async def _test_before():
pass

await _test_before()

self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 1)

@asynctest
async def test_after_attempts(self):
TestAsyncBeforeAfterAttempts._attempt_number = 0

async def _after(retry_state):
TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number

@retry(
wait=tenacity.wait_fixed(0.1),
stop=tenacity.stop_after_attempt(3),
after=_after,
)
async def _test_after():
if TestAsyncBeforeAfterAttempts._attempt_number < 2:
raise Exception("testing after_attempts handler")
else:
pass

await _test_after()

self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 2)

@asynctest
async def test_before_sleep(self):
async def _before_sleep(retry_state):
self.assertGreater(retry_state.next_action.sleep, 0)
_before_sleep.attempt_number = retry_state.attempt_number

_before_sleep.attempt_number = 0

@retry(
wait=tenacity.wait_fixed(0.01),
stop=tenacity.stop_after_attempt(3),
before_sleep=_before_sleep,
)
async def _test_before_sleep():
if _before_sleep.attempt_number < 2:
raise Exception("testing before_sleep_attempts handler")

await _test_before_sleep()
self.assertEqual(_before_sleep.attempt_number, 2)

async def _test_before_sleep_log_returns(self, exc_info):
thing = NoneReturnUntilAfterCount(2)
logger = logging.getLogger(self.id())
logger.propagate = False
logger.setLevel(logging.INFO)
handler = CapturingHandler()
logger.addHandler(handler)
try:
_before_sleep = tenacity.before_sleep_log(logger, logging.INFO, exc_info=exc_info)
_retry = tenacity.retry_if_result(lambda result: result is None)
retrying = tenacity.AsyncRetrying(
wait=tenacity.wait_fixed(0.01),
stop=tenacity.stop_after_attempt(3),
retry=_retry,
before_sleep=_before_sleep,
)
await retrying(_async_function, thing)
finally:
logger.removeHandler(handler)

etalon_re = r"^Retrying .* in 0\.01 seconds as it returned None\.$"
self.assertEqual(len(handler.records), 2)
fmt = logging.Formatter().format
self.assertRegex(fmt(handler.records[0]), etalon_re)
self.assertRegex(fmt(handler.records[1]), etalon_re)

@asynctest
async def test_before_sleep_log_returns_without_exc_info(self):
await self._test_before_sleep_log_returns(exc_info=False)

@asynctest
async def test_before_sleep_log_returns_with_exc_info(self):
await self._test_before_sleep_log_returns(exc_info=True)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions tests/test_tenacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,8 @@ def _before_sleep(retry_state):
self.assertGreater(retry_state.next_action.sleep, 0)
_before_sleep.attempt_number = retry_state.attempt_number

_before_sleep.attempt_number = 0

@retry(
wait=tenacity.wait_fixed(0.01),
stop=tenacity.stop_after_attempt(3),
Expand Down