1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
|
# -*- coding: utf-8 -*-
"""
==================
OT distances in 1D
==================
Shows how to compute multiple Wasserstein and Sinkhorn with two different
ground metrics and plot their values for different distributions.
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
import ot
from ot.datasets import make_1D_gauss as gauss
##############################################################################
# Generate data
# -------------
# %% parameters
n = 100 # nb bins
n_target = 20 # nb target distributions
# bin positions
x = np.arange(n, dtype=np.float64)
lst_m = np.linspace(20, 90, n_target)
# Gaussian distributions
a = gauss(n, m=20, s=5) # m= mean, s= std
B = np.zeros((n, n_target))
for i, m in enumerate(lst_m):
B[:, i] = gauss(n, m=m, s=5)
# loss matrix and normalization
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), "euclidean")
M /= M.max() * 0.1
M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), "sqeuclidean")
M2 /= M2.max() * 0.1
##############################################################################
# Plot data
# ---------
# %% plot the distributions
pl.figure(1)
pl.subplot(2, 1, 1)
pl.plot(x, a, "r", label="Source distribution")
pl.title("Source distribution")
pl.subplot(2, 1, 2)
for i in range(n_target):
pl.plot(x, B[:, i], "b", alpha=i / n_target)
pl.plot(x, B[:, -1], "b", label="Target distributions")
pl.title("Target distributions")
pl.tight_layout()
##############################################################################
# Compute EMD for the different losses
# ------------------------------------
# %% Compute and plot distributions and loss matrix
d_emd = ot.emd2(a, B, M) # direct computation of OT loss
d_emd2 = ot.emd2(a, B, M2) # direct computation of OT loss with metric M2
d_tv = [np.sum(abs(a - B[:, i])) for i in range(n_target)]
pl.figure(2)
pl.subplot(2, 1, 1)
pl.plot(x, a, "r", label="Source distribution")
pl.title("Distributions")
for i in range(n_target):
pl.plot(x, B[:, i], "b", alpha=i / n_target)
pl.plot(x, B[:, -1], "b", label="Target distributions")
pl.ylim((-0.01, 0.13))
pl.xticks(())
pl.legend()
pl.subplot(2, 1, 2)
pl.plot(d_emd, label="Euclidean OT")
pl.plot(d_emd2, label="Squared Euclidean OT")
pl.plot(d_tv, label="Total Variation (TV)")
# pl.xlim((-7,23))
pl.xlabel("Displacement")
pl.title("Divergences")
pl.legend()
##############################################################################
# Compute Sinkhorn for the different losses
# -----------------------------------------
# %%
reg = 1e-1
d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
pl.figure(3)
pl.clf()
pl.subplot(2, 1, 1)
pl.plot(x, a, "r", label="Source distribution")
pl.title("Distributions")
for i in range(n_target):
pl.plot(x, B[:, i], "b", alpha=i / n_target)
pl.plot(x, B[:, -1], "b", label="Target distributions")
pl.ylim((-0.01, 0.13))
pl.xticks(())
pl.legend()
pl.subplot(2, 1, 2)
pl.plot(d_emd, label="Euclidean OT")
pl.plot(d_emd2, label="Squared Euclidean OT")
pl.plot(d_sinkhorn, "+", label="Euclidean Sinkhorn")
pl.plot(d_sinkhorn2, "+", label="Squared Euclidean Sinkhorn")
pl.plot(d_tv, label="Total Variation (TV)")
# pl.xlim((-7,23))
pl.xlabel("Displacement")
pl.title("Divergences")
pl.legend()
pl.show()
|