Skip to content

Commit 5a1838b

Browse files
committed
add executor.map() functionality
1 parent 474d51d commit 5a1838b

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

src/qasync/__init__.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,25 @@ def submit(self, callback, *args, **kwargs):
214214
return future
215215

216216
def map(self, func, *iterables, timeout=None):
217-
raise NotImplementedError("use as_completed on the event loop")
217+
deadline = time.monotonic() + timeout if timeout is not None else None
218+
futures = [self.submit(func, *args) for args in zip(*iterables)]
219+
220+
# must have generator as a closure so that the submit occurs before first iteration
221+
def generator():
222+
try:
223+
futures.reverse()
224+
while futures:
225+
if deadline is not None:
226+
yield _result_or_cancel(
227+
futures.pop(), timeout=deadline - time.monotonic()
228+
)
229+
else:
230+
yield _result_or_cancel(futures.pop())
231+
finally:
232+
for future in futures:
233+
future.cancel()
234+
235+
return generator()
218236

219237
def shutdown(self, wait=True, *, cancel_futures=False):
220238
with self.__shutdown_lock:
@@ -241,6 +259,16 @@ def __exit__(self, *args):
241259
self.shutdown()
242260

243261

262+
def _result_or_cancel(fut, timeout=None):
263+
try:
264+
try:
265+
return fut.result(timeout)
266+
finally:
267+
fut.cancel()
268+
finally:
269+
del fut # break reference cycle in exceptions
270+
271+
244272
def _format_handle(handle: asyncio.Handle):
245273
cb = getattr(handle, "_callback", None)
246274
if isinstance(getattr(cb, "__self__", None), asyncio.tasks.Task):

tests/test_qthreadexec.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import threading
77
import time
88
import weakref
9-
from concurrent.futures import CancelledError
9+
from concurrent.futures import CancelledError, TimeoutError
1010

1111
import pytest
1212

@@ -167,3 +167,32 @@ def task():
167167
assert cancels > 0
168168
else:
169169
assert cancels == 0
170+
171+
172+
def test_map(executor):
173+
"""Basic test of executor map functionality"""
174+
results = list(executor.map(lambda x: x + 1, range(10)))
175+
assert results == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
176+
177+
178+
def test_map_timeout(executor):
179+
"""Test that map with timeout raises TimeoutError and cancels futures"""
180+
results = []
181+
182+
def func(x):
183+
nonlocal results
184+
time.sleep(0.05)
185+
results.append(x)
186+
return x
187+
188+
start = time.monotonic()
189+
with pytest.raises(TimeoutError):
190+
list(executor.map(func, range(10), timeout=0.01))
191+
duration = time.monotonic() - start
192+
assert duration < 0.05
193+
194+
executor.shutdown(wait=True)
195+
# only about half of the tasks should have completed
196+
# because the max number of workers is 5 and the rest of
197+
# the tasks were not started at the time of the cancel.
198+
assert set(results) != {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}

0 commit comments

Comments
 (0)