File: plot_compare_gpr_krr.py

package info (click to toggle)
scikit-learn 1.4.2%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 25,036 kB
  • sloc: python: 201,105; cpp: 5,790; ansic: 854; makefile: 304; sh: 56; javascript: 20
file content (397 lines) | stat: -rw-r--r-- 13,329 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
"""
==========================================================
Comparison of kernel ridge and Gaussian process regression
==========================================================

This example illustrates differences between a kernel ridge regression and a
Gaussian process regression.

Both kernel ridge regression and Gaussian process regression are using a
so-called "kernel trick" to make their models expressive enough to fit
the training data. However, the machine learning problems solved by the two
methods are drastically different.

Kernel ridge regression will find the target function that minimizes a loss
function (the mean squared error).

Instead of finding a single target function, the Gaussian process regression
employs a probabilistic approach : a Gaussian posterior distribution over
target functions is defined based on the Bayes' theorem, Thus prior
probabilities on target functions are being combined with a likelihood function
defined by the observed training data to provide estimates of the posterior
distributions.

We will illustrate these differences with an example and we will also focus on
tuning the kernel hyperparameters.
"""

# Authors: Jan Hendrik Metzen <jhm@informatik.uni-bremen.de>
#          Guillaume Lemaitre <g.lemaitre58@gmail.com>
# License: BSD 3 clause

# %%
# Generating a dataset
# --------------------
#
# We create a synthetic dataset. The true generative process will take a 1-D
# vector and compute its sine. Note that the period of this sine is thus
# :math:`2 \pi`. We will reuse this information later in this example.
import numpy as np

rng = np.random.RandomState(0)
data = np.linspace(0, 30, num=1_000).reshape(-1, 1)
target = np.sin(data).ravel()

# %%
# Now, we can imagine a scenario where we get observations from this true
# process. However, we will add some challenges:
#
# - the measurements will be noisy;
# - only samples from the beginning of the signal will be available.
training_sample_indices = rng.choice(np.arange(0, 400), size=40, replace=False)
training_data = data[training_sample_indices]
training_noisy_target = target[training_sample_indices] + 0.5 * rng.randn(
    len(training_sample_indices)
)

# %%
# Let's plot the true signal and the noisy measurements available for training.
import matplotlib.pyplot as plt

plt.plot(data, target, label="True signal", linewidth=2)
plt.scatter(
    training_data,
    training_noisy_target,
    color="black",
    label="Noisy measurements",
)
plt.legend()
plt.xlabel("data")
plt.ylabel("target")
_ = plt.title(
    "Illustration of the true generative process and \n"
    "noisy measurements available during training"
)

# %%
# Limitations of a simple linear model
# ------------------------------------
#
# First, we would like to highlight the limitations of a linear model given
# our dataset. We fit a :class:`~sklearn.linear_model.Ridge` and check the
# predictions of this model on our dataset.
from sklearn.linear_model import Ridge

ridge = Ridge().fit(training_data, training_noisy_target)

plt.plot(data, target, label="True signal", linewidth=2)
plt.scatter(
    training_data,
    training_noisy_target,
    color="black",
    label="Noisy measurements",
)
plt.plot(data, ridge.predict(data), label="Ridge regression")
plt.legend()
plt.xlabel("data")
plt.ylabel("target")
_ = plt.title("Limitation of a linear model such as ridge")

# %%
# Such a ridge regressor underfits data since it is not expressive enough.
#
# Kernel methods: kernel ridge and Gaussian process
# -------------------------------------------------
#
# Kernel ridge
# ............
#
# We can make the previous linear model more expressive by using a so-called
# kernel. A kernel is an embedding from the original feature space to another
# one. Simply put, it is used to map our original data into a newer and more
# complex feature space. This new space is explicitly defined by the choice of
# kernel.
#
# In our case, we know that the true generative process is a periodic function.
# We can use a :class:`~sklearn.gaussian_process.kernels.ExpSineSquared` kernel
# which allows recovering the periodicity. The class
# :class:`~sklearn.kernel_ridge.KernelRidge` will accept such a kernel.
#
# Using this model together with a kernel is equivalent to embed the data
# using the mapping function of the kernel and then apply a ridge regression.
# In practice, the data are not mapped explicitly; instead the dot product
# between samples in the higher dimensional feature space is computed using the
# "kernel trick".
#
# Thus, let's use such a :class:`~sklearn.kernel_ridge.KernelRidge`.
import time

from sklearn.gaussian_process.kernels import ExpSineSquared
from sklearn.kernel_ridge import KernelRidge

kernel_ridge = KernelRidge(kernel=ExpSineSquared())

start_time = time.time()
kernel_ridge.fit(training_data, training_noisy_target)
print(
    f"Fitting KernelRidge with default kernel: {time.time() - start_time:.3f} seconds"
)

# %%
plt.plot(data, target, label="True signal", linewidth=2, linestyle="dashed")
plt.scatter(
    training_data,
    training_noisy_target,
    color="black",
    label="Noisy measurements",
)
plt.plot(
    data,
    kernel_ridge.predict(data),
    label="Kernel ridge",
    linewidth=2,
    linestyle="dashdot",
)
plt.legend(loc="lower right")
plt.xlabel("data")
plt.ylabel("target")
_ = plt.title(
    "Kernel ridge regression with an exponential sine squared\n "
    "kernel using default hyperparameters"
)

# %%
# This fitted model is not accurate. Indeed, we did not set the parameters of
# the kernel and instead used the default ones. We can inspect them.
kernel_ridge.kernel

# %%
# Our kernel has two parameters: the length-scale and the periodicity. For our
# dataset, we use `sin` as the generative process, implying a
# :math:`2 \pi`-periodicity for the signal. The default value of the parameter
# being :math:`1`, it explains the high frequency observed in the predictions of
# our model.
# Similar conclusions could be drawn with the length-scale parameter. Thus, it
# tell us that the kernel parameters need to be tuned. We will use a randomized
# search to tune the different parameters the kernel ridge model: the `alpha`
# parameter and the kernel parameters.

# %%
from scipy.stats import loguniform

from sklearn.model_selection import RandomizedSearchCV

param_distributions = {
    "alpha": loguniform(1e0, 1e3),
    "kernel__length_scale": loguniform(1e-2, 1e2),
    "kernel__periodicity": loguniform(1e0, 1e1),
}
kernel_ridge_tuned = RandomizedSearchCV(
    kernel_ridge,
    param_distributions=param_distributions,
    n_iter=500,
    random_state=0,
)
start_time = time.time()
kernel_ridge_tuned.fit(training_data, training_noisy_target)
print(f"Time for KernelRidge fitting: {time.time() - start_time:.3f} seconds")

# %%
# Fitting the model is now more computationally expensive since we have to try
# several combinations of hyperparameters. We can have a look at the
# hyperparameters found to get some intuitions.
kernel_ridge_tuned.best_params_

# %%
# Looking at the best parameters, we see that they are different from the
# defaults. We also see that the periodicity is closer to the expected value:
# :math:`2 \pi`. We can now inspect the predictions of our tuned kernel ridge.
start_time = time.time()
predictions_kr = kernel_ridge_tuned.predict(data)
print(f"Time for KernelRidge predict: {time.time() - start_time:.3f} seconds")

# %%
plt.plot(data, target, label="True signal", linewidth=2, linestyle="dashed")
plt.scatter(
    training_data,
    training_noisy_target,
    color="black",
    label="Noisy measurements",
)
plt.plot(
    data,
    predictions_kr,
    label="Kernel ridge",
    linewidth=2,
    linestyle="dashdot",
)
plt.legend(loc="lower right")
plt.xlabel("data")
plt.ylabel("target")
_ = plt.title(
    "Kernel ridge regression with an exponential sine squared\n "
    "kernel using tuned hyperparameters"
)

# %%
# We get a much more accurate model. We still observe some errors mainly due to
# the noise added to the dataset.
#
# Gaussian process regression
# ...........................
#
# Now, we will use a
# :class:`~sklearn.gaussian_process.GaussianProcessRegressor` to fit the same
# dataset. When training a Gaussian process, the hyperparameters of the kernel
# are optimized during the fitting process. There is no need for an external
# hyperparameter search. Here, we create a slightly more complex kernel than
# for the kernel ridge regressor: we add a
# :class:`~sklearn.gaussian_process.kernels.WhiteKernel` that is used to
# estimate the noise in the dataset.
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import WhiteKernel

kernel = 1.0 * ExpSineSquared(1.0, 5.0, periodicity_bounds=(1e-2, 1e1)) + WhiteKernel(
    1e-1
)
gaussian_process = GaussianProcessRegressor(kernel=kernel)
start_time = time.time()
gaussian_process.fit(training_data, training_noisy_target)
print(
    f"Time for GaussianProcessRegressor fitting: {time.time() - start_time:.3f} seconds"
)

# %%
# The computation cost of training a Gaussian process is much less than the
# kernel ridge that uses a randomized search. We can check the parameters of
# the kernels that we computed.
gaussian_process.kernel_

# %%
# Indeed, we see that the parameters have been optimized. Looking at the
# `periodicity` parameter, we see that we found a period close to the
# theoretical value :math:`2 \pi`. We can have a look now at the predictions of
# our model.
start_time = time.time()
mean_predictions_gpr, std_predictions_gpr = gaussian_process.predict(
    data,
    return_std=True,
)
print(
    f"Time for GaussianProcessRegressor predict: {time.time() - start_time:.3f} seconds"
)

# %%
plt.plot(data, target, label="True signal", linewidth=2, linestyle="dashed")
plt.scatter(
    training_data,
    training_noisy_target,
    color="black",
    label="Noisy measurements",
)
# Plot the predictions of the kernel ridge
plt.plot(
    data,
    predictions_kr,
    label="Kernel ridge",
    linewidth=2,
    linestyle="dashdot",
)
# Plot the predictions of the gaussian process regressor
plt.plot(
    data,
    mean_predictions_gpr,
    label="Gaussian process regressor",
    linewidth=2,
    linestyle="dotted",
)
plt.fill_between(
    data.ravel(),
    mean_predictions_gpr - std_predictions_gpr,
    mean_predictions_gpr + std_predictions_gpr,
    color="tab:green",
    alpha=0.2,
)
plt.legend(loc="lower right")
plt.xlabel("data")
plt.ylabel("target")
_ = plt.title("Comparison between kernel ridge and gaussian process regressor")

# %%
# We observe that the results of the kernel ridge and the Gaussian process
# regressor are close. However, the Gaussian process regressor also provide
# an uncertainty information that is not available with a kernel ridge.
# Due to the probabilistic formulation of the target functions, the
# Gaussian process can output the standard deviation (or the covariance)
# together with the mean predictions of the target functions.
#
# However, it comes at a cost: the time to compute the predictions is higher
# with a Gaussian process.
#
# Final conclusion
# ----------------
#
# We can give a final word regarding the possibility of the two models to
# extrapolate. Indeed, we only provided the beginning of the signal as a
# training set. Using a periodic kernel forces our model to repeat the pattern
# found on the training set. Using this kernel information together with the
# capacity of the both models to extrapolate, we observe that the models will
# continue to predict the sine pattern.
#
# Gaussian process allows to combine kernels together. Thus, we could associate
# the exponential sine squared kernel together with a radial basis function
# kernel.
from sklearn.gaussian_process.kernels import RBF

kernel = 1.0 * ExpSineSquared(1.0, 5.0, periodicity_bounds=(1e-2, 1e1)) * RBF(
    length_scale=15, length_scale_bounds="fixed"
) + WhiteKernel(1e-1)
gaussian_process = GaussianProcessRegressor(kernel=kernel)
gaussian_process.fit(training_data, training_noisy_target)
mean_predictions_gpr, std_predictions_gpr = gaussian_process.predict(
    data,
    return_std=True,
)

# %%
plt.plot(data, target, label="True signal", linewidth=2, linestyle="dashed")
plt.scatter(
    training_data,
    training_noisy_target,
    color="black",
    label="Noisy measurements",
)
# Plot the predictions of the kernel ridge
plt.plot(
    data,
    predictions_kr,
    label="Kernel ridge",
    linewidth=2,
    linestyle="dashdot",
)
# Plot the predictions of the gaussian process regressor
plt.plot(
    data,
    mean_predictions_gpr,
    label="Gaussian process regressor",
    linewidth=2,
    linestyle="dotted",
)
plt.fill_between(
    data.ravel(),
    mean_predictions_gpr - std_predictions_gpr,
    mean_predictions_gpr + std_predictions_gpr,
    color="tab:green",
    alpha=0.2,
)
plt.legend(loc="lower right")
plt.xlabel("data")
plt.ylabel("target")
_ = plt.title("Effect of using a radial basis function kernel")

# %%
# The effect of using a radial basis function kernel will attenuate the
# periodicity effect once that no sample are available in the training.
# As testing samples get further away from the training ones, predictions
# are converging towards their mean and their standard deviation
# also increases.