Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Include/internal/pycore_global_objects_fini_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Include/internal/pycore_global_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ struct _Py_global_strings {
STRUCT_FOR_ID(coro)
STRUCT_FOR_ID(count)
STRUCT_FOR_ID(covariant)
STRUCT_FOR_ID(cpu_time)
STRUCT_FOR_ID(ctx)
STRUCT_FOR_ID(cwd)
STRUCT_FOR_ID(d_parameter_type)
Expand Down
1 change: 1 addition & 0 deletions Include/internal/pycore_runtime_init_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Include/internal/pycore_unicodeobject_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 15 additions & 1 deletion Lib/profiling/sampling/collector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
from abc import ABC, abstractmethod

# Enums are slow
THREAD_STATE_RUNNING = 0
THREAD_STATE_IDLE = 1
THREAD_STATE_GIL_WAIT = 2
THREAD_STATE_UNKNOWN = 3

STATUS = {
THREAD_STATE_RUNNING: "running",
THREAD_STATE_IDLE: "idle",
THREAD_STATE_GIL_WAIT: "gil_wait",
THREAD_STATE_UNKNOWN: "unknown",
}

class Collector(ABC):
@abstractmethod
Expand All @@ -10,10 +22,12 @@ def collect(self, stack_frames):
def export(self, filename):
"""Export collected data to a file."""

def _iter_all_frames(self, stack_frames):
def _iter_all_frames(self, stack_frames, skip_idle=False):
"""Iterate over all frame stacks from all interpreters and threads."""
for interpreter_info in stack_frames:
for thread_info in interpreter_info.threads:
if skip_idle and thread_info.status != THREAD_STATE_RUNNING:
continue
frames = thread_info.frame_info
if frames:
yield frames
5 changes: 3 additions & 2 deletions Lib/profiling/sampling/pstats_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class PstatsCollector(Collector):
def __init__(self, sample_interval_usec):
def __init__(self, sample_interval_usec, *, skip_idle=False):
self.result = collections.defaultdict(
lambda: dict(total_rec_calls=0, direct_calls=0, cumulative_calls=0)
)
Expand All @@ -14,6 +14,7 @@ def __init__(self, sample_interval_usec):
self.callers = collections.defaultdict(
lambda: collections.defaultdict(int)
)
self.skip_idle = skip_idle

def _process_frames(self, frames):
"""Process a single thread's frame stack."""
Expand All @@ -40,7 +41,7 @@ def _process_frames(self, frames):
self.callers[callee][caller] += 1

def collect(self, stack_frames):
for frames in self._iter_all_frames(stack_frames):
for frames in self._iter_all_frames(stack_frames, skip_idle=self.skip_idle):
self._process_frames(frames)

def export(self, filename):
Expand Down
30 changes: 23 additions & 7 deletions Lib/profiling/sampling/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,18 @@ def _run_with_sync(original_cmd):


class SampleProfiler:
def __init__(self, pid, sample_interval_usec, all_threads):
def __init__(self, pid, sample_interval_usec, all_threads, *, cpu_time=False):
self.pid = pid
self.sample_interval_usec = sample_interval_usec
self.all_threads = all_threads
if _FREE_THREADED_BUILD:
self.unwinder = _remote_debugging.RemoteUnwinder(
self.pid, all_threads=self.all_threads
self.pid, all_threads=self.all_threads, cpu_time=cpu_time
)
else:
only_active_threads = bool(self.all_threads)
self.unwinder = _remote_debugging.RemoteUnwinder(
self.pid, only_active_thread=only_active_threads
self.pid, only_active_thread=only_active_threads, cpu_time=cpu_time
)
# Track sample intervals and total sample count
self.sample_intervals = deque(maxlen=100)
Expand Down Expand Up @@ -596,21 +596,22 @@ def sample(
show_summary=True,
output_format="pstats",
realtime_stats=False,
skip_idle=False,
):
profiler = SampleProfiler(
pid, sample_interval_usec, all_threads=all_threads
pid, sample_interval_usec, all_threads=all_threads, cpu_time=skip_idle
)
profiler.realtime_stats = realtime_stats

collector = None
match output_format:
case "pstats":
collector = PstatsCollector(sample_interval_usec)
collector = PstatsCollector(sample_interval_usec, skip_idle=skip_idle)
case "collapsed":
collector = CollapsedStackCollector()
collector = CollapsedStackCollector(skip_idle=skip_idle)
filename = filename or f"collapsed.{pid}.txt"
case "flamegraph":
collector = FlamegraphCollector()
collector = FlamegraphCollector(skip_idle=skip_idle)
filename = filename or f"flamegraph.{pid}.html"
case _:
raise ValueError(f"Invalid output format: {output_format}")
Expand Down Expand Up @@ -660,6 +661,7 @@ def wait_for_process_and_sample(pid, sort_value, args):
filename = args.outfile
if not filename and args.format == "collapsed":
filename = f"collapsed.{pid}.txt"
skip_idle = True if args.mode == "cpu" else False

sample(
pid,
Expand All @@ -672,6 +674,7 @@ def wait_for_process_and_sample(pid, sort_value, args):
show_summary=not args.no_summary,
output_format=args.format,
realtime_stats=args.realtime_stats,
skip_idle=skip_idle,
)


Expand Down Expand Up @@ -726,6 +729,15 @@ def main():
help="Print real-time sampling statistics (Hz, mean, min, max, stdev) during profiling",
)

# Mode options
mode_group = parser.add_argument_group("Mode options")
mode_group.add_argument(
"--mode",
choices=["wall", "cpu"],
default="wall-time",
help="Sampling mode: wall-time (default, skip_idle=False) or cpu-time (skip_idle=True)",
)

# Output format selection
output_group = parser.add_argument_group("Output options")
output_format = output_group.add_mutually_exclusive_group()
Expand Down Expand Up @@ -850,6 +862,9 @@ def main():
elif target_count > 1:
parser.error("only one target type can be specified: -p/--pid, -m/--module, or script")

# Set skip_idle based on mode
skip_idle = True if args.mode == "cpu" else False

if args.pid:
sample(
args.pid,
Expand All @@ -862,6 +877,7 @@ def main():
show_summary=not args.no_summary,
output_format=args.format,
realtime_stats=args.realtime_stats,
skip_idle=skip_idle,
)
elif args.module or args.args:
if args.module:
Expand Down
9 changes: 5 additions & 4 deletions Lib/profiling/sampling/stack_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@


class StackTraceCollector(Collector):
def __init__(self):
def __init__(self, *, skip_idle=False):
self.call_trees = []
self.function_samples = collections.defaultdict(int)
self.skip_idle = skip_idle

def _process_frames(self, frames):
"""Process a single thread's frame stack."""
Expand All @@ -28,7 +29,7 @@ def _process_frames(self, frames):
self.function_samples[frame] += 1

def collect(self, stack_frames):
for frames in self._iter_all_frames(stack_frames):
for frames in self._iter_all_frames(stack_frames, skip_idle=self.skip_idle):
self._process_frames(frames)


Expand All @@ -49,8 +50,8 @@ def export(self, filename):


class FlamegraphCollector(StackTraceCollector):
def __init__(self):
super().__init__()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stats = {}

def set_stats(self, sample_interval_usec, duration_sec, sample_rate, error_rate=None):
Expand Down
109 changes: 109 additions & 0 deletions Lib/test/test_external_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,115 @@ def test_unsupported_platform_error(self):
str(cm.exception)
)

class TestDetectionOfThreadStatus(unittest.TestCase):
@unittest.skipIf(
sys.platform not in ("linux", "darwin", "win32"),
"Test only runs on unsupported platforms (not Linux, macOS, or Windows)",
)
@unittest.skipIf(sys.platform == "android", "Android raises Linux-specific exception")
def test_thread_status_detection(self):
port = find_unused_port()
script = textwrap.dedent(
f"""\
import time, sys, socket, threading
import os

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))

def sleeper():
tid = threading.get_native_id()
sock.sendall(f'ready:sleeper:{{tid}}\\n'.encode())
time.sleep(10000)

def busy():
tid = threading.get_native_id()
sock.sendall(f'ready:busy:{{tid}}\\n'.encode())
x = 0
while True:
x = x + 1
time.sleep(0.5)

t1 = threading.Thread(target=sleeper)
t2 = threading.Thread(target=busy)
t1.start()
t2.start()
sock.sendall(b'ready:main\\n')
t1.join()
t2.join()
sock.close()
"""
)
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(1)

script_name = _make_test_script(script_dir, "thread_status_script", script)
client_socket = None
try:
p = subprocess.Popen([sys.executable, script_name])
client_socket, _ = server_socket.accept()
server_socket.close()
response = b""
sleeper_tid = None
busy_tid = None
while True:
chunk = client_socket.recv(1024)
response += chunk
if b"ready:main" in response and b"ready:sleeper" in response and b"ready:busy" in response:
# Parse TIDs from the response
for line in response.split(b"\n"):
if line.startswith(b"ready:sleeper:"):
try:
sleeper_tid = int(line.split(b":")[-1])
except Exception:
pass
elif line.startswith(b"ready:busy:"):
try:
busy_tid = int(line.split(b":")[-1])
except Exception:
pass
break

attempts = 10
try:
unwinder = RemoteUnwinder(p.pid, all_threads=True, cpu_time=True)
for _ in range(attempts):
traces = unwinder.get_stack_trace()
# Check if any thread is running
if any(thread_info.status == 0 for interpreter_info in traces
for thread_info in interpreter_info.threads):
break
time.sleep(0.5) # Give a bit of time to let threads settle
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)


# Find threads and their statuses
statuses = {}
for interpreter_info in traces:
for thread_info in interpreter_info.threads:
statuses[thread_info.thread_id] = thread_info.status

self.assertIsNotNone(sleeper_tid, "Sleeper thread id not received")
self.assertIsNotNone(busy_tid, "Busy thread id not received")
self.assertIn(sleeper_tid, statuses, "Sleeper tid not found in sampled threads")
self.assertIn(busy_tid, statuses, "Busy tid not found in sampled threads")
self.assertEqual(statuses[sleeper_tid], 1, "Sleeper thread should be idle (1)")
self.assertEqual(statuses[busy_tid], 0, "Busy thread should be running (0)")

finally:
if client_socket is not None:
client_socket.close()
p.terminate()
p.wait(timeout=SHORT_TIMEOUT)

if __name__ == "__main__":
unittest.main()
Loading
Loading