File: test_dmmot.py

package info (click to toggle)
python-pot 0.9.5%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,884 kB
  • sloc: python: 56,498; cpp: 2,310; makefile: 265; sh: 19
file content (82 lines) | stat: -rw-r--r-- 2,399 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Tests for ot.lp.dmmot module"""

# Author: Ronak Mehta <ronakrm@cs.wisc.edu>
#         Xizheng Yu <xyu354@wisc.edu>
#
# License: MIT License

import numpy as np
import ot


def create_test_data(nx):
    n = 4
    a1 = ot.datasets.make_1D_gauss(n, m=20, s=5)
    a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
    A = np.vstack([a1, a2]).T
    x = np.arange(n, dtype=np.float64).reshape((n, 1))
    A, x = nx.from_numpy(A, x)
    return A, x


def test_dmmot_monge_1dgrid_loss(nx):
    A, x = create_test_data(nx)

    # Compute loss using dmmot_monge_1dgrid_loss
    primal_obj = ot.lp.dmmot_monge_1dgrid_loss(A)
    primal_obj = nx.to_numpy(primal_obj)
    expected_primal_obj = 0.13667759626298503

    np.testing.assert_allclose(
        primal_obj,
        expected_primal_obj,
        rtol=1e-7,
        err_msg="Test failed: \
                                   Expected different primal objective value",
    )

    # Compute loss using exact OT solver with absolute ground metric
    A, x = nx.to_numpy(A, x)
    M = ot.utils.dist(x, metric="cityblock")  # absolute ground metric
    bary, _ = ot.barycenter(A, M, 1e-2, weights=None, verbose=False, log=True)
    ot_obj = 0.0
    for x in A.T:
        # deal with C-contiguous error from tensorflow backend (not sure why)
        x = np.ascontiguousarray(x)
        # compute loss
        _, log = ot.lp.emd(x, np.array(bary / np.sum(bary)), M, log=True)
        ot_obj += log["cost"]

    np.testing.assert_allclose(
        primal_obj,
        ot_obj,
        rtol=1e-7,
        err_msg="Test failed: \
                                   Expected different primal objective value",
    )


def test_dmmot_monge_1dgrid_optimize(nx):
    # test discrete_mmot_converge result
    A, _ = create_test_data(nx)
    d = 2
    niters = 10
    result = ot.lp.dmmot_monge_1dgrid_optimize(A, niters, lr_init=1e-3, lr_decay=1)

    expected_obj = np.array(
        [
            [0.05553516, 0.13082618, 0.27327479, 0.54036388],
            [0.04185365, 0.09570724, 0.24384705, 0.61859206],
        ]
    )

    assert len(result) == d, "Test failed: Expected a list of length n"
    for i in range(d):
        np.testing.assert_allclose(
            result[i],
            expected_obj[i],
            atol=1e-7,
            rtol=1e-7,
            err_msg="Test failed: \
                                   Expected vectors of all zeros",
        )