File: metrics_lambda.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (187 lines) | stat: -rw-r--r-- 7,309 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import itertools
from typing import Any, Callable, Optional, Union

import torch

from ignite.engine import Engine
from ignite.metrics.metric import EpochWise, Metric, MetricUsage, reinit__is_reduced

__all__ = ["MetricsLambda"]


class MetricsLambda(Metric):
    """
    Apply a function to other metrics to obtain a new metric.
    The result of the new metric is defined to be the result
    of applying the function to the result of argument metrics.

    When update, this metric recursively updates the metrics
    it depends on. When reset, all its dependency metrics would be
    resetted as well. When attach, all its dependency metrics would be attached
    automatically (but partially, e.g :meth:`~ignite.metrics.metric.Metric.is_attached()` will return False).

    Args:
        f: the function that defines the computation
        args: Sequence of other metrics or something
            else that will be fed to ``f`` as arguments.
        kwargs: Sequence of other metrics or something
            else that will be fed to ``f`` as keyword arguments.

    Examples:

        For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

        .. include:: defaults.rst
            :start-after: :orphan:

        .. testcode::

            precision = Precision(average=False)
            recall = Recall(average=False)

            def Fbeta(r, p, beta):
                return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r + 1e-20)).item()

            F1 = MetricsLambda(Fbeta, recall, precision, 1)
            F2 = MetricsLambda(Fbeta, recall, precision, 2)
            F3 = MetricsLambda(Fbeta, recall, precision, 3)
            F4 = MetricsLambda(Fbeta, recall, precision, 4)

            F1.attach(default_evaluator, "F1")
            F2.attach(default_evaluator, "F2")
            F3.attach(default_evaluator, "F3")
            F4.attach(default_evaluator, "F4")

            y_true = torch.tensor([1, 0, 1, 0, 0, 1])
            y_pred = torch.tensor([1, 0, 1, 0, 1, 1])
            state = default_evaluator.run([[y_pred, y_true]])
            print(state.metrics["F1"])
            print(state.metrics["F2"])
            print(state.metrics["F3"])
            print(state.metrics["F4"])

        .. testoutput::

            0.8571...
            0.9375...
            0.9677...
            0.9807...

        When check if the metric is attached, if one of its dependency
        metrics is detached, the metric is considered detached too.

        .. code-block:: python

            engine = ...
            precision = Precision(average=False)

            aP = precision.mean()

            aP.attach(engine, "aP")

            assert aP.is_attached(engine)
            # partially attached
            assert not precision.is_attached(engine)

            precision.detach(engine)

            assert not aP.is_attached(engine)
            # fully attached
            assert not precision.is_attached(engine)
    """

    _state_dict_all_req_keys = ("_updated", "args", "kwargs")

    def __init__(self, f: Callable, *args: Any, **kwargs: Any) -> None:
        self.function = f
        self.args = list(args)  # we need args to be a list instead of a tuple for state_dict/load_state_dict feature
        self.kwargs = kwargs
        self.engine: Optional[Engine] = None
        self._updated = False
        super(MetricsLambda, self).__init__(device="cpu")

    @reinit__is_reduced
    def reset(self) -> None:
        for i in itertools.chain(self.args, self.kwargs.values()):
            if isinstance(i, Metric):
                i.reset()
        self._updated = False

    @reinit__is_reduced
    def update(self, output: Any) -> None:
        if self.engine:
            raise ValueError(
                "MetricsLambda is already attached to an engine, "
                "and MetricsLambda can't use update API while it's attached."
            )

        for i in itertools.chain(self.args, self.kwargs.values()):
            if isinstance(i, Metric):
                i.update(output)

        self._updated = True

    def compute(self) -> Any:
        materialized = [_get_value_on_cpu(i) for i in self.args]
        materialized_kwargs = {k: _get_value_on_cpu(v) for k, v in self.kwargs.items()}
        return self.function(*materialized, **materialized_kwargs)

    def _internal_attach(self, engine: Engine, usage: MetricUsage) -> None:
        self.engine = engine
        for index, metric in enumerate(itertools.chain(self.args, self.kwargs.values())):
            if isinstance(metric, MetricsLambda):
                metric._internal_attach(engine, usage)
            elif isinstance(metric, Metric):
                # NB : metrics is attached partially
                # We must not use is_attached() but rather if these events exist
                if not engine.has_event_handler(metric.started, usage.STARTED):
                    engine.add_event_handler(usage.STARTED, metric.started)
                if not engine.has_event_handler(metric.iteration_completed, usage.ITERATION_COMPLETED):
                    engine.add_event_handler(usage.ITERATION_COMPLETED, metric.iteration_completed)

    def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = EpochWise()) -> None:
        if self._updated:
            raise ValueError(
                "The underlying metrics are already updated, can't attach while using reset/update/compute API."
            )
        usage = self._check_usage(usage)
        # recursively attach all its dependencies (partially)
        self._internal_attach(engine, usage)
        # attach only handler on EPOCH_COMPLETED
        engine.add_event_handler(usage.COMPLETED, self.completed, name)

    def detach(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise()) -> None:
        usage = self._check_usage(usage)
        # remove from engine
        super(MetricsLambda, self).detach(engine, usage)
        self.engine = None

    def is_attached(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise()) -> bool:
        usage = self._check_usage(usage)
        # check recursively the dependencies
        return super(MetricsLambda, self).is_attached(engine, usage) and self._internal_is_attached(engine, usage)

    def _internal_is_attached(self, engine: Engine, usage: MetricUsage) -> bool:
        # if no engine, metrics is not attached
        if engine is None:
            return False
        # check recursively if metrics are attached
        is_detached = False
        for metric in itertools.chain(self.args, self.kwargs.values()):
            if isinstance(metric, MetricsLambda):
                if not metric._internal_is_attached(engine, usage):
                    is_detached = True
            elif isinstance(metric, Metric):
                if not engine.has_event_handler(metric.started, usage.STARTED):
                    is_detached = True
                if not engine.has_event_handler(metric.iteration_completed, usage.ITERATION_COMPLETED):
                    is_detached = True
        return not is_detached


def _get_value_on_cpu(v: Any) -> Any:
    if isinstance(v, Metric):
        v = v.compute()
    if isinstance(v, torch.Tensor):
        v = v.cpu()
    return v