File: epoch_metric.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (173 lines) | stat: -rw-r--r-- 7,050 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import warnings
from typing import Callable, cast, List, Optional, Tuple, Union

import torch

import ignite.distributed as idist
from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced

__all__ = ["EpochMetric"]


class EpochMetric(Metric):
    """Class for metrics that should be computed on the entire output history of a model.
    Model's output and targets are restricted to be of shape ``(batch_size, n_targets)``. Output
    datatype should be `float32`. Target datatype should be `long` for classification and `float` for regression.

    .. warning::

        Current implementation stores all input data (output and target) in as tensors before computing a metric.
        This can potentially lead to a memory error if the input data is larger than available RAM.

        In distributed configuration, all stored data (output and target) is mutually collected across all processes
        using all gather collective operation. This can potentially lead to a memory error.
        Compute method executes ``compute_fn`` on zero rank process only and final result is broadcasted to
        all processes.

    - ``update`` must receive output of the form ``(y_pred, y)``.

    Args:
        compute_fn: a callable which receives two tensors as the `predictions` and `targets`
            and returns a scalar. Input tensors will be on specified ``device`` (see arg below).
        output_transform: a callable that is used to transform the
            :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
            form expected by the metric. This can be useful if, for example, you have a multi-output model and
            you want to compute the metric with respect to one of the outputs.
        check_compute_fn: if True, ``compute_fn`` is run on the first batch of data to ensure there are no
            issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``.
            Default, True.
        device: optional device specification for internal storage.

    Example:

        For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

        .. include:: defaults.rst
            :start-after: :orphan:

        .. testcode::

            def mse_fn(y_preds, y_targets):
                return torch.mean(((y_preds - y_targets.type_as(y_preds)) ** 2)).item()

            metric = EpochMetric(mse_fn)
            metric.attach(default_evaluator, "mse")
            y_true = torch.tensor([0, 1, 2, 3, 4, 5])
            y_pred = y_true * 0.75
            state = default_evaluator.run([[y_pred, y_true]])
            print(state.metrics["mse"])

        .. testoutput::

            0.5729...

    Warnings:
        EpochMetricWarning: User is warned that there are issues with ``compute_fn`` on a batch of data processed.
        To disable the warning, set ``check_compute_fn=False``.

    .. versionchanged:: 0.5.1
        ``skip_unrolling`` argument is added.
    """

    _state_dict_all_req_keys = ("_predictions", "_targets")

    def __init__(
        self,
        compute_fn: Callable[[torch.Tensor, torch.Tensor], float],
        output_transform: Callable = lambda x: x,
        check_compute_fn: bool = True,
        device: Union[str, torch.device] = torch.device("cpu"),
        skip_unrolling: bool = False,
    ) -> None:
        if not callable(compute_fn):
            raise TypeError("Argument compute_fn should be callable.")

        self.compute_fn = compute_fn
        self._check_compute_fn = check_compute_fn

        super(EpochMetric, self).__init__(
            output_transform=output_transform, device=device, skip_unrolling=skip_unrolling
        )

    @reinit__is_reduced
    def reset(self) -> None:
        self._predictions: List[torch.Tensor] = []
        self._targets: List[torch.Tensor] = []
        self._result: Optional[float] = None

    def _check_shape(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
        y_pred, y = output
        if y_pred.ndimension() not in (1, 2):
            raise ValueError("Predictions should be of shape (batch_size, n_targets) or (batch_size, ).")

        if y.ndimension() not in (1, 2):
            raise ValueError("Targets should be of shape (batch_size, n_targets) or (batch_size, ).")

    def _check_type(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
        y_pred, y = output
        if len(self._predictions) < 1:
            return
        dtype_preds = self._predictions[-1].dtype
        if dtype_preds != y_pred.dtype:
            raise ValueError(
                f"Incoherent types between input y_pred and stored predictions: {dtype_preds} vs {y_pred.dtype}"
            )

        dtype_targets = self._targets[-1].dtype
        if dtype_targets != y.dtype:
            raise ValueError(f"Incoherent types between input y and stored targets: {dtype_targets} vs {y.dtype}")

    @reinit__is_reduced
    def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
        self._check_shape(output)
        y_pred, y = output[0].detach(), output[1].detach()

        if y_pred.ndimension() == 2 and y_pred.shape[1] == 1:
            y_pred = y_pred.squeeze(dim=-1)

        if y.ndimension() == 2 and y.shape[1] == 1:
            y = y.squeeze(dim=-1)

        y_pred = y_pred.clone().to(self._device)
        y = y.clone().to(self._device)

        self._check_type((y_pred, y))
        self._predictions.append(y_pred)
        self._targets.append(y)

        # Check once the signature and execution of compute_fn
        if len(self._predictions) == 1 and self._check_compute_fn:
            try:
                self.compute_fn(self._predictions[0], self._targets[0])
            except Exception as e:
                warnings.warn(f"Probably, there can be a problem with `compute_fn`:\n {e}.", EpochMetricWarning)

    def compute(self) -> float:
        if len(self._predictions) < 1 or len(self._targets) < 1:
            raise NotComputableError("EpochMetric must have at least one example before it can be computed.")

        if self._result is None:
            _prediction_tensor = torch.cat(self._predictions, dim=0)
            _target_tensor = torch.cat(self._targets, dim=0)

            ws = idist.get_world_size()
            if ws > 1:
                # All gather across all processes
                _prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
                _target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))

            self._result = 0.0
            if idist.get_rank() == 0:
                # Run compute_fn on zero rank only
                self._result = self.compute_fn(_prediction_tensor, _target_tensor)

            if ws > 1:
                # broadcast result to all processes
                self._result = cast(float, idist.broadcast(self._result, src=0))

        return self._result


class EpochMetricWarning(UserWarning):
    pass