1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
|
import os.path as osp
import time
import torch
import torch.nn.functional as F
from torch.nn import BatchNorm1d, LeakyReLU, Linear
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DeepGCNLayer, GATConv, MemPooling
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TUD')
dataset = TUDataset(path, name="PROTEINS_full", use_node_attr=True)
dataset._data.x = dataset._data.x[:, :-3] # only use non-binary features.
dataset = dataset.shuffle()
n = (len(dataset)) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]
test_loader = DataLoader(test_dataset, batch_size=20)
val_loader = DataLoader(val_dataset, batch_size=20)
train_loader = DataLoader(train_dataset, batch_size=20)
class Net(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, dropout):
super().__init__()
self.dropout = dropout
self.lin = Linear(in_channels, hidden_channels)
self.convs = torch.nn.ModuleList()
for i in range(2):
conv = GATConv(hidden_channels, hidden_channels, dropout=dropout)
norm = BatchNorm1d(hidden_channels)
act = LeakyReLU()
self.convs.append(
DeepGCNLayer(conv, norm, act, block='res+', dropout=dropout))
self.mem1 = MemPooling(hidden_channels, 80, heads=5, num_clusters=10)
self.mem2 = MemPooling(80, out_channels, heads=5, num_clusters=1)
def forward(self, x, edge_index, batch):
x = self.lin(x)
for conv in self.convs:
x = conv(x, edge_index)
x, S1 = self.mem1(x, batch)
x = F.leaky_relu(x)
x = F.dropout(x, p=self.dropout)
x, S2 = self.mem2(x)
return (
F.log_softmax(x.squeeze(1), dim=-1),
MemPooling.kl_loss(S1) + MemPooling.kl_loss(S2),
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Net(dataset.num_features, 32, dataset.num_classes, dropout=0.1)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=4e-5)
def train():
model.train()
model.mem1.k.requires_grad = False
model.mem2.k.requires_grad = False
for data in train_loader:
optimizer.zero_grad()
data = data.to(device)
out = model(data.x, data.edge_index, data.batch)[0]
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
kl_loss = 0.
model.mem1.k.requires_grad = True
model.mem2.k.requires_grad = True
optimizer.zero_grad()
for data in train_loader:
data = data.to(device)
kl_loss += model(data.x, data.edge_index, data.batch)[1]
kl_loss /= len(train_loader.dataset)
kl_loss.backward()
optimizer.step()
@torch.no_grad()
def test(loader):
model.eval()
total_correct = 0
for data in loader:
data = data.to(device)
out = model(data.x, data.edge_index, data.batch)[0]
total_correct += int((out.argmax(dim=-1) == data.y).sum())
return total_correct / len(loader.dataset)
times = []
patience = start_patience = 250
test_acc = best_val_acc = 0.
for epoch in range(1, 2001):
start = time.time()
train()
val_acc = test(val_loader)
if epoch % 500 == 0:
optimizer.param_groups[0]['lr'] *= 0.5
if best_val_acc < val_acc:
patience = start_patience
best_val_acc = val_acc
test_acc = test(test_loader)
else:
patience -= 1
print(f'Epoch {epoch:02d}, Val: {val_acc:.3f}, Test: {test_acc:.3f}')
if patience <= 0:
break
times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
|