File: classification_report.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (147 lines) | stat: -rw-r--r-- 5,988 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import json
from typing import Callable, Collection, Dict, List, Optional, Union

import torch

from ignite.metrics.fbeta import Fbeta
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.precision import Precision
from ignite.metrics.recall import Recall

__all__ = ["ClassificationReport"]


def ClassificationReport(
    beta: int = 1,
    output_dict: bool = False,
    output_transform: Callable = lambda x: x,
    device: Union[str, torch.device] = torch.device("cpu"),
    is_multilabel: bool = False,
    labels: Optional[List[str]] = None,
) -> MetricsLambda:
    r"""Build a text report showing the main classification metrics. The report resembles in functionality to
    `scikit-learn classification_report
    <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html#sklearn.metrics.classification_report>`_
    The underlying implementation doesn't use the sklearn function.

    Args:
        beta: weight of precision in harmonic mean
        output_dict: If True, return output as dict, otherwise return a str
        output_transform: a callable that is used to transform the
            :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
            form expected by the metric. This can be useful if, for example, you have a multi-output model and
            you want to compute the metric with respect to one of the outputs.
        is_multilabel: If True, the tensors are assumed to be multilabel.
        device: optional device specification for internal storage.
        labels: Optional list of label indices to include in the report

    Examples:

        For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

        .. include:: defaults.rst
            :start-after: :orphan:

        Multiclass case

        .. testcode:: 1

            metric = ClassificationReport(output_dict=True)
            metric.attach(default_evaluator, "cr")
            y_true = torch.tensor([2, 0, 2, 1, 0, 1])
            y_pred = torch.tensor([
                [0.0266, 0.1719, 0.3055],
                [0.6886, 0.3978, 0.8176],
                [0.9230, 0.0197, 0.8395],
                [0.1785, 0.2670, 0.6084],
                [0.8448, 0.7177, 0.7288],
                [0.7748, 0.9542, 0.8573],
            ])
            state = default_evaluator.run([[y_pred, y_true]])
            print(state.metrics["cr"].keys())
            print(state.metrics["cr"]["0"])
            print(state.metrics["cr"]["1"])
            print(state.metrics["cr"]["2"])
            print(state.metrics["cr"]["macro avg"])

        .. testoutput:: 1

            dict_keys(['0', '1', '2', 'macro avg'])
            {'precision': 0.5, 'recall': 0.5, 'f1-score': 0.4999...}
            {'precision': 1.0, 'recall': 0.5, 'f1-score': 0.6666...}
            {'precision': 0.3333..., 'recall': 0.5, 'f1-score': 0.3999...}
            {'precision': 0.6111..., 'recall': 0.5, 'f1-score': 0.5222...}

        Multilabel case, the shapes must be (batch_size, num_categories, ...)

        .. testcode:: 2

            metric = ClassificationReport(output_dict=True, is_multilabel=True)
            metric.attach(default_evaluator, "cr")
            y_true = torch.tensor([
                [0, 0, 1],
                [0, 0, 0],
                [0, 0, 0],
                [1, 0, 0],
                [0, 1, 1],
            ])
            y_pred = torch.tensor([
                [1, 1, 0],
                [1, 0, 1],
                [1, 0, 0],
                [1, 0, 1],
                [1, 1, 0],
            ])
            state = default_evaluator.run([[y_pred, y_true]])
            print(state.metrics["cr"].keys())
            print(state.metrics["cr"]["0"])
            print(state.metrics["cr"]["1"])
            print(state.metrics["cr"]["2"])
            print(state.metrics["cr"]["macro avg"])

        .. testoutput:: 2

            dict_keys(['0', '1', '2', 'macro avg'])
            {'precision': 0.2, 'recall': 1.0, 'f1-score': 0.3333...}
            {'precision': 0.5, 'recall': 1.0, 'f1-score': 0.6666...}
            {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0}
            {'precision': 0.2333..., 'recall': 0.6666..., 'f1-score': 0.3333...}

    """

    # setup all the underlying metrics
    precision = Precision(average=False, is_multilabel=is_multilabel, output_transform=output_transform, device=device)
    recall = Recall(average=False, is_multilabel=is_multilabel, output_transform=output_transform, device=device)
    fbeta = Fbeta(beta, average=False, precision=precision, recall=recall)
    averaged_precision = precision.mean()
    averaged_recall = recall.mean()
    averaged_fbeta = fbeta.mean()

    def _wrapper(
        re: torch.Tensor, pr: torch.Tensor, f: torch.Tensor, a_re: torch.Tensor, a_pr: torch.Tensor, a_f: torch.Tensor
    ) -> Union[Collection[str], Dict]:
        if pr.shape != re.shape:
            raise ValueError(
                "Internal error: Precision and Recall have mismatched shapes: "
                f"{pr.shape} vs {re.shape}. Please, open an issue "
                "with a reference on this error. Thank you!"
            )
        dict_obj = {}
        for idx, p_label in enumerate(pr):
            dict_obj[_get_label_for_class(idx)] = {
                "precision": p_label.item(),
                "recall": re[idx].item(),
                f"f{beta}-score": f[idx].item(),
            }
        dict_obj["macro avg"] = {
            "precision": a_pr.item(),
            "recall": a_re.item(),
            f"f{beta}-score": a_f.item(),
        }
        return dict_obj if output_dict else json.dumps(dict_obj)

    # helper method to get a label for a given class
    def _get_label_for_class(idx: int) -> str:
        return labels[idx] if labels else str(idx)

    return MetricsLambda(_wrapper, recall, precision, fbeta, averaged_recall, averaged_precision, averaged_fbeta)