File: test_dropout.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (102 lines) | stat: -rw-r--r-- 3,581 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import pytest
import torch

from torch_geometric.testing import withPackage
from torch_geometric.utils import (
    dropout_adj,
    dropout_edge,
    dropout_node,
    dropout_path,
)


def test_dropout_adj():
    edge_index = torch.tensor([
        [0, 1, 1, 2, 2, 3],
        [1, 0, 2, 1, 3, 2],
    ])
    edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])

    with pytest.warns(UserWarning, match="'dropout_adj' is deprecated"):
        out = dropout_adj(edge_index, edge_attr, training=False)
    assert edge_index.tolist() == out[0].tolist()
    assert edge_attr.tolist() == out[1].tolist()

    torch.manual_seed(5)
    with pytest.warns(UserWarning, match="'dropout_adj' is deprecated"):
        out = dropout_adj(edge_index, edge_attr)
    assert out[0].tolist() == [[0, 1, 2, 2], [1, 2, 1, 3]]
    assert out[1].tolist() == [1, 3, 4, 5]

    torch.manual_seed(6)
    with pytest.warns(UserWarning, match="'dropout_adj' is deprecated"):
        out = dropout_adj(edge_index, edge_attr, force_undirected=True)
    assert out[0].tolist() == [[0, 1, 1, 2], [1, 2, 0, 1]]
    assert out[1].tolist() == [1, 3, 1, 3]


def test_dropout_node():
    edge_index = torch.tensor([
        [0, 1, 1, 2, 2, 3],
        [1, 0, 2, 1, 3, 2],
    ])

    out = dropout_node(edge_index, training=False)
    assert edge_index.tolist() == out[0].tolist()
    assert out[1].tolist() == [True, True, True, True, True, True]
    assert out[2].tolist() == [True, True, True, True]

    torch.manual_seed(5)
    out = dropout_node(edge_index)
    assert out[0].tolist() == [[2, 3], [3, 2]]
    assert out[1].tolist() == [False, False, False, False, True, True]
    assert out[2].tolist() == [True, False, True, True]


def test_dropout_edge():
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])

    out = dropout_edge(edge_index, training=False)
    assert edge_index.tolist() == out[0].tolist()
    assert out[1].tolist() == [True, True, True, True, True, True]

    torch.manual_seed(5)
    out = dropout_edge(edge_index)
    assert out[0].tolist() == [[0, 1, 2, 2], [1, 2, 1, 3]]
    assert out[1].tolist() == [True, False, True, True, True, False]

    torch.manual_seed(6)
    out = dropout_edge(edge_index, force_undirected=True)
    assert out[0].tolist() == [[0, 1, 1, 2], [1, 2, 0, 1]]
    assert out[1].tolist() == [0, 2, 0, 2]


@withPackage('torch_cluster')
def test_dropout_path():
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])

    out = dropout_path(edge_index, training=False)
    assert edge_index.tolist() == out[0].tolist()
    assert out[1].tolist() == [True, True, True, True, True, True]

    torch.manual_seed(4)
    out = dropout_path(edge_index, p=0.2)
    assert out[0].tolist() == [[0, 1], [1, 0]]
    assert out[1].tolist() == [True, True, False, False, False, False]
    assert edge_index[:, out[1]].tolist() == out[0].tolist()

    # test with unsorted edges
    torch.manual_seed(5)
    edge_index = torch.tensor([[3, 5, 2, 2, 2, 1], [1, 0, 0, 1, 3, 2]])
    out = dropout_path(edge_index, p=0.2)
    assert out[0].tolist() == [[3, 2, 2, 1], [1, 1, 3, 2]]
    assert out[1].tolist() == [True, False, False, True, True, True]
    assert edge_index[:, out[1]].tolist() == out[0].tolist()

    # test with isolated nodes
    torch.manual_seed(7)
    edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 2, 4]])
    out = dropout_path(edge_index, p=0.2)
    assert out[0].tolist() == [[2, 3], [2, 4]]
    assert out[1].tolist() == [False, False, True, True]
    assert edge_index[:, out[1]].tolist() == out[0].tolist()