1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
|
# Reaches around 0.7945 ± 0.0059 test accuracy.
import os.path as osp
import torch
import torch.nn.functional as F
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from torch.nn import Linear as Lin
from tqdm import tqdm
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GATConv
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'products')
dataset = PygNodePropPredDataset('ogbn-products', root)
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name='ogbn-products')
data = dataset[0].to(device, 'x', 'y')
train_loader = NeighborLoader(
data,
input_nodes=split_idx['train'],
num_neighbors=[10, 10, 5],
batch_size=512,
shuffle=True,
num_workers=12,
persistent_workers=True,
)
subgraph_loader = NeighborLoader(
data,
input_nodes=None,
num_neighbors=[-1],
batch_size=2048,
num_workers=12,
persistent_workers=True,
)
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
heads):
super().__init__()
self.num_layers = num_layers
self.convs = torch.nn.ModuleList()
self.convs.append(GATConv(dataset.num_features, hidden_channels,
heads))
for _ in range(num_layers - 2):
self.convs.append(
GATConv(heads * hidden_channels, hidden_channels, heads))
self.convs.append(
GATConv(heads * hidden_channels, out_channels, heads,
concat=False))
self.skips = torch.nn.ModuleList()
self.skips.append(Lin(dataset.num_features, hidden_channels * heads))
for _ in range(num_layers - 2):
self.skips.append(
Lin(hidden_channels * heads, hidden_channels * heads))
self.skips.append(Lin(hidden_channels * heads, out_channels))
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for skip in self.skips:
skip.reset_parameters()
def forward(self, x, edge_index):
for i, (conv, skip) in enumerate(zip(self.convs, self.skips)):
x = conv(x, edge_index) + skip(x)
if i != self.num_layers - 1:
x = F.elu(x)
x = F.dropout(x, p=0.5, training=self.training)
return x
def inference(self, x_all):
pbar = tqdm(total=x_all.size(0) * self.num_layers)
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch.
for i in range(self.num_layers):
xs = []
for batch in subgraph_loader:
x = x_all[batch.n_id].to(device)
edge_index = batch.edge_index.to(device)
x = self.convs[i](x, edge_index) + self.skips[i](x)
x = x[:batch.batch_size]
if i != self.num_layers - 1:
x = F.elu(x)
xs.append(x.cpu())
pbar.update(batch.batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return x_all
model = GAT(dataset.num_features, 128, dataset.num_classes, num_layers=3,
heads=4).to(device)
def train(epoch):
model.train()
pbar = tqdm(total=split_idx['train'].size(0))
pbar.set_description(f'Epoch {epoch:02d}')
total_loss = total_correct = 0
for batch in train_loader:
optimizer.zero_grad()
out = model(batch.x, batch.edge_index.to(device))[:batch.batch_size]
y = batch.y[:batch.batch_size].squeeze()
loss = F.cross_entropy(out, y)
loss.backward()
optimizer.step()
total_loss += float(loss)
total_correct += int(out.argmax(dim=-1).eq(y).sum())
pbar.update(batch.batch_size)
pbar.close()
loss = total_loss / len(train_loader)
approx_acc = total_correct / split_idx['train'].size(0)
return loss, approx_acc
@torch.no_grad()
def test():
model.eval()
out = model.inference(data.x)
y_true = data.y.cpu()
y_pred = out.argmax(dim=-1, keepdim=True)
train_acc = evaluator.eval({
'y_true': y_true[split_idx['train']],
'y_pred': y_pred[split_idx['train']],
})['acc']
val_acc = evaluator.eval({
'y_true': y_true[split_idx['valid']],
'y_pred': y_pred[split_idx['valid']],
})['acc']
test_acc = evaluator.eval({
'y_true': y_true[split_idx['test']],
'y_pred': y_pred[split_idx['test']],
})['acc']
return train_acc, val_acc, test_acc
test_accs = []
for run in range(1, 11):
print(f'\nRun {run:02d}:\n')
model.reset_parameters()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
best_val_acc = final_test_acc = 0.0
for epoch in range(1, 101):
loss, acc = train(epoch)
print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}')
if epoch > 50 and epoch % 10 == 0:
train_acc, val_acc, test_acc = test()
print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
f'Test: {test_acc:.4f}')
if val_acc > best_val_acc:
best_val_acc = val_acc
final_test_acc = test_acc
test_accs.append(final_test_acc)
test_acc = torch.tensor(test_accs)
print('============================')
print(f'Final Test: {test_acc.mean():.4f} ± {test_acc.std():.4f}')
|