1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
|
from abc import abstractmethod
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.explain import Explanation, HeteroExplanation
from torch_geometric.explain.config import (
ExplainerConfig,
ModelConfig,
ModelReturnType,
)
from torch_geometric.nn import MessagePassing
from torch_geometric.typing import EdgeType, NodeType
from torch_geometric.utils import k_hop_subgraph
class ExplainerAlgorithm(torch.nn.Module):
r"""An abstract base class for implementing explainer algorithms."""
@abstractmethod
def forward(
self,
model: torch.nn.Module,
x: Union[Tensor, Dict[NodeType, Tensor]],
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
*,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
) -> Union[Explanation, HeteroExplanation]:
r"""Computes the explanation.
Args:
model (torch.nn.Module): The model to explain.
x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The input
node features of a homogeneous or heterogeneous graph.
edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The
input edge indices of a homogeneous or heterogeneous graph.
target (torch.Tensor): The target of the model.
index (Union[int, Tensor], optional): The index of the model
output to explain. Can be a single index or a tensor of
indices. (default: :obj:`None`)
**kwargs (optional): Additional keyword arguments passed to
:obj:`model`.
"""
@abstractmethod
def supports(self) -> bool:
r"""Checks if the explainer supports the user-defined settings provided
in :obj:`self.explainer_config`, :obj:`self.model_config`.
"""
###########################################################################
@property
def explainer_config(self) -> ExplainerConfig:
r"""Returns the connected explainer configuration."""
if not hasattr(self, '_explainer_config'):
raise ValueError(
f"The explanation algorithm '{self.__class__.__name__}' is "
f"not yet connected to any explainer configuration. Please "
f"call `{self.__class__.__name__}.connect(...)` before "
f"proceeding.")
return self._explainer_config
@property
def model_config(self) -> ModelConfig:
r"""Returns the connected model configuration."""
if not hasattr(self, '_model_config'):
raise ValueError(
f"The explanation algorithm '{self.__class__.__name__}' is "
f"not yet connected to any model configuration. Please call "
f"`{self.__class__.__name__}.connect(...)` before "
f"proceeding.")
return self._model_config
def connect(
self,
explainer_config: ExplainerConfig,
model_config: ModelConfig,
):
r"""Connects an explainer and model configuration to the explainer
algorithm.
"""
self._explainer_config = ExplainerConfig.cast(explainer_config)
self._model_config = ModelConfig.cast(model_config)
if not self.supports():
raise ValueError(
f"The explanation algorithm '{self.__class__.__name__}' does "
f"not support the given explanation settings.")
# Helper functions ########################################################
@staticmethod
def _post_process_mask(
mask: Optional[Tensor],
hard_mask: Optional[Tensor] = None,
apply_sigmoid: bool = True,
) -> Optional[Tensor]:
r""""Post processes any mask to not include any attributions of
elements not involved during message passing.
"""
if mask is None:
return mask
mask = mask.detach()
if apply_sigmoid:
mask = mask.sigmoid()
if hard_mask is not None and mask.size(0) == hard_mask.size(0):
mask[~hard_mask] = 0.
return mask
@staticmethod
def _get_hard_masks(
model: torch.nn.Module,
node_index: Optional[Union[int, Tensor]],
edge_index: Tensor,
num_nodes: int,
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
r"""Returns hard node and edge masks that only include the nodes and
edges visited during message passing.
"""
if node_index is None:
return None, None # Consider all nodes and edges.
index, _, _, edge_mask = k_hop_subgraph(
node_index,
num_hops=ExplainerAlgorithm._num_hops(model),
edge_index=edge_index,
num_nodes=num_nodes,
flow=ExplainerAlgorithm._flow(model),
)
node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool)
node_mask[index] = True
return node_mask, edge_mask
@staticmethod
def _num_hops(model: torch.nn.Module) -> int:
r"""Returns the number of hops the :obj:`model` is aggregating
information from.
"""
num_hops = 0
for module in model.modules():
if isinstance(module, MessagePassing):
num_hops += 1
return num_hops
@staticmethod
def _flow(model: torch.nn.Module) -> str:
r"""Determines the message passing flow of the :obj:`model`."""
for module in model.modules():
if isinstance(module, MessagePassing):
return module.flow
return 'source_to_target'
def _loss_binary_classification(self, y_hat: Tensor, y: Tensor) -> Tensor:
if self.model_config.return_type == ModelReturnType.raw:
loss_fn = F.binary_cross_entropy_with_logits
elif self.model_config.return_type == ModelReturnType.probs:
loss_fn = F.binary_cross_entropy
else:
raise AssertionError()
return loss_fn(y_hat.view_as(y), y.float())
def _loss_multiclass_classification(
self,
y_hat: Tensor,
y: Tensor,
) -> Tensor:
if self.model_config.return_type == ModelReturnType.raw:
loss_fn = F.cross_entropy
elif self.model_config.return_type == ModelReturnType.probs:
loss_fn = F.nll_loss
y_hat = y_hat.log()
elif self.model_config.return_type == ModelReturnType.log_probs:
loss_fn = F.nll_loss
else:
raise AssertionError()
return loss_fn(y_hat, y)
def _loss_regression(self, y_hat: Tensor, y: Tensor) -> Tensor:
assert self.model_config.return_type == ModelReturnType.raw
return F.mse_loss(y_hat, y)
def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
|