File: plot_GMMOT_plan.py

package info (click to toggle)
python-pot 0.9.5%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 3,884 kB
  • sloc: python: 56,498; cpp: 2,310; makefile: 265; sh: 19
file content (100 lines) | stat: -rw-r--r-- 2,895 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# %%
# -*- coding: utf-8 -*-
r"""
====================================================
GMM Plan 1D
====================================================

Illustration of the GMM plan for
the Mixture Wasserstein between two GMM in 1D,
as well as the two maps T_mean and T_rand.
T_mean is the barycentric projection of the GMM coupling,
and T_rand takes a random gaussian image between two components,
according to the coupling and the GMMs.
See [69] for details.
.. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.

"""

# Author: Eloi Tanguy <eloi.tanguy@u-paris>
#         Remi Flamary <remi.flamary@polytehnique.edu>
#         Julie Delon <julie.delon@math.cnrs.fr>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 1

import numpy as np
from ot.plot import plot1D_mat, rescale_for_imshow_plot
from ot.gmm import gmm_ot_plan_density, gmm_pdf, gmm_ot_apply_map
import matplotlib.pyplot as plt

##############################################################################
# Generate GMMOT plan plot it
# ---------------------------
ks = 2
kt = 3
d = 1
eps = 0.1
m_s = np.array([[1], [2]])
m_t = np.array([[3], [4.2], [5]])
C_s = np.array([[[0.05]], [[0.06]]])
C_t = np.array([[[0.03]], [[0.07]], [[0.04]]])
w_s = np.array([0.4, 0.6])
w_t = np.array([0.4, 0.2, 0.4])

n = 500
a_x, b_x = 0, 3
x = np.linspace(a_x, b_x, n)
a_y, b_y = 2, 6
y = np.linspace(a_y, b_y, n)
plan_density = gmm_ot_plan_density(
    x[:, None], y[:, None], m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=2e-2
)

a = gmm_pdf(x[:, None], m_s, C_s, w_s)
b = gmm_pdf(y[:, None], m_t, C_t, w_t)
plt.figure(figsize=(8, 8))
plot1D_mat(
    a,
    b,
    plan_density,
    title="GMM OT plan",
    plot_style="xy",
    a_label="Source distribution",
    b_label="Target distribution",
)


##############################################################################
# Generate GMMOT maps and plot them over plan
# -------------------------------------------
plt.figure(figsize=(8, 8))
ax_s, ax_t, ax_M = plot1D_mat(
    a,
    b,
    plan_density,
    plot_style="xy",
    title="GMM OT plan with T_mean and T_rand maps",
    a_label="Source distribution",
    b_label="Target distribution",
)
T_mean = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t, w_s, w_t, method="bary")[:, 0]
x_rescaled, T_mean_rescaled = rescale_for_imshow_plot(x, T_mean, n, a_y=a_y, b_y=b_y)

ax_M.plot(
    x_rescaled, T_mean_rescaled, label="T_mean", alpha=0.5, linewidth=5, color="aqua"
)

T_rand = gmm_ot_apply_map(
    x[:, None], m_s, m_t, C_s, C_t, w_s, w_t, method="rand", seed=0
)[:, 0]
x_rescaled, T_rand_rescaled = rescale_for_imshow_plot(x, T_rand, n, a_y=a_y, b_y=b_y)

ax_M.scatter(
    x_rescaled, T_rand_rescaled, label="T_rand", alpha=0.5, s=20, color="orange"
)

ax_M.legend(loc="upper left", fontsize=13)

# %%