Add AttentionalAggregation support to HeteroConv #10486
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds support for using
Aggregation
modules (likeAttentionalAggregation
) inHeteroConv
for aggregating node embeddings from different edge types, addressing feature request #7414.Motivation
Previously,
HeteroConv
only supported simple string-based aggregations ("sum"
,"mean"
,"min"
,"max"
,"cat"
,None
) when combining node features from multiple edge types. This limitation prevented users from leveraging more sophisticated aggregation schemes likeAttentionalAggregation
, which can learn to weight the importance of different edge types dynamically.Changes
Core Implementation
torch_geometric/nn/conv/hetero_conv.py
:group()
function to handleAggregation
module instancesHeteroConv.__init__()
to accept both string andAggregation
instancesaggr_module
attribute to storeAggregation
instancesreset_parameters()
to reset learnable parameters in aggregation modulesTesting
test/nn/conv/test_hetero_conv.py
:test_hetero_conv_with_attentional_aggregation()
- Tests basicAttentionalAggregation
usagetest_hetero_conv_with_attentional_aggregation_and_nn()
- TestsAttentionalAggregation
with bothgate_nn
andnn
parameterstest_hetero_conv_with_aggregation_modules()
- Tests otherAggregation
modules (MaxAggregation
,MeanAggregation
)Usage Example
Backward Compatibility
✅ All existing tests pass without modification
✅ String-based aggregations continue to work as before
✅ No breaking changes to the API
Testing
All tests pass:
test_hetero_conv.py
Related Issues
Fixes #7414
Checklist