1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
|
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union
from torch_geometric.utils.mixin import CastMixin
class ExplanationType(Enum):
"""Enum class for the explanation type."""
model = 'model'
phenomenon = 'phenomenon'
class MaskType(Enum):
"""Enum class for the mask type."""
object = 'object'
common_attributes = 'common_attributes'
attributes = 'attributes'
class ModelMode(Enum):
"""Enum class for the model return type."""
binary_classification = 'binary_classification'
multiclass_classification = 'multiclass_classification'
regression = 'regression'
class ModelTaskLevel(Enum):
"""Enum class for the model task level."""
node = 'node'
edge = 'edge'
graph = 'graph'
class ModelReturnType(Enum):
"""Enum class for the model return type."""
raw = 'raw'
probs = 'probs'
log_probs = 'log_probs'
class ThresholdType(Enum):
"""Enum class for the threshold type."""
hard = 'hard'
topk = 'topk'
topk_hard = 'topk_hard'
# connected = 'connected' # TODO
@dataclass
class ExplainerConfig(CastMixin):
r"""Configuration class to store and validate high level explanation
parameters.
Args:
explanation_type (ExplanationType or str): The type of explanation to
compute. The possible values are:
- :obj:`"model"`: Explains the model prediction.
- :obj:`"phenomenon"`: Explains the phenomenon that the model
is trying to predict.
In practice, this means that the explanation algorithm will either
compute their losses with respect to the model output
(:obj:`"model"`) or the target output (:obj:`"phenomenon"`).
node_mask_type (MaskType or str, optional): The type of mask to apply
on nodes. The possible values are (default: :obj:`None`):
- :obj:`None`: Will not apply any mask on nodes.
- :obj:`"object"`: Will mask each node.
- :obj:`"common_attributes"`: Will mask each feature.
- :obj:`"attributes"`: Will mask each feature across all nodes.
edge_mask_type (MaskType or str, optional): The type of mask to apply
on edges. Has the sample possible values as :obj:`node_mask_type`.
(default: :obj:`None`)
"""
explanation_type: ExplanationType
node_mask_type: Optional[MaskType]
edge_mask_type: Optional[MaskType]
def __init__(
self,
explanation_type: Union[ExplanationType, str],
node_mask_type: Optional[Union[MaskType, str]] = None,
edge_mask_type: Optional[Union[MaskType, str]] = None,
):
if node_mask_type is not None:
node_mask_type = MaskType(node_mask_type)
if edge_mask_type is not None:
edge_mask_type = MaskType(edge_mask_type)
if edge_mask_type is not None and edge_mask_type != MaskType.object:
raise ValueError(f"'edge_mask_type' needs be None or of type "
f"'object' (got '{edge_mask_type.value}')")
if node_mask_type is None and edge_mask_type is None:
raise ValueError("Either 'node_mask_type' or 'edge_mask_type' "
"must be provided")
self.explanation_type = ExplanationType(explanation_type)
self.node_mask_type = node_mask_type
self.edge_mask_type = edge_mask_type
@dataclass
class ModelConfig(CastMixin):
r"""Configuration class to store model parameters.
Args:
mode (ModelMode or str): The mode of the model. The possible values
are:
- :obj:`"binary_classification"`: A binary classification
model.
- :obj:`"multiclass_classification"`: A multiclass
classification model.
- :obj:`"regression"`: A regression model.
task_level (ModelTaskLevel or str): The task-level of the model.
The possible values are:
- :obj:`"node"`: A node-level prediction model.
- :obj:`"edge"`: An edge-level prediction model.
- :obj:`"graph"`: A graph-level prediction model.
return_type (ModelReturnType or str, optional): The return type of the
model. The possible values are (default: :obj:`None`):
- :obj:`"raw"`: The model returns raw values.
- :obj:`"probs"`: The model returns probabilities.
- :obj:`"log_probs"`: The model returns log-probabilities.
"""
mode: ModelMode
task_level: ModelTaskLevel
return_type: ModelReturnType
def __init__(
self,
mode: Union[ModelMode, str],
task_level: Union[ModelTaskLevel, str],
return_type: Optional[Union[ModelReturnType, str]] = None,
):
self.mode = ModelMode(mode)
self.task_level = ModelTaskLevel(task_level)
if return_type is None and self.mode == ModelMode.regression:
return_type = ModelReturnType.raw
self.return_type = ModelReturnType(return_type)
if (self.mode == ModelMode.regression
and self.return_type != ModelReturnType.raw):
raise ValueError(f"A model for regression needs to return raw "
f"outputs (got {self.return_type.value})")
if (self.mode == ModelMode.binary_classification and self.return_type
not in [ModelReturnType.raw, ModelReturnType.probs]):
raise ValueError(
f"A model for binary classification needs to return raw "
f"outputs or probabilities (got {self.return_type.value})")
@dataclass
class ThresholdConfig(CastMixin):
r"""Configuration class to store and validate threshold parameters.
Args:
threshold_type (ThresholdType or str): The type of threshold to apply.
The possible values are:
- :obj:`None`: No threshold is applied.
- :obj:`"hard"`: A hard threshold is applied to each mask.
The elements of the mask with a value below the :obj:`value`
are set to :obj:`0`, the others are set to :obj:`1`.
- :obj:`"topk"`: A soft threshold is applied to each mask.
The top obj:`value` elements of each mask are kept, the
others are set to :obj:`0`.
- :obj:`"topk_hard"`: Same as :obj:`"topk"` but values are set
to :obj:`1` for all elements which are kept.
value (int or float, optional): The value to use when thresholding.
(default: :obj:`None`)
"""
type: ThresholdType
value: Union[float, int]
def __init__(
self,
threshold_type: Union[ThresholdType, str],
value: Union[float, int],
):
self.type = ThresholdType(threshold_type)
self.value = value
if not isinstance(self.value, (int, float)):
raise ValueError(f"Threshold value must be a float or int "
f"(got {type(self.value)}).")
if (self.type == ThresholdType.hard
and (self.value < 0 or self.value > 1)):
raise ValueError(f"Threshold value must be between 0 and 1 "
f"(got {self.value})")
if self.type in [ThresholdType.topk, ThresholdType.topk_hard]:
if not isinstance(self.value, int):
raise ValueError(f"Threshold value needs to be an integer "
f"(got {type(self.value)}).")
if self.value <= 0:
raise ValueError(f"Threshold value needs to be positive "
f"(got {self.value}).")
|