File: attention_explainer.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (114 lines) | stat: -rw-r--r-- 4,545 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import logging
from typing import List, Optional, Union

import torch
from torch import Tensor

from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.config import ExplanationType, ModelTaskLevel
from torch_geometric.nn.conv.message_passing import MessagePassing


class AttentionExplainer(ExplainerAlgorithm):
    r"""An explainer that uses the attention coefficients produced by an
    attention-based GNN (*e.g.*,
    :class:`~torch_geometric.nn.conv.GATConv`,
    :class:`~torch_geometric.nn.conv.GATv2Conv`, or
    :class:`~torch_geometric.nn.conv.TransformerConv`) as edge explanation.
    Attention scores across layers and heads will be aggregated according to
    the :obj:`reduce` argument.

    Args:
        reduce (str, optional): The method to reduce the attention scores
            across layers and heads. (default: :obj:`"max"`)
    """
    def __init__(self, reduce: str = 'max'):
        super().__init__()
        self.reduce = reduce

    def forward(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ) -> Explanation:
        if isinstance(x, dict) or isinstance(edge_index, dict):
            raise ValueError(f"Heterogeneous graphs not yet supported in "
                             f"'{self.__class__.__name__}'")

        hard_edge_mask = None
        if self.model_config.task_level == ModelTaskLevel.node:
            # We need to compute the hard edge mask to properly clean up edge
            # attributions not involved during message passing:
            _, hard_edge_mask = self._get_hard_masks(model, index, edge_index,
                                                     num_nodes=x.size(0))

        alphas: List[Tensor] = []

        def hook(module, msg_kwargs, out):
            if 'alpha' in msg_kwargs[0]:
                alphas.append(msg_kwargs[0]['alpha'].detach())
            elif getattr(module, '_alpha', None) is not None:
                alphas.append(module._alpha.detach())

        hook_handles = []
        for module in model.modules():  # Register message forward hooks:
            if (isinstance(module, MessagePassing)
                    and module.explain is not False):
                hook_handles.append(module.register_message_forward_hook(hook))

        model(x, edge_index, **kwargs)

        for handle in hook_handles:  # Remove hooks:
            handle.remove()

        if len(alphas) == 0:
            raise ValueError("Could not collect any attention coefficients. "
                             "Please ensure that your model is using "
                             "attention-based GNN layers.")

        for i, alpha in enumerate(alphas):
            alpha = alpha[:edge_index.size(1)]  # Respect potential self-loops.
            if alpha.dim() == 2:
                alpha = getattr(torch, self.reduce)(alpha, dim=-1)
                if isinstance(alpha, tuple):  # Respect `torch.max`:
                    alpha = alpha[0]
            elif alpha.dim() > 2:
                raise ValueError(f"Can not reduce attention coefficients of "
                                 f"shape {list(alpha.size())}")
            alphas[i] = alpha

        if len(alphas) > 1:
            alpha = torch.stack(alphas, dim=-1)
            alpha = getattr(torch, self.reduce)(alpha, dim=-1)
            if isinstance(alpha, tuple):  # Respect `torch.max`:
                alpha = alpha[0]
        else:
            alpha = alphas[0]

        alpha = self._post_process_mask(alpha, hard_edge_mask,
                                        apply_sigmoid=False)

        return Explanation(edge_mask=alpha)

    def supports(self) -> bool:
        explanation_type = self.explainer_config.explanation_type
        if explanation_type != ExplanationType.model:
            logging.error(f"'{self.__class__.__name__}' only supports "
                          f"model explanations "
                          f"got (`explanation_type={explanation_type.value}`)")
            return False

        node_mask_type = self.explainer_config.node_mask_type
        if node_mask_type is not None:
            logging.error(f"'{self.__class__.__name__}' does not support "
                          f"explaining input node features "
                          f"got (`node_mask_type={node_mask_type.value}`)")
            return False

        return True