1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
|
import warnings
from typing import Any, Callable, cast, Optional, Union
import torch
import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.metrics.metric import Metric, MetricUsage, reinit__is_reduced, RunningBatchWise, SingleEpochRunningBatchWise
__all__ = ["RunningAverage"]
class RunningAverage(Metric):
"""Compute running average of a metric or the output of process function.
Args:
src: input source: an instance of :class:`~ignite.metrics.metric.Metric` or None. The latter
corresponds to `engine.state.output` which holds the output of process function.
alpha: running average decay factor, default 0.98
output_transform: a function to use to transform the output if `src` is None and
corresponds the output of process function. Otherwise it should be None.
epoch_bound: whether the running average should be reset after each epoch. It is depracated in favor of
``usage`` argument in :meth:`attach` method. Setting ``epoch_bound`` to ``False`` is equivalent to
``usage=SingleEpochRunningBatchWise()`` and setting it to ``True`` is equivalent to
``usage=RunningBatchWise()`` in the :meth:`attach` method. Default None.
device: specifies which device updates are accumulated on. Should be
None when ``src`` is an instance of :class:`~ignite.metrics.metric.Metric`, as the running average will
use the ``src``'s device. Otherwise, defaults to CPU. Only applicable when the computed value
from the metric is a tensor.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.
Examples:
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
.. include:: defaults.rst
:start-after: :orphan:
.. testcode:: 1
default_trainer = get_default_trainer()
accuracy = Accuracy()
metric = RunningAverage(accuracy)
metric.attach(default_trainer, 'running_avg_accuracy')
@default_trainer.on(Events.ITERATION_COMPLETED)
def log_running_avg_metrics():
print(default_trainer.state.metrics['running_avg_accuracy'])
y_true = [torch.tensor(y) for y in [[0], [1], [0], [1], [0], [1]]]
y_pred = [torch.tensor(y) for y in [[0], [0], [0], [1], [1], [1]]]
state = default_trainer.run(zip(y_pred, y_true))
.. testoutput:: 1
1.0
0.98
0.98039...
0.98079...
0.96117...
0.96195...
.. testcode:: 2
default_trainer = get_default_trainer()
metric = RunningAverage(output_transform=lambda x: x.item())
metric.attach(default_trainer, 'running_avg_accuracy')
@default_trainer.on(Events.ITERATION_COMPLETED)
def log_running_avg_metrics():
print(default_trainer.state.metrics['running_avg_accuracy'])
y = [torch.tensor(y) for y in [[0], [1], [0], [1], [0], [1]]]
state = default_trainer.run(y)
.. testoutput:: 2
0.0
0.020000...
0.019600...
0.039208...
0.038423...
0.057655...
.. versionchanged:: 0.5.1
``skip_unrolling`` argument is added.
"""
required_output_keys = None
_state_dict_all_req_keys = ("_value", "src")
def __init__(
self,
src: Optional[Metric] = None,
alpha: float = 0.98,
output_transform: Optional[Callable] = None,
epoch_bound: Optional[bool] = None,
device: Optional[Union[str, torch.device]] = None,
skip_unrolling: bool = False,
):
if not (isinstance(src, Metric) or src is None):
raise TypeError("Argument src should be a Metric or None.")
if not (0.0 < alpha <= 1.0):
raise ValueError("Argument alpha should be a float between 0.0 and 1.0.")
if isinstance(src, Metric):
if output_transform is not None:
raise ValueError("Argument output_transform should be None if src is a Metric.")
def output_transform(x: Any) -> Any:
return x
if device is not None:
raise ValueError("Argument device should be None if src is a Metric.")
self.src: Union[Metric, None] = src
device = src._device
else:
if output_transform is None:
raise ValueError(
"Argument output_transform should not be None if src corresponds "
"to the output of process function."
)
self.src = None
if device is None:
device = torch.device("cpu")
if epoch_bound is not None:
warnings.warn(
"`epoch_bound` is deprecated and will be removed in the future. Consider using `usage` argument of"
"`attach` method instead. `epoch_bound=True` is equivalent with `usage=SingleEpochRunningBatchWise()`"
" and `epoch_bound=False` is equivalent with `usage=RunningBatchWise()`."
)
self.epoch_bound = epoch_bound
self.alpha = alpha
super(RunningAverage, self).__init__(
output_transform=output_transform, device=device, skip_unrolling=skip_unrolling
)
@reinit__is_reduced
def reset(self) -> None:
self._value: Optional[Union[float, torch.Tensor]] = None
if isinstance(self.src, Metric):
self.src.reset()
@reinit__is_reduced
def update(self, output: Union[torch.Tensor, float]) -> None:
if self.src is None:
output = output.detach().to(self._device, copy=True) if isinstance(output, torch.Tensor) else output
value = idist.all_reduce(output) / idist.get_world_size()
else:
value = self.src.compute()
self.src.reset()
if self._value is None:
self._value = value
else:
self._value = self._value * self.alpha + (1.0 - self.alpha) * value
def compute(self) -> Union[torch.Tensor, float]:
return cast(Union[torch.Tensor, float], self._value)
def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = RunningBatchWise()) -> None:
r"""
Attach the metric to the ``engine`` using the events determined by the ``usage``.
Args:
engine: the engine to get attached to.
name: by which, the metric is inserted into ``engine.state.metrics`` dictionary.
usage: the usage determining on which events the metric is reset, updated and computed. It should be an
instance of the :class:`~ignite.metrics.metric.MetricUsage`\ s in the following table.
======================================================= ===========================================
``usage`` **class** **Description**
======================================================= ===========================================
:class:`~.metrics.metric.RunningBatchWise` Running average of the ``src`` metric or
``engine.state.output`` is computed across
batches. In the former case, on each batch,
``src`` is reset, updated and computed then
its value is retrieved. Default.
:class:`~.metrics.metric.SingleEpochRunningBatchWise` Same as above but the running average is
computed across batches in an epoch so it
is reset at the end of the epoch.
:class:`~.metrics.metric.RunningEpochWise` Running average of the ``src`` metric or
``engine.state.output`` is computed across
epochs. In the former case, ``src`` works
as if it was attached in a
:class:`~ignite.metrics.metric.EpochWise`
manner and its computed value is retrieved
at the end of the epoch. The latter case
doesn't make much sense for this usage as
the ``engine.state.output`` of the last
batch is retrieved then.
======================================================= ===========================================
``RunningAverage`` retrieves ``engine.state.output`` at ``usage.ITERATION_COMPLETED`` if the ``src`` is not
given and it's computed and updated using ``src``, by manually calling its ``compute`` method, or
``engine.state.output`` at ``usage.COMPLETED`` event.
Also if ``src`` is given, it is updated at ``usage.ITERATION_COMPLETED``, but its reset event is determined by
``usage`` type. If ``isinstance(usage, BatchWise)`` holds true, ``src`` is reset on ``BatchWise().STARTED``,
otherwise on ``EpochWise().STARTED`` if ``isinstance(usage, EpochWise)``.
.. versionchanged:: 0.5.1
Added `usage` argument
"""
usage = self._check_usage(usage)
if self.epoch_bound is not None:
usage = SingleEpochRunningBatchWise() if self.epoch_bound else RunningBatchWise()
if isinstance(self.src, Metric) and not engine.has_event_handler(
self.src.iteration_completed, Events.ITERATION_COMPLETED
):
engine.add_event_handler(Events.ITERATION_COMPLETED, self.src.iteration_completed)
super().attach(engine, name, usage)
def detach(self, engine: Engine, usage: Union[str, MetricUsage] = RunningBatchWise()) -> None:
usage = self._check_usage(usage)
if self.epoch_bound is not None:
usage = SingleEpochRunningBatchWise() if self.epoch_bound else RunningBatchWise()
if isinstance(self.src, Metric) and engine.has_event_handler(
self.src.iteration_completed, Events.ITERATION_COMPLETED
):
engine.remove_event_handler(self.src.iteration_completed, Events.ITERATION_COMPLETED)
super().detach(engine, usage)
|