1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
|
# -*- coding: utf-8 -*-
"""
Dictionary Learning based on Bregman projections for entropic regularized OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
import warnings
from ..utils import list_to_array
from ..backend import get_backend
from ._utils import projC, projR
def unmix(
a,
D,
M,
M0,
h0,
reg,
reg0,
alpha,
numItermax=1000,
stopThr=1e-3,
verbose=False,
log=False,
warn=True,
):
r"""
Compute the unmixing of an observation with a given dictionary using Wasserstein distance
The function solve the following optimization problem:
.. math::
\mathbf{h} = \mathop{\arg \min}_\mathbf{h} \quad
(1 - \alpha) W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a}, \mathbf{Dh}) +
\alpha W_{\mathbf{M_0}, \mathrm{reg}_0}(\mathbf{h}_0, \mathbf{h})
where :
- :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`)
- :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`,
its expected shape is `(dim_a, n_atoms)`
- :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms`
- :math:`\mathbf{a}` is an observed distribution of dimension `dim_a`
- :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior`
- `reg` and :math:`\mathbf{M}` are respectively the regularization term and the
cost matrix (`dim_a`, `dim_a`) for OT data fitting
- `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization
term and the cost matrix (`dim_prior`, `n_atoms`) regularization
- :math:`\alpha` weight data fitting and regularization
The optimization problem is solved following the algorithm described
in :ref:`[4] <references-unmix>`
Parameters
----------
a : array-like, shape (dim_a)
observed distribution (histogram, sums to 1)
D : array-like, shape (dim_a, n_atoms)
dictionary matrix
M : array-like, shape (dim_a, dim_a)
loss matrix
M0 : array-like, shape (n_atoms, dim_prior)
loss matrix
h0 : array-like, shape (n_atoms,)
prior on the estimated unmixing h
reg : float
Regularization term >0 (Wasserstein data fitting)
reg0 : float
Regularization term >0 (Wasserstein reg with h0)
alpha : float
How much should we trust the prior ([0,1])
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
h : array-like, shape (n_atoms,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
.. _references-unmix:
References
----------
.. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti,
Supervised planetary unmixing with optimal transport, Workshop
on Hyperspectral Image and Signal Processing :
Evolution in Remote Sensing (WHISPERS), 2016.
"""
a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0)
nx = get_backend(a, D, M, M0, h0)
# M = M/np.median(M)
K = nx.exp(-M / reg)
# M0 = M0/np.median(M0)
K0 = nx.exp(-M0 / reg0)
old = h0
err = 1
# log = {'niter':0, 'all_err':[]}
if log:
log = {"err": []}
for ii in range(numItermax):
K = projC(K, a)
K0 = projC(K0, h0)
new = nx.sum(K0, axis=1)
# we recombine the current selection from dictionary
inv_new = nx.dot(D, new)
other = nx.sum(K, axis=1)
# geometric interpolation
delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new))
K = projR(K, delta)
K0 = nx.dot(D.T, delta / inv_new)[:, None] * K0
err = nx.norm(nx.sum(K0, axis=1) - old)
old = new
if log:
log["err"].append(err)
if verbose:
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
if err < stopThr:
break
else:
if warn:
warnings.warn(
"Unmixing algorithm did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`."
)
if log:
log["niter"] = ii
return nx.sum(K0, axis=1), log
else:
return nx.sum(K0, axis=1)
|