1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
|
import pytest
from torch_geometric.explain.algorithm.captum import (
CaptumHeteroModel,
captum_output_to_dicts,
to_captum_input,
)
from torch_geometric.nn import to_captum_model
from torch_geometric.testing import withPackage
mask_types = [
'node',
'edge',
'node_and_edge',
]
methods = [
'Saliency',
'InputXGradient',
'Deconvolution',
'FeatureAblation',
'ShapleyValueSampling',
'IntegratedGradients',
'GradientShap',
'Occlusion',
'GuidedBackprop',
'KernelShap',
'Lime',
]
@withPackage('captum')
@pytest.mark.parametrize('mask_type', mask_types)
@pytest.mark.parametrize('method', methods)
def test_captum_attribution_methods_hetero(mask_type, method, hetero_data,
hetero_model):
from captum import attr # noqa
data = hetero_data
metadata = data.metadata()
model = hetero_model(metadata)
captum_model = to_captum_model(model, mask_type, 0, metadata)
explainer = getattr(attr, method)(captum_model)
assert isinstance(captum_model, CaptumHeteroModel)
inputs, additional_forward_args = to_captum_input(
data.x_dict,
data.edge_index_dict,
mask_type,
'additional_arg',
)
if mask_type == 'node':
sliding_window_shapes = ((3, 3), (3, 3))
elif mask_type == 'edge':
sliding_window_shapes = ((5, ), (5, ), (5, ))
else:
sliding_window_shapes = ((3, 3), (3, 3), (5, ), (5, ), (5, ))
if method == 'IntegratedGradients':
attributions, delta = explainer.attribute(
inputs, target=0, internal_batch_size=1,
additional_forward_args=additional_forward_args,
return_convergence_delta=True)
elif method == 'GradientShap':
attributions, delta = explainer.attribute(
inputs, target=0, return_convergence_delta=True, baselines=inputs,
n_samples=1, additional_forward_args=additional_forward_args)
elif method == 'DeepLiftShap' or method == 'DeepLift':
attributions, delta = explainer.attribute(
inputs, target=0, return_convergence_delta=True, baselines=inputs,
additional_forward_args=additional_forward_args)
elif method == 'Occlusion':
attributions = explainer.attribute(
inputs, target=0, sliding_window_shapes=sliding_window_shapes,
additional_forward_args=additional_forward_args)
else:
attributions = explainer.attribute(
inputs, target=0, additional_forward_args=additional_forward_args)
if mask_type == 'node':
assert len(attributions) == len(metadata[0])
x_attr_dict, _ = captum_output_to_dicts(attributions, mask_type,
metadata)
for node_type in metadata[0]:
num_nodes = data[node_type].num_nodes
num_node_feats = data[node_type].x.shape[1]
assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats)
elif mask_type == 'edge':
assert len(attributions) == len(metadata[1])
_, edge_attr_dict = captum_output_to_dicts(attributions, mask_type,
metadata)
for edge_type in metadata[1]:
num_edges = data[edge_type].edge_index.shape[1]
assert edge_attr_dict[edge_type].shape == (num_edges, )
else:
assert len(attributions) == len(metadata[0]) + len(metadata[1])
x_attr_dict, edge_attr_dict = captum_output_to_dicts(
attributions, mask_type, metadata)
for edge_type in metadata[1]:
num_edges = data[edge_type].edge_index.shape[1]
assert edge_attr_dict[edge_type].shape == (num_edges, )
for node_type in metadata[0]:
num_nodes = data[node_type].num_nodes
num_node_feats = data[node_type].x.shape[1]
assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats)
|