File: dummy_explainer.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (72 lines) | stat: -rw-r--r-- 2,872 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from collections import defaultdict
from typing import Dict, Optional, Union

import torch
from torch import Tensor

from torch_geometric.explain import Explanation, HeteroExplanation
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.config import MaskType
from torch_geometric.typing import EdgeType, NodeType


class DummyExplainer(ExplainerAlgorithm):
    r"""A dummy explainer that returns random explanations (useful for testing
    purposes).
    """
    def forward(
        self,
        model: torch.nn.Module,
        x: Union[Tensor, Dict[NodeType, Tensor]],
        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
        edge_attr: Optional[Union[Tensor, Dict[EdgeType, Tensor]]] = None,
        **kwargs,
    ) -> Union[Explanation, HeteroExplanation]:
        assert isinstance(x, (Tensor, dict))

        node_mask_type = self.explainer_config.node_mask_type
        edge_mask_type = self.explainer_config.edge_mask_type

        if isinstance(x, Tensor):  # Homogeneous graph.
            assert isinstance(edge_index, Tensor)

            node_mask = None
            if node_mask_type == MaskType.object:
                node_mask = torch.rand(x.size(0), 1, device=x.device)
            elif node_mask_type == MaskType.common_attributes:
                node_mask = torch.rand(1, x.size(1), device=x.device)
            elif node_mask_type == MaskType.attributes:
                node_mask = torch.rand_like(x)

            edge_mask = None
            if edge_mask_type == MaskType.object:
                edge_mask = torch.rand(edge_index.size(1), device=x.device)

            return Explanation(node_mask=node_mask, edge_mask=edge_mask)
        else:  # isinstance(x, dict):  # Heterogeneous graph.
            assert isinstance(edge_index, dict)

            node_dict = defaultdict(dict)
            for k, v in x.items():
                node_mask = None
                if node_mask_type == MaskType.object:
                    node_mask = torch.rand(v.size(0), 1, device=v.device)
                elif node_mask_type == MaskType.common_attributes:
                    node_mask = torch.rand(1, v.size(1), device=v.device)
                elif node_mask_type == MaskType.attributes:
                    node_mask = torch.rand_like(v)
                if node_mask is not None:
                    node_dict[k]['node_mask'] = node_mask

            edge_dict = defaultdict(dict)
            for k, v in edge_index.items():
                edge_mask = None
                if edge_mask_type == MaskType.object:
                    edge_mask = torch.rand(v.size(1), device=v.device)
                if edge_mask is not None:
                    edge_dict[k]['edge_mask'] = edge_mask

            return HeteroExplanation({**node_dict, **edge_dict})

    def supports(self) -> bool:
        return True