Add jit_trace_friendly wrapper for HeteroConv to support TensorBoard #10481
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Fixes #10421
This PR resolves the issue where
torch.utils.tensorboard.SummaryWriter.add_graph
fails withHeteroConv
models due totorch.jit.trace
not supporting dictionaries with tuple keys.Problem
When attempting to visualize a
HeteroConv
model with TensorBoard, the operation fails with:This occurs because:
add_graph
usestorch.jit.trace
internallytorch.jit.trace
doesn't support dictionaries with tuple keysHeteroConv
requiresedge_index_dict
with tuple keys like(str, str, str)
for edge typesSolution
Added a
jit_trace_friendly()
method toHeteroConv
that returns aHeteroConvWrapper
instance. This wrapper:torch.jit.trace
and TensorBoard'sadd_graph
Usage Example
Changes
jit_trace_friendly()
method andHeteroConvWrapper
classHeteroConvWrapper
Testing
All tests pass successfully:
test_hetero_conv.py
- all passingBenefits
HeteroConv
API remains unchanged