File: utils.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (76 lines) | stat: -rw-r--r-- 2,564 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import Dict, Union

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn import MessagePassing
from torch_geometric.typing import EdgeType


def set_masks(
    model: torch.nn.Module,
    mask: Union[Tensor, Parameter],
    edge_index: Tensor,
    apply_sigmoid: bool = True,
):
    r"""Apply mask to every graph layer in the :obj:`model`."""
    loop_mask = edge_index[0] != edge_index[1]

    # Loop over layers and set masks on MessagePassing layers:
    for module in model.modules():
        if isinstance(module, MessagePassing):
            # Skip layers that have been explicitly set to `False`:
            if module.explain is False:
                continue

            # Convert mask to a param if it was previously registered as one.
            # This is a workaround for the fact that PyTorch does not allow
            # assignments of pure tensors to parameter attributes:
            if (not isinstance(mask, Parameter)
                    and '_edge_mask' in module._parameters):
                mask = Parameter(mask)

            module.explain = True
            module._edge_mask = mask
            module._loop_mask = loop_mask
            module._apply_sigmoid = apply_sigmoid


def set_hetero_masks(
    model: torch.nn.Module,
    mask_dict: Dict[EdgeType, Union[Tensor, Parameter]],
    edge_index_dict: Dict[EdgeType, Tensor],
    apply_sigmoid: bool = True,
):
    r"""Apply masks to every heterogeneous graph layer in the :obj:`model`
    according to edge types.
    """
    for module in model.modules():
        if isinstance(module, torch.nn.ModuleDict):
            for edge_type in mask_dict.keys():
                if edge_type in module:
                    edge_level_module = module[edge_type]
                elif '__'.join(edge_type) in module:
                    edge_level_module = module['__'.join(edge_type)]
                else:
                    continue

                set_masks(
                    edge_level_module,
                    mask_dict[edge_type],
                    edge_index_dict[edge_type],
                    apply_sigmoid=apply_sigmoid,
                )


def clear_masks(model: torch.nn.Module):
    r"""Clear all masks from the model."""
    for module in model.modules():
        if isinstance(module, MessagePassing):
            if module.explain is True:
                module.explain = None
            module._edge_mask = None
            module._loop_mask = None
            module._apply_sigmoid = True
    return module