1
1
import argparse
2
+ import os
2
3
import os .path as osp
3
4
import time
4
5
5
6
import cupy
6
7
import psutil
7
8
import rmm
8
9
import torch
10
+ import torch .distributed as dist
9
11
from rmm .allocators .cupy import rmm_cupy_allocator
10
12
from rmm .allocators .torch import rmm_torch_allocator
11
13
31
33
cudf .set_option ("spill" , True )
32
34
33
35
36
+ # ---------------- Distributed helpers ----------------
37
+ def safe_get_rank ():
38
+ return dist .get_rank () if dist .is_initialized () else 0
39
+
40
+
41
+ def safe_get_world_size ():
42
+ return dist .get_world_size () if dist .is_initialized () else 1
43
+
44
+
45
+ def init_distributed ():
46
+ """Initialize distributed training if environment variables are set.
47
+ Fallback to single-GPU mode otherwise.
48
+ """
49
+ # Already initialized ? nothing to do
50
+ if dist .is_available () and dist .is_initialized ():
51
+ return
52
+
53
+ # Default env vars for single-GPU / single-process fallback
54
+ default_env = {
55
+ "RANK" : "0" ,
56
+ "LOCAL_RANK" : "0" ,
57
+ "WORLD_SIZE" : "1" ,
58
+ "LOCAL_WORLD_SIZE" : "1" ,
59
+ "MASTER_ADDR" : "127.0.0.1" ,
60
+ "MASTER_PORT" : "29500"
61
+ }
62
+
63
+ # Update environment only if keys are missing
64
+ for k , v in default_env .items ():
65
+ os .environ .setdefault (k , v )
66
+
67
+ # Set CUDA device
68
+ if torch .cuda .is_available ():
69
+ local_rank = int (os .environ ["LOCAL_RANK" ])
70
+ torch .cuda .set_device (local_rank )
71
+
72
+ # Initialize distributed only if world_size > 1
73
+ world_size = int (os .environ ["WORLD_SIZE" ])
74
+ if world_size > 1 :
75
+ dist .init_process_group (backend = "nccl" , init_method = "env://" )
76
+ rank = os .environ ['RANK' ]
77
+ print (f"Initialized distributed: rank { rank } , world_size { world_size } " )
78
+ else :
79
+ print ("Running in single-GPU / single-process mode" )
80
+
81
+ if not dist .is_initialized ():
82
+ dist .init_process_group (backend = "nccl" , init_method = "env://" , rank = 0 ,
83
+ world_size = 1 )
84
+
85
+
86
+ # ------------------------------------------------------
87
+
88
+
34
89
def arg_parse ():
35
90
parser = argparse .ArgumentParser (
36
91
formatter_class = argparse .ArgumentDefaultsHelpFormatter , )
@@ -98,15 +153,16 @@ def arg_parse():
98
153
99
154
100
155
def create_loader (
156
+ input_nodes ,
157
+ stage_name ,
101
158
data ,
102
159
num_neighbors ,
103
- input_nodes ,
104
160
replace ,
105
161
batch_size ,
106
- stage_name ,
107
162
shuffle = False ,
108
163
):
109
- print (f'Creating { stage_name } loader...' )
164
+ if safe_get_rank () == 0 :
165
+ print (f'Creating { stage_name } loader...' )
110
166
111
167
return NeighborLoader (
112
168
data ,
@@ -118,7 +174,7 @@ def create_loader(
118
174
)
119
175
120
176
121
- def train (model , train_loader ):
177
+ def train (model , train_loader , optimizer ):
122
178
model .train ()
123
179
124
180
total_loss = total_correct = total_examples = 0
@@ -156,17 +212,26 @@ def test(model, loader):
156
212
157
213
158
214
if __name__ == '__main__' :
215
+ # init DDP if needed
216
+ init_distributed ()
217
+
159
218
args = arg_parse ()
160
219
torch_geometric .seed_everything (123 )
220
+
161
221
if "papers" in str (args .dataset ) and (psutil .virtual_memory ().total /
162
222
(1024 ** 3 )) < 390 :
163
- print ("Warning: may not have enough RAM to use this many GPUs." )
164
- print ("Consider upgrading RAM if an error occurs." )
165
- print ("Estimated RAM Needed: ~390GB." )
223
+ if safe_get_rank () == 0 :
224
+ print ("Warning: may not have enough RAM to use this many GPUs." )
225
+ print ("Consider upgrading RAM if an error occurs." )
226
+ print ("Estimated RAM Needed: ~390GB." )
227
+
166
228
wall_clock_start = time .perf_counter ()
167
229
168
230
root = osp .join (args .dataset_dir , args .dataset_subdir )
169
- print ('The root is: ' , root )
231
+
232
+ if safe_get_rank () == 0 :
233
+ print ('The root is: ' , root )
234
+
170
235
dataset = PygNodePropPredDataset (name = args .dataset , root = root )
171
236
split_idx = dataset .get_idx_split ()
172
237
@@ -188,33 +253,30 @@ def test(model, loader):
188
253
size = (data .num_nodes , data .num_nodes ),
189
254
)] = data .edge_index
190
255
191
- feature_store = cugraph_pyg .data .TensorDictFeatureStore ()
256
+ feature_store = cugraph_pyg .data .FeatureStore ()
192
257
feature_store ['node' , 'x' , None ] = data .x
193
258
feature_store ['node' , 'y' , None ] = data .y
194
259
195
260
data = (feature_store , graph_store )
196
261
197
- print (f"Training { args .dataset } with { args .model } model." )
262
+ if safe_get_rank () == 0 :
263
+ print (f"Training { args .dataset } with { args .model } model." )
264
+
198
265
if args .model == "GAT" :
199
266
model = torch_geometric .nn .models .GAT (dataset .num_features ,
200
267
args .hidden_channels ,
201
268
args .num_layers ,
202
269
dataset .num_classes ,
203
270
heads = args .num_heads ).cuda ()
204
271
elif args .model == "GCN" :
205
- model = torch_geometric .nn .models .GCN (
206
- dataset .num_features ,
207
- args .hidden_channels ,
208
- args .num_layers ,
209
- dataset .num_classes ,
210
- ).cuda ()
272
+ model = torch_geometric .nn .models .GCN (dataset .num_features ,
273
+ args .hidden_channels ,
274
+ args .num_layers ,
275
+ dataset .num_classes ).cuda ()
211
276
elif args .model == "SAGE" :
212
277
model = torch_geometric .nn .models .GraphSAGE (
213
- dataset .num_features ,
214
- args .hidden_channels ,
215
- args .num_layers ,
216
- dataset .num_classes ,
217
- ).cuda ()
278
+ dataset .num_features , args .hidden_channels , args .num_layers ,
279
+ dataset .num_classes ).cuda ()
218
280
elif args .model == 'SGFormer' :
219
281
# TODO add support for this with disjoint sampling
220
282
model = torch_geometric .nn .models .SGFormer (
@@ -227,7 +289,7 @@ def test(model, loader):
227
289
gnn_dropout = args .dropout ,
228
290
).cuda ()
229
291
else :
230
- raise ValueError ('Unsupported model type: {args.model}' )
292
+ raise ValueError (f 'Unsupported model type: { args .model } ' )
231
293
232
294
optimizer = torch .optim .Adam (model .parameters (), lr = args .lr ,
233
295
weight_decay = args .wd )
@@ -239,69 +301,54 @@ def test(model, loader):
239
301
batch_size = args .batch_size ,
240
302
)
241
303
242
- train_loader = create_loader (
243
- input_nodes = split_idx ['train' ],
244
- stage_name = 'train' ,
245
- shuffle = True ,
246
- ** loader_kwargs ,
247
- )
304
+ train_loader = create_loader (split_idx ['train' ], 'train' , ** loader_kwargs ,
305
+ shuffle = True )
306
+ val_loader = create_loader (split_idx ['valid' ], 'val' , ** loader_kwargs )
307
+ test_loader = create_loader (split_idx ['test' ], 'test' , ** loader_kwargs )
248
308
249
- val_loader = create_loader (
250
- input_nodes = split_idx ['valid' ],
251
- stage_name = 'val' ,
252
- ** loader_kwargs ,
253
- )
309
+ if dist .is_initialized ():
310
+ dist .barrier () # sync before training
254
311
255
- test_loader = create_loader (
256
- input_nodes = split_idx ['test' ],
257
- stage_name = 'test' ,
258
- ** loader_kwargs ,
259
- )
260
- prep_time = round (time .perf_counter () - wall_clock_start , 2 )
261
- print ("Total time before training begins (prep_time) =" , prep_time ,
262
- "seconds" )
263
- print ("Beginning training..." )
264
- val_accs = []
265
- times = []
266
- train_times = []
267
- inference_times = []
312
+ if safe_get_rank () == 0 :
313
+ prep_time = round (time .perf_counter () - wall_clock_start , 2 )
314
+ print ("Total time before training begins (prep_time) =" , prep_time ,
315
+ "seconds" )
316
+ print ("Beginning training..." )
317
+
318
+ val_accs , times , train_times , inference_times = [], [], [], []
268
319
best_val = 0.
269
320
start = time .perf_counter ()
270
- epochs = args .epochs
271
- for epoch in range (1 , epochs + 1 ):
321
+ for epoch in range (1 , args .epochs + 1 ):
272
322
train_start = time .perf_counter ()
273
- loss , train_acc = train (model , train_loader )
323
+ loss , train_acc = train (model , train_loader , optimizer )
274
324
train_end = time .perf_counter ()
275
325
train_times .append (train_end - train_start )
276
326
inference_start = time .perf_counter ()
277
327
train_acc = test (model , train_loader )
278
328
val_acc = test (model , val_loader )
279
-
280
329
inference_times .append (time .perf_counter () - inference_start )
281
330
val_accs .append (val_acc )
282
- print (f'Epoch { epoch :02d} , Loss: { loss :.4f} , Approx. Train:'
283
- f' { train_acc :.4f} Time: { train_end - train_start :.4f} s' )
284
- print (f'Train: { train_acc :.4f} , Val: { val_acc :.4f} , ' )
331
+
332
+ if safe_get_rank () == 0 :
333
+ print (f'Epoch { epoch :02d} , Loss: { loss :.4f} , '
334
+ f'Train: { train_acc :.4f} , Val: { val_acc :.4f} , '
335
+ f'Time: { train_end - train_start :.4f} s' )
285
336
286
337
times .append (time .perf_counter () - train_start )
287
- if val_acc > best_val :
288
- best_val = val_acc
289
-
290
- print (f"Total time used: is { time .perf_counter ()- start :.4f} " )
291
- val_acc = torch .tensor (val_accs )
292
- print ('============================' )
293
- print ("Average Epoch Time on training: {:.4f}" .format (
294
- torch .tensor (train_times ).mean ()))
295
- print ("Average Epoch Time on inference: {:.4f}" .format (
296
- torch .tensor (inference_times ).mean ()))
297
- print (f"Average Epoch Time: { torch .tensor (times ).mean ():.4f} " )
298
- print (f"Median time per epoch: { torch .tensor (times ).median ():.4f} s" )
299
- print (f'Final Validation: { val_acc .mean ():.4f} ± { val_acc .std ():.4f} ' )
300
- print (f"Best validation accuracy: { best_val :.4f} " )
301
-
302
- print ("Testing..." )
303
- final_test_acc = test (model , test_loader )
304
- print (f'Test Accuracy: { final_test_acc :.4f} ' )
305
-
306
- total_time = round (time .perf_counter () - wall_clock_start , 2 )
307
- print ("Total Program Runtime (total_time) =" , total_time , "seconds" )
338
+ best_val = max (best_val , val_acc )
339
+
340
+ if safe_get_rank () == 0 :
341
+ print (f"Total time used: { time .perf_counter ()- start :.4f} " )
342
+ print ("Final Validation: {:.4f} ± {:.4f}" .format (
343
+ torch .tensor (val_accs ).mean (),
344
+ torch .tensor (val_accs ).std ()))
345
+ print (f"Best validation accuracy: { best_val :.4f} " )
346
+ print ("Testing..." )
347
+ final_test_acc = test (model , test_loader )
348
+ print (f'Test Accuracy: { final_test_acc :.4f} ' )
349
+ total_time = round (time .perf_counter () - wall_clock_start , 2 )
350
+ print ("Total Program Runtime (total_time) =" , total_time , "seconds" )
351
+
352
+ if dist .is_initialized ():
353
+ dist .barrier ()
354
+ dist .destroy_process_group ()
0 commit comments