1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
"""
Comparison between grid search and successive halving
=====================================================
This example compares the parameter search performed by
:class:`~sklearn.model_selection.HalvingGridSearchCV` and
:class:`~sklearn.model_selection.GridSearchCV`.
"""
from time import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.svm import SVC
from sklearn import datasets
from sklearn.model_selection import GridSearchCV
from sklearn.experimental import enable_halving_search_cv # noqa
from sklearn.model_selection import HalvingGridSearchCV
# %%
# We first define the parameter space for an :class:`~sklearn.svm.SVC`
# estimator, and compute the time required to train a
# :class:`~sklearn.model_selection.HalvingGridSearchCV` instance, as well as a
# :class:`~sklearn.model_selection.GridSearchCV` instance.
rng = np.random.RandomState(0)
X, y = datasets.make_classification(n_samples=1000, random_state=rng)
gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
Cs = [1, 10, 100, 1e3, 1e4, 1e5]
param_grid = {"gamma": gammas, "C": Cs}
clf = SVC(random_state=rng)
tic = time()
gsh = HalvingGridSearchCV(
estimator=clf, param_grid=param_grid, factor=2, random_state=rng
)
gsh.fit(X, y)
gsh_time = time() - tic
tic = time()
gs = GridSearchCV(estimator=clf, param_grid=param_grid)
gs.fit(X, y)
gs_time = time() - tic
# %%
# We now plot heatmaps for both search estimators.
def make_heatmap(ax, gs, is_sh=False, make_cbar=False):
"""Helper to make a heatmap."""
results = pd.DataFrame(gs.cv_results_)
results[["param_C", "param_gamma"]] = results[["param_C", "param_gamma"]].astype(
np.float64
)
if is_sh:
# SH dataframe: get mean_test_score values for the highest iter
scores_matrix = results.sort_values("iter").pivot_table(
index="param_gamma",
columns="param_C",
values="mean_test_score",
aggfunc="last",
)
else:
scores_matrix = results.pivot(
index="param_gamma", columns="param_C", values="mean_test_score"
)
im = ax.imshow(scores_matrix)
ax.set_xticks(np.arange(len(Cs)))
ax.set_xticklabels(["{:.0E}".format(x) for x in Cs])
ax.set_xlabel("C", fontsize=15)
ax.set_yticks(np.arange(len(gammas)))
ax.set_yticklabels(["{:.0E}".format(x) for x in gammas])
ax.set_ylabel("gamma", fontsize=15)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
if is_sh:
iterations = results.pivot_table(
index="param_gamma", columns="param_C", values="iter", aggfunc="max"
).values
for i in range(len(gammas)):
for j in range(len(Cs)):
ax.text(
j,
i,
iterations[i, j],
ha="center",
va="center",
color="w",
fontsize=20,
)
if make_cbar:
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
cbar_ax.set_ylabel("mean_test_score", rotation=-90, va="bottom", fontsize=15)
fig, axes = plt.subplots(ncols=2, sharey=True)
ax1, ax2 = axes
make_heatmap(ax1, gsh, is_sh=True)
make_heatmap(ax2, gs, make_cbar=True)
ax1.set_title("Successive Halving\ntime = {:.3f}s".format(gsh_time), fontsize=15)
ax2.set_title("GridSearch\ntime = {:.3f}s".format(gs_time), fontsize=15)
plt.show()
# %%
# The heatmaps show the mean test score of the parameter combinations for an
# :class:`~sklearn.svm.SVC` instance. The
# :class:`~sklearn.model_selection.HalvingGridSearchCV` also shows the
# iteration at which the combinations where last used. The combinations marked
# as ``0`` were only evaluated at the first iteration, while the ones with
# ``5`` are the parameter combinations that are considered the best ones.
#
# We can see that the :class:`~sklearn.model_selection.HalvingGridSearchCV`
# class is able to find parameter combinations that are just as accurate as
# :class:`~sklearn.model_selection.GridSearchCV`, in much less time.
|