Skip to content

Commit 9e9c40e

Browse files
committed
Make outputs go to correct cell when generated in threads/asyncio
1 parent 6d97970 commit 9e9c40e

File tree

3 files changed

+140
-36
lines changed

3 files changed

+140
-36
lines changed

ipykernel/iostream.py

Lines changed: 62 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
import asyncio
77
import atexit
8+
import contextvars
89
import io
910
import os
1011
import sys
1112
import threading
1213
import traceback
1314
import warnings
1415
from binascii import b2a_hex
15-
from collections import deque
16+
from collections import defaultdict, deque
1617
from io import StringIO, TextIOBase
1718
from threading import local
1819
from typing import Any, Callable, Deque, Dict, Optional
@@ -412,7 +413,7 @@ def __init__(
412413
name : str {'stderr', 'stdout'}
413414
the name of the standard stream to replace
414415
pipe : object
415-
the pip object
416+
the pipe object
416417
echo : bool
417418
whether to echo output
418419
watchfd : bool (default, True)
@@ -446,13 +447,18 @@ def __init__(
446447
self.pub_thread = pub_thread
447448
self.name = name
448449
self.topic = b"stream." + name.encode()
449-
self.parent_header = {}
450+
self._parent_header: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar(
451+
"parent_header"
452+
)
453+
self._parent_header.set({})
454+
self._thread_parents = {}
455+
self._parent_header_global = {}
450456
self._master_pid = os.getpid()
451457
self._flush_pending = False
452458
self._subprocess_flush_pending = False
453459
self._io_loop = pub_thread.io_loop
454460
self._buffer_lock = threading.RLock()
455-
self._buffer = StringIO()
461+
self._buffers = defaultdict(StringIO)
456462
self.echo = None
457463
self._isatty = bool(isatty)
458464
self._should_watch = False
@@ -495,6 +501,24 @@ def __init__(
495501
msg = "echo argument must be a file-like object"
496502
raise ValueError(msg)
497503

504+
@property
505+
def parent_header(self):
506+
try:
507+
# asyncio-specific
508+
return self._parent_header.get()
509+
except LookupError:
510+
try:
511+
# thread-specific
512+
return self._thread_parents[threading.current_thread().ident]
513+
except KeyError:
514+
# global (fallback)
515+
return self._parent_header_global
516+
517+
@parent_header.setter
518+
def parent_header(self, value):
519+
self._parent_header_global = value
520+
return self._parent_header.set(value)
521+
498522
def isatty(self):
499523
"""Return a bool indicating whether this is an 'interactive' stream.
500524
@@ -598,28 +622,28 @@ def _flush(self):
598622
if self.echo is not sys.__stderr__:
599623
print(f"Flush failed: {e}", file=sys.__stderr__)
600624

601-
data = self._flush_buffer()
602-
if data:
603-
# FIXME: this disables Session's fork-safe check,
604-
# since pub_thread is itself fork-safe.
605-
# There should be a better way to do this.
606-
self.session.pid = os.getpid()
607-
content = {"name": self.name, "text": data}
608-
msg = self.session.msg("stream", content, parent=self.parent_header)
609-
610-
# Each transform either returns a new
611-
# message or None. If None is returned,
612-
# the message has been 'used' and we return.
613-
for hook in self._hooks:
614-
msg = hook(msg)
615-
if msg is None:
616-
return
617-
618-
self.session.send(
619-
self.pub_thread,
620-
msg,
621-
ident=self.topic,
622-
)
625+
for parent, data in self._flush_buffers():
626+
if data:
627+
# FIXME: this disables Session's fork-safe check,
628+
# since pub_thread is itself fork-safe.
629+
# There should be a better way to do this.
630+
self.session.pid = os.getpid()
631+
content = {"name": self.name, "text": data}
632+
msg = self.session.msg("stream", content, parent=parent)
633+
634+
# Each transform either returns a new
635+
# message or None. If None is returned,
636+
# the message has been 'used' and we return.
637+
for hook in self._hooks:
638+
msg = hook(msg)
639+
if msg is None:
640+
return
641+
642+
self.session.send(
643+
self.pub_thread,
644+
msg,
645+
ident=self.topic,
646+
)
623647

624648
def write(self, string: str) -> Optional[int]: # type:ignore[override]
625649
"""Write to current stream after encoding if necessary
@@ -630,6 +654,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
630654
number of items from input parameter written to stream.
631655
632656
"""
657+
parent = self.parent_header
633658

634659
if not isinstance(string, str):
635660
msg = f"write() argument must be str, not {type(string)}" # type:ignore[unreachable]
@@ -649,7 +674,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
649674
is_child = not self._is_master_process()
650675
# only touch the buffer in the IO thread to avoid races
651676
with self._buffer_lock:
652-
self._buffer.write(string)
677+
self._buffers[frozenset(parent.items())].write(string)
653678
if is_child:
654679
# mp.Pool cannot be trusted to flush promptly (or ever),
655680
# and this helps.
@@ -675,19 +700,20 @@ def writable(self):
675700
"""Test whether the stream is writable."""
676701
return True
677702

678-
def _flush_buffer(self):
703+
def _flush_buffers(self):
679704
"""clear the current buffer and return the current buffer data."""
680-
buf = self._rotate_buffer()
681-
data = buf.getvalue()
682-
buf.close()
683-
return data
705+
buffers = self._rotate_buffers()
706+
for frozen_parent, buffer in buffers.items():
707+
data = buffer.getvalue()
708+
buffer.close()
709+
yield dict(frozen_parent), data
684710

685-
def _rotate_buffer(self):
711+
def _rotate_buffers(self):
686712
"""Returns the current buffer and replaces it with an empty buffer."""
687713
with self._buffer_lock:
688-
old_buffer = self._buffer
689-
self._buffer = StringIO()
690-
return old_buffer
714+
old_buffers = self._buffers
715+
self._buffers = defaultdict(StringIO)
716+
return old_buffers
691717

692718
@property
693719
def _hooks(self):

ipykernel/ipkernel.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import builtins
5+
import gc
56
import getpass
67
import os
78
import signal
@@ -14,6 +15,7 @@
1415
import comm
1516
from IPython.core import release
1617
from IPython.utils.tokenutil import line_at_cursor, token_at_cursor
18+
from jupyter_client.session import extract_header
1719
from traitlets import Any, Bool, HasTraits, Instance, List, Type, observe, observe_compat
1820
from zmq.eventloop.zmqstream import ZMQStream
1921

@@ -22,6 +24,7 @@
2224
from .compiler import XCachingCompiler
2325
from .debugger import Debugger, _is_debugpy_available
2426
from .eventloops import _use_appnope
27+
from .iostream import OutStream
2528
from .kernelbase import Kernel as KernelBase
2629
from .kernelbase import _accepts_parameters
2730
from .zmqshell import ZMQInteractiveShell
@@ -66,6 +69,10 @@ def _get_comm_manager(*args, **kwargs):
6669
comm.create_comm = _create_comm
6770
comm.get_comm_manager = _get_comm_manager
6871

72+
import threading
73+
74+
threading_start = threading.Thread.start
75+
6976

7077
class IPythonKernel(KernelBase):
7178
"""The IPython Kernel class."""
@@ -151,6 +158,11 @@ def __init__(self, **kwargs):
151158

152159
appnope.nope()
153160

161+
if hasattr(gc, "callbacks"):
162+
# while `gc.callbacks` exists since Python 3.3, pypy does not
163+
# implement it even as of 3.9.
164+
gc.callbacks.append(self._clean_thread_parent_frames)
165+
154166
help_links = List(
155167
[
156168
{
@@ -341,6 +353,12 @@ def set_sigint_result():
341353
# restore the previous sigint handler
342354
signal.signal(signal.SIGINT, save_sigint)
343355

356+
async def execute_request(self, stream, ident, parent):
357+
"""Override for cell output - cell reconciliation."""
358+
parent_header = extract_header(parent)
359+
self._associate_identity_of_new_threads_with(parent_header)
360+
await super().execute_request(stream, ident, parent)
361+
344362
async def do_execute(
345363
self,
346364
code,
@@ -706,6 +724,58 @@ def do_clear(self):
706724
self.shell.reset(False)
707725
return dict(status="ok")
708726

727+
def _associate_identity_of_new_threads_with(self, parent_header):
728+
"""Intercept the identity of any thread started after this method finished,
729+
730+
and associate the thread's output with the parent header frame, which allows
731+
to direct the outputs to the cell which started the thread.
732+
733+
This is a no-op if the `self._stdout` and `self._stderr` are not
734+
sub-classes of `OutStream`.
735+
"""
736+
stdout = self._stdout
737+
stderr = self._stderr
738+
739+
def start_closure(self: threading.Thread):
740+
"""Wrap the `threading.Thread.start` to intercept thread identity.
741+
742+
This is needed because there is no "start" hook yet, but there
743+
might be one in the future: https://bugs.python.org/issue14073
744+
"""
745+
746+
threading_start(self)
747+
for stream in [stdout, stderr]:
748+
if isinstance(stream, OutStream):
749+
stream._thread_parents[self.ident] = parent_header
750+
751+
threading.Thread.start = start_closure # type:ignore[method-assign]
752+
753+
def _clean_thread_parent_frames(
754+
self, phase: t.Literal["start", "stop"], info: t.Dict[str, t.Any]
755+
):
756+
"""Clean parent frames of threads which are no longer running.
757+
This is meant to be invoked by garbage collector callback hook.
758+
759+
The implementation enumerates the threads because there is no "exit" hook yet,
760+
but there might be one in the future: https://bugs.python.org/issue14073
761+
762+
This is a no-op if the `self._stdout` and `self._stderr` are not
763+
sub-classes of `OutStream`.
764+
"""
765+
# Only run before the garbage collector starts
766+
if phase != "start":
767+
return
768+
active_threads = {thread.ident for thread in threading.enumerate()}
769+
for stream in [self._stdout, self._stderr]:
770+
if isinstance(stream, OutStream):
771+
thread_parents = stream._thread_parents
772+
for identity in list(thread_parents.keys()):
773+
if identity not in active_threads:
774+
try:
775+
del thread_parents[identity]
776+
except KeyError:
777+
pass
778+
709779

710780
# This exists only for backwards compatibility - use IPythonKernel instead
711781

ipykernel/kernelbase.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from ipykernel.jsonutil import json_clean
6262

6363
from ._version import kernel_protocol_version
64+
from .iostream import OutStream
6465

6566

6667
def _accepts_parameters(meth, param_names):
@@ -272,6 +273,13 @@ def _parent_header(self):
272273
def __init__(self, **kwargs):
273274
"""Initialize the kernel."""
274275
super().__init__(**kwargs)
276+
277+
# Kernel application may swap stdout and stderr to OutStream,
278+
# which is the case in `IPKernelApp.init_io`, hence `sys.stdout`
279+
# can already by different from TextIO at initialization time.
280+
self._stdout: OutStream | t.TextIO = sys.stdout
281+
self._stderr: OutStream | t.TextIO = sys.stderr
282+
275283
# Build dict of handlers for message types
276284
self.shell_handlers = {}
277285
for msg_type in self.msg_types:

0 commit comments

Comments
 (0)