1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
|
import os.path as osp
import time
from math import ceil
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DenseGraphConv, GCNConv, dense_mincut_pool
from torch_geometric.utils import to_dense_adj, to_dense_batch
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS')
dataset = TUDataset(path, name='PROTEINS').shuffle()
avg_num_nodes = int(dataset._data.x.size(0) / len(dataset))
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]
test_loader = DataLoader(test_dataset, batch_size=20)
val_loader = DataLoader(val_dataset, batch_size=20)
train_loader = DataLoader(train_dataset, batch_size=20)
class Net(torch.nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels=32):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
num_nodes = ceil(0.5 * avg_num_nodes)
self.pool1 = Linear(hidden_channels, num_nodes)
self.conv2 = DenseGraphConv(hidden_channels, hidden_channels)
num_nodes = ceil(0.5 * num_nodes)
self.pool2 = Linear(hidden_channels, num_nodes)
self.conv3 = DenseGraphConv(hidden_channels, hidden_channels)
self.lin1 = Linear(hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, out_channels)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x, mask = to_dense_batch(x, batch)
adj = to_dense_adj(edge_index, batch)
s = self.pool1(x)
x, adj, mc1, o1 = dense_mincut_pool(x, adj, s, mask)
x = self.conv2(x, adj).relu()
s = self.pool2(x)
x, adj, mc2, o2 = dense_mincut_pool(x, adj, s)
x = self.conv3(x, adj)
x = x.mean(dim=1)
x = self.lin1(x).relu()
x = self.lin2(x)
return F.log_softmax(x, dim=-1), mc1 + mc2, o1 + o2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
def train(epoch):
model.train()
loss_all = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
out, mc_loss, o_loss = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(out, data.y.view(-1)) + mc_loss + o_loss
loss.backward()
loss_all += data.y.size(0) * float(loss)
optimizer.step()
return loss_all / len(train_dataset)
@torch.no_grad()
def test(loader):
model.eval()
correct = 0
loss_all = 0
for data in loader:
data = data.to(device)
pred, mc_loss, o_loss = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(pred, data.y.view(-1)) + mc_loss + o_loss
loss_all += data.y.size(0) * float(loss)
correct += int(pred.max(dim=1)[1].eq(data.y.view(-1)).sum())
return loss_all / len(loader.dataset), correct / len(loader.dataset)
times = []
best_val_acc = test_acc = 0
best_val_loss = float('inf')
patience = start_patience = 50
for epoch in range(1, 15001):
start = time.time()
train_loss = train(epoch)
_, train_acc = test(train_loader)
val_loss, val_acc = test(val_loader)
if val_loss < best_val_loss:
test_loss, test_acc = test(test_loader)
best_val_acc = val_acc
patience = start_patience
else:
patience -= 1
if patience == 0:
break
print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.3f}, '
f'Train Acc: {train_acc:.3f}, Val Loss: {val_loss:.3f}, '
f'Val Acc: {val_acc:.3f}, Test Loss: {test_loss:.3f}, '
f'Test Acc: {test_acc:.3f}')
times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
|