1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
"""
Factored OT solvers (low rank, cost or OT plan)
"""
# Author: Remi Flamary <remi.flamary@polytehnique.edu>
#
# License: MIT License
from .backend import get_backend
from .utils import dist
from .lp import emd
from .bregman import sinkhorn
__all__ = ['factored_optimal_transport']
def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs):
r"""Solves factored OT problem and return OT plans and intermediate distribution
This function solve the following OT problem [40]_
.. math::
\mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b)
where :
- :math:`\mu_a` and :math:`\mu_b` are empirical distributions.
- :math:`\mu` is an empirical distribution with r samples
And returns the two OT plans between
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.
Uses the conditional gradient algorithm to solve the problem proposed in
:ref:`[39] <references-weak>`.
Parameters
----------
Xa : (ns,d) array-like, float
Source samples
Xb : (nt,d) array-like, float
Target samples
a : (ns,) array-like, float
Source histogram (uniform weight if empty list)
b : (nt,) array-like, float
Target histogram (uniform weight if empty list))
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on the relative variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
Ga: array-like, shape (ns, r)
Optimal transportation matrix between source and the intermediate
distribution
Gb: array-like, shape (r, nt)
Optimal transportation matrix between the intermediate and target
distribution
X: array-like, shape (r, d)
Support of the intermediate distribution
log: dict, optional
If input log is true, a dictionary containing the cost and dual
variables and exit status
.. _references-factored:
References
----------
.. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger,
G., & Weed, J. (2019, April). Statistical optimal transport via factored
couplings. In The 22nd International Conference on Artificial
Intelligence and Statistics (pp. 2454-2465). PMLR.
See Also
--------
ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General
regularized OT
"""
nx = get_backend(Xa, Xb)
n_a = Xa.shape[0]
n_b = Xb.shape[0]
d = Xa.shape[1]
if a is None:
a = nx.ones((n_a), type_as=Xa) / n_a
if b is None:
b = nx.ones((n_b), type_as=Xb) / n_b
if X0 is None:
X = nx.randn(r, d, type_as=Xa)
else:
X = X0
w = nx.ones(r, type_as=Xa) / r
def solve_ot(X1, X2, w1, w2):
M = dist(X1, X2)
if reg > 0:
G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs)
log['cost'] = nx.sum(G * M)
return G, log
else:
return emd(w1, w2, M, log=True, **kwargs)
norm_delta = []
# solve the barycenter
for i in range(numItermax):
old_X = X
# solve OT with template
Ga, loga = solve_ot(Xa, X, a, w)
Gb, logb = solve_ot(X, Xb, w, b)
X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r
delta = nx.norm(X - old_X)
if delta < stopThr:
break
if log:
norm_delta.append(delta)
if log:
log_dic = {'delta_iter': norm_delta,
'ua': loga['u'],
'va': loga['v'],
'ub': logb['u'],
'vb': logb['v'],
'costa': loga['cost'],
'costb': logb['cost'],
}
return Ga, Gb, X, log_dic
return Ga, Gb, X
|