Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions test/nn/conv/test_hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
MessagePassing,
SAGEConv,
)
from torch_geometric.nn.aggr import (
AttentionalAggregation,
MaxAggregation,
MeanAggregation,
)
from torch_geometric.profile import benchmark
from torch_geometric.testing import (
get_random_edge_index,
Expand Down Expand Up @@ -178,6 +183,120 @@ def test_hetero_conv_with_dot_syntax_node_types():
assert out_dict['author'].size() == (30, 64)


def test_hetero_conv_with_attentional_aggregation():
data = HeteroData()
data['paper'].x = torch.randn(50, 32)
data['author'].x = torch.randn(30, 64)
data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200)
data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100)
data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100)

# Test with AttentionalAggregation
gate_nn = torch.nn.Linear(64, 1)
aggr = AttentionalAggregation(gate_nn)

conv = HeteroConv(
{
('paper', 'to', 'paper'): GCNConv(-1, 64),
('author', 'to', 'paper'): SAGEConv((-1, -1), 64),
('paper', 'to', 'author'): GATConv(
(-1, -1), 64, add_self_loops=False),
},
aggr=aggr,
)

# Check that parameters include both conv and aggregation parameters
assert len(list(conv.parameters())) > 0
assert str(conv) == 'HeteroConv(num_relations=3)'

out_dict = conv(
data.x_dict,
data.edge_index_dict,
)

assert len(out_dict) == 2
assert out_dict['paper'].size() == (50, 64)
assert out_dict['author'].size() == (30, 64)

# Test reset_parameters
conv.reset_parameters()


def test_hetero_conv_with_attentional_aggregation_and_nn():
data = HeteroData()
data['paper'].x = torch.randn(50, 32)
data['author'].x = torch.randn(30, 64)
data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200)
data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100)
data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100)

# Test with AttentionalAggregation with both gate_nn and nn
gate_nn = torch.nn.Linear(64, 1)
nn = torch.nn.Linear(64, 64)
aggr = AttentionalAggregation(gate_nn, nn)

conv = HeteroConv(
{
('paper', 'to', 'paper'): GCNConv(-1, 64),
('author', 'to', 'paper'): SAGEConv((-1, -1), 64),
('paper', 'to', 'author'): GATConv(
(-1, -1), 64, add_self_loops=False),
},
aggr=aggr,
)

out_dict = conv(
data.x_dict,
data.edge_index_dict,
)

assert len(out_dict) == 2
assert out_dict['paper'].size() == (50, 64)
assert out_dict['author'].size() == (30, 64)


def test_hetero_conv_with_aggregation_modules():
"""Test HeteroConv with various Aggregation module types."""
data = HeteroData()
data['paper'].x = torch.randn(50, 32)
data['author'].x = torch.randn(30, 64)
data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200)
data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100)
data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100)

# Test with MaxAggregation
conv = HeteroConv(
{
('paper', 'to', 'paper'): GCNConv(-1, 64),
('author', 'to', 'paper'): SAGEConv((-1, -1), 64),
('paper', 'to', 'author'): GATConv(
(-1, -1), 64, add_self_loops=False),
},
aggr=MaxAggregation(),
)

out_dict = conv(data.x_dict, data.edge_index_dict)
assert len(out_dict) == 2
assert out_dict['paper'].size() == (50, 64)
assert out_dict['author'].size() == (30, 64)

# Test with MeanAggregation
conv = HeteroConv(
{
('paper', 'to', 'paper'): GCNConv(-1, 64),
('author', 'to', 'paper'): SAGEConv((-1, -1), 64),
('paper', 'to', 'author'): GATConv(
(-1, -1), 64, add_self_loops=False),
},
aggr=MeanAggregation(),
)

out_dict = conv(data.x_dict, data.edge_index_dict)
assert len(out_dict) == 2
assert out_dict['paper'].size() == (50, 64)
assert out_dict['author'].size() == (30, 64)


@withDevice
@onlyLinux
@withPackage('torch>=2.1.0')
Expand Down
71 changes: 62 additions & 9 deletions torch_geometric/nn/conv/hetero_conv.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,45 @@
import warnings
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import torch
from torch import Tensor

from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.module_dict import ModuleDict
from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver
from torch_geometric.typing import EdgeType, NodeType
from torch_geometric.utils.hetero import check_add_self_loops


def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]:
def group(
xs: List[Tensor],
aggr: Optional[Union[str, Aggregation]],
) -> Optional[Tensor]:
if len(xs) == 0:
return None
elif aggr is None:
return torch.stack(xs, dim=1)
elif len(xs) == 1:
return xs[0]
elif isinstance(aggr, Aggregation):
# For Aggregation modules, stack tensors and aggregate along dim=1
# This treats each relation as a separate element to aggregate per node
out = torch.stack(xs,
dim=1) # [num_nodes, num_relations, num_features]
batch_size = out.size(0)
num_relations = out.size(1)
num_features = out.size(2)

# Reshape to [num_nodes * num_relations, num_features]
out = out.view(-1, num_features)

# Create index that groups by node
index = torch.arange(
batch_size, device=out.device).repeat_interleave(num_relations)

# Apply aggregation
return aggr(out, index, dim_size=batch_size, dim=0)
elif aggr == "cat":
return torch.cat(xs, dim=-1)
else:
Expand Down Expand Up @@ -50,20 +73,35 @@ class HeteroConv(torch.nn.Module):
print(list(out_dict.keys()))
>>> ['paper', 'author']

Alternatively, you can use an :class:`~torch_geometric.nn.aggr.Aggregation`
module for more complex aggregation schemes:

.. code-block:: python

from torch_geometric.nn import AttentionalAggregation

gate_nn = torch.nn.Linear(64, 1)
hetero_conv = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(-1, 64),
('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
}, aggr=AttentionalAggregation(gate_nn))

Args:
convs (Dict[Tuple[str, str, str], MessagePassing]): A dictionary
holding a bipartite
:class:`~torch_geometric.nn.conv.MessagePassing` layer for each
individual edge type.
aggr (str, optional): The aggregation scheme to use for grouping node
embeddings generated by different relations
(:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
:obj:`"cat"`, :obj:`None`). (default: :obj:`"sum"`)
aggr (str or Aggregation, optional): The aggregation scheme to use
for grouping node embeddings generated by different relations.
Can be one of :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`,
:obj:`"max"`, :obj:`"cat"`, :obj:`None`, or an instance of
:class:`~torch_geometric.nn.aggr.Aggregation`.
(default: :obj:`"sum"`)
"""
def __init__(
self,
convs: Dict[EdgeType, MessagePassing],
aggr: Optional[str] = "sum",
aggr: Optional[Union[str, Aggregation]] = "sum",
):
super().__init__()

Expand All @@ -81,12 +119,25 @@ def __init__(
stacklevel=2)

self.convs = ModuleDict(convs)
self.aggr = aggr

# Resolve aggregation to support both string and Aggregation instances
if isinstance(aggr, str) or aggr is None:
self.aggr = aggr
self.aggr_module = None
elif isinstance(aggr, Aggregation):
self.aggr = None
self.aggr_module = aggr
else:
# Try to resolve using aggr_resolver for other types
self.aggr_module = aggr_resolver(aggr)
self.aggr = None if self.aggr_module is not None else aggr

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
for conv in self.convs.values():
conv.reset_parameters()
if self.aggr_module is not None:
self.aggr_module.reset_parameters()

def forward(
self,
Expand Down Expand Up @@ -163,8 +214,10 @@ def forward(
else:
out_dict[dst].append(out)

# Use aggr_module if available, otherwise use aggr string
aggr = self.aggr_module if self.aggr_module is not None else self.aggr
for key, value in out_dict.items():
out_dict[key] = group(value, self.aggr)
out_dict[key] = group(value, aggr)

return out_dict

Expand Down