File: plot_otda_mapping.py

package info (click to toggle)
python-pot 0.9.5%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,884 kB
  • sloc: python: 56,498; cpp: 2,310; makefile: 265; sh: 19
file content (142 lines) | stat: -rw-r--r-- 4,108 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# -*- coding: utf-8 -*-
"""
===========================================
OT mapping estimation for domain adaptation
===========================================

This example presents how to use MappingTransport to estimate at the same
time both the coupling transport and approximate the transport map with either
a linear or a kernelized mapping as introduced in [8].

[8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
"Mapping estimation for discrete optimal transport",
Neural Information Processing Systems (NIPS), 2016.
"""

# Authors: Remi Flamary <remi.flamary@unice.fr>
#          Stanislas Chambon <stan.chambon@gmail.com>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2

import numpy as np
import matplotlib.pylab as pl
import ot


##############################################################################
# Generate data
# -------------

n_source_samples = 100
n_target_samples = 100
theta = 2 * np.pi / 20
noise_level = 0.1

Xs, ys = ot.datasets.make_data_classif("gaussrot", n_source_samples, nz=noise_level)
Xs_new, _ = ot.datasets.make_data_classif("gaussrot", n_source_samples, nz=noise_level)
Xt, yt = ot.datasets.make_data_classif(
    "gaussrot", n_target_samples, theta=theta, nz=noise_level
)

# one of the target mode changes its variance (no linear mapping)
Xt[yt == 2] *= 3
Xt = Xt + 4

##############################################################################
# Plot data
# ---------

pl.figure(1, (10, 5))
pl.clf()
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples")
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples")
pl.legend(loc=0)
pl.title("Source and target distributions")


##############################################################################
# Instantiate the different transport algorithms and fit them
# -----------------------------------------------------------

# MappingTransport with linear kernel
ot_mapping_linear = ot.da.MappingTransport(
    kernel="linear", mu=1e0, eta=1e-8, bias=True, max_iter=20, verbose=True
)

ot_mapping_linear.fit(Xs=Xs, Xt=Xt)

# for original source samples, transform applies barycentric mapping
transp_Xs_linear = ot_mapping_linear.transform(Xs=Xs)

# for out of source samples, transform applies the linear mapping
transp_Xs_linear_new = ot_mapping_linear.transform(Xs=Xs_new)


# MappingTransport with gaussian kernel
ot_mapping_gaussian = ot.da.MappingTransport(
    kernel="gaussian", eta=1e-5, mu=1e-1, bias=True, sigma=1, max_iter=10, verbose=True
)
ot_mapping_gaussian.fit(Xs=Xs, Xt=Xt)

# for original source samples, transform applies barycentric mapping
transp_Xs_gaussian = ot_mapping_gaussian.transform(Xs=Xs)

# for out of source samples, transform applies the gaussian mapping
transp_Xs_gaussian_new = ot_mapping_gaussian.transform(Xs=Xs_new)


##############################################################################
# Plot transported samples
# ------------------------

pl.figure(2)
pl.clf()
pl.subplot(2, 2, 1)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.2)
pl.scatter(
    transp_Xs_linear[:, 0],
    transp_Xs_linear[:, 1],
    c=ys,
    marker="+",
    label="Mapped source samples",
)
pl.title("Bary. mapping (linear)")
pl.legend(loc=0)

pl.subplot(2, 2, 2)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.2)
pl.scatter(
    transp_Xs_linear_new[:, 0],
    transp_Xs_linear_new[:, 1],
    c=ys,
    marker="+",
    label="Learned mapping",
)
pl.title("Estim. mapping (linear)")

pl.subplot(2, 2, 3)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.2)
pl.scatter(
    transp_Xs_gaussian[:, 0],
    transp_Xs_gaussian[:, 1],
    c=ys,
    marker="+",
    label="barycentric mapping",
)
pl.title("Bary. mapping (kernel)")

pl.subplot(2, 2, 4)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.2)
pl.scatter(
    transp_Xs_gaussian_new[:, 0],
    transp_Xs_gaussian_new[:, 1],
    c=ys,
    marker="+",
    label="Learned mapping",
)
pl.title("Estim. mapping (kernel)")
pl.tight_layout()

pl.show()