File: plot_rbf_parameters.py

package info (click to toggle)
scikit-learn 0.11.0-2%2Bdeb7u1
  • links: PTS, VCS
  • area: main
  • in suites: wheezy
  • size: 13,900 kB
  • sloc: python: 34,740; ansic: 8,860; cpp: 8,849; pascal: 230; makefile: 211; sh: 14
file content (66 lines) | stat: -rw-r--r-- 1,684 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
==================
RBF SVM parameters
==================

This example illustrates the effect of the parameters `gamma`
and `C` of the rbf kernel SVM.

Intuitively, the `gamma` parameter defines how far the influence
of a single training example reaches, with low values meaning 'far'
and high values meaning 'close'.
The `C` parameter trades off misclassification of training examples
against simplicity of the decision surface. A low C makes
the decision surface smooth, while a high C aims at classifying
all training examples correctly.
"""
print __doc__

import numpy as np
import pylab as pl

from sklearn import svm
from sklearn.datasets import load_iris
from sklearn.preprocessing import Scaler

iris = load_iris()
X = iris.data[:, :2]  # Take only 2 dimensions
y = iris.target
X = X[y > 0]
y = y[y > 0]
y -= 1

scaler = Scaler()
X = scaler.fit_transform(X)

xx, yy = np.meshgrid(np.linspace(-5, 5, 200), np.linspace(-5, 5, 200))

np.random.seed(0)

gamma_range = [1e-1, 1, 1e1]
C_range = [1, 1e2, 1e4]

pl.figure()
k = 1

for C in C_range:
    for gamma in gamma_range:
        # fit the model
        clf = svm.SVC(gamma=gamma, C=C)
        clf.fit(X, y)

        # plot the decision function for each datapoint on the grid
        Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
        Z = Z.reshape(xx.shape)

        pl.subplot(3, 3, k)
        pl.title("gamma %.1f, C %.2f" % (gamma, C))
        k += 1
        pl.pcolormesh(xx, yy, -Z, cmap=pl.cm.jet)
        pl.scatter(X[:, 0], X[:, 1], c=y, cmap=pl.cm.jet)
        pl.xticks(())
        pl.yticks(())
        pl.axis('tight')

pl.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95)
pl.show()