File: test_train_test_split_edges.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (28 lines) | stat: -rw-r--r-- 1,159 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import pytest
import torch

from torch_geometric.data import Data
from torch_geometric.utils import train_test_split_edges


def test_train_test_split_edges():
    edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                               [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
    edge_attr = torch.arange(edge_index.size(1))
    data = Data(edge_index=edge_index, edge_attr=edge_attr)
    data.num_nodes = edge_index.max().item() + 1

    with pytest.warns(UserWarning, match='deprecated'):
        data = train_test_split_edges(data, val_ratio=0.2, test_ratio=0.3)

    assert len(data) == 10
    assert data.val_pos_edge_index.size() == (2, 2)
    assert data.val_neg_edge_index.size() == (2, 2)
    assert data.test_pos_edge_index.size() == (2, 3)
    assert data.test_neg_edge_index.size() == (2, 3)
    assert data.train_pos_edge_index.size() == (2, 10)
    assert data.train_neg_adj_mask.size() == (11, 11)
    assert data.train_neg_adj_mask.sum().item() == (11**2 - 11) / 2 - 4 - 6 - 5
    assert data.train_pos_edge_attr.size() == (10, )
    assert data.val_pos_edge_attr.size() == (2, )
    assert data.test_pos_edge_attr.size() == (3, )