File: inception_score.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (137 lines) | stat: -rw-r--r-- 5,581 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from typing import Callable, Optional, Union

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.gan.utils import _BaseInceptionMetric, InceptionModel

# These decorators helps with distributed settings
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce

__all__ = ["InceptionScore"]


class InceptionScore(_BaseInceptionMetric):
    r"""Calculates Inception Score.

    .. math::
       \text{IS(G)} = \exp(\frac{1}{N}\sum_{i=1}^{N} D_{KL} (p(y|x^{(i)} \parallel \hat{p}(y))))

    where :math:`p(y|x)` is the conditional probability of image being the given object and
    :math:`p(y)` is the marginal probability that the given image is real, `G` refers to the
    generated image and :math:`D_{KL}` refers to KL Divergence of the above mentioned probabilities.

    More details can be found in `Barratt et al. 2018`__.

    __ https://arxiv.org/pdf/1801.01973.pdf

    Args:
        num_features: number of features predicted by the model or number of classes of the model. Default
            value is 1000.
        feature_extractor: a torch Module for predicting the probabilities from the input data.
            It returns a tensor of shape (batch_size, num_features).
            If neither ``num_features`` nor ``feature_extractor`` are defined, by default we use an ImageNet
            pretrained Inception Model. If only ``num_features`` is defined but ``feature_extractor`` is not
            defined, ``feature_extractor`` is assigned Identity Function.
            Please note that the class object will be implicitly converted to device mentioned in the
            ``device`` argument.
        output_transform: a callable that is used to transform the
            :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
            form expected by the metric. This can be useful if, for example, you have a multi-output model and
            you want to compute the metric with respect to one of the outputs.
            By default, metrics require the output as ``y_pred``.
        device: specifies which device updates are accumulated on. Setting the
            metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
            non-blocking. By default, CPU.

    .. note::
        The default Inception model requires the `torchvision` module to be installed.

    Examples:

        For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

        .. include:: defaults.rst
            :start-after: :orphan:

        .. code-block:: python

            metric = InceptionScore()
            metric.attach(default_evaluator, "is")
            y = torch.rand(10, 3, 299, 299)
            state = default_evaluator.run([y])
            print(state.metrics["is"])

        .. testcode::

            metric = InceptionScore(num_features=1, feature_extractor=default_model)
            metric.attach(default_evaluator, "is")
            y = torch.zeros(10, 4)
            state = default_evaluator.run([y])
            print(state.metrics["is"])

        .. testoutput::

            1.0

    .. versionadded:: 0.4.6
    """

    _state_dict_all_req_keys = ("_num_examples", "_prob_total", "_total_kl_d")

    def __init__(
        self,
        num_features: Optional[int] = None,
        feature_extractor: Optional[torch.nn.Module] = None,
        output_transform: Callable = lambda x: x,
        device: Union[str, torch.device] = torch.device("cpu"),
    ) -> None:
        if num_features is None and feature_extractor is None:
            num_features = 1000
            feature_extractor = InceptionModel(return_features=False, device=device)

        self._eps = 1e-16

        super(InceptionScore, self).__init__(
            num_features=num_features,
            feature_extractor=feature_extractor,
            output_transform=output_transform,
            device=device,
        )

    @reinit__is_reduced
    def reset(self) -> None:
        self._num_examples = 0

        self._prob_total = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)
        self._total_kl_d = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)

        super(InceptionScore, self).reset()  # type: ignore

    @reinit__is_reduced
    def update(self, output: torch.Tensor) -> None:
        probabilities = self._extract_features(output)

        prob_sum = torch.sum(probabilities, 0, dtype=torch.float64)
        log_prob = torch.log(probabilities + self._eps)
        if log_prob.dtype != probabilities.dtype:
            log_prob = log_prob.to(probabilities)
        kl_sum = torch.sum(probabilities * log_prob, 0, dtype=torch.float64)

        self._num_examples += probabilities.shape[0]
        self._prob_total += prob_sum
        self._total_kl_d += kl_sum

    @sync_all_reduce("_num_examples", "_prob_total", "_total_kl_d")
    def compute(self) -> float:
        if self._num_examples == 0:
            raise NotComputableError("InceptionScore must have at least one example before it can be computed.")

        mean_probs = self._prob_total / self._num_examples
        log_mean_probs = torch.log(mean_probs + self._eps)
        if log_mean_probs.dtype != self._prob_total.dtype:
            log_mean_probs = log_mean_probs.to(self._prob_total)
        excess_entropy = self._prob_total * log_mean_probs
        avg_kl_d = torch.sum(self._total_kl_d - excess_entropy) / self._num_examples

        return torch.exp(avg_kl_d).item()