File: test_captum_hetero.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (106 lines) | stat: -rw-r--r-- 3,979 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import pytest

from torch_geometric.explain.algorithm.captum import (
    CaptumHeteroModel,
    captum_output_to_dicts,
    to_captum_input,
)
from torch_geometric.nn import to_captum_model
from torch_geometric.testing import withPackage

mask_types = [
    'node',
    'edge',
    'node_and_edge',
]

methods = [
    'Saliency',
    'InputXGradient',
    'Deconvolution',
    'FeatureAblation',
    'ShapleyValueSampling',
    'IntegratedGradients',
    'GradientShap',
    'Occlusion',
    'GuidedBackprop',
    'KernelShap',
    'Lime',
]


@withPackage('captum')
@pytest.mark.parametrize('mask_type', mask_types)
@pytest.mark.parametrize('method', methods)
def test_captum_attribution_methods_hetero(mask_type, method, hetero_data,
                                           hetero_model):
    from captum import attr  # noqa
    data = hetero_data
    metadata = data.metadata()
    model = hetero_model(metadata)
    captum_model = to_captum_model(model, mask_type, 0, metadata)
    explainer = getattr(attr, method)(captum_model)
    assert isinstance(captum_model, CaptumHeteroModel)

    inputs, additional_forward_args = to_captum_input(
        data.x_dict,
        data.edge_index_dict,
        mask_type,
        'additional_arg',
    )

    if mask_type == 'node':
        sliding_window_shapes = ((3, 3), (3, 3))
    elif mask_type == 'edge':
        sliding_window_shapes = ((5, ), (5, ), (5, ))
    else:
        sliding_window_shapes = ((3, 3), (3, 3), (5, ), (5, ), (5, ))

    if method == 'IntegratedGradients':
        attributions, delta = explainer.attribute(
            inputs, target=0, internal_batch_size=1,
            additional_forward_args=additional_forward_args,
            return_convergence_delta=True)
    elif method == 'GradientShap':
        attributions, delta = explainer.attribute(
            inputs, target=0, return_convergence_delta=True, baselines=inputs,
            n_samples=1, additional_forward_args=additional_forward_args)
    elif method == 'DeepLiftShap' or method == 'DeepLift':
        attributions, delta = explainer.attribute(
            inputs, target=0, return_convergence_delta=True, baselines=inputs,
            additional_forward_args=additional_forward_args)
    elif method == 'Occlusion':
        attributions = explainer.attribute(
            inputs, target=0, sliding_window_shapes=sliding_window_shapes,
            additional_forward_args=additional_forward_args)
    else:
        attributions = explainer.attribute(
            inputs, target=0, additional_forward_args=additional_forward_args)

    if mask_type == 'node':
        assert len(attributions) == len(metadata[0])
        x_attr_dict, _ = captum_output_to_dicts(attributions, mask_type,
                                                metadata)
        for node_type in metadata[0]:
            num_nodes = data[node_type].num_nodes
            num_node_feats = data[node_type].x.shape[1]
            assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats)
    elif mask_type == 'edge':
        assert len(attributions) == len(metadata[1])
        _, edge_attr_dict = captum_output_to_dicts(attributions, mask_type,
                                                   metadata)
        for edge_type in metadata[1]:
            num_edges = data[edge_type].edge_index.shape[1]
            assert edge_attr_dict[edge_type].shape == (num_edges, )
    else:
        assert len(attributions) == len(metadata[0]) + len(metadata[1])
        x_attr_dict, edge_attr_dict = captum_output_to_dicts(
            attributions, mask_type, metadata)
        for edge_type in metadata[1]:
            num_edges = data[edge_type].edge_index.shape[1]
            assert edge_attr_dict[edge_type].shape == (num_edges, )

        for node_type in metadata[0]:
            num_nodes = data[node_type].num_nodes
            num_node_feats = data[node_type].x.shape[1]
            assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats)