1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
|
# This code achieves a performance of around 96.60%. However, it is not
# directly comparable to the results reported by the TGN paper since a
# slightly different evaluation setup is used here.
# In particular, predictions in the same batch are made in parallel, i.e.
# predictions for interactions later in the batch have no access to any
# information whatsoever about previous interactions in the same batch.
# On the contrary, when sampling node neighborhoods for interactions later in
# the batch, the TGN paper code has access to previous interactions in the
# batch.
# While both approaches are correct, together with the authors of the paper we
# decided to present this version here as it is more realsitic and a better
# test bed for future methods.
import os.path as osp
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import (
IdentityMessage,
LastAggregator,
LastNeighborLoader,
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'JODIE')
dataset = JODIEDataset(path, name='wikipedia')
data = dataset[0]
# For small datasets, we can put the whole dataset on GPU and thus avoid
# expensive memory transfer costs for mini-batches:
data = data.to(device)
train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15)
train_loader = TemporalDataLoader(
train_data,
batch_size=200,
neg_sampling_ratio=1.0,
)
val_loader = TemporalDataLoader(
val_data,
batch_size=200,
neg_sampling_ratio=1.0,
)
test_loader = TemporalDataLoader(
test_data,
batch_size=200,
neg_sampling_ratio=1.0,
)
neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)
class GraphAttentionEmbedding(torch.nn.Module):
def __init__(self, in_channels, out_channels, msg_dim, time_enc):
super().__init__()
self.time_enc = time_enc
edge_dim = msg_dim + time_enc.out_channels
self.conv = TransformerConv(in_channels, out_channels // 2, heads=2,
dropout=0.1, edge_dim=edge_dim)
def forward(self, x, last_update, edge_index, t, msg):
rel_t = last_update[edge_index[0]] - t
rel_t_enc = self.time_enc(rel_t.to(x.dtype))
edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
return self.conv(x, edge_index, edge_attr)
class LinkPredictor(torch.nn.Module):
def __init__(self, in_channels):
super().__init__()
self.lin_src = Linear(in_channels, in_channels)
self.lin_dst = Linear(in_channels, in_channels)
self.lin_final = Linear(in_channels, 1)
def forward(self, z_src, z_dst):
h = self.lin_src(z_src) + self.lin_dst(z_dst)
h = h.relu()
return self.lin_final(h)
memory_dim = time_dim = embedding_dim = 100
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
memory_dim,
time_dim,
message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=memory_dim,
out_channels=embedding_dim,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=embedding_dim).to(device)
optimizer = torch.optim.Adam(
set(memory.parameters()) | set(gnn.parameters())
| set(link_pred.parameters()), lr=0.0001)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
def train():
memory.train()
gnn.train()
link_pred.train()
memory.reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
optimizer.zero_grad()
batch = batch.to(device)
n_id, edge_index, e_id = neighbor_loader(batch.n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = memory(n_id)
z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
data.msg[e_id].to(device))
pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
neighbor_loader.insert(batch.src, batch.dst)
loss.backward()
optimizer.step()
memory.detach()
total_loss += float(loss) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader):
memory.eval()
gnn.eval()
link_pred.eval()
torch.manual_seed(12345) # Ensure deterministic sampling across epochs.
aps, aucs = [], []
for batch in loader:
batch = batch.to(device)
n_id, edge_index, e_id = neighbor_loader(batch.n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
z, last_update = memory(n_id)
z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
data.msg[e_id].to(device))
pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])
y_pred = torch.cat([pos_out, neg_out], dim=0).sigmoid().cpu()
y_true = torch.cat(
[torch.ones(pos_out.size(0)),
torch.zeros(neg_out.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred))
aucs.append(roc_auc_score(y_true, y_pred))
memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
neighbor_loader.insert(batch.src, batch.dst)
return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean())
for epoch in range(1, 51):
loss = train()
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
val_ap, val_auc = test(val_loader)
test_ap, test_auc = test(test_loader)
print(f'Val AP: {val_ap:.4f}, Val AUC: {val_auc:.4f}')
print(f'Test AP: {test_ap:.4f}, Test AUC: {test_auc:.4f}')
|