Skip to content

Mapping between runtime and aot intermediate outputs #11624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 179 additions & 1 deletion devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import math
import sys
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union

Expand Down Expand Up @@ -72,6 +73,25 @@ class TimeScale(Enum):
}


class NodeSource(Enum):
AOT = 1
RUNTIME = 2


@dataclass
class NodeData:
"""
Each node in the graph is an instance of NodeData, which contains:
- source: A string indicating the origin of the node (either FROM_AOT or FROM_RUNTIME).
- debug_handle: A tuple representing the unique identifier for the output.
- output: The actual output data associated with the debug handle.
"""

source: NodeSource
debug_handle: tuple[int]
output: Any


def calculate_time_scale_factor(
source_time_scale: TimeScale, target_time_scale: TimeScale
) -> float:
Expand Down Expand Up @@ -489,7 +509,7 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
"""
Merge overlapping debug handles int a single key
"""
if not intermediate_outputs:
if len(intermediate_outputs) == 0:
return
# Extract and normalize into (start, end, val)
intervals = [(min(key), max(key), val) for key, val in intermediate_outputs.items()]
Expand All @@ -512,3 +532,161 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
intermediate_outputs.clear()
for start, end, val in merged_intermediate_outputs:
intermediate_outputs[tuple(range(start, end + 1))] = val


def _debug_handles_have_overlap(
aot_debug_hanlde: Tuple[int, ...], runtime_debug_handle: Tuple[int, ...]
) -> bool:
"""
Check if the AOT debug handle and the runtime debug handle have any overlap.
"""
aot_set = set(aot_debug_hanlde)
runtime_set = set(runtime_debug_handle)
return len(aot_set.intersection(runtime_set)) > 0


def _combine_debug_hanldes(debug_handles: List[Tuple[int, ...]]) -> Tuple[int, ...]:
"""Combine multiple debug handles into one debug handle"""
combined_debug_handles_set = set()
for debug_handle in debug_handles:
combined_debug_handles_set.update(set(debug_handle))
return tuple(sorted(combined_debug_handles_set))


def _combine_overlapped_intermediate_outputs(
nodes: List[Tuple[Tuple[int, ...], Any]]
) -> Tuple[Tuple[int, ...], Any]:
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
debug_handles = [debug_handle for debug_handle, _ in nodes]
outputs = [output for _, output in nodes]
combined_debug_handle = _combine_debug_hanldes(debug_handles)
output = outputs[-1] # Pick the last one
return combined_debug_handle, output


def _create_debug_handle_overlap_graph(
aot_intermediate_outputs: Dict[Tuple[int, ...], Any],
runtime_intermediate_outputs: Dict[Tuple[int, ...], Any],
) -> Tuple[List[NodeData], Dict[int, List[int]]]:
"""
Create a graph representing overlapping debug handles between AOT and runtime outputs.

Edges in the graph are represented as a dictionary where:
- The key is the index of a node in the nodes list.
- The value is a list of indices of nodes that have overlapping debug handles with the key node.

Returns:
- A tuple containing:
- A list of NodeData instances representing the nodes in the graph.
- A dictionary representing the edges, where each key-value pair indicates connected nodes due to overlapping debug handles.
"""
nodes = []
for debug_handle, output in aot_intermediate_outputs.items():
nodes.append(NodeData(NodeSource.AOT, debug_handle, output))
for debug_handle, output in runtime_intermediate_outputs.items():
nodes.append(NodeData(NodeSource.RUNTIME, debug_handle, output))

edges = {i: [] for i in range(len(nodes))}
for i in range(len(nodes)):
for j in range(i + 1, len(nodes)):
node_i = nodes[i]
node_j = nodes[j]
# Only connect nodes from different sources(aot vs runtime) that overlap
if node_i.source != node_j.source and _debug_handles_have_overlap(
node_i.debug_handle, node_j.debug_handle
):
edges[i].append(j)
edges[j].append(i)
return (nodes, edges)


def _find_connected_components(
nodes: List[NodeData], edges: Dict[int, List[int]]
) -> List[List[int]]:
"""
Find groups of connected nodes in a graph using DFS.
Parameters:
- nodes: A list of nodes in the graph.
- edges: A dictionary where each key is a node index, and the value is a list
of indices of connected nodes.
Returns:
- A list of connected components, each represented as a list of node indices.
"""
visited = [False] * len(nodes)
connected_components = []

def dfs(node_id, component):
visited[node_id] = True
component.append(node_id)
# Iterate over all neighbors of the current node
for neighbor_node_id in edges[node_id]:
# If a neighbor has not been visited yet, recursively visit it
if not visited[neighbor_node_id]:
dfs(neighbor_node_id, component)

# Perform DFS on all nodes to find connected components
for i in range(len(nodes)):
# If a node has not been visited yet, start a new DFS from it
if not visited[i]:
component = []
dfs(i, component)
# After visiting all reachable nodes, add the current component to the list
connected_components.append(component)
return connected_components


def map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs: Dict[Tuple[int, ...], Any],
runtime_intermediate_outputs: Dict[Tuple[int, ...], Any],
) -> Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]]:
"""
Map the runtime intermediate outputs to the AOT intermediate outputs
by finding overlapping debug handles and combining them into a single debug_handle

Returns:
Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]] - Mapping
from runtime intermediate output to AOT intermediate output
"""
# Merge overlapping debug handles
merge_overlapping_debug_handles(aot_intermediate_outputs)
merge_overlapping_debug_handles(runtime_intermediate_outputs)

# Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
nodes, edges = _create_debug_handle_overlap_graph(
aot_intermediate_outputs, runtime_intermediate_outputs
)
# Find connected(between aot and runtime) components
connected_components = _find_connected_components(nodes, edges)

aot_runtime_mapping = {}
for comp in connected_components:
# Separate nodes into AOT and runtime lists based on their source,
# each list is combined into a single element and mapped to each other.
aot_list = [
(nodes[node_id].debug_handle, nodes[node_id].output)
for node_id in comp
if nodes[node_id].source == NodeSource.AOT
]
runtime_list = [
(nodes[node_id].debug_handle, nodes[node_id].output)
for node_id in comp
if nodes[node_id].source == NodeSource.RUNTIME
]

# Map only if both AOT and runtime data are present.
if len(aot_list) != 0 and len(runtime_list) != 0:
# Combine aot debug handles into a single key
aot_combined_debug_handle, aot_output = (
_combine_overlapped_intermediate_outputs(aot_list)
)
# Combine runtime debug handles into a single key
runtime_combined_debug_handle, runtime_output = (
_combine_overlapped_intermediate_outputs(runtime_list)
)
# Create a mapping between runtime and aot
aot_runtime_mapping[(aot_combined_debug_handle, aot_output)] = (
runtime_combined_debug_handle,
runtime_output,
)

return aot_runtime_mapping
79 changes: 79 additions & 0 deletions devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
find_populated_event,
gen_graphs_from_etrecord,
is_inference_output_equal,
map_runtime_aot_intermediate_outputs,
merge_overlapping_debug_handles,
TimeScale,
)
Expand Down Expand Up @@ -238,6 +239,84 @@ def test_merge_overlapping_debug_handles(self):
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
self.assertIs(expected_intermediate_outputs[(10, 11, 12)], big_tensor)

def test_map_runtime_aot_intermediate_outputs_empty_inputs(self):
# When the inputs are empty, the output should also be empty
aot_intermediate_outputs = {}
runtime_intermediate_outputs = {}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self):
# Single element tuple
aot_intermediate_outputs = {(0,): 100, (1,): 200, (2,): 300}
runtime_intermediate_outputs = {(0,): 150, (1,): 250, (2,): 350}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {
((0,), 100): ((0,), 150),
((1,), 200): ((1,), 250),
((2,), 300): ((2,), 350),
}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_exact_match(self):
# Exact match between aot and runtime debug_handles
aot_intermediate_outputs = {(0, 1): 100, (2, 3): 200, (4, 5): 300}
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {
((0, 1), 100): ((0, 1), 150),
((2, 3), 200): ((2, 3), 200),
((4, 5), 300): ((4, 5), 300),
}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_no_overlaps(self):
# No overlaps between aot and runtime debug_handles
aot_intermediate_outputs = {(0, 1): 100, (4, 5): 300}
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
# Multiple aot debug_handles map to one runtime debug_handle
aot_intermediate_outputs = {(0, 1, 2): 100, (3, 4): 300}
runtime_intermediate_outputs = {(1, 2, 3): 250, (8, 9): 300}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {((0, 1, 2, 3, 4), 300): ((1, 2, 3), 250)}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime(self):
# One aot debug_handle map to multiple runtime debug_handles
aot_intermediate_outputs = {(0, 1, 2, 3, 4): 100, (8, 9): 300}
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {((0, 1, 2, 3, 4), 100): ((0, 1, 2, 3, 4, 5), 300)}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
# Complex chain (N-to-N mapping)
aot_intermediate_outputs = {(1, 2): 100, (3, 4): 200, (5, 6): 300}
runtime_intermediate_outputs = {(2, 3): 150, (4, 5): 250, (6, 7): 350}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
self.assertEqual(actual, expected)


def gen_mock_operator_graph_with_expected_map() -> (
Tuple[OperatorGraph, Dict[int, OperatorNode]]
Expand Down
Loading