1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
# -*- coding: utf-8 -*-
"""
========================================
Low rank Sinkhorn
========================================
This example illustrates the computation of Low Rank Sinkhorn [26].
[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021).
"Low-rank Sinkhorn factorization". In International Conference on Machine Learning.
"""
# Author: Laurène David <laurene.david@ip-paris.fr>
#
# License: MIT License
#
# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
import ot.plot
from ot.datasets import make_1D_gauss as gauss
##############################################################################
# Generate data
# -------------
# %% parameters
n = 100
m = 120
# Gaussian distribution
a = gauss(n, m=int(n / 3), s=25 / np.sqrt(2)) + 1.5 * gauss(
n, m=int(5 * n / 6), s=15 / np.sqrt(2)
)
a = a / np.sum(a)
b = 2 * gauss(m, m=int(m / 5), s=30 / np.sqrt(2)) + gauss(
m, m=int(m / 2), s=35 / np.sqrt(2)
)
b = b / np.sum(b)
# Source and target distribution
X = np.arange(n).reshape(-1, 1)
Y = np.arange(m).reshape(-1, 1)
##############################################################################
# Solve Low rank sinkhorn
# ------------
# %%
# Solve low rank sinkhorn
Q, R, g, log = ot.lowrank_sinkhorn(
X,
Y,
a,
b,
rank=10,
init="random",
gamma_init="rescale",
rescale_cost=True,
warn=False,
log=True,
)
P = log["lazy_plan"][:]
ot.plot.plot1D_mat(a, b, P, "OT matrix Low rank")
##############################################################################
# Sinkhorn vs Low Rank Sinkhorn
# -----------------------
# Compare Sinkhorn and Low rank sinkhorn with different regularizations and ranks.
# %% Sinkhorn
# Compute cost matrix for sinkhorn OT
M = ot.dist(X, Y)
M = M / np.max(M)
# Solve sinkhorn with different regularizations using ot.solve
list_reg = [0.05, 0.005, 0.001]
list_P_Sin = []
for reg in list_reg:
P = ot.solve(M, a, b, reg=reg, max_iter=2000, tol=1e-8).plan
list_P_Sin.append(P)
# %% Low rank sinkhorn
# Solve low rank sinkhorn with different ranks using ot.solve_sample
list_rank = [3, 10, 50]
list_P_LR = []
for rank in list_rank:
P = ot.solve_sample(X, Y, a, b, method="lowrank", rank=rank).plan
P = P[:]
list_P_LR.append(P)
# %%
# Plot sinkhorn vs low rank sinkhorn
pl.figure(1, figsize=(10, 8))
pl.subplot(2, 3, 1)
pl.imshow(list_P_Sin[0], interpolation="nearest")
pl.axis("off")
pl.title("Sinkhorn (reg=0.05)")
pl.subplot(2, 3, 2)
pl.imshow(list_P_Sin[1], interpolation="nearest")
pl.axis("off")
pl.title("Sinkhorn (reg=0.005)")
pl.subplot(2, 3, 3)
pl.imshow(list_P_Sin[2], interpolation="nearest")
pl.axis("off")
pl.title("Sinkhorn (reg=0.001)")
pl.show()
pl.subplot(2, 3, 4)
pl.imshow(list_P_LR[0], interpolation="nearest")
pl.axis("off")
pl.title("Low rank (rank=3)")
pl.subplot(2, 3, 5)
pl.imshow(list_P_LR[1], interpolation="nearest")
pl.axis("off")
pl.title("Low rank (rank=10)")
pl.subplot(2, 3, 6)
pl.imshow(list_P_LR[2], interpolation="nearest")
pl.axis("off")
pl.title("Low rank (rank=50)")
pl.tight_layout()
|