File: label_prop.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (32 lines) | stat: -rw-r--r-- 931 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os.path as osp

from ogb.nodeproppred import Evaluator, PygNodePropPredDataset

import torch_geometric.transforms as T
from torch_geometric.nn import LabelPropagation

root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB')
dataset = PygNodePropPredDataset(
    'ogbn-arxiv', root, transform=T.Compose([
        T.ToUndirected(),
        T.ToSparseTensor(),
    ]))
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name='ogbn-arxiv')
data = dataset[0]

model = LabelPropagation(num_layers=3, alpha=0.9)
out = model(data.y, data.adj_t, mask=split_idx['train'])

y_pred = out.argmax(dim=-1, keepdim=True)

val_acc = evaluator.eval({
    'y_true': data.y[split_idx['valid']],
    'y_pred': y_pred[split_idx['valid']],
})['acc']
test_acc = evaluator.eval({
    'y_true': data.y[split_idx['test']],
    'y_pred': y_pred[split_idx['test']],
})['acc']

print(f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')