Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions examples/hetero_conv_tensorboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Example: Using HeteroConv with TensorBoard Visualization.

==========================================================

This example demonstrates how to use the `jit_trace_friendly` wrapper
to visualize HeteroConv models with TensorBoard, resolving the issue
where torch.jit.trace doesn't support tuple dictionary keys.

Issue: https://github.com/pyg-team/pytorch_geometric/issues/10421
"""

import torch
from torch.utils.tensorboard import SummaryWriter

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import GATConv, GCNConv, HeteroConv, Linear, SAGEConv


# Define a heterogeneous GNN model
class HeteroGNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_layers):
super().__init__()
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv(
{
('paper', 'cites', 'paper'):
GCNConv(-1, hidden_channels),
('author', 'writes', 'paper'):
SAGEConv((-1, -1), hidden_channels),
('paper', 'rev_writes', 'author'):
GATConv((-1, -1), hidden_channels, add_self_loops=False),
}, aggr='sum')
self.convs.append(conv)
self.lin = Linear(hidden_channels, out_channels)

def forward(self, x_dict, edge_index_dict):
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: x.relu() for key, x in x_dict.items()}
return self.lin(x_dict['author'])


def main():
# Load a heterogeneous graph dataset
print("Loading dataset...")
dataset = OGB_MAG(root='./data', preprocess='metapath2vec',
transform=T.ToUndirected())
data = dataset[0]

# Instantiate the model and initialize lazy modules
print("Initializing model...")
model = HeteroGNN(hidden_channels=64, out_channels=dataset.num_classes,
num_layers=2)
with torch.no_grad():
_ = model(data.x_dict, data.edge_index_dict)

# Prepare list-based inputs for the wrapper
x_list = list(data.x_dict.values())
x_dict_keys = list(data.x_dict.keys())
edge_index_list = list(data.edge_index_dict.values())
edge_index_dict_keys = list(data.edge_index_dict.keys())

print(f"\nNode types: {x_dict_keys}")
print(f"Number of edge types: {len(edge_index_dict_keys)}")

# Create a JIT-friendly wrapper for a single HeteroConv layer
print("\nCreating JIT-friendly wrapper...")
single_conv = model.convs[0]
wrapped_conv = single_conv.jit_trace_friendly(x_dict_keys,
edge_index_dict_keys)

# Test forward pass
print("Testing forward pass...")
with torch.no_grad():
out_list = wrapped_conv(x_list, edge_index_list)
print("✅ Forward pass successful!")
print(f" Output list length: {len(out_list)}")
print(f" Output shapes: {[o.shape for o in out_list]}")

# Visualize with TensorBoard
print("\nAdding graph to TensorBoard...")
writer = SummaryWriter('runs/hetero_conv_example')
try:
writer.add_graph(wrapped_conv, (x_list, edge_index_list))
print("✅ TensorBoard visualization successful!")
print(" Run 'tensorboard --logdir=runs' to view the graph")
except Exception as e:
print(f"❌ Error: {e}")
finally:
writer.close()

# Test torch.jit.trace compatibility
print("\nTesting torch.jit.trace compatibility...")
try:
traced_model = torch.jit.trace(wrapped_conv, (x_list, edge_index_list))
print("✅ torch.jit.trace successful!")

# Verify traced model works
with torch.no_grad():
traced_model(x_list, edge_index_list)
print("✅ Traced model forward pass successful!")
except Exception as e:
print(f"❌ Error: {e}")

print("\n" + "=" * 70)
print("Example completed successfully!")
print("=" * 70)


if __name__ == '__main__':
main()
87 changes: 87 additions & 0 deletions test/nn/conv/test_hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,93 @@ def test_hetero_conv_with_dot_syntax_node_types():
assert out_dict['author'].size() == (30, 64)


def test_hetero_conv_jit_trace_friendly():
"""Test the jit_trace_friendly wrapper for TensorBoard compatibility."""
data = HeteroData()
data['paper'].x = torch.randn(50, 32)
data['author'].x = torch.randn(30, 64)
data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200)
data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100)
data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100)

conv = HeteroConv({
('paper', 'to', 'paper'):
GCNConv(-1, 64),
('author', 'to', 'paper'):
SAGEConv((-1, -1), 64),
('paper', 'to', 'author'):
GATConv((-1, -1), 64, add_self_loops=False),
})

# Initialize lazy modules
_ = conv(data.x_dict, data.edge_index_dict)

# Prepare list-based inputs
x_list = list(data.x_dict.values())
x_dict_keys = list(data.x_dict.keys())
edge_index_list = list(data.edge_index_dict.values())
edge_index_dict_keys = list(data.edge_index_dict.keys())

# Create wrapper
wrapper = conv.jit_trace_friendly(x_dict_keys, edge_index_dict_keys)
assert str(wrapper).startswith('HeteroConvWrapper')

# Test forward pass
out_list = wrapper(x_list, edge_index_list)
assert isinstance(out_list, list)
assert len(out_list) == 2 # Two destination node types

# Test that output matches original conv
out_dict = conv(data.x_dict, data.edge_index_dict)
out_keys = sorted(out_dict.keys())
for i, key in enumerate(out_keys):
assert torch.allclose(out_list[i], out_dict[key])

# Test torch.jit.trace compatibility
traced = torch.jit.trace(wrapper, (x_list, edge_index_list))
traced_out = traced(x_list, edge_index_list)
assert isinstance(traced_out, list)
assert len(traced_out) == len(out_list)
for i in range(len(out_list)):
assert torch.allclose(traced_out[i], out_list[i], atol=1e-6)


@withPackage('tensorboard')
def test_hetero_conv_tensorboard():
"""Test TensorBoard add_graph with HeteroConv wrapper."""
from torch.utils.tensorboard import SummaryWriter

data = HeteroData()
data['paper'].x = torch.randn(50, 32)
data['author'].x = torch.randn(30, 64)
data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200)
data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100)

conv = HeteroConv({
('paper', 'to', 'paper'): GCNConv(-1, 64),
('paper', 'to', 'author'): SAGEConv((-1, -1), 64),
})

# Initialize lazy modules
_ = conv(data.x_dict, data.edge_index_dict)

# Prepare list-based inputs
x_list = list(data.x_dict.values())
x_dict_keys = list(data.x_dict.keys())
edge_index_list = list(data.edge_index_dict.values())
edge_index_dict_keys = list(data.edge_index_dict.keys())

# Create wrapper and test with TensorBoard
wrapper = conv.jit_trace_friendly(x_dict_keys, edge_index_dict_keys)

writer = SummaryWriter()
try:
writer.add_graph(wrapper, (x_list, edge_index_list))
# If we get here without exception, the test passes
finally:
writer.close()


@withDevice
@onlyLinux
@withPackage('torch>=2.1.0')
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from .general_conv import GeneralConv
from .hgt_conv import HGTConv
from .heat_conv import HEATConv
from .hetero_conv import HeteroConv
from .hetero_conv import HeteroConv, HeteroConvWrapper
from .han_conv import HANConv
from .lg_conv import LGConv
from .ssg_conv import SSGConv
Expand Down Expand Up @@ -125,6 +125,7 @@
'HGTConv',
'HEATConv',
'HeteroConv',
'HeteroConvWrapper',
'HANConv',
'LGConv',
'PointGNNConv',
Expand Down
132 changes: 132 additions & 0 deletions torch_geometric/nn/conv/hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@ class HeteroConv(torch.nn.Module):
print(list(out_dict.keys()))
>>> ['paper', 'author']

For TensorBoard visualization or :meth:`torch.jit.trace` compatibility,
use the :meth:`jit_trace_friendly` method:

.. code-block:: python

from torch.utils.tensorboard import SummaryWriter

# Prepare list-based inputs
x_list = list(x_dict.values())
x_dict_keys = list(x_dict.keys())
edge_index_list = list(edge_index_dict.values())
edge_index_dict_keys = list(edge_index_dict.keys())

# Create JIT-friendly wrapper
wrapper = hetero_conv.jit_trace_friendly(x_dict_keys,
edge_index_dict_keys)

# Now compatible with TensorBoard and torch.jit.trace
writer = SummaryWriter()
writer.add_graph(wrapper, (x_list, edge_index_list))
writer.close()

Args:
convs (Dict[Tuple[str, str, str], MessagePassing]): A dictionary
holding a bipartite
Expand Down Expand Up @@ -168,5 +190,115 @@ def forward(

return out_dict

def jit_trace_friendly(
self,
x_dict_keys: List[NodeType],
edge_index_dict_keys: List[EdgeType],
) -> 'HeteroConvWrapper':
r"""Returns a JIT trace-friendly wrapper for this module.

This method creates a wrapper that accepts lists of tensors instead
of dictionaries with tuple keys, making it compatible with
:meth:`torch.jit.trace` and
:meth:`torch.utils.tensorboard.SummaryWriter.add_graph`.

Args:
x_dict_keys (List[str]): The ordered list of node types
corresponding to the node feature tensors.
edge_index_dict_keys (List[Tuple[str, str, str]]): The ordered
list of edge types corresponding to the edge index tensors.

Returns:
HeteroConvWrapper: A wrapper module that can be traced by
:meth:`torch.jit.trace`.

Example:
>>> from torch_geometric.nn import HeteroConv, GCNConv
>>> conv = HeteroConv({
... ('paper', 'cites', 'paper'): GCNConv(-1, 64),
... ('author', 'writes', 'paper'): GCNConv((-1, -1), 64),
... })
>>> x_dict_keys = ['paper', 'author']
>>> edge_index_dict_keys = [('paper', 'cites', 'paper'),
... ('author', 'writes', 'paper')]
>>> wrapper = conv.jit_trace_friendly(x_dict_keys,
... edge_index_dict_keys)
>>> # Now you can use torch.jit.trace or TensorBoard's add_graph
>>> traced = torch.jit.trace(wrapper, (x_list, edge_index_list))
"""
return HeteroConvWrapper(self, x_dict_keys, edge_index_dict_keys)

def __repr__(self) -> str:
return f'{self.__class__.__name__}(num_relations={len(self.convs)})'


class HeteroConvWrapper(torch.nn.Module):
r"""A wrapper for :class:`HeteroConv` that is compatible with
:meth:`torch.jit.trace`.

This wrapper converts list-based inputs to dictionary-based inputs
internally, allowing the module to be traced by
:meth:`torch.jit.trace` and visualized with
:meth:`torch.utils.tensorboard.SummaryWriter.add_graph`.

.. note::
This class is typically created via
:meth:`HeteroConv.jit_trace_friendly` and should not be instantiated
directly.

Args:
hetero_conv (HeteroConv): The :class:`HeteroConv` module to wrap.
x_dict_keys (List[str]): The ordered list of node types.
edge_index_dict_keys (List[Tuple[str, str, str]]): The ordered list
of edge types.
"""
def __init__(
self,
hetero_conv: HeteroConv,
x_dict_keys: List[NodeType],
edge_index_dict_keys: List[EdgeType],
):
super().__init__()
self.hetero_conv = hetero_conv
self.x_dict_keys = x_dict_keys
self.edge_index_dict_keys = edge_index_dict_keys

def forward(
self,
x_list: List[Tensor],
edge_index_list: List[Tensor],
) -> List[Tensor]:
r"""Forward pass that converts lists to dicts and back.

Args:
x_list (List[Tensor]): List of node feature tensors, ordered
according to :attr:`x_dict_keys`.
edge_index_list (List[Tensor]): List of edge index tensors,
ordered according to :attr:`edge_index_dict_keys`.

Returns:
List[Tensor]: List of output node feature tensors, ordered
according to the unique destination node types in
:attr:`edge_index_dict_keys`.
"""
# Reconstruct dictionaries from lists
x_dict = {key: tensor for key, tensor in zip(self.x_dict_keys, x_list)}
edge_index_dict = {
key: tensor
for key, tensor in zip(self.edge_index_dict_keys, edge_index_list)
}

# Call the original HeteroConv forward method
out_dict = self.hetero_conv(x_dict, edge_index_dict)

# Convert output dict back to list (maintain consistent ordering)
out_keys = sorted(out_dict.keys())
out_list = [out_dict[key] for key in out_keys]

return out_list

def __repr__(self) -> str:
return (f'{self.__class__.__name__}('
f'hetero_conv={self.hetero_conv}, '
f'num_node_types={len(self.x_dict_keys)}, '
f'num_edge_types={len(self.edge_index_dict_keys)})')