1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
|
import inspect
import logging
import warnings
from typing import Any, Dict, Optional, Union
import torch
from torch import Tensor
from torch_geometric.explain import Explanation, HeteroExplanation
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.algorithm.captum import (
CaptumHeteroModel,
CaptumModel,
MaskLevelType,
convert_captum_output,
to_captum_input,
)
from torch_geometric.explain.config import MaskType, ModelMode, ModelReturnType
from torch_geometric.typing import EdgeType, NodeType
class CaptumExplainer(ExplainerAlgorithm):
"""A `Captum <https://captum.ai>`__-based explainer for identifying compact
subgraph structures and node features that play a crucial role in the
predictions made by a GNN.
This explainer algorithm uses :captum:`null` `Captum <https://captum.ai/>`_
to compute attributions.
Currently, the following attribution methods are supported:
* :class:`captum.attr.IntegratedGradients`
* :class:`captum.attr.Saliency`
* :class:`captum.attr.InputXGradient`
* :class:`captum.attr.Deconvolution`
* :class:`captum.attr.ShapleyValueSampling`
* :class:`captum.attr.GuidedBackprop`
Args:
attribution_method (Attribution or str): The Captum attribution method
to use. Can be a string or a :class:`captum.attr` method.
**kwargs: Additional arguments for the Captum attribution method.
"""
SUPPORTED_METHODS = [ # TODO: Add support for more methods.
'IntegratedGradients',
'Saliency',
'InputXGradient',
'Deconvolution',
'ShapleyValueSampling',
'GuidedBackprop',
]
def __init__(
self,
attribution_method: Union[str, Any],
**kwargs,
):
super().__init__()
import captum.attr
if isinstance(attribution_method, str):
self.attribution_method_class = getattr(
captum.attr,
attribution_method,
)
else:
self.attribution_method_class = attribution_method
if not self._is_supported_attribution_method():
raise ValueError(f"{self.__class__.__name__} does not support "
f"attribution method "
f"{self.attribution_method_class.__name__}")
if kwargs.get('internal_batch_size', 1) != 1:
warnings.warn("Overriding 'internal_batch_size' to 1",
stacklevel=2)
if 'internal_batch_size' in self._get_attribute_parameters():
kwargs['internal_batch_size'] = 1
self.kwargs = kwargs
def _get_mask_type(self) -> MaskLevelType:
r"""Based on the explainer config, return the mask type."""
node_mask_type = self.explainer_config.node_mask_type
edge_mask_type = self.explainer_config.edge_mask_type
if node_mask_type is not None and edge_mask_type is not None:
mask_type = MaskLevelType.node_and_edge
elif node_mask_type is not None:
mask_type = MaskLevelType.node
elif edge_mask_type is not None:
mask_type = MaskLevelType.edge
else:
raise ValueError("Neither node mask type nor "
"edge mask type is specified.")
return mask_type
def _get_attribute_parameters(self) -> Dict[str, Any]:
r"""Returns the attribute arguments."""
signature = inspect.signature(self.attribution_method_class.attribute)
return signature.parameters
def _needs_baseline(self) -> bool:
r"""Checks if the method needs a baseline."""
parameters = self._get_attribute_parameters()
if 'baselines' in parameters:
param = parameters['baselines']
if param.default is inspect.Parameter.empty:
return True
return False
def _is_supported_attribution_method(self) -> bool:
r"""Returns :obj:`True` if `self.attribution_method` is supported."""
# This is redundant for now since all supported methods need a baseline
if self._needs_baseline():
return False
elif self.attribution_method_class.__name__ in self.SUPPORTED_METHODS:
return True
return False
def forward(
self,
model: torch.nn.Module,
x: Union[Tensor, Dict[NodeType, Tensor]],
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
*,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
) -> Union[Explanation, HeteroExplanation]:
mask_type = self._get_mask_type()
inputs, add_forward_args = to_captum_input(
x,
edge_index,
mask_type,
*kwargs.values(),
)
if isinstance(x, dict): # Heterogeneous GNN:
metadata = (list(x.keys()), list(edge_index.keys()))
captum_model = CaptumHeteroModel(
model,
mask_type,
index,
metadata,
self.model_config,
)
else: # Homogeneous GNN:
metadata = None
captum_model = CaptumModel(
model,
mask_type,
index,
self.model_config,
)
self.attribution_method_instance = self.attribution_method_class(
captum_model)
# In Captum, the target is the class index for which the attribution is
# computed. Within CaptumModel, we transform the binary classification
# into a multi-class classification task.
if self.model_config.mode == ModelMode.regression:
target = None
elif index is not None:
target = target[index]
attributions = self.attribution_method_instance.attribute(
inputs=inputs,
target=target,
additional_forward_args=add_forward_args,
**self.kwargs,
)
node_mask, edge_mask = convert_captum_output(
attributions,
mask_type,
metadata,
)
if not isinstance(x, dict):
return Explanation(node_mask=node_mask, edge_mask=edge_mask)
explanation = HeteroExplanation()
explanation.set_value_dict('node_mask', node_mask)
explanation.set_value_dict('edge_mask', edge_mask)
return explanation
def supports(self) -> bool:
node_mask_type = self.explainer_config.node_mask_type
if node_mask_type not in [None, MaskType.attributes]:
logging.error(f"'{self.__class__.__name__}' expects "
f"'node_mask_type' to be 'None' or 'attributes' "
f"(got '{node_mask_type.value}')")
return False
return_type = self.model_config.return_type
if (self.model_config.mode == ModelMode.binary_classification
and return_type != ModelReturnType.probs):
logging.error(f"'{self.__class__.__name__}' expects "
f"'return_type' to be 'probs' for binary "
f"classification tasks (got '{return_type.value}')")
return False
# TODO (ramona) Confirm that output type is valid.
return True
|