Skip to content

Commit 7f72b80

Browse files
committed
Allow asynchronous callbacks for async retries.
Fixes #249
1 parent 18d05a6 commit 7f72b80

File tree

4 files changed

+194
-25
lines changed

4 files changed

+194
-25
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
---
2+
features:
3+
- Allow asynchronous callbacks for `before`, `after`, `retry_error_callback`, `wait`, and `before_sleep` parameters.

tenacity/_asyncio.py

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,32 @@
1717

1818
import functools
1919
import sys
20-
import typing
21-
from asyncio import sleep
20+
from asyncio import sleep as aio_sleep
21+
from collections.abc import Awaitable
22+
from inspect import iscoroutinefunction
23+
from typing import Union, Callable, Any, TypeVar
2224

23-
from tenacity import AttemptManager
24-
from tenacity import BaseRetrying
25-
from tenacity import DoAttempt
26-
from tenacity import DoSleep
27-
from tenacity import RetryCallState
25+
from tenacity import AttemptManager, BaseRetrying, DoAttempt, DoSleep, RetryCallState, RetryAction, TryAgain
2826

29-
WrappedFn = typing.TypeVar("WrappedFn", bound=typing.Callable)
30-
_RetValT = typing.TypeVar("_RetValT")
27+
WrappedFn = TypeVar("WrappedFn", bound=Callable)
28+
_RetValT = TypeVar("_RetValT")
3129

3230

3331
class AsyncRetrying(BaseRetrying):
34-
def __init__(self, sleep: typing.Callable[[float], typing.Awaitable] = sleep, **kwargs: typing.Any) -> None:
32+
def __init__(self, sleep: Callable[[float], Awaitable] = aio_sleep, **kwargs: Any) -> None:
3533
super().__init__(**kwargs)
3634
self.sleep = sleep
3735

3836
async def __call__( # type: ignore # Change signature from supertype
3937
self,
40-
fn: typing.Callable[..., typing.Awaitable[_RetValT]],
41-
*args: typing.Any,
42-
**kwargs: typing.Any,
38+
fn: Callable[..., Awaitable[_RetValT]],
39+
*args: Any,
40+
**kwargs: Any,
4341
) -> _RetValT:
4442
self.begin()
45-
4643
retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
4744
while True:
48-
do = self.iter(retry_state=retry_state)
45+
do = await self.iter(retry_state=retry_state)
4946
if isinstance(do, DoAttempt):
5047
try:
5148
result = await fn(*args, **kwargs)
@@ -64,9 +61,9 @@ def __aiter__(self) -> "AsyncRetrying":
6461
self._retry_state = RetryCallState(self, fn=None, args=(), kwargs={})
6562
return self
6663

67-
async def __anext__(self) -> typing.Union[AttemptManager, typing.Any]:
64+
async def __anext__(self) -> Union[AttemptManager, Any]:
6865
while True:
69-
do = self.iter(retry_state=self._retry_state)
66+
do = await self.iter(retry_state=self._retry_state)
7067
if do is None:
7168
raise StopAsyncIteration
7269
elif isinstance(do, DoAttempt):
@@ -82,11 +79,54 @@ def wraps(self, fn: WrappedFn) -> WrappedFn:
8279
# Ensure wrapper is recognized as a coroutine function.
8380

8481
@functools.wraps(fn)
85-
async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
82+
async def async_wrapped(*args: Any, **kwargs: Any) -> Any:
8683
return await fn(*args, **kwargs)
8784

8885
# Preserve attributes
8986
async_wrapped.retry = fn.retry
9087
async_wrapped.retry_with = fn.retry_with
9188

9289
return async_wrapped
90+
91+
@staticmethod
92+
async def handle_custom_function(func: Union[Callable, Awaitable], retry_state: RetryCallState) -> Any:
93+
if iscoroutinefunction(func):
94+
return await func(retry_state)
95+
return func(retry_state)
96+
97+
async def iter(self, retry_state: "RetryCallState") -> Union[DoAttempt, DoSleep, Any]:
98+
fut = retry_state.outcome
99+
if fut is None:
100+
if self.before is not None:
101+
await self.handle_custom_function(self.before, retry_state)
102+
return DoAttempt()
103+
104+
is_explicit_retry = retry_state.outcome.failed and isinstance(retry_state.outcome.exception(), TryAgain)
105+
if not (is_explicit_retry or self.retry(retry_state=retry_state)):
106+
return fut.result()
107+
108+
if self.after is not None:
109+
await self.handle_custom_function(self.after, retry_state)
110+
111+
self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start
112+
if self.stop(retry_state=retry_state):
113+
if self.retry_error_callback:
114+
return await self.handle_custom_function(self.retry_error_callback, retry_state)
115+
retry_exc = self.retry_error_cls(fut)
116+
if self.reraise:
117+
raise retry_exc.reraise()
118+
raise retry_exc from fut.exception()
119+
120+
if self.wait:
121+
sleep = await self.handle_custom_function(self.wait, retry_state=retry_state)
122+
else:
123+
sleep = 0.0
124+
retry_state.next_action = RetryAction(sleep)
125+
retry_state.idle_for += sleep
126+
self.statistics["idle_for"] += sleep
127+
self.statistics["attempt_number"] += 1
128+
129+
if self.before_sleep is not None:
130+
await self.handle_custom_function(self.before_sleep, retry_state)
131+
132+
return DoSleep(sleep)

tests/test_asyncio.py

Lines changed: 131 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515

1616
import asyncio
1717
import inspect
18+
import logging
1819
import unittest
1920
from functools import wraps
2021

21-
from tenacity import AsyncRetrying, RetryError
22+
import tenacity
2223
from tenacity import _asyncio as tasyncio
2324
from tenacity import retry, stop_after_attempt
2425
from tenacity.wait import wait_fixed
25-
26-
from .test_tenacity import NoIOErrorAfterCount, current_time_ms
26+
from .test_tenacity import CapturingHandler, NoneReturnUntilAfterCount, NoIOErrorAfterCount, current_time_ms
2727

2828

2929
def asynctest(callable_):
@@ -67,7 +67,7 @@ async def test_iscoroutinefunction(self):
6767
@asynctest
6868
async def test_retry_using_async_retying(self):
6969
thing = NoIOErrorAfterCount(5)
70-
retrying = AsyncRetrying()
70+
retrying = tenacity.AsyncRetrying()
7171
await retrying(_async_function, thing)
7272
assert thing.counter == thing.count
7373

@@ -76,7 +76,7 @@ async def test_stop_after_attempt(self):
7676
thing = NoIOErrorAfterCount(2)
7777
try:
7878
await _retryable_coroutine_with_2_attempts(thing)
79-
except RetryError:
79+
except tenacity.RetryError:
8080
assert thing.counter == 2
8181

8282
def test_repr(self):
@@ -86,6 +86,31 @@ def test_retry_attributes(self):
8686
assert hasattr(_retryable_coroutine, "retry")
8787
assert hasattr(_retryable_coroutine, "retry_with")
8888

89+
@asynctest
90+
async def test_async_retry_error_callback_handler(self):
91+
num_attempts = 3
92+
self.attempt_counter = 0
93+
94+
async def _retry_error_callback_handler(retry_state: tenacity.RetryCallState):
95+
_retry_error_callback_handler.called_times += 1
96+
return retry_state.outcome
97+
98+
_retry_error_callback_handler.called_times = 0
99+
100+
@retry(
101+
stop=stop_after_attempt(num_attempts),
102+
retry_error_callback=_retry_error_callback_handler,
103+
)
104+
async def _foobar():
105+
self.attempt_counter += 1
106+
raise Exception("This exception should not be raised")
107+
108+
result = await _foobar()
109+
110+
self.assertEqual(_retry_error_callback_handler.called_times, 1)
111+
self.assertEqual(num_attempts, self.attempt_counter)
112+
self.assertIsInstance(result, tenacity.Future)
113+
89114
@asynctest
90115
async def test_attempt_number_is_correct_for_interleaved_coroutines(self):
91116

@@ -125,7 +150,7 @@ async def test_do_max_attempts(self):
125150
with attempt:
126151
attempts += 1
127152
raise Exception
128-
except RetryError:
153+
except tenacity.RetryError:
129154
pass
130155

131156
assert attempts == 3
@@ -151,11 +176,110 @@ async def test_sleeps(self):
151176
async for attempt in tasyncio.AsyncRetrying(stop=stop_after_attempt(1), wait=wait_fixed(1)):
152177
with attempt:
153178
raise Exception()
154-
except RetryError:
179+
except tenacity.RetryError:
155180
pass
156181
t = current_time_ms() - start
157182
self.assertLess(t, 1.1)
158183

159184

185+
class TestAsyncBeforeAfterAttempts(unittest.TestCase):
186+
_attempt_number = 0
187+
188+
@asynctest
189+
async def test_before_attempts(self):
190+
TestAsyncBeforeAfterAttempts._attempt_number = 0
191+
192+
async def _before(retry_state):
193+
TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number
194+
195+
@retry(
196+
wait=tenacity.wait_fixed(1),
197+
stop=tenacity.stop_after_attempt(1),
198+
before=_before,
199+
)
200+
async def _test_before():
201+
pass
202+
203+
await _test_before()
204+
205+
self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 1)
206+
207+
@asynctest
208+
async def test_after_attempts(self):
209+
TestAsyncBeforeAfterAttempts._attempt_number = 0
210+
211+
async def _after(retry_state):
212+
TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number
213+
214+
@retry(
215+
wait=tenacity.wait_fixed(0.1),
216+
stop=tenacity.stop_after_attempt(3),
217+
after=_after,
218+
)
219+
async def _test_after():
220+
if TestAsyncBeforeAfterAttempts._attempt_number < 2:
221+
raise Exception("testing after_attempts handler")
222+
else:
223+
pass
224+
225+
await _test_after()
226+
227+
self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 2)
228+
229+
@asynctest
230+
async def test_before_sleep(self):
231+
async def _before_sleep(retry_state):
232+
self.assertGreater(retry_state.next_action.sleep, 0)
233+
_before_sleep.attempt_number = retry_state.attempt_number
234+
235+
_before_sleep.attempt_number = 0
236+
237+
@retry(
238+
wait=tenacity.wait_fixed(0.01),
239+
stop=tenacity.stop_after_attempt(3),
240+
before_sleep=_before_sleep,
241+
)
242+
async def _test_before_sleep():
243+
if _before_sleep.attempt_number < 2:
244+
raise Exception("testing before_sleep_attempts handler")
245+
246+
await _test_before_sleep()
247+
self.assertEqual(_before_sleep.attempt_number, 2)
248+
249+
async def _test_before_sleep_log_returns(self, exc_info):
250+
thing = NoneReturnUntilAfterCount(2)
251+
logger = logging.getLogger(self.id())
252+
logger.propagate = False
253+
logger.setLevel(logging.INFO)
254+
handler = CapturingHandler()
255+
logger.addHandler(handler)
256+
try:
257+
_before_sleep = tenacity.before_sleep_log(logger, logging.INFO, exc_info=exc_info)
258+
_retry = tenacity.retry_if_result(lambda result: result is None)
259+
retrying = tenacity.AsyncRetrying(
260+
wait=tenacity.wait_fixed(0.01),
261+
stop=tenacity.stop_after_attempt(3),
262+
retry=_retry,
263+
before_sleep=_before_sleep,
264+
)
265+
await retrying(_async_function, thing)
266+
finally:
267+
logger.removeHandler(handler)
268+
269+
etalon_re = r"^Retrying .* in 0\.01 seconds as it returned None\.$"
270+
self.assertEqual(len(handler.records), 2)
271+
fmt = logging.Formatter().format
272+
self.assertRegex(fmt(handler.records[0]), etalon_re)
273+
self.assertRegex(fmt(handler.records[1]), etalon_re)
274+
275+
@asynctest
276+
async def test_before_sleep_log_returns_without_exc_info(self):
277+
await self._test_before_sleep_log_returns(exc_info=False)
278+
279+
@asynctest
280+
async def test_before_sleep_log_returns_with_exc_info(self):
281+
await self._test_before_sleep_log_returns(exc_info=True)
282+
283+
160284
if __name__ == "__main__":
161285
unittest.main()

tests/test_tenacity.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,8 @@ def _before_sleep(retry_state):
10931093
self.assertGreater(retry_state.next_action.sleep, 0)
10941094
_before_sleep.attempt_number = retry_state.attempt_number
10951095

1096+
_before_sleep.attempt_number = 0
1097+
10961098
@retry(
10971099
wait=tenacity.wait_fixed(0.01),
10981100
stop=tenacity.stop_after_attempt(3),

0 commit comments

Comments
 (0)