Skip to content

Commit 0fea80f

Browse files
authored
Merge branch 'master' into aki/rev-10422
2 parents d926f42 + e2a1675 commit 0fea80f

File tree

2 files changed

+122
-74
lines changed

2 files changed

+122
-74
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Fixed
99

10+
- Fixed `ogbn_train_cugraph` example for distributed cuGraph ([#10439](https://github.com/pyg-team/pytorch_geometric/pull/10439))
1011
- Fixed importing PyTorch Lightning in `torch_geometric.graphgym` and `torch_geometric.data.lightning` when using `lightning` instead of `pytorch-lightning` ([#10404](https://github.com/pyg-team/pytorch_geometric/pull/10404), [#10417](https://github.com/pyg-team/pytorch_geometric/pull/10417)))
1112
- Fixed `detach()` warnings in example scripts involving tensor conversions ([#10357](https://github.com/pyg-team/pytorch_geometric/pull/10357))
1213
- Fixed non-tuple indexing to resolve PyTorch deprecation warning ([#10389](https://github.com/pyg-team/pytorch_geometric/pull/10389))

examples/ogbn_train_cugraph.py

Lines changed: 121 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import argparse
2+
import os
23
import os.path as osp
34
import time
45

56
import cupy
67
import psutil
78
import rmm
89
import torch
10+
import torch.distributed as dist
911
from rmm.allocators.cupy import rmm_cupy_allocator
1012
from rmm.allocators.torch import rmm_torch_allocator
1113

@@ -31,6 +33,59 @@
3133
cudf.set_option("spill", True)
3234

3335

36+
# ---------------- Distributed helpers ----------------
37+
def safe_get_rank():
38+
return dist.get_rank() if dist.is_initialized() else 0
39+
40+
41+
def safe_get_world_size():
42+
return dist.get_world_size() if dist.is_initialized() else 1
43+
44+
45+
def init_distributed():
46+
"""Initialize distributed training if environment variables are set.
47+
Fallback to single-GPU mode otherwise.
48+
"""
49+
# Already initialized ? nothing to do
50+
if dist.is_available() and dist.is_initialized():
51+
return
52+
53+
# Default env vars for single-GPU / single-process fallback
54+
default_env = {
55+
"RANK": "0",
56+
"LOCAL_RANK": "0",
57+
"WORLD_SIZE": "1",
58+
"LOCAL_WORLD_SIZE": "1",
59+
"MASTER_ADDR": "127.0.0.1",
60+
"MASTER_PORT": "29500"
61+
}
62+
63+
# Update environment only if keys are missing
64+
for k, v in default_env.items():
65+
os.environ.setdefault(k, v)
66+
67+
# Set CUDA device
68+
if torch.cuda.is_available():
69+
local_rank = int(os.environ["LOCAL_RANK"])
70+
torch.cuda.set_device(local_rank)
71+
72+
# Initialize distributed only if world_size > 1
73+
world_size = int(os.environ["WORLD_SIZE"])
74+
if world_size > 1:
75+
dist.init_process_group(backend="nccl", init_method="env://")
76+
rank = os.environ['RANK']
77+
print(f"Initialized distributed: rank {rank}, world_size {world_size}")
78+
else:
79+
print("Running in single-GPU / single-process mode")
80+
81+
if not dist.is_initialized():
82+
dist.init_process_group(backend="nccl", init_method="env://", rank=0,
83+
world_size=1)
84+
85+
86+
# ------------------------------------------------------
87+
88+
3489
def arg_parse():
3590
parser = argparse.ArgumentParser(
3691
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
@@ -98,15 +153,16 @@ def arg_parse():
98153

99154

100155
def create_loader(
156+
input_nodes,
157+
stage_name,
101158
data,
102159
num_neighbors,
103-
input_nodes,
104160
replace,
105161
batch_size,
106-
stage_name,
107162
shuffle=False,
108163
):
109-
print(f'Creating {stage_name} loader...')
164+
if safe_get_rank() == 0:
165+
print(f'Creating {stage_name} loader...')
110166

111167
return NeighborLoader(
112168
data,
@@ -118,7 +174,7 @@ def create_loader(
118174
)
119175

120176

121-
def train(model, train_loader):
177+
def train(model, train_loader, optimizer):
122178
model.train()
123179

124180
total_loss = total_correct = total_examples = 0
@@ -156,17 +212,26 @@ def test(model, loader):
156212

157213

158214
if __name__ == '__main__':
215+
# init DDP if needed
216+
init_distributed()
217+
159218
args = arg_parse()
160219
torch_geometric.seed_everything(123)
220+
161221
if "papers" in str(args.dataset) and (psutil.virtual_memory().total /
162222
(1024**3)) < 390:
163-
print("Warning: may not have enough RAM to use this many GPUs.")
164-
print("Consider upgrading RAM if an error occurs.")
165-
print("Estimated RAM Needed: ~390GB.")
223+
if safe_get_rank() == 0:
224+
print("Warning: may not have enough RAM to use this many GPUs.")
225+
print("Consider upgrading RAM if an error occurs.")
226+
print("Estimated RAM Needed: ~390GB.")
227+
166228
wall_clock_start = time.perf_counter()
167229

168230
root = osp.join(args.dataset_dir, args.dataset_subdir)
169-
print('The root is: ', root)
231+
232+
if safe_get_rank() == 0:
233+
print('The root is: ', root)
234+
170235
dataset = PygNodePropPredDataset(name=args.dataset, root=root)
171236
split_idx = dataset.get_idx_split()
172237

@@ -188,33 +253,30 @@ def test(model, loader):
188253
size=(data.num_nodes, data.num_nodes),
189254
)] = data.edge_index
190255

191-
feature_store = cugraph_pyg.data.TensorDictFeatureStore()
256+
feature_store = cugraph_pyg.data.FeatureStore()
192257
feature_store['node', 'x', None] = data.x
193258
feature_store['node', 'y', None] = data.y
194259

195260
data = (feature_store, graph_store)
196261

197-
print(f"Training {args.dataset} with {args.model} model.")
262+
if safe_get_rank() == 0:
263+
print(f"Training {args.dataset} with {args.model} model.")
264+
198265
if args.model == "GAT":
199266
model = torch_geometric.nn.models.GAT(dataset.num_features,
200267
args.hidden_channels,
201268
args.num_layers,
202269
dataset.num_classes,
203270
heads=args.num_heads).cuda()
204271
elif args.model == "GCN":
205-
model = torch_geometric.nn.models.GCN(
206-
dataset.num_features,
207-
args.hidden_channels,
208-
args.num_layers,
209-
dataset.num_classes,
210-
).cuda()
272+
model = torch_geometric.nn.models.GCN(dataset.num_features,
273+
args.hidden_channels,
274+
args.num_layers,
275+
dataset.num_classes).cuda()
211276
elif args.model == "SAGE":
212277
model = torch_geometric.nn.models.GraphSAGE(
213-
dataset.num_features,
214-
args.hidden_channels,
215-
args.num_layers,
216-
dataset.num_classes,
217-
).cuda()
278+
dataset.num_features, args.hidden_channels, args.num_layers,
279+
dataset.num_classes).cuda()
218280
elif args.model == 'SGFormer':
219281
# TODO add support for this with disjoint sampling
220282
model = torch_geometric.nn.models.SGFormer(
@@ -227,7 +289,7 @@ def test(model, loader):
227289
gnn_dropout=args.dropout,
228290
).cuda()
229291
else:
230-
raise ValueError('Unsupported model type: {args.model}')
292+
raise ValueError(f'Unsupported model type: {args.model}')
231293

232294
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
233295
weight_decay=args.wd)
@@ -239,69 +301,54 @@ def test(model, loader):
239301
batch_size=args.batch_size,
240302
)
241303

242-
train_loader = create_loader(
243-
input_nodes=split_idx['train'],
244-
stage_name='train',
245-
shuffle=True,
246-
**loader_kwargs,
247-
)
304+
train_loader = create_loader(split_idx['train'], 'train', **loader_kwargs,
305+
shuffle=True)
306+
val_loader = create_loader(split_idx['valid'], 'val', **loader_kwargs)
307+
test_loader = create_loader(split_idx['test'], 'test', **loader_kwargs)
248308

249-
val_loader = create_loader(
250-
input_nodes=split_idx['valid'],
251-
stage_name='val',
252-
**loader_kwargs,
253-
)
309+
if dist.is_initialized():
310+
dist.barrier() # sync before training
254311

255-
test_loader = create_loader(
256-
input_nodes=split_idx['test'],
257-
stage_name='test',
258-
**loader_kwargs,
259-
)
260-
prep_time = round(time.perf_counter() - wall_clock_start, 2)
261-
print("Total time before training begins (prep_time) =", prep_time,
262-
"seconds")
263-
print("Beginning training...")
264-
val_accs = []
265-
times = []
266-
train_times = []
267-
inference_times = []
312+
if safe_get_rank() == 0:
313+
prep_time = round(time.perf_counter() - wall_clock_start, 2)
314+
print("Total time before training begins (prep_time) =", prep_time,
315+
"seconds")
316+
print("Beginning training...")
317+
318+
val_accs, times, train_times, inference_times = [], [], [], []
268319
best_val = 0.
269320
start = time.perf_counter()
270-
epochs = args.epochs
271-
for epoch in range(1, epochs + 1):
321+
for epoch in range(1, args.epochs + 1):
272322
train_start = time.perf_counter()
273-
loss, train_acc = train(model, train_loader)
323+
loss, train_acc = train(model, train_loader, optimizer)
274324
train_end = time.perf_counter()
275325
train_times.append(train_end - train_start)
276326
inference_start = time.perf_counter()
277327
train_acc = test(model, train_loader)
278328
val_acc = test(model, val_loader)
279-
280329
inference_times.append(time.perf_counter() - inference_start)
281330
val_accs.append(val_acc)
282-
print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train:'
283-
f' {train_acc:.4f} Time: {train_end - train_start:.4f}s')
284-
print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ')
331+
332+
if safe_get_rank() == 0:
333+
print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, '
334+
f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
335+
f'Time: {train_end - train_start:.4f}s')
285336

286337
times.append(time.perf_counter() - train_start)
287-
if val_acc > best_val:
288-
best_val = val_acc
289-
290-
print(f"Total time used: is {time.perf_counter()-start:.4f}")
291-
val_acc = torch.tensor(val_accs)
292-
print('============================')
293-
print("Average Epoch Time on training: {:.4f}".format(
294-
torch.tensor(train_times).mean()))
295-
print("Average Epoch Time on inference: {:.4f}".format(
296-
torch.tensor(inference_times).mean()))
297-
print(f"Average Epoch Time: {torch.tensor(times).mean():.4f}")
298-
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
299-
print(f'Final Validation: {val_acc.mean():.4f} ± {val_acc.std():.4f}')
300-
print(f"Best validation accuracy: {best_val:.4f}")
301-
302-
print("Testing...")
303-
final_test_acc = test(model, test_loader)
304-
print(f'Test Accuracy: {final_test_acc:.4f}')
305-
306-
total_time = round(time.perf_counter() - wall_clock_start, 2)
307-
print("Total Program Runtime (total_time) =", total_time, "seconds")
338+
best_val = max(best_val, val_acc)
339+
340+
if safe_get_rank() == 0:
341+
print(f"Total time used: {time.perf_counter()-start:.4f}")
342+
print("Final Validation: {:.4f} ± {:.4f}".format(
343+
torch.tensor(val_accs).mean(),
344+
torch.tensor(val_accs).std()))
345+
print(f"Best validation accuracy: {best_val:.4f}")
346+
print("Testing...")
347+
final_test_acc = test(model, test_loader)
348+
print(f'Test Accuracy: {final_test_acc:.4f}')
349+
total_time = round(time.perf_counter() - wall_clock_start, 2)
350+
print("Total Program Runtime (total_time) =", total_time, "seconds")
351+
352+
if dist.is_initialized():
353+
dist.barrier()
354+
dist.destroy_process_group()

0 commit comments

Comments
 (0)