File: renet.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (91 lines) | stat: -rw-r--r-- 3,061 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import os.path as osp
import time

import torch
import torch.nn.functional as F

from torch_geometric.datasets import GDELT, ICEWS18
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models.re_net import RENet

parser = argparse.ArgumentParser()
parser.add_argument(
    '--dataset',
    type=str,
    default='GDELT',
    choices=['ICEWS18', 'GDELT'],
)
parser.add_argument('--seq_len', type=int, default=10)
args = parser.parse_args()

# Load the dataset and precompute history objects:
path = osp.dirname(osp.realpath(__file__))
path = osp.join(path, '..', 'data', args.dataset)
pre_transform = RENet.pre_transform(args.seq_len)
if args.dataset == 'ICEWS18':
    train_dataset = ICEWS18(path, pre_transform=pre_transform)
    test_dataset = ICEWS18(path, split='test', pre_transform=pre_transform)
elif args.dataset == 'GDELT':
    train_dataset = GDELT(path, pre_transform=pre_transform)
    test_dataset = GDELT(path, split='test', pre_transform=pre_transform)

# Create dataloader for training and test dataset.
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True,
                          follow_batch=['h_sub', 'h_obj'], num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False,
                         follow_batch=['h_sub', 'h_obj'], num_workers=6)

# Initialize model and optimizer.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RENet(
    train_dataset.num_nodes,
    train_dataset.num_rels,
    hidden_channels=200,
    seq_len=args.seq_len,
    dropout=0.5,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,
                             weight_decay=0.00001)


def train():
    model.train()

    # Train model via multi-class classification against the corresponding
    # object and subject entities.
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        log_prob_obj, log_prob_sub = model(data)
        loss_obj = F.nll_loss(log_prob_obj, data.obj)
        loss_sub = F.nll_loss(log_prob_sub, data.sub)
        loss = loss_obj + loss_sub
        loss.backward()
        optimizer.step()


def test(loader):
    model.eval()

    # Compute Mean Reciprocal Rank (MRR) and Hits@1/3/10.
    result = torch.tensor([0, 0, 0, 0], dtype=torch.float)
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            log_prob_obj, log_prob_sub = model(data)
        result += model.test(log_prob_obj, data.obj) * data.obj.size(0)
        result += model.test(log_prob_sub, data.sub) * data.sub.size(0)
    result = result / (2 * len(loader.dataset))
    return result.tolist()


times = []
for epoch in range(1, 21):
    start = time.time()
    train()
    mrr, hits1, hits3, hits10 = test(test_loader)
    print(f'Epoch: {epoch:02d}, MRR: {mrr:.4f}, Hits@1: {hits1:.4f}, '
          f'Hits@3: {hits3:.4f}, Hits@10: {hits10:.4f}')
    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")