File: wl_kernel.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (69 lines) | stat: -rw-r--r-- 2,271 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import os.path as osp
import warnings

import torch
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics import accuracy_score
from sklearn.svm import LinearSVC

from torch_geometric.data import Batch
from torch_geometric.datasets import TUDataset
from torch_geometric.nn import WLConv

warnings.filterwarnings('ignore', category=ConvergenceWarning)

parser = argparse.ArgumentParser()
parser.add_argument('--runs', type=int, default=10)
args = parser.parse_args()

torch.manual_seed(42)

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TU')
dataset = TUDataset(path, name='ENZYMES')
data = Batch.from_data_list(dataset)


class WL(torch.nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        self.convs = torch.nn.ModuleList([WLConv() for _ in range(num_layers)])

    def forward(self, x, edge_index, batch=None):
        hists = []
        for conv in self.convs:
            x = conv(x, edge_index)
            hists.append(conv.histogram(x, batch, norm=True))
        return hists


wl = WL(num_layers=5)
hists = wl(data.x, data.edge_index, data.batch)

test_accs = torch.empty(args.runs, dtype=torch.float)

for run in range(1, args.runs + 1):
    perm = torch.randperm(data.num_graphs)
    val_index = perm[:data.num_graphs // 10]
    test_index = perm[data.num_graphs // 10:data.num_graphs // 5]
    train_index = perm[data.num_graphs // 5:]

    best_val_acc = 0

    for hist in hists:
        train_hist, train_y = hist[train_index], data.y[train_index]
        val_hist, val_y = hist[val_index], data.y[val_index]
        test_hist, test_y = hist[test_index], data.y[test_index]

        for C in [10**3, 10**2, 10**1, 10**0, 10**-1, 10**-2, 10**-3]:
            model = LinearSVC(C=C, tol=0.01, dual=True)
            model.fit(train_hist, train_y)
            val_acc = accuracy_score(val_y, model.predict(val_hist))
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                test_acc = accuracy_score(test_y, model.predict(test_hist))
                test_accs[run - 1] = test_acc

    print(f'Run: {run:02d}, Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')

print(f'Final Test Performance: {test_accs.mean():.4f}±{test_accs.std():.4f}')