File: plot_semirelaxed_gromov_wasserstein_barycenter.py

package info (click to toggle)
python-pot 0.9.5%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,884 kB
  • sloc: python: 56,498; cpp: 2,310; makefile: 265; sh: 19
file content (333 lines) | stat: -rw-r--r-- 11,484 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
# -*- coding: utf-8 -*-

r"""
=====================================================
Semi-relaxed (Fused) Gromov-Wasserstein Barycenter as Dictionary Learning
=====================================================

In this example, we illustrate how to learn a semi-relaxed Gromov-Wasserstein
(srGW) barycenter using a Block-Coordinate Descent algorithm, on a dataset of
structured data such as graphs, denoted :math:`\{ \mathbf{C_s} \}_{s \in [S]}`
where every nodes have uniform weights :math:`\{ \mathbf{p_s} \}_{s \in [S]}`.
Given a barycenter structure matrix :math:`\mathbf{C}` with N nodes,
each graph :math:`(\mathbf{C_s}, \mathbf{p_s})` is modeled as a reweighed subgraph
with structure :math:`\mathbf{C}` and weights :math:`\mathbf{w_s} \in \Sigma_N`
where each :math:`\mathbf{w_s}` corresponds to the second marginal of the OT
:math:`\mathbf{T_s}` (s.t :math:`\mathbf{w_s} = \mathbf{T_s}^\top \mathbf{1}`)
minimizing the srGW loss between the s^{th} input and the barycenter.


First, we consider a dataset composed of graphs generated by Stochastic Block models
with variable sizes taken in :math:`\{30, ... , 50\}` and number of clusters
varying in :math:`\{ 1, 2, 3\}` with random proportions. We learn a srGW barycenter
with 3 nodes and visualize the learned structure and the embeddings for some inputs.

Second, we illustrate the extension of this framework to graphs endowed
with node features by using the semi-relaxed Fused Gromov-Wasserstein
divergence (srFGW). Starting from the aforementioned dataset of unattributed graphs, we
add discrete labels uniformly depending on the number of clusters. Then conduct
the analog analysis.


[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs".
International Conference on Learning Representations (ICLR), 2022.

"""
# Author: Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2

import numpy as np
import matplotlib.pylab as pl
from sklearn.manifold import MDS
from ot.gromov import semirelaxed_gromov_barycenters, semirelaxed_fgw_barycenters
import ot
import networkx
from networkx.generators.community import stochastic_block_model as sbm

#############################################################################
#
# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.
# -----------------------------------------------------------------------------------------------

np.random.seed(42)

n_samples = 60  # number of graphs in the dataset
# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability,
# and variable cluster proportions.
clusters = [1, 2, 3]
Nc = n_samples // len(clusters)  # number of graphs by cluster
nlabels = len(clusters)
dataset = []
node_labels = []
labels = []

p_inter = 0.1
p_intra = 0.9
for n_cluster in clusters:
    for i in range(Nc):
        n_nodes = int(np.random.uniform(low=30, high=50))

        if n_cluster > 1:
            P = p_inter * np.ones((n_cluster, n_cluster))
            np.fill_diagonal(P, p_intra)
            props = np.random.uniform(0.2, 1, size=(n_cluster,))
            props /= props.sum()
            sizes = np.round(n_nodes * props).astype(np.int32)
        else:
            P = p_intra * np.eye(1)
            sizes = [n_nodes]

        G = sbm(sizes, P, seed=i, directed=False)
        part = np.array([G.nodes[i]["block"] for i in range(np.sum(sizes))])
        C = networkx.to_numpy_array(G)
        dataset.append(C)
        node_labels.append(part)
        labels.append(n_cluster)


# Visualize samples


def plot_graph(x, C, binary=True, color="C0", s=None):
    for j in range(C.shape[0]):
        for i in range(j):
            if binary:
                if C[i, j] > 0:
                    pl.plot(
                        [x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color="k"
                    )
            else:  # connection intensity proportional to C[i,j]
                pl.plot(
                    [x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color="k"
                )

    pl.scatter(
        x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors="k", cmap="tab10", vmax=9
    )


pl.figure(1, (12, 8))
pl.clf()
for idx_c, c in enumerate(clusters):
    C = dataset[(c - 1) * Nc]  # sample with c clusters
    # get 2d position for nodes
    x = MDS(dissimilarity="precomputed", random_state=0).fit_transform(1 - C)
    pl.subplot(2, nlabels, c)
    pl.title("(graph) sample from label " + str(c), fontsize=14)
    plot_graph(x, C, binary=True, color="C0", s=50.0)
    pl.axis("off")
    pl.subplot(2, nlabels, nlabels + c)
    pl.title("(matrix) sample from label %s \n" % c, fontsize=14)
    pl.imshow(C, interpolation="nearest")
    pl.axis("off")
pl.tight_layout()
pl.show()

#############################################################################
#
# Estimate the srGW barycenter from the dataset and visualize embeddings
# -----------------------------------------------------------


np.random.seed(0)
ps = [ot.unif(C.shape[0]) for C in dataset]  # uniform weights on input nodes
lambdas = [1.0 / n_samples for _ in range(n_samples)]  # uniform barycenter
N = 3  # 3 nodes in the barycenter

# Here we use the Fluid partitioning method to deduce initial transport plans
# for the barycenter problem. An initlal structure is also deduced from these
# initial transport plans. Then a warmstart strategy is used iteratively to
# init each individual srGW problem within the BCD algorithm.

init_plan = "fluid"  # notice that several init options are implemented in `ot.gromov.semirelaxed_init_plan`
warmstartT = True

C, log = semirelaxed_gromov_barycenters(
    N=N,
    Cs=dataset,
    ps=ps,
    lambdas=lambdas,
    loss_fun="square_loss",
    tol=1e-6,
    stop_criterion="loss",
    warmstartT=warmstartT,
    log=True,
    G0=init_plan,
    verbose=False,
)

print("barycenter structure:", C)

unmixings = log["p"]
# Compute the 2D representation of the embeddings living in the 2-simplex of probability
unmixings2D = np.zeros(shape=(n_samples, 2))
for i, w in enumerate(unmixings):
    unmixings2D[i, 0] = (2.0 * w[1] + w[2]) / 2.0
    unmixings2D[i, 1] = (np.sqrt(3.0) * w[2]) / 2.0
x = [0.0, 0.0]
y = [1.0, 0.0]
z = [0.5, np.sqrt(3) / 2.0]
extremities = np.stack([x, y, z])

pl.figure(2, (4, 4))
pl.clf()
pl.title("Embedding space", fontsize=14)
for cluster in range(nlabels):
    start, end = Nc * cluster, Nc * (cluster + 1)
    if cluster == 0:
        pl.scatter(
            unmixings2D[start:end, 0],
            unmixings2D[start:end, 1],
            c="C" + str(cluster),
            marker="o",
            s=80.0,
            label="1 cluster",
        )
    else:
        pl.scatter(
            unmixings2D[start:end, 0],
            unmixings2D[start:end, 1],
            c="C" + str(cluster),
            marker="o",
            s=80.0,
            label="%s clusters" % (cluster + 1),
        )
pl.scatter(
    extremities[:, 0],
    extremities[:, 1],
    c="black",
    marker="x",
    s=100.0,
    label="bary. nodes",
)
pl.plot([x[0], y[0]], [x[1], y[1]], color="black", linewidth=2.0)
pl.plot([x[0], z[0]], [x[1], z[1]], color="black", linewidth=2.0)
pl.plot([y[0], z[0]], [y[1], z[1]], color="black", linewidth=2.0)
pl.axis("off")
pl.legend(fontsize=11)
pl.tight_layout()
pl.show()

#############################################################################
#
# Endow the dataset with node features
# ------------------------------------
# node labels, corresponding to the true SBM cluster assignments,
# are set for each graph as one-hot encoded node features.

dataset_features = []
for i in range(len(dataset)):
    n = dataset[i].shape[0]
    F = np.zeros((n, 3))
    F[np.arange(n), node_labels[i]] = 1.0
    dataset_features.append(F)

pl.figure(3, (12, 8))
pl.clf()
for idx_c, c in enumerate(clusters):
    C = dataset[(c - 1) * Nc]  # sample with c clusters
    F = dataset_features[(c - 1) * Nc]
    colors = [f"C{labels[i]}" for i in range(F.shape[0])]
    # get 2d position for nodes
    x = MDS(dissimilarity="precomputed", random_state=0).fit_transform(1 - C)
    pl.subplot(2, nlabels, c)
    pl.title("(graph) sample from label " + str(c), fontsize=14)
    plot_graph(x, C, binary=True, color=colors, s=50)
    pl.axis("off")
    pl.subplot(2, nlabels, nlabels + c)
    pl.title("(matrix) sample from label %s \n" % c, fontsize=14)
    pl.imshow(C, interpolation="nearest")
    pl.axis("off")
pl.tight_layout()
pl.show()

#############################################################################
#
# Estimate the srFGW barycenter from the attributed graphs and visualize embeddings
# -----------------------------------------------------------
# We emphasize the dependence to the trade-off parameter alpha that weights the
# relative importance between structures (alpha=1) and features (alpha=0),
# knowing that embeddings that perfectly cluster graphs w.r.t their features
# should ease the identification of the number of clusters in the graphs.

list_alphas = [0.0001, 0.5, 0.9999]
list_unmixings2D = []

for ialpha, alpha in enumerate(list_alphas):
    print("--- alpha:", alpha)
    C, F, log = semirelaxed_fgw_barycenters(
        N=N,
        Ys=dataset_features,
        Cs=dataset,
        ps=ps,
        lambdas=lambdas,
        alpha=alpha,
        loss_fun="square_loss",
        tol=1e-6,
        stop_criterion="loss",
        warmstartT=warmstartT,
        log=True,
        G0=init_plan,
    )

    print("barycenter structure:", C)
    print("barycenter features:", F)

    unmixings = log["p"]
    # Compute the 2D representation of the embeddings living in the 2-simplex of probability
    unmixings2D = np.zeros(shape=(n_samples, 2))
    for i, w in enumerate(unmixings):
        unmixings2D[i, 0] = (2.0 * w[1] + w[2]) / 2.0
        unmixings2D[i, 1] = (np.sqrt(3.0) * w[2]) / 2.0
    list_unmixings2D.append(unmixings2D.copy())

x = [0.0, 0.0]
y = [1.0, 0.0]
z = [0.5, np.sqrt(3) / 2.0]
extremities = np.stack([x, y, z])

pl.figure(4, (12, 4))
pl.clf()
pl.suptitle("Embedding spaces", fontsize=14)
for ialpha, alpha in enumerate(list_alphas):
    pl.subplot(1, len(list_alphas), ialpha + 1)
    pl.title(f"alpha = {alpha}", fontsize=14)
    for cluster in range(nlabels):
        start, end = Nc * cluster, Nc * (cluster + 1)
        if cluster == 0:
            pl.scatter(
                list_unmixings2D[ialpha][start:end, 0],
                list_unmixings2D[ialpha][start:end, 1],
                c="C" + str(cluster),
                marker="o",
                s=80.0,
                label="1 cluster",
            )
        else:
            pl.scatter(
                list_unmixings2D[ialpha][start:end, 0],
                list_unmixings2D[ialpha][start:end, 1],
                c="C" + str(cluster),
                marker="o",
                s=80.0,
                label="%s clusters" % (cluster + 1),
            )
    pl.scatter(
        extremities[:, 0],
        extremities[:, 1],
        c="black",
        marker="x",
        s=100.0,
        label="bary. nodes",
    )
    pl.plot([x[0], y[0]], [x[1], y[1]], color="black", linewidth=2.0)
    pl.plot([x[0], z[0]], [x[1], z[1]], color="black", linewidth=2.0)
    pl.plot([y[0], z[0]], [y[1], z[1]], color="black", linewidth=2.0)
    pl.axis("off")
    pl.legend(fontsize=11)
pl.tight_layout()
pl.show()