File: mem_pool.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (122 lines) | stat: -rw-r--r-- 3,781 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os.path as osp
import time

import torch
import torch.nn.functional as F
from torch.nn import BatchNorm1d, LeakyReLU, Linear

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DeepGCNLayer, GATConv, MemPooling

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TUD')
dataset = TUDataset(path, name="PROTEINS_full", use_node_attr=True)
dataset._data.x = dataset._data.x[:, :-3]  # only use non-binary features.
dataset = dataset.shuffle()

n = (len(dataset)) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]

test_loader = DataLoader(test_dataset, batch_size=20)
val_loader = DataLoader(val_dataset, batch_size=20)
train_loader = DataLoader(train_dataset, batch_size=20)


class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout):
        super().__init__()
        self.dropout = dropout

        self.lin = Linear(in_channels, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for i in range(2):
            conv = GATConv(hidden_channels, hidden_channels, dropout=dropout)
            norm = BatchNorm1d(hidden_channels)
            act = LeakyReLU()
            self.convs.append(
                DeepGCNLayer(conv, norm, act, block='res+', dropout=dropout))

        self.mem1 = MemPooling(hidden_channels, 80, heads=5, num_clusters=10)
        self.mem2 = MemPooling(80, out_channels, heads=5, num_clusters=1)

    def forward(self, x, edge_index, batch):
        x = self.lin(x)
        for conv in self.convs:
            x = conv(x, edge_index)

        x, S1 = self.mem1(x, batch)
        x = F.leaky_relu(x)
        x = F.dropout(x, p=self.dropout)
        x, S2 = self.mem2(x)

        return (
            F.log_softmax(x.squeeze(1), dim=-1),
            MemPooling.kl_loss(S1) + MemPooling.kl_loss(S2),
        )


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Net(dataset.num_features, 32, dataset.num_classes, dropout=0.1)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=4e-5)


def train():
    model.train()

    model.mem1.k.requires_grad = False
    model.mem2.k.requires_grad = False
    for data in train_loader:
        optimizer.zero_grad()
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)[0]
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()

    kl_loss = 0.
    model.mem1.k.requires_grad = True
    model.mem2.k.requires_grad = True
    optimizer.zero_grad()
    for data in train_loader:
        data = data.to(device)
        kl_loss += model(data.x, data.edge_index, data.batch)[1]
    kl_loss /= len(train_loader.dataset)
    kl_loss.backward()
    optimizer.step()


@torch.no_grad()
def test(loader):
    model.eval()
    total_correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)[0]
        total_correct += int((out.argmax(dim=-1) == data.y).sum())
    return total_correct / len(loader.dataset)


times = []
patience = start_patience = 250
test_acc = best_val_acc = 0.
for epoch in range(1, 2001):
    start = time.time()
    train()
    val_acc = test(val_loader)
    if epoch % 500 == 0:
        optimizer.param_groups[0]['lr'] *= 0.5
    if best_val_acc < val_acc:
        patience = start_patience
        best_val_acc = val_acc
        test_acc = test(test_loader)
    else:
        patience -= 1
    print(f'Epoch {epoch:02d}, Val: {val_acc:.3f}, Test: {test_acc:.3f}')
    if patience <= 0:
        break
    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")