File: multilabel_confusion_matrix.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (181 lines) | stat: -rw-r--r-- 7,372 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from typing import Callable, Sequence, Union

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["MultiLabelConfusionMatrix"]


class MultiLabelConfusionMatrix(Metric):
    """Calculates a confusion matrix for multi-labelled, multi-class data.

    - ``update`` must receive output of the form ``(y_pred, y)``.
    - `y_pred` must contain 0s and 1s and has the following shape (batch_size, num_classes, ...).
      For example, `y_pred[i, j]` = 1 denotes that the j'th class is one of the labels of the i'th sample as predicted.
    - `y` should have the following shape (batch_size, num_classes, ...) with 0s and 1s. For example,
      `y[i, j]` = 1 denotes that the j'th class is one of the labels of the i'th sample according to the ground truth.
    - both `y` and `y_pred` must be torch Tensors having any of the following types:
      {torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}. They must have the same dimensions.
    - The confusion matrix 'M' is of dimension (num_classes, 2, 2).

      * M[i, 0, 0] corresponds to count/rate of true negatives of class i
      * M[i, 0, 1] corresponds to count/rate of false positives of class i
      * M[i, 1, 0] corresponds to count/rate of false negatives of class i
      * M[i, 1, 1] corresponds to count/rate of true positives of class i

    - The classes present in M are indexed as 0, ... , num_classes-1 as can be inferred from above.

    Args:
        num_classes: Number of classes, should be > 1.
        output_transform: a callable that is used to transform the
            :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
            form expected by the metric. This can be useful if, for example, you have a multi-output model and
            you want to compute the metric with respect to one of the outputs.
        device: specifies which device updates are accumulated on. Setting the metric's
            device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
            default, CPU.
        normalized: whether to normalize confusion matrix by its sum or not.
        skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
            true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
            Alternatively, ``output_transform`` can be used to handle this.

    Example:

        For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

        .. include:: defaults.rst
            :start-after: :orphan:

        .. testcode::

            metric = MultiLabelConfusionMatrix(num_classes=3)
            metric.attach(default_evaluator, "mlcm")
            y_true = torch.tensor([
                [0, 0, 1],
                [0, 0, 0],
                [0, 0, 0],
                [1, 0, 0],
                [0, 1, 1],
            ])
            y_pred = torch.tensor([
                [1, 1, 0],
                [1, 0, 1],
                [1, 0, 0],
                [1, 0, 1],
                [1, 1, 0],
            ])
            state = default_evaluator.run([[y_pred, y_true]])
            print(state.metrics["mlcm"])

        .. testoutput::

            tensor([[[0, 4],
                     [0, 1]],

                    [[3, 1],
                     [0, 1]],

                    [[1, 2],
                     [2, 0]]])

    .. versionadded:: 0.4.5

    .. versionchanged:: 0.5.1
        ``skip_unrolling`` argument is added.
    """

    _state_dict_all_req_keys = ("confusion_matrix", "_num_examples")

    def __init__(
        self,
        num_classes: int,
        output_transform: Callable = lambda x: x,
        device: Union[str, torch.device] = torch.device("cpu"),
        normalized: bool = False,
        skip_unrolling: bool = False,
    ):
        if num_classes <= 1:
            raise ValueError("Argument num_classes needs to be > 1")

        self.num_classes = num_classes
        self._num_examples = 0
        self.normalized = normalized
        super(MultiLabelConfusionMatrix, self).__init__(
            output_transform=output_transform, device=device, skip_unrolling=skip_unrolling
        )

    @reinit__is_reduced
    def reset(self) -> None:
        self.confusion_matrix = torch.zeros(self.num_classes, 2, 2, dtype=torch.int64, device=self._device)
        self._num_examples = 0

    @reinit__is_reduced
    def update(self, output: Sequence[torch.Tensor]) -> None:
        self._check_input(output)
        y_pred, y = output[0].detach(), output[1].detach()

        self._num_examples += y.shape[0]
        y_reshaped = y.transpose(0, 1).reshape(self.num_classes, -1)
        y_pred_reshaped = y_pred.transpose(0, 1).reshape(self.num_classes, -1)

        y_total = y_reshaped.sum(dim=1)
        y_pred_total = y_pred_reshaped.sum(dim=1)

        tp = (y_reshaped * y_pred_reshaped).sum(dim=1)
        fp = y_pred_total - tp
        fn = y_total - tp
        tn = y_reshaped.shape[1] - tp - fp - fn

        self.confusion_matrix += torch.stack([tn, fp, fn, tp], dim=1).reshape(-1, 2, 2).to(self._device)

    @sync_all_reduce("confusion_matrix", "_num_examples")
    def compute(self) -> torch.Tensor:
        if self._num_examples == 0:
            raise NotComputableError("Confusion matrix must have at least one example before it can be computed.")

        if self.normalized:
            conf = self.confusion_matrix.to(dtype=torch.float64)
            sums = conf.sum(dim=(1, 2))
            return conf / sums[:, None, None]

        return self.confusion_matrix

    def _check_input(self, output: Sequence[torch.Tensor]) -> None:
        y_pred, y = output[0].detach(), output[1].detach()

        if y_pred.ndimension() < 2:
            raise ValueError(
                f"y_pred must at least have shape (batch_size, num_classes (currently set to {self.num_classes}), ...)"
            )

        if y.ndimension() < 2:
            raise ValueError(
                f"y must at least have shape (batch_size, num_classes (currently set to {self.num_classes}), ...)"
            )

        if y_pred.shape[0] != y.shape[0]:
            raise ValueError(f"y_pred and y have different batch size: {y_pred.shape[0]} vs {y.shape[0]}")

        if y_pred.shape[1] != self.num_classes:
            raise ValueError(f"y_pred does not have correct number of classes: {y_pred.shape[1]} vs {self.num_classes}")

        if y.shape[1] != self.num_classes:
            raise ValueError(f"y does not have correct number of classes: {y.shape[1]} vs {self.num_classes}")

        if y.shape != y_pred.shape:
            raise ValueError("y and y_pred shapes must match.")

        valid_types = (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
        if y_pred.dtype not in valid_types:
            raise ValueError(f"y_pred must be of any type: {valid_types}")

        if y.dtype not in valid_types:
            raise ValueError(f"y must be of any type: {valid_types}")

        if not torch.equal(y_pred, y_pred**2):
            raise ValueError("y_pred must be a binary tensor")

        if not torch.equal(y, y**2):
            raise ValueError("y must be a binary tensor")