1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
|
import json
from typing import Callable, Collection, Dict, List, Optional, Union
import torch
from ignite.metrics.fbeta import Fbeta
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.precision import Precision
from ignite.metrics.recall import Recall
__all__ = ["ClassificationReport"]
def ClassificationReport(
beta: int = 1,
output_dict: bool = False,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
is_multilabel: bool = False,
labels: Optional[List[str]] = None,
) -> MetricsLambda:
r"""Build a text report showing the main classification metrics. The report resembles in functionality to
`scikit-learn classification_report
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html#sklearn.metrics.classification_report>`_
The underlying implementation doesn't use the sklearn function.
Args:
beta: weight of precision in harmonic mean
output_dict: If True, return output as dict, otherwise return a str
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
is_multilabel: If True, the tensors are assumed to be multilabel.
device: optional device specification for internal storage.
labels: Optional list of label indices to include in the report
Examples:
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
.. include:: defaults.rst
:start-after: :orphan:
Multiclass case
.. testcode:: 1
metric = ClassificationReport(output_dict=True)
metric.attach(default_evaluator, "cr")
y_true = torch.tensor([2, 0, 2, 1, 0, 1])
y_pred = torch.tensor([
[0.0266, 0.1719, 0.3055],
[0.6886, 0.3978, 0.8176],
[0.9230, 0.0197, 0.8395],
[0.1785, 0.2670, 0.6084],
[0.8448, 0.7177, 0.7288],
[0.7748, 0.9542, 0.8573],
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["cr"].keys())
print(state.metrics["cr"]["0"])
print(state.metrics["cr"]["1"])
print(state.metrics["cr"]["2"])
print(state.metrics["cr"]["macro avg"])
.. testoutput:: 1
dict_keys(['0', '1', '2', 'macro avg'])
{'precision': 0.5, 'recall': 0.5, 'f1-score': 0.4999...}
{'precision': 1.0, 'recall': 0.5, 'f1-score': 0.6666...}
{'precision': 0.3333..., 'recall': 0.5, 'f1-score': 0.3999...}
{'precision': 0.6111..., 'recall': 0.5, 'f1-score': 0.5222...}
Multilabel case, the shapes must be (batch_size, num_categories, ...)
.. testcode:: 2
metric = ClassificationReport(output_dict=True, is_multilabel=True)
metric.attach(default_evaluator, "cr")
y_true = torch.tensor([
[0, 0, 1],
[0, 0, 0],
[0, 0, 0],
[1, 0, 0],
[0, 1, 1],
])
y_pred = torch.tensor([
[1, 1, 0],
[1, 0, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["cr"].keys())
print(state.metrics["cr"]["0"])
print(state.metrics["cr"]["1"])
print(state.metrics["cr"]["2"])
print(state.metrics["cr"]["macro avg"])
.. testoutput:: 2
dict_keys(['0', '1', '2', 'macro avg'])
{'precision': 0.2, 'recall': 1.0, 'f1-score': 0.3333...}
{'precision': 0.5, 'recall': 1.0, 'f1-score': 0.6666...}
{'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0}
{'precision': 0.2333..., 'recall': 0.6666..., 'f1-score': 0.3333...}
"""
# setup all the underlying metrics
precision = Precision(average=False, is_multilabel=is_multilabel, output_transform=output_transform, device=device)
recall = Recall(average=False, is_multilabel=is_multilabel, output_transform=output_transform, device=device)
fbeta = Fbeta(beta, average=False, precision=precision, recall=recall)
averaged_precision = precision.mean()
averaged_recall = recall.mean()
averaged_fbeta = fbeta.mean()
def _wrapper(
re: torch.Tensor, pr: torch.Tensor, f: torch.Tensor, a_re: torch.Tensor, a_pr: torch.Tensor, a_f: torch.Tensor
) -> Union[Collection[str], Dict]:
if pr.shape != re.shape:
raise ValueError(
"Internal error: Precision and Recall have mismatched shapes: "
f"{pr.shape} vs {re.shape}. Please, open an issue "
"with a reference on this error. Thank you!"
)
dict_obj = {}
for idx, p_label in enumerate(pr):
dict_obj[_get_label_for_class(idx)] = {
"precision": p_label.item(),
"recall": re[idx].item(),
f"f{beta}-score": f[idx].item(),
}
dict_obj["macro avg"] = {
"precision": a_pr.item(),
"recall": a_re.item(),
f"f{beta}-score": a_f.item(),
}
return dict_obj if output_dict else json.dumps(dict_obj)
# helper method to get a label for a given class
def _get_label_for_class(idx: int) -> str:
return labels[idx] if labels else str(idx)
return MetricsLambda(_wrapper, recall, precision, fbeta, averaged_recall, averaged_precision, averaged_fbeta)
|