File: training_benchmark_cuda.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (51 lines) | stat: -rw-r--r-- 1,548 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import os
from typing import Union

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from benchmark.multi_gpu.training.common import (
    get_predefined_args,
    run,
    supported_sets,
)
from benchmark.utils import get_dataset
from torch_geometric.data import Data, HeteroData


def run_cuda(rank: int, world_size: int, args: argparse.ArgumentParser,
             num_classes: int, data: Union[Data, HeteroData]):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)
    run(rank, world_size, args, num_classes, data)


if __name__ == '__main__':
    argparser = get_predefined_args()
    argparser.add_argument('--n-gpus', default=1, type=int)
    args = argparser.parse_args()
    setattr(args, 'device', 'cuda')

    assert args.dataset in supported_sets.keys(), \
        f"Dataset {args.dataset} isn't supported."
    data, num_classes = get_dataset(args.dataset, args.root)

    max_world_size = torch.cuda.device_count()
    chosen_world_size = args.n_gpus
    if chosen_world_size <= max_world_size:
        world_size = chosen_world_size
    else:
        print(f'User selected {chosen_world_size} GPUs '
              f'but only {max_world_size} GPUs are available')
        world_size = max_world_size
    print(f'Let\'s use {world_size} GPUs!')

    mp.spawn(
        run_cuda,
        args=(world_size, args, num_classes, data),
        nprocs=world_size,
        join=True,
    )