File: papers100m_gcn_cugraph.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (272 lines) | stat: -rw-r--r-- 8,499 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
"""Single-node, multi-GPU example."""

import argparse
import os
import os.path as osp
import tempfile
import time

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from ogb.nodeproppred import PygNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel

import torch_geometric

# Allow computation on objects that are larger than GPU memory
# https://docs.rapids.ai/api/cudf/stable/developer_guide/library_design/#spilling-to-host-memory
os.environ['CUDF_SPILL'] = '1'

# Ensures that a CUDA context is not created on import of rapids.
# Allows pytorch to create the context instead
os.environ['RAPIDS_NO_INITIALIZE'] = '1'


def init_pytorch_worker(rank, world_size, cugraph_id):
    import rmm

    rmm.reinitialize(
        devices=rank,
        managed_memory=True,
        pool_allocator=True,
    )

    import cupy

    cupy.cuda.Device(rank).use()
    from rmm.allocators.cupy import rmm_cupy_allocator

    cupy.cuda.set_allocator(rmm_cupy_allocator)

    from cugraph.testing.mg_utils import enable_spilling

    enable_spilling()

    torch.cuda.set_device(rank)

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)

    from cugraph.gnn import cugraph_comms_init
    cugraph_comms_init(rank=rank, world_size=world_size, uid=cugraph_id,
                       device=rank)


def run(rank, data, world_size, cugraph_id, model, epochs, batch_size, fan_out,
        split_idx, num_classes, wall_clock_start, tempdir=None, num_layers=3):

    init_pytorch_worker(
        rank,
        world_size,
        cugraph_id,
    )

    model = model.to(rank)
    model = DistributedDataParallel(model, device_ids=[rank])
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01,
                                 weight_decay=0.0005)

    kwargs = dict(
        num_neighbors=[fan_out] * num_layers,
        batch_size=batch_size,
    )
    from cugraph_pyg.data import GraphStore, TensorDictFeatureStore
    from cugraph_pyg.loader import NeighborLoader

    graph_store = GraphStore(is_multi_gpu=True)
    ixr = torch.tensor_split(data.edge_index, world_size, dim=1)[rank]
    graph_store[dict(
        edge_type=('node', 'rel', 'node'),
        layout='coo',
        is_sorted=False,
        size=(data.num_nodes, data.num_nodes),
    )] = ixr

    feature_store = TensorDictFeatureStore()
    feature_store['node', 'x'] = data.x
    feature_store['node', 'y'] = data.y

    dist.barrier()

    ix_train = torch.tensor_split(split_idx['train'], world_size)[rank].cuda()
    train_path = osp.join(tempdir, f'train_{rank}')
    os.mkdir(train_path)
    train_loader = NeighborLoader(
        (feature_store, graph_store),
        input_nodes=ix_train,
        directory=train_path,
        shuffle=True,
        drop_last=True,
        **kwargs,
    )

    ix_val = torch.tensor_split(split_idx['valid'], world_size)[rank].cuda()
    val_path = osp.join(tempdir, f'val_{rank}')
    os.mkdir(val_path)
    val_loader = NeighborLoader(
        (feature_store, graph_store),
        input_nodes=ix_val,
        directory=val_path,
        drop_last=True,
        **kwargs,
    )

    ix_test = torch.tensor_split(split_idx['test'], world_size)[rank].cuda()
    test_path = osp.join(tempdir, f'test_{rank}')
    os.mkdir(test_path)
    test_loader = NeighborLoader(
        (feature_store, graph_store),
        input_nodes=ix_test,
        directory=test_path,
        drop_last=True,
        local_seeds_per_call=80000,
        **kwargs,
    )

    dist.barrier()

    eval_steps = 1000
    warmup_steps = 20
    dist.barrier()
    torch.cuda.synchronize()

    if rank == 0:
        prep_time = time.perf_counter() - wall_clock_start
        print(f"Preparation time: {prep_time:.2f}s")

    for epoch in range(1, epochs + 1):
        for i, batch in enumerate(train_loader):
            if i == warmup_steps:
                torch.cuda.synchronize()
                start = time.time()

            batch = batch.to(rank)
            batch_size = batch.batch_size

            batch.y = batch.y.to(torch.long)
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index)
            loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size])
            loss.backward()
            optimizer.step()
            if rank == 0 and i % 10 == 0:
                print(f"Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}")
        nb = i + 1.0

        if rank == 0:
            print(f"Avg Training Iteration Time: "
                  f"{(time.time() - start) / (nb - warmup_steps):.4f} s/iter")

        with torch.no_grad():
            total_correct = total_examples = 0
            for i, batch in enumerate(val_loader):
                if i >= eval_steps:
                    break

                batch = batch.to(rank)
                batch_size = batch.batch_size

                batch.y = batch.y.to(torch.long)
                out = model(batch.x, batch.edge_index)[:batch_size]

                pred = out.argmax(dim=-1)
                y = batch.y[:batch_size].view(-1).to(torch.long)

                total_correct += int((pred == y).sum())
                total_examples += y.size(0)

            acc_val = total_correct / total_examples
            if rank == 0:
                print(f"Validation Accuracy: {acc_val * 100:.2f}%", )

        torch.cuda.synchronize()

    with torch.no_grad():
        total_correct = total_examples = 0
        for i, batch in enumerate(test_loader):
            batch = batch.to(rank)
            batch_size = batch.batch_size

            batch.y = batch.y.to(torch.long)
            out = model(batch.x, batch.edge_index)[:batch_size]

            pred = out.argmax(dim=-1)
            y = batch.y[:batch_size].view(-1).to(torch.long)

            total_correct += int((pred == y).sum())
            total_examples += y.size(0)

        acc_test = total_correct / total_examples
        if rank == 0:
            print(f"Test Accuracy: {acc_test * 100:.2f}%", )

    if rank == 0:
        total_time = time.perf_counter() - wall_clock_start
        print(f"Total Training Runtime: {total_time - prep_time}s")
        print(f"Total Program Runtime: {total_time}s")

    from cugraph.gnn import cugraph_comms_shutdown
    cugraph_comms_shutdown()
    dist.destroy_process_group()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=4)
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--fan_out', type=int, default=30)
    parser.add_argument('--dataset', type=str, default='ogbn-papers100M')
    parser.add_argument('--root', type=str, default='dataset')
    parser.add_argument('--tempdir_root', type=str, default=None)
    parser.add_argument(
        '--n_devices',
        type=int,
        default=-1,
        help='How many GPUs to use. Defaults to all available GPUs',
    )

    args = parser.parse_args()
    wall_clock_start = time.perf_counter()

    from rmm.allocators.torch import rmm_torch_allocator

    torch.cuda.memory.change_current_allocator(rmm_torch_allocator)

    dataset = PygNodePropPredDataset(name=args.dataset, root=args.root)
    split_idx = dataset.get_idx_split()
    data = dataset[0]
    data.y = data.y.reshape(-1)

    model = torch_geometric.nn.models.GCN(
        dataset.num_features,
        args.hidden_channels,
        args.num_layers,
        dataset.num_classes,
    )

    if args.n_devices < 1:
        world_size = torch.cuda.device_count()
    else:
        world_size = args.n_devices
    print(f"Using {world_size} GPUs...")

    # Create the uid needed for cuGraph comms
    from cugraph.gnn import cugraph_comms_create_unique_id
    cugraph_id = cugraph_comms_create_unique_id()

    with tempfile.TemporaryDirectory(dir=args.tempdir_root) as tempdir:
        mp.spawn(
            run,
            args=(data, world_size, cugraph_id, model, args.epochs,
                  args.batch_size, args.fan_out, split_idx,
                  dataset.num_classes, wall_clock_start, tempdir,
                  args.num_layers),
            nprocs=world_size,
            join=True,
        )