File: plot_gaussian_barycenter.py

package info (click to toggle)
python-pot 0.9.5%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,884 kB
  • sloc: python: 56,498; cpp: 2,310; makefile: 265; sh: 19
file content (133 lines) | stat: -rw-r--r-- 3,015 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# -*- coding: utf-8 -*-
"""
========================================================
Gaussian Bures-Wasserstein barycenters
========================================================

Illustration of Gaussian Bures-Wasserstein barycenters.

"""

# Authors: Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2
# %%
from matplotlib import colors
from matplotlib.patches import Ellipse
import numpy as np
import matplotlib.pylab as pl
import ot


# %%
# Define Gaussian Covariances and distributions
# ---------------------------------------------

C1 = np.array([[0.5, -0.4], [-0.4, 0.5]])
C2 = np.array([[1, 0.3], [0.3, 1]])
C3 = np.array([[1.5, 0], [0, 0.5]])
C4 = np.array([[0.5, 0], [0, 1.5]])

C = np.stack((C1, C2, C3, C4))

m1 = np.array([0, 0])
m2 = np.array([0, 4])
m3 = np.array([4, 0])
m4 = np.array([4, 4])

m = np.stack((m1, m2, m3, m4))

# %%
# Plot the distributions
# ----------------------


def draw_cov(mu, C, color=None, label=None, nstd=1):
    def eigsorted(cov):
        vals, vecs = np.linalg.eigh(cov)
        order = vals.argsort()[::-1]
        return vals[order], vecs[:, order]

    vals, vecs = eigsorted(C)
    theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
    w, h = 2 * nstd * np.sqrt(vals)
    ell = Ellipse(
        xy=(mu[0], mu[1]),
        width=w,
        height=h,
        alpha=0.5,
        angle=theta,
        facecolor=color,
        edgecolor=color,
        label=label,
        fill=True,
    )
    pl.gca().add_artist(ell)
    # pl.scatter(mu[0],mu[1],color=color, marker='x')


axis = [-1.5, 5.5, -1.5, 5.5]

pl.figure(1, (8, 2))
pl.clf()

pl.subplot(1, 4, 1)
draw_cov(m1, C1, color="C0")
pl.axis(axis)
pl.title("$\mathcal{N}(m_1,\Sigma_1)$")

pl.subplot(1, 4, 2)
draw_cov(m2, C2, color="C1")
pl.axis(axis)
pl.title("$\mathcal{N}(m_2,\Sigma_2)$")

pl.subplot(1, 4, 3)
draw_cov(m3, C3, color="C2")
pl.axis(axis)
pl.title("$\mathcal{N}(m_3,\Sigma_3)$")

pl.subplot(1, 4, 4)
draw_cov(m4, C4, color="C3")
pl.axis(axis)
pl.title("$\mathcal{N}(m_4,\Sigma_4)$")

# %%
# Compute Bures-Wasserstein barycenters and plot them
# -------------------------------------------

# basis for bilinear interpolation
v1 = np.array((1, 0, 0, 0))
v2 = np.array((0, 1, 0, 0))
v3 = np.array((0, 0, 1, 0))
v4 = np.array((0, 0, 0, 1))


colors = np.stack(
    (colors.to_rgb("C0"), colors.to_rgb("C1"), colors.to_rgb("C2"), colors.to_rgb("C3"))
)

pl.figure(2, (8, 8))

nb_interp = 6

for i in range(nb_interp):
    for j in range(nb_interp):
        tx = float(i) / (nb_interp - 1)
        ty = float(j) / (nb_interp - 1)

        # weights are constructed by bilinear interpolation
        tmp1 = (1 - tx) * v1 + tx * v2
        tmp2 = (1 - tx) * v3 + tx * v4
        weights = (1 - ty) * tmp1 + ty * tmp2

        color = np.dot(colors.T, weights)

        mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, weights)

        draw_cov(mb, Cb, color=color, label=None, nstd=0.3)

pl.axis(axis)
pl.axis("off")
pl.tight_layout()