1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
"""Tests for gromov._lowrank.py"""
# Author: Laurène DAVID <laurene.david@ip-paris.fr>
#
# License: MIT License
import ot
import numpy as np
import pytest
def test__flat_product_operator():
# test flat product operator
n, d = 100, 2
X = np.reshape(1.0 * np.arange(2 * n), (n, d))
A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False)
A1_ = ot.gromov._lowrank._flat_product_operator(A1)
A2_ = ot.gromov._lowrank._flat_product_operator(A2)
cost = ot.dist(X, X)
# test value
np.testing.assert_allclose(cost**2, np.dot(A1_, A2_.T), atol=1e-05)
def test_lowrank_gromov_wasserstein_samples():
# test low rank gromov wasserstein
n_samples = 20 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
X_s = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1)
X_t = X_s[::-1].copy()
a = ot.unif(n_samples)
b = ot.unif(n_samples)
Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples(
X_s, X_t, a, b, reg=0.1, log=True, rescale_cost=False
)
P = log["lazy_plan"][:]
# check constraints for P
np.testing.assert_allclose(a, P.sum(1), atol=1e-04)
np.testing.assert_allclose(b, P.sum(0), atol=1e-04)
# check if lazy_plan is equal to the fully computed plan
P_true = np.dot(Q, np.dot(np.diag(1 / g), R.T))
np.testing.assert_allclose(P, P_true, atol=1e-05)
# check warn parameter when low rank GW algorithm doesn't converge
with pytest.warns(UserWarning):
ot.gromov.lowrank_gromov_wasserstein_samples(
X_s,
X_t,
a,
b,
reg=0.1,
stopThr=0,
numItermax=1,
warn=True,
warn_dykstra=False,
)
# check warn parameter when Dykstra algorithm doesn't converge
with pytest.warns(UserWarning):
ot.gromov.lowrank_gromov_wasserstein_samples(
X_s,
X_t,
a,
b,
reg=0.1,
stopThr_dykstra=0,
numItermax_dykstra=1,
warn=False,
warn_dykstra=True,
)
@pytest.mark.parametrize(("alpha, rank"), ((0.8, 2), (0.5, 3), (0.2, 6), (0.1, -1)))
def test_lowrank_gromov_wasserstein_samples_alpha_error(alpha, rank):
# Test warning for value of alpha and rank
n_samples = 20 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
X_s = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1)
X_t = X_s[::-1].copy()
a = ot.unif(n_samples)
b = ot.unif(n_samples)
with pytest.raises(ValueError):
ot.gromov.lowrank_gromov_wasserstein_samples(
X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False
)
@pytest.mark.parametrize(("gamma_init"), ("rescale", "theory", "other"))
def test_lowrank_wasserstein_samples_gamma_init(gamma_init):
# Test lr sinkhorn with different init strategies
n_samples = 20 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
X_s = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1)
X_t = X_s[::-1].copy()
a = ot.unif(n_samples)
b = ot.unif(n_samples)
if gamma_init not in ["rescale", "theory"]:
with pytest.raises(NotImplementedError):
ot.gromov.lowrank_gromov_wasserstein_samples(
X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True
)
else:
Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples(
X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True
)
P = log["lazy_plan"][:]
# check constraints for P
np.testing.assert_allclose(a, P.sum(1), atol=1e-04)
np.testing.assert_allclose(b, P.sum(0), atol=1e-04)
@pytest.skip_backend("tf")
def test_lowrank_gromov_wasserstein_samples_backends(nx):
# Test low rank sinkhorn for different backends
n_samples = 20 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
X_s = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1)
X_t = X_s[::-1].copy()
a = ot.unif(n_samples)
b = ot.unif(n_samples)
ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t)
Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples(
X_sb, X_tb, ab, bb, reg=0.1, log=True
)
lazy_plan = log["lazy_plan"]
P = lazy_plan[:]
np.testing.assert_allclose(ab, P.sum(1), atol=1e-04)
np.testing.assert_allclose(bb, P.sum(0), atol=1e-04)
|