File: running_average.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (232 lines) | stat: -rw-r--r-- 11,333 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import warnings
from typing import Any, Callable, cast, Optional, Union

import torch

import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.metrics.metric import Metric, MetricUsage, reinit__is_reduced, RunningBatchWise, SingleEpochRunningBatchWise

__all__ = ["RunningAverage"]


class RunningAverage(Metric):
    """Compute running average of a metric or the output of process function.

    Args:
        src: input source: an instance of :class:`~ignite.metrics.metric.Metric` or None. The latter
            corresponds to `engine.state.output` which holds the output of process function.
        alpha: running average decay factor, default 0.98
        output_transform: a function to use to transform the output if `src` is None and
            corresponds the output of process function. Otherwise it should be None.
        epoch_bound: whether the running average should be reset after each epoch. It is depracated in favor of
            ``usage`` argument in :meth:`attach` method. Setting ``epoch_bound`` to ``False`` is equivalent to
            ``usage=SingleEpochRunningBatchWise()`` and setting it to ``True`` is equivalent to
            ``usage=RunningBatchWise()`` in the :meth:`attach` method. Default None.
        device: specifies which device updates are accumulated on. Should be
            None when ``src`` is an instance of :class:`~ignite.metrics.metric.Metric`, as the running average will
            use the ``src``'s device. Otherwise, defaults to CPU. Only applicable when the computed value
            from the metric is a tensor.
        skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
            true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
            Alternatively, ``output_transform`` can be used to handle this.

    Examples:

        For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

        .. include:: defaults.rst
            :start-after: :orphan:

        .. testcode:: 1

            default_trainer = get_default_trainer()

            accuracy = Accuracy()
            metric = RunningAverage(accuracy)
            metric.attach(default_trainer, 'running_avg_accuracy')

            @default_trainer.on(Events.ITERATION_COMPLETED)
            def log_running_avg_metrics():
                print(default_trainer.state.metrics['running_avg_accuracy'])

            y_true = [torch.tensor(y) for y in [[0], [1], [0], [1], [0], [1]]]
            y_pred = [torch.tensor(y) for y in [[0], [0], [0], [1], [1], [1]]]

            state = default_trainer.run(zip(y_pred, y_true))

        .. testoutput:: 1

            1.0
            0.98
            0.98039...
            0.98079...
            0.96117...
            0.96195...

        .. testcode:: 2

            default_trainer = get_default_trainer()

            metric = RunningAverage(output_transform=lambda x: x.item())
            metric.attach(default_trainer, 'running_avg_accuracy')

            @default_trainer.on(Events.ITERATION_COMPLETED)
            def log_running_avg_metrics():
                print(default_trainer.state.metrics['running_avg_accuracy'])

            y = [torch.tensor(y) for y in [[0], [1], [0], [1], [0], [1]]]

            state = default_trainer.run(y)

        .. testoutput:: 2

            0.0
            0.020000...
            0.019600...
            0.039208...
            0.038423...
            0.057655...

    .. versionchanged:: 0.5.1
        ``skip_unrolling`` argument is added.
    """

    required_output_keys = None
    _state_dict_all_req_keys = ("_value", "src")

    def __init__(
        self,
        src: Optional[Metric] = None,
        alpha: float = 0.98,
        output_transform: Optional[Callable] = None,
        epoch_bound: Optional[bool] = None,
        device: Optional[Union[str, torch.device]] = None,
        skip_unrolling: bool = False,
    ):
        if not (isinstance(src, Metric) or src is None):
            raise TypeError("Argument src should be a Metric or None.")
        if not (0.0 < alpha <= 1.0):
            raise ValueError("Argument alpha should be a float between 0.0 and 1.0.")

        if isinstance(src, Metric):
            if output_transform is not None:
                raise ValueError("Argument output_transform should be None if src is a Metric.")

            def output_transform(x: Any) -> Any:
                return x

            if device is not None:
                raise ValueError("Argument device should be None if src is a Metric.")
            self.src: Union[Metric, None] = src
            device = src._device
        else:
            if output_transform is None:
                raise ValueError(
                    "Argument output_transform should not be None if src corresponds "
                    "to the output of process function."
                )
            self.src = None
            if device is None:
                device = torch.device("cpu")

        if epoch_bound is not None:
            warnings.warn(
                "`epoch_bound` is deprecated and will be removed in the future. Consider using `usage` argument of"
                "`attach` method instead. `epoch_bound=True` is equivalent with `usage=SingleEpochRunningBatchWise()`"
                " and `epoch_bound=False` is equivalent with `usage=RunningBatchWise()`."
            )
        self.epoch_bound = epoch_bound
        self.alpha = alpha
        super(RunningAverage, self).__init__(
            output_transform=output_transform, device=device, skip_unrolling=skip_unrolling
        )

    @reinit__is_reduced
    def reset(self) -> None:
        self._value: Optional[Union[float, torch.Tensor]] = None
        if isinstance(self.src, Metric):
            self.src.reset()

    @reinit__is_reduced
    def update(self, output: Union[torch.Tensor, float]) -> None:
        if self.src is None:
            output = output.detach().to(self._device, copy=True) if isinstance(output, torch.Tensor) else output
            value = idist.all_reduce(output) / idist.get_world_size()
        else:
            value = self.src.compute()
            self.src.reset()

        if self._value is None:
            self._value = value
        else:
            self._value = self._value * self.alpha + (1.0 - self.alpha) * value

    def compute(self) -> Union[torch.Tensor, float]:
        return cast(Union[torch.Tensor, float], self._value)

    def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = RunningBatchWise()) -> None:
        r"""
        Attach the metric to the ``engine`` using the events determined by the ``usage``.

        Args:
            engine: the engine to get attached to.
            name: by which, the metric is inserted into ``engine.state.metrics`` dictionary.
            usage: the usage determining on which events the metric is reset, updated and computed. It should be an
                instance of the :class:`~ignite.metrics.metric.MetricUsage`\ s in the following table.

                ======================================================= ===========================================
                ``usage`` **class**                                     **Description**
                ======================================================= ===========================================
                :class:`~.metrics.metric.RunningBatchWise`              Running average of the ``src`` metric or
                                                                        ``engine.state.output`` is computed across
                                                                        batches. In the former case, on each batch,
                                                                        ``src`` is reset, updated and computed then
                                                                        its value is retrieved. Default.
                :class:`~.metrics.metric.SingleEpochRunningBatchWise`   Same as above but the running average is
                                                                        computed across batches in an epoch so it
                                                                        is reset at the end of the epoch.
                :class:`~.metrics.metric.RunningEpochWise`              Running average of the ``src`` metric or
                                                                        ``engine.state.output`` is computed across
                                                                        epochs. In the former case, ``src`` works
                                                                        as if it was attached in a
                                                                        :class:`~ignite.metrics.metric.EpochWise`
                                                                        manner and its computed value is retrieved
                                                                        at the end of the epoch. The latter case
                                                                        doesn't make much sense for this usage as
                                                                        the ``engine.state.output`` of the last
                                                                        batch is retrieved then.
                ======================================================= ===========================================

        ``RunningAverage`` retrieves ``engine.state.output`` at ``usage.ITERATION_COMPLETED`` if the ``src`` is not
        given and it's computed and updated using ``src``, by manually calling its ``compute`` method, or
        ``engine.state.output`` at ``usage.COMPLETED`` event.
        Also if ``src`` is given, it is updated at ``usage.ITERATION_COMPLETED``, but its reset event is determined by
        ``usage`` type. If ``isinstance(usage, BatchWise)`` holds true, ``src`` is reset on ``BatchWise().STARTED``,
        otherwise on ``EpochWise().STARTED`` if ``isinstance(usage, EpochWise)``.

        .. versionchanged:: 0.5.1
            Added `usage` argument
        """
        usage = self._check_usage(usage)
        if self.epoch_bound is not None:
            usage = SingleEpochRunningBatchWise() if self.epoch_bound else RunningBatchWise()

        if isinstance(self.src, Metric) and not engine.has_event_handler(
            self.src.iteration_completed, Events.ITERATION_COMPLETED
        ):
            engine.add_event_handler(Events.ITERATION_COMPLETED, self.src.iteration_completed)

        super().attach(engine, name, usage)

    def detach(self, engine: Engine, usage: Union[str, MetricUsage] = RunningBatchWise()) -> None:
        usage = self._check_usage(usage)
        if self.epoch_bound is not None:
            usage = SingleEpochRunningBatchWise() if self.epoch_bound else RunningBatchWise()

        if isinstance(self.src, Metric) and engine.has_event_handler(
            self.src.iteration_completed, Events.ITERATION_COMPLETED
        ):
            engine.remove_event_handler(self.src.iteration_completed, Events.ITERATION_COMPLETED)

        super().detach(engine, usage)