File: attention_explainer.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (304 lines) | stat: -rw-r--r-- 11,304 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import logging
from typing import Dict, List, Optional, Union, overload

import torch
from torch import Tensor

from torch_geometric.explain import Explanation, HeteroExplanation
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.config import ExplanationType, ModelTaskLevel
from torch_geometric.nn.conv.message_passing import MessagePassing
from torch_geometric.typing import EdgeType, NodeType


class AttentionExplainer(ExplainerAlgorithm):
    r"""An explainer that uses the attention coefficients produced by an
    attention-based GNN (*e.g.*,
    :class:`~torch_geometric.nn.conv.GATConv`,
    :class:`~torch_geometric.nn.conv.GATv2Conv`, or
    :class:`~torch_geometric.nn.conv.TransformerConv`) as edge explanation.
    Attention scores across layers and heads will be aggregated according to
    the :obj:`reduce` argument.

    Args:
        reduce (str, optional): The method to reduce the attention scores
            across layers and heads. (default: :obj:`"max"`)
    """
    def __init__(self, reduce: str = 'max'):
        super().__init__()
        self.reduce = reduce
        self.is_hetero = False

    @overload
    def forward(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ) -> Explanation:
        ...

    @overload
    def forward(
        self,
        model: torch.nn.Module,
        x: Dict[NodeType, Tensor],
        edge_index: Dict[EdgeType, Tensor],
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ) -> HeteroExplanation:
        ...

    def forward(
        self,
        model: torch.nn.Module,
        x: Union[Tensor, Dict[NodeType, Tensor]],
        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ) -> Union[Explanation, HeteroExplanation]:
        """Generate explanations based on attention coefficients."""
        self.is_hetero = isinstance(x, dict)

        # Collect attention coefficients
        alphas_dict = self._collect_attention_coefficients(
            model, x, edge_index, **kwargs)

        # Process attention coefficients
        if self.is_hetero:
            return self._create_hetero_explanation(model, alphas_dict,
                                                   edge_index, index, x)
        else:
            return self._create_homo_explanation(model, alphas_dict,
                                                 edge_index, index, x)

    @overload
    def _collect_attention_coefficients(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        **kwargs,
    ) -> List[Tensor]:
        ...

    @overload
    def _collect_attention_coefficients(
        self,
        model: torch.nn.Module,
        x: Dict[NodeType, Tensor],
        edge_index: Dict[EdgeType, Tensor],
        **kwargs,
    ) -> Dict[EdgeType, List[Tensor]]:
        ...

    def _collect_attention_coefficients(
        self,
        model: torch.nn.Module,
        x: Union[Tensor, Dict[NodeType, Tensor]],
        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
        **kwargs,
    ) -> Union[List[Tensor], Dict[EdgeType, List[Tensor]]]:
        """Collect attention coefficients from model layers."""
        if self.is_hetero:
            # For heterogeneous graphs, store alphas by edge type
            alphas_dict: Dict[EdgeType, List[Tensor]] = {}

            # Get list of edge types
            edge_types = list(edge_index.keys())

            # Hook function to capture attention coefficients by edge type
            def hook(module, msg_kwargs, out):
                # Find edge type from the module's full name
                module_name = getattr(module, '_name', None)
                if module_name is None:
                    return

                edge_type = None
                for edge_tuple in edge_types:
                    src_type, edge_name, dst_type = edge_tuple
                    # Check if all components appear in the module name in
                    # order
                    try:
                        src_idx = module_name.index(src_type)
                        edge_idx = module_name.index(edge_name, src_idx)
                        dst_idx = module_name.index(dst_type, edge_idx)
                        if src_idx < edge_idx < dst_idx:
                            edge_type = edge_tuple
                            break
                    except ValueError:  # Component not found
                        continue

                if edge_type is None:
                    return

                if edge_type not in alphas_dict:
                    alphas_dict[edge_type] = []

                # Extract alpha from message kwargs or module
                if 'alpha' in msg_kwargs[0]:
                    alphas_dict[edge_type].append(
                        msg_kwargs[0]['alpha'].detach())
                elif getattr(module, '_alpha', None) is not None:
                    alphas_dict[edge_type].append(module._alpha.detach())
        else:
            # For homogeneous graphs, store all alphas in a list
            alphas: List[Tensor] = []

            def hook(module, msg_kwargs, out):
                if 'alpha' in msg_kwargs[0]:
                    alphas.append(msg_kwargs[0]['alpha'].detach())
                elif getattr(module, '_alpha', None) is not None:
                    alphas.append(module._alpha.detach())

        # Register hooks for all message passing modules
        hook_handles = []
        for name, module in model.named_modules():
            if isinstance(module,
                          MessagePassing) and module.explain is not False:
                # Store name for hetero graph lookup in the hook
                if self.is_hetero:
                    module._name = name

                hook_handles.append(module.register_message_forward_hook(hook))

        # Forward pass to collect attention coefficients.
        model(x, edge_index, **kwargs)

        # Remove hooks
        for handle in hook_handles:
            handle.remove()

        # Check if we collected any attention coefficients.
        if self.is_hetero:
            if not alphas_dict:
                raise ValueError(
                    "Could not collect any attention coefficients. "
                    "Please ensure that your model is using "
                    "attention-based GNN layers.")
            return alphas_dict
        else:
            if not alphas:
                raise ValueError(
                    "Could not collect any attention coefficients. "
                    "Please ensure that your model is using "
                    "attention-based GNN layers.")
            return alphas

    def _process_attention_coefficients(
        self,
        alphas: List[Tensor],
        edge_index_size: int,
    ) -> Tensor:
        """Process collected attention coefficients into a single mask."""
        for i, alpha in enumerate(alphas):
            # Ensure alpha doesn't exceed edge_index size
            alpha = alpha[:edge_index_size]

            # Reduce multi-head attention
            if alpha.dim() == 2:
                alpha = getattr(torch, self.reduce)(alpha, dim=-1)
                if isinstance(alpha, tuple):  # Handle torch.max output
                    alpha = alpha[0]
            elif alpha.dim() > 2:
                raise ValueError(f"Cannot reduce attention coefficients of "
                                 f"shape {list(alpha.size())}")
            alphas[i] = alpha

        # Combine attention coefficients across layers
        if len(alphas) > 1:
            alpha = torch.stack(alphas, dim=-1)
            alpha = getattr(torch, self.reduce)(alpha, dim=-1)
            if isinstance(alpha, tuple):  # Handle torch.max output
                alpha = alpha[0]
        else:
            alpha = alphas[0]

        return alpha

    def _create_homo_explanation(
        self,
        model: torch.nn.Module,
        alphas: List[Tensor],
        edge_index: Tensor,
        index: Optional[Union[int, Tensor]],
        x: Tensor,
    ) -> Explanation:
        """Create explanation for homogeneous graph."""
        # Get hard edge mask for node-level tasks
        hard_edge_mask = None
        if self.model_config.task_level == ModelTaskLevel.node:
            _, hard_edge_mask = self._get_hard_masks(model, index, edge_index,
                                                     num_nodes=x.size(0))

        # Process attention coefficients
        alpha = self._process_attention_coefficients(alphas,
                                                     edge_index.size(1))

        # Post-process mask with hard edge mask if needed
        alpha = self._post_process_mask(alpha, hard_edge_mask,
                                        apply_sigmoid=False)

        return Explanation(edge_mask=alpha)

    def _create_hetero_explanation(
        self,
        model: torch.nn.Module,
        alphas_dict: Dict[EdgeType, List[Tensor]],
        edge_index: Dict[EdgeType, Tensor],
        index: Optional[Union[int, Tensor]],
        x: Dict[NodeType, Tensor],
    ) -> HeteroExplanation:
        """Create explanation for heterogeneous graph."""
        edge_masks_dict = {}

        # Process each edge type separately
        for edge_type, alphas in alphas_dict.items():
            if not alphas:
                continue

            # Get hard edge mask for node-level tasks
            hard_edge_mask = None
            if self.model_config.task_level == ModelTaskLevel.node:
                src_type, _, dst_type = edge_type
                _, hard_edge_mask = self._get_hard_masks(
                    model, index, edge_index[edge_type],
                    num_nodes=max(x[src_type].size(0), x[dst_type].size(0)))

            # Process attention coefficients for this edge type
            alpha = self._process_attention_coefficients(
                alphas, edge_index[edge_type].size(1))

            # Apply hard mask if available
            edge_masks_dict[edge_type] = self._post_process_mask(
                alpha, hard_edge_mask, apply_sigmoid=False)

        # Create heterogeneous explanation
        explanation = HeteroExplanation()
        explanation.set_value_dict('edge_mask', edge_masks_dict)
        return explanation

    def supports(self) -> bool:
        explanation_type = self.explainer_config.explanation_type
        if explanation_type != ExplanationType.model:
            logging.error(f"'{self.__class__.__name__}' only supports "
                          f"model explanations "
                          f"got (`explanation_type={explanation_type.value}`)")
            return False

        node_mask_type = self.explainer_config.node_mask_type
        if node_mask_type is not None:
            logging.error(f"'{self.__class__.__name__}' does not support "
                          f"explaining input node features "
                          f"got (`node_mask_type={node_mask_type.value}`)")
            return False

        return True