File: main_performance.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (148 lines) | stat: -rw-r--r-- 5,544 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import argparse
from itertools import product

import torch
from datasets import get_dataset
from gcn import GCN
from gin import GIN
from graph_sage import GraphSAGE
from train_eval import eval_acc, inference_run, train

from torch_geometric import seed_everything
from torch_geometric.loader import DataLoader
from torch_geometric.profile import rename_profile_file, timeit, torch_profile

seed_everything(0)

parser = argparse.ArgumentParser()
parser.add_argument(
    '--datasets', type=str, nargs='+',
    default=['MUTAG', 'PROTEINS', 'IMDB-BINARY', 'REDDIT-BINARY'])
parser.add_argument('--models', type=str, nargs='+',
                    default=['GCN', 'GraphSAGE', 'GIN'])
parser.add_argument('--layers', type=int, nargs='+', default=[1, 2, 3])
parser.add_argument('--hiddens', type=int, nargs='+', default=[16, 32])
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--warmup_profile', type=int, default=1,
                    help='Skip the first few runs')
parser.add_argument('--goal_accuracy', type=int, default=1,
                    help='The goal test accuracy')
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
parser.add_argument('--compile', action='store_true')
args = parser.parse_args()

if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

if torch.cuda.is_available():
    amp = torch.amp.autocast('cuda', enabled=False)
else:
    amp = torch.cpu.amp.autocast(enabled=args.bf16)

MODELS = {
    'GCN': GCN,
    'GraphSAGE': GraphSAGE,
    'GIN': GIN,
}


def prepare_dataloader(dataset_name):
    dataset = get_dataset(dataset_name, sparse=True)
    num_train = int(len(dataset) * 0.8)
    num_val = int(len(dataset) * 0.1)

    train_dataset = dataset[:num_train]
    val_dataset = dataset[num_train:num_train + num_val]
    test_dataset = dataset[num_train + num_val:]

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
                              shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size,
                            shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
                             shuffle=False)
    return dataset, train_loader, val_loader, test_loader


def run_train():
    for dataset_name, model_name in product(args.datasets, args.models):
        dataset, train_loader, val_loader, test_loader = prepare_dataloader(
            dataset_name)
        Model = MODELS[model_name]

        for num_layers, hidden in product(args.layers, args.hiddens):
            print('--')
            print(f'{dataset_name} - {model_name}- {num_layers} - {hidden}')

            model = Model(dataset, num_layers, hidden).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
            if args.compile:
                model = torch.compile(model)
            loss_list = []
            acc_list = []
            for epoch in range(1, args.epochs + 1):
                if epoch == args.epochs:
                    with timeit():
                        loss = train(model, optimizer, train_loader)
                else:
                    loss = train(model, optimizer, train_loader)

                with timeit(log=False) as t:
                    val_acc = eval_acc(model, val_loader)
                val_time = t.duration
                with timeit(log=False) as t:
                    test_acc = eval_acc(model, test_loader)
                test_time = t.duration

                if epoch >= args.warmup_profile:
                    loss_list.append(loss)
                    acc_list.append([val_acc, val_time, test_acc, test_time])

            if args.profile:
                with torch_profile():
                    train(model, optimizer, train_loader)
                rename_profile_file(model_name, dataset_name, str(num_layers),
                                    str(hidden), 'train')


@torch.no_grad()
def run_inference():
    for dataset_name, model_name in product(args.datasets, args.models):
        dataset, _, _, test_loader = prepare_dataloader(dataset_name)
        Model = MODELS[model_name]

        for num_layers, hidden in product(args.layers, args.hiddens):
            print('--')
            print(f'{dataset_name} - {model_name}- {num_layers} - {hidden}')

            model = Model(dataset, num_layers, hidden).to(device)
            if args.compile:
                model = torch.compile(model)
            with amp:
                for epoch in range(1, args.epochs + 1):
                    if epoch == args.epochs:
                        with timeit():
                            inference_run(model, test_loader, args.bf16)
                    else:
                        inference_run(model, test_loader, args.bf16)

                if args.profile:
                    with torch_profile():
                        inference_run(model, test_loader, args.bf16)
                    rename_profile_file(model_name, dataset_name,
                                        str(num_layers), str(hidden),
                                        'inference')


if not args.inference:
    run_train()
else:
    run_inference()