5
5
6
6
import asyncio
7
7
import atexit
8
+ import contextvars
8
9
import io
9
10
import os
10
11
import sys
11
12
import threading
12
13
import traceback
13
14
import warnings
14
15
from binascii import b2a_hex
15
- from collections import deque
16
+ from collections import defaultdict , deque
16
17
from io import StringIO , TextIOBase
17
18
from threading import local
18
19
from typing import Any , Callable , Deque , Dict , Optional
@@ -412,7 +413,7 @@ def __init__(
412
413
name : str {'stderr', 'stdout'}
413
414
the name of the standard stream to replace
414
415
pipe : object
415
- the pip object
416
+ the pipe object
416
417
echo : bool
417
418
whether to echo output
418
419
watchfd : bool (default, True)
@@ -446,13 +447,18 @@ def __init__(
446
447
self .pub_thread = pub_thread
447
448
self .name = name
448
449
self .topic = b"stream." + name .encode ()
449
- self .parent_header = {}
450
+ self ._parent_header : contextvars .ContextVar [Dict [str , Any ]] = contextvars .ContextVar (
451
+ "parent_header"
452
+ )
453
+ self ._parent_header .set ({})
454
+ self ._thread_parents = {}
455
+ self ._parent_header_global = {}
450
456
self ._master_pid = os .getpid ()
451
457
self ._flush_pending = False
452
458
self ._subprocess_flush_pending = False
453
459
self ._io_loop = pub_thread .io_loop
454
460
self ._buffer_lock = threading .RLock ()
455
- self ._buffer = StringIO ( )
461
+ self ._buffers = defaultdict ( StringIO )
456
462
self .echo = None
457
463
self ._isatty = bool (isatty )
458
464
self ._should_watch = False
@@ -495,6 +501,24 @@ def __init__(
495
501
msg = "echo argument must be a file-like object"
496
502
raise ValueError (msg )
497
503
504
+ @property
505
+ def parent_header (self ):
506
+ try :
507
+ # asyncio-specific
508
+ return self ._parent_header .get ()
509
+ except LookupError :
510
+ try :
511
+ # thread-specific
512
+ return self ._thread_parents [threading .current_thread ().ident ]
513
+ except KeyError :
514
+ # global (fallback)
515
+ return self ._parent_header_global
516
+
517
+ @parent_header .setter
518
+ def parent_header (self , value ):
519
+ self ._parent_header_global = value
520
+ return self ._parent_header .set (value )
521
+
498
522
def isatty (self ):
499
523
"""Return a bool indicating whether this is an 'interactive' stream.
500
524
@@ -598,28 +622,28 @@ def _flush(self):
598
622
if self .echo is not sys .__stderr__ :
599
623
print (f"Flush failed: { e } " , file = sys .__stderr__ )
600
624
601
- data = self ._flush_buffer ()
602
- if data :
603
- # FIXME: this disables Session's fork-safe check,
604
- # since pub_thread is itself fork-safe.
605
- # There should be a better way to do this.
606
- self .session .pid = os .getpid ()
607
- content = {"name" : self .name , "text" : data }
608
- msg = self .session .msg ("stream" , content , parent = self . parent_header )
609
-
610
- # Each transform either returns a new
611
- # message or None. If None is returned,
612
- # the message has been 'used' and we return.
613
- for hook in self ._hooks :
614
- msg = hook (msg )
615
- if msg is None :
616
- return
617
-
618
- self .session .send (
619
- self .pub_thread ,
620
- msg ,
621
- ident = self .topic ,
622
- )
625
+ for parent , data in self ._flush_buffers ():
626
+ if data :
627
+ # FIXME: this disables Session's fork-safe check,
628
+ # since pub_thread is itself fork-safe.
629
+ # There should be a better way to do this.
630
+ self .session .pid = os .getpid ()
631
+ content = {"name" : self .name , "text" : data }
632
+ msg = self .session .msg ("stream" , content , parent = parent )
633
+
634
+ # Each transform either returns a new
635
+ # message or None. If None is returned,
636
+ # the message has been 'used' and we return.
637
+ for hook in self ._hooks :
638
+ msg = hook (msg )
639
+ if msg is None :
640
+ return
641
+
642
+ self .session .send (
643
+ self .pub_thread ,
644
+ msg ,
645
+ ident = self .topic ,
646
+ )
623
647
624
648
def write (self , string : str ) -> Optional [int ]: # type:ignore[override]
625
649
"""Write to current stream after encoding if necessary
@@ -630,6 +654,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
630
654
number of items from input parameter written to stream.
631
655
632
656
"""
657
+ parent = self .parent_header
633
658
634
659
if not isinstance (string , str ):
635
660
msg = f"write() argument must be str, not { type (string )} " # type:ignore[unreachable]
@@ -649,7 +674,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
649
674
is_child = not self ._is_master_process ()
650
675
# only touch the buffer in the IO thread to avoid races
651
676
with self ._buffer_lock :
652
- self ._buffer .write (string )
677
+ self ._buffers [ frozenset ( parent . items ())] .write (string )
653
678
if is_child :
654
679
# mp.Pool cannot be trusted to flush promptly (or ever),
655
680
# and this helps.
@@ -675,19 +700,20 @@ def writable(self):
675
700
"""Test whether the stream is writable."""
676
701
return True
677
702
678
- def _flush_buffer (self ):
703
+ def _flush_buffers (self ):
679
704
"""clear the current buffer and return the current buffer data."""
680
- buf = self ._rotate_buffer ()
681
- data = buf .getvalue ()
682
- buf .close ()
683
- return data
705
+ buffers = self ._rotate_buffers ()
706
+ for frozen_parent , buffer in buffers .items ():
707
+ data = buffer .getvalue ()
708
+ buffer .close ()
709
+ yield dict (frozen_parent ), data
684
710
685
- def _rotate_buffer (self ):
711
+ def _rotate_buffers (self ):
686
712
"""Returns the current buffer and replaces it with an empty buffer."""
687
713
with self ._buffer_lock :
688
- old_buffer = self ._buffer
689
- self ._buffer = StringIO ( )
690
- return old_buffer
714
+ old_buffers = self ._buffers
715
+ self ._buffers = defaultdict ( StringIO )
716
+ return old_buffers
691
717
692
718
@property
693
719
def _hooks (self ):
0 commit comments