File: plot_digits_denoising.py

package info (click to toggle)
scikit-learn 1.4.2%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 25,036 kB
  • sloc: python: 201,105; cpp: 5,790; ansic: 854; makefile: 304; sh: 56; javascript: 20
file content (151 lines) | stat: -rw-r--r-- 5,242 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
================================
Image denoising using kernel PCA
================================

This example shows how to use :class:`~sklearn.decomposition.KernelPCA` to
denoise images. In short, we take advantage of the approximation function
learned during `fit` to reconstruct the original image.

We will compare the results with an exact reconstruction using
:class:`~sklearn.decomposition.PCA`.

We will use USPS digits dataset to reproduce presented in Sect. 4 of [1]_.

.. topic:: References

   .. [1] `Bakır, Gökhan H., Jason Weston, and Bernhard Schölkopf.
      "Learning to find pre-images."
      Advances in neural information processing systems 16 (2004): 449-456.
      <https://papers.nips.cc/paper/2003/file/ac1ad983e08ad3304a97e147f522747e-Paper.pdf>`_

"""

# Authors: Guillaume Lemaitre <guillaume.lemaitre@inria.fr>
# Licence: BSD 3 clause

# %%
# Load the dataset via OpenML
# ---------------------------
#
# The USPS digits datasets is available in OpenML. We use
# :func:`~sklearn.datasets.fetch_openml` to get this dataset. In addition, we
# normalize the dataset such that all pixel values are in the range (0, 1).
import numpy as np

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

X, y = fetch_openml(data_id=41082, as_frame=False, return_X_y=True)
X = MinMaxScaler().fit_transform(X)

# %%
# The idea will be to learn a PCA basis (with and without a kernel) on
# noisy images and then use these models to reconstruct and denoise these
# images.
#
# Thus, we split our dataset into a training and testing set composed of 1,000
# samples for the training and 100 samples for testing. These images are
# noise-free and we will use them to evaluate the efficiency of the denoising
# approaches. In addition, we create a copy of the original dataset and add a
# Gaussian noise.
#
# The idea of this application, is to show that we can denoise corrupted images
# by learning a PCA basis on some uncorrupted images. We will use both a PCA
# and a kernel-based PCA to solve this problem.
X_train, X_test, y_train, y_test = train_test_split(
    X, y, stratify=y, random_state=0, train_size=1_000, test_size=100
)

rng = np.random.RandomState(0)
noise = rng.normal(scale=0.25, size=X_test.shape)
X_test_noisy = X_test + noise

noise = rng.normal(scale=0.25, size=X_train.shape)
X_train_noisy = X_train + noise

# %%
# In addition, we will create a helper function to qualitatively assess the
# image reconstruction by plotting the test images.
import matplotlib.pyplot as plt


def plot_digits(X, title):
    """Small helper function to plot 100 digits."""
    fig, axs = plt.subplots(nrows=10, ncols=10, figsize=(8, 8))
    for img, ax in zip(X, axs.ravel()):
        ax.imshow(img.reshape((16, 16)), cmap="Greys")
        ax.axis("off")
    fig.suptitle(title, fontsize=24)


# %%
# In addition, we will use the mean squared error (MSE) to quantitatively
# assess the image reconstruction.
#
# Let's first have a look to see the difference between noise-free and noisy
# images. We will check the test set in this regard.
plot_digits(X_test, "Uncorrupted test images")
plot_digits(
    X_test_noisy, f"Noisy test images\nMSE: {np.mean((X_test - X_test_noisy) ** 2):.2f}"
)

# %%
# Learn the `PCA` basis
# ---------------------
#
# We can now learn our PCA basis using both a linear PCA and a kernel PCA that
# uses a radial basis function (RBF) kernel.
from sklearn.decomposition import PCA, KernelPCA

pca = PCA(n_components=32, random_state=42)
kernel_pca = KernelPCA(
    n_components=400,
    kernel="rbf",
    gamma=1e-3,
    fit_inverse_transform=True,
    alpha=5e-3,
    random_state=42,
)

pca.fit(X_train_noisy)
_ = kernel_pca.fit(X_train_noisy)

# %%
# Reconstruct and denoise test images
# -----------------------------------
#
# Now, we can transform and reconstruct the noisy test set. Since we used less
# components than the number of original features, we will get an approximation
# of the original set. Indeed, by dropping the components explaining variance
# in PCA the least, we hope to remove noise. Similar thinking happens in kernel
# PCA; however, we expect a better reconstruction because we use a non-linear
# kernel to learn the PCA basis and a kernel ridge to learn the mapping
# function.
X_reconstructed_kernel_pca = kernel_pca.inverse_transform(
    kernel_pca.transform(X_test_noisy)
)
X_reconstructed_pca = pca.inverse_transform(pca.transform(X_test_noisy))

# %%
plot_digits(X_test, "Uncorrupted test images")
plot_digits(
    X_reconstructed_pca,
    f"PCA reconstruction\nMSE: {np.mean((X_test - X_reconstructed_pca) ** 2):.2f}",
)
plot_digits(
    X_reconstructed_kernel_pca,
    (
        "Kernel PCA reconstruction\n"
        f"MSE: {np.mean((X_test - X_reconstructed_kernel_pca) ** 2):.2f}"
    ),
)

# %%
# PCA has a lower MSE than kernel PCA. However, the qualitative analysis might
# not favor PCA instead of kernel PCA. We observe that kernel PCA is able to
# remove background noise and provide a smoother image.
#
# However, it should be noted that the results of the denoising with kernel PCA
# will depend of the parameters `n_components`, `gamma`, and `alpha`.