1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
|
import math
import os.path as osp
import time
from itertools import chain
import numpy as np
import torch
import torch.nn.functional as F
from scipy.sparse.csgraph import shortest_path
from sklearn.metrics import roc_auc_score
from torch.nn import BCEWithLogitsLoss, Conv1d, MaxPool1d, ModuleList
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, GCNConv, SortAggregation
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix
class SEALDataset(InMemoryDataset):
def __init__(self, dataset, num_hops, split='train'):
self._data = dataset[0]
self.num_hops = num_hops
super().__init__(dataset.root)
index = ['train', 'val', 'test'].index(split)
self.load(self.processed_paths[index])
@property
def processed_file_names(self):
return ['SEAL_train_data.pt', 'SEAL_val_data.pt', 'SEAL_test_data.pt']
def process(self):
transform = RandomLinkSplit(num_val=0.05, num_test=0.1,
is_undirected=True, split_labels=True)
train_data, val_data, test_data = transform(self._data)
self._max_z = 0
# Collect a list of subgraphs for training, validation and testing:
train_pos_data_list = self.extract_enclosing_subgraphs(
train_data.edge_index, train_data.pos_edge_label_index, 1)
train_neg_data_list = self.extract_enclosing_subgraphs(
train_data.edge_index, train_data.neg_edge_label_index, 0)
val_pos_data_list = self.extract_enclosing_subgraphs(
val_data.edge_index, val_data.pos_edge_label_index, 1)
val_neg_data_list = self.extract_enclosing_subgraphs(
val_data.edge_index, val_data.neg_edge_label_index, 0)
test_pos_data_list = self.extract_enclosing_subgraphs(
test_data.edge_index, test_data.pos_edge_label_index, 1)
test_neg_data_list = self.extract_enclosing_subgraphs(
test_data.edge_index, test_data.neg_edge_label_index, 0)
# Convert node labeling to one-hot features.
for data in chain(train_pos_data_list, train_neg_data_list,
val_pos_data_list, val_neg_data_list,
test_pos_data_list, test_neg_data_list):
# We solely learn links from structure, dropping any node features:
data.x = F.one_hot(data.z, self._max_z + 1).to(torch.float)
train_data_list = train_pos_data_list + train_neg_data_list
self.save(train_data_list, self.processed_paths[0])
val_data_list = val_pos_data_list + val_neg_data_list
self.save(val_data_list, self.processed_paths[1])
test_data_list = test_pos_data_list + test_neg_data_list
self.save(test_data_list, self.processed_paths[2])
def extract_enclosing_subgraphs(self, edge_index, edge_label_index, y):
data_list = []
for src, dst in edge_label_index.t().tolist():
sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(
[src, dst], self.num_hops, edge_index, relabel_nodes=True)
src, dst = mapping.tolist()
# Remove target link from the subgraph.
mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
sub_edge_index = sub_edge_index[:, mask1 & mask2]
# Calculate node labeling.
z = self.drnl_node_labeling(sub_edge_index, src, dst,
num_nodes=sub_nodes.size(0))
data = Data(x=self._data.x[sub_nodes], z=z,
edge_index=sub_edge_index, y=y)
data_list.append(data)
return data_list
def drnl_node_labeling(self, edge_index, src, dst, num_nodes=None):
# Double-radius node labeling (DRNL).
src, dst = (dst, src) if src > dst else (src, dst)
adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr()
idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
adj_wo_src = adj[idx, :][:, idx]
idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
adj_wo_dst = adj[idx, :][:, idx]
dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True,
indices=src)
dist2src = np.insert(dist2src, dst, 0, axis=0)
dist2src = torch.from_numpy(dist2src)
dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True,
indices=dst - 1)
dist2dst = np.insert(dist2dst, src, 0, axis=0)
dist2dst = torch.from_numpy(dist2dst)
dist = dist2src + dist2dst
dist_over_2, dist_mod_2 = dist // 2, dist % 2
z = 1 + torch.min(dist2src, dist2dst)
z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
z[src] = 1.
z[dst] = 1.
z[torch.isnan(z)] = 0.
self._max_z = max(int(z.max()), self._max_z)
return z.to(torch.long)
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, name='Cora')
train_dataset = SEALDataset(dataset, num_hops=2, split='train')
val_dataset = SEALDataset(dataset, num_hops=2, split='val')
test_dataset = SEALDataset(dataset, num_hops=2, split='test')
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)
class DGCNN(torch.nn.Module):
def __init__(self, hidden_channels, num_layers, GNN=GCNConv, k=0.6):
super().__init__()
if k < 1: # Transform percentile to number.
num_nodes = sorted([data.num_nodes for data in train_dataset])
k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1]
k = int(max(10, k))
self.convs = ModuleList()
self.convs.append(GNN(train_dataset.num_features, hidden_channels))
for i in range(0, num_layers - 1):
self.convs.append(GNN(hidden_channels, hidden_channels))
self.convs.append(GNN(hidden_channels, 1))
conv1d_channels = [16, 32]
total_latent_dim = hidden_channels * num_layers + 1
conv1d_kws = [total_latent_dim, 5]
self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0],
conv1d_kws[0])
self.pool = SortAggregation(k)
self.maxpool1d = MaxPool1d(2, 2)
self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1],
conv1d_kws[1], 1)
dense_dim = int((k - 2) / 2 + 1)
dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
self.mlp = MLP([dense_dim, 128, 1], dropout=0.5, norm=None)
def forward(self, x, edge_index, batch):
xs = [x]
for conv in self.convs:
xs += [conv(xs[-1], edge_index).tanh()]
x = torch.cat(xs[1:], dim=-1)
# Global pooling.
x = self.pool(x, batch)
x = x.unsqueeze(1) # [num_graphs, 1, k * hidden]
x = self.conv1(x).relu()
x = self.maxpool1d(x)
x = self.conv2(x).relu()
x = x.view(x.size(0), -1) # [num_graphs, dense_dim]
return self.mlp(x)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DGCNN(hidden_channels=32, num_layers=3).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
criterion = BCEWithLogitsLoss()
def train():
model.train()
total_loss = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
out = model(data.x, data.edge_index, data.batch)
loss = criterion(out.view(-1), data.y.to(torch.float))
loss.backward()
optimizer.step()
total_loss += float(loss) * data.num_graphs
return total_loss / len(train_dataset)
@torch.no_grad()
def test(loader):
model.eval()
y_pred, y_true = [], []
for data in loader:
data = data.to(device)
logits = model(data.x, data.edge_index, data.batch)
y_pred.append(logits.view(-1).cpu())
y_true.append(data.y.view(-1).cpu().to(torch.float))
return roc_auc_score(torch.cat(y_true), torch.cat(y_pred))
times = []
best_val_auc = test_auc = 0
for epoch in range(1, 51):
start = time.time()
loss = train()
val_auc = test(val_loader)
if val_auc > best_val_auc:
best_val_auc = val_auc
test_auc = test(test_loader)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
f'Test: {test_auc:.4f}')
times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
|