diff --git a/examples/hetero_conv_tensorboard.py b/examples/hetero_conv_tensorboard.py new file mode 100644 index 000000000000..6a7bf8e97fcc --- /dev/null +++ b/examples/hetero_conv_tensorboard.py @@ -0,0 +1,113 @@ +"""Example: Using HeteroConv with TensorBoard Visualization. + +========================================================== + +This example demonstrates how to use the `jit_trace_friendly` wrapper +to visualize HeteroConv models with TensorBoard, resolving the issue +where torch.jit.trace doesn't support tuple dictionary keys. + +Issue: https://github.com/pyg-team/pytorch_geometric/issues/10421 +""" + +import torch +from torch.utils.tensorboard import SummaryWriter + +import torch_geometric.transforms as T +from torch_geometric.datasets import OGB_MAG +from torch_geometric.nn import GATConv, GCNConv, HeteroConv, Linear, SAGEConv + + +# Define a heterogeneous GNN model +class HeteroGNN(torch.nn.Module): + def __init__(self, hidden_channels, out_channels, num_layers): + super().__init__() + self.convs = torch.nn.ModuleList() + for _ in range(num_layers): + conv = HeteroConv( + { + ('paper', 'cites', 'paper'): + GCNConv(-1, hidden_channels), + ('author', 'writes', 'paper'): + SAGEConv((-1, -1), hidden_channels), + ('paper', 'rev_writes', 'author'): + GATConv((-1, -1), hidden_channels, add_self_loops=False), + }, aggr='sum') + self.convs.append(conv) + self.lin = Linear(hidden_channels, out_channels) + + def forward(self, x_dict, edge_index_dict): + for conv in self.convs: + x_dict = conv(x_dict, edge_index_dict) + x_dict = {key: x.relu() for key, x in x_dict.items()} + return self.lin(x_dict['author']) + + +def main(): + # Load a heterogeneous graph dataset + print("Loading dataset...") + dataset = OGB_MAG(root='./data', preprocess='metapath2vec', + transform=T.ToUndirected()) + data = dataset[0] + + # Instantiate the model and initialize lazy modules + print("Initializing model...") + model = HeteroGNN(hidden_channels=64, out_channels=dataset.num_classes, + num_layers=2) + with torch.no_grad(): + _ = model(data.x_dict, data.edge_index_dict) + + # Prepare list-based inputs for the wrapper + x_list = list(data.x_dict.values()) + x_dict_keys = list(data.x_dict.keys()) + edge_index_list = list(data.edge_index_dict.values()) + edge_index_dict_keys = list(data.edge_index_dict.keys()) + + print(f"\nNode types: {x_dict_keys}") + print(f"Number of edge types: {len(edge_index_dict_keys)}") + + # Create a JIT-friendly wrapper for a single HeteroConv layer + print("\nCreating JIT-friendly wrapper...") + single_conv = model.convs[0] + wrapped_conv = single_conv.jit_trace_friendly(x_dict_keys, + edge_index_dict_keys) + + # Test forward pass + print("Testing forward pass...") + with torch.no_grad(): + out_list = wrapped_conv(x_list, edge_index_list) + print("✅ Forward pass successful!") + print(f" Output list length: {len(out_list)}") + print(f" Output shapes: {[o.shape for o in out_list]}") + + # Visualize with TensorBoard + print("\nAdding graph to TensorBoard...") + writer = SummaryWriter('runs/hetero_conv_example') + try: + writer.add_graph(wrapped_conv, (x_list, edge_index_list)) + print("✅ TensorBoard visualization successful!") + print(" Run 'tensorboard --logdir=runs' to view the graph") + except Exception as e: + print(f"❌ Error: {e}") + finally: + writer.close() + + # Test torch.jit.trace compatibility + print("\nTesting torch.jit.trace compatibility...") + try: + traced_model = torch.jit.trace(wrapped_conv, (x_list, edge_index_list)) + print("✅ torch.jit.trace successful!") + + # Verify traced model works + with torch.no_grad(): + traced_model(x_list, edge_index_list) + print("✅ Traced model forward pass successful!") + except Exception as e: + print(f"❌ Error: {e}") + + print("\n" + "=" * 70) + print("Example completed successfully!") + print("=" * 70) + + +if __name__ == '__main__': + main() diff --git a/test/nn/conv/test_hetero_conv.py b/test/nn/conv/test_hetero_conv.py index a2def84a0b5a..aff8d432719a 100644 --- a/test/nn/conv/test_hetero_conv.py +++ b/test/nn/conv/test_hetero_conv.py @@ -178,6 +178,93 @@ def test_hetero_conv_with_dot_syntax_node_types(): assert out_dict['author'].size() == (30, 64) +def test_hetero_conv_jit_trace_friendly(): + """Test the jit_trace_friendly wrapper for TensorBoard compatibility.""" + data = HeteroData() + data['paper'].x = torch.randn(50, 32) + data['author'].x = torch.randn(30, 64) + data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200) + data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100) + data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100) + + conv = HeteroConv({ + ('paper', 'to', 'paper'): + GCNConv(-1, 64), + ('author', 'to', 'paper'): + SAGEConv((-1, -1), 64), + ('paper', 'to', 'author'): + GATConv((-1, -1), 64, add_self_loops=False), + }) + + # Initialize lazy modules + _ = conv(data.x_dict, data.edge_index_dict) + + # Prepare list-based inputs + x_list = list(data.x_dict.values()) + x_dict_keys = list(data.x_dict.keys()) + edge_index_list = list(data.edge_index_dict.values()) + edge_index_dict_keys = list(data.edge_index_dict.keys()) + + # Create wrapper + wrapper = conv.jit_trace_friendly(x_dict_keys, edge_index_dict_keys) + assert str(wrapper).startswith('HeteroConvWrapper') + + # Test forward pass + out_list = wrapper(x_list, edge_index_list) + assert isinstance(out_list, list) + assert len(out_list) == 2 # Two destination node types + + # Test that output matches original conv + out_dict = conv(data.x_dict, data.edge_index_dict) + out_keys = sorted(out_dict.keys()) + for i, key in enumerate(out_keys): + assert torch.allclose(out_list[i], out_dict[key]) + + # Test torch.jit.trace compatibility + traced = torch.jit.trace(wrapper, (x_list, edge_index_list)) + traced_out = traced(x_list, edge_index_list) + assert isinstance(traced_out, list) + assert len(traced_out) == len(out_list) + for i in range(len(out_list)): + assert torch.allclose(traced_out[i], out_list[i], atol=1e-6) + + +@withPackage('tensorboard') +def test_hetero_conv_tensorboard(): + """Test TensorBoard add_graph with HeteroConv wrapper.""" + from torch.utils.tensorboard import SummaryWriter + + data = HeteroData() + data['paper'].x = torch.randn(50, 32) + data['author'].x = torch.randn(30, 64) + data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200) + data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100) + + conv = HeteroConv({ + ('paper', 'to', 'paper'): GCNConv(-1, 64), + ('paper', 'to', 'author'): SAGEConv((-1, -1), 64), + }) + + # Initialize lazy modules + _ = conv(data.x_dict, data.edge_index_dict) + + # Prepare list-based inputs + x_list = list(data.x_dict.values()) + x_dict_keys = list(data.x_dict.keys()) + edge_index_list = list(data.edge_index_dict.values()) + edge_index_dict_keys = list(data.edge_index_dict.keys()) + + # Create wrapper and test with TensorBoard + wrapper = conv.jit_trace_friendly(x_dict_keys, edge_index_dict_keys) + + writer = SummaryWriter() + try: + writer.add_graph(wrapper, (x_list, edge_index_list)) + # If we get here without exception, the test passes + finally: + writer.close() + + @withDevice @onlyLinux @withPackage('torch>=2.1.0') diff --git a/torch_geometric/nn/conv/__init__.py b/torch_geometric/nn/conv/__init__.py index 871374060055..e65a059ec715 100644 --- a/torch_geometric/nn/conv/__init__.py +++ b/torch_geometric/nn/conv/__init__.py @@ -52,7 +52,7 @@ from .general_conv import GeneralConv from .hgt_conv import HGTConv from .heat_conv import HEATConv -from .hetero_conv import HeteroConv +from .hetero_conv import HeteroConv, HeteroConvWrapper from .han_conv import HANConv from .lg_conv import LGConv from .ssg_conv import SSGConv @@ -125,6 +125,7 @@ 'HGTConv', 'HEATConv', 'HeteroConv', + 'HeteroConvWrapper', 'HANConv', 'LGConv', 'PointGNNConv', diff --git a/torch_geometric/nn/conv/hetero_conv.py b/torch_geometric/nn/conv/hetero_conv.py index 9f823ea78627..74b575e871db 100644 --- a/torch_geometric/nn/conv/hetero_conv.py +++ b/torch_geometric/nn/conv/hetero_conv.py @@ -50,6 +50,28 @@ class HeteroConv(torch.nn.Module): print(list(out_dict.keys())) >>> ['paper', 'author'] + For TensorBoard visualization or :meth:`torch.jit.trace` compatibility, + use the :meth:`jit_trace_friendly` method: + + .. code-block:: python + + from torch.utils.tensorboard import SummaryWriter + + # Prepare list-based inputs + x_list = list(x_dict.values()) + x_dict_keys = list(x_dict.keys()) + edge_index_list = list(edge_index_dict.values()) + edge_index_dict_keys = list(edge_index_dict.keys()) + + # Create JIT-friendly wrapper + wrapper = hetero_conv.jit_trace_friendly(x_dict_keys, + edge_index_dict_keys) + + # Now compatible with TensorBoard and torch.jit.trace + writer = SummaryWriter() + writer.add_graph(wrapper, (x_list, edge_index_list)) + writer.close() + Args: convs (Dict[Tuple[str, str, str], MessagePassing]): A dictionary holding a bipartite @@ -168,5 +190,115 @@ def forward( return out_dict + def jit_trace_friendly( + self, + x_dict_keys: List[NodeType], + edge_index_dict_keys: List[EdgeType], + ) -> 'HeteroConvWrapper': + r"""Returns a JIT trace-friendly wrapper for this module. + + This method creates a wrapper that accepts lists of tensors instead + of dictionaries with tuple keys, making it compatible with + :meth:`torch.jit.trace` and + :meth:`torch.utils.tensorboard.SummaryWriter.add_graph`. + + Args: + x_dict_keys (List[str]): The ordered list of node types + corresponding to the node feature tensors. + edge_index_dict_keys (List[Tuple[str, str, str]]): The ordered + list of edge types corresponding to the edge index tensors. + + Returns: + HeteroConvWrapper: A wrapper module that can be traced by + :meth:`torch.jit.trace`. + + Example: + >>> from torch_geometric.nn import HeteroConv, GCNConv + >>> conv = HeteroConv({ + ... ('paper', 'cites', 'paper'): GCNConv(-1, 64), + ... ('author', 'writes', 'paper'): GCNConv((-1, -1), 64), + ... }) + >>> x_dict_keys = ['paper', 'author'] + >>> edge_index_dict_keys = [('paper', 'cites', 'paper'), + ... ('author', 'writes', 'paper')] + >>> wrapper = conv.jit_trace_friendly(x_dict_keys, + ... edge_index_dict_keys) + >>> # Now you can use torch.jit.trace or TensorBoard's add_graph + >>> traced = torch.jit.trace(wrapper, (x_list, edge_index_list)) + """ + return HeteroConvWrapper(self, x_dict_keys, edge_index_dict_keys) + def __repr__(self) -> str: return f'{self.__class__.__name__}(num_relations={len(self.convs)})' + + +class HeteroConvWrapper(torch.nn.Module): + r"""A wrapper for :class:`HeteroConv` that is compatible with + :meth:`torch.jit.trace`. + + This wrapper converts list-based inputs to dictionary-based inputs + internally, allowing the module to be traced by + :meth:`torch.jit.trace` and visualized with + :meth:`torch.utils.tensorboard.SummaryWriter.add_graph`. + + .. note:: + This class is typically created via + :meth:`HeteroConv.jit_trace_friendly` and should not be instantiated + directly. + + Args: + hetero_conv (HeteroConv): The :class:`HeteroConv` module to wrap. + x_dict_keys (List[str]): The ordered list of node types. + edge_index_dict_keys (List[Tuple[str, str, str]]): The ordered list + of edge types. + """ + def __init__( + self, + hetero_conv: HeteroConv, + x_dict_keys: List[NodeType], + edge_index_dict_keys: List[EdgeType], + ): + super().__init__() + self.hetero_conv = hetero_conv + self.x_dict_keys = x_dict_keys + self.edge_index_dict_keys = edge_index_dict_keys + + def forward( + self, + x_list: List[Tensor], + edge_index_list: List[Tensor], + ) -> List[Tensor]: + r"""Forward pass that converts lists to dicts and back. + + Args: + x_list (List[Tensor]): List of node feature tensors, ordered + according to :attr:`x_dict_keys`. + edge_index_list (List[Tensor]): List of edge index tensors, + ordered according to :attr:`edge_index_dict_keys`. + + Returns: + List[Tensor]: List of output node feature tensors, ordered + according to the unique destination node types in + :attr:`edge_index_dict_keys`. + """ + # Reconstruct dictionaries from lists + x_dict = {key: tensor for key, tensor in zip(self.x_dict_keys, x_list)} + edge_index_dict = { + key: tensor + for key, tensor in zip(self.edge_index_dict_keys, edge_index_list) + } + + # Call the original HeteroConv forward method + out_dict = self.hetero_conv(x_dict, edge_index_dict) + + # Convert output dict back to list (maintain consistent ordering) + out_keys = sorted(out_dict.keys()) + out_list = [out_dict[key] for key in out_keys] + + return out_list + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(' + f'hetero_conv={self.hetero_conv}, ' + f'num_node_types={len(self.x_dict_keys)}, ' + f'num_edge_types={len(self.edge_index_dict_keys)})')