1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
|
# -*- coding: utf-8 -*-
r"""
=================================
Wasserstein unmixing with PyTorch
=================================
In this example we estimate mixing parameters from distributions that minimize
the Wasserstein distance. In other words we suppose that a target
distribution :math:`\mu^t` can be expressed as a weighted sum of source
distributions :math:`\mu^s_k` with the following model:
.. math::
\mu^t = \sum_{k=1}^K w_k\mu^s_k
where :math:`\mathbf{w}` is a vector of size :math:`K` and belongs in the
distribution simplex :math:`\Delta_K`.
In order to estimate this weight vector we propose to optimize the Wasserstein
distance between the model and the observed :math:`\mu^t` with respect to
the vector. This leads to the following optimization problem:
.. math::
\min_{\mathbf{w}\in\Delta_K} \quad W \left(\mu^t,\sum_{k=1}^K w_k\mu^s_k\right)
This minimization is done in this example with a simple projected gradient
descent in PyTorch. We use the automatic backend of POT that allows us to
compute the Wasserstein distance with :any:`ot.emd2` with
differentiable losses.
"""
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
import ot
import torch
##############################################################################
# Generate data
# -------------
# %% Data
nt = 100
nt1 = 10 #
ns1 = 50
ns = 2 * ns1
rng = np.random.RandomState(2)
xt = rng.randn(nt, 2) * 0.2
xt[:nt1, 0] += 1
xt[nt1:, 1] += 1
xs1 = rng.randn(ns1, 2) * 0.2
xs1[:, 0] += 1
xs2 = rng.randn(ns1, 2) * 0.2
xs2[:, 1] += 1
xs = np.concatenate((xs1, xs2))
# Sample reweighting matrix H
H = np.zeros((ns, 2))
H[:ns1, 0] = 1 / ns1
H[ns1:, 1] = 1 / ns1
# each columns sums to 1 and has weights only for samples form the
# corresponding source distribution
M = ot.dist(xs, xt)
##############################################################################
# Plot data
# ---------
# %% plot the distributions
pl.figure(1)
pl.scatter(xt[:, 0], xt[:, 1], label="Target $\mu^t$", alpha=0.5)
pl.scatter(xs1[:, 0], xs1[:, 1], label="Source $\mu^s_1$", alpha=0.5)
pl.scatter(xs2[:, 0], xs2[:, 1], label="Source $\mu^s_2$", alpha=0.5)
pl.title("Sources and Target distributions")
pl.legend()
##############################################################################
# Optimization of the model wrt the Wasserstein distance
# ------------------------------------------------------
# %% Weights optimization with gradient descent
# convert numpy arrays to torch tensors
H2 = torch.tensor(H)
M2 = torch.tensor(M)
# weights for the source distributions
w = torch.tensor(ot.unif(2), requires_grad=True)
# uniform weights for target
b = torch.tensor(ot.unif(nt))
lr = 2e-3 # learning rate
niter = 500 # number of iterations
losses = [] # loss along the iterations
# loss for the minimal Wasserstein estimator
def get_loss(w):
a = torch.mv(H2, w) # distribution reweighting
return ot.emd2(a, b, M2) # squared Wasserstein 2
for i in range(niter):
loss = get_loss(w)
losses.append(float(loss))
loss.backward()
with torch.no_grad():
w -= lr * w.grad # gradient step
w[:] = ot.utils.proj_simplex(w) # projection on the simplex
w.grad.zero_()
##############################################################################
# Estimated weights and convergence of the objective
# --------------------------------------------------
we = w.detach().numpy()
print("Estimated mixture:", we)
pl.figure(2)
pl.semilogy(losses)
pl.grid()
pl.title("Wasserstein distance")
pl.xlabel("Iterations")
##############################################################################
# Plotting the reweighted source distribution
# -------------------------------------------
pl.figure(3)
# compute source weights
ws = H.dot(we)
pl.scatter(xt[:, 0], xt[:, 1], label="Target $\mu^t$", alpha=0.5)
pl.scatter(
xs[:, 0],
xs[:, 1],
color="C3",
s=ws * 20 * ns,
label="Weighted sources $\sum_{k} w_k\mu^s_k$",
alpha=0.5,
)
pl.title("Target and reweighted source distributions")
pl.legend()
|