1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
|
# -*- coding: utf-8 -*-
"""
=======================================
Generalized Wasserstein Barycenter Demo
=======================================
This example illustrates the computation of Generalized Wasserstein Barycenter
as proposed in [42].
[42] Delon, J., Gozlan, N., and Saint-Dizier, A..
Generalized Wasserstein barycenters between probability measures living on different subspaces.
arXiv preprint arXiv:2105.09755, 2021.
"""
# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pylab as pl
import ot
import matplotlib.animation as animation
########################
# Generate and plot data
# ----------------------
# Input measures
sub_sample_factor = 8
I1 = pl.imread("../../data/redcross.png").astype(np.float64)[
::sub_sample_factor, ::sub_sample_factor, 2
]
I2 = pl.imread("../../data/tooth.png").astype(np.float64)[
::-sub_sample_factor, ::sub_sample_factor, 2
]
I3 = pl.imread("../../data/heart.png").astype(np.float64)[
::-sub_sample_factor, ::sub_sample_factor, 2
]
sz = I1.shape[0]
UU, VV = np.meshgrid(np.arange(sz), np.arange(sz))
# Input measure locations in their respective 2D spaces
X_list = [np.stack((UU[im == 0], VV[im == 0]), 1) * 1.0 for im in [I1, I2, I3]]
# Input measure weights
a_list = [ot.unif(x.shape[0]) for x in X_list]
# Projections 3D -> 2D
P1 = np.array([[1, 0, 0], [0, 1, 0]])
P2 = np.array([[0, 1, 0], [0, 0, 1]])
P3 = np.array([[1, 0, 0], [0, 0, 1]])
P_list = [P1, P2, P3]
# Barycenter weights
weights = np.array([1 / 3, 1 / 3, 1 / 3])
# Number of barycenter points to compute
n_samples_bary = 150
# Send the input measures into 3D space for visualization
X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)]
# Plot the input data
fig = plt.figure(figsize=(3, 3))
axis = fig.add_subplot(1, 1, 1, projection="3d")
for Xi in X_visu:
axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker="o", alpha=0.6)
axis.view_init(azim=45)
axis.set_xticks([])
axis.set_yticks([])
axis.set_zticks([])
plt.show()
#################################
# Barycenter computation and plot
# -------------------------------
Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary)
fig = plt.figure(figsize=(3, 3))
axis = fig.add_subplot(1, 1, 1, projection="3d")
for Xi in X_visu:
axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker="o", alpha=0.6)
axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker="o", alpha=0.6)
axis.view_init(azim=45)
axis.set_xticks([])
axis.set_yticks([])
axis.set_zticks([])
plt.show()
#############################
# Plotting projection matches
# ---------------------------
fig = plt.figure(figsize=(9, 3))
ax = fig.add_subplot(1, 3, 1, projection="3d")
for Xi in X_visu:
ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker="o", alpha=0.6)
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker="o", alpha=0.6)
ax.view_init(elev=0, azim=0)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax = fig.add_subplot(1, 3, 2, projection="3d")
for Xi in X_visu:
ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker="o", alpha=0.6)
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker="o", alpha=0.6)
ax.view_init(elev=0, azim=90)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax = fig.add_subplot(1, 3, 3, projection="3d")
for Xi in X_visu:
ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker="o", alpha=0.6)
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker="o", alpha=0.6)
ax.view_init(elev=90, azim=0)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
plt.tight_layout()
plt.show()
##############################################
# Rotation animation
# --------------------------------------------
fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(1, 1, 1, projection="3d")
def _init():
for Xi in X_visu:
ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker="o", alpha=0.6)
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker="o", alpha=0.6)
ax.view_init(elev=0, azim=0)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
return (fig,)
def _update_plot(i):
if i < 45:
ax.view_init(elev=0, azim=4 * i)
else:
ax.view_init(elev=i - 45, azim=4 * i)
return (fig,)
ani = animation.FuncAnimation(
fig,
_update_plot,
init_func=_init,
frames=136,
interval=50,
blit=True,
repeat_delay=2000,
)
|