Skip to content

Commit 5dae96f

Browse files
committed
Make outputs go to correct cell when generated in threads/asyncio
1 parent de45c7a commit 5dae96f

File tree

3 files changed

+145
-46
lines changed

3 files changed

+145
-46
lines changed

ipykernel/iostream.py

+60-36
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
import asyncio
77
import atexit
8+
import contextvars
89
import io
910
import os
1011
import sys
1112
import threading
1213
import traceback
1314
import warnings
1415
from binascii import b2a_hex
15-
from collections import deque
16+
from collections import defaultdict, deque
1617
from io import StringIO, TextIOBase
1718
from threading import local
1819
from typing import Any, Callable, Deque, Dict, Optional
@@ -412,7 +413,7 @@ def __init__(
412413
name : str {'stderr', 'stdout'}
413414
the name of the standard stream to replace
414415
pipe : object
415-
the pip object
416+
the pipe object
416417
echo : bool
417418
whether to echo output
418419
watchfd : bool (default, True)
@@ -446,13 +447,16 @@ def __init__(
446447
self.pub_thread = pub_thread
447448
self.name = name
448449
self.topic = b"stream." + name.encode()
449-
self.parent_header = {}
450+
self._parent_header: Dict[str, Any] = contextvars.ContextVar("parent_header")
451+
self._parent_header.set({})
452+
self._thread_parents = {}
453+
self._parent_header_global = {}
450454
self._master_pid = os.getpid()
451455
self._flush_pending = False
452456
self._subprocess_flush_pending = False
453457
self._io_loop = pub_thread.io_loop
454458
self._buffer_lock = threading.RLock()
455-
self._buffer = StringIO()
459+
self._buffers = defaultdict(StringIO)
456460
self.echo = None
457461
self._isatty = bool(isatty)
458462
self._should_watch = False
@@ -495,6 +499,24 @@ def __init__(
495499
msg = "echo argument must be a file-like object"
496500
raise ValueError(msg)
497501

502+
@property
503+
def parent_header(self):
504+
try:
505+
# asyncio-specific
506+
return self._parent_header.get()
507+
except LookupError:
508+
try:
509+
# thread-specific
510+
return self._thread_parents[threading.current_thread().ident]
511+
except KeyError:
512+
# global (fallback)
513+
return self._parent_header_global
514+
515+
@parent_header.setter
516+
def parent_header(self, value):
517+
self._parent_header_global = value
518+
return self._parent_header.set(value)
519+
498520
def isatty(self):
499521
"""Return a bool indicating whether this is an 'interactive' stream.
500522
@@ -598,28 +620,28 @@ def _flush(self):
598620
if self.echo is not sys.__stderr__:
599621
print(f"Flush failed: {e}", file=sys.__stderr__)
600622

601-
data = self._flush_buffer()
602-
if data:
603-
# FIXME: this disables Session's fork-safe check,
604-
# since pub_thread is itself fork-safe.
605-
# There should be a better way to do this.
606-
self.session.pid = os.getpid()
607-
content = {"name": self.name, "text": data}
608-
msg = self.session.msg("stream", content, parent=self.parent_header)
609-
610-
# Each transform either returns a new
611-
# message or None. If None is returned,
612-
# the message has been 'used' and we return.
613-
for hook in self._hooks:
614-
msg = hook(msg)
615-
if msg is None:
616-
return
617-
618-
self.session.send(
619-
self.pub_thread,
620-
msg,
621-
ident=self.topic,
622-
)
623+
for parent, data in self._flush_buffers():
624+
if data:
625+
# FIXME: this disables Session's fork-safe check,
626+
# since pub_thread is itself fork-safe.
627+
# There should be a better way to do this.
628+
self.session.pid = os.getpid()
629+
content = {"name": self.name, "text": data}
630+
msg = self.session.msg("stream", content, parent=parent)
631+
632+
# Each transform either returns a new
633+
# message or None. If None is returned,
634+
# the message has been 'used' and we return.
635+
for hook in self._hooks:
636+
msg = hook(msg)
637+
if msg is None:
638+
return
639+
640+
self.session.send(
641+
self.pub_thread,
642+
msg,
643+
ident=self.topic,
644+
)
623645

624646
def write(self, string: str) -> Optional[int]: # type:ignore[override]
625647
"""Write to current stream after encoding if necessary
@@ -630,6 +652,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
630652
number of items from input parameter written to stream.
631653
632654
"""
655+
parent = self.parent_header
633656

634657
if not isinstance(string, str):
635658
msg = f"write() argument must be str, not {type(string)}" # type:ignore[unreachable]
@@ -649,7 +672,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
649672
is_child = not self._is_master_process()
650673
# only touch the buffer in the IO thread to avoid races
651674
with self._buffer_lock:
652-
self._buffer.write(string)
675+
self._buffers[frozenset(parent.items())].write(string)
653676
if is_child:
654677
# mp.Pool cannot be trusted to flush promptly (or ever),
655678
# and this helps.
@@ -675,19 +698,20 @@ def writable(self):
675698
"""Test whether the stream is writable."""
676699
return True
677700

678-
def _flush_buffer(self):
701+
def _flush_buffers(self):
679702
"""clear the current buffer and return the current buffer data."""
680-
buf = self._rotate_buffer()
681-
data = buf.getvalue()
682-
buf.close()
683-
return data
703+
buffers = self._rotate_buffers()
704+
for frozen_parent, buffer in buffers.items():
705+
data = buffer.getvalue()
706+
buffer.close()
707+
yield dict(frozen_parent), data
684708

685-
def _rotate_buffer(self):
709+
def _rotate_buffers(self):
686710
"""Returns the current buffer and replaces it with an empty buffer."""
687711
with self._buffer_lock:
688-
old_buffer = self._buffer
689-
self._buffer = StringIO()
690-
return old_buffer
712+
old_buffers = self._buffers
713+
self._buffers = defaultdict(StringIO)
714+
return old_buffers
691715

692716
@property
693717
def _hooks(self):

ipykernel/ipkernel.py

+67
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import builtins
5+
import gc
56
import getpass
67
import os
78
import signal
@@ -14,6 +15,7 @@
1415
import comm
1516
from IPython.core import release
1617
from IPython.utils.tokenutil import line_at_cursor, token_at_cursor
18+
from jupyter_client.session import extract_header
1719
from traitlets import Any, Bool, HasTraits, Instance, List, Type, observe, observe_compat
1820
from zmq.eventloop.zmqstream import ZMQStream
1921

@@ -22,6 +24,7 @@
2224
from .compiler import XCachingCompiler
2325
from .debugger import Debugger, _is_debugpy_available
2426
from .eventloops import _use_appnope
27+
from .iostream import OutStream
2528
from .kernelbase import Kernel as KernelBase
2629
from .kernelbase import _accepts_parameters
2730
from .zmqshell import ZMQInteractiveShell
@@ -66,6 +69,10 @@ def _get_comm_manager(*args, **kwargs):
6669
comm.create_comm = _create_comm
6770
comm.get_comm_manager = _get_comm_manager
6871

72+
import threading
73+
74+
threading_start = threading.Thread.start
75+
6976

7077
class IPythonKernel(KernelBase):
7178
"""The IPython Kernel class."""
@@ -151,6 +158,8 @@ def __init__(self, **kwargs):
151158

152159
appnope.nope()
153160

161+
gc.callbacks.append(self._clean_thread_parent_frames)
162+
154163
help_links = List(
155164
[
156165
{
@@ -341,6 +350,12 @@ def set_sigint_result():
341350
# restore the previous sigint handler
342351
signal.signal(signal.SIGINT, save_sigint)
343352

353+
async def execute_request(self, stream, ident, parent):
354+
"""Override for cell output - cell reconciliation."""
355+
parent_header = extract_header(parent)
356+
self._associate_identity_of_new_threads_with(parent_header)
357+
await super().execute_request(stream, ident, parent)
358+
344359
async def do_execute(
345360
self,
346361
code,
@@ -706,6 +721,58 @@ def do_clear(self):
706721
self.shell.reset(False)
707722
return dict(status="ok")
708723

724+
def _associate_identity_of_new_threads_with(self, parent_header):
725+
"""Intercept the identity of any thread started after this method finished,
726+
727+
and associate the thread's output with the parent header frame, which allows
728+
to direct the outputs to the cell which started the thread.
729+
730+
This is a no-op if the `self.stdout` and `self.stderr` are not
731+
sub-classes of `OutStream`.
732+
"""
733+
stdout = self.stdout
734+
stderr = self.stderr
735+
736+
def start_closure(self: threading.Thread):
737+
"""Wrap the `threading.Thread.start` to intercept thread identity.
738+
739+
This is needed because there is no "start" hook yet, but there
740+
might be one in the future: https://bugs.python.org/issue14073
741+
"""
742+
743+
threading_start(self)
744+
for stream in [stdout, stderr]:
745+
if isinstance(stream, OutStream):
746+
stream._thread_parents[self.ident] = parent_header
747+
748+
threading.Thread.start = start_closure # type:ignore[method-assign]
749+
750+
def _clean_thread_parent_frames(
751+
self, phase: t.Literal["start", "stop"], info: t.Dict[str, t.Any]
752+
):
753+
"""Clean parent frames of threads which are no longer running.
754+
This is meant to be invoked by garbage collector callback hook.
755+
756+
The implementation enumerates the threads because there is no "exit" hook yet,
757+
but there might be one in the future: https://bugs.python.org/issue14073
758+
759+
This is a no-op if the `self.stdout` and `self.stderr` are not
760+
sub-classes of `OutStream`.
761+
"""
762+
# Only run before the garbage collector starts
763+
if phase != "start":
764+
return
765+
active_threads = {thread.ident for thread in threading.enumerate()}
766+
for stream in [self.stdout, self.stderr]:
767+
if isinstance(stream, OutStream):
768+
thread_parents = stream._thread_parents
769+
for identity in list(thread_parents.keys()):
770+
if identity not in active_threads:
771+
try:
772+
del thread_parents[identity]
773+
except KeyError:
774+
pass
775+
709776

710777
# This exists only for backwards compatibility - use IPythonKernel instead
711778

ipykernel/kernelbase.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from ipykernel.jsonutil import json_clean
6262

6363
from ._version import kernel_protocol_version
64+
from .iostream import OutStream
6465

6566

6667
def _accepts_parameters(meth, param_names):
@@ -272,6 +273,13 @@ def _parent_header(self):
272273
def __init__(self, **kwargs):
273274
"""Initialize the kernel."""
274275
super().__init__(**kwargs)
276+
277+
# Kernel application may swap stdout and stderr to OutStream,
278+
# which is the case in `IPKernelApp.init_io`, hence `sys.stdout`
279+
# can already by different from TextIO at initialization time.
280+
self.stdout: t.Union[OutStream, t.TextIO] = sys.stdout
281+
self.stderr: t.Union[OutStream, t.TextIO] = sys.stderr
282+
275283
# Build dict of handlers for message types
276284
self.shell_handlers = {}
277285
for msg_type in self.msg_types:
@@ -355,8 +363,8 @@ async def process_control(self, msg):
355363
except Exception:
356364
self.log.error("Exception in control handler:", exc_info=True) # noqa: G201
357365

358-
sys.stdout.flush()
359-
sys.stderr.flush()
366+
self.stdout.flush()
367+
self.stderr.flush()
360368
self._publish_status("idle", "control")
361369
# flush to ensure reply is sent
362370
if self.control_stream:
@@ -438,8 +446,8 @@ async def dispatch_shell(self, msg):
438446
except Exception:
439447
self.log.debug("Unable to signal in post_handler_hook:", exc_info=True)
440448

441-
sys.stdout.flush()
442-
sys.stderr.flush()
449+
self.stdout.flush()
450+
self.stderr.flush()
443451
self._publish_status("idle", "shell")
444452
# flush to ensure reply is sent before
445453
# handling the next request
@@ -767,8 +775,8 @@ async def execute_request(self, stream, ident, parent):
767775
reply_content = await reply_content
768776

769777
# Flush output before sending the reply.
770-
sys.stdout.flush()
771-
sys.stderr.flush()
778+
self.stdout.flush()
779+
self.stderr.flush()
772780
# FIXME: on rare occasions, the flush doesn't seem to make it to the
773781
# clients... This seems to mitigate the problem, but we definitely need
774782
# to better understand what's going on.
@@ -1102,8 +1110,8 @@ async def apply_request(self, stream, ident, parent): # pragma: no cover
11021110
reply_content, result_buf = self.do_apply(content, bufs, msg_id, md)
11031111

11041112
# flush i/o
1105-
sys.stdout.flush()
1106-
sys.stderr.flush()
1113+
self.stdout.flush()
1114+
self.stderr.flush()
11071115

11081116
md = self.finish_metadata(parent, md, reply_content)
11091117
if not self.session:
@@ -1268,8 +1276,8 @@ def raw_input(self, prompt=""):
12681276

12691277
def _input_request(self, prompt, ident, parent, password=False):
12701278
# Flush output before making the request.
1271-
sys.stderr.flush()
1272-
sys.stdout.flush()
1279+
self.stderr.flush()
1280+
self.stdout.flush()
12731281

12741282
# flush the stdin socket, to purge stale replies
12751283
while True:

0 commit comments

Comments
 (0)