diff --git a/test/nn/conv/test_hetero_conv.py b/test/nn/conv/test_hetero_conv.py index a2def84a0b5a..44c797f36993 100644 --- a/test/nn/conv/test_hetero_conv.py +++ b/test/nn/conv/test_hetero_conv.py @@ -14,6 +14,11 @@ MessagePassing, SAGEConv, ) +from torch_geometric.nn.aggr import ( + AttentionalAggregation, + MaxAggregation, + MeanAggregation, +) from torch_geometric.profile import benchmark from torch_geometric.testing import ( get_random_edge_index, @@ -178,6 +183,120 @@ def test_hetero_conv_with_dot_syntax_node_types(): assert out_dict['author'].size() == (30, 64) +def test_hetero_conv_with_attentional_aggregation(): + data = HeteroData() + data['paper'].x = torch.randn(50, 32) + data['author'].x = torch.randn(30, 64) + data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200) + data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100) + data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100) + + # Test with AttentionalAggregation + gate_nn = torch.nn.Linear(64, 1) + aggr = AttentionalAggregation(gate_nn) + + conv = HeteroConv( + { + ('paper', 'to', 'paper'): GCNConv(-1, 64), + ('author', 'to', 'paper'): SAGEConv((-1, -1), 64), + ('paper', 'to', 'author'): GATConv( + (-1, -1), 64, add_self_loops=False), + }, + aggr=aggr, + ) + + # Check that parameters include both conv and aggregation parameters + assert len(list(conv.parameters())) > 0 + assert str(conv) == 'HeteroConv(num_relations=3)' + + out_dict = conv( + data.x_dict, + data.edge_index_dict, + ) + + assert len(out_dict) == 2 + assert out_dict['paper'].size() == (50, 64) + assert out_dict['author'].size() == (30, 64) + + # Test reset_parameters + conv.reset_parameters() + + +def test_hetero_conv_with_attentional_aggregation_and_nn(): + data = HeteroData() + data['paper'].x = torch.randn(50, 32) + data['author'].x = torch.randn(30, 64) + data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200) + data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100) + data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100) + + # Test with AttentionalAggregation with both gate_nn and nn + gate_nn = torch.nn.Linear(64, 1) + nn = torch.nn.Linear(64, 64) + aggr = AttentionalAggregation(gate_nn, nn) + + conv = HeteroConv( + { + ('paper', 'to', 'paper'): GCNConv(-1, 64), + ('author', 'to', 'paper'): SAGEConv((-1, -1), 64), + ('paper', 'to', 'author'): GATConv( + (-1, -1), 64, add_self_loops=False), + }, + aggr=aggr, + ) + + out_dict = conv( + data.x_dict, + data.edge_index_dict, + ) + + assert len(out_dict) == 2 + assert out_dict['paper'].size() == (50, 64) + assert out_dict['author'].size() == (30, 64) + + +def test_hetero_conv_with_aggregation_modules(): + """Test HeteroConv with various Aggregation module types.""" + data = HeteroData() + data['paper'].x = torch.randn(50, 32) + data['author'].x = torch.randn(30, 64) + data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200) + data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100) + data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100) + + # Test with MaxAggregation + conv = HeteroConv( + { + ('paper', 'to', 'paper'): GCNConv(-1, 64), + ('author', 'to', 'paper'): SAGEConv((-1, -1), 64), + ('paper', 'to', 'author'): GATConv( + (-1, -1), 64, add_self_loops=False), + }, + aggr=MaxAggregation(), + ) + + out_dict = conv(data.x_dict, data.edge_index_dict) + assert len(out_dict) == 2 + assert out_dict['paper'].size() == (50, 64) + assert out_dict['author'].size() == (30, 64) + + # Test with MeanAggregation + conv = HeteroConv( + { + ('paper', 'to', 'paper'): GCNConv(-1, 64), + ('author', 'to', 'paper'): SAGEConv((-1, -1), 64), + ('paper', 'to', 'author'): GATConv( + (-1, -1), 64, add_self_loops=False), + }, + aggr=MeanAggregation(), + ) + + out_dict = conv(data.x_dict, data.edge_index_dict) + assert len(out_dict) == 2 + assert out_dict['paper'].size() == (50, 64) + assert out_dict['author'].size() == (30, 64) + + @withDevice @onlyLinux @withPackage('torch>=2.1.0') diff --git a/torch_geometric/nn/conv/hetero_conv.py b/torch_geometric/nn/conv/hetero_conv.py index 9f823ea78627..05aa0862ac46 100644 --- a/torch_geometric/nn/conv/hetero_conv.py +++ b/torch_geometric/nn/conv/hetero_conv.py @@ -1,22 +1,45 @@ import warnings -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import torch from torch import Tensor +from torch_geometric.nn.aggr import Aggregation from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.module_dict import ModuleDict +from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver from torch_geometric.typing import EdgeType, NodeType from torch_geometric.utils.hetero import check_add_self_loops -def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]: +def group( + xs: List[Tensor], + aggr: Optional[Union[str, Aggregation]], +) -> Optional[Tensor]: if len(xs) == 0: return None elif aggr is None: return torch.stack(xs, dim=1) elif len(xs) == 1: return xs[0] + elif isinstance(aggr, Aggregation): + # For Aggregation modules, stack tensors and aggregate along dim=1 + # This treats each relation as a separate element to aggregate per node + out = torch.stack(xs, + dim=1) # [num_nodes, num_relations, num_features] + batch_size = out.size(0) + num_relations = out.size(1) + num_features = out.size(2) + + # Reshape to [num_nodes * num_relations, num_features] + out = out.view(-1, num_features) + + # Create index that groups by node + index = torch.arange( + batch_size, device=out.device).repeat_interleave(num_relations) + + # Apply aggregation + return aggr(out, index, dim_size=batch_size, dim=0) elif aggr == "cat": return torch.cat(xs, dim=-1) else: @@ -50,20 +73,35 @@ class HeteroConv(torch.nn.Module): print(list(out_dict.keys())) >>> ['paper', 'author'] + Alternatively, you can use an :class:`~torch_geometric.nn.aggr.Aggregation` + module for more complex aggregation schemes: + + .. code-block:: python + + from torch_geometric.nn import AttentionalAggregation + + gate_nn = torch.nn.Linear(64, 1) + hetero_conv = HeteroConv({ + ('paper', 'cites', 'paper'): GCNConv(-1, 64), + ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64), + }, aggr=AttentionalAggregation(gate_nn)) + Args: convs (Dict[Tuple[str, str, str], MessagePassing]): A dictionary holding a bipartite :class:`~torch_geometric.nn.conv.MessagePassing` layer for each individual edge type. - aggr (str, optional): The aggregation scheme to use for grouping node - embeddings generated by different relations - (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, - :obj:`"cat"`, :obj:`None`). (default: :obj:`"sum"`) + aggr (str or Aggregation, optional): The aggregation scheme to use + for grouping node embeddings generated by different relations. + Can be one of :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, + :obj:`"max"`, :obj:`"cat"`, :obj:`None`, or an instance of + :class:`~torch_geometric.nn.aggr.Aggregation`. + (default: :obj:`"sum"`) """ def __init__( self, convs: Dict[EdgeType, MessagePassing], - aggr: Optional[str] = "sum", + aggr: Optional[Union[str, Aggregation]] = "sum", ): super().__init__() @@ -81,12 +119,25 @@ def __init__( stacklevel=2) self.convs = ModuleDict(convs) - self.aggr = aggr + + # Resolve aggregation to support both string and Aggregation instances + if isinstance(aggr, str) or aggr is None: + self.aggr = aggr + self.aggr_module = None + elif isinstance(aggr, Aggregation): + self.aggr = None + self.aggr_module = aggr + else: + # Try to resolve using aggr_resolver for other types + self.aggr_module = aggr_resolver(aggr) + self.aggr = None if self.aggr_module is not None else aggr def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for conv in self.convs.values(): conv.reset_parameters() + if self.aggr_module is not None: + self.aggr_module.reset_parameters() def forward( self, @@ -163,8 +214,10 @@ def forward( else: out_dict[dst].append(out) + # Use aggr_module if available, otherwise use aggr string + aggr = self.aggr_module if self.aggr_module is not None else self.aggr for key, value in out_dict.items(): - out_dict[key] = group(value, self.aggr) + out_dict[key] = group(value, aggr) return out_dict