1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
from typing import Callable, Optional, Union
import torch
from ignite.exceptions import NotComputableError
from ignite.metrics.gan.utils import _BaseInceptionMetric, InceptionModel
# These decorators helps with distributed settings
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce
__all__ = ["InceptionScore"]
class InceptionScore(_BaseInceptionMetric):
r"""Calculates Inception Score.
.. math::
\text{IS(G)} = \exp(\frac{1}{N}\sum_{i=1}^{N} D_{KL} (p(y|x^{(i)} \parallel \hat{p}(y))))
where :math:`p(y|x)` is the conditional probability of image being the given object and
:math:`p(y)` is the marginal probability that the given image is real, `G` refers to the
generated image and :math:`D_{KL}` refers to KL Divergence of the above mentioned probabilities.
More details can be found in `Barratt et al. 2018`__.
__ https://arxiv.org/pdf/1801.01973.pdf
Args:
num_features: number of features predicted by the model or number of classes of the model. Default
value is 1000.
feature_extractor: a torch Module for predicting the probabilities from the input data.
It returns a tensor of shape (batch_size, num_features).
If neither ``num_features`` nor ``feature_extractor`` are defined, by default we use an ImageNet
pretrained Inception Model. If only ``num_features`` is defined but ``feature_extractor`` is not
defined, ``feature_extractor`` is assigned Identity Function.
Please note that the class object will be implicitly converted to device mentioned in the
``device`` argument.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``y_pred``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
.. note::
The default Inception model requires the `torchvision` module to be installed.
Examples:
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
.. include:: defaults.rst
:start-after: :orphan:
.. code-block:: python
metric = InceptionScore()
metric.attach(default_evaluator, "is")
y = torch.rand(10, 3, 299, 299)
state = default_evaluator.run([y])
print(state.metrics["is"])
.. testcode::
metric = InceptionScore(num_features=1, feature_extractor=default_model)
metric.attach(default_evaluator, "is")
y = torch.zeros(10, 4)
state = default_evaluator.run([y])
print(state.metrics["is"])
.. testoutput::
1.0
.. versionadded:: 0.4.6
"""
_state_dict_all_req_keys = ("_num_examples", "_prob_total", "_total_kl_d")
def __init__(
self,
num_features: Optional[int] = None,
feature_extractor: Optional[torch.nn.Module] = None,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
) -> None:
if num_features is None and feature_extractor is None:
num_features = 1000
feature_extractor = InceptionModel(return_features=False, device=device)
self._eps = 1e-16
super(InceptionScore, self).__init__(
num_features=num_features,
feature_extractor=feature_extractor,
output_transform=output_transform,
device=device,
)
@reinit__is_reduced
def reset(self) -> None:
self._num_examples = 0
self._prob_total = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)
self._total_kl_d = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)
super(InceptionScore, self).reset() # type: ignore
@reinit__is_reduced
def update(self, output: torch.Tensor) -> None:
probabilities = self._extract_features(output)
prob_sum = torch.sum(probabilities, 0, dtype=torch.float64)
log_prob = torch.log(probabilities + self._eps)
if log_prob.dtype != probabilities.dtype:
log_prob = log_prob.to(probabilities)
kl_sum = torch.sum(probabilities * log_prob, 0, dtype=torch.float64)
self._num_examples += probabilities.shape[0]
self._prob_total += prob_sum
self._total_kl_d += kl_sum
@sync_all_reduce("_num_examples", "_prob_total", "_total_kl_d")
def compute(self) -> float:
if self._num_examples == 0:
raise NotComputableError("InceptionScore must have at least one example before it can be computed.")
mean_probs = self._prob_total / self._num_examples
log_mean_probs = torch.log(mean_probs + self._eps)
if log_mean_probs.dtype != self._prob_total.dtype:
log_mean_probs = log_mean_probs.to(self._prob_total)
excess_entropy = self._prob_total * log_mean_probs
avg_kl_d = torch.sum(self._total_kl_d - excess_entropy) / self._num_examples
return torch.exp(avg_kl_d).item()
|