Skip to content

Conversation

DhyeyMavani2003
Copy link

Description

Fixes #10421

This PR resolves the issue where torch.utils.tensorboard.SummaryWriter.add_graph fails with HeteroConv models due to torch.jit.trace not supporting dictionaries with tuple keys.

Problem

When attempting to visualize a HeteroConv model with TensorBoard, the operation fails with:

RuntimeError: Cannot create dict for key type '(str, str, str)', only int, float, complex, Tensor, device and string keys are supported

This occurs because:

  • TensorBoard's add_graph uses torch.jit.trace internally
  • torch.jit.trace doesn't support dictionaries with tuple keys
  • HeteroConv requires edge_index_dict with tuple keys like (str, str, str) for edge types

Solution

Added a jit_trace_friendly() method to HeteroConv that returns a HeteroConvWrapper instance. This wrapper:

  • Accepts lists of tensors instead of dictionaries with tuple keys
  • Internally converts lists to dictionaries for processing
  • Converts output dictionaries back to lists
  • Is fully compatible with torch.jit.trace and TensorBoard's add_graph

Usage Example

from torch.utils.tensorboard import SummaryWriter
from torch_geometric.nn import HeteroConv, GCNConv

# Create HeteroConv layer
conv = HeteroConv({
    ('paper', 'cites', 'paper'): GCNConv(-1, 64),
    ('author', 'writes', 'paper'): GCNConv((-1, -1), 64),
})

# Prepare list-based inputs
x_list = list(x_dict.values())
x_dict_keys = list(x_dict.keys())
edge_index_list = list(edge_index_dict.values())
edge_index_dict_keys = list(edge_index_dict.keys())

# Create JIT-friendly wrapper
wrapper = conv.jit_trace_friendly(x_dict_keys, edge_index_dict_keys)

# Now compatible with TensorBoard
writer = SummaryWriter()
writer.add_graph(wrapper, (x_list, edge_index_list))
writer.close()

Changes

  • torch_geometric/nn/conv/hetero_conv.py: Added jit_trace_friendly() method and HeteroConvWrapper class
  • torch_geometric/nn/conv/init.py: Exported HeteroConvWrapper
  • test/nn/conv/test_hetero_conv.py: Added comprehensive unit tests
  • examples/hetero_conv_tensorboard.py: Added example demonstrating TensorBoard visualization

Testing

All tests pass successfully:

  • ✅ 13 total tests in test_hetero_conv.py - all passing
  • ✅ 2 new tests specifically for this feature
  • ✅ Verified backward compatibility - no breaking changes
  • ✅ Tested with the exact example from the issue

Benefits

  • Non-breaking: Existing HeteroConv API remains unchanged
  • Opt-in: Users only use the wrapper when needed for TensorBoard/JIT tracing
  • Well-tested: Comprehensive unit tests and integration tests included
  • Fully documented: Docstrings and usage examples provided
  • Clean API: Simple method call to get a compatible wrapper

…visualization

Fixes pyg-team#10421

This commit resolves the issue where torch.utils.tensorboard.SummaryWriter.add_graph
fails with HeteroConv models due to torch.jit.trace not supporting dictionaries
with tuple keys.

Changes:
- Added jit_trace_friendly() method to HeteroConv class that returns a
  HeteroConvWrapper instance
- Created HeteroConvWrapper class that accepts list-based inputs instead of
  dict-based inputs, making it compatible with torch.jit.trace
- The wrapper internally converts lists to dicts for processing and converts
  output dicts back to lists
- Updated HeteroConv docstring with TensorBoard usage example
- Exported HeteroConvWrapper in torch_geometric.nn.conv.__init__
- Added comprehensive unit tests for the new functionality
- Added example script demonstrating TensorBoard visualization

The solution is:
- Non-breaking: existing HeteroConv API remains unchanged
- Opt-in: users only use the wrapper when needed for TensorBoard/JIT tracing
- Well-tested: includes unit tests and integration tests
- Fully documented: comprehensive docstrings and usage examples

Co-authored-by: Ona <[email protected]>
pre-commit-ci bot and others added 2 commits October 4, 2025 07:30
- Fix f-string without placeholders in examples/hetero_conv_tensorboard.py
- Fix line length violations in torch_geometric/nn/conv/hetero_conv.py
- Add period to docstring first line

Co-authored-by: Ona <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

tensorboard writer.add_graph fails with HeteroConv due to jit.trace not supporting tuple dictionary keys
1 participant