File: plot_otda_laplacian.py

package info (click to toggle)
python-pot 0.9.5%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,884 kB
  • sloc: python: 56,498; cpp: 2,310; makefile: 265; sh: 19
file content (142 lines) | stat: -rw-r--r-- 3,729 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# -*- coding: utf-8 -*-
"""
======================================================
OT with Laplacian regularization for domain adaptation
======================================================

This example introduces a domain adaptation in a 2D setting and OTDA
approach with Laplacian regularization.

"""

# Authors: Ievgen Redko <ievgen.redko@univ-st-etienne.fr>

# License: MIT License

import matplotlib.pylab as pl
import ot

##############################################################################
# Generate data
# -------------

n_source_samples = 150
n_target_samples = 150

Xs, ys = ot.datasets.make_data_classif("3gauss", n_source_samples)
Xt, yt = ot.datasets.make_data_classif("3gauss2", n_target_samples)


##############################################################################
# Instantiate the different transport algorithms and fit them
# -----------------------------------------------------------

# EMD Transport
ot_emd = ot.da.EMDTransport()
ot_emd.fit(Xs=Xs, Xt=Xt)

# Sinkhorn Transport
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=0.01)
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)

# EMD Transport with Laplacian regularization
ot_emd_laplace = ot.da.EMDLaplaceTransport(reg_lap=100, reg_src=1)
ot_emd_laplace.fit(Xs=Xs, Xt=Xt)

# transport source samples onto target samples
transp_Xs_emd = ot_emd.transform(Xs=Xs)
transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
transp_Xs_emd_laplace = ot_emd_laplace.transform(Xs=Xs)

##############################################################################
# Fig 1 : plots source and target samples
# ---------------------------------------

pl.figure(1, figsize=(10, 5))
pl.subplot(1, 2, 1)
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples")
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title("Source  samples")

pl.subplot(1, 2, 2)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples")
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title("Target samples")
pl.tight_layout()


##############################################################################
# Fig 2 : plot optimal couplings and transported samples
# ------------------------------------------------------

param_img = {"interpolation": "nearest"}

pl.figure(2, figsize=(15, 8))
pl.subplot(2, 3, 1)
pl.imshow(ot_emd.coupling_, **param_img)
pl.xticks([])
pl.yticks([])
pl.title("Optimal coupling\nEMDTransport")

pl.figure(2, figsize=(15, 8))
pl.subplot(2, 3, 2)
pl.imshow(ot_sinkhorn.coupling_, **param_img)
pl.xticks([])
pl.yticks([])
pl.title("Optimal coupling\nSinkhornTransport")

pl.subplot(2, 3, 3)
pl.imshow(ot_emd_laplace.coupling_, **param_img)
pl.xticks([])
pl.yticks([])
pl.title("Optimal coupling\nEMDLaplaceTransport")

pl.subplot(2, 3, 4)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.3)
pl.scatter(
    transp_Xs_emd[:, 0],
    transp_Xs_emd[:, 1],
    c=ys,
    marker="+",
    label="Transp samples",
    s=30,
)
pl.xticks([])
pl.yticks([])
pl.title("Transported samples\nEmdTransport")
pl.legend(loc="lower left")

pl.subplot(2, 3, 5)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.3)
pl.scatter(
    transp_Xs_sinkhorn[:, 0],
    transp_Xs_sinkhorn[:, 1],
    c=ys,
    marker="+",
    label="Transp samples",
    s=30,
)
pl.xticks([])
pl.yticks([])
pl.title("Transported samples\nSinkhornTransport")

pl.subplot(2, 3, 6)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.3)
pl.scatter(
    transp_Xs_emd_laplace[:, 0],
    transp_Xs_emd_laplace[:, 1],
    c=ys,
    marker="+",
    label="Transp samples",
    s=30,
)
pl.xticks([])
pl.yticks([])
pl.title("Transported samples\nEMDLaplaceTransport")
pl.tight_layout()

pl.show()