1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
|
import argparse
import os.path as osp
import time
import torch
import torch.nn.functional as F
from torch_geometric.datasets import GDELT, ICEWS18
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models.re_net import RENet
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataset',
type=str,
default='GDELT',
choices=['ICEWS18', 'GDELT'],
)
parser.add_argument('--seq_len', type=int, default=10)
args = parser.parse_args()
# Load the dataset and precompute history objects:
path = osp.dirname(osp.realpath(__file__))
path = osp.join(path, '..', 'data', args.dataset)
pre_transform = RENet.pre_transform(args.seq_len)
if args.dataset == 'ICEWS18':
train_dataset = ICEWS18(path, pre_transform=pre_transform)
test_dataset = ICEWS18(path, split='test', pre_transform=pre_transform)
elif args.dataset == 'GDELT':
train_dataset = GDELT(path, pre_transform=pre_transform)
test_dataset = GDELT(path, split='test', pre_transform=pre_transform)
# Create dataloader for training and test dataset.
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True,
follow_batch=['h_sub', 'h_obj'], num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False,
follow_batch=['h_sub', 'h_obj'], num_workers=6)
# Initialize model and optimizer.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RENet(
train_dataset.num_nodes,
train_dataset.num_rels,
hidden_channels=200,
seq_len=args.seq_len,
dropout=0.5,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,
weight_decay=0.00001)
def train():
model.train()
# Train model via multi-class classification against the corresponding
# object and subject entities.
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
log_prob_obj, log_prob_sub = model(data)
loss_obj = F.nll_loss(log_prob_obj, data.obj)
loss_sub = F.nll_loss(log_prob_sub, data.sub)
loss = loss_obj + loss_sub
loss.backward()
optimizer.step()
def test(loader):
model.eval()
# Compute Mean Reciprocal Rank (MRR) and Hits@1/3/10.
result = torch.tensor([0, 0, 0, 0], dtype=torch.float)
for data in loader:
data = data.to(device)
with torch.no_grad():
log_prob_obj, log_prob_sub = model(data)
result += model.test(log_prob_obj, data.obj) * data.obj.size(0)
result += model.test(log_prob_sub, data.sub) * data.sub.size(0)
result = result / (2 * len(loader.dataset))
return result.tolist()
times = []
for epoch in range(1, 21):
start = time.time()
train()
mrr, hits1, hits3, hits10 = test(test_loader)
print(f'Epoch: {epoch:02d}, MRR: {mrr:.4f}, Hits@1: {hits1:.4f}, '
f'Hits@3: {hits3:.4f}, Hits@10: {hits10:.4f}')
times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
|