File: ogbn_train_cugraph.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (392 lines) | stat: -rw-r--r-- 12,605 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
"""Single-node, multi-GPU example."""

import argparse
import os
import os.path as osp
import time

import cupy
import rmm
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from cugraph.gnn import (
    cugraph_comms_create_unique_id,
    cugraph_comms_init,
    cugraph_comms_shutdown,
)
from ogb.nodeproppred import PygNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel

import torch_geometric
from torch_geometric import seed_everything
from torch_geometric.utils import (
    add_self_loops,
    remove_self_loops,
    to_undirected,
)

# Allow computation on objects that are larger than GPU memory
# https://docs.rapids.ai/api/cudf/stable/developer_guide/library_design/#spilling-to-host-memory
os.environ['CUDF_SPILL'] = '1'

# Ensures that a CUDA context is not created on import of rapids.
# Allows pytorch to create the context instead
os.environ['RAPIDS_NO_INITIALIZE'] = '1'


def arg_parse():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
    parser.add_argument(
        '--dataset',
        type=str,
        default='ogbn-arxiv',
        choices=['ogbn-papers100M', 'ogbn-products', 'ogbn-arxiv'],
        help='Dataset name.',
    )
    parser.add_argument(
        '--dataset_dir',
        type=str,
        default='/workspace/data',
        help='Root directory of dataset.',
    )
    parser.add_argument(
        "--dataset_subdir",
        type=str,
        default="ogbn-arxiv",
        help="directory of dataset.",
    )
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--wd', type=float, default=0.000)
    parser.add_argument('-e', '--epochs', type=int, default=50)
    parser.add_argument('-b', '--batch_size', type=int, default=1024)
    parser.add_argument('--fan_out', type=int, default=10)
    parser.add_argument('--warmup_steps', type=int, default=20)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument(
        '--use_directed_graph',
        action='store_true',
        help='Whether or not to use directed graph',
    )
    parser.add_argument(
        '--add_self_loop',
        action='store_true',
        help='Whether or not to add self loop',
    )
    parser.add_argument(
        "--model",
        type=str,
        default='GCN',
        choices=[
            'SAGE',
            'GAT',
            'GCN',
            # TODO: Uncomment when we add support for disjoint sampling
            # 'SGFormer',
        ],
        help="Model used for training, default GCN",
    )
    parser.add_argument(
        "--num_heads",
        type=int,
        default=1,
        help="If using GATConv or GT, number of attention heads to use",
    )
    parser.add_argument(
        '--num_devices',
        type=int,
        default=-1,
        help='How many GPUs to use. Defaults to all available GPUs',
    )
    parser.add_argument(
        '--verbose',
        action='store_true',
        help='Whether or not to generate statistical report',
    )
    args = parser.parse_args()

    return args


def evaluate(rank, loader, model):
    with torch.no_grad():
        total_correct = total_examples = 0
        for batch in loader:
            batch = batch.to(rank)
            batch_size = batch.batch_size

            batch.y = batch.y.to(torch.long)
            out = model(batch.x, batch.edge_index)[:batch_size]

            pred = out.argmax(dim=-1)
            y = batch.y[:batch_size].view(-1).to(torch.long)

            total_correct += (pred == y).sum()
            total_examples += y.size(0)

        acc = total_correct.item() / total_examples
    return acc


def init_pytorch_worker(rank, world_size, cugraph_id):

    rmm.reinitialize(
        devices=rank,
        managed_memory=True,
        pool_allocator=True,
    )

    cupy.cuda.Device(rank).use()
    from rmm.allocators.cupy import rmm_cupy_allocator

    cupy.cuda.set_allocator(rmm_cupy_allocator)

    import cudf
    cudf.set_option("spill", True)
    torch.cuda.set_device(rank)

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)

    cugraph_comms_init(rank=rank, world_size=world_size, uid=cugraph_id,
                       device=rank)


def run_train(rank, args, data, world_size, cugraph_id, model, split_idx,
              num_classes, wall_clock_start):

    epochs = args.epochs
    batch_size = args.batch_size
    fan_out = args.fan_out
    num_layers = args.num_layers

    init_pytorch_worker(
        rank,
        world_size,
        cugraph_id,
    )

    model = model.to(rank)
    model = DistributedDataParallel(model, device_ids=[rank])
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
                                 weight_decay=args.wd)

    kwargs = dict(
        num_neighbors=[fan_out] * num_layers,
        batch_size=batch_size,
    )
    from cugraph_pyg.data import GraphStore, TensorDictFeatureStore
    from cugraph_pyg.loader import NeighborLoader

    graph_store = GraphStore(is_multi_gpu=True)
    ixr = torch.tensor_split(data.edge_index, world_size, dim=1)[rank]
    graph_store[dict(
        edge_type=('node', 'rel', 'node'),
        layout='coo',
        is_sorted=False,
        size=(data.num_nodes, data.num_nodes),
    )] = ixr

    feature_store = TensorDictFeatureStore()
    feature_store['node', 'x', None] = data.x
    feature_store['node', 'y', None] = data.y

    dist.barrier()

    ix_train = torch.tensor_split(split_idx['train'], world_size)[rank].cuda()
    train_loader = NeighborLoader(
        (feature_store, graph_store),
        input_nodes=ix_train,
        shuffle=True,
        drop_last=True,
        **kwargs,
    )

    ix_val = torch.tensor_split(split_idx['valid'], world_size)[rank].cuda()
    val_loader = NeighborLoader(
        (feature_store, graph_store),
        input_nodes=ix_val,
        drop_last=True,
        **kwargs,
    )

    ix_test = torch.tensor_split(split_idx['test'], world_size)[rank].cuda()
    test_loader = NeighborLoader(
        (feature_store, graph_store),
        input_nodes=ix_test,
        drop_last=True,
        local_seeds_per_call=80000,
        **kwargs,
    )

    dist.barrier()

    warmup_steps = args.warmup_steps
    dist.barrier()
    torch.cuda.synchronize()

    if rank == 0:
        prep_time = time.perf_counter() - wall_clock_start
        print("Total time before training begins (prep_time) =", prep_time,
              "seconds")
        print("Beginning training...")

    val_accs = []
    times = []
    train_times = []
    inference_times = []
    best_val = 0.
    start = time.perf_counter()
    for epoch in range(1, epochs + 1):
        train_start = time.perf_counter()
        total_loss = 0
        i = 0
        for i, batch in enumerate(train_loader):
            if i == warmup_steps:
                torch.cuda.synchronize()
            batch = batch.to(rank)
            batch_size = batch.batch_size
            batch.y = batch.y.to(torch.long)
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index)
            loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size])
            loss.backward()
            optimizer.step()
            total_loss += loss
        train_end = time.perf_counter()
        train_times.append(train_end - train_start)
        nb = i + 1.0
        total_loss /= nb
        dist.barrier()
        torch.cuda.synchronize()

        inference_start = time.perf_counter()
        train_acc = evaluate(rank, train_loader, model)
        dist.barrier()
        val_acc = evaluate(rank, val_loader, model)
        dist.barrier()

        inference_times.append(time.perf_counter() - inference_start)
        val_accs.append(val_acc)
        if rank == 0:
            print(f'Epoch {epoch:02d}, Loss: {total_loss:.4f}, Approx. Train:'
                  f' {train_acc:.4f} Time: {train_end - train_start:.4f}s')
            print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ')

        times.append(time.perf_counter() - train_start)
        if val_acc > best_val:
            best_val = val_acc

    print(f'Total time used for rank: {rank:02d} is '
          f'{time.perf_counter()-start:.4f}')
    if rank == 0:
        val_acc = torch.tensor(val_accs)
        print('============================')
        print("Average Epoch Time on training: {:.4f}".format(
            torch.tensor(train_times).mean()))
        print("Average Epoch Time on inference: {:.4f}".format(
            torch.tensor(inference_times).mean()))
        print(f"Average Epoch Time: {torch.tensor(times).mean():.4f}")
        print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
        print(f'Final Validation: {val_acc.mean():.4f} ± {val_acc.std():.4f}')
        print(f"Best validation accuracy: {best_val:.4f}")

    if rank == 0:
        print("Testing...")
    final_test_acc = evaluate(rank, test_loader, model)
    dist.barrier()
    if rank == 0:
        print(f'Test Accuracy: {final_test_acc:.4f} for rank: {rank:02d}')
    if rank == 0:
        total_time = time.perf_counter() - wall_clock_start
        print(f"Total Training Runtime: {total_time - prep_time}s")
        print(f"Total Program Runtime: {total_time}s")

    cugraph_comms_shutdown()
    dist.destroy_process_group()


if __name__ == '__main__':

    args = arg_parse()
    seed_everything(123)
    wall_clock_start = time.perf_counter()

    root = osp.join(args.dataset_dir, args.dataset_subdir)
    dataset = PygNodePropPredDataset(name=args.dataset, root=root)
    split_idx = dataset.get_idx_split()
    data = dataset[0]
    if not args.use_directed_graph:
        data.edge_index = to_undirected(data.edge_index, reduce="mean")
    if args.add_self_loop:
        data.edge_index, _ = remove_self_loops(data.edge_index)
        data.edge_index, _ = add_self_loops(data.edge_index,
                                            num_nodes=data.num_nodes)
    data.y = data.y.reshape(-1)

    print(f"Training {args.dataset} with {args.model} model.")
    if args.model == "GAT":
        model = torch_geometric.nn.models.GAT(dataset.num_features,
                                              args.hidden_channels,
                                              args.num_layers,
                                              dataset.num_classes,
                                              heads=args.num_heads)
    elif args.model == "GCN":
        model = torch_geometric.nn.models.GCN(
            dataset.num_features,
            args.hidden_channels,
            args.num_layers,
            dataset.num_classes,
        )
    elif args.model == "SAGE":
        model = torch_geometric.nn.models.GraphSAGE(
            dataset.num_features,
            args.hidden_channels,
            args.num_layers,
            dataset.num_classes,
        )
    elif args.model == 'SGFormer':
        # TODO add support for this with disjoint sampling
        model = torch_geometric.nn.models.SGFormer(
            in_channels=dataset.num_features,
            hidden_channels=args.hidden_channels,
            out_channels=dataset.num_classes,
            trans_num_heads=args.num_heads,
            trans_dropout=args.dropout,
            gnn_num_layers=args.num_layers,
            gnn_dropout=args.dropout,
        )
    else:
        raise ValueError(f'Unsupported model type: {args.model}')

    print("Data =", data)
    if args.num_devices < 1:
        world_size = torch.cuda.device_count()
    elif args.num_devices <= torch.cuda.device_count():
        world_size = args.num_devices
    else:
        world_size = torch.cuda.device_count()
    print('Let\'s use', world_size, 'GPUs!')

    # Create the uid needed for cuGraph comms
    cugraph_id = cugraph_comms_create_unique_id()

    if world_size > 1:
        mp.spawn(
            run_train,
            args=(args, data, world_size, cugraph_id, model, split_idx,
                  dataset.num_classes, wall_clock_start),
            nprocs=world_size,
            join=True,
        )
    else:
        run_train(0, args, data, world_size, cugraph_id, model, split_idx,
                  dataset.num_classes, wall_clock_start)

    total_time = round(time.perf_counter() - wall_clock_start, 2)
    print("Total Program Runtime (total_time) =", total_time, "seconds")