1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
|
from typing import Sequence
import torch
from ignite.metrics.metric import reinit__is_reduced
from ignite.metrics.precision import _BasePrecisionRecall
__all__ = ["Recall"]
class Recall(_BasePrecisionRecall):
r"""Calculates recall for binary, multiclass and multilabel data.
.. math:: \text{Recall} = \frac{ TP }{ TP + FN }
where :math:`\text{TP}` is true positives and :math:`\text{FN}` is false negatives.
- ``update`` must receive output of the form ``(y_pred, y)``.
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...).
- `y` must be in the following shape (batch_size, ...).
Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
average: available options are
False
default option. For multicalss and multilabel inputs, per class and per label
metric is returned respectively.
None
like `False` option except that per class metric is returned for binary data as well.
For compatibility with Scikit-Learn api.
'micro'
Metric is computed counting stats of classes/labels altogether.
.. math::
\text{Micro Recall} = \frac{\sum_{k=1}^C TP_k}{\sum_{k=1}^C TP_k+FN_k}
where :math:`C` is the number of classes/labels (2 in binary case). :math:`k` in
:math:`TP_k` and :math:`FN_k`means that the measures are computed for class/label :math:`k` (in
a one-vs-rest sense in multiclass case).
For binary and multiclass inputs, this is equivalent with accuracy,
so use :class:`~ignite.metrics.accuracy.Accuracy`.
'samples'
for multilabel input, at first, recall is computed on a
per sample basis and then average across samples is returned.
.. math::
\text{Sample-averaged Recall} = \frac{\sum_{n=1}^N \frac{TP_n}{TP_n+FN_n}}{N}
where :math:`N` is the number of samples. :math:`n` in :math:`TP_n` and :math:`FN_n`
means that the measures are computed for sample :math:`n`, across labels.
Incompatible with binary and multiclass inputs.
'weighted'
like macro recall but considers class/label imbalance. For binary and multiclass
input, it computes metric for each class then returns average of them weighted by
support of classes (number of actual samples in each class). For multilabel input,
it computes recall for each label then returns average of them weighted by support
of labels (number of actual positive samples in each label).
.. math::
Recall_k = \frac{TP_k}{TP_k+FN_k}
.. math::
\text{Weighted Recall} = \frac{\sum_{k=1}^C P_k * Recall_k}{N}
where :math:`C` is the number of classes (2 in binary case). :math:`P_k` is the number
of samples belonged to class :math:`k` in binary and multiclass case, and the number of
positive samples belonged to label :math:`k` in multilabel case.
Note that for binary and multiclass data, weighted recall is equivalent
with accuracy, so use :class:`~ignite.metrics.accuracy.Accuracy`.
macro
computes macro recall which is unweighted average of metric computed across
classes or labels.
.. math::
\text{Macro Recall} = \frac{\sum_{k=1}^C Recall_k}{C}
where :math:`C` is the number of classes (2 in binary case).
True
like macro option. For backward compatibility.
is_multilabel: flag to use in multilabel case. By default, value is False.
device: specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.
Examples:
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
.. include:: defaults.rst
:start-after: :orphan:
Binary case. In binary and multilabel cases, the elements of
`y` and `y_pred` should have 0 or 1 values.
.. testcode:: 1
metric = Recall()
two_class_metric = Recall(average=None) # Returns recall for both classes
metric.attach(default_evaluator, "recall")
two_class_metric.attach(default_evaluator, "both classes recall")
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.tensor([1, 0, 1, 0, 1, 1])
state = default_evaluator.run([[y_pred, y_true]])
print(f"Recall: {state.metrics['recall']}")
print(f"Recall for class 0 and class 1: {state.metrics['both classes recall']}")
.. testoutput:: 1
Recall: 0.75
Recall for class 0 and class 1: tensor([0.5000, 0.7500], dtype=torch.float64)
Multiclass case
.. testcode:: 2
metric = Recall()
macro_metric = Recall(average=True)
metric.attach(default_evaluator, "recall")
macro_metric.attach(default_evaluator, "macro recall")
y_true = torch.tensor([2, 0, 2, 1, 0])
y_pred = torch.tensor([
[0.0266, 0.1719, 0.3055],
[0.6886, 0.3978, 0.8176],
[0.9230, 0.0197, 0.8395],
[0.1785, 0.2670, 0.6084],
[0.8448, 0.7177, 0.7288]
])
state = default_evaluator.run([[y_pred, y_true]])
print(f"Recall: {state.metrics['recall']}")
print(f"Macro Recall: {state.metrics['macro recall']}")
.. testoutput:: 2
Recall: tensor([0.5000, 0.0000, 0.5000], dtype=torch.float64)
Macro Recall: 0.3333333333333333
Multilabel case, the shapes must be (batch_size, num_categories, ...)
.. testcode:: 3
metric = Recall(is_multilabel=True)
micro_metric = Recall(is_multilabel=True, average='micro')
macro_metric = Recall(is_multilabel=True, average=True)
samples_metric = Recall(is_multilabel=True, average='samples')
metric.attach(default_evaluator, "recall")
micro_metric.attach(default_evaluator, "micro recall")
macro_metric.attach(default_evaluator, "macro recall")
samples_metric.attach(default_evaluator, "samples recall")
y_true = torch.tensor([
[0, 0, 1],
[0, 0, 0],
[0, 0, 0],
[1, 0, 0],
[0, 1, 1],
])
y_pred = torch.tensor([
[1, 1, 0],
[1, 0, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
])
state = default_evaluator.run([[y_pred, y_true]])
print(f"Recall: {state.metrics['recall']}")
print(f"Micro Recall: {state.metrics['micro recall']}")
print(f"Macro Recall: {state.metrics['macro recall']}")
print(f"Samples Recall: {state.metrics['samples recall']}")
.. testoutput:: 3
Recall: tensor([1., 1., 0.], dtype=torch.float64)
Micro Recall: 0.5
Macro Recall: 0.6666666666666666
Samples Recall: 0.3
Thresholding of predictions can be done as below:
.. testcode:: 4
def thresholded_output_transform(output):
y_pred, y = output
y_pred = torch.round(y_pred)
return y_pred, y
metric = Recall(output_transform=thresholded_output_transform)
metric.attach(default_evaluator, "recall")
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.tensor([0.6, 0.2, 0.9, 0.4, 0.7, 0.65])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics['recall'])
.. testoutput:: 4
0.75
.. versionchanged:: 0.4.10
Some new options were added to `average` parameter.
.. versionchanged:: 0.5.1
``skip_unrolling`` argument is added.
"""
@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
self._check_shape(output)
self._check_type(output)
_, y, correct = self._prepare_output(output)
if self._average == "samples":
actual_positives = y.sum(dim=1)
true_positives = correct.sum(dim=1)
self._numerator += torch.sum(true_positives / (actual_positives + self.eps))
self._denominator += y.size(0)
elif self._average == "micro":
self._denominator += y.sum()
self._numerator += correct.sum()
else: # _average in [False, 'macro', 'weighted']
self._denominator += y.sum(dim=0)
self._numerator += correct.sum(dim=0)
if self._average == "weighted":
self._weight += y.sum(dim=0)
self._updated = True
|