1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
|
import argparse
import os.path as osp
import time
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GAE, VGAE, GCNConv
parser = argparse.ArgumentParser()
parser.add_argument('--variational', action='store_true')
parser.add_argument('--linear', action='store_true')
parser.add_argument('--dataset', type=str, default='Cora',
choices=['Cora', 'CiteSeer', 'PubMed'])
parser.add_argument('--epochs', type=int, default=400)
args = parser.parse_args()
if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
transform = T.Compose([
T.NormalizeFeatures(),
T.ToDevice(device),
T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
split_labels=True, add_negative_train_samples=False),
])
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, args.dataset, transform=transform)
train_data, val_data, test_data = dataset[0]
class GCNEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, 2 * out_channels)
self.conv2 = GCNConv(2 * out_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
class VariationalGCNEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, 2 * out_channels)
self.conv_mu = GCNConv(2 * out_channels, out_channels)
self.conv_logstd = GCNConv(2 * out_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)
class LinearEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = GCNConv(in_channels, out_channels)
def forward(self, x, edge_index):
return self.conv(x, edge_index)
class VariationalLinearEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_mu = GCNConv(in_channels, out_channels)
self.conv_logstd = GCNConv(in_channels, out_channels)
def forward(self, x, edge_index):
return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)
in_channels, out_channels = dataset.num_features, 16
if not args.variational and not args.linear:
model = GAE(GCNEncoder(in_channels, out_channels))
elif not args.variational and args.linear:
model = GAE(LinearEncoder(in_channels, out_channels))
elif args.variational and not args.linear:
model = VGAE(VariationalGCNEncoder(in_channels, out_channels))
elif args.variational and args.linear:
model = VGAE(VariationalLinearEncoder(in_channels, out_channels))
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
optimizer.zero_grad()
z = model.encode(train_data.x, train_data.edge_index)
loss = model.recon_loss(z, train_data.pos_edge_label_index)
if args.variational:
loss = loss + (1 / train_data.num_nodes) * model.kl_loss()
loss.backward()
optimizer.step()
return float(loss)
@torch.no_grad()
def test(data):
model.eval()
z = model.encode(data.x, data.edge_index)
return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)
times = []
for epoch in range(1, args.epochs + 1):
start = time.time()
loss = train()
auc, ap = test(test_data)
print(f'Epoch: {epoch:03d}, AUC: {auc:.4f}, AP: {ap:.4f}')
times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
|