File: maximum_mean_discrepancy.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (148 lines) | stat: -rw-r--r-- 6,452 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Callable, Sequence

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["MaximumMeanDiscrepancy"]


class MaximumMeanDiscrepancy(Metric):
    r"""Calculates the mean of `maximum mean discrepancy (MMD)
    <https://www.onurtunali.com/ml/2019/03/08/maximum-mean-discrepancy-in-machine-learning.html>`_.

    .. math::
       \begin{align*}
           \text{MMD}^2 (P,Q) &= \underset{\| f \| \leq 1}{\text{sup}} | \mathbb{E}_{X\sim P}[f(X)]
           - \mathbb{E}_{Y\sim Q}[f(Y)] |^2 \\
           &\approx \frac{1}{B(B-1)} \sum_{i=1}^B \sum_{\substack{j=1 \\ j\neq i}}^B k(\mathbf{x}_i,\mathbf{x}_j)
           -\frac{2}{B^2}\sum_{i=1}^B \sum_{j=1}^B k(\mathbf{x}_i,\mathbf{y}_j)
           + \frac{1}{B(B-1)} \sum_{i=1}^B \sum_{\substack{j=1 \\ j\neq i}}^B k(\mathbf{y}_i,\mathbf{y}_j)
       \end{align*}

    where :math:`B` is the batch size, and :math:`\mathbf{x}_i` and :math:`\mathbf{y}_j` are
    feature vectors sampled from :math:`P` and :math:`Q`, respectively.
    :math:`k(\mathbf{x},\mathbf{y})=\exp(-\| \mathbf{x}-\mathbf{y} \|^2/ 2\sigma^2)` is the Gaussian RBF kernel.

    This metric computes the MMD for each batch and takes the average.

    More details can be found in `Gretton et al. 2012`__.

    __ https://www.jmlr.org/papers/volume13/gretton12a/gretton12a.pdf

    - ``update`` must receive output of the form ``(x, y)``.
    - ``x`` and ``y`` are expected to be in the same shape :math:`(B, \ldots)`.

    Args:
        var: the bandwidth :math:`\sigma^2` of the kernel. Default: 1.0
        output_transform: a callable that is used to transform the
            :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
            form expected by the metric. This can be useful if, for example, you have a multi-output model and
            you want to compute the metric with respect to one of the outputs.
            By default, this metric requires the output as ``(x, y)``.
        device: specifies which device updates are accumulated on. Setting the
            metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
            non-blocking. By default, CPU.
        skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
            true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
            Alternatively, ``output_transform`` can be used to handle this.

    Examples:
        To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
        The output of the engine's ``process_function`` needs to be in the format of
        ``(x, y)``. If not, ``output_tranform`` can be added
        to the metric to transform the output into the form expected by the metric.

        For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

        .. include:: defaults.rst
            :start-after: :orphan:

        .. testcode::

            metric = MaximumMeanDiscrepancy()
            metric.attach(default_evaluator, "mmd")
            x = torch.tensor([[-0.80324818, -0.95768364, -0.03807209],
                            [-0.11059691, -0.38230813, -0.4111988],
                            [-0.8864329, -0.02890403, -0.60119252],
                            [-0.68732452, -0.12854739, -0.72095073],
                            [-0.62604613, -0.52368328, -0.24112842]])
            y = torch.tensor([[0.0686768, 0.80502737, 0.53321717],
                            [0.83849465, 0.59099726, 0.76385441],
                            [0.68688272, 0.56833803, 0.98100778],
                            [0.55267761, 0.13084654, 0.45382906],
                            [0.0754253, 0.70317304, 0.4756805]])
            state = default_evaluator.run([[x, y]])
            print(state.metrics["mmd"])

        .. testoutput::

           1.072697639465332

    .. versionchanged:: 0.5.1
        ``skip_unrolling`` argument is added.
    """

    _state_dict_all_req_keys = ("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches")

    def __init__(
        self,
        var: float = 1.0,
        output_transform: Callable = lambda x: x,
        device: torch.device = torch.device("cpu"),
        skip_unrolling: bool = False,
    ):
        self.var = var
        super().__init__(output_transform, device, skip_unrolling=skip_unrolling)

    @reinit__is_reduced
    def reset(self) -> None:
        self._xx_sum = torch.tensor(0.0, device=self._device)
        self._yy_sum = torch.tensor(0.0, device=self._device)
        self._xy_sum = torch.tensor(0.0, device=self._device)
        self._num_batches = 0

    @reinit__is_reduced
    def update(self, output: Sequence[torch.Tensor]) -> None:
        x, y = output[0].detach(), output[1].detach()
        if x.shape != y.shape:
            raise ValueError(f"x and y must be in the same shape, got {x.shape} != {y.shape}.")

        if x.ndim >= 3:
            x = x.flatten(start_dim=1)
            y = y.flatten(start_dim=1)
        elif x.ndim == 1:
            raise ValueError(f"x must be in the shape of (B, ...), got {x.shape}.")

        xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
        rx = xx.diag().unsqueeze(0).expand_as(xx)
        ry = yy.diag().unsqueeze(0).expand_as(yy)

        dxx = rx.t() + rx - 2.0 * xx
        dyy = ry.t() + ry - 2.0 * yy
        dxy = rx.t() + ry - 2.0 * zz

        v = self.var
        XX = torch.exp(-0.5 * dxx / v)
        YY = torch.exp(-0.5 * dyy / v)
        XY = torch.exp(-0.5 * dxy / v)

        # unbiased
        n = x.shape[0]
        XX = (XX.sum() - n) / (n * (n - 1))
        YY = (YY.sum() - n) / (n * (n - 1))
        XY = XY.sum() / (n * n)

        self._xx_sum += XX.to(self._device)
        self._yy_sum += YY.to(self._device)
        self._xy_sum += XY.to(self._device)

        self._num_batches += 1

    @sync_all_reduce("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches")
    def compute(self) -> float:
        if self._num_batches == 0:
            raise NotComputableError("MaximumMeanDiscrepacy must have at least one batch before it can be computed.")
        mmd2 = (self._xx_sum + self._yy_sum - 2.0 * self._xy_sum).clamp(min=0.0) / self._num_batches
        return mmd2.sqrt().item()