File: fid.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (267 lines) | stat: -rw-r--r-- 9,994 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import warnings
from typing import Callable, Optional, Sequence, Union

import torch
from packaging.version import Version

from ignite.metrics.gan.utils import _BaseInceptionMetric, InceptionModel
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce

__all__ = [
    "FID",
]


if Version(torch.__version__) <= Version("1.7.0"):
    torch_outer = torch.ger
else:
    torch_outer = torch.outer


def fid_score(
    mu1: torch.Tensor, mu2: torch.Tensor, sigma1: torch.Tensor, sigma2: torch.Tensor, eps: float = 1e-6
) -> float:
    try:
        import numpy as np
    except ImportError:
        raise ModuleNotFoundError("fid_score requires numpy to be installed.")

    try:
        import scipy.linalg
    except ImportError:
        raise ModuleNotFoundError("fid_score requires scipy to be installed.")

    mu1, mu2 = mu1.cpu(), mu2.cpu()
    sigma1, sigma2 = sigma1.cpu(), sigma2.cpu()

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2), disp=False)
    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    if not np.isfinite(covmean).all():
        tr_covmean = np.sum(np.sqrt(((np.diag(sigma1) * eps) * (np.diag(sigma2) * eps)) / (eps * eps)))

    return float(diff.dot(diff).item() + torch.trace(sigma1) + torch.trace(sigma2) - 2 * tr_covmean)


class FID(_BaseInceptionMetric):
    r"""Calculates Frechet Inception Distance.

    .. math::
       \text{FID} = |\mu_{1} - \mu_{2}| + \text{Tr}(\sigma_{1} + \sigma_{2} - {2}\sqrt{\sigma_1*\sigma_2})

    where :math:`\mu_1` and :math:`\sigma_1` refer to the mean and covariance of the train data and
    :math:`\mu_2` and :math:`\sigma_2` refer to the mean and covariance of the test data.

    More details can be found in `Heusel et al. 2017`__

    __ https://arxiv.org/pdf/1706.08500.pdf

    In addition, a faster and online computation approach can be found in `Mathiasen et al. 2020`__

    __ https://arxiv.org/pdf/2009.14075.pdf

    Remark:

        This implementation is inspired by `pytorch_fid` package which can be found `here`__

        __ https://github.com/mseitzer/pytorch-fid

    .. note::
        The default Inception model requires the `torchvision` module to be installed.
        FID also requires `scipy` library for matrix square root calculations.

    Args:
        num_features: number of features predicted by the model or the reduced feature vector of the image.
            Default value is 1000.
        feature_extractor: a torch Module for extracting the features from the input data.
            It returns a tensor of shape (batch_size, num_features).
            If neither ``num_features`` nor ``feature_extractor`` are defined, by default we use an ImageNet
            pretrained Inception Model and use model's output logits as features.
            If only ``num_features`` is defined but ``feature_extractor`` is not defined,
            ``feature_extractor`` is assigned Identity Function.
            Please note that the model will be implicitly converted to device mentioned in the ``device``
            argument.
        output_transform: a callable that is used to transform the
            :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
            form expected by the metric. This can be useful if, for example, you have a multi-output model and
            you want to compute the metric with respect to one of the outputs.
            By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
        device: specifies which device updates are accumulated on. Setting the
            metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
            non-blocking. By default, CPU.

    Examples:

        For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

        .. include:: defaults.rst
            :start-after: :orphan:

        .. testcode::

            metric = FID(num_features=1, feature_extractor=default_model)
            metric.attach(default_evaluator, "fid")
            y_true = torch.ones(10, 4)
            y_pred = torch.ones(10, 4)
            state = default_evaluator.run([[y_pred, y_true]])
            print(state.metrics["fid"])

        .. testoutput::

            0.0

    .. note::

        The default `torchvision` model used is InceptionV3 pretrained on ImageNet.
        This can lead to differences in results with `pytorch_fid`. To find comparable results,
        the following model wrapper should be used:

        .. code::

            import torch.nn as nn

            # wrapper class as feature_extractor
            class WrapperInceptionV3(nn.Module):

                def __init__(self, fid_incv3):
                    super().__init__()
                    self.fid_incv3 = fid_incv3

                @torch.no_grad()
                def forward(self, x):
                    y = self.fid_incv3(x)
                    y = y[0]
                    y = y[:, :, 0, 0]
                    return y

            # use cpu rather than cuda to get comparable results
            device = "cpu"

            # pytorch_fid model
            dims = 2048
            block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
            model = InceptionV3([block_idx]).to(device)

            # wrapper model to pytorch_fid model
            wrapper_model = WrapperInceptionV3(model)
            wrapper_model.eval();

            # comparable metric
            pytorch_fid_metric = FID(num_features=dims, feature_extractor=wrapper_model)

        Important, `pytorch_fid` results depend on the batch size if the device is `cuda`.

    .. versionadded:: 0.4.6
    """

    _state_dict_all_req_keys = ("_num_examples", "_train_total", "_test_total", "_train_sigma", "_test_sigma")

    def __init__(
        self,
        num_features: Optional[int] = None,
        feature_extractor: Optional[torch.nn.Module] = None,
        output_transform: Callable = lambda x: x,
        device: Union[str, torch.device] = torch.device("cpu"),
    ) -> None:
        try:
            import numpy as np  # noqa: F401
        except ImportError:
            raise ModuleNotFoundError("This module requires numpy to be installed.")

        try:
            import scipy  # noqa: F401
        except ImportError:
            raise ModuleNotFoundError("This module requires scipy to be installed.")

        if num_features is None and feature_extractor is None:
            num_features = 1000
            feature_extractor = InceptionModel(return_features=False, device=device)

        self._eps = 1e-6

        super(FID, self).__init__(
            num_features=num_features,
            feature_extractor=feature_extractor,
            output_transform=output_transform,
            device=device,
        )

    @staticmethod
    def _online_update(features: torch.Tensor, total: torch.Tensor, sigma: torch.Tensor) -> None:
        total += features
        sigma += torch_outer(features, features)

    def _get_covariance(self, sigma: torch.Tensor, total: torch.Tensor) -> torch.Tensor:
        r"""
        Calculates covariance from mean and sum of products of variables
        """

        sub_matrix = torch_outer(total, total)
        sub_matrix = sub_matrix / self._num_examples

        return (sigma - sub_matrix) / (self._num_examples - 1)

    @reinit__is_reduced
    def reset(self) -> None:
        self._train_sigma = torch.zeros(
            (self._num_features, self._num_features), dtype=torch.float64, device=self._device
        )

        self._train_total = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)

        self._test_sigma = torch.zeros(
            (self._num_features, self._num_features), dtype=torch.float64, device=self._device
        )

        self._test_total = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)
        self._num_examples: int = 0

        super(FID, self).reset()  # type: ignore

    @reinit__is_reduced
    def update(self, output: Sequence[torch.Tensor]) -> None:
        train, test = output
        train_features = self._extract_features(train)
        test_features = self._extract_features(test)

        if train_features.shape[0] != test_features.shape[0] or train_features.shape[1] != test_features.shape[1]:
            raise ValueError(
                f"""
    Number of Training Features and Testing Features should be equal ({train_features.shape} != {test_features.shape})
                """
            )

        # Updates the mean and covariance for the train features
        for features in train_features:
            self._online_update(features, self._train_total, self._train_sigma)

        # Updates the mean and covariance for the test features
        for features in test_features:
            self._online_update(features, self._test_total, self._test_sigma)

        self._num_examples += train_features.shape[0]

    @sync_all_reduce("_num_examples", "_train_total", "_test_total", "_train_sigma", "_test_sigma")
    def compute(self) -> float:
        fid = fid_score(
            mu1=self._train_total / self._num_examples,
            mu2=self._test_total / self._num_examples,
            sigma1=self._get_covariance(self._train_sigma, self._train_total),
            sigma2=self._get_covariance(self._test_sigma, self._test_total),
            eps=self._eps,
        )

        if torch.isnan(torch.tensor(fid)) or torch.isinf(torch.tensor(fid)):
            warnings.warn("The product of covariance of train and test features is out of bounds.")

        return fid