1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
|
import argparse
import os.path as osp
from typing import Any, Dict, Optional
import torch
from torch.nn import (
BatchNorm1d,
Embedding,
Linear,
ModuleList,
ReLU,
Sequential,
)
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch_geometric.transforms as T
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, GPSConv, global_add_pool
from torch_geometric.nn.attention import PerformerAttention
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ZINC-PE')
transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')
train_dataset = ZINC(path, subset=True, split='train', pre_transform=transform)
val_dataset = ZINC(path, subset=True, split='val', pre_transform=transform)
test_dataset = ZINC(path, subset=True, split='test', pre_transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)
parser = argparse.ArgumentParser()
parser.add_argument(
'--attn_type', default='multihead',
help="Global attention type such as 'multihead' or 'performer'.")
args = parser.parse_args()
class GPS(torch.nn.Module):
def __init__(self, channels: int, pe_dim: int, num_layers: int,
attn_type: str, attn_kwargs: Dict[str, Any]):
super().__init__()
self.node_emb = Embedding(28, channels - pe_dim)
self.pe_lin = Linear(20, pe_dim)
self.pe_norm = BatchNorm1d(20)
self.edge_emb = Embedding(4, channels)
self.convs = ModuleList()
for _ in range(num_layers):
nn = Sequential(
Linear(channels, channels),
ReLU(),
Linear(channels, channels),
)
conv = GPSConv(channels, GINEConv(nn), heads=4,
attn_type=attn_type, attn_kwargs=attn_kwargs)
self.convs.append(conv)
self.mlp = Sequential(
Linear(channels, channels // 2),
ReLU(),
Linear(channels // 2, channels // 4),
ReLU(),
Linear(channels // 4, 1),
)
self.redraw_projection = RedrawProjection(
self.convs,
redraw_interval=1000 if attn_type == 'performer' else None)
def forward(self, x, pe, edge_index, edge_attr, batch):
x_pe = self.pe_norm(pe)
x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(x_pe)), 1)
edge_attr = self.edge_emb(edge_attr)
for conv in self.convs:
x = conv(x, edge_index, batch, edge_attr=edge_attr)
x = global_add_pool(x, batch)
return self.mlp(x)
class RedrawProjection:
def __init__(self, model: torch.nn.Module,
redraw_interval: Optional[int] = None):
self.model = model
self.redraw_interval = redraw_interval
self.num_last_redraw = 0
def redraw_projections(self):
if not self.model.training or self.redraw_interval is None:
return
if self.num_last_redraw >= self.redraw_interval:
fast_attentions = [
module for module in self.model.modules()
if isinstance(module, PerformerAttention)
]
for fast_attention in fast_attentions:
fast_attention.redraw_projection_matrix()
self.num_last_redraw = 0
return
self.num_last_redraw += 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attn_kwargs = {'dropout': 0.5}
model = GPS(channels=64, pe_dim=8, num_layers=10, attn_type=args.attn_type,
attn_kwargs=attn_kwargs).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
min_lr=0.00001)
def train():
model.train()
total_loss = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
model.redraw_projection.redraw_projections()
out = model(data.x, data.pe, data.edge_index, data.edge_attr,
data.batch)
loss = (out.squeeze() - data.y).abs().mean()
loss.backward()
total_loss += loss.item() * data.num_graphs
optimizer.step()
return total_loss / len(train_loader.dataset)
@torch.no_grad()
def test(loader):
model.eval()
total_error = 0
for data in loader:
data = data.to(device)
out = model(data.x, data.pe, data.edge_index, data.edge_attr,
data.batch)
total_error += (out.squeeze() - data.y).abs().sum().item()
return total_error / len(loader.dataset)
for epoch in range(1, 101):
loss = train()
val_mae = test(val_loader)
test_mae = test(test_loader)
scheduler.step(val_mae)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
f'Test: {test_mae:.4f}')
|