5
5
6
6
import asyncio
7
7
import atexit
8
+ import contextvars
8
9
import io
9
10
import os
10
11
import sys
11
12
import threading
12
13
import traceback
13
14
import warnings
14
15
from binascii import b2a_hex
15
- from collections import deque
16
+ from collections import defaultdict , deque
16
17
from io import StringIO , TextIOBase
17
18
from threading import local
18
19
from typing import Any , Callable , Deque , Dict , Optional
@@ -412,7 +413,7 @@ def __init__(
412
413
name : str {'stderr', 'stdout'}
413
414
the name of the standard stream to replace
414
415
pipe : object
415
- the pip object
416
+ the pipe object
416
417
echo : bool
417
418
whether to echo output
418
419
watchfd : bool (default, True)
@@ -446,13 +447,16 @@ def __init__(
446
447
self .pub_thread = pub_thread
447
448
self .name = name
448
449
self .topic = b"stream." + name .encode ()
449
- self .parent_header = {}
450
+ self ._parent_header : Dict [str , Any ] = contextvars .ContextVar ("parent_header" )
451
+ self ._parent_header .set ({})
452
+ self ._thread_parents = {}
453
+ self ._parent_header_global = {}
450
454
self ._master_pid = os .getpid ()
451
455
self ._flush_pending = False
452
456
self ._subprocess_flush_pending = False
453
457
self ._io_loop = pub_thread .io_loop
454
458
self ._buffer_lock = threading .RLock ()
455
- self ._buffer = StringIO ( )
459
+ self ._buffers = defaultdict ( StringIO )
456
460
self .echo = None
457
461
self ._isatty = bool (isatty )
458
462
self ._should_watch = False
@@ -495,6 +499,24 @@ def __init__(
495
499
msg = "echo argument must be a file-like object"
496
500
raise ValueError (msg )
497
501
502
+ @property
503
+ def parent_header (self ):
504
+ try :
505
+ # asyncio-specific
506
+ return self ._parent_header .get ()
507
+ except LookupError :
508
+ try :
509
+ # thread-specific
510
+ return self ._thread_parents [threading .current_thread ().ident ]
511
+ except KeyError :
512
+ # global (fallback)
513
+ return self ._parent_header_global
514
+
515
+ @parent_header .setter
516
+ def parent_header (self , value ):
517
+ self ._parent_header_global = value
518
+ return self ._parent_header .set (value )
519
+
498
520
def isatty (self ):
499
521
"""Return a bool indicating whether this is an 'interactive' stream.
500
522
@@ -598,28 +620,28 @@ def _flush(self):
598
620
if self .echo is not sys .__stderr__ :
599
621
print (f"Flush failed: { e } " , file = sys .__stderr__ )
600
622
601
- data = self ._flush_buffer ()
602
- if data :
603
- # FIXME: this disables Session's fork-safe check,
604
- # since pub_thread is itself fork-safe.
605
- # There should be a better way to do this.
606
- self .session .pid = os .getpid ()
607
- content = {"name" : self .name , "text" : data }
608
- msg = self .session .msg ("stream" , content , parent = self . parent_header )
609
-
610
- # Each transform either returns a new
611
- # message or None. If None is returned,
612
- # the message has been 'used' and we return.
613
- for hook in self ._hooks :
614
- msg = hook (msg )
615
- if msg is None :
616
- return
617
-
618
- self .session .send (
619
- self .pub_thread ,
620
- msg ,
621
- ident = self .topic ,
622
- )
623
+ for parent , data in self ._flush_buffers ():
624
+ if data :
625
+ # FIXME: this disables Session's fork-safe check,
626
+ # since pub_thread is itself fork-safe.
627
+ # There should be a better way to do this.
628
+ self .session .pid = os .getpid ()
629
+ content = {"name" : self .name , "text" : data }
630
+ msg = self .session .msg ("stream" , content , parent = parent )
631
+
632
+ # Each transform either returns a new
633
+ # message or None. If None is returned,
634
+ # the message has been 'used' and we return.
635
+ for hook in self ._hooks :
636
+ msg = hook (msg )
637
+ if msg is None :
638
+ return
639
+
640
+ self .session .send (
641
+ self .pub_thread ,
642
+ msg ,
643
+ ident = self .topic ,
644
+ )
623
645
624
646
def write (self , string : str ) -> Optional [int ]: # type:ignore[override]
625
647
"""Write to current stream after encoding if necessary
@@ -630,6 +652,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
630
652
number of items from input parameter written to stream.
631
653
632
654
"""
655
+ parent = self .parent_header
633
656
634
657
if not isinstance (string , str ):
635
658
msg = f"write() argument must be str, not { type (string )} " # type:ignore[unreachable]
@@ -649,7 +672,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
649
672
is_child = not self ._is_master_process ()
650
673
# only touch the buffer in the IO thread to avoid races
651
674
with self ._buffer_lock :
652
- self ._buffer .write (string )
675
+ self ._buffers [ frozenset ( parent . items ())] .write (string )
653
676
if is_child :
654
677
# mp.Pool cannot be trusted to flush promptly (or ever),
655
678
# and this helps.
@@ -675,19 +698,20 @@ def writable(self):
675
698
"""Test whether the stream is writable."""
676
699
return True
677
700
678
- def _flush_buffer (self ):
701
+ def _flush_buffers (self ):
679
702
"""clear the current buffer and return the current buffer data."""
680
- buf = self ._rotate_buffer ()
681
- data = buf .getvalue ()
682
- buf .close ()
683
- return data
703
+ buffers = self ._rotate_buffers ()
704
+ for frozen_parent , buffer in buffers .items ():
705
+ data = buffer .getvalue ()
706
+ buffer .close ()
707
+ yield dict (frozen_parent ), data
684
708
685
- def _rotate_buffer (self ):
709
+ def _rotate_buffers (self ):
686
710
"""Returns the current buffer and replaces it with an empty buffer."""
687
711
with self ._buffer_lock :
688
- old_buffer = self ._buffer
689
- self ._buffer = StringIO ( )
690
- return old_buffer
712
+ old_buffers = self ._buffers
713
+ self ._buffers = defaultdict ( StringIO )
714
+ return old_buffers
691
715
692
716
@property
693
717
def _hooks (self ):
0 commit comments