File: config.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (225 lines) | stat: -rw-r--r-- 7,834 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union

from torch_geometric.utils.mixin import CastMixin


class ExplanationType(Enum):
    """Enum class for the explanation type."""
    model = 'model'
    phenomenon = 'phenomenon'


class MaskType(Enum):
    """Enum class for the mask type."""
    object = 'object'
    common_attributes = 'common_attributes'
    attributes = 'attributes'


class ModelMode(Enum):
    """Enum class for the model return type."""
    binary_classification = 'binary_classification'
    multiclass_classification = 'multiclass_classification'
    regression = 'regression'


class ModelTaskLevel(Enum):
    """Enum class for the model task level."""
    node = 'node'
    edge = 'edge'
    graph = 'graph'


class ModelReturnType(Enum):
    """Enum class for the model return type."""
    raw = 'raw'
    probs = 'probs'
    log_probs = 'log_probs'


class ThresholdType(Enum):
    """Enum class for the threshold type."""
    hard = 'hard'
    topk = 'topk'
    topk_hard = 'topk_hard'
    # connected = 'connected'  # TODO


@dataclass
class ExplainerConfig(CastMixin):
    r"""Configuration class to store and validate high level explanation
    parameters.

    Args:
        explanation_type (ExplanationType or str): The type of explanation to
            compute. The possible values are:

                - :obj:`"model"`: Explains the model prediction.

                - :obj:`"phenomenon"`: Explains the phenomenon that the model
                  is trying to predict.

            In practice, this means that the explanation algorithm will either
            compute their losses with respect to the model output
            (:obj:`"model"`) or the target output (:obj:`"phenomenon"`).

        node_mask_type (MaskType or str, optional): The type of mask to apply
            on nodes. The possible values are (default: :obj:`None`):

                - :obj:`None`: Will not apply any mask on nodes.

                - :obj:`"object"`: Will mask each node.

                - :obj:`"common_attributes"`: Will mask each feature.

                - :obj:`"attributes"`: Will mask each feature across all nodes.

        edge_mask_type (MaskType or str, optional): The type of mask to apply
            on edges. Has the sample possible values as :obj:`node_mask_type`.
            (default: :obj:`None`)
    """
    explanation_type: ExplanationType
    node_mask_type: Optional[MaskType]
    edge_mask_type: Optional[MaskType]

    def __init__(
        self,
        explanation_type: Union[ExplanationType, str],
        node_mask_type: Optional[Union[MaskType, str]] = None,
        edge_mask_type: Optional[Union[MaskType, str]] = None,
    ):
        if node_mask_type is not None:
            node_mask_type = MaskType(node_mask_type)
        if edge_mask_type is not None:
            edge_mask_type = MaskType(edge_mask_type)

        if edge_mask_type is not None and edge_mask_type != MaskType.object:
            raise ValueError(f"'edge_mask_type' needs be None or of type "
                             f"'object' (got '{edge_mask_type.value}')")

        if node_mask_type is None and edge_mask_type is None:
            raise ValueError("Either 'node_mask_type' or 'edge_mask_type' "
                             "must be provided")

        self.explanation_type = ExplanationType(explanation_type)
        self.node_mask_type = node_mask_type
        self.edge_mask_type = edge_mask_type


@dataclass
class ModelConfig(CastMixin):
    r"""Configuration class to store model parameters.

    Args:
        mode (ModelMode or str): The mode of the model. The possible values
            are:

                - :obj:`"binary_classification"`: A binary classification
                  model.

                - :obj:`"multiclass_classification"`: A multiclass
                  classification model.

                - :obj:`"regression"`: A regression model.

        task_level (ModelTaskLevel or str): The task-level of the model.
            The possible values are:

                - :obj:`"node"`: A node-level prediction model.

                - :obj:`"edge"`: An edge-level prediction model.

                - :obj:`"graph"`: A graph-level prediction model.

        return_type (ModelReturnType or str, optional): The return type of the
            model. The possible values are (default: :obj:`None`):

                - :obj:`"raw"`: The model returns raw values.

                - :obj:`"probs"`: The model returns probabilities.

                - :obj:`"log_probs"`: The model returns log-probabilities.
    """
    mode: ModelMode
    task_level: ModelTaskLevel
    return_type: ModelReturnType

    def __init__(
        self,
        mode: Union[ModelMode, str],
        task_level: Union[ModelTaskLevel, str],
        return_type: Optional[Union[ModelReturnType, str]] = None,
    ):
        self.mode = ModelMode(mode)
        self.task_level = ModelTaskLevel(task_level)

        if return_type is None and self.mode == ModelMode.regression:
            return_type = ModelReturnType.raw

        self.return_type = ModelReturnType(return_type)

        if (self.mode == ModelMode.regression
                and self.return_type != ModelReturnType.raw):
            raise ValueError(f"A model for regression needs to return raw "
                             f"outputs (got {self.return_type.value})")

        if (self.mode == ModelMode.binary_classification and self.return_type
                not in [ModelReturnType.raw, ModelReturnType.probs]):
            raise ValueError(
                f"A model for binary classification needs to return raw "
                f"outputs or probabilities (got {self.return_type.value})")


@dataclass
class ThresholdConfig(CastMixin):
    r"""Configuration class to store and validate threshold parameters.

    Args:
        threshold_type (ThresholdType or str): The type of threshold to apply.
            The possible values are:

                - :obj:`None`: No threshold is applied.

                - :obj:`"hard"`: A hard threshold is applied to each mask.
                  The elements of the mask with a value below the :obj:`value`
                  are set to :obj:`0`, the others are set to :obj:`1`.

                - :obj:`"topk"`: A soft threshold is applied to each mask.
                  The top obj:`value` elements of each mask are kept, the
                  others are set to :obj:`0`.

                - :obj:`"topk_hard"`: Same as :obj:`"topk"` but values are set
                  to :obj:`1` for all elements which are kept.

        value (int or float, optional): The value to use when thresholding.
            (default: :obj:`None`)
    """
    type: ThresholdType
    value: Union[float, int]

    def __init__(
        self,
        threshold_type: Union[ThresholdType, str],
        value: Union[float, int],
    ):
        self.type = ThresholdType(threshold_type)
        self.value = value

        if not isinstance(self.value, (int, float)):
            raise ValueError(f"Threshold value must be a float or int "
                             f"(got {type(self.value)}).")

        if (self.type == ThresholdType.hard
                and (self.value < 0 or self.value > 1)):
            raise ValueError(f"Threshold value must be between 0 and 1 "
                             f"(got {self.value})")

        if self.type in [ThresholdType.topk, ThresholdType.topk_hard]:
            if not isinstance(self.value, int):
                raise ValueError(f"Threshold value needs to be an integer "
                                 f"(got {type(self.value)}).")
            if self.value <= 0:
                raise ValueError(f"Threshold value needs to be positive "
                                 f"(got {self.value}).")