1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
|
import pytest
import torch
from torch_geometric.data import Data, HeteroData
from torch_geometric.explain.algorithm.captum import to_captum_input
from torch_geometric.nn import GAT, GCN, SAGEConv
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.models import to_captum_model
from torch_geometric.testing import withPackage
x = torch.randn(8, 3, requires_grad=True)
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
[1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])
GCN = GCN(3, 16, 2, 7, dropout=0.5)
GAT = GAT(3, 16, 2, 7, heads=2, concat=False)
mask_types = ['edge', 'node_and_edge', 'node']
methods = [
'Saliency',
'InputXGradient',
'Deconvolution',
'FeatureAblation',
'ShapleyValueSampling',
'IntegratedGradients',
'GradientShap',
'Occlusion',
'GuidedBackprop',
'KernelShap',
'Lime',
]
@pytest.mark.parametrize('mask_type', mask_types)
@pytest.mark.parametrize('model', [GCN, GAT])
@pytest.mark.parametrize('output_idx', [None, 1])
def test_to_captum(model, mask_type, output_idx):
captum_model = to_captum_model(model, mask_type=mask_type,
output_idx=output_idx)
pre_out = model(x, edge_index)
if mask_type == 'node':
mask = x * 0.0
out = captum_model(mask.unsqueeze(0), edge_index)
elif mask_type == 'edge':
mask = torch.ones(edge_index.shape[1], dtype=torch.float,
requires_grad=True) * 0.5
out = captum_model(mask.unsqueeze(0), x, edge_index)
elif mask_type == 'node_and_edge':
node_mask = x * 0.0
edge_mask = torch.ones(edge_index.shape[1], dtype=torch.float,
requires_grad=True) * 0.5
out = captum_model(node_mask.unsqueeze(0), edge_mask.unsqueeze(0),
edge_index)
if output_idx is not None:
assert out.shape == (1, 7)
assert torch.any(out != pre_out[[output_idx]])
else:
assert out.shape == (8, 7)
assert torch.any(out != pre_out)
@withPackage('captum')
@pytest.mark.parametrize('mask_type', mask_types)
@pytest.mark.parametrize('method', methods)
def test_captum_attribution_methods(mask_type, method):
from captum import attr # noqa
captum_model = to_captum_model(GCN, mask_type, 0)
explainer = getattr(attr, method)(captum_model)
data = Data(x, edge_index)
input, additional_forward_args = to_captum_input(data.x, data.edge_index,
mask_type)
if mask_type == 'node':
sliding_window_shapes = (3, 3)
elif mask_type == 'edge':
sliding_window_shapes = (5, )
elif mask_type == 'node_and_edge':
sliding_window_shapes = ((3, 3), (5, ))
if method == 'IntegratedGradients':
attributions, delta = explainer.attribute(
input, target=0, internal_batch_size=1,
additional_forward_args=additional_forward_args,
return_convergence_delta=True)
elif method == 'GradientShap':
attributions, delta = explainer.attribute(
input, target=0, return_convergence_delta=True, baselines=input,
n_samples=1, additional_forward_args=additional_forward_args)
elif method == 'DeepLiftShap' or method == 'DeepLift':
attributions, delta = explainer.attribute(
input, target=0, return_convergence_delta=True, baselines=input,
additional_forward_args=additional_forward_args)
elif method == 'Occlusion':
attributions = explainer.attribute(
input, target=0, sliding_window_shapes=sliding_window_shapes,
additional_forward_args=additional_forward_args)
else:
attributions = explainer.attribute(
input, target=0, additional_forward_args=additional_forward_args)
if mask_type == 'node':
assert attributions[0].shape == (1, 8, 3)
elif mask_type == 'edge':
assert attributions[0].shape == (1, 14)
else:
assert attributions[0].shape == (1, 8, 3)
assert attributions[1].shape == (1, 14)
def test_custom_explain_message():
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])
conv = SAGEConv(8, 32)
def explain_message(self, inputs, x_i, x_j):
assert isinstance(self, SAGEConv)
assert inputs.size() == (6, 8)
assert inputs.size() == x_i.size() == x_j.size()
assert torch.allclose(inputs, x_j)
self.x_i = x_i
self.x_j = x_j
return inputs
conv.explain_message = explain_message.__get__(conv, MessagePassing)
conv.explain = True
conv(x, edge_index)
assert torch.allclose(conv.x_i, x[edge_index[1]])
assert torch.allclose(conv.x_j, x[edge_index[0]])
@withPackage('captum')
@pytest.mark.parametrize('mask_type', ['node', 'edge', 'node_and_edge'])
def test_to_captum_input(mask_type):
num_nodes = x.shape[0]
num_node_feats = x.shape[1]
num_edges = edge_index.shape[1]
# Check for Data:
data = Data(x, edge_index)
args = 'test_args'
inputs, additional_forward_args = to_captum_input(data.x, data.edge_index,
mask_type, args)
if mask_type == 'node':
assert len(inputs) == 1
assert inputs[0].shape == (1, num_nodes, num_node_feats)
assert len(additional_forward_args) == 2
assert torch.allclose(additional_forward_args[0], edge_index)
elif mask_type == 'edge':
assert len(inputs) == 1
assert inputs[0].shape == (1, num_edges)
assert inputs[0].sum() == num_edges
assert len(additional_forward_args) == 3
assert torch.allclose(additional_forward_args[0], x)
assert torch.allclose(additional_forward_args[1], edge_index)
else:
assert len(inputs) == 2
assert inputs[0].shape == (1, num_nodes, num_node_feats)
assert inputs[1].shape == (1, num_edges)
assert inputs[1].sum() == num_edges
assert len(additional_forward_args) == 2
assert torch.allclose(additional_forward_args[0], edge_index)
# Check for HeteroData:
data = HeteroData()
x2 = torch.rand(8, 3)
data['paper'].x = x
data['author'].x = x2
data['paper', 'to', 'author'].edge_index = edge_index
data['author', 'to', 'paper'].edge_index = edge_index.flip([0])
inputs, additional_forward_args = to_captum_input(data.x_dict,
data.edge_index_dict,
mask_type, args)
if mask_type == 'node':
assert len(inputs) == 2
assert inputs[0].shape == (1, num_nodes, num_node_feats)
assert inputs[1].shape == (1, num_nodes, num_node_feats)
assert len(additional_forward_args) == 2
for key in data.edge_types:
torch.allclose(additional_forward_args[0][key],
data[key].edge_index)
elif mask_type == 'edge':
assert len(inputs) == 2
assert inputs[0].shape == (1, num_edges)
assert inputs[1].shape == (1, num_edges)
assert inputs[1].sum() == inputs[0].sum() == num_edges
assert len(additional_forward_args) == 3
for key in data.node_types:
torch.allclose(additional_forward_args[0][key], data[key].x)
for key in data.edge_types:
torch.allclose(additional_forward_args[1][key],
data[key].edge_index)
else:
assert len(inputs) == 4
assert inputs[0].shape == (1, num_nodes, num_node_feats)
assert inputs[1].shape == (1, num_nodes, num_node_feats)
assert inputs[2].shape == (1, num_edges)
assert inputs[3].shape == (1, num_edges)
assert inputs[3].sum() == inputs[2].sum() == num_edges
assert len(additional_forward_args) == 2
for key in data.edge_types:
torch.allclose(additional_forward_args[0][key],
data[key].edge_index)
|