Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added GraphLand benchmark via `GraphLandDataset` ([#10458](https://github.com/pyg-team/pytorch_geometric/pull/10458))
- Added llm generated explanations to `TAGDataset` ([#9918](https://github.com/pyg-team/pytorch_geometric/pull/9918))
- Added `torch_geometric.llm` and its examples ([#10436](https://github.com/pyg-team/pytorch_geometric/pull/10436))
- Added support for negative weights in `sparse_cross_entropy` ([#10432](https://github.com/pyg-team/pytorch_geometric/pull/10432))
Expand Down
2 changes: 2 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ For examples on [Open Graph Benchmark](https://ogb.stanford.edu/) datasets, see

For an example on [Relational Deep Learning](https://arxiv.org/abs/2312.04615) with the [RelBench datasets](https://relbench.stanford.edu/), see [`rdl.py`](./rdl.py).

For an example on using [GraphLand datasets](https://arxiv.org/abs/2409.14500) for node property prediction, see [`graphland.py`](./graphland.py).

For examples on using `torch.compile`, see the examples under [`examples/compile`](./compile).

For examples on scaling PyG up via multi-GPUs, see the examples under [`examples/multi_gpu`](./multi_gpu).
Expand Down
147 changes: 147 additions & 0 deletions examples/graphland.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import argparse

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score, average_precision_score, r2_score
from tqdm import tqdm

from torch_geometric.datasets import GraphLandDataset
from torch_geometric.nn import GCNConv


class Model(torch.nn.Module):
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
) -> None:
super().__init__()
self.conv = GCNConv(in_channels, hidden_channels)
self.head = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels, out_channels),
)

def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
) -> torch.Tensor:
return self.head(self.conv(x, edge_index))


def _get_num_classes(dataset: GraphLandDataset) -> int:
assert dataset.task != 'regression'
targets = torch.cat([data.y for data in dataset], dim=0)
return len(torch.unique(targets[~torch.isnan(targets)]))


def _train_step(
model: nn.Module,
dataset: GraphLandDataset,
optimizer: optim.Optimizer,
) -> torch.Tensor:
data = dataset[0]
mask = data.train_mask if dataset.split != 'THI' else data.mask
optimizer.zero_grad()
outputs = model(data.x, data.edge_index).squeeze()

if dataset.task == 'regression':
loss = F.mse_loss(outputs[mask], data.y[mask])
else:
loss = F.cross_entropy(outputs[mask], data.y[mask].long())

loss.backward()
optimizer.step()
return loss


def _eval_step(
model: nn.Module,
dataset: GraphLandDataset,
) -> dict[str, float]:
def _compute_metric(outputs: np.ndarray, targets: np.ndarray) -> float:
if dataset.task == 'regression':
return float(r2_score(targets, outputs))

elif dataset.task == 'binary_classification':
predictions = outputs[:, 1]
return float(average_precision_score(targets, predictions))

else:
predictions = np.argmax(outputs, axis=1)
return float(accuracy_score(targets, predictions))

metrics = dict()
for idx, part in enumerate(['train', 'val', 'test']):
if dataset.split == 'THI':
data = dataset[idx]
mask = data.mask
else:
data = dataset[0]
mask = getattr(data, f'{part}_mask')

outputs = model(data.x, data.edge_index).squeeze()
metrics[part] = _compute_metric(
outputs[mask].detach().cpu().numpy(),
data.y[mask].cpu().numpy(),
)
return metrics


def _format_metrics(metrics: dict[str, float]) -> str:
return ', '.join(f'{part}={metrics[part] * 100.0:.2f}'
for part in ['train', 'val', 'test'])


def run_experiment(name: str, split: str) -> None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_steps = 100
dataset = GraphLandDataset(
root='./data',
split=split,
name=name,
to_undirected=True,
).to(device)
model = Model(
in_channels=dataset[0].x.shape[1],
hidden_channels=256,
out_channels=(_get_num_classes(dataset)
if dataset.task != 'regression' else 1),
).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

best_metrics = {part: -float('inf') for part in ['train', 'val', 'test']}
pbar = tqdm(range(n_steps))
for _ in pbar:
loss = _train_step(model, dataset, optimizer)
curr_metrics = _eval_step(model, dataset)
pbar.set_postfix_str(f'loss={loss.detach().cpu().item():.4f}, ' +
_format_metrics(curr_metrics))
if curr_metrics['val'] > best_metrics['val']:
best_metrics = curr_metrics

print('Best metrics: ' + _format_metrics(best_metrics))
return best_metrics


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--name',
choices=list(GraphLandDataset.GRAPHLAND_DATASETS.keys()),
help='The name of dataset.',
required=True,
)
parser.add_argument(
'--split',
choices=['RL', 'RH', 'TH', 'THI'],
help='The type of data split.',
required=True,
)
args = parser.parse_args()
run_experiment(args.name, args.split)
72 changes: 72 additions & 0 deletions test/datasets/test_graphland.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
import torch

from torch_geometric.datasets import GraphLandDataset
from torch_geometric.testing import onlyOnline, withPackage


@onlyOnline
@withPackage('pandas', 'sklearn', 'yaml')
@pytest.mark.parametrize('name', [
'hm-categories',
'tolokers-2',
'avazu-ctr',
])
def test_transductive_graphland(name: str):
dataset = GraphLandDataset(
root='./datasets',
split='RL',
name=name,
to_undirected=True,
)
assert len(dataset) == 1

data = dataset[0]
assert data.num_nodes == data.x.shape[0] == data.y.shape[0]

assert not (data.train_mask & data.val_mask & data.test_mask).any().item()

labeled_mask = data.train_mask | data.val_mask | data.test_mask
assert not torch.isnan(data.y[labeled_mask]).any().item()
assert not torch.isnan(data.x).any().item()

assert not (data.x_numerical_mask & data.x_fraction_mask
& data.x_categorical_mask).any().item()

assert (data.x_numerical_mask | data.x_fraction_mask
| data.x_categorical_mask).all().item()


@onlyOnline
@withPackage('pandas', 'sklearn', 'yaml')
@pytest.mark.parametrize('name', [
'hm-categories',
'tolokers-2',
'avazu-ctr',
])
def test_inductive_graphland(name: str):
base_data = GraphLandDataset(
root='./datasets',
split='TH',
name=name,
to_undirected=True,
)[0]
num_nodes = base_data.num_nodes
num_edges = base_data.num_edges
del base_data

dataset = GraphLandDataset(
root='./datasets',
split='THI',
name=name,
to_undirected=True,
)
assert len(dataset) == 3

data_train, data_val, data_test = dataset
assert num_nodes == data_test.num_nodes == data_test.node_id.shape[0]
assert num_edges == data_test.num_edges

assert not torch.isnan(data_train.y[data_train.mask]).any().item()
assert not torch.isnan(data_val.y[data_val.mask]).any().item()
assert not torch.isnan(data_test.y[data_test.mask]).any().item()
2 changes: 2 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
from .tag_dataset import TAGDataset
from .city import CityNetwork
from .teeth3ds import Teeth3DS
from .graphland import GraphLandDataset

from .dbp15k import DBP15K
from .aminer import AMiner
Expand Down Expand Up @@ -207,6 +208,7 @@
'TAGDataset',
'CityNetwork',
'Teeth3DS',
'GraphLandDataset',
]

hetero_datasets = [
Expand Down
Loading
Loading