Skip to content

Conversation

DhyeyMavani2003
Copy link

Summary

This PR adds support for using Aggregation modules (like AttentionalAggregation) in HeteroConv for aggregating node embeddings from different edge types, addressing feature request #7414.

Motivation

Previously, HeteroConv only supported simple string-based aggregations ("sum", "mean", "min", "max", "cat", None) when combining node features from multiple edge types. This limitation prevented users from leveraging more sophisticated aggregation schemes like AttentionalAggregation, which can learn to weight the importance of different edge types dynamically.

Changes

Core Implementation

torch_geometric/nn/conv/hetero_conv.py:

  • Modified group() function to handle Aggregation module instances
    • Properly reshapes tensors to aggregate per-node across different relations
    • Creates appropriate index tensors for the aggregation operation
  • Updated HeteroConv.__init__() to accept both string and Aggregation instances
    • Added aggr_module attribute to store Aggregation instances
    • Maintains backward compatibility with string-based aggregations
  • Enhanced reset_parameters() to reset learnable parameters in aggregation modules
  • Updated docstring with example usage and expanded parameter documentation

Testing

test/nn/conv/test_hetero_conv.py:

  • Added test_hetero_conv_with_attentional_aggregation() - Tests basic AttentionalAggregation usage
  • Added test_hetero_conv_with_attentional_aggregation_and_nn() - Tests AttentionalAggregation with both gate_nn and nn parameters
  • Added test_hetero_conv_with_aggregation_modules() - Tests other Aggregation modules (MaxAggregation, MeanAggregation)

Usage Example

from torch_geometric.nn import HeteroConv, GraphConv, AttentionalAggregation
import torch

# Create AttentionalAggregation with a gate network
gate_nn = torch.nn.Linear(64, 1)
aggr = AttentionalAggregation(gate_nn)

# Use it in HeteroConv
conv = HeteroConv({
    ('paper', 'cites', 'paper'): GraphConv(-1, 64),
    ('author', 'writes', 'paper'): GraphConv(-1, 64),
}, aggr=aggr)

# Forward pass
out_dict = conv(x_dict, edge_index_dict)

Backward Compatibility

✅ All existing tests pass without modification
✅ String-based aggregations continue to work as before
✅ No breaking changes to the API

Testing

All tests pass:

  • ✅ 14/14 tests in test_hetero_conv.py
  • ✅ Pre-commit hooks (formatting, linting, type checking)
  • ✅ Aggregation module tests

Related Issues

Fixes #7414

Checklist

  • Implementation follows existing code patterns
  • Comprehensive tests added
  • Documentation updated with examples
  • All tests pass
  • Pre-commit hooks pass
  • Backward compatibility maintained

This commit adds support for using Aggregation modules (like
AttentionalAggregation) in HeteroConv for aggregating node embeddings
from different edge types.

Changes:
- Modified HeteroConv to accept Aggregation instances in addition to
  string aggregation types
- Updated the group() function to handle Aggregation modules by
  properly reshaping tensors and creating index tensors for per-node
  aggregation across different relations
- Added reset_parameters() support for aggregation modules
- Updated documentation with example usage
- Added comprehensive tests for AttentionalAggregation and other
  Aggregation modules

Fixes pyg-team#7414

Co-authored-by: Ona <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support AttentionalAggregation for HeteroConv
1 participant