1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
|
# -*- coding: utf-8 -*-
"""
===============================================================
Translation Invariant Sinkhorn for Unbalanced Optimal Transport
===============================================================
This examples illustrates the better convergence of the translation
invariance Sinkhorn algorithm proposed in [73] compared to the classical
Sinkhorn algorithm.
[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022).
Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe.
In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
"""
# Author: Clément Bonet <clement.bonet@ensae.fr>
# License: MIT License
import numpy as np
import matplotlib.pylab as pl
import ot
##############################################################################
# Setting parameters
# -------------
# %% parameters
n_iter = 50 # nb iters
n = 40 # nb samples
num_iter_max = 100
n_noise = 10
reg = 0.005
reg_m_kl = 0.05
mu_s = np.array([-1, -1])
cov_s = np.array([[1, 0], [0, 1]])
mu_t = np.array([4, 4])
cov_t = np.array([[1, -0.8], [-0.8, 1]])
##############################################################################
# Compute entropic kl-regularized UOT with Sinkhorn and Translation Invariant Sinkhorn
# -----------
err_sinkhorn_uot = np.empty((n_iter, num_iter_max))
err_sinkhorn_uot_ti = np.empty((n_iter, num_iter_max))
for seed in range(n_iter):
np.random.seed(seed)
xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
xs = np.concatenate((xs, (np.random.rand(n_noise, 2) - 4)), axis=0)
xt = np.concatenate((xt, (np.random.rand(n_noise, 2) + 6)), axis=0)
n = n + n_noise
a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
# loss matrix
M = ot.dist(xs, xt)
M /= M.max()
entropic_kl_uot, log_uot = ot.unbalanced.sinkhorn_unbalanced(
a,
b,
M,
reg,
reg_m_kl,
reg_type="kl",
log=True,
numItermax=num_iter_max,
stopThr=0,
)
entropic_kl_uot_ti, log_uot_ti = ot.unbalanced.sinkhorn_unbalanced(
a,
b,
M,
reg,
reg_m_kl,
reg_type="kl",
method="sinkhorn_translation_invariant",
log=True,
numItermax=num_iter_max,
stopThr=0,
)
err_sinkhorn_uot[seed] = log_uot["err"]
err_sinkhorn_uot_ti[seed] = log_uot_ti["err"]
##############################################################################
# Plot the results
# ----------------
mean_sinkh = np.mean(err_sinkhorn_uot, axis=0)
std_sinkh = np.std(err_sinkhorn_uot, axis=0)
mean_sinkh_ti = np.mean(err_sinkhorn_uot_ti, axis=0)
std_sinkh_ti = np.std(err_sinkhorn_uot_ti, axis=0)
absc = list(range(num_iter_max))
pl.plot(absc, mean_sinkh, label="Sinkhorn")
pl.fill_between(absc, mean_sinkh - 2 * std_sinkh, mean_sinkh + 2 * std_sinkh, alpha=0.5)
pl.plot(absc, mean_sinkh_ti, label="Translation Invariant Sinkhorn")
pl.fill_between(
absc, mean_sinkh_ti - 2 * std_sinkh_ti, mean_sinkh_ti + 2 * std_sinkh_ti, alpha=0.5
)
pl.yscale("log")
pl.legend()
pl.xlabel("Number of Iterations")
pl.ylabel(r"$\|u-v\|_\infty$")
pl.grid(True)
pl.show()
|