1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
|
from typing import Callable, Sequence, Union
import torch
from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
__all__ = ["MultiLabelConfusionMatrix"]
class MultiLabelConfusionMatrix(Metric):
"""Calculates a confusion matrix for multi-labelled, multi-class data.
- ``update`` must receive output of the form ``(y_pred, y)``.
- `y_pred` must contain 0s and 1s and has the following shape (batch_size, num_classes, ...).
For example, `y_pred[i, j]` = 1 denotes that the j'th class is one of the labels of the i'th sample as predicted.
- `y` should have the following shape (batch_size, num_classes, ...) with 0s and 1s. For example,
`y[i, j]` = 1 denotes that the j'th class is one of the labels of the i'th sample according to the ground truth.
- both `y` and `y_pred` must be torch Tensors having any of the following types:
{torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}. They must have the same dimensions.
- The confusion matrix 'M' is of dimension (num_classes, 2, 2).
* M[i, 0, 0] corresponds to count/rate of true negatives of class i
* M[i, 0, 1] corresponds to count/rate of false positives of class i
* M[i, 1, 0] corresponds to count/rate of false negatives of class i
* M[i, 1, 1] corresponds to count/rate of true positives of class i
- The classes present in M are indexed as 0, ... , num_classes-1 as can be inferred from above.
Args:
num_classes: Number of classes, should be > 1.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device: specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
normalized: whether to normalize confusion matrix by its sum or not.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.
Example:
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
.. include:: defaults.rst
:start-after: :orphan:
.. testcode::
metric = MultiLabelConfusionMatrix(num_classes=3)
metric.attach(default_evaluator, "mlcm")
y_true = torch.tensor([
[0, 0, 1],
[0, 0, 0],
[0, 0, 0],
[1, 0, 0],
[0, 1, 1],
])
y_pred = torch.tensor([
[1, 1, 0],
[1, 0, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["mlcm"])
.. testoutput::
tensor([[[0, 4],
[0, 1]],
[[3, 1],
[0, 1]],
[[1, 2],
[2, 0]]])
.. versionadded:: 0.4.5
.. versionchanged:: 0.5.1
``skip_unrolling`` argument is added.
"""
_state_dict_all_req_keys = ("confusion_matrix", "_num_examples")
def __init__(
self,
num_classes: int,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
normalized: bool = False,
skip_unrolling: bool = False,
):
if num_classes <= 1:
raise ValueError("Argument num_classes needs to be > 1")
self.num_classes = num_classes
self._num_examples = 0
self.normalized = normalized
super(MultiLabelConfusionMatrix, self).__init__(
output_transform=output_transform, device=device, skip_unrolling=skip_unrolling
)
@reinit__is_reduced
def reset(self) -> None:
self.confusion_matrix = torch.zeros(self.num_classes, 2, 2, dtype=torch.int64, device=self._device)
self._num_examples = 0
@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
self._check_input(output)
y_pred, y = output[0].detach(), output[1].detach()
self._num_examples += y.shape[0]
y_reshaped = y.transpose(0, 1).reshape(self.num_classes, -1)
y_pred_reshaped = y_pred.transpose(0, 1).reshape(self.num_classes, -1)
y_total = y_reshaped.sum(dim=1)
y_pred_total = y_pred_reshaped.sum(dim=1)
tp = (y_reshaped * y_pred_reshaped).sum(dim=1)
fp = y_pred_total - tp
fn = y_total - tp
tn = y_reshaped.shape[1] - tp - fp - fn
self.confusion_matrix += torch.stack([tn, fp, fn, tp], dim=1).reshape(-1, 2, 2).to(self._device)
@sync_all_reduce("confusion_matrix", "_num_examples")
def compute(self) -> torch.Tensor:
if self._num_examples == 0:
raise NotComputableError("Confusion matrix must have at least one example before it can be computed.")
if self.normalized:
conf = self.confusion_matrix.to(dtype=torch.float64)
sums = conf.sum(dim=(1, 2))
return conf / sums[:, None, None]
return self.confusion_matrix
def _check_input(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
if y_pred.ndimension() < 2:
raise ValueError(
f"y_pred must at least have shape (batch_size, num_classes (currently set to {self.num_classes}), ...)"
)
if y.ndimension() < 2:
raise ValueError(
f"y must at least have shape (batch_size, num_classes (currently set to {self.num_classes}), ...)"
)
if y_pred.shape[0] != y.shape[0]:
raise ValueError(f"y_pred and y have different batch size: {y_pred.shape[0]} vs {y.shape[0]}")
if y_pred.shape[1] != self.num_classes:
raise ValueError(f"y_pred does not have correct number of classes: {y_pred.shape[1]} vs {self.num_classes}")
if y.shape[1] != self.num_classes:
raise ValueError(f"y does not have correct number of classes: {y.shape[1]} vs {self.num_classes}")
if y.shape != y_pred.shape:
raise ValueError("y and y_pred shapes must match.")
valid_types = (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
if y_pred.dtype not in valid_types:
raise ValueError(f"y_pred must be of any type: {valid_types}")
if y.dtype not in valid_types:
raise ValueError(f"y must be of any type: {valid_types}")
if not torch.equal(y_pred, y_pred**2):
raise ValueError("y_pred must be a binary tensor")
if not torch.equal(y, y**2):
raise ValueError("y must be a binary tensor")
|