File: confusion_matrix.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (475 lines) | stat: -rw-r--r-- 18,215 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
import numbers
from typing import Callable, Optional, Sequence, Tuple, Union

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
from ignite.metrics.metrics_lambda import MetricsLambda

__all__ = ["ConfusionMatrix", "mIoU", "IoU", "DiceCoefficient", "cmAccuracy", "cmPrecision", "cmRecall", "JaccardIndex"]


class ConfusionMatrix(Metric):
    """Calculates confusion matrix for multi-class data.

    - ``update`` must receive output of the form ``(y_pred, y)``.
    - `y_pred` must contain logits and has the following shape (batch_size, num_classes, ...).
      If you are doing binary classification, see Note for an example on how to get this.
    - `y` should have the following shape (batch_size, ...) and contains ground-truth class indices
      with or without the background class. During the computation, argmax of `y_pred` is taken to determine
      predicted classes.

    Args:
        num_classes: Number of classes, should be > 1. See notes for more details.
        average: confusion matrix values averaging schema: None, "samples", "recall", "precision".
            Default is None. If `average="samples"` then confusion matrix values are normalized by the number of seen
            samples. If `average="recall"` then confusion matrix values are normalized such that diagonal values
            represent class recalls. If `average="precision"` then confusion matrix values are normalized such that
            diagonal values represent class precisions.
        output_transform: a callable that is used to transform the
            :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
            form expected by the metric. This can be useful if, for example, you have a multi-output model and
            you want to compute the metric with respect to one of the outputs.
        device: specifies which device updates are accumulated on. Setting the metric's
            device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
            default, CPU.
        skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
            true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
            Alternatively, ``output_transform`` can be used to handle this.

    Note:
        The confusion matrix is formatted such that columns are predictions and rows are targets.
        For example, if you were to plot the matrix, you could correctly assign to the horizontal axis
        the label "predicted values" and to the vertical axis the label "actual values".

    Note:
        In case of the targets `y` in `(batch_size, ...)` format, target indices between 0 and `num_classes` only
        contribute to the confusion matrix and others are neglected. For example, if `num_classes=20` and target index
        equal 255 is encountered, then it is filtered out.

    Examples:

        For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

        .. include:: defaults.rst
            :start-after: :orphan:

        .. testcode:: 1

            metric = ConfusionMatrix(num_classes=3)
            metric.attach(default_evaluator, 'cm')
            y_true = torch.tensor([0, 1, 0, 1, 2])
            y_pred = torch.tensor([
                [0.0, 1.0, 0.0],
                [0.0, 1.0, 0.0],
                [1.0, 0.0, 0.0],
                [0.0, 1.0, 0.0],
                [0.0, 1.0, 0.0],
            ])
            state = default_evaluator.run([[y_pred, y_true]])
            print(state.metrics['cm'])

        .. testoutput:: 1

            tensor([[1, 1, 0],
                    [0, 2, 0],
                    [0, 1, 0]])

        If you are doing binary classification with a single output unit, you may have to transform your network output,
        so that you have one value for each class. E.g. you can transform your network output into a one-hot vector
        with:

        .. testcode:: 2

            def binary_one_hot_output_transform(output):
                from ignite import utils
                y_pred, y = output
                y_pred = torch.sigmoid(y_pred).round().long()
                y_pred = utils.to_onehot(y_pred, 2)
                y = y.long()
                return y_pred, y

            metric = ConfusionMatrix(num_classes=2, output_transform=binary_one_hot_output_transform)
            metric.attach(default_evaluator, 'cm')
            y_true = torch.tensor([0, 1, 0, 1, 0])
            y_pred = torch.tensor([0, 0, 1, 1, 0])
            state = default_evaluator.run([[y_pred, y_true]])
            print(state.metrics['cm'])

        .. testoutput:: 2

            tensor([[2, 1],
                    [1, 1]])

    .. versionchanged:: 0.5.1
        ``skip_unrolling`` argument is added.
    """

    _state_dict_all_req_keys = ("confusion_matrix", "_num_examples")

    def __init__(
        self,
        num_classes: int,
        average: Optional[str] = None,
        output_transform: Callable = lambda x: x,
        device: Union[str, torch.device] = torch.device("cpu"),
        skip_unrolling: bool = True,
    ):
        if average is not None and average not in ("samples", "recall", "precision"):
            raise ValueError("Argument average can None or one of 'samples', 'recall', 'precision'")

        if num_classes <= 1:
            raise ValueError("Argument num_classes needs to be > 1")

        self.num_classes = num_classes
        self._num_examples = 0
        self.average = average
        super(ConfusionMatrix, self).__init__(
            output_transform=output_transform, device=device, skip_unrolling=skip_unrolling
        )

    @reinit__is_reduced
    def reset(self) -> None:
        self.confusion_matrix = torch.zeros(self.num_classes, self.num_classes, dtype=torch.int64, device=self._device)
        self._num_examples = 0

    def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
        y_pred, y = output[0].detach(), output[1].detach()

        if y_pred.ndimension() < 2:
            raise ValueError(
                f"y_pred must have shape (batch_size, num_classes (currently set to {self.num_classes}), ...), "
                f"but given {y_pred.shape}"
            )

        if y_pred.shape[1] != self.num_classes:
            raise ValueError(f"y_pred does not have correct number of classes: {y_pred.shape[1]} vs {self.num_classes}")

        if not (y.ndimension() + 1 == y_pred.ndimension()):
            raise ValueError(
                f"y_pred must have shape (batch_size, num_classes (currently set to {self.num_classes}), ...) "
                "and y must have shape of (batch_size, ...), "
                f"but given {y.shape} vs {y_pred.shape}."
            )

        y_shape = y.shape
        y_pred_shape: Tuple[int, ...] = y_pred.shape

        if y.ndimension() + 1 == y_pred.ndimension():
            y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:]

        if y_shape != y_pred_shape:
            raise ValueError("y and y_pred must have compatible shapes.")

    @reinit__is_reduced
    def update(self, output: Sequence[torch.Tensor]) -> None:
        self._check_shape(output)
        y_pred, y = output[0].detach(), output[1].detach()

        self._num_examples += y_pred.shape[0]

        # target is (batch_size, ...)
        y_pred = torch.argmax(y_pred, dim=1).flatten()
        y = y.flatten()

        target_mask = (y >= 0) & (y < self.num_classes)
        y = y[target_mask]
        y_pred = y_pred[target_mask]

        indices = self.num_classes * y + y_pred
        m = torch.bincount(indices, minlength=self.num_classes**2).reshape(self.num_classes, self.num_classes)
        self.confusion_matrix += m.to(self.confusion_matrix)

    @sync_all_reduce("confusion_matrix", "_num_examples")
    def compute(self) -> torch.Tensor:
        if self._num_examples == 0:
            raise NotComputableError("Confusion matrix must have at least one example before it can be computed.")
        if self.average:
            self.confusion_matrix = self.confusion_matrix.float()
            if self.average == "samples":
                return self.confusion_matrix / self._num_examples
            else:
                return self.normalize(self.confusion_matrix, self.average)
        return self.confusion_matrix

    @staticmethod
    def normalize(matrix: torch.Tensor, average: str) -> torch.Tensor:
        """Normalize given `matrix` with given `average`."""
        if average == "recall":
            return matrix / (matrix.sum(dim=1).unsqueeze(1) + 1e-15)
        elif average == "precision":
            return matrix / (matrix.sum(dim=0) + 1e-15)
        else:
            raise ValueError("Argument average should be one of 'samples', 'recall', 'precision'")


def IoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda:
    r"""Calculates Intersection over Union using :class:`~ignite.metrics.confusion_matrix.ConfusionMatrix` metric.

    .. math:: \text{J}(A, B) = \frac{ \lvert A \cap B \rvert }{ \lvert A \cup B \rvert }

    Args:
        cm: instance of confusion matrix metric
        ignore_index: index to ignore, e.g. background index

    Returns:
        MetricsLambda

    Examples:

        For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

        .. include:: defaults.rst
            :start-after: :orphan:

        .. testcode::

            cm = ConfusionMatrix(num_classes=3)
            metric = IoU(cm)
            metric.attach(default_evaluator, 'iou')
            y_true = torch.tensor([0, 1, 0, 1, 2])
            y_pred = torch.tensor([
                [0.0, 1.0, 0.0],
                [0.0, 1.0, 0.0],
                [1.0, 0.0, 0.0],
                [0.0, 1.0, 0.0],
                [0.0, 1.0, 0.0],
            ])
            state = default_evaluator.run([[y_pred, y_true]])
            print(state.metrics['iou'])

        .. testoutput::

            tensor([0.5000, 0.5000, 0.0000], dtype=torch.float64)
    """
    if not isinstance(cm, ConfusionMatrix):
        raise TypeError(f"Argument cm should be instance of ConfusionMatrix, but given {type(cm)}")

    if not (cm.average in (None, "samples")):
        raise ValueError("ConfusionMatrix should have average attribute either None or 'samples'")

    if ignore_index is not None:
        if not (isinstance(ignore_index, numbers.Integral) and 0 <= ignore_index < cm.num_classes):
            raise ValueError(
                f"ignore_index should be integer and in the range of [0, {cm.num_classes}), but given {ignore_index}"
            )

    # Increase floating point precision and pass to CPU
    cm = cm.to(torch.double)
    iou: MetricsLambda = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-15)
    if ignore_index is not None:
        ignore_idx: int = ignore_index  # used due to typing issues with mympy

        def ignore_index_fn(iou_vector: torch.Tensor) -> torch.Tensor:
            if ignore_idx >= len(iou_vector):
                raise ValueError(f"ignore_index {ignore_idx} is larger than the length of IoU vector {len(iou_vector)}")
            indices = list(range(len(iou_vector)))
            indices.remove(ignore_idx)
            return iou_vector[indices]

        return MetricsLambda(ignore_index_fn, iou)
    else:
        return iou


def mIoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda:
    """Calculates mean Intersection over Union using :class:`~ignite.metrics.confusion_matrix.ConfusionMatrix` metric.

    Args:
        cm: instance of confusion matrix metric
        ignore_index: index to ignore, e.g. background index

    Returns:
        MetricsLambda

    Examples:

        For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

        .. include:: defaults.rst
            :start-after: :orphan:

        .. testcode::

            cm = ConfusionMatrix(num_classes=3)
            metric = mIoU(cm, ignore_index=0)
            metric.attach(default_evaluator, 'miou')
            y_true = torch.tensor([0, 1, 0, 1, 2])
            y_pred = torch.tensor([
                [0.0, 1.0, 0.0],
                [0.0, 1.0, 0.0],
                [1.0, 0.0, 0.0],
                [0.0, 1.0, 0.0],
                [0.0, 1.0, 0.0],
            ])
            state = default_evaluator.run([[y_pred, y_true]])
            print(state.metrics['miou'])

        .. testoutput::

            0.24999...
    """
    iou: MetricsLambda = IoU(cm=cm, ignore_index=ignore_index).mean()
    return iou


def cmAccuracy(cm: ConfusionMatrix) -> MetricsLambda:
    """Calculates accuracy using :class:`~ignite.metrics.metric.ConfusionMatrix` metric.

    Args:
        cm: instance of confusion matrix metric

    Returns:
        MetricsLambda
    """
    # Increase floating point precision and pass to CPU
    cm = cm.to(torch.double)
    accuracy: MetricsLambda = cm.diag().sum() / (cm.sum() + 1e-15)
    return accuracy


def cmPrecision(cm: ConfusionMatrix, average: bool = True) -> MetricsLambda:
    """Calculates precision using :class:`~ignite.metrics.metric.ConfusionMatrix` metric.

    Args:
        cm: instance of confusion matrix metric
        average: if True metric value is averaged over all classes
    Returns:
        MetricsLambda
    """

    # Increase floating point precision and pass to CPU
    cm = cm.to(torch.double)
    precision: MetricsLambda = cm.diag() / (cm.sum(dim=0) + 1e-15)
    if average:
        mean: MetricsLambda = precision.mean()
        return mean
    return precision


def cmRecall(cm: ConfusionMatrix, average: bool = True) -> MetricsLambda:
    """
    Calculates recall using :class:`~ignite.metrics.confusion_matrix.ConfusionMatrix` metric.
    Args:
        cm: instance of confusion matrix metric
        average: if True metric value is averaged over all classes
    Returns:
        MetricsLambda
    """

    # Increase floating point precision and pass to CPU
    cm = cm.to(torch.double)
    recall: MetricsLambda = cm.diag() / (cm.sum(dim=1) + 1e-15)
    if average:
        mean: MetricsLambda = recall.mean()
        return mean
    return recall


def DiceCoefficient(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda:
    """Calculates Dice Coefficient for a given :class:`~ignite.metrics.confusion_matrix.ConfusionMatrix` metric.

    Args:
        cm: instance of confusion matrix metric
        ignore_index: index to ignore, e.g. background index

    Examples:

        For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

        .. include:: defaults.rst
            :start-after: :orphan:

        .. testcode::

            cm = ConfusionMatrix(num_classes=3)
            metric = DiceCoefficient(cm, ignore_index=0)
            metric.attach(default_evaluator, 'dice')
            y_true = torch.tensor([0, 1, 0, 1, 2])
            y_pred = torch.tensor([
                [0.0, 1.0, 0.0],
                [0.0, 1.0, 0.0],
                [1.0, 0.0, 0.0],
                [0.0, 1.0, 0.0],
                [0.0, 1.0, 0.0],
            ])
            state = default_evaluator.run([[y_pred, y_true]])
            print(state.metrics['dice'])

        .. testoutput::

            tensor([0.6667, 0.0000], dtype=torch.float64)
    """

    if not isinstance(cm, ConfusionMatrix):
        raise TypeError(f"Argument cm should be instance of ConfusionMatrix, but given {type(cm)}")

    if ignore_index is not None:
        if not (isinstance(ignore_index, numbers.Integral) and 0 <= ignore_index < cm.num_classes):
            raise ValueError(
                f"ignore_index should be integer and in the range of [0, {cm.num_classes}), but given {ignore_index}"
            )

    # Increase floating point precision and pass to CPU
    cm = cm.to(torch.double)
    dice: MetricsLambda = 2.0 * cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) + 1e-15)

    if ignore_index is not None:
        ignore_idx: int = ignore_index  # used due to typing issues with mympy

        def ignore_index_fn(dice_vector: torch.Tensor) -> torch.Tensor:
            if ignore_idx >= len(dice_vector):
                raise ValueError(
                    f"ignore_index {ignore_idx} is larger than the length of Dice vector {len(dice_vector)}"
                )
            indices = list(range(len(dice_vector)))
            indices.remove(ignore_idx)
            return dice_vector[indices]

        return MetricsLambda(ignore_index_fn, dice)
    else:
        return dice


def JaccardIndex(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda:
    r"""Calculates the Jaccard Index using :class:`~ignite.metrics.confusion_matrix.ConfusionMatrix` metric.
    Implementation is based on :meth:`~ignite.metrics.IoU`.

    .. math:: \text{J}(A, B) = \frac{ \lvert A \cap B \rvert }{ \lvert A \cup B \rvert }

    Args:
        cm: instance of confusion matrix metric
        ignore_index: index to ignore, e.g. background index

    Returns:
        MetricsLambda

    Examples:

        For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

        .. include:: defaults.rst
            :start-after: :orphan:

        .. testcode::

            cm = ConfusionMatrix(num_classes=3)
            metric = JaccardIndex(cm, ignore_index=0)
            metric.attach(default_evaluator, 'jac')
            y_true = torch.tensor([0, 1, 0, 1, 2])
            y_pred = torch.tensor([
                [0.0, 1.0, 0.0],
                [0.0, 1.0, 0.0],
                [1.0, 0.0, 0.0],
                [0.0, 1.0, 0.0],
                [0.0, 1.0, 0.0],
            ])
            state = default_evaluator.run([[y_pred, y_true]])
            print(state.metrics['jac'])

        .. testoutput::

            tensor([0.5000, 0.0000], dtype=torch.float64)
    """
    return IoU(cm, ignore_index)