File: test_explainer.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (123 lines) | stat: -rw-r--r-- 4,117 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import pytest
import torch

from torch_geometric.explain import DummyExplainer, Explainer, Explanation
from torch_geometric.explain.config import ExplanationType


class DummyModel(torch.nn.Module):
    def forward(self, x, edge_index):
        return x.mean().view(-1)


def test_get_prediction(data):
    model = DummyModel()
    assert model.training

    explainer = Explainer(
        model,
        algorithm=DummyExplainer(),
        explanation_type='phenomenon',
        node_mask_type='object',
        model_config=dict(
            mode='regression',
            task_level='graph',
        ),
    )
    pred = explainer.get_prediction(data.x, data.edge_index)
    assert model.training
    assert pred.size() == (1, )


@pytest.mark.parametrize('target', [None, torch.randn(2)])
@pytest.mark.parametrize('explanation_type', [x for x in ExplanationType])
def test_forward(data, target, explanation_type):
    model = DummyModel()
    assert model.training

    explainer = Explainer(
        model,
        algorithm=DummyExplainer(),
        explanation_type=explanation_type,
        node_mask_type='attributes',
        model_config=dict(
            mode='regression',
            task_level='graph',
        ),
    )

    if target is None and explanation_type == ExplanationType.phenomenon:
        with pytest.raises(ValueError):
            explainer(data.x, data.edge_index, target=target)
    else:
        explanation = explainer(
            data.x,
            data.edge_index,
            target=target
            if explanation_type == ExplanationType.phenomenon else None,
        )
        assert model.training
        assert isinstance(explanation, Explanation)
        assert 'x' in explanation
        assert 'edge_index' in explanation
        assert 'target' in explanation
        assert 'node_mask' in explanation.available_explanations
        assert explanation.node_mask.size() == data.x.size()


@pytest.mark.parametrize('threshold_value', [0.2, 0.5, 0.8])
@pytest.mark.parametrize('node_mask_type', ['object', 'attributes'])
def test_hard_threshold(data, threshold_value, node_mask_type):
    explainer = Explainer(
        DummyModel(),
        algorithm=DummyExplainer(),
        explanation_type='model',
        node_mask_type=node_mask_type,
        edge_mask_type='object',
        model_config=dict(
            mode='regression',
            task_level='graph',
        ),
        threshold_config=('hard', threshold_value),
    )
    explanation = explainer(data.x, data.edge_index)

    assert 'node_mask' in explanation.available_explanations
    assert 'edge_mask' in explanation.available_explanations

    for key in explanation.available_explanations:
        mask = explanation[key]
        assert set(mask.unique().tolist()).issubset({0, 1})


@pytest.mark.parametrize('threshold_value', [1, 5, 10])
@pytest.mark.parametrize('threshold_type', ['topk', 'topk_hard'])
@pytest.mark.parametrize('node_mask_type', ['object', 'attributes'])
def test_topk_threshold(data, threshold_value, threshold_type, node_mask_type):
    explainer = Explainer(
        DummyModel(),
        algorithm=DummyExplainer(),
        explanation_type='model',
        node_mask_type=node_mask_type,
        edge_mask_type='object',
        model_config=dict(
            mode='regression',
            task_level='graph',
        ),
        threshold_config=(threshold_type, threshold_value),
    )
    explanation = explainer(data.x, data.edge_index)

    assert 'node_mask' in explanation.available_explanations
    assert 'edge_mask' in explanation.available_explanations

    for key in explanation.available_explanations:
        mask = explanation[key]
        if threshold_type == 'topk':
            assert (mask > 0).sum() == min(mask.numel(), threshold_value)
            assert ((mask == 0).sum() == mask.numel() -
                    min(mask.numel(), threshold_value))
        else:
            assert (mask == 1).sum() == min(mask.numel(), threshold_value)
            assert ((mask == 0).sum() == mask.numel() -
                    min(mask.numel(), threshold_value))