Skip to content

Commit 608b713

Browse files
committed
Protect shutdown with lock. Allow shutdown more than once.
1 parent 57a7e2a commit 608b713

File tree

2 files changed

+65
-31
lines changed

2 files changed

+65
-31
lines changed

src/qasync/__init__.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import time
2323
from concurrent.futures import Future
2424
from queue import Queue
25+
from threading import Lock
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -191,45 +192,42 @@ def __init__(self, max_workers=10, stack_size=None):
191192
self.__workers = [
192193
_QThreadWorker(self.__queue, i + 1, stack_size) for i in range(max_workers)
193194
]
195+
self.__shutdown_lock = Lock()
194196
self.__been_shutdown = False
195197

196198
for w in self.__workers:
197199
w.start()
198200

199201
def submit(self, callback, *args, **kwargs):
200-
if self.__been_shutdown:
201-
raise RuntimeError("QThreadExecutor has been shutdown")
202+
with self.__shutdown_lock:
203+
if self.__been_shutdown:
204+
raise RuntimeError("QThreadExecutor has been shutdown")
202205

203-
future = Future()
204-
self._logger.debug(
205-
"Submitting callback %s with args %s and kwargs %s to thread worker queue",
206-
callback,
207-
args,
208-
kwargs,
209-
)
210-
self.__queue.put((future, callback, args, kwargs))
211-
return future
206+
future = Future()
207+
self._logger.debug(
208+
"Submitting callback %s with args %s and kwargs %s to thread worker queue",
209+
callback,
210+
args,
211+
kwargs,
212+
)
213+
self.__queue.put((future, callback, args, kwargs))
214+
return future
212215

213216
def map(self, func, *iterables, timeout=None):
214217
raise NotImplementedError("use as_completed on the event loop")
215218

216219
def shutdown(self, wait=True):
217-
if self.__been_shutdown:
218-
raise RuntimeError("QThreadExecutor has been shutdown")
219-
220-
self.__been_shutdown = True
221-
222-
self._logger.debug("Shutting down")
223-
for i in range(len(self.__workers)):
224-
# Signal workers to stop
225-
self.__queue.put(None)
226-
if wait:
227-
for w in self.__workers:
228-
w.wait()
220+
with self.__shutdown_lock:
221+
self.__been_shutdown = True
222+
self._logger.debug("Shutting down")
223+
for i in range(len(self.__workers)):
224+
# Signal workers to stop
225+
self.__queue.put(None)
226+
if wait:
227+
for w in self.__workers:
228+
w.wait()
229229

230230
def __enter__(self, *args):
231-
if self.__been_shutdown:
232-
raise RuntimeError("QThreadExecutor has been shutdown")
233231
return self
234232

235233
def __exit__(self, *args):

tests/test_qthreadexec.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,16 @@ def shutdown_executor():
4444
return exe
4545

4646

47-
def test_shutdown_after_shutdown(shutdown_executor):
48-
with pytest.raises(RuntimeError):
49-
shutdown_executor.shutdown()
47+
@pytest.mark.parametrize("wait", [True, False])
48+
def test_shutdown_after_shutdown(shutdown_executor, wait):
49+
# it is safe to shutdown twice
50+
shutdown_executor.shutdown(wait=wait)
5051

5152

5253
def test_ctx_after_shutdown(shutdown_executor):
53-
with pytest.raises(RuntimeError):
54-
with shutdown_executor:
55-
pass
54+
# it is safe to enter and exit the context after shutdown
55+
with shutdown_executor:
56+
pass
5657

5758

5859
def test_submit_after_shutdown(shutdown_executor):
@@ -104,3 +105,38 @@ def test_no_stale_reference_as_result(executor, disable_executor_logging):
104105
assert collected is True, (
105106
"Stale reference to executor result not collected within timeout."
106107
)
108+
109+
110+
@pytest.mark.parametrize("cancel", [True, False])
111+
def test_shutdown_cancel_futures(executor, cancel):
112+
"""Test that shutdown with cancel_futures=True cancels all remaining futures in the queue."""
113+
114+
def task():
115+
time.sleep(0.01)
116+
117+
# Submit ten tasks to the executor
118+
futures = [executor.submit(task) for _ in range(10)]
119+
# shut it down
120+
executor.shutdown(cancel_futures=cancel)
121+
122+
cancels = 0
123+
for future in futures:
124+
try:
125+
future.result(timeout=0.01)
126+
except CancelledError:
127+
cancels += 1
128+
129+
if cancel:
130+
assert cancels > 0
131+
else:
132+
assert cancels == 0
133+
134+
135+
def test_context(executor):
136+
"""Test that the context manager will shutdown executor"""
137+
with executor:
138+
f = executor.submit(lambda: 42)
139+
assert f.result() == 42
140+
141+
with pytest.raises(RuntimeError):
142+
executor.submit(lambda: 42)

0 commit comments

Comments
 (0)